Files
pay-bridge/backend/internal/api/handler/merchant_test.go
2026-03-13 15:51:59 +08:00

217 lines
7.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
)
func init() {
gin.SetMode(gin.TestMode)
}
// mockMerchantSvc 实现 merchantService interface
type mockMerchantSvc struct {
mock.Mock
}
func (m *mockMerchantSvc) CreateMerchantForApp(ctx context.Context, appID string, merchant *model.Merchant) error {
return m.Called(ctx, appID, merchant).Error(0)
}
func (m *mockMerchantSvc) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) {
args := m.Called(ctx, appID, merchantID)
v, _ := args.Get(0).(*model.Merchant)
return v, args.Error(1)
}
func (m *mockMerchantSvc) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, appID, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantSvc) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) {
args := m.Called(ctx, appID, merchantID, channelCode, bizContent)
return args.String(0), args.Error(1)
}
func (m *mockMerchantSvc) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) {
args := m.Called(ctx, appID, merchantID)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantSvc) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) {
args := m.Called(ctx, channelCode, req)
return args.String(0), args.Error(1)
}
// newMerchantTestRouter 构建测试路由,注入固定 app_id 模拟鉴权
func newMerchantTestRouter(svc *mockMerchantSvc) *gin.Engine {
r := gin.New()
h := &MerchantHandler{merchantSvc: svc}
auth := func(c *gin.Context) {
c.Set("app_id", "app_test")
c.Next()
}
g := r.Group("/api/v1/merchant", auth)
g.POST("", h.CreateMerchant)
g.GET("", h.ListMerchants)
g.GET("/:merchantID", h.GetMerchant)
g.POST("/:merchantID/apply", h.Apply)
g.GET("/:merchantID/audit", h.QueryAuditStatus)
return r
}
// --- CreateMerchant ---
func TestCreateMerchant_OK(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("CreateMerchantForApp", mock.Anything, "app_test", mock.MatchedBy(func(m *model.Merchant) bool {
return m.MerchantID == "m001" && m.MerchantName == "测试公司"
})).Return(nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
strings.NewReader(`{"merchant_id":"m001","merchant_name":"测试公司"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
assert.Equal(t, "0", resp["code"])
assert.Equal(t, "m001", resp["data"].(map[string]any)["merchant_id"])
svc.AssertExpectations(t)
}
func TestCreateMerchant_MissingName(t *testing.T) {
svc := new(mockMerchantSvc)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
strings.NewReader(`{"merchant_id":"m001"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
svc.AssertNotCalled(t, "CreateMerchantForApp")
}
func TestCreateMerchant_MissingID(t *testing.T) {
svc := new(mockMerchantSvc)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
strings.NewReader(`{"merchant_name":"测试公司"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
// --- GetMerchant ---
func TestGetMerchant_OK(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("GetMerchantForApp", mock.Anything, "app_test", "m001").
Return(&model.Merchant{MerchantID: "m001", AppID: "app_test"}, nil)
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m001", nil))
assert.Equal(t, http.StatusOK, w.Code)
svc.AssertExpectations(t)
}
func TestGetMerchant_NotFound(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("GetMerchantForApp", mock.Anything, "app_test", "m999").
Return((*model.Merchant)(nil), errors.New("30001"))
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m999", nil))
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestGetMerchant_WrongApp(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("GetMerchantForApp", mock.Anything, "app_test", "other_m").
Return((*model.Merchant)(nil), errors.New("30001"))
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/other_m", nil))
// 跨 app 访问应返回 404而不是 403避免信息泄露
assert.Equal(t, http.StatusNotFound, w.Code)
}
// --- ListMerchants ---
func TestListMerchants_DefaultPagination(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("ListMerchantsForApp", mock.Anything, "app_test", model.MerchantStatus(""), 20, 0).
Return([]*model.Merchant{}, nil)
w := httptest.NewRecorder()
newMerchantTestRouter(svc).ServeHTTP(w,
httptest.NewRequest(http.MethodGet, "/api/v1/merchant", nil))
assert.Equal(t, http.StatusOK, w.Code)
svc.AssertExpectations(t)
}
// --- Apply ---
func TestApply_OK(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("ApplyForApp", mock.Anything, "app_test", "m001", "HEEPAY", mock.Anything).
Return("APP123", nil)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply",
strings.NewReader(`{"channel_code":"HEEPAY","submit_data":{"name":"测试公司"}}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
assert.Equal(t, "APP123", resp["data"].(map[string]any)["application_id"])
}
func TestApply_MissingChannelCode(t *testing.T) {
svc := new(mockMerchantSvc)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply",
strings.NewReader(`{"submit_data":{}}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
svc.AssertNotCalled(t, "ApplyForApp")
}
func TestApply_MerchantNotBelongToApp(t *testing.T) {
svc := new(mockMerchantSvc)
svc.On("ApplyForApp", mock.Anything, "app_test", "m_other", "HEEPAY", mock.Anything).
Return("", errors.New("30001"))
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m_other/apply",
strings.NewReader(`{"channel_code":"HEEPAY"}`))
req.Header.Set("Content-Type", "application/json")
newMerchantTestRouter(svc).ServeHTTP(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}