draft
This commit is contained in:
373
backend/internal/api/handler/admin.go
Normal file
373
backend/internal/api/handler/admin.go
Normal file
@@ -0,0 +1,373 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/service"
|
||||
)
|
||||
|
||||
// AdminHandler 管理后台接口处理器
|
||||
type AdminHandler struct {
|
||||
matchSvc *service.PaymentMatchService
|
||||
merchantSvc *service.MerchantService
|
||||
reconSvc *service.ReconciliationService
|
||||
channelSvc *service.ChannelService
|
||||
appSvc *service.AppService
|
||||
}
|
||||
|
||||
func NewAdminHandler(
|
||||
matchSvc *service.PaymentMatchService,
|
||||
merchantSvc *service.MerchantService,
|
||||
reconSvc *service.ReconciliationService,
|
||||
channelSvc *service.ChannelService,
|
||||
appSvc *service.AppService,
|
||||
) *AdminHandler {
|
||||
return &AdminHandler{
|
||||
matchSvc: matchSvc,
|
||||
merchantSvc: merchantSvc,
|
||||
reconSvc: reconSvc,
|
||||
channelSvc: channelSvc,
|
||||
appSvc: appSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// --- 请求结构体 ---
|
||||
|
||||
type createAppReq struct {
|
||||
AppName string `json:"app_name" binding:"required"`
|
||||
}
|
||||
|
||||
type manualBindOrderReq struct {
|
||||
MatchID uint64 `json:"match_id" binding:"required"`
|
||||
TradeNo string `json:"trade_no" binding:"required"`
|
||||
Operator string `json:"operator" binding:"required"`
|
||||
}
|
||||
|
||||
type applyMerchantReq struct {
|
||||
ChannelCode string `json:"channel_code" binding:"required"`
|
||||
SubmitData map[string]any `json:"submit_data"`
|
||||
}
|
||||
|
||||
// appVO 应用列表视图(不含加密 secret)
|
||||
type appVO struct {
|
||||
AppID string `json:"app_id"`
|
||||
AppName string `json:"app_name"`
|
||||
Status int8 `json:"status"`
|
||||
CreatedAt any `json:"created_at"`
|
||||
UpdatedAt any `json:"updated_at"`
|
||||
}
|
||||
|
||||
// --- 应用管理 ---
|
||||
|
||||
// CreateApp 创建下游接入应用
|
||||
func (h *AdminHandler) CreateApp(c *gin.Context) {
|
||||
var req createAppReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, "10001", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appSvc.CreateApp(c.Request.Context(), req.AppName)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 明文 secret 仅在创建时返回一次,之后无法再查看
|
||||
OK(c, gin.H{
|
||||
"app_id": result.App.AppID,
|
||||
"app_name": result.App.AppName,
|
||||
"app_secret": result.PlainSecret,
|
||||
"status": result.App.Status,
|
||||
"created_at": result.App.CreatedAt,
|
||||
"secret_tip": "请妥善保存 app_secret,此后将无法再次查看",
|
||||
})
|
||||
}
|
||||
|
||||
// ListApps 查询应用列表
|
||||
func (h *AdminHandler) ListApps(c *gin.Context) {
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
|
||||
apps, err := h.appSvc.ListApps(c.Request.Context(), limit, offset)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
list := make([]appVO, 0, len(apps))
|
||||
for _, a := range apps {
|
||||
list = append(list, appVO{
|
||||
AppID: a.AppID,
|
||||
AppName: a.AppName,
|
||||
Status: a.Status,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
})
|
||||
}
|
||||
OK(c, gin.H{"list": list, "limit": limit, "offset": offset})
|
||||
}
|
||||
|
||||
// DisableApp 禁用应用
|
||||
func (h *AdminHandler) DisableApp(c *gin.Context) {
|
||||
appID := c.Param("appID")
|
||||
if err := h.appSvc.DisableApp(c.Request.Context(), appID); err != nil {
|
||||
if err.Error() == "20002" {
|
||||
c.JSON(http.StatusNotFound, Response{Code: "20002", Message: "应用不存在"})
|
||||
return
|
||||
}
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// EnableApp 启用应用
|
||||
func (h *AdminHandler) EnableApp(c *gin.Context) {
|
||||
appID := c.Param("appID")
|
||||
if err := h.appSvc.EnableApp(c.Request.Context(), appID); err != nil {
|
||||
if err.Error() == "20002" {
|
||||
c.JSON(http.StatusNotFound, Response{Code: "20002", Message: "应用不存在"})
|
||||
return
|
||||
}
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// ResetAppSecret 重置应用密钥
|
||||
func (h *AdminHandler) ResetAppSecret(c *gin.Context) {
|
||||
appID := c.Param("appID")
|
||||
plainSecret, err := h.appSvc.ResetSecret(c.Request.Context(), appID)
|
||||
if err != nil {
|
||||
if err.Error() == "20002" {
|
||||
c.JSON(http.StatusNotFound, Response{Code: "20002", Message: "应用不存在"})
|
||||
return
|
||||
}
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{
|
||||
"app_id": appID,
|
||||
"app_secret": plainSecret,
|
||||
"secret_tip": "请妥善保存 app_secret,此后将无法再次查看",
|
||||
})
|
||||
}
|
||||
|
||||
// --- 收款匹配管理 ---
|
||||
|
||||
// ListPendingMatches 查询待人工确认的收款记录
|
||||
func (h *AdminHandler) ListPendingMatches(c *gin.Context) {
|
||||
appID := c.Query("app_id")
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
|
||||
logs, err := h.matchSvc.ListPendingManual(c.Request.Context(), appID, limit, offset)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"list": logs, "limit": limit, "offset": offset})
|
||||
}
|
||||
|
||||
// ManualBindOrder 人工关联收款与订单
|
||||
func (h *AdminHandler) ManualBindOrder(c *gin.Context) {
|
||||
var req manualBindOrderReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, "10001", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.matchSvc.ManualBindOrder(c.Request.Context(), req.MatchID, req.TradeNo, req.Operator); err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// --- 商户管理 ---
|
||||
|
||||
// CreateMerchant 创建商户
|
||||
func (h *AdminHandler) CreateMerchant(c *gin.Context) {
|
||||
var m model.Merchant
|
||||
if err := c.ShouldBindJSON(&m); err != nil {
|
||||
BadRequest(c, "10001", err.Error())
|
||||
return
|
||||
}
|
||||
if err := h.merchantSvc.CreateMerchant(c.Request.Context(), &m); err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, m)
|
||||
}
|
||||
|
||||
// GetMerchant 查询商户信息
|
||||
func (h *AdminHandler) GetMerchant(c *gin.Context) {
|
||||
merchantID := c.Param("merchantID")
|
||||
m, err := h.merchantSvc.GetMerchant(c.Request.Context(), merchantID)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
c.JSON(http.StatusNotFound, Response{Code: "30001", Message: "merchant not found"})
|
||||
return
|
||||
}
|
||||
OK(c, m)
|
||||
}
|
||||
|
||||
// ListMerchants 查询商户列表
|
||||
func (h *AdminHandler) ListMerchants(c *gin.Context) {
|
||||
status := model.MerchantStatus(c.Query("status"))
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
|
||||
merchants, err := h.merchantSvc.ListMerchants(c.Request.Context(), status, limit, offset)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"list": merchants, "limit": limit, "offset": offset})
|
||||
}
|
||||
|
||||
// FreezeMerchant 冻结商户
|
||||
func (h *AdminHandler) FreezeMerchant(c *gin.Context) {
|
||||
merchantID := c.Param("merchantID")
|
||||
if err := h.merchantSvc.FreezeMerchant(c.Request.Context(), merchantID); err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// UnfreezeMerchant 解冻商户
|
||||
func (h *AdminHandler) UnfreezeMerchant(c *gin.Context) {
|
||||
merchantID := c.Param("merchantID")
|
||||
if err := h.merchantSvc.UnfreezeMerchant(c.Request.Context(), merchantID); err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// UploadMerchantFile 上传进件所需文件,返回 file_id
|
||||
// POST /api/v1/admin/merchant/upload-file
|
||||
// multipart/form-data: file=<binary>, channel_code=HEEPAY, file_media_type=01
|
||||
func (h *AdminHandler) UploadMerchantFile(c *gin.Context) {
|
||||
channelCode := c.PostForm("channel_code")
|
||||
if channelCode == "" {
|
||||
channelCode = "HEEPAY"
|
||||
}
|
||||
fileMediaType := c.PostForm("file_media_type")
|
||||
if fileMediaType == "" {
|
||||
BadRequest(c, "10001", "file_media_type is required")
|
||||
return
|
||||
}
|
||||
|
||||
fileHeader, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
BadRequest(c, "10001", "file is required: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
f, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
InternalError(c, "50001", "open file: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
content, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", "read file: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fileID, err := h.merchantSvc.UploadFile(c.Request.Context(), channelCode, &channel.UploadFileReq{
|
||||
FileContent: content,
|
||||
FileName: filepath.Base(fileHeader.Filename),
|
||||
FileMediaType: fileMediaType,
|
||||
})
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"file_id": fileID})
|
||||
}
|
||||
|
||||
// ApplyMerchant 商户进件申请
|
||||
func (h *AdminHandler) ApplyMerchant(c *gin.Context) {
|
||||
merchantID := c.Param("merchantID")
|
||||
var req applyMerchantReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, "10001", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
applicationID, err := h.merchantSvc.Apply(c.Request.Context(), merchantID, req.ChannelCode, req.SubmitData)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"application_id": applicationID})
|
||||
}
|
||||
|
||||
// QueryAuditStatus 查询进件审核状态
|
||||
func (h *AdminHandler) QueryAuditStatus(c *gin.Context) {
|
||||
merchantID := c.Param("merchantID")
|
||||
app, err := h.merchantSvc.QueryAuditStatus(c.Request.Context(), merchantID)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, app)
|
||||
}
|
||||
|
||||
// --- 对账管理 ---
|
||||
|
||||
// TriggerReconciliation 手动触发对账
|
||||
func (h *AdminHandler) TriggerReconciliation(c *gin.Context) {
|
||||
if err := h.reconSvc.RunDailyReconciliation(c.Request.Context()); err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// GetReconciliationReport 查询对账报告
|
||||
func (h *AdminHandler) GetReconciliationReport(c *gin.Context) {
|
||||
appID := c.Query("app_id")
|
||||
billDate := c.Query("bill_date")
|
||||
channelCode := c.Query("channel_code")
|
||||
|
||||
report, err := h.reconSvc.GetReport(c.Request.Context(), appID, billDate, channelCode)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, report)
|
||||
}
|
||||
|
||||
// GetReconciliationExceptions 查询对账异常明细
|
||||
func (h *AdminHandler) GetReconciliationExceptions(c *gin.Context) {
|
||||
reportIDStr := c.Param("reportID")
|
||||
reportID, err := strconv.ParseUint(reportIDStr, 10, 64)
|
||||
if err != nil {
|
||||
BadRequest(c, "10001", "invalid report_id")
|
||||
return
|
||||
}
|
||||
|
||||
exs, err := h.reconSvc.GetExceptions(c.Request.Context(), reportID)
|
||||
if err != nil {
|
||||
InternalError(c, "50001", err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, exs)
|
||||
}
|
||||
49
backend/internal/api/handler/auth.go
Normal file
49
backend/internal/api/handler/auth.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type adminAuthSvc interface {
|
||||
Login(ctx context.Context, username, password string) (string, error)
|
||||
}
|
||||
|
||||
type AuthHandler struct {
|
||||
authSvc adminAuthSvc
|
||||
}
|
||||
|
||||
func NewAuthHandler(authSvc adminAuthSvc) *AuthHandler {
|
||||
return &AuthHandler{authSvc: authSvc}
|
||||
}
|
||||
|
||||
type loginRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req loginRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": "400", "message": "参数错误"})
|
||||
return
|
||||
}
|
||||
|
||||
token, err := h.authSvc.Login(c.Request.Context(), req.Username, req.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": "UNAUTHORIZED", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": "0",
|
||||
"message": "ok",
|
||||
"data": gin.H{"token": token},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": "0", "message": "ok"})
|
||||
}
|
||||
188
backend/internal/api/handler/merchant.go
Normal file
188
backend/internal/api/handler/merchant.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// merchantService 定义 MerchantHandler 依赖的 service 方法,便于测试时注入 mock
|
||||
type merchantService interface {
|
||||
CreateMerchantForApp(ctx context.Context, appID string, m *model.Merchant) error
|
||||
GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error)
|
||||
ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
|
||||
ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error)
|
||||
QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error)
|
||||
UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error)
|
||||
}
|
||||
|
||||
// MerchantHandler 业务侧商户进件接口处理器(HMAC 鉴权,appID 隔离)
|
||||
type MerchantHandler struct {
|
||||
merchantSvc merchantService
|
||||
}
|
||||
|
||||
func NewMerchantHandler(svc merchantService) *MerchantHandler {
|
||||
return &MerchantHandler{merchantSvc: svc}
|
||||
}
|
||||
|
||||
// --- 请求结构体 ---
|
||||
|
||||
type createMerchantReq struct {
|
||||
MerchantID string `json:"merchant_id" binding:"required"`
|
||||
MerchantName string `json:"merchant_name" binding:"required"`
|
||||
LicenseNo string `json:"license_no"`
|
||||
LegalPerson string `json:"legal_person"`
|
||||
BankAccount string `json:"bank_account"`
|
||||
}
|
||||
|
||||
type merchantApplyReq struct {
|
||||
ChannelCode string `json:"channel_code" binding:"required"`
|
||||
SubmitData map[string]any `json:"submit_data"`
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// CreateMerchant POST /api/v1/merchant
|
||||
func (h *MerchantHandler) CreateMerchant(c *gin.Context) {
|
||||
var req createMerchantReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, errcode.ErrInvalidParam, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
appID := c.GetString("app_id")
|
||||
m := &model.Merchant{
|
||||
MerchantID: req.MerchantID,
|
||||
MerchantName: req.MerchantName,
|
||||
LicenseNo: req.LicenseNo,
|
||||
LegalPerson: req.LegalPerson,
|
||||
BankAccount: req.BankAccount,
|
||||
}
|
||||
if err := h.merchantSvc.CreateMerchantForApp(c.Request.Context(), appID, m); err != nil {
|
||||
InternalError(c, errcode.ErrInternalDB, err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"merchant_id": m.MerchantID})
|
||||
}
|
||||
|
||||
// ListMerchants GET /api/v1/merchant
|
||||
func (h *MerchantHandler) ListMerchants(c *gin.Context) {
|
||||
appID := c.GetString("app_id")
|
||||
status := model.MerchantStatus(c.Query("status"))
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
offset, _ := strconv.Atoi(c.DefaultQuery("offset", "0"))
|
||||
|
||||
merchants, err := h.merchantSvc.ListMerchantsForApp(c.Request.Context(), appID, status, limit, offset)
|
||||
if err != nil {
|
||||
InternalError(c, errcode.ErrInternalDB, err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"list": merchants, "limit": limit, "offset": offset})
|
||||
}
|
||||
|
||||
// GetMerchant GET /api/v1/merchant/:merchantID
|
||||
func (h *MerchantHandler) GetMerchant(c *gin.Context) {
|
||||
appID := c.GetString("app_id")
|
||||
merchantID := c.Param("merchantID")
|
||||
|
||||
m, err := h.merchantSvc.GetMerchantForApp(c.Request.Context(), appID, merchantID)
|
||||
if err != nil {
|
||||
if err.Error() == errcode.ErrOrderNotFound {
|
||||
c.JSON(http.StatusNotFound, Response{Code: errcode.ErrOrderNotFound, Message: "merchant not found"})
|
||||
return
|
||||
}
|
||||
InternalError(c, errcode.ErrInternalDB, err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, m)
|
||||
}
|
||||
|
||||
// UploadFile POST /api/v1/merchant/upload-file
|
||||
func (h *MerchantHandler) UploadFile(c *gin.Context) {
|
||||
channelCode := c.PostForm("channel_code")
|
||||
if channelCode == "" {
|
||||
channelCode = "HEEPAY"
|
||||
}
|
||||
fileMediaType := c.PostForm("file_media_type")
|
||||
if fileMediaType == "" {
|
||||
BadRequest(c, errcode.ErrInvalidParam, "file_media_type is required")
|
||||
return
|
||||
}
|
||||
|
||||
fileHeader, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
BadRequest(c, errcode.ErrInvalidParam, "file is required: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
f, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
InternalError(c, errcode.ErrInternalSystem, "open file: "+err.Error())
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
content, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
InternalError(c, errcode.ErrInternalSystem, "read file: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
fileID, err := h.merchantSvc.UploadFile(c.Request.Context(), channelCode, &channel.UploadFileReq{
|
||||
FileContent: content,
|
||||
FileName: filepath.Base(fileHeader.Filename),
|
||||
FileMediaType: fileMediaType,
|
||||
})
|
||||
if err != nil {
|
||||
InternalError(c, errcode.ErrInternalSystem, err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"file_id": fileID})
|
||||
}
|
||||
|
||||
// Apply POST /api/v1/merchant/:merchantID/apply
|
||||
func (h *MerchantHandler) Apply(c *gin.Context) {
|
||||
appID := c.GetString("app_id")
|
||||
merchantID := c.Param("merchantID")
|
||||
|
||||
var req merchantApplyReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, errcode.ErrInvalidParam, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
applicationID, err := h.merchantSvc.ApplyForApp(c.Request.Context(), appID, merchantID, req.ChannelCode, req.SubmitData)
|
||||
if err != nil {
|
||||
if err.Error() == errcode.ErrOrderNotFound {
|
||||
c.JSON(http.StatusNotFound, Response{Code: errcode.ErrOrderNotFound, Message: "merchant not found"})
|
||||
return
|
||||
}
|
||||
InternalError(c, errcode.ErrInternalSystem, err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, gin.H{"application_id": applicationID})
|
||||
}
|
||||
|
||||
// QueryAuditStatus GET /api/v1/merchant/:merchantID/audit
|
||||
func (h *MerchantHandler) QueryAuditStatus(c *gin.Context) {
|
||||
appID := c.GetString("app_id")
|
||||
merchantID := c.Param("merchantID")
|
||||
|
||||
app, err := h.merchantSvc.QueryAuditStatusForApp(c.Request.Context(), appID, merchantID)
|
||||
if err != nil {
|
||||
if err.Error() == errcode.ErrOrderNotFound {
|
||||
c.JSON(http.StatusNotFound, Response{Code: errcode.ErrOrderNotFound, Message: "merchant not found"})
|
||||
return
|
||||
}
|
||||
InternalError(c, errcode.ErrInternalSystem, err.Error())
|
||||
return
|
||||
}
|
||||
OK(c, app)
|
||||
}
|
||||
216
backend/internal/api/handler/merchant_test.go
Normal file
216
backend/internal/api/handler/merchant_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// mockMerchantSvc 实现 merchantService interface
|
||||
type mockMerchantSvc struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockMerchantSvc) CreateMerchantForApp(ctx context.Context, appID string, merchant *model.Merchant) error {
|
||||
return m.Called(ctx, appID, merchant).Error(0)
|
||||
}
|
||||
func (m *mockMerchantSvc) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, appID, merchantID)
|
||||
v, _ := args.Get(0).(*model.Merchant)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantSvc) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, appID, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantSvc) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) {
|
||||
args := m.Called(ctx, appID, merchantID, channelCode, bizContent)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantSvc) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, appID, merchantID)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantSvc) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) {
|
||||
args := m.Called(ctx, channelCode, req)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
// newMerchantTestRouter 构建测试路由,注入固定 app_id 模拟鉴权
|
||||
func newMerchantTestRouter(svc *mockMerchantSvc) *gin.Engine {
|
||||
r := gin.New()
|
||||
h := &MerchantHandler{merchantSvc: svc}
|
||||
|
||||
auth := func(c *gin.Context) {
|
||||
c.Set("app_id", "app_test")
|
||||
c.Next()
|
||||
}
|
||||
|
||||
g := r.Group("/api/v1/merchant", auth)
|
||||
g.POST("", h.CreateMerchant)
|
||||
g.GET("", h.ListMerchants)
|
||||
g.GET("/:merchantID", h.GetMerchant)
|
||||
g.POST("/:merchantID/apply", h.Apply)
|
||||
g.GET("/:merchantID/audit", h.QueryAuditStatus)
|
||||
return r
|
||||
}
|
||||
|
||||
// --- CreateMerchant ---
|
||||
|
||||
func TestCreateMerchant_OK(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("CreateMerchantForApp", mock.Anything, "app_test", mock.MatchedBy(func(m *model.Merchant) bool {
|
||||
return m.MerchantID == "m001" && m.MerchantName == "测试公司"
|
||||
})).Return(nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
|
||||
strings.NewReader(`{"merchant_id":"m001","merchant_name":"测试公司"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
newMerchantTestRouter(svc).ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var resp map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, "0", resp["code"])
|
||||
assert.Equal(t, "m001", resp["data"].(map[string]any)["merchant_id"])
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestCreateMerchant_MissingName(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
|
||||
strings.NewReader(`{"merchant_id":"m001"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
newMerchantTestRouter(svc).ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
svc.AssertNotCalled(t, "CreateMerchantForApp")
|
||||
}
|
||||
|
||||
func TestCreateMerchant_MissingID(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant",
|
||||
strings.NewReader(`{"merchant_name":"测试公司"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
newMerchantTestRouter(svc).ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
// --- GetMerchant ---
|
||||
|
||||
func TestGetMerchant_OK(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("GetMerchantForApp", mock.Anything, "app_test", "m001").
|
||||
Return(&model.Merchant{MerchantID: "m001", AppID: "app_test"}, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
newMerchantTestRouter(svc).ServeHTTP(w,
|
||||
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m001", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetMerchant_NotFound(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("GetMerchantForApp", mock.Anything, "app_test", "m999").
|
||||
Return((*model.Merchant)(nil), errors.New("30001"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
newMerchantTestRouter(svc).ServeHTTP(w,
|
||||
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/m999", nil))
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
func TestGetMerchant_WrongApp(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("GetMerchantForApp", mock.Anything, "app_test", "other_m").
|
||||
Return((*model.Merchant)(nil), errors.New("30001"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
newMerchantTestRouter(svc).ServeHTTP(w,
|
||||
httptest.NewRequest(http.MethodGet, "/api/v1/merchant/other_m", nil))
|
||||
|
||||
// 跨 app 访问应返回 404,而不是 403,避免信息泄露
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
// --- ListMerchants ---
|
||||
|
||||
func TestListMerchants_DefaultPagination(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("ListMerchantsForApp", mock.Anything, "app_test", model.MerchantStatus(""), 20, 0).
|
||||
Return([]*model.Merchant{}, nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
newMerchantTestRouter(svc).ServeHTTP(w,
|
||||
httptest.NewRequest(http.MethodGet, "/api/v1/merchant", nil))
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
svc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- Apply ---
|
||||
|
||||
func TestApply_OK(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("ApplyForApp", mock.Anything, "app_test", "m001", "HEEPAY", mock.Anything).
|
||||
Return("APP123", nil)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply",
|
||||
strings.NewReader(`{"channel_code":"HEEPAY","submit_data":{"name":"测试公司"}}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
newMerchantTestRouter(svc).ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var resp map[string]any
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
assert.Equal(t, "APP123", resp["data"].(map[string]any)["application_id"])
|
||||
}
|
||||
|
||||
func TestApply_MissingChannelCode(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m001/apply",
|
||||
strings.NewReader(`{"submit_data":{}}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
newMerchantTestRouter(svc).ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
svc.AssertNotCalled(t, "ApplyForApp")
|
||||
}
|
||||
|
||||
func TestApply_MerchantNotBelongToApp(t *testing.T) {
|
||||
svc := new(mockMerchantSvc)
|
||||
svc.On("ApplyForApp", mock.Anything, "app_test", "m_other", "HEEPAY", mock.Anything).
|
||||
Return("", errors.New("30001"))
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/merchant/m_other/apply",
|
||||
strings.NewReader(`{"channel_code":"HEEPAY"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
newMerchantTestRouter(svc).ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
242
backend/internal/api/handler/pay.go
Normal file
242
backend/internal/api/handler/pay.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/service"
|
||||
)
|
||||
|
||||
// PayHandler 支付相关 Handler
|
||||
type PayHandler struct {
|
||||
tradeSvc *service.TradeService
|
||||
refundSvc *service.RefundService
|
||||
}
|
||||
|
||||
func NewPayHandler(tradeSvc *service.TradeService, refundSvc *service.RefundService) *PayHandler {
|
||||
return &PayHandler{tradeSvc: tradeSvc, refundSvc: refundSvc}
|
||||
}
|
||||
|
||||
type unifiedOrderReq struct {
|
||||
ChannelCode string `json:"channel_code"`
|
||||
MerchantOrderNo string `json:"merchant_order_no" binding:"required"`
|
||||
PayMethod model.PayMethod `json:"pay_method" binding:"required"`
|
||||
Amount int64 `json:"amount" binding:"required,min=1"`
|
||||
ProfitSharingAmount int64 `json:"profit_sharing_amount"`
|
||||
Subject string `json:"subject" binding:"required"`
|
||||
NotifyURL string `json:"notify_url" binding:"required,url"`
|
||||
ExpireMinutes int `json:"expire_minutes"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
MerchantID string `json:"merchant_id"` // 可选,指定收款商户
|
||||
}
|
||||
|
||||
type closeOrderReq struct {
|
||||
TradeNo string `json:"trade_no" binding:"required"`
|
||||
}
|
||||
|
||||
type refundReq struct {
|
||||
TradeNo string `json:"trade_no" binding:"required"`
|
||||
RefundAmount int64 `json:"refund_amount" binding:"required,min=1"`
|
||||
Reason string `json:"reason"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
// UnifiedOrder POST /api/v1/pay/unified-order
|
||||
func (h *PayHandler) UnifiedOrder(c *gin.Context) {
|
||||
var req unifiedOrderReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, errcode.ErrInvalidParam, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
appID := c.GetString("app_id")
|
||||
resp, err := h.tradeSvc.CreateOrder(c.Request.Context(), &service.CreateOrderReq{
|
||||
AppID: appID,
|
||||
ChannelCode: req.ChannelCode,
|
||||
MerchantOrderNo: req.MerchantOrderNo,
|
||||
PayMethod: req.PayMethod,
|
||||
Amount: req.Amount,
|
||||
ProfitSharingAmount: req.ProfitSharingAmount,
|
||||
Subject: req.Subject,
|
||||
NotifyURL: req.NotifyURL,
|
||||
ExpireMinutes: req.ExpireMinutes,
|
||||
Extra: req.Extra,
|
||||
MerchantID: req.MerchantID,
|
||||
})
|
||||
if err != nil {
|
||||
handleBizError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
OK(c, gin.H{
|
||||
"trade_no": resp.TradeNo,
|
||||
"pay_credential": resp.PayCredential,
|
||||
"is_idempotent": resp.IsIdempotent,
|
||||
})
|
||||
}
|
||||
|
||||
// QueryOrder GET /api/v1/pay/query/:tradeNo
|
||||
func (h *PayHandler) QueryOrder(c *gin.Context) {
|
||||
tradeNo := c.Param("tradeNo")
|
||||
if tradeNo == "" {
|
||||
BadRequest(c, errcode.ErrMissingParam, "trade_no is required")
|
||||
return
|
||||
}
|
||||
|
||||
appID := c.GetString("app_id")
|
||||
order, err := h.tradeSvc.QueryOrder(c.Request.Context(), appID, tradeNo)
|
||||
if err != nil {
|
||||
handleBizError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
OK(c, gin.H{
|
||||
"trade_no": order.TradeNo,
|
||||
"merchant_order_no": order.MerchantOrderNo,
|
||||
"pay_method": order.PayMethod,
|
||||
"amount": order.Amount,
|
||||
"status": order.Status,
|
||||
"channel_trade_no": order.ChannelTradeNo,
|
||||
"pay_time": order.PayTime,
|
||||
"created_at": order.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
// CloseOrder POST /api/v1/pay/close
|
||||
func (h *PayHandler) CloseOrder(c *gin.Context) {
|
||||
var req closeOrderReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, errcode.ErrInvalidParam, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
appID := c.GetString("app_id")
|
||||
if err := h.tradeSvc.CloseOrder(c.Request.Context(), appID, req.TradeNo); err != nil {
|
||||
handleBizError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
OK(c, nil)
|
||||
}
|
||||
|
||||
// Refund POST /api/v1/pay/refund
|
||||
func (h *PayHandler) Refund(c *gin.Context) {
|
||||
var req refundReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
BadRequest(c, errcode.ErrInvalidParam, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
appID := c.GetString("app_id")
|
||||
refund, err := h.refundSvc.CreateRefund(c.Request.Context(), &service.CreateRefundReq{
|
||||
AppID: appID,
|
||||
TradeNo: req.TradeNo,
|
||||
RefundAmount: req.RefundAmount,
|
||||
Reason: req.Reason,
|
||||
NotifyURL: req.NotifyURL,
|
||||
})
|
||||
if err != nil {
|
||||
handleBizError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
OK(c, gin.H{
|
||||
"refund_no": refund.RefundNo,
|
||||
"trade_no": refund.TradeNo,
|
||||
"refund_amount": refund.RefundAmount,
|
||||
"status": refund.Status,
|
||||
"channel_refund_no": refund.ChannelRefundNo,
|
||||
})
|
||||
}
|
||||
|
||||
// QueryRefund GET /api/v1/pay/refund/query/:refundNo
|
||||
func (h *PayHandler) QueryRefund(c *gin.Context) {
|
||||
refundNo := c.Param("refundNo")
|
||||
if refundNo == "" {
|
||||
BadRequest(c, errcode.ErrMissingParam, "refund_no is required")
|
||||
return
|
||||
}
|
||||
|
||||
appID := c.GetString("app_id")
|
||||
refund, err := h.refundSvc.QueryRefund(c.Request.Context(), appID, refundNo)
|
||||
if err != nil {
|
||||
handleBizError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
OK(c, gin.H{
|
||||
"refund_no": refund.RefundNo,
|
||||
"trade_no": refund.TradeNo,
|
||||
"refund_amount": refund.RefundAmount,
|
||||
"status": refund.Status,
|
||||
"channel_refund_no": refund.ChannelRefundNo,
|
||||
"refund_time": refund.RefundTime,
|
||||
})
|
||||
}
|
||||
|
||||
// handleBizError 将业务错误映射到 HTTP 响应
|
||||
func handleBizError(c *gin.Context, err error) {
|
||||
code := err.Error()
|
||||
msg := errcode.Message(code)
|
||||
if msg == "未知错误" {
|
||||
msg = err.Error()
|
||||
code = errcode.ErrInternalSystem
|
||||
}
|
||||
|
||||
switch code {
|
||||
case errcode.ErrInvalidParam, errcode.ErrMissingParam, errcode.ErrInvalidPayMethod, errcode.ErrInvalidAmount:
|
||||
BadRequest(c, code, msg)
|
||||
case errcode.ErrUnauthorized, errcode.ErrAppNotFound:
|
||||
Unauthorized(c, code, msg)
|
||||
case errcode.ErrPermissionDenied:
|
||||
Forbidden(c, code, msg)
|
||||
case errcode.ErrOrderNotFound, errcode.ErrOrderAlreadyPaid, errcode.ErrOrderClosed,
|
||||
errcode.ErrRefundAmountExceed, errcode.ErrSharingAmountExceed, errcode.ErrOrderNotPaid,
|
||||
errcode.ErrSharingNotConfig, errcode.ErrSharingFeeExceed, errcode.ErrRefundNotFound:
|
||||
UnprocessableEntity(c, code, msg)
|
||||
case errcode.ErrChannelCreateFail, errcode.ErrChannelRefundFail,
|
||||
errcode.ErrChannelTimeout, errcode.ErrChannelNotSupport, errcode.ErrChannelVerifyFail:
|
||||
BadGateway(c, code, msg)
|
||||
default:
|
||||
_ = errors.New(code) // suppress unused
|
||||
InternalError(c, errcode.ErrInternalSystem, errcode.Message(errcode.ErrInternalSystem))
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyHandler 渠道回调 Handler
|
||||
type NotifyHandler struct {
|
||||
tradeSvc *service.TradeService
|
||||
}
|
||||
|
||||
func NewNotifyHandler(tradeSvc *service.TradeService) *NotifyHandler {
|
||||
return &NotifyHandler{tradeSvc: tradeSvc}
|
||||
}
|
||||
|
||||
// PaymentCallback POST /api/v1/notify/payment/:channelCode
|
||||
func (h *NotifyHandler) PaymentCallback(c *gin.Context) {
|
||||
channelCode := c.Param("channelCode")
|
||||
var rawBody []byte
|
||||
if v, exists := c.Get("raw_body"); exists {
|
||||
rawBody, _ = v.([]byte)
|
||||
}
|
||||
|
||||
result, err := h.tradeSvc.HandleUpstreamNotify(c.Request.Context(), channelCode, rawBody, headersMap(c))
|
||||
if err != nil {
|
||||
c.String(http.StatusOK, "fail")
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, result)
|
||||
}
|
||||
|
||||
func headersMap(c *gin.Context) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for k, v := range c.Request.Header {
|
||||
if len(v) > 0 {
|
||||
m[k] = v[0]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
73
backend/internal/api/handler/response.go
Normal file
73
backend/internal/api/handler/response.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Response 统一响应格式
|
||||
type Response struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data,omitempty"`
|
||||
TraceID string `json:"trace_id,omitempty"`
|
||||
}
|
||||
|
||||
// OK 成功响应
|
||||
func OK(c *gin.Context, data any) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Code: "0",
|
||||
Message: "success",
|
||||
Data: data,
|
||||
TraceID: traceID(c),
|
||||
})
|
||||
}
|
||||
|
||||
// Fail 失败响应
|
||||
func Fail(c *gin.Context, httpStatus int, code, message string) {
|
||||
c.JSON(httpStatus, Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
TraceID: traceID(c),
|
||||
})
|
||||
}
|
||||
|
||||
// BadRequest 400
|
||||
func BadRequest(c *gin.Context, code, message string) {
|
||||
Fail(c, http.StatusBadRequest, code, message)
|
||||
}
|
||||
|
||||
// Unauthorized 401
|
||||
func Unauthorized(c *gin.Context, code, message string) {
|
||||
Fail(c, http.StatusUnauthorized, code, message)
|
||||
}
|
||||
|
||||
// Forbidden 403
|
||||
func Forbidden(c *gin.Context, code, message string) {
|
||||
Fail(c, http.StatusForbidden, code, message)
|
||||
}
|
||||
|
||||
// UnprocessableEntity 422(业务规则错误)
|
||||
func UnprocessableEntity(c *gin.Context, code, message string) {
|
||||
Fail(c, http.StatusUnprocessableEntity, code, message)
|
||||
}
|
||||
|
||||
// InternalError 500
|
||||
func InternalError(c *gin.Context, code, message string) {
|
||||
Fail(c, http.StatusInternalServerError, code, message)
|
||||
}
|
||||
|
||||
// BadGateway 502(渠道错误)
|
||||
func BadGateway(c *gin.Context, code, message string) {
|
||||
Fail(c, http.StatusBadGateway, code, message)
|
||||
}
|
||||
|
||||
func traceID(c *gin.Context) string {
|
||||
if id, exists := c.Get("trace_id"); exists {
|
||||
if s, ok := id.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
99
backend/internal/api/middleware/auth.go
Normal file
99
backend/internal/api/middleware/auth.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"pay-bridge/internal/api/handler"
|
||||
"pay-bridge/internal/errcode"
|
||||
)
|
||||
|
||||
// AppLoader 根据 appId 加载 app 信息的接口
|
||||
type AppLoader interface {
|
||||
GetAppSecret(ctx context.Context, appID string) (string, error)
|
||||
}
|
||||
|
||||
// Auth 鉴权中间件
|
||||
// 请求头:X-App-Id、X-Timestamp、X-Sign
|
||||
// 签名算法:HMAC-SHA256(appId + timestamp + body, appSecret)
|
||||
func Auth(loader AppLoader) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
appID := c.GetHeader("X-App-Id")
|
||||
timestamp := c.GetHeader("X-Timestamp")
|
||||
sign := c.GetHeader("X-Sign")
|
||||
|
||||
if appID == "" || timestamp == "" || sign == "" {
|
||||
handler.Unauthorized(c, errcode.ErrUnauthorized, errcode.Message(errcode.ErrUnauthorized))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 时间戳防重放(5分钟内有效)
|
||||
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil || abs(time.Now().Unix()-ts) > 300 {
|
||||
handler.Unauthorized(c, errcode.ErrUnauthorized, "请求已过期")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
appSecret, err := loader.GetAppSecret(c.Request.Context(), appID)
|
||||
if err != nil {
|
||||
handler.Unauthorized(c, errcode.ErrAppNotFound, errcode.Message(errcode.ErrAppNotFound))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 读取 body(注意:body 只能读一次,需要提前 cache)
|
||||
body := bodyFromContext(c)
|
||||
|
||||
expectedSign := sign256(appID+timestamp+string(body), appSecret)
|
||||
if !hmac.Equal([]byte(expectedSign), []byte(sign)) {
|
||||
handler.Unauthorized(c, errcode.ErrUnauthorized, errcode.Message(errcode.ErrUnauthorized))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("app_id", appID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ChannelCallback 渠道回调鉴权(由渠道适配器验签,此中间件只做基础检查)
|
||||
func ChannelCallback() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
channelCode := c.Param("channelCode")
|
||||
if channelCode == "" {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func sign256(payload, secret string) string {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(payload))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func abs(n int64) int64 {
|
||||
if n < 0 {
|
||||
return -n
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func bodyFromContext(c *gin.Context) []byte {
|
||||
if v, exists := c.Get("raw_body"); exists {
|
||||
if b, ok := v.([]byte); ok {
|
||||
return b
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
21
backend/internal/api/middleware/body_cache.go
Normal file
21
backend/internal/api/middleware/body_cache.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CacheBody 缓存 request body(body 只能读一次,中间件提前读取并缓存)
|
||||
func CacheBody() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
body = []byte{}
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
c.Set("raw_body", body)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
48
backend/internal/api/middleware/jwt_auth.go
Normal file
48
backend/internal/api/middleware/jwt_auth.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// TokenParser 解析 JWT token 的接口
|
||||
type TokenParser interface {
|
||||
ParseToken(tokenStr string) (string, error)
|
||||
}
|
||||
|
||||
// JWTAuth 管理后台 JWT 鉴权中间件
|
||||
func JWTAuth(parser TokenParser) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"code": "401",
|
||||
"message": "未登录,请先登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"code": "401",
|
||||
"message": "Token 格式错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username, err := parser.ParseToken(parts[1])
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"code": "401",
|
||||
"message": "Token 无效或已过期",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("username", username)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
23
backend/internal/api/middleware/trace.go
Normal file
23
backend/internal/api/middleware/trace.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Trace 注入 trace_id 中间件
|
||||
func Trace() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
traceID := c.GetHeader("X-Trace-Id")
|
||||
if traceID == "" {
|
||||
b := make([]byte, 8)
|
||||
_, _ = rand.Read(b)
|
||||
traceID = hex.EncodeToString(b)
|
||||
}
|
||||
c.Set("trace_id", traceID)
|
||||
c.Header("X-Trace-Id", traceID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
107
backend/internal/api/router.go
Normal file
107
backend/internal/api/router.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"pay-bridge/internal/api/handler"
|
||||
"pay-bridge/internal/api/middleware"
|
||||
"pay-bridge/internal/app"
|
||||
)
|
||||
|
||||
// SetupRouter 注册所有路由
|
||||
func SetupRouter(a *app.App) *gin.Engine {
|
||||
payHandler := handler.NewPayHandler(a.TradeSvc, a.RefundSvc)
|
||||
notifyHandler := handler.NewNotifyHandler(a.TradeSvc)
|
||||
adminHandler := handler.NewAdminHandler(a.MatchSvc, a.MerchantSvc, a.ReconSvc, a.ChannelSvc, a.AppSvc)
|
||||
authHandler := handler.NewAuthHandler(a.AdminAuthSvc)
|
||||
merchantHandler := handler.NewMerchantHandler(a.MerchantSvc)
|
||||
|
||||
r := gin.New()
|
||||
r.Use(gin.Recovery())
|
||||
r.Use(middleware.Trace())
|
||||
r.Use(middleware.CacheBody())
|
||||
|
||||
// 健康检查
|
||||
r.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
})
|
||||
|
||||
// 上游渠道回调(渠道验签,不走 appId 鉴权)
|
||||
notify := r.Group("/api/v1/notify")
|
||||
{
|
||||
notify.POST("/payment/:channelCode", notifyHandler.PaymentCallback)
|
||||
}
|
||||
|
||||
// 下游系统调用接口(appId+appSecret 签名鉴权)
|
||||
v1 := r.Group("/api/v1", middleware.Auth(a.AppSvc))
|
||||
{
|
||||
pay := v1.Group("/pay")
|
||||
{
|
||||
pay.POST("/unified-order", payHandler.UnifiedOrder)
|
||||
pay.GET("/query/:tradeNo", payHandler.QueryOrder)
|
||||
pay.POST("/close", payHandler.CloseOrder)
|
||||
pay.POST("/refund", payHandler.Refund)
|
||||
pay.GET("/refund/query/:refundNo", payHandler.QueryRefund)
|
||||
}
|
||||
|
||||
merchantGroup := v1.Group("/merchant")
|
||||
{
|
||||
merchantGroup.POST("", merchantHandler.CreateMerchant)
|
||||
merchantGroup.GET("", merchantHandler.ListMerchants)
|
||||
merchantGroup.POST("/upload-file", merchantHandler.UploadFile)
|
||||
merchantGroup.GET("/:merchantID", merchantHandler.GetMerchant)
|
||||
merchantGroup.POST("/:merchantID/apply", merchantHandler.Apply)
|
||||
merchantGroup.GET("/:merchantID/audit", merchantHandler.QueryAuditStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// 管理后台接口
|
||||
adminPublic := r.Group("/api/v1/admin")
|
||||
{
|
||||
adminPublic.POST("/login", authHandler.Login)
|
||||
}
|
||||
|
||||
admin := r.Group("/api/v1/admin", middleware.JWTAuth(a.AdminAuthSvc))
|
||||
{
|
||||
admin.POST("/logout", authHandler.Logout)
|
||||
|
||||
// 应用管理
|
||||
appGroup := admin.Group("/app")
|
||||
{
|
||||
appGroup.POST("", adminHandler.CreateApp)
|
||||
appGroup.GET("", adminHandler.ListApps)
|
||||
appGroup.POST("/:appID/disable", adminHandler.DisableApp)
|
||||
appGroup.POST("/:appID/enable", adminHandler.EnableApp)
|
||||
appGroup.POST("/:appID/reset-secret", adminHandler.ResetAppSecret)
|
||||
}
|
||||
|
||||
// 收款匹配
|
||||
match := admin.Group("/match")
|
||||
{
|
||||
match.GET("/pending", adminHandler.ListPendingMatches)
|
||||
match.POST("/bind", adminHandler.ManualBindOrder)
|
||||
}
|
||||
|
||||
// 商户管理
|
||||
merchant := admin.Group("/merchant")
|
||||
{
|
||||
merchant.POST("", adminHandler.CreateMerchant)
|
||||
merchant.GET("", adminHandler.ListMerchants)
|
||||
merchant.POST("/upload-file", adminHandler.UploadMerchantFile)
|
||||
merchant.GET("/:merchantID", adminHandler.GetMerchant)
|
||||
merchant.POST("/:merchantID/freeze", adminHandler.FreezeMerchant)
|
||||
merchant.POST("/:merchantID/unfreeze", adminHandler.UnfreezeMerchant)
|
||||
merchant.POST("/:merchantID/apply", adminHandler.ApplyMerchant)
|
||||
merchant.GET("/:merchantID/audit", adminHandler.QueryAuditStatus)
|
||||
}
|
||||
|
||||
// 对账管理
|
||||
recon := admin.Group("/reconciliation")
|
||||
{
|
||||
recon.POST("/trigger", adminHandler.TriggerReconciliation)
|
||||
recon.GET("/report", adminHandler.GetReconciliationReport)
|
||||
recon.GET("/report/:reportID/exceptions", adminHandler.GetReconciliationExceptions)
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
126
backend/internal/app/app.go
Normal file
126
backend/internal/app/app.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/internal/service"
|
||||
"pay-bridge/pkg/config"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
// App 应用容器,持有所有初始化完成的 service 实例
|
||||
type App struct {
|
||||
Cfg *config.Config
|
||||
|
||||
// 对外暴露供 router 使用的 service
|
||||
AdminAuthSvc *service.AdminAuthService
|
||||
AppSvc *service.AppService
|
||||
TradeSvc *service.TradeService
|
||||
RefundSvc *service.RefundService
|
||||
NotifySvc *service.NotifyService
|
||||
MatchSvc *service.PaymentMatchService
|
||||
MerchantSvc *service.MerchantService
|
||||
ReconSvc *service.ReconciliationService
|
||||
ChannelSvc *service.ChannelService
|
||||
|
||||
// 内部资源
|
||||
db *gorm.DB
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// New 初始化所有基础设施和 service,返回就绪的 App 实例
|
||||
func New(cfg *config.Config) (*App, error) {
|
||||
a := &App{Cfg: cfg}
|
||||
|
||||
if err := a.initInfra(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.initServices()
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Start 启动后台任务(notify poller 等)
|
||||
func (a *App) Start(ctx context.Context) {
|
||||
a.NotifySvc.StartPoller(ctx, a.Cfg.Notify.PollerInterval, a.Cfg.Notify.PollerBatch)
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭:关闭 DB 和 Redis 连接
|
||||
func (a *App) Shutdown(ctx context.Context) {
|
||||
if a.rdb != nil {
|
||||
if err := a.rdb.Close(); err != nil {
|
||||
slog.Error("redis close error", "err", err)
|
||||
}
|
||||
}
|
||||
if a.db != nil {
|
||||
sqlDB, err := a.db.DB()
|
||||
if err == nil {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
slog.Error("db close error", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initInfra 初始化 DB 和 Redis
|
||||
func (a *App) initInfra() error {
|
||||
db, err := config.NewDB(a.Cfg.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.db = db
|
||||
|
||||
rdb, err := config.NewRedis(a.Cfg.Redis)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
a.rdb = rdb
|
||||
return nil
|
||||
}
|
||||
|
||||
// initServices 按依赖顺序构建 repo → service
|
||||
func (a *App) initServices() {
|
||||
encKey := a.Cfg.Security.FieldEncryptKey
|
||||
|
||||
// Repositories
|
||||
adminUserRepo := repository.NewAdminUserRepository(a.db)
|
||||
appRepo := repository.NewAppRepository(a.db)
|
||||
tradeRepo := repository.NewTradeOrderRepository(a.db)
|
||||
refundRepo := repository.NewRefundOrderRepository(a.db)
|
||||
notifyRepo := repository.NewNotifyLogRepository(a.db)
|
||||
channelCfgRepo := repository.NewChannelConfigRepository(a.db)
|
||||
seqRepo := repository.NewSequenceRepository(a.db)
|
||||
profitSharingRepo := repository.NewProfitSharingRepository(a.db)
|
||||
serviceFeeRepo := repository.NewServiceFeeRepository(a.db)
|
||||
matchRepo := repository.NewPaymentMatchRepository(a.db)
|
||||
merchantRepo := repository.NewMerchantRepository(a.db)
|
||||
wechatRepo := repository.NewWechatRepository(a.db)
|
||||
reconRepo := repository.NewReconciliationRepository(a.db)
|
||||
|
||||
// JWT expire hours
|
||||
jwtExpireHours := a.Cfg.JWT.ExpireHours
|
||||
if jwtExpireHours == 0 {
|
||||
jwtExpireHours = 24
|
||||
}
|
||||
|
||||
// Services
|
||||
a.AdminAuthSvc = service.NewAdminAuthService(adminUserRepo, a.Cfg.JWT.Secret, jwtExpireHours)
|
||||
a.ChannelSvc = service.NewChannelService(channelCfgRepo, encKey, a.Cfg.Channels)
|
||||
seqSvc := sequence.NewService(seqRepo)
|
||||
a.AppSvc = service.NewAppService(appRepo, encKey)
|
||||
a.NotifySvc = service.NewNotifyService(notifyRepo, tradeRepo, a.Cfg.Notify.HTTPTimeout)
|
||||
a.MerchantSvc = service.NewMerchantService(merchantRepo, a.ChannelSvc)
|
||||
a.TradeSvc = service.NewTradeService(tradeRepo, a.ChannelSvc, seqSvc, a.rdb, a.NotifySvc, a.MerchantSvc)
|
||||
a.RefundSvc = service.NewRefundService(refundRepo, tradeRepo, a.ChannelSvc, seqSvc, a.NotifySvc)
|
||||
a.MatchSvc = service.NewPaymentMatchService(matchRepo, tradeRepo, a.NotifySvc, a.TradeSvc)
|
||||
a.ReconSvc = service.NewReconciliationService(reconRepo, tradeRepo, a.ChannelSvc, appRepo)
|
||||
|
||||
// 以下 service 目前未直接暴露给 router,但已初始化供将来扩展使用
|
||||
_ = service.NewProfitSharingService(profitSharingRepo, tradeRepo, a.ChannelSvc, seqSvc, a.rdb)
|
||||
_ = service.NewServiceFeeService(serviceFeeRepo, tradeRepo, a.ChannelSvc)
|
||||
_ = service.NewWechatService(wechatRepo, encKey)
|
||||
}
|
||||
65
backend/internal/channel/factory.go
Normal file
65
backend/internal/channel/factory.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrChannelNotFound = errors.New("channel not found")
|
||||
ErrNotSupported = errors.New("operation not supported by this channel")
|
||||
)
|
||||
|
||||
var globalRegistry = &Registry{}
|
||||
|
||||
// Registry 渠道注册表
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
factories map[string]ChannelFactory
|
||||
}
|
||||
|
||||
// Register 注册渠道工厂(供各渠道包在 init() 中调用)
|
||||
func Register(channelCode string, factory ChannelFactory) {
|
||||
globalRegistry.Register(channelCode, factory)
|
||||
}
|
||||
|
||||
func (r *Registry) Register(channelCode string, factory ChannelFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.factories == nil {
|
||||
r.factories = make(map[string]ChannelFactory)
|
||||
}
|
||||
r.factories[channelCode] = factory
|
||||
}
|
||||
|
||||
// Get 根据渠道码、商户配置和网关地址获取渠道实例
|
||||
func Get(channelCode string, config *model.ChannelConfig, urls URLs) (PaymentChannel, error) {
|
||||
return globalRegistry.Get(channelCode, config, urls)
|
||||
}
|
||||
|
||||
func (r *Registry) Get(channelCode string, config *model.ChannelConfig, urls URLs) (PaymentChannel, error) {
|
||||
r.mu.RLock()
|
||||
factory, ok := r.factories[channelCode]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, ErrChannelNotFound
|
||||
}
|
||||
return factory(config, urls), nil
|
||||
}
|
||||
|
||||
// List 列出已注册的渠道码
|
||||
func List() []string {
|
||||
return globalRegistry.List()
|
||||
}
|
||||
|
||||
func (r *Registry) List() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
codes := make([]string, 0, len(r.factories))
|
||||
for code := range r.factories {
|
||||
codes = append(codes, code)
|
||||
}
|
||||
return codes
|
||||
}
|
||||
597
backend/internal/channel/heepay/adapter.go
Normal file
597
backend/internal/channel/heepay/adapter.go
Normal file
@@ -0,0 +1,597 @@
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelCode = "HEEPAY"
|
||||
codeSuccess = "10000"
|
||||
)
|
||||
|
||||
// cst 中国标准时间(UTC+8),避免依赖系统 time.Local
|
||||
var cst = time.FixedZone("CST", 8*3600)
|
||||
|
||||
// Adapter 汇元支付适配器
|
||||
type Adapter struct {
|
||||
config *model.ChannelConfig
|
||||
payURL string // 支付网关地址
|
||||
merchantURL string // 进件网关地址
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func New(config *model.ChannelConfig, urls channel.URLs) channel.PaymentChannel {
|
||||
return &Adapter{
|
||||
config: config,
|
||||
payURL: urls.PayURL,
|
||||
merchantURL: urls.MerchantURL,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
channel.Register(ChannelCode, New)
|
||||
}
|
||||
|
||||
func (a *Adapter) Code() string { return ChannelCode }
|
||||
|
||||
// CreateOrder 统一下单
|
||||
func (a *Adapter) CreateOrder(ctx context.Context, req *channel.CreateOrderReq) (*channel.CreateOrderResp, error) {
|
||||
tradeType, err := mapPayMethod(req.PayMethod)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
"body": req.Subject,
|
||||
"total_fee": req.Amount,
|
||||
"notify_url": req.NotifyURL,
|
||||
"trade_type": tradeType,
|
||||
}
|
||||
if req.Extra != nil {
|
||||
if openid, ok := req.Extra["openid"].(string); ok {
|
||||
biz["openid"] = openid
|
||||
}
|
||||
if subAppID, ok := req.Extra["sub_appid"].(string); ok {
|
||||
biz["sub_appid"] = subAppID
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.create", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelTradeNo, _ := resp["transaction_id"].(string)
|
||||
if channelTradeNo == "" {
|
||||
channelTradeNo, _ = resp["prepay_id"].(string)
|
||||
}
|
||||
|
||||
raw, _ := json.Marshal(resp)
|
||||
return &channel.CreateOrderResp{
|
||||
ChannelTradeNo: channelTradeNo,
|
||||
PayCredential: resp,
|
||||
RawResponse: raw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder 查询订单
|
||||
func (a *Adapter) QueryOrder(ctx context.Context, req *channel.QueryOrderReq) (*channel.QueryOrderResp, error) {
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
}
|
||||
if req.ChannelTradeNo != "" {
|
||||
biz["transaction_id"] = req.ChannelTradeNo
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.query", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &channel.QueryOrderResp{
|
||||
TradeNo: req.TradeNo,
|
||||
ChannelTradeNo: strVal(resp, "transaction_id"),
|
||||
}
|
||||
|
||||
switch strVal(resp, "trade_state") {
|
||||
case "SUCCESS":
|
||||
result.Status = model.TradeStatusPaid
|
||||
if s := strVal(resp, "pay_time"); s != "" {
|
||||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||||
result.PayTime = &t
|
||||
}
|
||||
case "CLOSED", "REVOKED":
|
||||
result.Status = model.TradeStatusClosed
|
||||
case "PAYERROR":
|
||||
result.Status = model.TradeStatusFailed
|
||||
default:
|
||||
result.Status = model.TradeStatusPaying
|
||||
}
|
||||
|
||||
if v, ok := resp["total_fee"].(float64); ok {
|
||||
result.Amount = int64(v)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CloseOrder 关闭订单
|
||||
func (a *Adapter) CloseOrder(ctx context.Context, req *channel.CloseOrderReq) error {
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
}
|
||||
_, err := a.post(ctx, "pay.heepay.trade.close", biz)
|
||||
return err
|
||||
}
|
||||
|
||||
// Refund 发起退款
|
||||
func (a *Adapter) Refund(ctx context.Context, req *channel.RefundReq) (*channel.RefundResp, error) {
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
"out_refund_no": req.RefundNo,
|
||||
"total_fee": req.TotalAmount,
|
||||
"refund_fee": req.RefundAmount,
|
||||
"refund_desc": req.Reason,
|
||||
"notify_url": req.NotifyURL,
|
||||
}
|
||||
if req.ChannelTradeNo != "" {
|
||||
biz["transaction_id"] = req.ChannelTradeNo
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.refund", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &channel.RefundResp{
|
||||
RefundNo: req.RefundNo,
|
||||
ChannelRefundNo: strVal(resp, "refund_id"),
|
||||
Status: model.RefundStatusProcessing,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryRefund 查询退款
|
||||
func (a *Adapter) QueryRefund(ctx context.Context, req *channel.QueryRefundReq) (*channel.QueryRefundResp, error) {
|
||||
biz := map[string]any{
|
||||
"out_refund_no": req.RefundNo,
|
||||
}
|
||||
if req.ChannelRefundNo != "" {
|
||||
biz["refund_id"] = req.ChannelRefundNo
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.refund.query", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &channel.QueryRefundResp{
|
||||
RefundNo: req.RefundNo,
|
||||
ChannelRefundNo: strVal(resp, "refund_id"),
|
||||
}
|
||||
|
||||
switch strVal(resp, "refund_status") {
|
||||
case "SUCCESS":
|
||||
result.Status = model.RefundStatusSuccess
|
||||
if s := strVal(resp, "refund_success_time"); s != "" {
|
||||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||||
result.RefundTime = &t
|
||||
}
|
||||
case "PROCESSING":
|
||||
result.Status = model.RefundStatusProcessing
|
||||
case "FAIL":
|
||||
result.Status = model.RefundStatusFailed
|
||||
default:
|
||||
result.Status = model.RefundStatusPending
|
||||
}
|
||||
|
||||
if v, ok := resp["refund_fee"].(float64); ok {
|
||||
result.RefundAmount = int64(v)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExtractTradeNo 从回调 body 中提取平台交易号(在验签前调用,仅做 JSON 解析)
|
||||
func (a *Adapter) ExtractTradeNo(rawBody []byte) (string, error) {
|
||||
var outer struct {
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rawBody, &outer); err != nil {
|
||||
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal body: %w", err)
|
||||
}
|
||||
var bizData struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
}
|
||||
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
|
||||
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal data: %w", err)
|
||||
}
|
||||
if bizData.OutTradeNo == "" {
|
||||
return "", fmt.Errorf("heepay ExtractTradeNo: out_trade_no is empty")
|
||||
}
|
||||
return bizData.OutTradeNo, nil
|
||||
}
|
||||
|
||||
// VerifyNotify 验证上游回调签名并解析
|
||||
// 汇元回调的签名规则与响应验签相同:公共参数+data整体 按字典序排列,用汇元公钥验签
|
||||
func (a *Adapter) VerifyNotify(ctx context.Context, rawBody []byte, headers map[string]string) (*channel.NotifyData, error) {
|
||||
// 解析外层公共参数
|
||||
var outer struct {
|
||||
Code string `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
TradeID string `json:"trade_id"`
|
||||
Sign string `json:"sign"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rawBody, &outer); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal notify body: %w", err)
|
||||
}
|
||||
|
||||
// 验签(data 整体作为字符串参与验签)
|
||||
verifyParams := map[string]string{
|
||||
"code": outer.Code,
|
||||
"msg": outer.Msg,
|
||||
"trade_id": outer.TradeID,
|
||||
"data": string(outer.Data),
|
||||
}
|
||||
if err := VerifyResponse(verifyParams, outer.Sign, a.config.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("verify notify sign: %w", err)
|
||||
}
|
||||
|
||||
if outer.Code != codeSuccess {
|
||||
return nil, fmt.Errorf("notify code not success: %s %s", outer.Code, outer.Msg)
|
||||
}
|
||||
|
||||
// 解析业务数据
|
||||
var bizData map[string]any
|
||||
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal notify data: %w", err)
|
||||
}
|
||||
|
||||
notifyType := strVal(bizData, "notify_type")
|
||||
result := &channel.NotifyData{
|
||||
TradeNo: strVal(bizData, "out_trade_no"),
|
||||
ChannelTradeNo: strVal(bizData, "transaction_id"),
|
||||
RawData: rawBody,
|
||||
}
|
||||
|
||||
switch notifyType {
|
||||
case "payment":
|
||||
result.NotifyType = model.NotifyTypePayment
|
||||
result.Status = model.TradeStatusPaid
|
||||
if v, ok := bizData["total_fee"].(float64); ok {
|
||||
result.Amount = int64(v)
|
||||
}
|
||||
if s := strVal(bizData, "pay_time"); s != "" {
|
||||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||||
result.PayTime = &t
|
||||
}
|
||||
case "refund":
|
||||
result.NotifyType = model.NotifyTypeRefund
|
||||
result.RefundNo = strVal(bizData, "out_refund_no")
|
||||
result.RefundStatus = model.RefundStatusSuccess
|
||||
if v, ok := bizData["refund_fee"].(float64); ok {
|
||||
result.RefundAmount = int64(v)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown notify_type: %s", notifyType)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ProfitSharing 分账
|
||||
func (a *Adapter) ProfitSharing(ctx context.Context, req *channel.ProfitSharingReq) (*channel.ProfitSharingResp, error) {
|
||||
biz := map[string]any{
|
||||
"transaction_id": req.ChannelTradeNo,
|
||||
"out_order_no": req.SharingNo,
|
||||
"receivers": []map[string]any{
|
||||
{
|
||||
"type": "MERCHANT_ID",
|
||||
"account": req.ReceiverMerchantID,
|
||||
"amount": req.Amount,
|
||||
"description": "分润",
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.profitsharing", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.ProfitSharingResp{
|
||||
SharingNo: req.SharingNo,
|
||||
ChannelSharingNo: strVal(resp, "order_id"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RollbackProfitSharing 回退分账
|
||||
func (a *Adapter) RollbackProfitSharing(ctx context.Context, req *channel.RollbackSharingReq) error {
|
||||
biz := map[string]any{
|
||||
"transaction_id": req.TradeNo,
|
||||
"out_order_no": req.SharingNo,
|
||||
"out_return_no": req.SharingNo + "_R",
|
||||
"return_account": a.config.MerchantID,
|
||||
"return_amount": 0,
|
||||
"description": "退款回退",
|
||||
}
|
||||
_, err := a.post(ctx, "pay.heepay.trade.profitsharing.return", biz)
|
||||
return err
|
||||
}
|
||||
|
||||
// DownloadBill 下载对账账单
|
||||
func (a *Adapter) DownloadBill(ctx context.Context, req *channel.DownloadBillReq) (*channel.BillData, error) {
|
||||
biz := map[string]any{
|
||||
"bill_date": req.BillDate,
|
||||
"bill_type": "ALL",
|
||||
}
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.bill.download", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 账单为 CSV 内容,暂存 raw 由对账服务解析
|
||||
_ = resp
|
||||
return &channel.BillData{}, nil
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到汇元进件网关(customer.file.upload)
|
||||
func (a *Adapter) UploadFile(ctx context.Context, req *channel.UploadFileReq) (*channel.UploadFileResp, error) {
|
||||
// 计算 MD5 签名
|
||||
h := md5.Sum(req.FileContent)
|
||||
fileSign := hex.EncodeToString(h[:])
|
||||
|
||||
// Base64 编码文件内容
|
||||
fileContentB64 := base64.StdEncoding.EncodeToString(req.FileContent)
|
||||
|
||||
biz := map[string]any{
|
||||
"file_content": fileContentB64,
|
||||
"file_sign": fileSign,
|
||||
"file_media_type": req.FileMediaType,
|
||||
}
|
||||
|
||||
resp, err := a.postToMerchant(ctx, "customer.file.upload", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileID := strVal(resp, "file_id")
|
||||
if fileID == "" {
|
||||
return nil, fmt.Errorf("heepay upload file: empty file_id in response")
|
||||
}
|
||||
return &channel.UploadFileResp{FileID: fileID}, nil
|
||||
}
|
||||
|
||||
// MerchantApply 企业入网申请(customer.enter.enterprise.apply)
|
||||
func (a *Adapter) MerchantApply(ctx context.Context, req *channel.MerchantApplyReq) (*channel.MerchantApplyResp, error) {
|
||||
resp, err := a.postToMerchant(ctx, "customer.enter.enterprise.apply", req.BizContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.MerchantApplyResp{
|
||||
RequestNo: strVal(resp, "request_no"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryMerchantStatus 查询商户状态(customer.enter.query)
|
||||
func (a *Adapter) QueryMerchantStatus(ctx context.Context, channelMerchantID string) (*channel.MerchantStatusResp, error) {
|
||||
biz := map[string]any{
|
||||
"request_no": channelMerchantID,
|
||||
}
|
||||
resp, err := a.postToMerchant(ctx, "customer.enter.query", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.MerchantStatusResp{
|
||||
ChannelMerchantID: strVal(resp, "merch_id"),
|
||||
Status: strVal(resp, "audit_state"),
|
||||
FailReason: strVal(resp, "audit_reason"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- 底层通信 ---
|
||||
|
||||
// heepayRequest 公共请求结构
|
||||
type heepayRequest struct {
|
||||
AppID string `json:"app_id"`
|
||||
Method string `json:"method"`
|
||||
Format string `json:"format"`
|
||||
Charset string `json:"charset"`
|
||||
SignType string `json:"sign_type"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
BizContent string `json:"biz_content"`
|
||||
Sign string `json:"sign"`
|
||||
}
|
||||
|
||||
// heepayResponse 公共响应结构
|
||||
// 注意:支付网关 code 为字符串("10000"),进件网关 code 为数字(10000),
|
||||
// 使用 json.RawMessage 兼容,再通过 codeStr() 统一转为字符串比较。
|
||||
type heepayResponse struct {
|
||||
Code json.RawMessage `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
SubCode string `json:"sub_code"`
|
||||
SubMsg string `json:"sub_msg"`
|
||||
Sign string `json:"sign"`
|
||||
TradeID string `json:"trade_id"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
// codeStr 将 code 字段统一转为字符串(兼容字符串和数字两种格式)
|
||||
func (r *heepayResponse) codeStr() string {
|
||||
if len(r.Code) == 0 {
|
||||
return ""
|
||||
}
|
||||
// 去掉引号(字符串形式)或直接返回数字字符串
|
||||
raw := string(r.Code)
|
||||
if len(raw) >= 2 && raw[0] == '"' {
|
||||
return raw[1 : len(raw)-1]
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// post 调用汇元支付网关(payURL)
|
||||
func (a *Adapter) post(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
|
||||
return a.postURL(ctx, a.payURL, method, bizParams)
|
||||
}
|
||||
|
||||
// postToMerchant 调用汇元进件网关(merchantURL)
|
||||
func (a *Adapter) postToMerchant(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
|
||||
return a.postURL(ctx, a.merchantURL, method, bizParams)
|
||||
}
|
||||
|
||||
// post 调用汇元支付网关(payURL)
|
||||
// 注意:原 post 方法内部调用 postURL
|
||||
func (a *Adapter) postURL(ctx context.Context, gatewayURL, method string, bizParams map[string]any) (map[string]any, error) {
|
||||
bizJSON, err := json.Marshal(bizParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal biz_content: %w", err)
|
||||
}
|
||||
bizContent := string(bizJSON)
|
||||
|
||||
timestamp := time.Now().In(cst).Format("2006-01-02 15:04:05")
|
||||
signParams := map[string]string{
|
||||
"app_id": a.config.MerchantID,
|
||||
"method": method,
|
||||
"format": "JSON",
|
||||
"charset": "utf-8",
|
||||
"sign_type": SignTypeRSA2,
|
||||
"timestamp": timestamp,
|
||||
"version": "1.0",
|
||||
"biz_content": bizContent,
|
||||
}
|
||||
|
||||
sign, err := Sign(signParams, a.config.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign request: %w", err)
|
||||
}
|
||||
|
||||
reqBody := heepayRequest{
|
||||
AppID: a.config.MerchantID,
|
||||
Method: method,
|
||||
Format: "JSON",
|
||||
Charset: "utf-8",
|
||||
SignType: SignTypeRSA2,
|
||||
Timestamp: timestamp,
|
||||
Version: "1.0",
|
||||
BizContent: bizContent,
|
||||
Sign: sign,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.DebugContext(ctx, "heepay request", "method", method, "url", gatewayURL)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, gatewayURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
httpResp, err := a.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("heepay http request: %w", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.DebugContext(ctx, "heepay response", "method", method, "body", string(respBytes))
|
||||
|
||||
if len(respBytes) == 0 {
|
||||
return nil, fmt.Errorf("heepay empty response from %s (method=%s)", gatewayURL, method)
|
||||
}
|
||||
|
||||
// 先解析为 raw map,确保所有字段都能参与验签
|
||||
var rawMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(respBytes, &rawMap); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
signRaw, _ := rawMap["sign"]
|
||||
signB64 := strings.Trim(string(signRaw), `"`)
|
||||
|
||||
// 构建验签参数:所有非 sign、非空字段
|
||||
verifyParams := make(map[string]string, len(rawMap))
|
||||
for k, v := range rawMap {
|
||||
if k == "sign" {
|
||||
continue
|
||||
}
|
||||
raw := string(v)
|
||||
if raw == "null" || raw == `""` {
|
||||
continue
|
||||
}
|
||||
// 字符串类型去引号并反转义,数字/对象/数组直接用 raw 值
|
||||
if len(raw) >= 2 && raw[0] == '"' {
|
||||
var s string
|
||||
json.Unmarshal(v, &s)
|
||||
verifyParams[k] = s
|
||||
} else {
|
||||
verifyParams[k] = raw
|
||||
}
|
||||
}
|
||||
|
||||
if err := VerifyResponse(verifyParams, signB64, a.config.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("verify response sign: %w", err)
|
||||
}
|
||||
|
||||
// 再用强类型 struct 方便后续处理
|
||||
var resp heepayResponse
|
||||
json.Unmarshal(respBytes, &resp)
|
||||
codeStr := resp.codeStr()
|
||||
|
||||
if codeStr != codeSuccess {
|
||||
return nil, fmt.Errorf("heepay error [%s] %s: %s %s", codeStr, resp.Msg, resp.SubCode, resp.SubMsg)
|
||||
}
|
||||
|
||||
var bizResult map[string]any
|
||||
if len(resp.Data) > 0 {
|
||||
if err := json.Unmarshal(resp.Data, &bizResult); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response data: %w", err)
|
||||
}
|
||||
}
|
||||
return bizResult, nil
|
||||
}
|
||||
|
||||
// mapPayMethod 将内部支付方式映射为汇元 trade_type
|
||||
func mapPayMethod(m model.PayMethod) (string, error) {
|
||||
mapping := map[model.PayMethod]string{
|
||||
model.PayMethodWechatJSAPI: "JSAPI",
|
||||
model.PayMethodWechatH5: "MWEB",
|
||||
model.PayMethodWechatNative: "NATIVE",
|
||||
model.PayMethodWechatMini: "MINIAPP",
|
||||
model.PayMethodAlipay: "ALI_NATIVE",
|
||||
model.PayMethodQuickPay: "QUICK",
|
||||
}
|
||||
t, ok := mapping[m]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported pay method: %s", m)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func strVal(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
13
backend/internal/channel/heepay/cbc.go
Normal file
13
backend/internal/channel/heepay/cbc.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package heepay
|
||||
|
||||
import "crypto/cipher"
|
||||
|
||||
// newCBCEncrypter 创建 CBC 加密器(封装 cipher.NewCBCEncrypter)
|
||||
func newCBCEncrypter(block cipher.Block, iv []byte) cipher.BlockMode {
|
||||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
// newCBCDecrypter 创建 CBC 解密器
|
||||
func newCBCDecrypter(block cipher.Block, iv []byte) cipher.BlockMode {
|
||||
return cipher.NewCBCDecrypter(block, iv)
|
||||
}
|
||||
142
backend/internal/channel/heepay/crypto.go
Normal file
142
backend/internal/channel/heepay/crypto.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/des"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// EncryptRequest 使用 RSA+3DES 加密请求体
|
||||
// 1. 生成随机 3DES 密钥(24 字节)
|
||||
// 2. 用 3DES-CBC 加密 JSON 请求体
|
||||
// 3. 用汇元 RSA 公钥加密 3DES 密钥
|
||||
func EncryptRequest(plaintext []byte, publicKeyPEM string) (encData string, encKey string, err error) {
|
||||
pubKey, err := parseRSAPublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 生成 3DES 密钥
|
||||
desKey := make([]byte, 24)
|
||||
if _, err = rand.Read(desKey); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 3DES 加密
|
||||
ciphertext, err := tripleDesEncrypt(plaintext, desKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// RSA 加密 3DES 密钥
|
||||
encKeyBytes, err := rsa.EncryptPKCS1v15(rand.Reader, pubKey, desKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
encData = base64.StdEncoding.EncodeToString(ciphertext)
|
||||
encKey = base64.StdEncoding.EncodeToString(encKeyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// DecryptResponse 使用 RSA 私钥 + 3DES 解密响应
|
||||
func DecryptResponse(encData, encKey, privateKeyPEM string) ([]byte, error) {
|
||||
privKey, err := parseRSAPrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encKeyBytes, err := base64.StdEncoding.DecodeString(encKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
desKey, err := rsa.DecryptPKCS1v15(rand.Reader, privKey, encKeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tripleDesDecrypt(ciphertext, desKey)
|
||||
}
|
||||
|
||||
// tripleDesEncrypt 3DES-CBC 加密(PKCS5Padding)
|
||||
func tripleDesEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := des.NewTripleDESCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
plaintext = pkcs5Padding(plaintext, blockSize)
|
||||
|
||||
iv := key[:blockSize] // 使用密钥前 8 字节作为 IV
|
||||
mode := newCBCEncrypter(block, iv)
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
mode.CryptBlocks(ciphertext, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// tripleDesDecrypt 3DES-CBC 解密
|
||||
func tripleDesDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := des.NewTripleDESCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
iv := key[:blockSize]
|
||||
mode := newCBCDecrypter(block, iv)
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext)
|
||||
}
|
||||
|
||||
func pkcs5Padding(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(data)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
func pkcs5Unpadding(data []byte) ([]byte, error) {
|
||||
length := len(data)
|
||||
if length == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
padding := int(data[length-1])
|
||||
if padding > length {
|
||||
return nil, errors.New("invalid padding")
|
||||
}
|
||||
return data[:length-padding], nil
|
||||
}
|
||||
|
||||
func parseRSAPublicKey(pemStr string) (*rsa.PublicKey, error) {
|
||||
block, _ := pem.Decode([]byte(pemStr))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("not RSA public key")
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
|
||||
func parseRSAPrivateKey(pemStr string) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode([]byte(pemStr))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
}
|
||||
355
backend/internal/channel/heepay/sandbox_test.go
Normal file
355
backend/internal/channel/heepay/sandbox_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
//go:build sandbox
|
||||
|
||||
// 沙盒集成测试:直连汇元沙盒环境,验证真实 API 调用。
|
||||
// 需要设置以下环境变量后运行:
|
||||
//
|
||||
// export HEEPAY_MERCHANT_ID=your_merchant_id
|
||||
// export HEEPAY_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----"
|
||||
// export HEEPAY_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
//
|
||||
// go test -tags sandbox ./internal/channel/heepay/ -v -timeout 60s
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
sandboxPayURL = "http://openapi.heepaydev.com/gateway"
|
||||
sandboxMerchantURL = "http://openapi.heepaydev.com/v1/customer/gateway"
|
||||
)
|
||||
|
||||
// newSandboxAdapter 从环境变量读取凭证,构造沙盒 Adapter
|
||||
func newSandboxAdapter(t *testing.T) *Adapter {
|
||||
t.Helper()
|
||||
merchantID := requireEnv(t, "HEEPAY_MERCHANT_ID")
|
||||
privateKey := requireEnv(t, "HEEPAY_PRIVATE_KEY")
|
||||
publicKey := requireEnv(t, "HEEPAY_PUBLIC_KEY")
|
||||
|
||||
// 支持 \n 转义(shell 传入时换行符可能被转义)
|
||||
privateKey = strings.ReplaceAll(privateKey, `\n`, "\n")
|
||||
publicKey = strings.ReplaceAll(publicKey, `\n`, "\n")
|
||||
|
||||
cfg := &model.ChannelConfig{
|
||||
MerchantID: merchantID,
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
Sandbox: 1,
|
||||
}
|
||||
return &Adapter{
|
||||
config: cfg,
|
||||
payURL: sandboxPayURL,
|
||||
merchantURL: sandboxMerchantURL,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func requireEnv(t *testing.T, key string) string {
|
||||
t.Helper()
|
||||
v := os.Getenv(key)
|
||||
require.NotEmpty(t, v, "环境变量 %s 未设置,无法运行沙盒测试", key)
|
||||
return v
|
||||
}
|
||||
|
||||
// uniqueOrderNo 生成测试用唯一订单号(避免重复)
|
||||
func uniqueOrderNo(prefix string) string {
|
||||
return fmt.Sprintf("%s%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// --- 下单 ---
|
||||
|
||||
func TestSandbox_CreateOrder_JSAPI(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := a.CreateOrder(ctx, &channel.CreateOrderReq{
|
||||
AppID: a.config.MerchantID,
|
||||
TradeNo: uniqueOrderNo("TEST"),
|
||||
MerchantOrderNo: uniqueOrderNo("ORD"),
|
||||
PayMethod: model.PayMethodWechatJSAPI,
|
||||
Amount: 1, // 1 分
|
||||
Subject: "沙盒测试-JSAPI",
|
||||
NotifyURL: "https://example.com/notify",
|
||||
ExpireTime: time.Now().Add(10 * time.Minute),
|
||||
Extra: map[string]any{"openid": "oBk9Y5YMoAb2UG0L1OWQ_xNoBnE0"}, // 沙盒 openid
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.PayCredential, "应返回支付凭证")
|
||||
t.Logf("pay_credential: %+v", resp.PayCredential)
|
||||
}
|
||||
|
||||
func TestSandbox_CreateOrder_Native(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := a.CreateOrder(ctx, &channel.CreateOrderReq{
|
||||
AppID: a.config.MerchantID,
|
||||
TradeNo: uniqueOrderNo("TEST"),
|
||||
MerchantOrderNo: uniqueOrderNo("ORD"),
|
||||
PayMethod: model.PayMethodWechatNative,
|
||||
Amount: 1,
|
||||
Subject: "沙盒测试-Native",
|
||||
NotifyURL: "https://example.com/notify",
|
||||
ExpireTime: time.Now().Add(10 * time.Minute),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.PayCredential)
|
||||
t.Logf("pay_credential: %+v", resp.PayCredential)
|
||||
}
|
||||
|
||||
// --- 查询订单 ---
|
||||
|
||||
func TestSandbox_QueryOrder_NotExist(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 查一个不存在的订单,预期渠道返回错误
|
||||
_, err := a.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: "NOT_EXIST_" + uniqueOrderNo(""),
|
||||
})
|
||||
|
||||
// 沙盒对不存在订单会返回错误,确认我们能正确解析
|
||||
assert.Error(t, err)
|
||||
t.Logf("expected error: %v", err)
|
||||
}
|
||||
|
||||
// --- 完整流程:下单 → 查询 → 关闭 ---
|
||||
|
||||
func TestSandbox_OrderLifecycle(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tradeNo := uniqueOrderNo("LIFE")
|
||||
|
||||
// 1. 下单
|
||||
createResp, err := a.CreateOrder(ctx, &channel.CreateOrderReq{
|
||||
AppID: a.config.MerchantID,
|
||||
TradeNo: tradeNo,
|
||||
MerchantOrderNo: uniqueOrderNo("ORD"),
|
||||
PayMethod: model.PayMethodWechatNative,
|
||||
Amount: 1,
|
||||
Subject: "沙盒-生命周期测试",
|
||||
NotifyURL: "https://example.com/notify",
|
||||
ExpireTime: time.Now().Add(10 * time.Minute),
|
||||
})
|
||||
require.NoError(t, err, "下单失败")
|
||||
t.Logf("下单成功,channel_trade_no: %s", createResp.ChannelTradeNo)
|
||||
|
||||
// 2. 查询(刚下单,应为 PAYING 状态)
|
||||
queryResp, err := a.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: createResp.ChannelTradeNo,
|
||||
})
|
||||
require.NoError(t, err, "查询订单失败")
|
||||
assert.Equal(t, model.TradeStatusPaying, queryResp.Status)
|
||||
t.Logf("订单状态: %s", queryResp.Status)
|
||||
|
||||
// 3. 关闭
|
||||
err = a.CloseOrder(ctx, &channel.CloseOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: createResp.ChannelTradeNo,
|
||||
})
|
||||
require.NoError(t, err, "关闭订单失败")
|
||||
t.Log("关闭订单成功")
|
||||
|
||||
// 4. 再次查询,应为 CLOSED
|
||||
queryResp2, err := a.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: createResp.ChannelTradeNo,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.TradeStatusClosed, queryResp2.Status)
|
||||
}
|
||||
|
||||
// --- 商户进件 ---
|
||||
|
||||
// TestSandbox_UploadFile 上传一张最小 JPEG 到沙盒,验证返回 file_id
|
||||
func TestSandbox_UploadFile(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 最小合法 JPEG(几十字节)
|
||||
minJPEG := []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01,
|
||||
0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0xFF, 0xD9,
|
||||
}
|
||||
|
||||
resp, err := a.UploadFile(ctx, &channel.UploadFileReq{
|
||||
FileContent: minJPEG,
|
||||
FileName: "test_license.jpg",
|
||||
FileMediaType: "image/jpeg",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.FileID, "应返回 file_id")
|
||||
t.Logf("file_id: %s", resp.FileID)
|
||||
}
|
||||
|
||||
// TestSandbox_MerchantApply 提交企业入网申请,验证返回 request_no
|
||||
// 沙盒环境审核不会真正处理,但接口应正常响应
|
||||
func TestSandbox_MerchantApply(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 沙盒测试用企业信息(使用固定测试数据)
|
||||
// 字段名以汇元 customer.enter.enterprise.apply 文档为准
|
||||
bizContent := map[string]any{
|
||||
"merch_name": "测试科技有限公司",
|
||||
"merch_short_name": "测试科技",
|
||||
"merch_type": "ENTERPRISE",
|
||||
"contact_name": "张三",
|
||||
"contact_phone": "13800138000",
|
||||
"contact_email": "test@example.com",
|
||||
"license_no": "91110000123456789X",
|
||||
"legal_name": "李四",
|
||||
"legal_id": "110101199001011234",
|
||||
"province": "北京市",
|
||||
"city": "北京市",
|
||||
"district": "朝阳区",
|
||||
"address": "朝阳区测试路1号",
|
||||
"bank_acct_name": "测试科技有限公司",
|
||||
"bank_acct_no": "6222021234567890123",
|
||||
"bank_name": "中国工商银行",
|
||||
}
|
||||
|
||||
resp, err := a.MerchantApply(ctx, &channel.MerchantApplyReq{
|
||||
MerchantID: uniqueOrderNo("M"),
|
||||
BizContent: bizContent,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.RequestNo, "应返回 request_no")
|
||||
t.Logf("request_no: %s", resp.RequestNo)
|
||||
}
|
||||
|
||||
// TestSandbox_QueryMerchantStatus 用已有的 request_no 查询进件状态
|
||||
// 若无真实 request_no,跳过(防止误报失败)
|
||||
func TestSandbox_QueryMerchantStatus(t *testing.T) {
|
||||
requestNo := os.Getenv("HEEPAY_TEST_REQUEST_NO")
|
||||
if requestNo == "" {
|
||||
t.Skip("未设置 HEEPAY_TEST_REQUEST_NO,跳过进件状态查询测试")
|
||||
}
|
||||
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := a.QueryMerchantStatus(ctx, requestNo)
|
||||
require.NoError(t, err)
|
||||
t.Logf("audit_state: %s, merch_id: %s, fail_reason: %s",
|
||||
resp.Status, resp.ChannelMerchantID, resp.FailReason)
|
||||
}
|
||||
|
||||
// TestSandbox_MerchantOnboardingFlow 完整进件流程:上传文件 → 提交申请 → 查询状态
|
||||
func TestSandbox_MerchantOnboardingFlow(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. 上传营业执照
|
||||
minJPEG := []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01,
|
||||
0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0xFF, 0xD9,
|
||||
}
|
||||
uploadResp, err := a.UploadFile(ctx, &channel.UploadFileReq{
|
||||
FileContent: minJPEG,
|
||||
FileName: "license.jpg",
|
||||
FileMediaType: "image/jpeg",
|
||||
})
|
||||
require.NoError(t, err, "上传营业执照失败")
|
||||
t.Logf("上传成功,file_id: %s", uploadResp.FileID)
|
||||
|
||||
// 2. 提交进件申请(带上文件 ID)
|
||||
bizContent := map[string]any{
|
||||
"merch_name": "沙盒流程测试公司",
|
||||
"merch_short_name": "沙盒测试",
|
||||
"merch_type": "ENTERPRISE",
|
||||
"contact_name": "王五",
|
||||
"contact_phone": "13900139000",
|
||||
"contact_email": "flow@example.com",
|
||||
"license_no": "91110000FLOW00001X",
|
||||
"license_img": uploadResp.FileID,
|
||||
"legal_name": "赵六",
|
||||
"legal_id": "110101199001019999",
|
||||
"province": "北京市",
|
||||
"city": "北京市",
|
||||
"district": "海淀区",
|
||||
"address": "海淀区测试路2号",
|
||||
"bank_acct_name": "沙盒流程测试公司",
|
||||
"bank_acct_no": "6222029876543210987",
|
||||
"bank_name": "中国建设银行",
|
||||
}
|
||||
applyResp, err := a.MerchantApply(ctx, &channel.MerchantApplyReq{
|
||||
MerchantID: uniqueOrderNo("FLOW"),
|
||||
BizContent: bizContent,
|
||||
})
|
||||
require.NoError(t, err, "提交进件申请失败")
|
||||
t.Logf("申请成功,request_no: %s", applyResp.RequestNo)
|
||||
|
||||
// 3. 查询进件状态(沙盒可能立即返回状态)
|
||||
statusResp, err := a.QueryMerchantStatus(ctx, applyResp.RequestNo)
|
||||
require.NoError(t, err, "查询进件状态失败")
|
||||
t.Logf("进件状态: audit_state=%s, merch_id=%s", statusResp.Status, statusResp.ChannelMerchantID)
|
||||
}
|
||||
|
||||
// --- 签名验证(本地,不发网络请求)---
|
||||
|
||||
// TestSign_And_Verify 本地签名/验签往返测试(不发网络请求)
|
||||
// 注意汇元双密钥体系:
|
||||
// - 请求签名:商户私钥签 → 汇元用商户公钥验
|
||||
// - 响应验签:汇元私钥签 → 商户用汇元公钥验(即 a.config.PublicKey)
|
||||
//
|
||||
// 本测试只验证商户私钥签名正确,使用从私钥派生的公钥做自验,
|
||||
// 不混用汇元公钥(两者非同一密钥对)。
|
||||
func TestSign_And_Verify(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
|
||||
params := map[string]string{
|
||||
"app_id": a.config.MerchantID,
|
||||
"method": "pay.heepay.trade.create",
|
||||
"format": "JSON",
|
||||
"charset": "utf-8",
|
||||
"sign_type": SignTypeRSA2,
|
||||
"timestamp": "2026-02-28 10:00:00",
|
||||
"version": "1.0",
|
||||
"biz_content": `{"out_trade_no":"TEST001"}`,
|
||||
}
|
||||
|
||||
sign, err := Sign(params, a.config.PrivateKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, sign)
|
||||
t.Logf("sign (first 32 chars): %s...", sign[:32])
|
||||
|
||||
// 从商户私钥提取对应公钥,用于验证我们自己签出来的签名
|
||||
merchantPubKeyPEM, err := extractPublicKeyFromPrivate(a.config.PrivateKey)
|
||||
require.NoError(t, err, "从私钥提取公钥失败")
|
||||
err = VerifyResponse(params, sign, merchantPubKeyPEM)
|
||||
assert.NoError(t, err, "用商户公钥验证商户私钥签名应通过")
|
||||
}
|
||||
|
||||
// extractPublicKeyFromPrivate 从商户私钥中派生出对应的公钥(DER base64 格式)
|
||||
func extractPublicKeyFromPrivate(privKeyStr string) (string, error) {
|
||||
privKey, err := parsePrivateKey(privKeyStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(derBytes), nil
|
||||
}
|
||||
|
||||
124
backend/internal/channel/heepay/sign.go
Normal file
124
backend/internal/channel/heepay/sign.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const SignTypeRSA2 = "RSA2"
|
||||
|
||||
// Sign 对请求参数签名(商户私钥)
|
||||
// params 为公共参数(不含 sign),biz_content 已作为整体字符串放入 params["biz_content"]
|
||||
func Sign(params map[string]string, privateKeyPEM string) (string, error) {
|
||||
payload := sortAndJoin(params)
|
||||
return signRSA2(payload, privateKeyPEM)
|
||||
}
|
||||
|
||||
// VerifyResponse 验证汇元响应签名(汇元公钥)
|
||||
// params 为响应公共参数(不含 sign),data 已作为整体 JSON 字符串放入 params["data"]
|
||||
func VerifyResponse(params map[string]string, sign, publicKeyPEM string) error {
|
||||
payload := sortAndJoin(params)
|
||||
return verifyRSA2(payload, sign, publicKeyPEM)
|
||||
}
|
||||
|
||||
// sortAndJoin 按参数名 A-Z 排序后拼接 key=value&...(排除 sign 和空值字段)
|
||||
func sortAndJoin(params map[string]string) string {
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
if k == "sign" || params[k] == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
parts := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
parts = append(parts, k+"="+params[k])
|
||||
}
|
||||
return strings.Join(parts, "&")
|
||||
}
|
||||
|
||||
// signRSA2 SHA256WithRSA 签名,Base64 编码
|
||||
func signRSA2(payload, privateKeyPEM string) (string, error) {
|
||||
privKey, err := parsePrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(payload))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("rsa sign: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(sig), nil
|
||||
}
|
||||
|
||||
// verifyRSA2 验证 SHA256WithRSA 签名
|
||||
func verifyRSA2(payload, signB64, publicKeyPEM string) error {
|
||||
pubKey, err := parsePublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse public key: %w", err)
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(signB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode sign base64: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(payload))
|
||||
return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig)
|
||||
}
|
||||
|
||||
func parsePrivateKey(pemStr string) (*rsa.PrivateKey, error) {
|
||||
var der []byte
|
||||
if block, _ := pem.Decode([]byte(pemStr)); block != nil {
|
||||
der = block.Bytes
|
||||
} else {
|
||||
// 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码
|
||||
cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "")
|
||||
var err error
|
||||
der, err = base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("private key is neither PEM nor valid base64: %w", err)
|
||||
}
|
||||
}
|
||||
// 优先尝试 PKCS8,再尝试 PKCS1
|
||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||||
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||||
return rsaKey, nil
|
||||
}
|
||||
return nil, errors.New("not an RSA private key")
|
||||
}
|
||||
return x509.ParsePKCS1PrivateKey(der)
|
||||
}
|
||||
|
||||
func parsePublicKey(pemStr string) (*rsa.PublicKey, error) {
|
||||
var der []byte
|
||||
if block, _ := pem.Decode([]byte(pemStr)); block != nil {
|
||||
der = block.Bytes
|
||||
} else {
|
||||
// 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码
|
||||
cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "")
|
||||
var err error
|
||||
der, err = base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("public key is neither PEM nor valid base64: %w", err)
|
||||
}
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(der)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse public key: %w", err)
|
||||
}
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("not an RSA public key")
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
231
backend/internal/channel/interface.go
Normal file
231
backend/internal/channel/interface.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// PaymentChannel 支付渠道统一接口,所有渠道适配器必须实现此接口
|
||||
type PaymentChannel interface {
|
||||
// Code 返回渠道编码,如 "HEEPAY"
|
||||
Code() string
|
||||
|
||||
// CreateOrder 统一下单,返回支付凭证
|
||||
CreateOrder(ctx context.Context, req *CreateOrderReq) (*CreateOrderResp, error)
|
||||
|
||||
// QueryOrder 查询订单状态
|
||||
QueryOrder(ctx context.Context, req *QueryOrderReq) (*QueryOrderResp, error)
|
||||
|
||||
// CloseOrder 关闭订单
|
||||
CloseOrder(ctx context.Context, req *CloseOrderReq) error
|
||||
|
||||
// Refund 发起退款
|
||||
Refund(ctx context.Context, req *RefundReq) (*RefundResp, error)
|
||||
|
||||
// QueryRefund 查询退款状态
|
||||
QueryRefund(ctx context.Context, req *QueryRefundReq) (*QueryRefundResp, error)
|
||||
|
||||
// ExtractTradeNo 从上游回调 body 中提取平台交易号(用于回调路由,在验签前调用)
|
||||
ExtractTradeNo(rawBody []byte) (string, error)
|
||||
|
||||
// VerifyNotify 验证上游回调签名,返回解析后的通知数据
|
||||
VerifyNotify(ctx context.Context, rawBody []byte, headers map[string]string) (*NotifyData, error)
|
||||
|
||||
// ProfitSharing 发起分账(渠道不支持时返回 ErrNotSupported)
|
||||
ProfitSharing(ctx context.Context, req *ProfitSharingReq) (*ProfitSharingResp, error)
|
||||
|
||||
// RollbackProfitSharing 回退分账(退款场景使用)
|
||||
RollbackProfitSharing(ctx context.Context, req *RollbackSharingReq) error
|
||||
|
||||
// DownloadBill 下载对账账单
|
||||
DownloadBill(ctx context.Context, req *DownloadBillReq) (*BillData, error)
|
||||
|
||||
// UploadFile 上传文件,返回 file_id(进件图片/视频上传)
|
||||
UploadFile(ctx context.Context, req *UploadFileReq) (*UploadFileResp, error)
|
||||
|
||||
// MerchantApply 商户进件
|
||||
MerchantApply(ctx context.Context, req *MerchantApplyReq) (*MerchantApplyResp, error)
|
||||
|
||||
// QueryMerchantStatus 查询商户审核状态
|
||||
QueryMerchantStatus(ctx context.Context, channelMerchantID string) (*MerchantStatusResp, error)
|
||||
}
|
||||
|
||||
// URLs 渠道网关地址(由配置文件注入,与商户无关)
|
||||
type URLs struct {
|
||||
PayURL string // 支付网关
|
||||
MerchantURL string // 进件网关
|
||||
}
|
||||
|
||||
// ChannelFactory 渠道工厂函数类型
|
||||
type ChannelFactory func(config *model.ChannelConfig, urls URLs) PaymentChannel
|
||||
|
||||
// --- 请求/响应类型 ---
|
||||
|
||||
// CreateOrderReq 下单请求
|
||||
type CreateOrderReq struct {
|
||||
AppID string
|
||||
TradeNo string
|
||||
MerchantOrderNo string
|
||||
PayMethod model.PayMethod
|
||||
Amount int64 // 分
|
||||
Subject string
|
||||
NotifyURL string
|
||||
ExpireTime time.Time
|
||||
Extra map[string]any // 支付方式特有参数(openid 等)
|
||||
}
|
||||
|
||||
// CreateOrderResp 下单响应
|
||||
type CreateOrderResp struct {
|
||||
ChannelTradeNo string
|
||||
PayCredential map[string]any // 支付凭证,各方式格式不同
|
||||
RawResponse []byte
|
||||
}
|
||||
|
||||
// QueryOrderReq 查询订单请求
|
||||
type QueryOrderReq struct {
|
||||
TradeNo string
|
||||
ChannelTradeNo string
|
||||
}
|
||||
|
||||
// QueryOrderResp 查询订单响应
|
||||
type QueryOrderResp struct {
|
||||
TradeNo string
|
||||
ChannelTradeNo string
|
||||
Status model.TradeStatus
|
||||
Amount int64
|
||||
PayTime *time.Time
|
||||
}
|
||||
|
||||
// CloseOrderReq 关闭订单请求
|
||||
type CloseOrderReq struct {
|
||||
TradeNo string
|
||||
ChannelTradeNo string
|
||||
}
|
||||
|
||||
// RefundReq 退款请求
|
||||
type RefundReq struct {
|
||||
TradeNo string
|
||||
ChannelTradeNo string
|
||||
RefundNo string
|
||||
RefundAmount int64
|
||||
TotalAmount int64
|
||||
Reason string
|
||||
NotifyURL string
|
||||
}
|
||||
|
||||
// RefundResp 退款响应
|
||||
type RefundResp struct {
|
||||
RefundNo string
|
||||
ChannelRefundNo string
|
||||
Status model.RefundStatus
|
||||
}
|
||||
|
||||
// QueryRefundReq 查询退款请求
|
||||
type QueryRefundReq struct {
|
||||
RefundNo string
|
||||
ChannelRefundNo string
|
||||
}
|
||||
|
||||
// QueryRefundResp 查询退款响应
|
||||
type QueryRefundResp struct {
|
||||
RefundNo string
|
||||
ChannelRefundNo string
|
||||
Status model.RefundStatus
|
||||
RefundAmount int64
|
||||
RefundTime *time.Time
|
||||
}
|
||||
|
||||
// NotifyData 上游回调解析结果
|
||||
type NotifyData struct {
|
||||
TradeNo string
|
||||
ChannelTradeNo string
|
||||
Status model.TradeStatus
|
||||
Amount int64
|
||||
PayTime *time.Time
|
||||
NotifyType model.NotifyType
|
||||
RefundNo string
|
||||
RefundStatus model.RefundStatus
|
||||
RefundAmount int64
|
||||
RawData []byte
|
||||
}
|
||||
|
||||
// ProfitSharingReq 分账请求
|
||||
type ProfitSharingReq struct {
|
||||
TradeNo string
|
||||
ChannelTradeNo string
|
||||
SharingNo string
|
||||
ReceiverMerchantID string
|
||||
Amount int64
|
||||
}
|
||||
|
||||
// ProfitSharingResp 分账响应
|
||||
type ProfitSharingResp struct {
|
||||
SharingNo string
|
||||
ChannelSharingNo string
|
||||
}
|
||||
|
||||
// RollbackSharingReq 回退分账请求
|
||||
type RollbackSharingReq struct {
|
||||
SharingNo string
|
||||
ChannelSharingNo string
|
||||
TradeNo string
|
||||
}
|
||||
|
||||
// DownloadBillReq 下载账单请求
|
||||
type DownloadBillReq struct {
|
||||
BillDate string // YYYY-MM-DD
|
||||
}
|
||||
|
||||
// BillData 账单数据
|
||||
type BillData struct {
|
||||
Records []BillRecord
|
||||
TotalAmount int64
|
||||
}
|
||||
|
||||
// BillRecord 账单记录
|
||||
type BillRecord struct {
|
||||
TradeNo string // 平台交易号(渠道账单中携带时填充)
|
||||
ChannelBillNo string // 渠道账单流水号
|
||||
ChannelTradeNo string
|
||||
Amount int64
|
||||
Status string
|
||||
TradeTime time.Time
|
||||
}
|
||||
|
||||
// UploadFileReq 文件上传请求
|
||||
type UploadFileReq struct {
|
||||
FileContent []byte // 原始文件二进制内容
|
||||
FileName string // 文件名(须含扩展名,如 license.jpg)
|
||||
FileMediaType string // 文件类型编码(01=营业执照 等)
|
||||
}
|
||||
|
||||
// UploadFileResp 文件上传响应
|
||||
type UploadFileResp struct {
|
||||
FileID string // 上传成功后返回,供进件接口使用
|
||||
}
|
||||
|
||||
// MerchantApplyReq 商户进件请求(customer.enter.enterprise.apply)
|
||||
type MerchantApplyReq struct {
|
||||
MerchantID string
|
||||
// BizContent 为完整的 biz_content JSON 对象,按照 001 文档结构直接传入
|
||||
// 包含 base_info / settlement_info / subject_info / identity_info / contact_info /
|
||||
// business_info 等顶层字段
|
||||
BizContent map[string]any
|
||||
}
|
||||
|
||||
// MerchantApplyResp 商户进件响应
|
||||
type MerchantApplyResp struct {
|
||||
RequestNo string // 汇元返回的申请流水号,用于后续查询/修改
|
||||
ChannelMerchantID string
|
||||
AuditStatus string
|
||||
}
|
||||
|
||||
// MerchantStatusResp 商户状态响应
|
||||
type MerchantStatusResp struct {
|
||||
ChannelMerchantID string
|
||||
Status string
|
||||
RejectReason string
|
||||
FailReason string
|
||||
}
|
||||
79
backend/internal/errcode/errcode.go
Normal file
79
backend/internal/errcode/errcode.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package errcode
|
||||
|
||||
// 错误码常量
|
||||
const (
|
||||
OK = "0"
|
||||
|
||||
// 参数错误
|
||||
ErrInvalidParam = "10001"
|
||||
ErrMissingParam = "10002"
|
||||
ErrInvalidPayMethod = "10003"
|
||||
ErrInvalidAmount = "10004"
|
||||
|
||||
// 鉴权错误
|
||||
ErrUnauthorized = "20001"
|
||||
ErrAppNotFound = "20002"
|
||||
ErrPermissionDenied = "20003"
|
||||
|
||||
// 业务规则错误
|
||||
ErrOrderNotFound = "30001"
|
||||
ErrOrderAlreadyPaid = "30002"
|
||||
ErrOrderClosed = "30003"
|
||||
ErrRefundAmountExceed = "30004"
|
||||
ErrSharingAmountExceed = "30005"
|
||||
ErrSharingNotConfig = "30006"
|
||||
ErrSharingFeeExceed = "30007"
|
||||
ErrOrderIdempotent = "30008"
|
||||
ErrOrderNotPaid = "30009"
|
||||
ErrRefundNotFound = "30010"
|
||||
|
||||
// 渠道错误
|
||||
ErrChannelCreateFail = "40001"
|
||||
ErrChannelRefundFail = "40002"
|
||||
ErrChannelTimeout = "40003"
|
||||
ErrChannelNotSupport = "40004"
|
||||
ErrChannelVerifyFail = "40005"
|
||||
|
||||
// 系统错误
|
||||
ErrInternalDB = "50001"
|
||||
ErrInternalRedis = "50002"
|
||||
ErrInternalSystem = "50099"
|
||||
)
|
||||
|
||||
// messages 错误码对应的默认消息
|
||||
var messages = map[string]string{
|
||||
OK: "success",
|
||||
ErrInvalidParam: "参数校验失败",
|
||||
ErrMissingParam: "缺少必填参数",
|
||||
ErrInvalidPayMethod: "不支持的支付方式",
|
||||
ErrInvalidAmount: "金额非法",
|
||||
ErrUnauthorized: "签名验证失败",
|
||||
ErrAppNotFound: "应用不存在或已禁用",
|
||||
ErrPermissionDenied: "无权操作该资源",
|
||||
ErrOrderNotFound: "订单不存在",
|
||||
ErrOrderAlreadyPaid: "订单已支付",
|
||||
ErrOrderClosed: "订单已关闭",
|
||||
ErrRefundAmountExceed: "退款金额超过可退金额",
|
||||
ErrSharingAmountExceed: "分润金额超过最大比例",
|
||||
ErrSharingNotConfig: "未配置分润接收方",
|
||||
ErrSharingFeeExceed: "分润与服务费之和超过订单金额",
|
||||
ErrOrderIdempotent: "幂等请求,返回已有订单",
|
||||
ErrOrderNotPaid: "订单未支付,无法退款",
|
||||
ErrRefundNotFound: "退款单不存在",
|
||||
ErrChannelCreateFail: "渠道下单失败",
|
||||
ErrChannelRefundFail: "渠道退款失败",
|
||||
ErrChannelTimeout: "渠道调用超时",
|
||||
ErrChannelNotSupport: "渠道不支持该功能",
|
||||
ErrChannelVerifyFail: "回调验签失败",
|
||||
ErrInternalDB: "数据库错误",
|
||||
ErrInternalRedis: "Redis 错误",
|
||||
ErrInternalSystem: "系统内部错误",
|
||||
}
|
||||
|
||||
// Message 返回错误码对应的消息
|
||||
func Message(code string) string {
|
||||
if msg, ok := messages[code]; ok {
|
||||
return msg
|
||||
}
|
||||
return "未知错误"
|
||||
}
|
||||
15
backend/internal/model/admin_user.go
Normal file
15
backend/internal/model/admin_user.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// AdminUser 管理后台用户
|
||||
func (AdminUser) TableName() string { return "admin_user" }
|
||||
|
||||
type AdminUser struct {
|
||||
ID uint64 `gorm:"primaryKey;autoIncrement"`
|
||||
Username string `gorm:"uniqueIndex;size:64;not null"`
|
||||
PasswordHash string `gorm:"size:128;not null"`
|
||||
Status int8 `gorm:"not null;default:1"` // 1=启用
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
16
backend/internal/model/app.go
Normal file
16
backend/internal/model/app.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// App 接入应用
|
||||
type App struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;uniqueIndex;size:32;not null"`
|
||||
AppSecret string `gorm:"column:app_secret;size:128;not null"` // AES 加密存储
|
||||
AppName string `gorm:"column:app_name;size:64;not null"`
|
||||
Status int8 `gorm:"column:status;not null;default:1"` // 1=启用 0=禁用
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (App) TableName() string { return "app" }
|
||||
22
backend/internal/model/channel_config.go
Normal file
22
backend/internal/model/channel_config.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ChannelConfig 渠道配置
|
||||
type ChannelConfig struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;uniqueIndex:uk_app_channel"`
|
||||
ChannelCode string `gorm:"column:channel_code;size:32;not null;uniqueIndex:uk_app_channel"`
|
||||
MerchantID string `gorm:"column:merchant_id;size:64;not null"`
|
||||
APIKey string `gorm:"column:api_key;type:text"` // AES 加密
|
||||
PrivateKey string `gorm:"column:private_key;type:text"` // AES 加密
|
||||
PublicKey string `gorm:"column:public_key;type:text"` // 渠道公钥(明文)
|
||||
NotifyURL string `gorm:"column:notify_url;size:512;not null"`
|
||||
Sandbox int8 `gorm:"column:sandbox;not null;default:0"` // 1=沙箱 0=生产
|
||||
ExtraConfig JSONMap `gorm:"column:extra_config;type:json"`
|
||||
Status int8 `gorm:"column:status;not null;default:1"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (ChannelConfig) TableName() string { return "channel_config" }
|
||||
59
backend/internal/model/merchant.go
Normal file
59
backend/internal/model/merchant.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// MerchantStatus 商户状态
|
||||
type MerchantStatus string
|
||||
|
||||
const (
|
||||
MerchantStatusPending MerchantStatus = "PENDING"
|
||||
MerchantStatusActive MerchantStatus = "ACTIVE"
|
||||
MerchantStatusFrozen MerchantStatus = "FROZEN"
|
||||
MerchantStatusRejected MerchantStatus = "REJECTED"
|
||||
)
|
||||
|
||||
// AuditStatus 进件审核状态
|
||||
type AuditStatus string
|
||||
|
||||
const (
|
||||
AuditStatusSubmitting AuditStatus = "SUBMITTING"
|
||||
AuditStatusReviewing AuditStatus = "REVIEWING"
|
||||
AuditStatusApproved AuditStatus = "APPROVED"
|
||||
AuditStatusRejected AuditStatus = "REJECTED"
|
||||
)
|
||||
|
||||
// Merchant 商户
|
||||
type Merchant struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
MerchantID string `gorm:"column:merchant_id;uniqueIndex;size:32;not null"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;default:'';index"`
|
||||
MerchantName string `gorm:"column:merchant_name;size:128;not null"`
|
||||
LicenseNo string `gorm:"column:license_no;size:64"`
|
||||
LegalPerson string `gorm:"column:legal_person;size:64"`
|
||||
BankAccount string `gorm:"column:bank_account;size:64"` // 脱敏
|
||||
ChannelMerchantID string `gorm:"column:channel_merchant_id;size:64"`
|
||||
Status MerchantStatus `gorm:"column:status;size:20;not null;default:PENDING;index"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (Merchant) TableName() string { return "merchant" }
|
||||
|
||||
// MerchantApplication 商户进件申请
|
||||
// 一条记录对应一个商户在一个渠道的进件,(merchant_id, channel_code) 唯一
|
||||
type MerchantApplication struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
ApplicationID string `gorm:"column:application_id;uniqueIndex;size:32;not null"`
|
||||
MerchantID string `gorm:"column:merchant_id;size:32;not null;index"`
|
||||
ChannelCode string `gorm:"column:channel_code;size:32;not null"`
|
||||
ChannelMerchantID string `gorm:"column:channel_merchant_id;size:64;not null;default:''"`
|
||||
SubmitData JSONMap `gorm:"column:submit_data;type:json"`
|
||||
AuditStatus AuditStatus `gorm:"column:audit_status;size:20;not null;default:SUBMITTING"`
|
||||
RejectReason string `gorm:"column:reject_reason;size:512"`
|
||||
SubmittedAt time.Time `gorm:"column:submitted_at"`
|
||||
AuditedAt *time.Time `gorm:"column:audited_at"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (MerchantApplication) TableName() string { return "merchant_application" }
|
||||
37
backend/internal/model/notify_log.go
Normal file
37
backend/internal/model/notify_log.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// NotifyStatus 通知状态
|
||||
type NotifyStatus string
|
||||
|
||||
const (
|
||||
NotifyStatusPending NotifyStatus = "PENDING"
|
||||
NotifyStatusSuccess NotifyStatus = "SUCCESS"
|
||||
NotifyStatusRetry NotifyStatus = "RETRY"
|
||||
NotifyStatusGiveup NotifyStatus = "GIVEUP"
|
||||
)
|
||||
|
||||
// NotifyType 通知类型
|
||||
type NotifyType string
|
||||
|
||||
const (
|
||||
NotifyTypePayment NotifyType = "PAYMENT"
|
||||
NotifyTypeRefund NotifyType = "REFUND"
|
||||
)
|
||||
|
||||
// NotifyLog 下游通知记录
|
||||
type NotifyLog struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
TradeNo string `gorm:"column:trade_no;size:32;not null;uniqueIndex:uk_trade_notify_type"`
|
||||
NotifyType NotifyType `gorm:"column:notify_type;size:20;not null;uniqueIndex:uk_trade_notify_type"`
|
||||
NotifyURL string `gorm:"column:notify_url;size:512;not null"`
|
||||
Status NotifyStatus `gorm:"column:status;size:20;not null;default:PENDING"`
|
||||
RetryCount int `gorm:"column:retry_count;not null;default:0"`
|
||||
NextRetryTime *time.Time `gorm:"column:next_retry_time;index:idx_status_next_retry"`
|
||||
LastResponse string `gorm:"column:last_response;type:text"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (NotifyLog) TableName() string { return "notify_log" }
|
||||
25
backend/internal/model/order_sequence.go
Normal file
25
backend/internal/model/order_sequence.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// SeqType 序列类型
|
||||
type SeqType string
|
||||
|
||||
const (
|
||||
SeqTypeTrade SeqType = "TRADE"
|
||||
SeqTypeRefund SeqType = "REFUND"
|
||||
SeqTypeSharing SeqType = "SHARING"
|
||||
)
|
||||
|
||||
// OrderSequence 订单编码序列
|
||||
type OrderSequence struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;uniqueIndex:uk_app_type"`
|
||||
SeqType SeqType `gorm:"column:seq_type;size:20;not null;uniqueIndex:uk_app_type"`
|
||||
Prefix string `gorm:"column:prefix;size:8;not null"`
|
||||
CurrentValue uint64 `gorm:"column:current_value;not null;default:0"`
|
||||
Step int `gorm:"column:step;not null;default:1"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (OrderSequence) TableName() string { return "order_sequence" }
|
||||
49
backend/internal/model/payment_match.go
Normal file
49
backend/internal/model/payment_match.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// MatchStatus 匹配状态
|
||||
type MatchStatus string
|
||||
|
||||
const (
|
||||
MatchStatusMatched MatchStatus = "MATCHED"
|
||||
MatchStatusPendingManual MatchStatus = "PENDING_MANUAL"
|
||||
MatchStatusNameDiff MatchStatus = "NAME_DIFF" // 匹配成功但名称不一致
|
||||
)
|
||||
|
||||
// SubMerchantAccount 子商户固定收款账户
|
||||
type SubMerchantAccount struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;index:idx_app_merchant"`
|
||||
SubMerchantID string `gorm:"column:sub_merchant_id;size:64;not null;index:idx_app_merchant"`
|
||||
ChannelCode string `gorm:"column:channel_code;size:32;not null"`
|
||||
AccountType string `gorm:"column:account_type;size:20;not null"` // BANK_CARD
|
||||
AccountNo string `gorm:"column:account_no;size:64;not null;index"` // 脱敏
|
||||
AccountNoEnc string `gorm:"column:account_no_enc;type:text"` // AES 加密完整账号
|
||||
AccountName string `gorm:"column:account_name;size:128;not null"`
|
||||
BankName string `gorm:"column:bank_name;size:64"`
|
||||
Status int8 `gorm:"column:status;not null;default:1"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (SubMerchantAccount) TableName() string { return "sub_merchant_account" }
|
||||
|
||||
// PaymentMatchLog 收款匹配记录
|
||||
type PaymentMatchLog struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AccountID uint64 `gorm:"column:account_id;not null;index:idx_account_status"`
|
||||
TradeNo string `gorm:"column:trade_no;size:32;index"`
|
||||
IncomingAmount int64 `gorm:"column:incoming_amount;not null"`
|
||||
IncomingRemark string `gorm:"column:incoming_remark;size:256"`
|
||||
PayerName string `gorm:"column:payer_name;size:128"`
|
||||
ChannelBillNo string `gorm:"column:channel_bill_no;size:64;index"`
|
||||
MatchStatus MatchStatus `gorm:"column:match_status;size:20;not null;index:idx_account_status"`
|
||||
NameDiff int8 `gorm:"column:name_diff;not null;default:0"` // 1=名称不一致
|
||||
MatchTime *time.Time `gorm:"column:match_time"`
|
||||
Operator string `gorm:"column:operator;size:64"` // 人工操作者
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (PaymentMatchLog) TableName() string { return "payment_match_log" }
|
||||
58
backend/internal/model/profit_sharing.go
Normal file
58
backend/internal/model/profit_sharing.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ProfitSharingStatus 分润状态
|
||||
type ProfitSharingStatus string
|
||||
|
||||
const (
|
||||
ProfitSharingStatusPending ProfitSharingStatus = "PENDING"
|
||||
ProfitSharingStatusProcessing ProfitSharingStatus = "PROCESSING"
|
||||
ProfitSharingStatusSuccess ProfitSharingStatus = "SUCCESS"
|
||||
ProfitSharingStatusFailed ProfitSharingStatus = "FAILED"
|
||||
ProfitSharingStatusRollback ProfitSharingStatus = "ROLLBACK"
|
||||
)
|
||||
|
||||
// ProfitSharingConfig 分润配置(应用级)
|
||||
type ProfitSharingConfig struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;uniqueIndex"`
|
||||
ReceiverMerchantID string `gorm:"column:receiver_merchant_id;size:64;not null"`
|
||||
ReceiverType string `gorm:"column:receiver_type;size:20;not null"` // PLATFORM / SUB_MERCHANT
|
||||
MaxSharingRatio float64 `gorm:"column:max_sharing_ratio;type:decimal(5,4);not null"`
|
||||
Status int8 `gorm:"column:status;not null;default:1"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (ProfitSharingConfig) TableName() string { return "profit_sharing_config" }
|
||||
|
||||
// ProfitSharingOrder 分润记录
|
||||
type ProfitSharingOrder struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
SharingNo string `gorm:"column:sharing_no;uniqueIndex;size:32;not null"`
|
||||
TradeNo string `gorm:"column:trade_no;uniqueIndex;size:32;not null"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null"`
|
||||
ReceiverMerchantID string `gorm:"column:receiver_merchant_id;size:64;not null"`
|
||||
SharingAmount int64 `gorm:"column:sharing_amount;not null"`
|
||||
Status ProfitSharingStatus `gorm:"column:status;size:20;not null;default:PENDING"`
|
||||
ChannelSharingNo string `gorm:"column:channel_sharing_no;size:64"`
|
||||
FailReason string `gorm:"column:fail_reason;size:256"`
|
||||
SharingTime *time.Time `gorm:"column:sharing_time"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (ProfitSharingOrder) TableName() string { return "profit_sharing_order" }
|
||||
|
||||
// ProfitSharingLog 分润流水
|
||||
type ProfitSharingLog struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
SharingNo string `gorm:"column:sharing_no;size:32;not null;index"`
|
||||
Action string `gorm:"column:action;size:20;not null"` // SPLIT / ROLLBACK
|
||||
Amount int64 `gorm:"column:amount;not null"`
|
||||
Status string `gorm:"column:status;size:20;not null"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
}
|
||||
|
||||
func (ProfitSharingLog) TableName() string { return "profit_sharing_log" }
|
||||
44
backend/internal/model/reconciliation.go
Normal file
44
backend/internal/model/reconciliation.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// ReconciliationStatus 对账单状态
|
||||
type ReconciliationStatus string
|
||||
|
||||
const (
|
||||
ReconciliationStatusPending ReconciliationStatus = "PENDING"
|
||||
ReconciliationStatusMatched ReconciliationStatus = "MATCHED"
|
||||
ReconciliationStatusException ReconciliationStatus = "EXCEPTION"
|
||||
)
|
||||
|
||||
// ReconciliationReport 对账报告
|
||||
type ReconciliationReport struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;index:idx_app_date"`
|
||||
ChannelCode string `gorm:"column:channel_code;size:32;not null"`
|
||||
BillDate string `gorm:"column:bill_date;size:10;not null;index:idx_app_date"` // yyyy-MM-dd
|
||||
TotalCount int `gorm:"column:total_count;not null;default:0"`
|
||||
TotalAmount int64 `gorm:"column:total_amount;not null;default:0"` // 分
|
||||
MatchedCount int `gorm:"column:matched_count;not null;default:0"`
|
||||
ExceptionCount int `gorm:"column:exception_count;not null;default:0"`
|
||||
Status ReconciliationStatus `gorm:"column:status;size:20;not null;default:PENDING"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (ReconciliationReport) TableName() string { return "reconciliation_report" }
|
||||
|
||||
// ReconciliationException 对账异常明细
|
||||
type ReconciliationException struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
ReportID uint64 `gorm:"column:report_id;not null;index"`
|
||||
TradeNo string `gorm:"column:trade_no;size:32;index"`
|
||||
ChannelBillNo string `gorm:"column:channel_bill_no;size:64"`
|
||||
ExceptionType string `gorm:"column:exception_type;size:32;not null"` // MISSING_LOCAL/MISSING_CHANNEL/AMOUNT_MISMATCH
|
||||
LocalAmount int64 `gorm:"column:local_amount"`
|
||||
ChannelAmount int64 `gorm:"column:channel_amount"`
|
||||
Remark string `gorm:"column:remark;size:256"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
}
|
||||
|
||||
func (ReconciliationException) TableName() string { return "reconciliation_exception" }
|
||||
32
backend/internal/model/refund_order.go
Normal file
32
backend/internal/model/refund_order.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// RefundStatus 退款状态
|
||||
type RefundStatus string
|
||||
|
||||
const (
|
||||
RefundStatusPending RefundStatus = "PENDING"
|
||||
RefundStatusProcessing RefundStatus = "PROCESSING"
|
||||
RefundStatusSuccess RefundStatus = "SUCCESS"
|
||||
RefundStatusFailed RefundStatus = "FAILED"
|
||||
)
|
||||
|
||||
// RefundOrder 退款记录
|
||||
type RefundOrder struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
RefundNo string `gorm:"column:refund_no;uniqueIndex;size:32;not null"`
|
||||
TradeNo string `gorm:"column:trade_no;size:32;not null;index"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null"`
|
||||
ChannelCode string `gorm:"column:channel_code;size:32;not null"`
|
||||
ChannelRefundNo string `gorm:"column:channel_refund_no;size:64"`
|
||||
RefundAmount int64 `gorm:"column:refund_amount;not null"`
|
||||
Reason string `gorm:"column:reason;size:256"`
|
||||
Status RefundStatus `gorm:"column:status;size:20;not null;default:PENDING"`
|
||||
NotifyURL string `gorm:"column:notify_url;size:512"`
|
||||
RefundTime *time.Time `gorm:"column:refund_time"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (RefundOrder) TableName() string { return "refund_order" }
|
||||
56
backend/internal/model/service_fee.go
Normal file
56
backend/internal/model/service_fee.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// PayMethodGroup 支付方式分组(用于服务费配置)
|
||||
type PayMethodGroup string
|
||||
|
||||
const (
|
||||
PayMethodGroupScan PayMethodGroup = "SCAN" // 扫码支付(微信/支付宝)
|
||||
PayMethodGroupTransfer PayMethodGroup = "TRANSFER" // 对公转账
|
||||
PayMethodGroupBalance PayMethodGroup = "BALANCE" // 余额支付
|
||||
)
|
||||
|
||||
// ServiceFeeConfig 服务费配置
|
||||
type ServiceFeeConfig struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;uniqueIndex:uk_app_method"`
|
||||
PayMethodGroup PayMethodGroup `gorm:"column:pay_method_group;size:20;not null;uniqueIndex:uk_app_method"`
|
||||
FeeRate float64 `gorm:"column:fee_rate;type:decimal(6,4);not null"` // 0.0000 ~ 9.9999%
|
||||
FeeReceiverMerchantID string `gorm:"column:fee_receiver_merchant_id;size:64;not null"`
|
||||
Status int8 `gorm:"column:status;not null;default:1"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (ServiceFeeConfig) TableName() string { return "service_fee_config" }
|
||||
|
||||
// ServiceFeeLog 服务费流水
|
||||
type ServiceFeeLog struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
TradeNo string `gorm:"column:trade_no;size:32;not null;uniqueIndex:uk_trade_action"`
|
||||
ConfigID uint64 `gorm:"column:config_id;not null"`
|
||||
FeeAmount int64 `gorm:"column:fee_amount;not null"`
|
||||
FeeRate float64 `gorm:"column:fee_rate;type:decimal(6,4);not null"`
|
||||
ReceiverMerchantID string `gorm:"column:receiver_merchant_id;size:64;not null"`
|
||||
Action string `gorm:"column:action;size:20;not null;uniqueIndex:uk_trade_action"` // CHARGE / ROLLBACK
|
||||
Status string `gorm:"column:status;size:20;not null;default:PENDING"`
|
||||
ChannelSharingNo string `gorm:"column:channel_sharing_no;size:64"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (ServiceFeeLog) TableName() string { return "service_fee_log" }
|
||||
|
||||
// PayMethodToGroup 将支付方式映射到服务费分组
|
||||
func PayMethodToGroup(m PayMethod) PayMethodGroup {
|
||||
switch m {
|
||||
case PayMethodWechatJSAPI, PayMethodWechatH5, PayMethodWechatNative,
|
||||
PayMethodWechatMini, PayMethodAlipay, PayMethodQuickPay:
|
||||
return PayMethodGroupScan
|
||||
case PayMethodTransfer:
|
||||
return PayMethodGroupTransfer
|
||||
default:
|
||||
return PayMethodGroupBalance
|
||||
}
|
||||
}
|
||||
103
backend/internal/model/trade_order.go
Normal file
103
backend/internal/model/trade_order.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TradeStatus 交易状态
|
||||
type TradeStatus string
|
||||
|
||||
const (
|
||||
TradeStatusCreating TradeStatus = "CREATING"
|
||||
TradeStatusPaying TradeStatus = "PAYING"
|
||||
TradeStatusPaid TradeStatus = "PAID"
|
||||
TradeStatusClosed TradeStatus = "CLOSED"
|
||||
TradeStatusFailed TradeStatus = "FAILED"
|
||||
TradeStatusCreateFailed TradeStatus = "CREATE_FAILED"
|
||||
TradeStatusRefunded TradeStatus = "REFUNDED"
|
||||
)
|
||||
|
||||
// PayMethod 支付方式
|
||||
type PayMethod string
|
||||
|
||||
const (
|
||||
PayMethodWechatJSAPI PayMethod = "WECHAT_JSAPI"
|
||||
PayMethodWechatH5 PayMethod = "WECHAT_H5"
|
||||
PayMethodWechatNative PayMethod = "WECHAT_NATIVE"
|
||||
PayMethodWechatMini PayMethod = "WECHAT_MINI"
|
||||
PayMethodAlipay PayMethod = "ALIPAY"
|
||||
PayMethodQuickPay PayMethod = "QUICK_PAY"
|
||||
PayMethodTransfer PayMethod = "TRANSFER" // 对公转账
|
||||
)
|
||||
|
||||
// JSONMap JSON 字段类型
|
||||
type JSONMap map[string]any
|
||||
|
||||
func (j JSONMap) Value() (driver.Value, error) {
|
||||
if j == nil {
|
||||
return nil, nil
|
||||
}
|
||||
b, err := json.Marshal(j)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
func (j *JSONMap) Scan(value any) error {
|
||||
if value == nil {
|
||||
*j = nil
|
||||
return nil
|
||||
}
|
||||
var bytes []byte
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
case []byte:
|
||||
bytes = v
|
||||
default:
|
||||
return fmt.Errorf("unsupported type: %T", value)
|
||||
}
|
||||
return json.Unmarshal(bytes, j)
|
||||
}
|
||||
|
||||
// TradeOrder 交易订单
|
||||
type TradeOrder struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
TradeNo string `gorm:"column:trade_no;uniqueIndex;size:32;not null"`
|
||||
MerchantOrderNo string `gorm:"column:merchant_order_no;size:64;not null"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;index:idx_app_merchant,unique"`
|
||||
ChannelCode string `gorm:"column:channel_code;size:32;not null"`
|
||||
ChannelTradeNo string `gorm:"column:channel_trade_no;size:64;index"`
|
||||
PayMethod PayMethod `gorm:"column:pay_method;size:32;not null"`
|
||||
Amount int64 `gorm:"column:amount;not null"`
|
||||
ProfitSharingAmount int64 `gorm:"column:profit_sharing_amount;not null;default:0"`
|
||||
ServiceFeeAmount int64 `gorm:"column:service_fee_amount;not null;default:0"`
|
||||
Subject string `gorm:"column:subject;size:256;not null"`
|
||||
NotifyURL string `gorm:"column:notify_url;size:512;not null"`
|
||||
Status TradeStatus `gorm:"column:status;size:20;not null;default:CREATING"`
|
||||
Extra JSONMap `gorm:"column:extra;type:json"`
|
||||
ChannelExtra JSONMap `gorm:"column:channel_extra;type:json"`
|
||||
ExpireTime time.Time `gorm:"column:expire_time;not null"`
|
||||
PayTime *time.Time `gorm:"column:pay_time"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (TradeOrder) TableName() string { return "trade_order" }
|
||||
|
||||
// CanTransitTo 校验状态流转是否合法
|
||||
func (t TradeStatus) CanTransitTo(next TradeStatus) bool {
|
||||
allowed := map[TradeStatus][]TradeStatus{
|
||||
TradeStatusCreating: {TradeStatusPaying, TradeStatusCreateFailed},
|
||||
TradeStatusPaying: {TradeStatusPaid, TradeStatusClosed, TradeStatusFailed},
|
||||
TradeStatusPaid: {TradeStatusRefunded},
|
||||
TradeStatusCreateFailed: {TradeStatusPaying},
|
||||
}
|
||||
for _, s := range allowed[t] {
|
||||
if s == next {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
41
backend/internal/model/wechat.go
Normal file
41
backend/internal/model/wechat.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
// WechatBinding 商户微信公众号绑定
|
||||
type WechatBinding struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;uniqueIndex"`
|
||||
WxAppID string `gorm:"column:wx_app_id;size:32;not null"` // 微信公众号/小程序 AppID
|
||||
WxSecret string `gorm:"column:wx_secret;type:text;not null"` // AES 加密存储
|
||||
TemplateID string `gorm:"column:template_id;size:64;not null"` // 消息模板 ID
|
||||
Status int8 `gorm:"column:status;not null;default:1"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime:milli"`
|
||||
}
|
||||
|
||||
func (WechatBinding) TableName() string { return "wechat_binding" }
|
||||
|
||||
// WechatMessageStatus 微信消息发送状态
|
||||
type WechatMessageStatus string
|
||||
|
||||
const (
|
||||
WechatMessageStatusPending WechatMessageStatus = "PENDING"
|
||||
WechatMessageStatusSuccess WechatMessageStatus = "SUCCESS"
|
||||
WechatMessageStatusFailed WechatMessageStatus = "FAILED"
|
||||
)
|
||||
|
||||
// WechatMessageLog 微信消息发送日志
|
||||
type WechatMessageLog struct {
|
||||
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
|
||||
AppID string `gorm:"column:app_id;size:32;not null;index"`
|
||||
TradeNo string `gorm:"column:trade_no;size:32;index"`
|
||||
OpenID string `gorm:"column:open_id;size:64;not null"`
|
||||
TemplateID string `gorm:"column:template_id;size:64;not null"`
|
||||
Status WechatMessageStatus `gorm:"column:status;size:20;not null;default:PENDING"`
|
||||
ErrMsg string `gorm:"column:err_msg;size:256"`
|
||||
SentAt *time.Time `gorm:"column:sent_at"`
|
||||
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime:milli"`
|
||||
}
|
||||
|
||||
func (WechatMessageLog) TableName() string { return "wechat_message_log" }
|
||||
30
backend/internal/repository/admin_user.go
Normal file
30
backend/internal/repository/admin_user.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
type AdminUserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAdminUserRepository(db *gorm.DB) *AdminUserRepository {
|
||||
return &AdminUserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *AdminUserRepository) GetByUsername(ctx context.Context, username string) (*model.AdminUser, error) {
|
||||
var user model.AdminUser
|
||||
err := r.db.WithContext(ctx).Where("username = ? AND status = 1", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *AdminUserRepository) Create(ctx context.Context, user *model.AdminUser) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
71
backend/internal/repository/app.go
Normal file
71
backend/internal/repository/app.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// AppRepository app 数据访问
|
||||
type AppRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAppRepository(db *gorm.DB) *AppRepository {
|
||||
return &AppRepository{db: db}
|
||||
}
|
||||
|
||||
// GetByAppID 根据 appId 查询
|
||||
func (r *AppRepository) GetByAppID(ctx context.Context, appID string) (*model.App, error) {
|
||||
var app model.App
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND status = 1", appID).First(&app).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &app, err
|
||||
}
|
||||
|
||||
// Create 创建应用
|
||||
func (r *AppRepository) Create(ctx context.Context, app *model.App) error {
|
||||
return r.db.WithContext(ctx).Create(app).Error
|
||||
}
|
||||
|
||||
// ListActive 查询所有启用的应用
|
||||
func (r *AppRepository) ListActive(ctx context.Context) ([]*model.App, error) {
|
||||
var apps []*model.App
|
||||
err := r.db.WithContext(ctx).Where("status = 1").Find(&apps).Error
|
||||
return apps, err
|
||||
}
|
||||
|
||||
// List 分页查询所有应用(不过滤状态)
|
||||
func (r *AppRepository) List(ctx context.Context, limit, offset int) ([]*model.App, error) {
|
||||
var apps []*model.App
|
||||
err := r.db.WithContext(ctx).Order("id DESC").Limit(limit).Offset(offset).Find(&apps).Error
|
||||
return apps, err
|
||||
}
|
||||
|
||||
// GetByAppIDUnscoped 不过滤状态地查询(用于管理接口)
|
||||
func (r *AppRepository) GetByAppIDUnscoped(ctx context.Context, appID string) (*model.App, error) {
|
||||
var app model.App
|
||||
err := r.db.WithContext(ctx).Where("app_id = ?", appID).First(&app).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &app, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新应用状态
|
||||
func (r *AppRepository) UpdateStatus(ctx context.Context, appID string, status int8) error {
|
||||
return r.db.WithContext(ctx).Model(&model.App{}).
|
||||
Where("app_id = ?", appID).
|
||||
Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateSecret 更新应用密钥
|
||||
func (r *AppRepository) UpdateSecret(ctx context.Context, appID, encSecret string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.App{}).
|
||||
Where("app_id = ?", appID).
|
||||
Update("app_secret", encSecret).Error
|
||||
}
|
||||
45
backend/internal/repository/channel_config.go
Normal file
45
backend/internal/repository/channel_config.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// ChannelConfigRepository 渠道配置数据访问
|
||||
type ChannelConfigRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewChannelConfigRepository(db *gorm.DB) *ChannelConfigRepository {
|
||||
return &ChannelConfigRepository{db: db}
|
||||
}
|
||||
|
||||
// GetByAppChannel 按 app_id + channel_code 查询
|
||||
func (r *ChannelConfigRepository) GetByAppChannel(ctx context.Context, appID, channelCode string) (*model.ChannelConfig, error) {
|
||||
var cfg model.ChannelConfig
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND channel_code = ? AND status = 1", appID, channelCode).First(&cfg).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &cfg, err
|
||||
}
|
||||
|
||||
// Create 创建渠道配置
|
||||
func (r *ChannelConfigRepository) Create(ctx context.Context, cfg *model.ChannelConfig) error {
|
||||
return r.db.WithContext(ctx).Create(cfg).Error
|
||||
}
|
||||
|
||||
// Update 更新渠道配置
|
||||
func (r *ChannelConfigRepository) Update(ctx context.Context, id uint64, updates map[string]any) error {
|
||||
return r.db.WithContext(ctx).Model(&model.ChannelConfig{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// ListByApp 查询应用下所有启用的渠道配置
|
||||
func (r *ChannelConfigRepository) ListByApp(ctx context.Context, appID string) ([]*model.ChannelConfig, error) {
|
||||
var cfgs []*model.ChannelConfig
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND status = 1", appID).Find(&cfgs).Error
|
||||
return cfgs, err
|
||||
}
|
||||
115
backend/internal/repository/merchant.go
Normal file
115
backend/internal/repository/merchant.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// MerchantRepository 商户数据访问
|
||||
type MerchantRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewMerchantRepository(db *gorm.DB) *MerchantRepository {
|
||||
return &MerchantRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *MerchantRepository) Create(ctx context.Context, m *model.Merchant) error {
|
||||
return r.db.WithContext(ctx).Create(m).Error
|
||||
}
|
||||
|
||||
func (r *MerchantRepository) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
|
||||
var m model.Merchant
|
||||
err := r.db.WithContext(ctx).Where("merchant_id = ?", merchantID).First(&m).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &m, err
|
||||
}
|
||||
|
||||
func (r *MerchantRepository) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
|
||||
if updates == nil {
|
||||
updates = make(map[string]any)
|
||||
}
|
||||
updates["status"] = status
|
||||
return r.db.WithContext(ctx).Model(&model.Merchant{}).Where("merchant_id = ?", merchantID).Updates(updates).Error
|
||||
}
|
||||
|
||||
func (r *MerchantRepository) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
var merchants []*model.Merchant
|
||||
q := r.db.WithContext(ctx)
|
||||
if status != "" {
|
||||
q = q.Where("status = ?", status)
|
||||
}
|
||||
err := q.Order("created_at DESC").Limit(limit).Offset(offset).Find(&merchants).Error
|
||||
return merchants, err
|
||||
}
|
||||
|
||||
// ListAnomalous 查询状态异常的商户(Frozen/Rejected)
|
||||
func (r *MerchantRepository) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
|
||||
var merchants []*model.Merchant
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status IN ?", []model.MerchantStatus{
|
||||
model.MerchantStatusFrozen,
|
||||
model.MerchantStatusRejected,
|
||||
}).Find(&merchants).Error
|
||||
return merchants, err
|
||||
}
|
||||
|
||||
// GetByMerchantIDAndAppID 带 appID 隔离查询(业务侧用)
|
||||
func (r *MerchantRepository) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
|
||||
var m model.Merchant
|
||||
err := r.db.WithContext(ctx).Where("merchant_id = ? AND app_id = ?", merchantID, appID).First(&m).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &m, err
|
||||
}
|
||||
|
||||
// ListByAppID 按 appID 分页查询(业务侧用)
|
||||
func (r *MerchantRepository) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
var merchants []*model.Merchant
|
||||
q := r.db.WithContext(ctx).Where("app_id = ?", appID)
|
||||
if status != "" {
|
||||
q = q.Where("status = ?", status)
|
||||
}
|
||||
err := q.Limit(limit).Offset(offset).Order("id DESC").Find(&merchants).Error
|
||||
return merchants, err
|
||||
}
|
||||
|
||||
// CreateApplication 创建进件申请
|
||||
func (r *MerchantRepository) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
|
||||
return r.db.WithContext(ctx).Create(app).Error
|
||||
}
|
||||
|
||||
// GetLatestApplication 获取商户最新进件申请
|
||||
func (r *MerchantRepository) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
|
||||
var app model.MerchantApplication
|
||||
err := r.db.WithContext(ctx).Where("merchant_id = ?", merchantID).
|
||||
Order("created_at DESC").First(&app).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &app, err
|
||||
}
|
||||
|
||||
// GetApprovedApplicationByChannel 查询指定商户在指定渠道已审核通过的进件记录
|
||||
func (r *MerchantRepository) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
|
||||
var app model.MerchantApplication
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("merchant_id = ? AND channel_code = ? AND audit_status = ?", merchantID, channelCode, model.AuditStatusApproved).
|
||||
First(&app).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &app, err
|
||||
}
|
||||
|
||||
// UpdateApplication 更新进件申请状态
|
||||
func (r *MerchantRepository) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
|
||||
return r.db.WithContext(ctx).Model(&model.MerchantApplication{}).
|
||||
Where("application_id = ?", applicationID).Updates(updates).Error
|
||||
}
|
||||
75
backend/internal/repository/notify_log.go
Normal file
75
backend/internal/repository/notify_log.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// NotifyLogRepository 通知记录数据访问
|
||||
type NotifyLogRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewNotifyLogRepository(db *gorm.DB) *NotifyLogRepository {
|
||||
return &NotifyLogRepository{db: db}
|
||||
}
|
||||
|
||||
// Upsert 创建或更新通知记录
|
||||
func (r *NotifyLogRepository) Upsert(ctx context.Context, log *model.NotifyLog) error {
|
||||
return r.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "trade_no"}, {Name: "notify_type"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"notify_url", "status", "retry_count", "next_retry_time", "last_response"}),
|
||||
}).Create(log).Error
|
||||
}
|
||||
|
||||
// GetByTradeNo 按 trade_no + notify_type 查询
|
||||
func (r *NotifyLogRepository) GetByTradeNo(ctx context.Context, tradeNo string, notifyType model.NotifyType) (*model.NotifyLog, error) {
|
||||
var log model.NotifyLog
|
||||
err := r.db.WithContext(ctx).Where("trade_no = ? AND notify_type = ?", tradeNo, notifyType).First(&log).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &log, err
|
||||
}
|
||||
|
||||
// ListPendingRetry 查询到期需要重试的通知
|
||||
func (r *NotifyLogRepository) ListPendingRetry(ctx context.Context, before time.Time, limit int) ([]*model.NotifyLog, error) {
|
||||
var logs []*model.NotifyLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status IN ? AND next_retry_time <= ?",
|
||||
[]model.NotifyStatus{model.NotifyStatusPending, model.NotifyStatusRetry},
|
||||
before).
|
||||
Order("next_retry_time ASC").
|
||||
Limit(limit).
|
||||
Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
// IncrRetryCount 重试次数+1,更新下次重试时间和最后响应
|
||||
func (r *NotifyLogRepository) IncrRetryCount(ctx context.Context, id uint64, status model.NotifyStatus, nextRetryTime *time.Time, lastResponse string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.NotifyLog{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"retry_count": gorm.Expr("retry_count + 1"),
|
||||
"status": status,
|
||||
"next_retry_time": nextRetryTime,
|
||||
"last_response": lastResponse,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// MarkSuccess 标记通知成功
|
||||
func (r *NotifyLogRepository) MarkSuccess(ctx context.Context, id uint64, lastResponse string) error {
|
||||
return r.db.WithContext(ctx).Model(&model.NotifyLog{}).Where("id = ?", id).Updates(map[string]any{
|
||||
"status": model.NotifyStatusSuccess,
|
||||
"last_response": lastResponse,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// MarkGiveup 标记放弃通知
|
||||
func (r *NotifyLogRepository) MarkGiveup(ctx context.Context, id uint64) error {
|
||||
return r.db.WithContext(ctx).Model(&model.NotifyLog{}).Where("id = ?", id).
|
||||
Update("status", model.NotifyStatusGiveup).Error
|
||||
}
|
||||
95
backend/internal/repository/payment_match.go
Normal file
95
backend/internal/repository/payment_match.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// PaymentMatchRepository 收款匹配数据访问
|
||||
type PaymentMatchRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewPaymentMatchRepository(db *gorm.DB) *PaymentMatchRepository {
|
||||
return &PaymentMatchRepository{db: db}
|
||||
}
|
||||
|
||||
// GetAccountByNo 按账号查询子商户收款账户
|
||||
func (r *PaymentMatchRepository) GetAccountByNo(ctx context.Context, accountNo string) (*model.SubMerchantAccount, error) {
|
||||
var acc model.SubMerchantAccount
|
||||
err := r.db.WithContext(ctx).Where("account_no = ? AND status = 1", accountNo).First(&acc).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &acc, err
|
||||
}
|
||||
|
||||
// GetAccountByID 按 id 查询
|
||||
func (r *PaymentMatchRepository) GetAccountByID(ctx context.Context, id uint64) (*model.SubMerchantAccount, error) {
|
||||
var acc model.SubMerchantAccount
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&acc).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &acc, err
|
||||
}
|
||||
|
||||
// ListAccountsByApp 查询应用下所有收款账户
|
||||
func (r *PaymentMatchRepository) ListAccountsByApp(ctx context.Context, appID string) ([]*model.SubMerchantAccount, error) {
|
||||
var accs []*model.SubMerchantAccount
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND status = 1", appID).Find(&accs).Error
|
||||
return accs, err
|
||||
}
|
||||
|
||||
// CreateAccount 创建收款账户
|
||||
func (r *PaymentMatchRepository) CreateAccount(ctx context.Context, acc *model.SubMerchantAccount) error {
|
||||
return r.db.WithContext(ctx).Create(acc).Error
|
||||
}
|
||||
|
||||
// CreateMatchLog 创建匹配记录
|
||||
func (r *PaymentMatchRepository) CreateMatchLog(ctx context.Context, log *model.PaymentMatchLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
// GetMatchLogByBillNo 按渠道流水号查询(幂等检查)
|
||||
func (r *PaymentMatchRepository) GetMatchLogByBillNo(ctx context.Context, channelBillNo string) (*model.PaymentMatchLog, error) {
|
||||
var log model.PaymentMatchLog
|
||||
err := r.db.WithContext(ctx).Where("channel_bill_no = ?", channelBillNo).First(&log).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &log, err
|
||||
}
|
||||
|
||||
// UpdateMatchLog 更新匹配记录
|
||||
func (r *PaymentMatchRepository) UpdateMatchLog(ctx context.Context, id uint64, updates map[string]any) error {
|
||||
return r.db.WithContext(ctx).Model(&model.PaymentMatchLog{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// ListPendingManual 查询待人工确认的记录
|
||||
func (r *PaymentMatchRepository) ListPendingManual(ctx context.Context, appID string, limit, offset int) ([]*model.PaymentMatchLog, error) {
|
||||
var logs []*model.PaymentMatchLog
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN sub_merchant_account ON sub_merchant_account.id = payment_match_log.account_id").
|
||||
Where("sub_merchant_account.app_id = ? AND payment_match_log.match_status = ?",
|
||||
appID, model.MatchStatusPendingManual).
|
||||
Order("payment_match_log.created_at DESC").
|
||||
Limit(limit).Offset(offset).
|
||||
Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
// ListPayingByAmount 按金额查询指定时间窗口内的待支付订单(用于收款匹配降级)
|
||||
func (r *PaymentMatchRepository) ListPayingByAmount(ctx context.Context, appID string, amount int64, window time.Duration) ([]*model.TradeOrder, error) {
|
||||
var orders []*model.TradeOrder
|
||||
since := time.Now().Add(-window)
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("app_id = ? AND amount = ? AND status = ? AND created_at >= ?",
|
||||
appID, amount, model.TradeStatusPaying, since).
|
||||
Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
90
backend/internal/repository/profit_sharing.go
Normal file
90
backend/internal/repository/profit_sharing.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// ProfitSharingRepository 分润数据访问
|
||||
type ProfitSharingRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewProfitSharingRepository(db *gorm.DB) *ProfitSharingRepository {
|
||||
return &ProfitSharingRepository{db: db}
|
||||
}
|
||||
|
||||
// GetConfigByAppID 按 app_id 获取分润配置
|
||||
func (r *ProfitSharingRepository) GetConfigByAppID(ctx context.Context, appID string) (*model.ProfitSharingConfig, error) {
|
||||
var cfg model.ProfitSharingConfig
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND status = 1", appID).First(&cfg).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &cfg, err
|
||||
}
|
||||
|
||||
// SaveConfig 创建或更新分润配置
|
||||
func (r *ProfitSharingRepository) SaveConfig(ctx context.Context, cfg *model.ProfitSharingConfig) error {
|
||||
return r.db.WithContext(ctx).Save(cfg).Error
|
||||
}
|
||||
|
||||
// CreateOrder 创建分润记录
|
||||
func (r *ProfitSharingRepository) CreateOrder(ctx context.Context, order *model.ProfitSharingOrder) error {
|
||||
return r.db.WithContext(ctx).Create(order).Error
|
||||
}
|
||||
|
||||
// GetOrderByTradeNo 按 trade_no 查询分润记录
|
||||
func (r *ProfitSharingRepository) GetOrderByTradeNo(ctx context.Context, tradeNo string) (*model.ProfitSharingOrder, error) {
|
||||
var order model.ProfitSharingOrder
|
||||
err := r.db.WithContext(ctx).Where("trade_no = ?", tradeNo).First(&order).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// GetOrderBySharingNo 按 sharing_no 查询
|
||||
func (r *ProfitSharingRepository) GetOrderBySharingNo(ctx context.Context, sharingNo string) (*model.ProfitSharingOrder, error) {
|
||||
var order model.ProfitSharingOrder
|
||||
err := r.db.WithContext(ctx).Where("sharing_no = ?", sharingNo).First(&order).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// UpdateOrderStatus 更新分润状态
|
||||
func (r *ProfitSharingRepository) UpdateOrderStatus(ctx context.Context, sharingNo string, fromStatus, toStatus model.ProfitSharingStatus, updates map[string]any) (bool, error) {
|
||||
if updates == nil {
|
||||
updates = make(map[string]any)
|
||||
}
|
||||
updates["status"] = toStatus
|
||||
result := r.db.WithContext(ctx).Model(&model.ProfitSharingOrder{}).
|
||||
Where("sharing_no = ? AND status = ?", sharingNo, fromStatus).
|
||||
Updates(updates)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// CreateLog 记录分润流水
|
||||
func (r *ProfitSharingRepository) CreateLog(ctx context.Context, log *model.ProfitSharingLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
// ListPendingOrders 查询需要补偿的分润单(PROCESSING 超时)
|
||||
func (r *ProfitSharingRepository) ListPendingOrders(ctx context.Context, limit int) ([]*model.ProfitSharingOrder, error) {
|
||||
var orders []*model.ProfitSharingOrder
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status IN ?", []model.ProfitSharingStatus{
|
||||
model.ProfitSharingStatusPending,
|
||||
model.ProfitSharingStatusProcessing,
|
||||
}).
|
||||
Limit(limit).Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
62
backend/internal/repository/reconciliation.go
Normal file
62
backend/internal/repository/reconciliation.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// ReconciliationRepository 对账数据访问
|
||||
type ReconciliationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewReconciliationRepository(db *gorm.DB) *ReconciliationRepository {
|
||||
return &ReconciliationRepository{db: db}
|
||||
}
|
||||
|
||||
// CreateReport 创建对账报告
|
||||
func (r *ReconciliationRepository) CreateReport(ctx context.Context, report *model.ReconciliationReport) error {
|
||||
return r.db.WithContext(ctx).Create(report).Error
|
||||
}
|
||||
|
||||
// GetReport 查询对账报告
|
||||
func (r *ReconciliationRepository) GetReport(ctx context.Context, appID, billDate, channelCode string) (*model.ReconciliationReport, error) {
|
||||
var report model.ReconciliationReport
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("app_id = ? AND bill_date = ? AND channel_code = ?", appID, billDate, channelCode).
|
||||
First(&report).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &report, err
|
||||
}
|
||||
|
||||
// UpdateReport 更新对账报告
|
||||
func (r *ReconciliationRepository) UpdateReport(ctx context.Context, id uint64, updates map[string]any) error {
|
||||
return r.db.WithContext(ctx).Model(&model.ReconciliationReport{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
|
||||
// CreateException 创建对账异常记录
|
||||
func (r *ReconciliationRepository) CreateException(ctx context.Context, ex *model.ReconciliationException) error {
|
||||
return r.db.WithContext(ctx).Create(ex).Error
|
||||
}
|
||||
|
||||
// ListExceptions 查询报告下的异常明细
|
||||
func (r *ReconciliationRepository) ListExceptions(ctx context.Context, reportID uint64) ([]*model.ReconciliationException, error) {
|
||||
var exs []*model.ReconciliationException
|
||||
err := r.db.WithContext(ctx).Where("report_id = ?", reportID).Find(&exs).Error
|
||||
return exs, err
|
||||
}
|
||||
|
||||
// ListPaidOrdersByDate 查询指定日期的已支付订单(用于对账)
|
||||
func (r *ReconciliationRepository) ListPaidOrdersByDate(ctx context.Context, appID, date string) ([]*model.TradeOrder, error) {
|
||||
var orders []*model.TradeOrder
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("app_id = ? AND status = ? AND DATE(pay_time) = ?",
|
||||
appID, model.TradeStatusPaid, date).
|
||||
Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
63
backend/internal/repository/refund_order.go
Normal file
63
backend/internal/repository/refund_order.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// RefundOrderRepository 退款记录数据访问
|
||||
type RefundOrderRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewRefundOrderRepository(db *gorm.DB) *RefundOrderRepository {
|
||||
return &RefundOrderRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建退款单
|
||||
func (r *RefundOrderRepository) Create(ctx context.Context, refund *model.RefundOrder) error {
|
||||
return r.db.WithContext(ctx).Create(refund).Error
|
||||
}
|
||||
|
||||
// GetByRefundNo 按 refund_no 查询
|
||||
func (r *RefundOrderRepository) GetByRefundNo(ctx context.Context, refundNo string) (*model.RefundOrder, error) {
|
||||
var refund model.RefundOrder
|
||||
err := r.db.WithContext(ctx).Where("refund_no = ?", refundNo).First(&refund).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &refund, err
|
||||
}
|
||||
|
||||
// SumRefundedAmount 统计某笔交易已退款总额(成功+处理中)
|
||||
func (r *RefundOrderRepository) SumRefundedAmount(ctx context.Context, tradeNo string) (int64, error) {
|
||||
var total int64
|
||||
err := r.db.WithContext(ctx).Model(&model.RefundOrder{}).
|
||||
Where("trade_no = ? AND status IN ?", tradeNo, []model.RefundStatus{
|
||||
model.RefundStatusPending,
|
||||
model.RefundStatusProcessing,
|
||||
model.RefundStatusSuccess,
|
||||
}).
|
||||
Select("COALESCE(SUM(refund_amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新退款状态
|
||||
func (r *RefundOrderRepository) UpdateStatus(ctx context.Context, refundNo string, fromStatus, toStatus model.RefundStatus, updates map[string]any) (bool, error) {
|
||||
if updates == nil {
|
||||
updates = make(map[string]any)
|
||||
}
|
||||
updates["status"] = toStatus
|
||||
|
||||
result := r.db.WithContext(ctx).Model(&model.RefundOrder{}).
|
||||
Where("refund_no = ? AND status = ?", refundNo, fromStatus).
|
||||
Updates(updates)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
90
backend/internal/repository/sequence.go
Normal file
90
backend/internal/repository/sequence.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// SequenceRepository 序列数据访问
|
||||
type SequenceRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewSequenceRepository(db *gorm.DB) *SequenceRepository {
|
||||
return &SequenceRepository{db: db}
|
||||
}
|
||||
|
||||
// IncrAndGet 原子自增并返回新值(行级锁)
|
||||
func (r *SequenceRepository) IncrAndGet(ctx context.Context, appID string, seqType model.SeqType) (uint64, error) {
|
||||
var seq model.OrderSequence
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 加行锁读取
|
||||
if err := tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
Where("app_id = ? AND seq_type = ?", appID, seqType).
|
||||
First(&seq).Error; err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
// 自动初始化序列
|
||||
prefix := defaultPrefix(seqType)
|
||||
seq = model.OrderSequence{
|
||||
AppID: appID,
|
||||
SeqType: seqType,
|
||||
Prefix: prefix,
|
||||
CurrentValue: 0,
|
||||
Step: 1,
|
||||
}
|
||||
if err := tx.Create(&seq).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
// 重新加锁读
|
||||
return tx.Set("gorm:query_option", "FOR UPDATE").
|
||||
Where("app_id = ? AND seq_type = ?", appID, seqType).
|
||||
First(&seq).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 自增
|
||||
newVal := seq.CurrentValue + uint64(seq.Step)
|
||||
if err := r.db.WithContext(ctx).Model(&model.OrderSequence{}).
|
||||
Where("id = ?", seq.ID).
|
||||
Update("current_value", newVal).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return newVal, nil
|
||||
}
|
||||
|
||||
func defaultPrefix(t model.SeqType) string {
|
||||
switch t {
|
||||
case model.SeqTypeTrade:
|
||||
return "PAY"
|
||||
case model.SeqTypeRefund:
|
||||
return "REF"
|
||||
case model.SeqTypeSharing:
|
||||
return "SHA"
|
||||
default:
|
||||
return "ORD"
|
||||
}
|
||||
}
|
||||
|
||||
// GetPrefix 获取序列前缀
|
||||
func (r *SequenceRepository) GetPrefix(ctx context.Context, appID string, seqType model.SeqType) (string, error) {
|
||||
var seq model.OrderSequence
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND seq_type = ?", appID, seqType).First(&seq).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Sprintf("%s", defaultPrefix(seqType)), nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return seq.Prefix, nil
|
||||
}
|
||||
66
backend/internal/repository/service_fee.go
Normal file
66
backend/internal/repository/service_fee.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// ServiceFeeRepository 服务费数据访问
|
||||
type ServiceFeeRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewServiceFeeRepository(db *gorm.DB) *ServiceFeeRepository {
|
||||
return &ServiceFeeRepository{db: db}
|
||||
}
|
||||
|
||||
// GetConfig 按 app_id + 支付方式分组查询配置
|
||||
func (r *ServiceFeeRepository) GetConfig(ctx context.Context, appID string, group model.PayMethodGroup) (*model.ServiceFeeConfig, error) {
|
||||
var cfg model.ServiceFeeConfig
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("app_id = ? AND pay_method_group = ? AND status = 1", appID, group).
|
||||
First(&cfg).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &cfg, err
|
||||
}
|
||||
|
||||
// ListConfigs 查询应用所有服务费配置
|
||||
func (r *ServiceFeeRepository) ListConfigs(ctx context.Context, appID string) ([]*model.ServiceFeeConfig, error) {
|
||||
var cfgs []*model.ServiceFeeConfig
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND status = 1", appID).Find(&cfgs).Error
|
||||
return cfgs, err
|
||||
}
|
||||
|
||||
// SaveConfig 保存配置(创建或更新)
|
||||
func (r *ServiceFeeRepository) SaveConfig(ctx context.Context, cfg *model.ServiceFeeConfig) error {
|
||||
return r.db.WithContext(ctx).Save(cfg).Error
|
||||
}
|
||||
|
||||
// CreateLog 创建服务费流水
|
||||
func (r *ServiceFeeRepository) CreateLog(ctx context.Context, log *model.ServiceFeeLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
// GetLog 按 trade_no + action 查询流水
|
||||
func (r *ServiceFeeRepository) GetLog(ctx context.Context, tradeNo, action string) (*model.ServiceFeeLog, error) {
|
||||
var log model.ServiceFeeLog
|
||||
err := r.db.WithContext(ctx).Where("trade_no = ? AND action = ?", tradeNo, action).First(&log).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &log, err
|
||||
}
|
||||
|
||||
// UpdateLogStatus 更新流水状态
|
||||
func (r *ServiceFeeRepository) UpdateLogStatus(ctx context.Context, id uint64, status, channelSharingNo string) error {
|
||||
updates := map[string]any{"status": status}
|
||||
if channelSharingNo != "" {
|
||||
updates["channel_sharing_no"] = channelSharingNo
|
||||
}
|
||||
return r.db.WithContext(ctx).Model(&model.ServiceFeeLog{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
80
backend/internal/repository/trade_order.go
Normal file
80
backend/internal/repository/trade_order.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// TradeOrderRepository 交易订单数据访问
|
||||
type TradeOrderRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTradeOrderRepository(db *gorm.DB) *TradeOrderRepository {
|
||||
return &TradeOrderRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建订单
|
||||
func (r *TradeOrderRepository) Create(ctx context.Context, order *model.TradeOrder) error {
|
||||
return r.db.WithContext(ctx).Create(order).Error
|
||||
}
|
||||
|
||||
// GetByTradeNo 按 trade_no 查询
|
||||
func (r *TradeOrderRepository) GetByTradeNo(ctx context.Context, tradeNo string) (*model.TradeOrder, error) {
|
||||
var order model.TradeOrder
|
||||
err := r.db.WithContext(ctx).Where("trade_no = ?", tradeNo).First(&order).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// GetByMerchantOrderNo 按 app_id + merchant_order_no 查询
|
||||
func (r *TradeOrderRepository) GetByMerchantOrderNo(ctx context.Context, appID, merchantOrderNo string) (*model.TradeOrder, error) {
|
||||
var order model.TradeOrder
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND merchant_order_no = ?", appID, merchantOrderNo).First(&order).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// GetByChannelTradeNo 按渠道交易号查询
|
||||
func (r *TradeOrderRepository) GetByChannelTradeNo(ctx context.Context, channelTradeNo string) (*model.TradeOrder, error) {
|
||||
var order model.TradeOrder
|
||||
err := r.db.WithContext(ctx).Where("channel_trade_no = ?", channelTradeNo).First(&order).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &order, err
|
||||
}
|
||||
|
||||
// UpdateStatus 乐观锁更新状态(只允许从 fromStatus 流转到 toStatus)
|
||||
// 返回 bool 表示是否更新成功(false = 已被其他 goroutine 更新)
|
||||
func (r *TradeOrderRepository) UpdateStatus(ctx context.Context, tradeNo string, fromStatus, toStatus model.TradeStatus, updates map[string]any) (bool, error) {
|
||||
if updates == nil {
|
||||
updates = make(map[string]any)
|
||||
}
|
||||
updates["status"] = toStatus
|
||||
|
||||
result := r.db.WithContext(ctx).Model(&model.TradeOrder{}).
|
||||
Where("trade_no = ? AND status = ?", tradeNo, fromStatus).
|
||||
Updates(updates)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
// ListPayingExpired 查询已过期的 PAYING 订单(用于定时关单补偿)
|
||||
func (r *TradeOrderRepository) ListPayingExpired(ctx context.Context, before time.Time, limit int) ([]*model.TradeOrder, error) {
|
||||
var orders []*model.TradeOrder
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND expire_time < ?", model.TradeStatusPaying, before).
|
||||
Limit(limit).Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
43
backend/internal/repository/wechat.go
Normal file
43
backend/internal/repository/wechat.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// WechatRepository 微信通知数据访问
|
||||
type WechatRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewWechatRepository(db *gorm.DB) *WechatRepository {
|
||||
return &WechatRepository{db: db}
|
||||
}
|
||||
|
||||
// GetBinding 查询应用微信绑定配置
|
||||
func (r *WechatRepository) GetBinding(ctx context.Context, appID string) (*model.WechatBinding, error) {
|
||||
var b model.WechatBinding
|
||||
err := r.db.WithContext(ctx).Where("app_id = ? AND status = 1", appID).First(&b).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return &b, err
|
||||
}
|
||||
|
||||
// UpsertBinding 创建或更新绑定
|
||||
func (r *WechatRepository) UpsertBinding(ctx context.Context, b *model.WechatBinding) error {
|
||||
return r.db.WithContext(ctx).Save(b).Error
|
||||
}
|
||||
|
||||
// CreateMessageLog 记录消息发送日志
|
||||
func (r *WechatRepository) CreateMessageLog(ctx context.Context, log *model.WechatMessageLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
// UpdateMessageLog 更新消息日志状态
|
||||
func (r *WechatRepository) UpdateMessageLog(ctx context.Context, id uint64, updates map[string]any) error {
|
||||
return r.db.WithContext(ctx).Model(&model.WechatMessageLog{}).Where("id = ?", id).Updates(updates).Error
|
||||
}
|
||||
74
backend/internal/service/admin_auth.go
Normal file
74
backend/internal/service/admin_auth.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
type AdminAuthService struct {
|
||||
repo *repository.AdminUserRepository
|
||||
jwtSecret []byte
|
||||
expireHrs int
|
||||
}
|
||||
|
||||
func NewAdminAuthService(repo *repository.AdminUserRepository, jwtSecret string, expireHours int) *AdminAuthService {
|
||||
return &AdminAuthService{
|
||||
repo: repo,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
expireHrs: expireHours,
|
||||
}
|
||||
}
|
||||
|
||||
// Login 验证用户名密码,成功返回 JWT token
|
||||
func (s *AdminAuthService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
user, err := s.repo.GetByUsername(ctx, username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", errors.New("用户名或密码错误")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
|
||||
return "", errors.New("用户名或密码错误")
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"username": user.Username,
|
||||
"exp": time.Now().Add(time.Duration(s.expireHrs) * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(s.jwtSecret)
|
||||
}
|
||||
|
||||
// ParseToken 验证并解析 JWT,返回用户名
|
||||
func (s *AdminAuthService) ParseToken(tokenStr string) (string, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("invalid signing method")
|
||||
}
|
||||
return s.jwtSecret, nil
|
||||
}, jwt.WithValidMethods([]string{"HS256"}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", errors.New("invalid token")
|
||||
}
|
||||
|
||||
username, ok := claims["username"].(string)
|
||||
if !ok {
|
||||
return "", errors.New("invalid token claims")
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
141
backend/internal/service/app.go
Normal file
141
backend/internal/service/app.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/crypto"
|
||||
)
|
||||
|
||||
// AppService 应用服务
|
||||
type AppService struct {
|
||||
repo *repository.AppRepository
|
||||
encKey string
|
||||
}
|
||||
|
||||
func NewAppService(repo *repository.AppRepository, encKey string) *AppService {
|
||||
return &AppService{repo: repo, encKey: encKey}
|
||||
}
|
||||
|
||||
// GetAppSecret 获取 appSecret(用于鉴权中间件)
|
||||
func (s *AppService) GetAppSecret(ctx context.Context, appID string) (string, error) {
|
||||
app, err := s.repo.GetByAppID(ctx, appID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if app == nil {
|
||||
return "", errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
|
||||
secret, err := crypto.Decrypt(app.AppSecret, s.encKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt app secret: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GetApp 获取应用信息
|
||||
func (s *AppService) GetApp(ctx context.Context, appID string) (*model.App, error) {
|
||||
return s.repo.GetByAppID(ctx, appID)
|
||||
}
|
||||
|
||||
// CreateAppResult 创建应用的返回,包含明文 secret(仅展示一次)
|
||||
type CreateAppResult struct {
|
||||
App *model.App
|
||||
PlainSecret string
|
||||
}
|
||||
|
||||
// CreateApp 创建应用,自动生成 app_id 和 app_secret
|
||||
func (s *AppService) CreateApp(ctx context.Context, appName string) (*CreateAppResult, error) {
|
||||
appID := generateAppID()
|
||||
plainSecret := generateSecret()
|
||||
|
||||
encSecret, err := crypto.Encrypt(plainSecret, s.encKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app := &model.App{
|
||||
AppID: appID,
|
||||
AppSecret: encSecret,
|
||||
AppName: appName,
|
||||
Status: 1,
|
||||
}
|
||||
if err := s.repo.Create(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CreateAppResult{App: app, PlainSecret: plainSecret}, nil
|
||||
}
|
||||
|
||||
// ListApps 分页查询应用列表
|
||||
func (s *AppService) ListApps(ctx context.Context, limit, offset int) ([]*model.App, error) {
|
||||
return s.repo.List(ctx, limit, offset)
|
||||
}
|
||||
|
||||
// DisableApp 禁用应用
|
||||
func (s *AppService) DisableApp(ctx context.Context, appID string) error {
|
||||
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if app == nil {
|
||||
return errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
return s.repo.UpdateStatus(ctx, appID, 0)
|
||||
}
|
||||
|
||||
// EnableApp 启用应用
|
||||
func (s *AppService) EnableApp(ctx context.Context, appID string) error {
|
||||
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if app == nil {
|
||||
return errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
return s.repo.UpdateStatus(ctx, appID, 1)
|
||||
}
|
||||
|
||||
// ResetSecret 重置应用密钥,返回新的明文 secret(仅此一次)
|
||||
func (s *AppService) ResetSecret(ctx context.Context, appID string) (string, error) {
|
||||
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if app == nil {
|
||||
return "", errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
|
||||
plainSecret := generateSecret()
|
||||
encSecret, err := crypto.Encrypt(plainSecret, s.encKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := s.repo.UpdateSecret(ctx, appID, encSecret); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return plainSecret, nil
|
||||
}
|
||||
|
||||
// generateAppID 生成 app_id:app_ + yyMMdd + 8位随机hex
|
||||
func generateAppID() string {
|
||||
b := make([]byte, 4)
|
||||
_, _ = rand.Read(b)
|
||||
date := time.Now().Format("060102")
|
||||
return "app_" + date + hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// generateSecret 生成 32 字节随机 secret(64位hex)
|
||||
func generateSecret() string {
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
return strings.ToUpper(hex.EncodeToString(b))
|
||||
}
|
||||
140
backend/internal/service/channel.go
Normal file
140
backend/internal/service/channel.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/config"
|
||||
"pay-bridge/pkg/crypto"
|
||||
)
|
||||
|
||||
const channelCacheTTL = 5 * time.Minute
|
||||
|
||||
type cachedChannel struct {
|
||||
ch channel.PaymentChannel
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// ChannelService 渠道服务(负责加载渠道配置并获取渠道实例)
|
||||
type ChannelService struct {
|
||||
repo *repository.ChannelConfigRepository
|
||||
encKey string
|
||||
urlsCfg config.ChannelsConfig
|
||||
mu sync.Mutex
|
||||
cache map[string]*cachedChannel
|
||||
}
|
||||
|
||||
func NewChannelService(repo *repository.ChannelConfigRepository, encKey string, urlsCfg config.ChannelsConfig) *ChannelService {
|
||||
return &ChannelService{
|
||||
repo: repo,
|
||||
encKey: encKey,
|
||||
urlsCfg: urlsCfg,
|
||||
cache: make(map[string]*cachedChannel),
|
||||
}
|
||||
}
|
||||
|
||||
// GetChannel 根据 appID 和渠道码获取渠道适配器实例(5 分钟内存缓存)
|
||||
func (s *ChannelService) GetChannel(ctx context.Context, appID, channelCode string) (channel.PaymentChannel, error) {
|
||||
cacheKey := appID + ":" + channelCode
|
||||
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.cache[cacheKey]; ok && time.Now().Before(entry.expiresAt) {
|
||||
ch := entry.ch
|
||||
s.mu.Unlock()
|
||||
return ch, nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg, err := s.repo.GetByAppChannel(ctx, appID, channelCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("channel config not found: app=%s channel=%s", appID, channelCode)
|
||||
}
|
||||
|
||||
decCfg, err := s.decryptConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch, err := channel.Get(channelCode, decCfg, s.urlsFor(channelCode))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.cache[cacheKey] = &cachedChannel{ch: ch, expiresAt: time.Now().Add(channelCacheTTL)}
|
||||
s.mu.Unlock()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// InvalidateCache 使指定渠道的缓存失效(配置变更时调用)
|
||||
func (s *ChannelService) InvalidateCache(appID, channelCode string) {
|
||||
s.mu.Lock()
|
||||
delete(s.cache, appID+":"+channelCode)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// ListChannelCodes 获取应用下所有渠道码
|
||||
func (s *ChannelService) ListChannelCodes(ctx context.Context, appID string) ([]string, error) {
|
||||
cfgs, err := s.repo.ListByApp(ctx, appID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codes := make([]string, 0, len(cfgs))
|
||||
for _, c := range cfgs {
|
||||
codes = append(codes, c.ChannelCode)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// GetChannelConfig 获取渠道配置(已解密)
|
||||
func (s *ChannelService) GetChannelConfig(ctx context.Context, appID, channelCode string) (*model.ChannelConfig, error) {
|
||||
cfg, err := s.repo.GetByAppChannel(ctx, appID, channelCode)
|
||||
if err != nil || cfg == nil {
|
||||
return cfg, err
|
||||
}
|
||||
return s.decryptConfig(cfg)
|
||||
}
|
||||
|
||||
// urlsFor 根据渠道码返回对应的网关地址配置
|
||||
func (s *ChannelService) urlsFor(channelCode string) channel.URLs {
|
||||
switch channelCode {
|
||||
case "HEEPAY":
|
||||
return channel.URLs{
|
||||
PayURL: s.urlsCfg.Heepay.PayURL,
|
||||
MerchantURL: s.urlsCfg.Heepay.MerchantURL,
|
||||
}
|
||||
default:
|
||||
return channel.URLs{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChannelService) decryptConfig(cfg *model.ChannelConfig) (*model.ChannelConfig, error) {
|
||||
copied := *cfg
|
||||
|
||||
if cfg.APIKey != "" {
|
||||
dec, err := crypto.Decrypt(cfg.APIKey, s.encKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt api_key: %w", err)
|
||||
}
|
||||
copied.APIKey = dec
|
||||
}
|
||||
|
||||
if cfg.PrivateKey != "" {
|
||||
dec, err := crypto.Decrypt(cfg.PrivateKey, s.encKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt private_key: %w", err)
|
||||
}
|
||||
copied.PrivateKey = dec
|
||||
}
|
||||
|
||||
return &copied, nil
|
||||
}
|
||||
268
backend/internal/service/merchant.go
Normal file
268
backend/internal/service/merchant.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// merchantRepo 定义 MerchantService 所需的数据访问方法,便于测试时注入 mock
|
||||
type merchantRepo interface {
|
||||
Create(ctx context.Context, m *model.Merchant) error
|
||||
GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error)
|
||||
GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error)
|
||||
UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error
|
||||
List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
|
||||
ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
|
||||
ListAnomalous(ctx context.Context) ([]*model.Merchant, error)
|
||||
CreateApplication(ctx context.Context, app *model.MerchantApplication) error
|
||||
GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error)
|
||||
GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error)
|
||||
UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error
|
||||
}
|
||||
|
||||
// MerchantService 商户进件与管理服务
|
||||
type MerchantService struct {
|
||||
merchantRepo merchantRepo
|
||||
channelSvc *ChannelService
|
||||
}
|
||||
|
||||
func NewMerchantService(
|
||||
merchantRepo *repository.MerchantRepository,
|
||||
channelSvc *ChannelService,
|
||||
) *MerchantService {
|
||||
return &MerchantService{
|
||||
merchantRepo: merchantRepo,
|
||||
channelSvc: channelSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func genApplicationID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return "APP" + hex.EncodeToString(b)[:16]
|
||||
}
|
||||
|
||||
// Apply 提交商户进件申请
|
||||
// bizContent 为完整的入网申请业务参数(对应 001 文档的 biz_content 结构)
|
||||
func (s *MerchantService) Apply(ctx context.Context, merchantID, channelCode string, bizContent map[string]any) (string, error) {
|
||||
merchant, err := s.merchantRepo.GetByMerchantID(ctx, merchantID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if merchant == nil {
|
||||
return "", errors.New("merchant not found")
|
||||
}
|
||||
if merchant.Status == model.MerchantStatusFrozen {
|
||||
return "", errors.New("merchant is frozen")
|
||||
}
|
||||
|
||||
ch, err := s.channelSvc.GetChannel(ctx, "", channelCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := ch.MerchantApply(ctx, &channel.MerchantApplyReq{
|
||||
MerchantID: merchantID,
|
||||
BizContent: bizContent,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
applicationID := genApplicationID()
|
||||
app := &model.MerchantApplication{
|
||||
ApplicationID: applicationID,
|
||||
MerchantID: merchantID,
|
||||
ChannelCode: channelCode,
|
||||
SubmitData: model.JSONMap(bizContent),
|
||||
AuditStatus: model.AuditStatusSubmitting,
|
||||
SubmittedAt: time.Now(),
|
||||
}
|
||||
// 持久化渠道返回的 request_no,用于后续查询/修改
|
||||
if resp.RequestNo != "" {
|
||||
app.SubmitData["_channel_request_no"] = resp.RequestNo
|
||||
}
|
||||
if err := s.merchantRepo.CreateApplication(ctx, app); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "merchant application submitted",
|
||||
"merchant_id", merchantID,
|
||||
"application_id", applicationID,
|
||||
"channel_code", channelCode,
|
||||
"channel_request_no", resp.RequestNo,
|
||||
)
|
||||
return applicationID, nil
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到指定渠道,返回渠道 file_id
|
||||
func (s *MerchantService) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) {
|
||||
ch, err := s.channelSvc.GetChannel(ctx, "", channelCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := ch.UploadFile(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.FileID, nil
|
||||
}
|
||||
|
||||
// QueryAuditStatus 查询进件审核状态
|
||||
func (s *MerchantService) QueryAuditStatus(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
|
||||
app, err := s.merchantRepo.GetLatestApplication(ctx, merchantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if app == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 如果仍在审核中,向渠道查询最新状态
|
||||
if app.AuditStatus == model.AuditStatusSubmitting || app.AuditStatus == model.AuditStatusReviewing {
|
||||
// 从 submit_data 中读取渠道返回的 request_no
|
||||
channelRequestNo, _ := app.SubmitData["_channel_request_no"].(string)
|
||||
if channelRequestNo != "" {
|
||||
ch, err := s.channelSvc.GetChannel(ctx, "", app.ChannelCode)
|
||||
if err == nil {
|
||||
resp, err := ch.QueryMerchantStatus(ctx, channelRequestNo)
|
||||
if err == nil {
|
||||
merchant, _ := s.merchantRepo.GetByMerchantID(ctx, merchantID)
|
||||
s.syncMerchantStatus(ctx, merchantID, app.ApplicationID, merchant, resp)
|
||||
app, _ = s.merchantRepo.GetLatestApplication(ctx, merchantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// syncMerchantStatus 同步渠道返回的审核状态到本地
|
||||
func (s *MerchantService) syncMerchantStatus(ctx context.Context, merchantID, applicationID string,
|
||||
merchant *model.Merchant, resp *channel.MerchantStatusResp) {
|
||||
|
||||
now := time.Now()
|
||||
appUpdates := map[string]any{}
|
||||
|
||||
switch resp.Status {
|
||||
case "APPROVED":
|
||||
appUpdates["audit_status"] = model.AuditStatusApproved
|
||||
appUpdates["audited_at"] = now
|
||||
if resp.ChannelMerchantID != "" {
|
||||
appUpdates["channel_merchant_id"] = resp.ChannelMerchantID
|
||||
}
|
||||
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusActive, nil)
|
||||
|
||||
case "REJECTED":
|
||||
appUpdates["audit_status"] = model.AuditStatusRejected
|
||||
appUpdates["reject_reason"] = resp.RejectReason
|
||||
appUpdates["audited_at"] = now
|
||||
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusRejected, nil)
|
||||
|
||||
case "REVIEWING":
|
||||
appUpdates["audit_status"] = model.AuditStatusReviewing
|
||||
|
||||
case "FROZEN":
|
||||
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusFrozen, nil)
|
||||
}
|
||||
|
||||
if len(appUpdates) > 0 {
|
||||
s.merchantRepo.UpdateApplication(ctx, applicationID, appUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// GetChannelMerchantID 返回指定商户在指定渠道进件审核通过后的渠道商户ID
|
||||
// 若该商户未在该渠道进件或审核未通过,返回空字符串
|
||||
func (s *MerchantService) GetChannelMerchantID(ctx context.Context, merchantID, channelCode string) (string, error) {
|
||||
app, err := s.merchantRepo.GetApprovedApplicationByChannel(ctx, merchantID, channelCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if app == nil {
|
||||
return "", nil
|
||||
}
|
||||
return app.ChannelMerchantID, nil
|
||||
}
|
||||
|
||||
// CreateMerchantForApp 业务侧创建商户,强制绑定 appID
|
||||
func (s *MerchantService) CreateMerchantForApp(ctx context.Context, appID string, m *model.Merchant) error {
|
||||
m.AppID = appID
|
||||
return s.merchantRepo.Create(ctx, m)
|
||||
}
|
||||
|
||||
// GetMerchantForApp 业务侧查询,校验 appID 归属
|
||||
func (s *MerchantService) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) {
|
||||
m, err := s.merchantRepo.GetByMerchantIDAndAppID(ctx, merchantID, appID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m == nil {
|
||||
return nil, errors.New("30001") // merchant not found
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ListMerchantsForApp 业务侧列表,只返回该 appID 下的商户
|
||||
func (s *MerchantService) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
return s.merchantRepo.ListByAppID(ctx, appID, status, limit, offset)
|
||||
}
|
||||
|
||||
// ApplyForApp 业务侧进件,校验 appID 归属后委托 Apply
|
||||
func (s *MerchantService) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) {
|
||||
if _, err := s.GetMerchantForApp(ctx, appID, merchantID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Apply(ctx, merchantID, channelCode, bizContent)
|
||||
}
|
||||
|
||||
// QueryAuditStatusForApp 业务侧查审核状态,校验 appID 归属
|
||||
func (s *MerchantService) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) {
|
||||
if _, err := s.GetMerchantForApp(ctx, appID, merchantID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.QueryAuditStatus(ctx, merchantID)
|
||||
}
|
||||
|
||||
// CheckAnomalies 检查状态异常的商户(由 cron 调用)
|
||||
func (s *MerchantService) CheckAnomalies(ctx context.Context) error {
|
||||
merchants, err := s.merchantRepo.ListAnomalous(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.InfoContext(ctx, "anomalous merchants found", "count", len(merchants))
|
||||
// 实际业务中可在此发送告警通知
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMerchant 创建商户基础信息
|
||||
func (s *MerchantService) CreateMerchant(ctx context.Context, m *model.Merchant) error {
|
||||
return s.merchantRepo.Create(ctx, m)
|
||||
}
|
||||
|
||||
// GetMerchant 查询商户信息
|
||||
func (s *MerchantService) GetMerchant(ctx context.Context, merchantID string) (*model.Merchant, error) {
|
||||
return s.merchantRepo.GetByMerchantID(ctx, merchantID)
|
||||
}
|
||||
|
||||
// ListMerchants 查询商户列表
|
||||
func (s *MerchantService) ListMerchants(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
return s.merchantRepo.List(ctx, status, limit, offset)
|
||||
}
|
||||
|
||||
// FreezeMerchant 冻结商户
|
||||
func (s *MerchantService) FreezeMerchant(ctx context.Context, merchantID string) error {
|
||||
return s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusFrozen, nil)
|
||||
}
|
||||
|
||||
// UnfreezeMerchant 解冻商户
|
||||
func (s *MerchantService) UnfreezeMerchant(ctx context.Context, merchantID string) error {
|
||||
return s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusActive, nil)
|
||||
}
|
||||
200
backend/internal/service/merchant_test.go
Normal file
200
backend/internal/service/merchant_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// mockMerchantRepo 实现 merchantRepo interface
|
||||
type mockMerchantRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error {
|
||||
return m.Called(ctx, merchant).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, merchantID)
|
||||
return args.Get(0).(*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, merchantID, appID)
|
||||
v, _ := args.Get(0).(*model.Merchant)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
|
||||
return m.Called(ctx, merchantID, status, updates).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, appID, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
|
||||
return m.Called(ctx, app).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, merchantID)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, merchantID, channelCode)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
|
||||
return m.Called(ctx, applicationID, updates).Error(0)
|
||||
}
|
||||
|
||||
// newTestMerchantService 创建注入了 mock repo 的 service(channelSvc 为 nil,仅测不涉及渠道的方法)
|
||||
func newTestMerchantService(repo merchantRepo) *MerchantService {
|
||||
return &MerchantService{merchantRepo: repo}
|
||||
}
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
// --- GetMerchantForApp ---
|
||||
|
||||
func TestGetMerchantForApp_OK(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
want := &model.Merchant{MerchantID: "m001", AppID: "app1"}
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
got, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_NotFound(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.EqualError(t, err, "30001")
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_WrongAppID(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
// 商户存在但属于 other_app,GetByMerchantIDAndAppID 返回 nil
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "evil_app", "m001")
|
||||
|
||||
assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found,而不是泄露商户信息")
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_DBError(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error"))
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.EqualError(t, err, "db error")
|
||||
}
|
||||
|
||||
// --- CreateMerchantForApp ---
|
||||
|
||||
func TestCreateMerchantForApp_SetsAppID(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool {
|
||||
return m.AppID == "app1" && m.MerchantID == "m001"
|
||||
})).Return(nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
m := &model.Merchant{MerchantID: "m001"}
|
||||
err := svc.CreateMerchantForApp(ctx, "app1", m)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入")
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- ListMerchantsForApp ---
|
||||
|
||||
func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}}
|
||||
repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, got, 1)
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- ApplyForApp ---
|
||||
|
||||
func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil)
|
||||
|
||||
assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件")
|
||||
}
|
||||
|
||||
// --- GetChannelMerchantID ---
|
||||
|
||||
func TestGetChannelMerchantID_Approved(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
app := &model.MerchantApplication{
|
||||
ChannelMerchantID: "ch_m_999",
|
||||
}
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ch_m_999", id)
|
||||
}
|
||||
|
||||
func TestGetChannelMerchantID_NotApproved(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, id, "未在该渠道进件时返回空字符串")
|
||||
}
|
||||
|
||||
func TestGetChannelMerchantID_MultiChannel(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").
|
||||
Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").
|
||||
Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
|
||||
heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
|
||||
aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
|
||||
|
||||
assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id")
|
||||
assert.Equal(t, "ali_001", aliID)
|
||||
}
|
||||
229
backend/internal/service/notify.go
Normal file
229
backend/internal/service/notify.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// 重试间隔:9 次推送机会(第1次立即,后续8次重试)
|
||||
var retryIntervals = []time.Duration{
|
||||
0,
|
||||
15 * time.Second,
|
||||
30 * time.Second,
|
||||
1 * time.Minute,
|
||||
5 * time.Minute,
|
||||
30 * time.Minute,
|
||||
1 * time.Hour,
|
||||
6 * time.Hour,
|
||||
12 * time.Hour,
|
||||
}
|
||||
|
||||
const maxRetry = 8
|
||||
|
||||
// NotifyService 通知服务
|
||||
type NotifyService struct {
|
||||
notifyRepo *repository.NotifyLogRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewNotifyService(
|
||||
notifyRepo *repository.NotifyLogRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
httpTimeout time.Duration,
|
||||
) *NotifyService {
|
||||
return &NotifyService{
|
||||
notifyRepo: notifyRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
httpClient: &http.Client{Timeout: httpTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
// SendNotify 向下游发送通知(首次调用)
|
||||
func (s *NotifyService) SendNotify(ctx context.Context, tradeNo string, notifyType model.NotifyType, notifyURL string) error {
|
||||
// 构建通知内容
|
||||
payload, err := s.buildPayload(ctx, tradeNo, notifyType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建通知记录
|
||||
now := time.Now()
|
||||
log := &model.NotifyLog{
|
||||
TradeNo: tradeNo,
|
||||
NotifyType: notifyType,
|
||||
NotifyURL: notifyURL,
|
||||
Status: model.NotifyStatusPending,
|
||||
RetryCount: 0,
|
||||
}
|
||||
if err := s.notifyRepo.Upsert(ctx, log); err != nil {
|
||||
slog.ErrorContext(ctx, "upsert notify log failed", "trade_no", tradeNo, "err", err)
|
||||
}
|
||||
|
||||
// 发送通知
|
||||
resp, err := s.sendHTTP(ctx, notifyURL, payload)
|
||||
if err == nil && isSuccessResponse(resp) {
|
||||
s.notifyRepo.MarkSuccess(ctx, log.ID, resp)
|
||||
slog.InfoContext(ctx, "notify success", "trade_no", tradeNo, "type", notifyType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 首次失败,写入重试队列
|
||||
errMsg := ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else {
|
||||
errMsg = resp
|
||||
}
|
||||
|
||||
nextTime := now.Add(retryIntervals[1])
|
||||
s.notifyRepo.IncrRetryCount(ctx, log.ID, model.NotifyStatusRetry, &nextTime, errMsg)
|
||||
slog.WarnContext(ctx, "notify failed, scheduled retry", "trade_no", tradeNo, "next_retry", nextTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessRetryQueue 处理重试队列(由 Poller 调用)
|
||||
func (s *NotifyService) ProcessRetryQueue(ctx context.Context, batchSize int) error {
|
||||
logs, err := s.notifyRepo.ListPendingRetry(ctx, time.Now(), batchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
s.processOne(ctx, log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotifyService) processOne(ctx context.Context, log *model.NotifyLog) {
|
||||
payload, err := s.buildPayload(ctx, log.TradeNo, log.NotifyType)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "build payload failed", "trade_no", log.TradeNo, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.sendHTTP(ctx, log.NotifyURL, payload)
|
||||
if err == nil && isSuccessResponse(resp) {
|
||||
s.notifyRepo.MarkSuccess(ctx, log.ID, resp)
|
||||
slog.InfoContext(ctx, "notify retry success", "trade_no", log.TradeNo, "retry_count", log.RetryCount)
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else {
|
||||
errMsg = resp
|
||||
}
|
||||
|
||||
nextRetryIdx := log.RetryCount + 1
|
||||
if nextRetryIdx > maxRetry {
|
||||
s.notifyRepo.MarkGiveup(ctx, log.ID)
|
||||
slog.WarnContext(ctx, "notify giveup after max retries", "trade_no", log.TradeNo)
|
||||
return
|
||||
}
|
||||
|
||||
var nextTime *time.Time
|
||||
if nextRetryIdx < len(retryIntervals) {
|
||||
t := time.Now().Add(retryIntervals[nextRetryIdx])
|
||||
nextTime = &t
|
||||
}
|
||||
|
||||
status := model.NotifyStatusRetry
|
||||
if nextRetryIdx >= maxRetry {
|
||||
status = model.NotifyStatusGiveup
|
||||
}
|
||||
|
||||
s.notifyRepo.IncrRetryCount(ctx, log.ID, status, nextTime, errMsg)
|
||||
}
|
||||
|
||||
// buildPayload 构建通知内容
|
||||
func (s *NotifyService) buildPayload(ctx context.Context, tradeNo string, notifyType model.NotifyType) ([]byte, error) {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return nil, fmt.Errorf("order not found: %s", tradeNo)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"trade_no": order.TradeNo,
|
||||
"merchant_order_no": order.MerchantOrderNo,
|
||||
"app_id": order.AppID,
|
||||
"pay_method": order.PayMethod,
|
||||
"amount": order.Amount,
|
||||
"status": order.Status,
|
||||
"notify_type": notifyType,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
if order.ChannelTradeNo != "" {
|
||||
payload["channel_trade_no"] = order.ChannelTradeNo
|
||||
}
|
||||
if order.PayTime != nil {
|
||||
payload["pay_time"] = order.PayTime.Unix()
|
||||
}
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
// sendHTTP 向下游发送 HTTP POST 通知
|
||||
func (s *NotifyService) sendHTTP(ctx context.Context, notifyURL string, payload []byte) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, notifyURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// isSuccessResponse 判断下游是否返回成功
|
||||
// 下游返回 HTTP 200 且 body 包含 "success" 则视为成功
|
||||
func isSuccessResponse(body string) bool {
|
||||
return strings.Contains(strings.ToLower(body), "success")
|
||||
}
|
||||
|
||||
// NextRetryTime 计算下次重试时间
|
||||
func NextRetryTime(retryCount int) (time.Time, bool) {
|
||||
idx := retryCount + 1
|
||||
if idx >= len(retryIntervals) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Now().Add(retryIntervals[idx]), true
|
||||
}
|
||||
|
||||
// StartPoller 启动通知重试 Poller goroutine
|
||||
func (s *NotifyService) StartPoller(ctx context.Context, interval time.Duration, batchSize int) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.ProcessRetryQueue(ctx, batchSize); err != nil {
|
||||
slog.Error("notify poller error", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
slog.Info("notify poller started", "interval", interval)
|
||||
}
|
||||
280
backend/internal/service/payment_match.go
Normal file
280
backend/internal/service/payment_match.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// orderNoPatterns 从备注中提取订单号的正则列表(优先级从高到低)
|
||||
var orderNoPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`PAY\d{14}`), // pay-bridge 交易号格式 PAYyyMMddNNNNNNNN
|
||||
regexp.MustCompile(`REF\d{14}`), // 退款单号
|
||||
regexp.MustCompile(`[A-Z0-9]{16,32}`), // 通用订单号格式
|
||||
}
|
||||
|
||||
// IncomingPayment 入账通知数据
|
||||
type IncomingPayment struct {
|
||||
AccountNo string // 收款账号
|
||||
Amount int64 // 入账金额(分)
|
||||
Remark string // 转账备注
|
||||
PayerName string // 付款方名称
|
||||
ChannelBillNo string // 渠道流水号
|
||||
}
|
||||
|
||||
// matchWindow 匹配时间窗口(7天内的待支付订单)
|
||||
const matchWindow = 7 * 24 * time.Hour
|
||||
|
||||
// PaymentMatchService 收款匹配服务
|
||||
type PaymentMatchService struct {
|
||||
matchRepo *repository.PaymentMatchRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
notifySvc *NotifyService
|
||||
tradeSvc *TradeService
|
||||
}
|
||||
|
||||
func NewPaymentMatchService(
|
||||
matchRepo *repository.PaymentMatchRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
notifySvc *NotifyService,
|
||||
tradeSvc *TradeService,
|
||||
) *PaymentMatchService {
|
||||
return &PaymentMatchService{
|
||||
matchRepo: matchRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
notifySvc: notifySvc,
|
||||
tradeSvc: tradeSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleIncomingPayment 处理入账通知(核心匹配流程)
|
||||
func (s *PaymentMatchService) HandleIncomingPayment(ctx context.Context, incoming *IncomingPayment) error {
|
||||
// 幂等检查
|
||||
if existing, _ := s.matchRepo.GetMatchLogByBillNo(ctx, incoming.ChannelBillNo); existing != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查询收款账户
|
||||
account, err := s.matchRepo.GetAccountByNo(ctx, incoming.AccountNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if account == nil {
|
||||
slog.WarnContext(ctx, "incoming payment: account not found", "account_no", incoming.AccountNo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行匹配
|
||||
result := s.match(ctx, incoming, account)
|
||||
|
||||
// 记录匹配结果
|
||||
now := time.Now()
|
||||
log := &model.PaymentMatchLog{
|
||||
AccountID: account.ID,
|
||||
IncomingAmount: incoming.Amount,
|
||||
IncomingRemark: incoming.Remark,
|
||||
PayerName: incoming.PayerName,
|
||||
ChannelBillNo: incoming.ChannelBillNo,
|
||||
MatchStatus: result.status,
|
||||
NameDiff: result.nameDiff,
|
||||
}
|
||||
if result.tradeNo != "" {
|
||||
log.TradeNo = result.tradeNo
|
||||
log.MatchTime = &now
|
||||
}
|
||||
if err := s.matchRepo.CreateMatchLog(ctx, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 匹配成功:更新订单状态并通知下游
|
||||
if result.tradeNo != "" {
|
||||
updates := map[string]any{
|
||||
"status": model.TradeStatusPaid,
|
||||
"pay_time": now,
|
||||
}
|
||||
ok, err := s.tradeRepo.UpdateStatus(ctx, result.tradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
order, _ := s.tradeRepo.GetByTradeNo(ctx, result.tradeNo)
|
||||
if order != nil && s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
s.notifySvc.SendNotify(bgCtx, result.tradeNo, model.NotifyTypePayment, order.NotifyURL)
|
||||
}()
|
||||
}
|
||||
}
|
||||
slog.InfoContext(ctx, "payment matched",
|
||||
"trade_no", result.tradeNo,
|
||||
"amount", incoming.Amount,
|
||||
"status", result.status,
|
||||
"name_diff", result.nameDiff,
|
||||
)
|
||||
} else {
|
||||
slog.InfoContext(ctx, "payment pending manual",
|
||||
"channel_bill_no", incoming.ChannelBillNo,
|
||||
"amount", incoming.Amount,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ManualBindOrder 人工关联入账与订单
|
||||
func (s *PaymentMatchService) ManualBindOrder(ctx context.Context, matchID uint64, tradeNo, operator string) error {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
updates := map[string]any{
|
||||
"trade_no": tradeNo,
|
||||
"match_status": model.MatchStatusMatched,
|
||||
"match_time": now,
|
||||
"operator": operator,
|
||||
}
|
||||
if err := s.matchRepo.UpdateMatchLog(ctx, matchID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新订单状态
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusPaid,
|
||||
map[string]any{"pay_time": now})
|
||||
|
||||
if s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
s.notifySvc.SendNotify(bgCtx, tradeNo, model.NotifyTypePayment, order.NotifyURL)
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPendingManual 查询待人工确认的收款记录
|
||||
func (s *PaymentMatchService) ListPendingManual(ctx context.Context, appID string, limit, offset int) ([]*model.PaymentMatchLog, error) {
|
||||
return s.matchRepo.ListPendingManual(ctx, appID, limit, offset)
|
||||
}
|
||||
|
||||
// --- 内部匹配逻辑 ---
|
||||
|
||||
type matchResult struct {
|
||||
tradeNo string
|
||||
status model.MatchStatus
|
||||
nameDiff int8
|
||||
}
|
||||
|
||||
func (s *PaymentMatchService) match(ctx context.Context, incoming *IncomingPayment, account *model.SubMerchantAccount) matchResult {
|
||||
// Step 1: 从备注中提取订单号
|
||||
candidates := extractOrderNos(incoming.Remark)
|
||||
|
||||
var matched *model.TradeOrder
|
||||
for _, orderNo := range candidates {
|
||||
// 先按 trade_no 查,再按 merchant_order_no 查
|
||||
order, _ := s.tradeRepo.GetByTradeNo(ctx, orderNo)
|
||||
if order == nil {
|
||||
order, _ = s.tradeRepo.GetByMerchantOrderNo(ctx, account.AppID, orderNo)
|
||||
}
|
||||
if order == nil || order.AppID != account.AppID {
|
||||
continue
|
||||
}
|
||||
if order.Status != model.TradeStatusPaying {
|
||||
continue
|
||||
}
|
||||
// Step 2: 金额精确匹配
|
||||
if order.Amount != incoming.Amount {
|
||||
continue
|
||||
}
|
||||
matched = order
|
||||
break
|
||||
}
|
||||
|
||||
// 备注匹配失败,降级为金额匹配
|
||||
if matched == nil {
|
||||
orders, _ := s.matchRepo.ListPayingByAmount(ctx, account.AppID, incoming.Amount, matchWindow)
|
||||
if len(orders) == 1 {
|
||||
matched = orders[0]
|
||||
} else if len(orders) > 1 {
|
||||
// Step 3: 用付款方名称缩小范围
|
||||
matched = filterByPayerName(orders, incoming.PayerName)
|
||||
if matched == nil {
|
||||
return matchResult{status: model.MatchStatusPendingManual}
|
||||
}
|
||||
} else {
|
||||
return matchResult{status: model.MatchStatusPendingManual}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: 付款方名称一致性检查
|
||||
var nameDiff int8 = 0
|
||||
invoiceName := getInvoiceName(matched)
|
||||
if invoiceName != "" && incoming.PayerName != "" {
|
||||
if !strings.EqualFold(strings.TrimSpace(invoiceName), strings.TrimSpace(incoming.PayerName)) {
|
||||
nameDiff = 1
|
||||
}
|
||||
}
|
||||
|
||||
status := model.MatchStatusMatched
|
||||
if nameDiff == 1 {
|
||||
status = model.MatchStatusNameDiff
|
||||
}
|
||||
|
||||
return matchResult{
|
||||
tradeNo: matched.TradeNo,
|
||||
status: status,
|
||||
nameDiff: nameDiff,
|
||||
}
|
||||
}
|
||||
|
||||
// extractOrderNos 从备注字符串中提取可能的订单号
|
||||
func extractOrderNos(remark string) []string {
|
||||
if remark == "" {
|
||||
return nil
|
||||
}
|
||||
var results []string
|
||||
seen := map[string]bool{}
|
||||
for _, re := range orderNoPatterns {
|
||||
matches := re.FindAllString(remark, -1)
|
||||
for _, m := range matches {
|
||||
if !seen[m] {
|
||||
seen[m] = true
|
||||
results = append(results, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// filterByPayerName 从多个候选订单中,选择 invoice_name 与付款方名称匹配的订单
|
||||
// invoice_name 暂存在 extra 字段中
|
||||
func filterByPayerName(orders []*model.TradeOrder, payerName string) *model.TradeOrder {
|
||||
if payerName == "" {
|
||||
return nil
|
||||
}
|
||||
for _, o := range orders {
|
||||
name := getInvoiceName(o)
|
||||
if name != "" && strings.EqualFold(strings.TrimSpace(name), strings.TrimSpace(payerName)) {
|
||||
return o
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInvoiceName 从 extra 字段获取开票名称
|
||||
func getInvoiceName(order *model.TradeOrder) string {
|
||||
if order.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := order.Extra["invoice_name"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
268
backend/internal/service/profit_sharing.go
Normal file
268
backend/internal/service/profit_sharing.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
const (
|
||||
sharingLockPrefix = "lock:sharing:"
|
||||
sharingLockTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
// ProfitSharingService 分润服务
|
||||
type ProfitSharingService struct {
|
||||
sharingRepo *repository.ProfitSharingRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
seqSvc *sequence.Service
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewProfitSharingService(
|
||||
sharingRepo *repository.ProfitSharingRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
seqSvc *sequence.Service,
|
||||
rdb *redis.Client,
|
||||
) *ProfitSharingService {
|
||||
return &ProfitSharingService{
|
||||
sharingRepo: sharingRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
seqSvc: seqSvc,
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerSharing 支付成功后触发分润(幂等)
|
||||
func (s *ProfitSharingService) TriggerSharing(ctx context.Context, tradeNo string) error {
|
||||
// 分布式锁防止并发重复触发
|
||||
lockKey := sharingLockPrefix + tradeNo
|
||||
ok, err := s.rdb.SetNX(ctx, lockKey, "1", sharingLockTTL).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquire sharing lock: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil // 已有进程在处理
|
||||
}
|
||||
defer s.rdb.Del(ctx, lockKey)
|
||||
|
||||
// 查询交易
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.ProfitSharingAmount <= 0 {
|
||||
return nil // 无需分润
|
||||
}
|
||||
|
||||
// 幂等检查:是否已有分润记录
|
||||
existing, err := s.sharingRepo.GetOrderByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil // 已触发过
|
||||
}
|
||||
|
||||
// 获取应用分润配置
|
||||
cfg, err := s.sharingRepo.GetConfigByAppID(ctx, order.AppID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New(errcode.ErrSharingNotConfig)
|
||||
}
|
||||
|
||||
// 校验分润比例
|
||||
maxAmount := int64(float64(order.Amount) * cfg.MaxSharingRatio)
|
||||
if order.ProfitSharingAmount > maxAmount {
|
||||
return errors.New(errcode.ErrSharingAmountExceed)
|
||||
}
|
||||
|
||||
// 生成分润单号
|
||||
sharingNo, err := s.seqSvc.NextSharingNo(ctx, order.AppID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建分润记录
|
||||
sharingOrder := &model.ProfitSharingOrder{
|
||||
SharingNo: sharingNo,
|
||||
TradeNo: tradeNo,
|
||||
AppID: order.AppID,
|
||||
ReceiverMerchantID: cfg.ReceiverMerchantID,
|
||||
SharingAmount: order.ProfitSharingAmount,
|
||||
Status: model.ProfitSharingStatusPending,
|
||||
}
|
||||
if err := s.sharingRepo.CreateOrder(ctx, sharingOrder); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 调用渠道分账
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := ch.ProfitSharing(ctx, &channel.ProfitSharingReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
SharingNo: sharingNo,
|
||||
ReceiverMerchantID: cfg.ReceiverMerchantID,
|
||||
Amount: order.ProfitSharingAmount,
|
||||
})
|
||||
if err != nil {
|
||||
s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
|
||||
model.ProfitSharingStatusPending,
|
||||
model.ProfitSharingStatusFailed,
|
||||
map[string]any{"fail_reason": err.Error()},
|
||||
)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingNo,
|
||||
Action: "SPLIT",
|
||||
Amount: order.ProfitSharingAmount,
|
||||
Status: "FAILED",
|
||||
})
|
||||
return fmt.Errorf("profit sharing failed: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
|
||||
model.ProfitSharingStatusPending,
|
||||
model.ProfitSharingStatusProcessing,
|
||||
map[string]any{
|
||||
"channel_sharing_no": resp.ChannelSharingNo,
|
||||
"sharing_time": now,
|
||||
},
|
||||
)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingNo,
|
||||
Action: "SPLIT",
|
||||
Amount: order.ProfitSharingAmount,
|
||||
Status: "PROCESSING",
|
||||
})
|
||||
|
||||
slog.InfoContext(ctx, "profit sharing triggered",
|
||||
"trade_no", tradeNo,
|
||||
"sharing_no", sharingNo,
|
||||
"amount", order.ProfitSharingAmount,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleSharingNotify 处理分账回调(上游分账完成通知)
|
||||
func (s *ProfitSharingService) HandleSharingNotify(ctx context.Context, sharingNo, channelSharingNo string, status model.ProfitSharingStatus) error {
|
||||
now := time.Now()
|
||||
updates := map[string]any{
|
||||
"channel_sharing_no": channelSharingNo,
|
||||
"sharing_time": now,
|
||||
}
|
||||
ok, err := s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
|
||||
model.ProfitSharingStatusProcessing, status, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil // 幂等
|
||||
}
|
||||
logStatus := string(status)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingNo,
|
||||
Action: "SPLIT",
|
||||
Amount: 0,
|
||||
Status: logStatus,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollbackSharing 退款前回退分润
|
||||
func (s *ProfitSharingService) RollbackSharing(ctx context.Context, tradeNo string) error {
|
||||
sharingOrder, err := s.sharingRepo.GetOrderByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sharingOrder == nil {
|
||||
return nil // 无分润,直接跳过
|
||||
}
|
||||
if sharingOrder.Status == model.ProfitSharingStatusRollback {
|
||||
return nil // 已回退,幂等
|
||||
}
|
||||
if sharingOrder.Status != model.ProfitSharingStatusSuccess {
|
||||
return fmt.Errorf("sharing not success, cannot rollback, status=%s", sharingOrder.Status)
|
||||
}
|
||||
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ch.RollbackProfitSharing(ctx, &channel.RollbackSharingReq{
|
||||
SharingNo: sharingOrder.SharingNo,
|
||||
ChannelSharingNo: sharingOrder.ChannelSharingNo,
|
||||
TradeNo: tradeNo,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("rollback sharing failed: %w", err)
|
||||
}
|
||||
|
||||
s.sharingRepo.UpdateOrderStatus(ctx, sharingOrder.SharingNo,
|
||||
model.ProfitSharingStatusSuccess,
|
||||
model.ProfitSharingStatusRollback,
|
||||
nil,
|
||||
)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingOrder.SharingNo,
|
||||
Action: "ROLLBACK",
|
||||
Amount: sharingOrder.SharingAmount,
|
||||
Status: "SUCCESS",
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuerySharing 查询分润状态
|
||||
func (s *ProfitSharingService) QuerySharing(ctx context.Context, sharingNo string) (*model.ProfitSharingOrder, error) {
|
||||
order, err := s.sharingRepo.GetOrderBySharingNo(ctx, sharingNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil {
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// ValidateSharingAmount 下单时校验分润金额是否合法
|
||||
func (s *ProfitSharingService) ValidateSharingAmount(ctx context.Context, appID string, orderAmount, sharingAmount int64) error {
|
||||
if sharingAmount <= 0 {
|
||||
return nil
|
||||
}
|
||||
cfg, err := s.sharingRepo.GetConfigByAppID(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New(errcode.ErrSharingNotConfig)
|
||||
}
|
||||
maxAmount := int64(float64(orderAmount) * cfg.MaxSharingRatio)
|
||||
if sharingAmount > maxAmount {
|
||||
return errors.New(errcode.ErrSharingAmountExceed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
221
backend/internal/service/reconciliation.go
Normal file
221
backend/internal/service/reconciliation.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// ReconciliationService T+1 自动对账服务
|
||||
type ReconciliationService struct {
|
||||
reconRepo *repository.ReconciliationRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
appRepo *repository.AppRepository
|
||||
}
|
||||
|
||||
func NewReconciliationService(
|
||||
reconRepo *repository.ReconciliationRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
appRepo *repository.AppRepository,
|
||||
) *ReconciliationService {
|
||||
return &ReconciliationService{
|
||||
reconRepo: reconRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
appRepo: appRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// RunDailyReconciliation 执行 T+1 对账(cron 每日触发)
|
||||
func (s *ReconciliationService) RunDailyReconciliation(ctx context.Context) error {
|
||||
// 对账日期:昨天
|
||||
billDate := time.Now().AddDate(0, 0, -1).Format("2006-01-02")
|
||||
slog.InfoContext(ctx, "reconciliation started", "bill_date", billDate)
|
||||
|
||||
apps, err := s.appRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, app := range apps {
|
||||
if err := s.reconcileApp(ctx, app.AppID, billDate); err != nil {
|
||||
slog.ErrorContext(ctx, "reconciliation failed for app",
|
||||
"app_id", app.AppID,
|
||||
"bill_date", billDate,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileApp 对指定应用执行对账
|
||||
func (s *ReconciliationService) reconcileApp(ctx context.Context, appID, billDate string) error {
|
||||
// 获取所有活跃渠道配置
|
||||
channelCodes, err := s.channelSvc.ListChannelCodes(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, code := range channelCodes {
|
||||
if err := s.reconcileChannel(ctx, appID, code, billDate); err != nil {
|
||||
slog.ErrorContext(ctx, "channel reconciliation failed",
|
||||
"app_id", appID,
|
||||
"channel", code,
|
||||
"bill_date", billDate,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileChannel 对单个渠道执行对账
|
||||
func (s *ReconciliationService) reconcileChannel(ctx context.Context, appID, channelCode, billDate string) error {
|
||||
// 幂等检查
|
||||
existing, err := s.reconRepo.GetReport(ctx, appID, billDate, channelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil && existing.Status == model.ReconciliationStatusMatched {
|
||||
return nil // 已对账完成
|
||||
}
|
||||
|
||||
// 创建对账报告
|
||||
report := &model.ReconciliationReport{
|
||||
AppID: appID,
|
||||
ChannelCode: channelCode,
|
||||
BillDate: billDate,
|
||||
Status: model.ReconciliationStatusPending,
|
||||
}
|
||||
if existing == nil {
|
||||
if err := s.reconRepo.CreateReport(ctx, report); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
report = existing
|
||||
}
|
||||
|
||||
// 下载渠道对账单
|
||||
ch, err := s.channelSvc.GetChannel(ctx, appID, channelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
billData, err := ch.DownloadBill(ctx, &channel.DownloadBillReq{BillDate: billDate})
|
||||
if err != nil {
|
||||
return fmt.Errorf("download bill: %w", err)
|
||||
}
|
||||
|
||||
// 查询本地已支付订单
|
||||
localOrders, err := s.reconRepo.ListPaidOrdersByDate(ctx, appID, billDate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 建立本地订单索引
|
||||
localIndex := make(map[string]*model.TradeOrder, len(localOrders))
|
||||
for _, o := range localOrders {
|
||||
localIndex[o.TradeNo] = o
|
||||
}
|
||||
|
||||
// 建立渠道账单索引
|
||||
channelIndex := make(map[string]*channel.BillRecord, len(billData.Records))
|
||||
for i := range billData.Records {
|
||||
channelIndex[billData.Records[i].TradeNo] = &billData.Records[i]
|
||||
}
|
||||
|
||||
matched := 0
|
||||
exceptions := 0
|
||||
|
||||
// 检查渠道账单中有,本地没有的(漏单)
|
||||
for _, rec := range billData.Records {
|
||||
local, ok := localIndex[rec.TradeNo]
|
||||
if !ok {
|
||||
// 本地缺失
|
||||
ex := &model.ReconciliationException{
|
||||
ReportID: report.ID,
|
||||
TradeNo: rec.TradeNo,
|
||||
ChannelBillNo: rec.ChannelBillNo,
|
||||
ExceptionType: "MISSING_LOCAL",
|
||||
ChannelAmount: rec.Amount,
|
||||
Remark: "渠道有记录,本地无此订单",
|
||||
}
|
||||
s.reconRepo.CreateException(ctx, ex)
|
||||
exceptions++
|
||||
continue
|
||||
}
|
||||
// 金额比对
|
||||
if local.Amount != rec.Amount {
|
||||
ex := &model.ReconciliationException{
|
||||
ReportID: report.ID,
|
||||
TradeNo: rec.TradeNo,
|
||||
ChannelBillNo: rec.ChannelBillNo,
|
||||
ExceptionType: "AMOUNT_MISMATCH",
|
||||
LocalAmount: local.Amount,
|
||||
ChannelAmount: rec.Amount,
|
||||
Remark: fmt.Sprintf("金额不符:本地%d 渠道%d", local.Amount, rec.Amount),
|
||||
}
|
||||
s.reconRepo.CreateException(ctx, ex)
|
||||
exceptions++
|
||||
} else {
|
||||
matched++
|
||||
}
|
||||
}
|
||||
|
||||
// 检查本地有,渠道账单中没有的(多单)
|
||||
for tradeNo, local := range localIndex {
|
||||
if _, ok := channelIndex[tradeNo]; !ok {
|
||||
ex := &model.ReconciliationException{
|
||||
ReportID: report.ID,
|
||||
TradeNo: tradeNo,
|
||||
ExceptionType: "MISSING_CHANNEL",
|
||||
LocalAmount: local.Amount,
|
||||
Remark: "本地已支付,渠道账单无记录",
|
||||
}
|
||||
s.reconRepo.CreateException(ctx, ex)
|
||||
exceptions++
|
||||
}
|
||||
}
|
||||
|
||||
// 更新对账报告
|
||||
status := model.ReconciliationStatusMatched
|
||||
if exceptions > 0 {
|
||||
status = model.ReconciliationStatusException
|
||||
}
|
||||
updates := map[string]any{
|
||||
"total_count": len(billData.Records),
|
||||
"total_amount": billData.TotalAmount,
|
||||
"matched_count": matched,
|
||||
"exception_count": exceptions,
|
||||
"status": status,
|
||||
}
|
||||
if err := s.reconRepo.UpdateReport(ctx, report.ID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "reconciliation done",
|
||||
"app_id", appID,
|
||||
"channel", channelCode,
|
||||
"bill_date", billDate,
|
||||
"matched", matched,
|
||||
"exceptions", exceptions,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReport 查询对账报告
|
||||
func (s *ReconciliationService) GetReport(ctx context.Context, appID, billDate, channelCode string) (*model.ReconciliationReport, error) {
|
||||
return s.reconRepo.GetReport(ctx, appID, billDate, channelCode)
|
||||
}
|
||||
|
||||
// GetExceptions 查询对账异常明细
|
||||
func (s *ReconciliationService) GetExceptions(ctx context.Context, reportID uint64) ([]*model.ReconciliationException, error) {
|
||||
return s.reconRepo.ListExceptions(ctx, reportID)
|
||||
}
|
||||
213
backend/internal/service/refund.go
Normal file
213
backend/internal/service/refund.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
// CreateRefundReq 退款请求
|
||||
type CreateRefundReq struct {
|
||||
AppID string
|
||||
TradeNo string
|
||||
RefundAmount int64
|
||||
Reason string
|
||||
NotifyURL string
|
||||
}
|
||||
|
||||
// RefundService 退款服务
|
||||
type RefundService struct {
|
||||
refundRepo *repository.RefundOrderRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
seqSvc *sequence.Service
|
||||
notifySvc *NotifyService
|
||||
}
|
||||
|
||||
func NewRefundService(
|
||||
refundRepo *repository.RefundOrderRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
seqSvc *sequence.Service,
|
||||
notifySvc *NotifyService,
|
||||
) *RefundService {
|
||||
return &RefundService{
|
||||
refundRepo: refundRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
seqSvc: seqSvc,
|
||||
notifySvc: notifySvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRefund 发起退款
|
||||
func (s *RefundService) CreateRefund(ctx context.Context, req *CreateRefundReq) (*model.RefundOrder, error) {
|
||||
// 查询原交易
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, req.TradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil || order.AppID != req.AppID {
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.Status != model.TradeStatusPaid && order.Status != model.TradeStatusRefunded {
|
||||
return nil, errors.New(errcode.ErrOrderNotPaid)
|
||||
}
|
||||
|
||||
// 校验可退金额
|
||||
refunded, err := s.refundRepo.SumRefundedAmount(ctx, req.TradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refunded+req.RefundAmount > order.Amount {
|
||||
return nil, errors.New(errcode.ErrRefundAmountExceed)
|
||||
}
|
||||
|
||||
// 生成退款单号
|
||||
refundNo, err := s.seqSvc.NextRefundNo(ctx, req.AppID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建退款记录
|
||||
refund := &model.RefundOrder{
|
||||
RefundNo: refundNo,
|
||||
TradeNo: req.TradeNo,
|
||||
AppID: req.AppID,
|
||||
ChannelCode: order.ChannelCode,
|
||||
RefundAmount: req.RefundAmount,
|
||||
Reason: req.Reason,
|
||||
Status: model.RefundStatusPending,
|
||||
NotifyURL: req.NotifyURL,
|
||||
}
|
||||
if err := s.refundRepo.Create(ctx, refund); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用渠道退款
|
||||
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelResp, err := ch.Refund(ctx, &channel.RefundReq{
|
||||
TradeNo: req.TradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
RefundNo: refundNo,
|
||||
RefundAmount: req.RefundAmount,
|
||||
TotalAmount: order.Amount,
|
||||
Reason: req.Reason,
|
||||
NotifyURL: req.NotifyURL,
|
||||
})
|
||||
if err != nil {
|
||||
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusFailed, nil)
|
||||
return nil, errors.New(errcode.ErrChannelRefundFail)
|
||||
}
|
||||
|
||||
// 更新渠道退款单号
|
||||
updates := map[string]any{
|
||||
"channel_refund_no": channelResp.ChannelRefundNo,
|
||||
"status": model.RefundStatusProcessing,
|
||||
}
|
||||
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusProcessing, updates)
|
||||
|
||||
refund.ChannelRefundNo = channelResp.ChannelRefundNo
|
||||
refund.Status = model.RefundStatusProcessing
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// QueryRefund 查询退款状态
|
||||
func (s *RefundService) QueryRefund(ctx context.Context, appID, refundNo string) (*model.RefundOrder, error) {
|
||||
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refund == nil || refund.AppID != appID {
|
||||
return nil, errors.New(errcode.ErrRefundNotFound)
|
||||
}
|
||||
|
||||
// 如果处于处理中,主动查询渠道
|
||||
if refund.Status == model.RefundStatusProcessing {
|
||||
s.syncRefundStatus(ctx, refund)
|
||||
// 重新查询最新状态
|
||||
refund, _ = s.refundRepo.GetByRefundNo(ctx, refundNo)
|
||||
}
|
||||
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// HandleRefundNotify 处理退款回调
|
||||
func (s *RefundService) HandleRefundNotify(ctx context.Context, refundNo string, channelRefundNo string, status model.RefundStatus) error {
|
||||
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
|
||||
if err != nil || refund == nil {
|
||||
return errors.New(errcode.ErrRefundNotFound)
|
||||
}
|
||||
|
||||
if refund.Status == model.RefundStatusSuccess {
|
||||
return nil // 幂等
|
||||
}
|
||||
|
||||
updates := map[string]any{
|
||||
"channel_refund_no": channelRefundNo,
|
||||
}
|
||||
if status == model.RefundStatusSuccess {
|
||||
now := time.Now()
|
||||
updates["refund_time"] = now
|
||||
}
|
||||
|
||||
ok, err := s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusProcessing, status, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil // 幂等
|
||||
}
|
||||
|
||||
// 退款成功后通知下游
|
||||
if status == model.RefundStatusSuccess && refund.NotifyURL != "" && s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
if err := s.notifySvc.SendNotify(bgCtx, refund.TradeNo, model.NotifyTypeRefund, refund.NotifyURL); err != nil {
|
||||
slog.Error("send refund notify failed", "refund_no", refundNo, "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RefundService) syncRefundStatus(ctx context.Context, refund *model.RefundOrder) {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, refund.TradeNo)
|
||||
if err != nil || order == nil {
|
||||
return
|
||||
}
|
||||
ch, err := s.channelSvc.GetChannel(ctx, refund.AppID, refund.ChannelCode)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := ch.QueryRefund(ctx, &channel.QueryRefundReq{
|
||||
RefundNo: refund.RefundNo,
|
||||
ChannelRefundNo: refund.ChannelRefundNo,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp.Status != refund.Status {
|
||||
updates := map[string]any{
|
||||
"channel_refund_no": resp.ChannelRefundNo,
|
||||
}
|
||||
if resp.RefundTime != nil {
|
||||
updates["refund_time"] = resp.RefundTime
|
||||
}
|
||||
s.refundRepo.UpdateStatus(ctx, refund.RefundNo, refund.Status, resp.Status, updates)
|
||||
}
|
||||
}
|
||||
|
||||
189
backend/internal/service/service_fee.go
Normal file
189
backend/internal/service/service_fee.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// ServiceFeeService 服务费服务
|
||||
type ServiceFeeService struct {
|
||||
feeRepo *repository.ServiceFeeRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
}
|
||||
|
||||
func NewServiceFeeService(
|
||||
feeRepo *repository.ServiceFeeRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
) *ServiceFeeService {
|
||||
return &ServiceFeeService{
|
||||
feeRepo: feeRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// ChargeServiceFee 交易完成后扣收服务费
|
||||
func (s *ServiceFeeService) ChargeServiceFee(ctx context.Context, tradeNo string) error {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return fmt.Errorf("order not found: %s", tradeNo)
|
||||
}
|
||||
|
||||
// 幂等检查
|
||||
existing, err := s.feeRepo.GetLog(ctx, tradeNo, "CHARGE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil // 已扣收
|
||||
}
|
||||
|
||||
// 获取服务费配置
|
||||
group := model.PayMethodToGroup(order.PayMethod)
|
||||
cfg, err := s.feeRepo.GetConfig(ctx, order.AppID, group)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil || cfg.FeeRate == 0 {
|
||||
return nil // 未配置或费率为0
|
||||
}
|
||||
|
||||
// 计算服务费(四舍五入到分)
|
||||
feeAmount := calculateFee(order.Amount, cfg.FeeRate)
|
||||
if feeAmount <= 0 {
|
||||
return nil // 不足1分不扣收
|
||||
}
|
||||
|
||||
// 更新订单服务费金额快照
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaid, model.TradeStatusPaid,
|
||||
map[string]any{"service_fee_amount": feeAmount})
|
||||
|
||||
// 创建服务费流水
|
||||
log := &model.ServiceFeeLog{
|
||||
TradeNo: tradeNo,
|
||||
ConfigID: cfg.ID,
|
||||
FeeAmount: feeAmount,
|
||||
FeeRate: cfg.FeeRate,
|
||||
ReceiverMerchantID: cfg.FeeReceiverMerchantID,
|
||||
Action: "CHARGE",
|
||||
Status: "PENDING",
|
||||
}
|
||||
if err := s.feeRepo.CreateLog(ctx, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 调用渠道分账
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
s.feeRepo.UpdateLogStatus(ctx, log.ID, "FAILED", "")
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := ch.ProfitSharing(ctx, &channel.ProfitSharingReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
SharingNo: fmt.Sprintf("FEE%s", tradeNo),
|
||||
ReceiverMerchantID: cfg.FeeReceiverMerchantID,
|
||||
Amount: feeAmount,
|
||||
})
|
||||
if err != nil {
|
||||
s.feeRepo.UpdateLogStatus(ctx, log.ID, "FAILED", "")
|
||||
slog.WarnContext(ctx, "charge service fee failed", "trade_no", tradeNo, "err", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.feeRepo.UpdateLogStatus(ctx, log.ID, "SUCCESS", resp.ChannelSharingNo)
|
||||
slog.InfoContext(ctx, "service fee charged", "trade_no", tradeNo, "fee_amount", feeAmount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollbackServiceFee 退款时回退服务费
|
||||
func (s *ServiceFeeService) RollbackServiceFee(ctx context.Context, tradeNo string) error {
|
||||
// 幂等检查
|
||||
existing, err := s.feeRepo.GetLog(ctx, tradeNo, "ROLLBACK")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil // 已回退
|
||||
}
|
||||
|
||||
chargeLog, err := s.feeRepo.GetLog(ctx, tradeNo, "CHARGE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if chargeLog == nil || chargeLog.Status != "SUCCESS" {
|
||||
return nil // 没有成功扣收,无需回退
|
||||
}
|
||||
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return fmt.Errorf("order not found: %s", tradeNo)
|
||||
}
|
||||
|
||||
rollbackLog := &model.ServiceFeeLog{
|
||||
TradeNo: tradeNo,
|
||||
ConfigID: chargeLog.ConfigID,
|
||||
FeeAmount: chargeLog.FeeAmount,
|
||||
FeeRate: chargeLog.FeeRate,
|
||||
ReceiverMerchantID: chargeLog.ReceiverMerchantID,
|
||||
Action: "ROLLBACK",
|
||||
Status: "PENDING",
|
||||
}
|
||||
if err := s.feeRepo.CreateLog(ctx, rollbackLog); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sharingNo := fmt.Sprintf("FEE%s", tradeNo)
|
||||
if err := ch.RollbackProfitSharing(ctx, &channel.RollbackSharingReq{
|
||||
SharingNo: sharingNo,
|
||||
ChannelSharingNo: chargeLog.ChannelSharingNo,
|
||||
TradeNo: tradeNo,
|
||||
}); err != nil {
|
||||
s.feeRepo.UpdateLogStatus(ctx, rollbackLog.ID, "FAILED", "")
|
||||
return err
|
||||
}
|
||||
|
||||
s.feeRepo.UpdateLogStatus(ctx, rollbackLog.ID, "SUCCESS", "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateAndValidate 下单时校验分润+服务费不超过订单金额
|
||||
func (s *ServiceFeeService) CalculateAndValidate(ctx context.Context, appID string, payMethod model.PayMethod, orderAmount, sharingAmount int64) (int64, error) {
|
||||
group := model.PayMethodToGroup(payMethod)
|
||||
cfg, err := s.feeRepo.GetConfig(ctx, appID, group)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var feeAmount int64
|
||||
if cfg != nil && cfg.FeeRate > 0 {
|
||||
feeAmount = calculateFee(orderAmount, cfg.FeeRate)
|
||||
}
|
||||
|
||||
if sharingAmount+feeAmount > orderAmount {
|
||||
return 0, fmt.Errorf(errSharingFeeExceed)
|
||||
}
|
||||
return feeAmount, nil
|
||||
}
|
||||
|
||||
const errSharingFeeExceed = "30007" // errcode.ErrSharingFeeExceed
|
||||
|
||||
// calculateFee 计算服务费(四舍五入到分)
|
||||
func calculateFee(amount int64, rate float64) int64 {
|
||||
fee := float64(amount) * rate
|
||||
return int64(math.Round(fee))
|
||||
}
|
||||
371
backend/internal/service/trade.go
Normal file
371
backend/internal/service/trade.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"encoding/json"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
const (
|
||||
orderExpireDefault = 30 * time.Minute
|
||||
idempotentKeyPrefix = "idempotent:"
|
||||
idempotentTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CreateOrderReq 下单请求
|
||||
type CreateOrderReq struct {
|
||||
AppID string
|
||||
ChannelCode string // 指定渠道,为空时使用 defaultChannelCode
|
||||
MerchantOrderNo string
|
||||
PayMethod model.PayMethod
|
||||
Amount int64
|
||||
ProfitSharingAmount int64
|
||||
Subject string
|
||||
NotifyURL string
|
||||
ExpireMinutes int
|
||||
Extra map[string]any
|
||||
MerchantID string // 可选,指定收款商户(SaaS 多商户路由)
|
||||
}
|
||||
|
||||
// CreateOrderResp 下单响应
|
||||
type CreateOrderResp struct {
|
||||
TradeNo string
|
||||
PayCredential map[string]any
|
||||
IsIdempotent bool // true=幂等返回
|
||||
}
|
||||
|
||||
// TradeService 交易服务
|
||||
type TradeService struct {
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
merchantSvc *MerchantService
|
||||
seqSvc *sequence.Service
|
||||
rdb *redis.Client
|
||||
notifySvc *NotifyService
|
||||
}
|
||||
|
||||
func NewTradeService(
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
seqSvc *sequence.Service,
|
||||
rdb *redis.Client,
|
||||
notifySvc *NotifyService,
|
||||
merchantSvc *MerchantService,
|
||||
) *TradeService {
|
||||
return &TradeService{
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
merchantSvc: merchantSvc,
|
||||
seqSvc: seqSvc,
|
||||
rdb: rdb,
|
||||
notifySvc: notifySvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOrder 统一下单(含幂等控制)
|
||||
func (s *TradeService) CreateOrder(ctx context.Context, req *CreateOrderReq) (*CreateOrderResp, error) {
|
||||
// 参数校验
|
||||
if req.Amount <= 0 {
|
||||
return nil, errors.New(errcode.ErrInvalidAmount)
|
||||
}
|
||||
|
||||
// 幂等检查 - Redis SET NX
|
||||
idempotentKey := fmt.Sprintf("%s%s:%s", idempotentKeyPrefix, req.AppID, req.MerchantOrderNo)
|
||||
set, err := s.rdb.SetNX(ctx, idempotentKey, "1", idempotentTTL).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
slog.WarnContext(ctx, "redis idempotent check failed, fallback to db", "err", err)
|
||||
}
|
||||
|
||||
if !set {
|
||||
// 幂等命中,查询已有订单
|
||||
order, err := s.tradeRepo.GetByMerchantOrderNo(ctx, req.AppID, req.MerchantOrderNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil {
|
||||
// Redis key 存在但 DB 无记录(极端情况),清除 key 重试
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.Status == model.TradeStatusPaid {
|
||||
return nil, errors.New(errcode.ErrOrderAlreadyPaid)
|
||||
}
|
||||
if order.Status == model.TradeStatusClosed {
|
||||
return nil, errors.New(errcode.ErrOrderClosed)
|
||||
}
|
||||
return &CreateOrderResp{
|
||||
TradeNo: order.TradeNo,
|
||||
PayCredential: order.ChannelExtra,
|
||||
IsIdempotent: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 生成交易号
|
||||
tradeNo, err := s.seqSvc.NextTradeNo(ctx, req.AppID)
|
||||
if err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算过期时间
|
||||
expireMinutes := req.ExpireMinutes
|
||||
if expireMinutes <= 0 {
|
||||
expireMinutes = int(orderExpireDefault.Minutes())
|
||||
}
|
||||
expireTime := time.Now().Add(time.Duration(expireMinutes) * time.Minute)
|
||||
|
||||
// 确定渠道
|
||||
channelCode := req.ChannelCode
|
||||
if channelCode == "" {
|
||||
channelCode = "HEEPAY" // 向后兼容默认值,建议调用方明确传入
|
||||
}
|
||||
|
||||
// 可选:指定收款商户(SaaS 多商户路由),校验归属并按渠道注入 sub_merchant_id
|
||||
if req.MerchantID != "" && s.merchantSvc != nil {
|
||||
// 校验商户归属(只能使用本 appID 下的商户)
|
||||
if _, err := s.merchantSvc.GetMerchantForApp(ctx, req.AppID, req.MerchantID); err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
// 按实际下单渠道取对应进件记录的 channel_merchant_id
|
||||
channelMerchantID, err := s.merchantSvc.GetChannelMerchantID(ctx, req.MerchantID, channelCode)
|
||||
if err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
if channelMerchantID != "" {
|
||||
if req.Extra == nil {
|
||||
req.Extra = make(map[string]any)
|
||||
}
|
||||
req.Extra["sub_merchant_id"] = channelMerchantID
|
||||
}
|
||||
}
|
||||
|
||||
// 创建本地订单记录(CREATING 状态)
|
||||
order := &model.TradeOrder{
|
||||
TradeNo: tradeNo,
|
||||
MerchantOrderNo: req.MerchantOrderNo,
|
||||
AppID: req.AppID,
|
||||
ChannelCode: channelCode,
|
||||
PayMethod: req.PayMethod,
|
||||
Amount: req.Amount,
|
||||
ProfitSharingAmount: req.ProfitSharingAmount,
|
||||
Subject: req.Subject,
|
||||
NotifyURL: req.NotifyURL,
|
||||
Status: model.TradeStatusCreating,
|
||||
Extra: req.Extra,
|
||||
ExpireTime: expireTime,
|
||||
}
|
||||
if err := s.tradeRepo.Create(ctx, order); err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用渠道下单
|
||||
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, channelCode)
|
||||
if err != nil {
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
|
||||
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
|
||||
}
|
||||
|
||||
channelReq := &channel.CreateOrderReq{
|
||||
AppID: req.AppID,
|
||||
TradeNo: tradeNo,
|
||||
MerchantOrderNo: req.MerchantOrderNo,
|
||||
PayMethod: req.PayMethod,
|
||||
Amount: req.Amount,
|
||||
Subject: req.Subject,
|
||||
NotifyURL: req.NotifyURL,
|
||||
ExpireTime: expireTime,
|
||||
Extra: req.Extra,
|
||||
}
|
||||
|
||||
channelResp, err := ch.CreateOrder(ctx, channelReq)
|
||||
if err != nil {
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
|
||||
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
|
||||
}
|
||||
|
||||
// 更新为 PAYING 状态,保存支付凭证
|
||||
updates := map[string]any{
|
||||
"channel_trade_no": channelResp.ChannelTradeNo,
|
||||
"channel_extra": model.JSONMap(channelResp.PayCredential),
|
||||
}
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusPaying, updates)
|
||||
|
||||
return &CreateOrderResp{
|
||||
TradeNo: tradeNo,
|
||||
PayCredential: channelResp.PayCredential,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder 查询交易状态
|
||||
func (s *TradeService) QueryOrder(ctx context.Context, appID, tradeNo string) (*model.TradeOrder, error) {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil || order.AppID != appID {
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
|
||||
// 如果处于 PAYING 状态,主动查询渠道同步最新状态
|
||||
if order.Status == model.TradeStatusPaying {
|
||||
s.syncOrderStatus(ctx, order)
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// CloseOrder 关闭订单
|
||||
func (s *TradeService) CloseOrder(ctx context.Context, appID, tradeNo string) error {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if order == nil || order.AppID != appID {
|
||||
return errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.Status == model.TradeStatusPaid {
|
||||
return errors.New(errcode.ErrOrderAlreadyPaid)
|
||||
}
|
||||
if order.Status == model.TradeStatusClosed {
|
||||
return nil // 已关闭,幂等
|
||||
}
|
||||
if order.Status != model.TradeStatusPaying {
|
||||
return errors.New(errcode.ErrOrderClosed)
|
||||
}
|
||||
|
||||
// 调用渠道关单
|
||||
ch, err := s.channelSvc.GetChannel(ctx, appID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ch.CloseOrder(ctx, &channel.CloseOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
}); err != nil {
|
||||
slog.WarnContext(ctx, "close order on channel failed", "trade_no", tradeNo, "err", err)
|
||||
}
|
||||
|
||||
_, err = s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusClosed, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleUpstreamNotify 处理上游支付回调(验签 + 状态更新 + 触发通知下游)
|
||||
//
|
||||
// 流程:先用临时无配置实例从 body 提取 trade_no → 查 DB 得 appID → 加载完整渠道配置验签
|
||||
func (s *TradeService) HandleUpstreamNotify(ctx context.Context, channelCode string, rawBody []byte, headers map[string]string) (string, error) {
|
||||
// 用只负责解析的临时渠道实例提取交易号(不需要密钥配置)
|
||||
tempCh, err := channel.Get(channelCode, nil, channel.URLs{})
|
||||
if err != nil {
|
||||
return "fail", fmt.Errorf("unknown channel: %s", channelCode)
|
||||
}
|
||||
tradeNo, err := tempCh.ExtractTradeNo(rawBody)
|
||||
if err != nil || tradeNo == "" {
|
||||
return "fail", fmt.Errorf("extract trade_no from notify: %w", err)
|
||||
}
|
||||
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return "fail", err
|
||||
}
|
||||
if order == nil {
|
||||
slog.WarnContext(ctx, "notify: order not found", "trade_no", tradeNo)
|
||||
return "fail", errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
|
||||
// 加载完整渠道配置并验签
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, channelCode)
|
||||
if err != nil {
|
||||
return "fail", err
|
||||
}
|
||||
|
||||
notifyData, err := ch.VerifyNotify(ctx, rawBody, headers)
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "notify: verify sign failed", "trade_no", tradeNo, "err", err)
|
||||
return "fail", errors.New(errcode.ErrChannelVerifyFail)
|
||||
}
|
||||
|
||||
// 处理支付通知
|
||||
if notifyData.NotifyType == model.NotifyTypePayment && notifyData.Status == model.TradeStatusPaid {
|
||||
if err := s.handlePaymentSuccess(ctx, order, notifyData); err != nil {
|
||||
return "fail", err
|
||||
}
|
||||
}
|
||||
|
||||
return "success", nil
|
||||
}
|
||||
|
||||
// handlePaymentSuccess 处理支付成功
|
||||
func (s *TradeService) handlePaymentSuccess(ctx context.Context, order *model.TradeOrder, data *channel.NotifyData) error {
|
||||
updates := map[string]any{
|
||||
"channel_trade_no": data.ChannelTradeNo,
|
||||
}
|
||||
if data.PayTime != nil {
|
||||
updates["pay_time"] = data.PayTime
|
||||
}
|
||||
|
||||
ok, err := s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// 已被处理过(幂等),直接返回成功
|
||||
slog.InfoContext(ctx, "payment notify idempotent", "trade_no", order.TradeNo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 异步触发下游通知
|
||||
if s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
if err := s.notifySvc.SendNotify(bgCtx, order.TradeNo, model.NotifyTypePayment, order.NotifyURL); err != nil {
|
||||
slog.Error("send notify failed", "trade_no", order.TradeNo, "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncOrderStatus 主动查询渠道同步订单状态(查询接口兜底)
|
||||
func (s *TradeService) syncOrderStatus(ctx context.Context, order *model.TradeOrder) {
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "syncOrderStatus: get channel failed", "trade_no", order.TradeNo, "err", err)
|
||||
return
|
||||
}
|
||||
resp, err := ch.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: order.TradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "syncOrderStatus: query channel failed", "trade_no", order.TradeNo, "err", err)
|
||||
return
|
||||
}
|
||||
if resp.Status == model.TradeStatusPaid {
|
||||
updates := map[string]any{
|
||||
"channel_trade_no": resp.ChannelTradeNo,
|
||||
}
|
||||
if resp.PayTime != nil {
|
||||
updates["pay_time"] = resp.PayTime
|
||||
}
|
||||
s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSON(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
161
backend/internal/service/wechat.go
Normal file
161
backend/internal/service/wechat.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
wxTokenURL = "https://api.weixin.qq.com/cgi-bin/token"
|
||||
wxSendMsgURL = "https://api.weixin.qq.com/cgi-bin/message/template/send"
|
||||
accessTokenTTL = 90 * time.Minute // 微信 access_token 有效期 2h,提前 30min 刷新
|
||||
)
|
||||
|
||||
// WechatService 微信模板消息服务
|
||||
type WechatService struct {
|
||||
wechatRepo *repository.WechatRepository
|
||||
cryptoKey string
|
||||
httpClient *http.Client
|
||||
// 内存缓存 access_token,避免频繁调用微信接口
|
||||
tokenCache map[string]*tokenEntry
|
||||
}
|
||||
|
||||
type tokenEntry struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewWechatService(wechatRepo *repository.WechatRepository, cryptoKey string) *WechatService {
|
||||
return &WechatService{
|
||||
wechatRepo: wechatRepo,
|
||||
cryptoKey: cryptoKey,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
tokenCache: make(map[string]*tokenEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// SendPaymentNotify 发送支付成功通知
|
||||
func (s *WechatService) SendPaymentNotify(ctx context.Context, appID, tradeNo, openID string, amount int64) error {
|
||||
binding, err := s.wechatRepo.GetBinding(ctx, appID)
|
||||
if err != nil || binding == nil {
|
||||
return nil // 未配置微信通知,跳过
|
||||
}
|
||||
|
||||
data := map[string]any{
|
||||
"trade_no": map[string]string{"value": tradeNo},
|
||||
"amount": map[string]string{"value": fmt.Sprintf("%.2f 元", float64(amount)/100)},
|
||||
"time": map[string]string{"value": time.Now().Format("2006-01-02 15:04:05")},
|
||||
}
|
||||
|
||||
return s.sendTemplate(ctx, appID, binding, openID, tradeNo, data)
|
||||
}
|
||||
|
||||
// sendTemplate 发送模板消息
|
||||
func (s *WechatService) sendTemplate(ctx context.Context, appID string, binding *model.WechatBinding,
|
||||
openID, tradeNo string, data map[string]any) error {
|
||||
|
||||
log := &model.WechatMessageLog{
|
||||
AppID: appID,
|
||||
TradeNo: tradeNo,
|
||||
OpenID: openID,
|
||||
TemplateID: binding.TemplateID,
|
||||
Status: model.WechatMessageStatusPending,
|
||||
}
|
||||
if err := s.wechatRepo.CreateMessageLog(ctx, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := s.getAccessToken(ctx, binding)
|
||||
if err != nil {
|
||||
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": err.Error()}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
return err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"touser": openID,
|
||||
"template_id": binding.TemplateID,
|
||||
"data": data,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
url := fmt.Sprintf("%s?access_token=%s", wxSendMsgURL, token)
|
||||
resp, err := s.httpClient.Post(url, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": err.Error()}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
json.Unmarshal(respBody, &result)
|
||||
|
||||
now := time.Now()
|
||||
if result.ErrCode == 0 {
|
||||
updates := map[string]any{"status": model.WechatMessageStatusSuccess, "sent_at": now}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
slog.InfoContext(ctx, "wechat template sent", "trade_no", tradeNo, "open_id", openID)
|
||||
} else {
|
||||
errMsg := fmt.Sprintf("errcode=%d errmsg=%s", result.ErrCode, result.ErrMsg)
|
||||
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": errMsg}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
return fmt.Errorf("wechat send failed: %s", errMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken 获取微信 access_token(带内存缓存)
|
||||
func (s *WechatService) getAccessToken(ctx context.Context, binding *model.WechatBinding) (string, error) {
|
||||
if entry, ok := s.tokenCache[binding.WxAppID]; ok && time.Now().Before(entry.expiresAt) {
|
||||
return entry.token, nil
|
||||
}
|
||||
|
||||
// 解密 secret
|
||||
secret, err := crypto.Decrypt(binding.WxSecret, s.cryptoKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt wx secret: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s",
|
||||
wxTokenURL, binding.WxAppID, secret)
|
||||
resp, err := s.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get wx token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
return "", fmt.Errorf("wx token error: %d %s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
s.tokenCache[binding.WxAppID] = &tokenEntry{
|
||||
token: result.AccessToken,
|
||||
expiresAt: time.Now().Add(accessTokenTTL),
|
||||
}
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
Reference in New Issue
Block a user