This commit is contained in:
2026-03-13 15:51:59 +08:00
parent 4db2386bbf
commit 4e91f4cede
133 changed files with 19502 additions and 37 deletions

View File

@@ -0,0 +1,200 @@
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"pay-bridge/internal/model"
)
// mockMerchantRepo 实现 merchantRepo interface
type mockMerchantRepo struct {
mock.Mock
}
func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error {
return m.Called(ctx, merchant).Error(0)
}
func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
args := m.Called(ctx, merchantID)
return args.Get(0).(*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
args := m.Called(ctx, merchantID, appID)
v, _ := args.Get(0).(*model.Merchant)
return v, args.Error(1)
}
func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
return m.Called(ctx, merchantID, status, updates).Error(0)
}
func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, appID, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
args := m.Called(ctx)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
return m.Called(ctx, app).Error(0)
}
func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
args := m.Called(ctx, merchantID)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
args := m.Called(ctx, merchantID, channelCode)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
return m.Called(ctx, applicationID, updates).Error(0)
}
// newTestMerchantService 创建注入了 mock repo 的 servicechannelSvc 为 nil仅测不涉及渠道的方法
func newTestMerchantService(repo merchantRepo) *MerchantService {
return &MerchantService{merchantRepo: repo}
}
var ctx = context.Background()
// --- GetMerchantForApp ---
func TestGetMerchantForApp_OK(t *testing.T) {
repo := new(mockMerchantRepo)
want := &model.Merchant{MerchantID: "m001", AppID: "app1"}
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil)
svc := newTestMerchantService(repo)
got, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.NoError(t, err)
assert.Equal(t, want, got)
repo.AssertExpectations(t)
}
func TestGetMerchantForApp_NotFound(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.EqualError(t, err, "30001")
}
func TestGetMerchantForApp_WrongAppID(t *testing.T) {
repo := new(mockMerchantRepo)
// 商户存在但属于 other_appGetByMerchantIDAndAppID 返回 nil
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "evil_app", "m001")
assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found而不是泄露商户信息")
}
func TestGetMerchantForApp_DBError(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error"))
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.EqualError(t, err, "db error")
}
// --- CreateMerchantForApp ---
func TestCreateMerchantForApp_SetsAppID(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool {
return m.AppID == "app1" && m.MerchantID == "m001"
})).Return(nil)
svc := newTestMerchantService(repo)
m := &model.Merchant{MerchantID: "m001"}
err := svc.CreateMerchantForApp(ctx, "app1", m)
assert.NoError(t, err)
assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入")
repo.AssertExpectations(t)
}
// --- ListMerchantsForApp ---
func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) {
repo := new(mockMerchantRepo)
want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}}
repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil)
svc := newTestMerchantService(repo)
got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0)
assert.NoError(t, err)
assert.Len(t, got, 1)
repo.AssertExpectations(t)
}
// --- ApplyForApp ---
func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil)
assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件")
}
// --- GetChannelMerchantID ---
func TestGetChannelMerchantID_Approved(t *testing.T) {
repo := new(mockMerchantRepo)
app := &model.MerchantApplication{
ChannelMerchantID: "ch_m_999",
}
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil)
svc := newTestMerchantService(repo)
id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
assert.NoError(t, err)
assert.Equal(t, "ch_m_999", id)
}
func TestGetChannelMerchantID_NotApproved(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil)
svc := newTestMerchantService(repo)
id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
assert.NoError(t, err)
assert.Empty(t, id, "未在该渠道进件时返回空字符串")
}
func TestGetChannelMerchantID_MultiChannel(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").
Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").
Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil)
svc := newTestMerchantService(repo)
heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id")
assert.Equal(t, "ali_001", aliID)
}