draft
This commit is contained in:
200
backend/internal/service/merchant_test.go
Normal file
200
backend/internal/service/merchant_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// mockMerchantRepo 实现 merchantRepo interface
|
||||
type mockMerchantRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error {
|
||||
return m.Called(ctx, merchant).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, merchantID)
|
||||
return args.Get(0).(*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, merchantID, appID)
|
||||
v, _ := args.Get(0).(*model.Merchant)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
|
||||
return m.Called(ctx, merchantID, status, updates).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, appID, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
|
||||
return m.Called(ctx, app).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, merchantID)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, merchantID, channelCode)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
|
||||
return m.Called(ctx, applicationID, updates).Error(0)
|
||||
}
|
||||
|
||||
// newTestMerchantService 创建注入了 mock repo 的 service(channelSvc 为 nil,仅测不涉及渠道的方法)
|
||||
func newTestMerchantService(repo merchantRepo) *MerchantService {
|
||||
return &MerchantService{merchantRepo: repo}
|
||||
}
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
// --- GetMerchantForApp ---
|
||||
|
||||
func TestGetMerchantForApp_OK(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
want := &model.Merchant{MerchantID: "m001", AppID: "app1"}
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
got, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_NotFound(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.EqualError(t, err, "30001")
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_WrongAppID(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
// 商户存在但属于 other_app,GetByMerchantIDAndAppID 返回 nil
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "evil_app", "m001")
|
||||
|
||||
assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found,而不是泄露商户信息")
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_DBError(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error"))
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.EqualError(t, err, "db error")
|
||||
}
|
||||
|
||||
// --- CreateMerchantForApp ---
|
||||
|
||||
func TestCreateMerchantForApp_SetsAppID(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool {
|
||||
return m.AppID == "app1" && m.MerchantID == "m001"
|
||||
})).Return(nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
m := &model.Merchant{MerchantID: "m001"}
|
||||
err := svc.CreateMerchantForApp(ctx, "app1", m)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入")
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- ListMerchantsForApp ---
|
||||
|
||||
func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}}
|
||||
repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, got, 1)
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- ApplyForApp ---
|
||||
|
||||
func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil)
|
||||
|
||||
assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件")
|
||||
}
|
||||
|
||||
// --- GetChannelMerchantID ---
|
||||
|
||||
func TestGetChannelMerchantID_Approved(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
app := &model.MerchantApplication{
|
||||
ChannelMerchantID: "ch_m_999",
|
||||
}
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ch_m_999", id)
|
||||
}
|
||||
|
||||
func TestGetChannelMerchantID_NotApproved(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, id, "未在该渠道进件时返回空字符串")
|
||||
}
|
||||
|
||||
func TestGetChannelMerchantID_MultiChannel(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").
|
||||
Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").
|
||||
Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
|
||||
heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
|
||||
aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
|
||||
|
||||
assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id")
|
||||
assert.Equal(t, "ali_001", aliID)
|
||||
}
|
||||
Reference in New Issue
Block a user