package handler import ( "context" "encoding/json" "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "pay-bridge/internal/channel" "pay-bridge/internal/model" ) func init() { gin.SetMode(gin.TestMode) } // mockMerchantSvc 实现 merchantService interface type mockMerchantSvc struct { mock.Mock } func (m *mockMerchantSvc) CreateMerchantForApp(ctx context.Context, appID string, merchant *model.Merchant) error { return m.Called(ctx, appID, merchant).Error(0) } func (m *mockMerchantSvc) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) { args := m.Called(ctx, appID, merchantID) v, _ := args.Get(0).(*model.Merchant) return v, args.Error(1) } func (m *mockMerchantSvc) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) { args := m.Called(ctx, appID, status, limit, offset) return args.Get(0).([]*model.Merchant), args.Error(1) } func (m *mockMerchantSvc) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) { args := m.Called(ctx, appID, merchantID, channelCode, bizContent) return args.String(0), args.Error(1) } func (m *mockMerchantSvc) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) { args := m.Called(ctx, appID, merchantID) v, _ := args.Get(0).(*model.MerchantApplication) return v, args.Error(1) } func (m *mockMerchantSvc) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) { args := m.Called(ctx, channelCode, req) return args.String(0), args.Error(1) } // newMerchantTestRouter 构建测试路由,注入固定 app_id 模拟鉴权 func newMerchantTestRouter(svc *mockMerchantSvc) *gin.Engine { r := gin.New() h := &MerchantHandler{merchantSvc: svc} auth := func(c *gin.Context) { c.Set("app_id", "app_test") c.Next() } g := r.Group("/api/v1/merchant", auth) g.POST("", h.CreateMerchant) g.GET("", h.ListMerchants) g.GET("/:merchantID", h.GetMerchant) g.POST("/:merchantID/apply", h.Apply) g.GET("/:merchantID/audit", h.QueryAuditStatus) return r } // --- CreateMerchant --- func TestCreateMerchant_OK(t *testing.T) { svc := new(mockMerchantSvc) svc.On("CreateMerchantForApp", mock.Anything, "app_test", mock.MatchedBy(func(m *model.Merchant) bool { return m.MerchantID == "m001" && m.MerchantName == "测试公司" })).Return(nil) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant", strings.NewReader(`{"merchant_id":"m001","merchant_name":"测试公司"}`)) req.Header.Set("Content-Type", "application/json") newMerchantTestRouter(svc).ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) var resp map[string]any json.Unmarshal(w.Body.Bytes(), &resp) assert.Equal(t, "0", resp["code"]) assert.Equal(t, "m001", resp["data"].(map[string]any)["merchant_id"]) svc.AssertExpectations(t) } func TestCreateMerchant_MissingName(t *testing.T) { svc := new(mockMerchantSvc) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant", strings.NewReader(`{"merchant_id":"m001"}`)) req.Header.Set("Content-Type", "application/json") newMerchantTestRouter(svc).ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) svc.AssertNotCalled(t, "CreateMerchantForApp") } func TestCreateMerchant_MissingID(t *testing.T) { svc := new(mockMerchantSvc) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant", strings.NewReader(`{"merchant_name":"测试公司"}`)) req.Header.Set("Content-Type", "application/json") newMerchantTestRouter(svc).ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) } // --- GetMerchant --- func TestGetMerchant_OK(t *testing.T) { svc := new(mockMerchantSvc) svc.On("GetMerchantForApp", mock.Anything, "app_test", "m001"). Return(&model.Merchant{MerchantID: "m001", AppID: "app_test"}, nil) w := httptest.NewRecorder() newMerchantTestRouter(svc).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m001", nil)) assert.Equal(t, http.StatusOK, w.Code) svc.AssertExpectations(t) } func TestGetMerchant_NotFound(t *testing.T) { svc := new(mockMerchantSvc) svc.On("GetMerchantForApp", mock.Anything, "app_test", "m999"). Return((*model.Merchant)(nil), errors.New("30001")) w := httptest.NewRecorder() newMerchantTestRouter(svc).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m999", nil)) assert.Equal(t, http.StatusNotFound, w.Code) } func TestGetMerchant_WrongApp(t *testing.T) { svc := new(mockMerchantSvc) svc.On("GetMerchantForApp", mock.Anything, "app_test", "other_m"). Return((*model.Merchant)(nil), errors.New("30001")) w := httptest.NewRecorder() newMerchantTestRouter(svc).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/merchant/other_m", nil)) // 跨 app 访问应返回 404,而不是 403,避免信息泄露 assert.Equal(t, http.StatusNotFound, w.Code) } // --- ListMerchants --- func TestListMerchants_DefaultPagination(t *testing.T) { svc := new(mockMerchantSvc) svc.On("ListMerchantsForApp", mock.Anything, "app_test", model.MerchantStatus(""), 20, 0). Return([]*model.Merchant{}, nil) w := httptest.NewRecorder() newMerchantTestRouter(svc).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/v1/merchant", nil)) assert.Equal(t, http.StatusOK, w.Code) svc.AssertExpectations(t) } // --- Apply --- func TestApply_OK(t *testing.T) { svc := new(mockMerchantSvc) svc.On("ApplyForApp", mock.Anything, "app_test", "m001", "HEEPAY", mock.Anything). Return("APP123", nil) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply", strings.NewReader(`{"channel_code":"HEEPAY","submit_data":{"name":"测试公司"}}`)) req.Header.Set("Content-Type", "application/json") newMerchantTestRouter(svc).ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) var resp map[string]any json.Unmarshal(w.Body.Bytes(), &resp) assert.Equal(t, "APP123", resp["data"].(map[string]any)["application_id"]) } func TestApply_MissingChannelCode(t *testing.T) { svc := new(mockMerchantSvc) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply", strings.NewReader(`{"submit_data":{}}`)) req.Header.Set("Content-Type", "application/json") newMerchantTestRouter(svc).ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) svc.AssertNotCalled(t, "ApplyForApp") } func TestApply_MerchantNotBelongToApp(t *testing.T) { svc := new(mockMerchantSvc) svc.On("ApplyForApp", mock.Anything, "app_test", "m_other", "HEEPAY", mock.Anything). Return("", errors.New("30001")) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m_other/apply", strings.NewReader(`{"channel_code":"HEEPAY"}`)) req.Header.Set("Content-Type", "application/json") newMerchantTestRouter(svc).ServeHTTP(w, req) assert.Equal(t, http.StatusNotFound, w.Code) }