package service import ( "context" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "pay-bridge/internal/model" ) // mockMerchantRepo 实现 merchantRepo interface type mockMerchantRepo struct { mock.Mock } func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error { return m.Called(ctx, merchant).Error(0) } func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) { args := m.Called(ctx, merchantID) return args.Get(0).(*model.Merchant), args.Error(1) } func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) { args := m.Called(ctx, merchantID, appID) v, _ := args.Get(0).(*model.Merchant) return v, args.Error(1) } func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error { return m.Called(ctx, merchantID, status, updates).Error(0) } func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) { args := m.Called(ctx, status, limit, offset) return args.Get(0).([]*model.Merchant), args.Error(1) } func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) { args := m.Called(ctx, appID, status, limit, offset) return args.Get(0).([]*model.Merchant), args.Error(1) } func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) { args := m.Called(ctx) return args.Get(0).([]*model.Merchant), args.Error(1) } func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error { return m.Called(ctx, app).Error(0) } func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) { args := m.Called(ctx, merchantID) v, _ := args.Get(0).(*model.MerchantApplication) return v, args.Error(1) } func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) { args := m.Called(ctx, merchantID, channelCode) v, _ := args.Get(0).(*model.MerchantApplication) return v, args.Error(1) } func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error { return m.Called(ctx, applicationID, updates).Error(0) } // newTestMerchantService 创建注入了 mock repo 的 service(channelSvc 为 nil,仅测不涉及渠道的方法) func newTestMerchantService(repo merchantRepo) *MerchantService { return &MerchantService{merchantRepo: repo} } var ctx = context.Background() // --- GetMerchantForApp --- func TestGetMerchantForApp_OK(t *testing.T) { repo := new(mockMerchantRepo) want := &model.Merchant{MerchantID: "m001", AppID: "app1"} repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil) svc := newTestMerchantService(repo) got, err := svc.GetMerchantForApp(ctx, "app1", "m001") assert.NoError(t, err) assert.Equal(t, want, got) repo.AssertExpectations(t) } func TestGetMerchantForApp_NotFound(t *testing.T) { repo := new(mockMerchantRepo) repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil) svc := newTestMerchantService(repo) _, err := svc.GetMerchantForApp(ctx, "app1", "m001") assert.EqualError(t, err, "30001") } func TestGetMerchantForApp_WrongAppID(t *testing.T) { repo := new(mockMerchantRepo) // 商户存在但属于 other_app,GetByMerchantIDAndAppID 返回 nil repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil) svc := newTestMerchantService(repo) _, err := svc.GetMerchantForApp(ctx, "evil_app", "m001") assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found,而不是泄露商户信息") } func TestGetMerchantForApp_DBError(t *testing.T) { repo := new(mockMerchantRepo) repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error")) svc := newTestMerchantService(repo) _, err := svc.GetMerchantForApp(ctx, "app1", "m001") assert.EqualError(t, err, "db error") } // --- CreateMerchantForApp --- func TestCreateMerchantForApp_SetsAppID(t *testing.T) { repo := new(mockMerchantRepo) repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool { return m.AppID == "app1" && m.MerchantID == "m001" })).Return(nil) svc := newTestMerchantService(repo) m := &model.Merchant{MerchantID: "m001"} err := svc.CreateMerchantForApp(ctx, "app1", m) assert.NoError(t, err) assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入") repo.AssertExpectations(t) } // --- ListMerchantsForApp --- func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) { repo := new(mockMerchantRepo) want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}} repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil) svc := newTestMerchantService(repo) got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0) assert.NoError(t, err) assert.Len(t, got, 1) repo.AssertExpectations(t) } // --- ApplyForApp --- func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) { repo := new(mockMerchantRepo) repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil) svc := newTestMerchantService(repo) _, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil) assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件") } // --- GetChannelMerchantID --- func TestGetChannelMerchantID_Approved(t *testing.T) { repo := new(mockMerchantRepo) app := &model.MerchantApplication{ ChannelMerchantID: "ch_m_999", } repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil) svc := newTestMerchantService(repo) id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY") assert.NoError(t, err) assert.Equal(t, "ch_m_999", id) } func TestGetChannelMerchantID_NotApproved(t *testing.T) { repo := new(mockMerchantRepo) repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil) svc := newTestMerchantService(repo) id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY") assert.NoError(t, err) assert.Empty(t, id, "未在该渠道进件时返回空字符串") } func TestGetChannelMerchantID_MultiChannel(t *testing.T) { repo := new(mockMerchantRepo) repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY"). Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil) repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY"). Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil) svc := newTestMerchantService(repo) heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY") aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY") assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id") assert.Equal(t, "ali_001", aliID) }