Files
pay-bridge/backend/internal/service/merchant_test.go
2026-03-13 15:51:59 +08:00

201 lines
7.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"pay-bridge/internal/model"
)
// mockMerchantRepo 实现 merchantRepo interface
type mockMerchantRepo struct {
mock.Mock
}
func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error {
return m.Called(ctx, merchant).Error(0)
}
func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
args := m.Called(ctx, merchantID)
return args.Get(0).(*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
args := m.Called(ctx, merchantID, appID)
v, _ := args.Get(0).(*model.Merchant)
return v, args.Error(1)
}
func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
return m.Called(ctx, merchantID, status, updates).Error(0)
}
func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, appID, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
args := m.Called(ctx)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
return m.Called(ctx, app).Error(0)
}
func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
args := m.Called(ctx, merchantID)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
args := m.Called(ctx, merchantID, channelCode)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
return m.Called(ctx, applicationID, updates).Error(0)
}
// newTestMerchantService 创建注入了 mock repo 的 servicechannelSvc 为 nil仅测不涉及渠道的方法
func newTestMerchantService(repo merchantRepo) *MerchantService {
return &MerchantService{merchantRepo: repo}
}
var ctx = context.Background()
// --- GetMerchantForApp ---
func TestGetMerchantForApp_OK(t *testing.T) {
repo := new(mockMerchantRepo)
want := &model.Merchant{MerchantID: "m001", AppID: "app1"}
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil)
svc := newTestMerchantService(repo)
got, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.NoError(t, err)
assert.Equal(t, want, got)
repo.AssertExpectations(t)
}
func TestGetMerchantForApp_NotFound(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.EqualError(t, err, "30001")
}
func TestGetMerchantForApp_WrongAppID(t *testing.T) {
repo := new(mockMerchantRepo)
// 商户存在但属于 other_appGetByMerchantIDAndAppID 返回 nil
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "evil_app", "m001")
assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found而不是泄露商户信息")
}
func TestGetMerchantForApp_DBError(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error"))
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.EqualError(t, err, "db error")
}
// --- CreateMerchantForApp ---
func TestCreateMerchantForApp_SetsAppID(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool {
return m.AppID == "app1" && m.MerchantID == "m001"
})).Return(nil)
svc := newTestMerchantService(repo)
m := &model.Merchant{MerchantID: "m001"}
err := svc.CreateMerchantForApp(ctx, "app1", m)
assert.NoError(t, err)
assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入")
repo.AssertExpectations(t)
}
// --- ListMerchantsForApp ---
func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) {
repo := new(mockMerchantRepo)
want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}}
repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil)
svc := newTestMerchantService(repo)
got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0)
assert.NoError(t, err)
assert.Len(t, got, 1)
repo.AssertExpectations(t)
}
// --- ApplyForApp ---
func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil)
assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件")
}
// --- GetChannelMerchantID ---
func TestGetChannelMerchantID_Approved(t *testing.T) {
repo := new(mockMerchantRepo)
app := &model.MerchantApplication{
ChannelMerchantID: "ch_m_999",
}
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil)
svc := newTestMerchantService(repo)
id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
assert.NoError(t, err)
assert.Equal(t, "ch_m_999", id)
}
func TestGetChannelMerchantID_NotApproved(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil)
svc := newTestMerchantService(repo)
id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
assert.NoError(t, err)
assert.Empty(t, id, "未在该渠道进件时返回空字符串")
}
func TestGetChannelMerchantID_MultiChannel(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").
Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").
Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil)
svc := newTestMerchantService(repo)
heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id")
assert.Equal(t, "ali_001", aliID)
}