Files
2026-03-13 15:51:59 +08:00

243 lines
6.9 KiB
Go

package handler
import (
"errors"
"net/http"
"github.com/gin-gonic/gin"
"pay-bridge/internal/errcode"
"pay-bridge/internal/model"
"pay-bridge/internal/service"
)
// PayHandler 支付相关 Handler
type PayHandler struct {
tradeSvc *service.TradeService
refundSvc *service.RefundService
}
func NewPayHandler(tradeSvc *service.TradeService, refundSvc *service.RefundService) *PayHandler {
return &PayHandler{tradeSvc: tradeSvc, refundSvc: refundSvc}
}
type unifiedOrderReq struct {
ChannelCode string `json:"channel_code"`
MerchantOrderNo string `json:"merchant_order_no" binding:"required"`
PayMethod model.PayMethod `json:"pay_method" binding:"required"`
Amount int64 `json:"amount" binding:"required,min=1"`
ProfitSharingAmount int64 `json:"profit_sharing_amount"`
Subject string `json:"subject" binding:"required"`
NotifyURL string `json:"notify_url" binding:"required,url"`
ExpireMinutes int `json:"expire_minutes"`
Extra map[string]any `json:"extra"`
MerchantID string `json:"merchant_id"` // 可选,指定收款商户
}
type closeOrderReq struct {
TradeNo string `json:"trade_no" binding:"required"`
}
type refundReq struct {
TradeNo string `json:"trade_no" binding:"required"`
RefundAmount int64 `json:"refund_amount" binding:"required,min=1"`
Reason string `json:"reason"`
NotifyURL string `json:"notify_url"`
}
// UnifiedOrder POST /api/v1/pay/unified-order
func (h *PayHandler) UnifiedOrder(c *gin.Context) {
var req unifiedOrderReq
if err := c.ShouldBindJSON(&req); err != nil {
BadRequest(c, errcode.ErrInvalidParam, err.Error())
return
}
appID := c.GetString("app_id")
resp, err := h.tradeSvc.CreateOrder(c.Request.Context(), &service.CreateOrderReq{
AppID: appID,
ChannelCode: req.ChannelCode,
MerchantOrderNo: req.MerchantOrderNo,
PayMethod: req.PayMethod,
Amount: req.Amount,
ProfitSharingAmount: req.ProfitSharingAmount,
Subject: req.Subject,
NotifyURL: req.NotifyURL,
ExpireMinutes: req.ExpireMinutes,
Extra: req.Extra,
MerchantID: req.MerchantID,
})
if err != nil {
handleBizError(c, err)
return
}
OK(c, gin.H{
"trade_no": resp.TradeNo,
"pay_credential": resp.PayCredential,
"is_idempotent": resp.IsIdempotent,
})
}
// QueryOrder GET /api/v1/pay/query/:tradeNo
func (h *PayHandler) QueryOrder(c *gin.Context) {
tradeNo := c.Param("tradeNo")
if tradeNo == "" {
BadRequest(c, errcode.ErrMissingParam, "trade_no is required")
return
}
appID := c.GetString("app_id")
order, err := h.tradeSvc.QueryOrder(c.Request.Context(), appID, tradeNo)
if err != nil {
handleBizError(c, err)
return
}
OK(c, gin.H{
"trade_no": order.TradeNo,
"merchant_order_no": order.MerchantOrderNo,
"pay_method": order.PayMethod,
"amount": order.Amount,
"status": order.Status,
"channel_trade_no": order.ChannelTradeNo,
"pay_time": order.PayTime,
"created_at": order.CreatedAt,
})
}
// CloseOrder POST /api/v1/pay/close
func (h *PayHandler) CloseOrder(c *gin.Context) {
var req closeOrderReq
if err := c.ShouldBindJSON(&req); err != nil {
BadRequest(c, errcode.ErrInvalidParam, err.Error())
return
}
appID := c.GetString("app_id")
if err := h.tradeSvc.CloseOrder(c.Request.Context(), appID, req.TradeNo); err != nil {
handleBizError(c, err)
return
}
OK(c, nil)
}
// Refund POST /api/v1/pay/refund
func (h *PayHandler) Refund(c *gin.Context) {
var req refundReq
if err := c.ShouldBindJSON(&req); err != nil {
BadRequest(c, errcode.ErrInvalidParam, err.Error())
return
}
appID := c.GetString("app_id")
refund, err := h.refundSvc.CreateRefund(c.Request.Context(), &service.CreateRefundReq{
AppID: appID,
TradeNo: req.TradeNo,
RefundAmount: req.RefundAmount,
Reason: req.Reason,
NotifyURL: req.NotifyURL,
})
if err != nil {
handleBizError(c, err)
return
}
OK(c, gin.H{
"refund_no": refund.RefundNo,
"trade_no": refund.TradeNo,
"refund_amount": refund.RefundAmount,
"status": refund.Status,
"channel_refund_no": refund.ChannelRefundNo,
})
}
// QueryRefund GET /api/v1/pay/refund/query/:refundNo
func (h *PayHandler) QueryRefund(c *gin.Context) {
refundNo := c.Param("refundNo")
if refundNo == "" {
BadRequest(c, errcode.ErrMissingParam, "refund_no is required")
return
}
appID := c.GetString("app_id")
refund, err := h.refundSvc.QueryRefund(c.Request.Context(), appID, refundNo)
if err != nil {
handleBizError(c, err)
return
}
OK(c, gin.H{
"refund_no": refund.RefundNo,
"trade_no": refund.TradeNo,
"refund_amount": refund.RefundAmount,
"status": refund.Status,
"channel_refund_no": refund.ChannelRefundNo,
"refund_time": refund.RefundTime,
})
}
// handleBizError 将业务错误映射到 HTTP 响应
func handleBizError(c *gin.Context, err error) {
code := err.Error()
msg := errcode.Message(code)
if msg == "未知错误" {
msg = err.Error()
code = errcode.ErrInternalSystem
}
switch code {
case errcode.ErrInvalidParam, errcode.ErrMissingParam, errcode.ErrInvalidPayMethod, errcode.ErrInvalidAmount:
BadRequest(c, code, msg)
case errcode.ErrUnauthorized, errcode.ErrAppNotFound:
Unauthorized(c, code, msg)
case errcode.ErrPermissionDenied:
Forbidden(c, code, msg)
case errcode.ErrOrderNotFound, errcode.ErrOrderAlreadyPaid, errcode.ErrOrderClosed,
errcode.ErrRefundAmountExceed, errcode.ErrSharingAmountExceed, errcode.ErrOrderNotPaid,
errcode.ErrSharingNotConfig, errcode.ErrSharingFeeExceed, errcode.ErrRefundNotFound:
UnprocessableEntity(c, code, msg)
case errcode.ErrChannelCreateFail, errcode.ErrChannelRefundFail,
errcode.ErrChannelTimeout, errcode.ErrChannelNotSupport, errcode.ErrChannelVerifyFail:
BadGateway(c, code, msg)
default:
_ = errors.New(code) // suppress unused
InternalError(c, errcode.ErrInternalSystem, errcode.Message(errcode.ErrInternalSystem))
}
}
// NotifyHandler 渠道回调 Handler
type NotifyHandler struct {
tradeSvc *service.TradeService
}
func NewNotifyHandler(tradeSvc *service.TradeService) *NotifyHandler {
return &NotifyHandler{tradeSvc: tradeSvc}
}
// PaymentCallback POST /api/v1/notify/payment/:channelCode
func (h *NotifyHandler) PaymentCallback(c *gin.Context) {
channelCode := c.Param("channelCode")
var rawBody []byte
if v, exists := c.Get("raw_body"); exists {
rawBody, _ = v.([]byte)
}
result, err := h.tradeSvc.HandleUpstreamNotify(c.Request.Context(), channelCode, rawBody, headersMap(c))
if err != nil {
c.String(http.StatusOK, "fail")
return
}
c.String(http.StatusOK, result)
}
func headersMap(c *gin.Context) map[string]string {
m := make(map[string]string)
for k, v := range c.Request.Header {
if len(v) > 0 {
m[k] = v[0]
}
}
return m
}