This commit is contained in:
2026-03-13 15:51:59 +08:00
parent 4db2386bbf
commit 4e91f4cede
133 changed files with 19502 additions and 37 deletions

View File

@@ -0,0 +1,216 @@
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
)
func init() {
gin.SetMode(gin.TestMode)
}
// mockMerchantSvc 实现 merchantService interface
type mockMerchantSvc struct {
mock.Mock
}
func (m *mockMerchantSvc) CreateMerchantForApp(ctx context.Context, appID string, merchant *model.Merchant) error {
return m.Called(ctx, appID, merchant).Error(0)
}
func (m *mockMerchantSvc) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) {
args := m.Called(ctx, appID, merchantID)
v, _ := args.Get(0).(*model.Merchant)
return v, args.Error(1)
}
func (m *mockMerchantSvc) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, appID, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantSvc) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) {
args := m.Called(ctx, appID, merchantID, channelCode, bizContent)
return args.String(0), args.Error(1)
}
func (m *mockMerchantSvc) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) {
args := m.Called(ctx, appID, merchantID)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantSvc) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) {
args := m.Called(ctx, channelCode, req)
return args.String(0), args.Error(1)
}
// newMerchantTestRouter 构建测试路由,注入固定 app_id 模拟鉴权
func newMerchantTestRouter(svc *mockMerchantSvc) *gin.Engine {
r := gin.New()
h := &MerchantHandler{merchantSvc: svc}
auth := func(c *gin.Context) {
c.Set("app_id", "app_test")
c.Next()
}
g := r.Group("/api/v1/merchant", auth)
g.POST("", h.CreateMerchant)
g.GET("", h.ListMerchants)
g.GET("/:merchantID", h.GetMerchant)
g.POST("/:merchantID/apply", h.Apply)
g.GET("/:merchantID/audit", h.QueryAuditStatus)
return r
}
// --- CreateMerchant ---
func TestCreateMerchant_OK(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("CreateMerchantForApp", mock.Anything, "app_test", mock.MatchedBy(func(m *model.Merchant) bool {
return m.MerchantID == "m001" && m.MerchantName == "测试公司"
})).Return(nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
strings.NewReader(`{"merchant_id":"m001","merchant_name":"测试公司"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
assert.Equal(t, "0", resp["code"])
assert.Equal(t, "m001", resp["data"].(map[string]any)["merchant_id"])
svc.AssertExpectations(t)
}
func TestCreateMerchant_MissingName(t *testing.T) {
svc := new(mockMerchantSvc)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
strings.NewReader(`{"merchant_id":"m001"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
svc.AssertNotCalled(t, "CreateMerchantForApp")
}
func TestCreateMerchant_MissingID(t *testing.T) {
svc := new(mockMerchantSvc)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
strings.NewReader(`{"merchant_name":"测试公司"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
// --- GetMerchant ---
func TestGetMerchant_OK(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("GetMerchantForApp", mock.Anything, "app_test", "m001").
Return(&model.Merchant{MerchantID: "m001", AppID: "app_test"}, nil)
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m001", nil))
assert.Equal(t, http.StatusOK, w.Code)
svc.AssertExpectations(t)
}
func TestGetMerchant_NotFound(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("GetMerchantForApp", mock.Anything, "app_test", "m999").
Return((*model.Merchant)(nil), errors.New("30001"))
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m999", nil))
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestGetMerchant_WrongApp(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("GetMerchantForApp", mock.Anything, "app_test", "other_m").
Return((*model.Merchant)(nil), errors.New("30001"))
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/other_m", nil))
// 跨 app 访问应返回 404而不是 403避免信息泄露
assert.Equal(t, http.StatusNotFound, w.Code)
}
// --- ListMerchants ---
func TestListMerchants_DefaultPagination(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("ListMerchantsForApp", mock.Anything, "app_test", model.MerchantStatus(""), 20, 0).
Return([]*model.Merchant{}, nil)
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant", nil))
assert.Equal(t, http.StatusOK, w.Code)
svc.AssertExpectations(t)
}
// --- Apply ---
func TestApply_OK(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("ApplyForApp", mock.Anything, "app_test", "m001", "HEEPAY", mock.Anything).
Return("APP123", nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply",
strings.NewReader(`{"channel_code":"HEEPAY","submit_data":{"name":"测试公司"}}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
assert.Equal(t, "APP123", resp["data"].(map[string]any)["application_id"])
}
func TestApply_MissingChannelCode(t *testing.T) {
svc := new(mockMerchantSvc)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply",
strings.NewReader(`{"submit_data":{}}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
svc.AssertNotCalled(t, "ApplyForApp")
}
func TestApply_MerchantNotBelongToApp(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("ApplyForApp", mock.Anything, "app_test", "m_other", "HEEPAY", mock.Anything).
Return("", errors.New("30001"))
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m_other/apply",
strings.NewReader(`{"channel_code":"HEEPAY"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}