372 lines
11 KiB
Go
372 lines
11 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"log/slog"
|
||
"time"
|
||
|
||
"encoding/json"
|
||
|
||
"github.com/go-redis/redis/v8"
|
||
"pay-bridge/internal/channel"
|
||
"pay-bridge/internal/errcode"
|
||
"pay-bridge/internal/model"
|
||
"pay-bridge/internal/repository"
|
||
"pay-bridge/pkg/sequence"
|
||
)
|
||
|
||
const (
|
||
orderExpireDefault = 30 * time.Minute
|
||
idempotentKeyPrefix = "idempotent:"
|
||
idempotentTTL = 24 * time.Hour
|
||
)
|
||
|
||
// CreateOrderReq 下单请求
|
||
type CreateOrderReq struct {
|
||
AppID string
|
||
ChannelCode string // 指定渠道,为空时使用 defaultChannelCode
|
||
MerchantOrderNo string
|
||
PayMethod model.PayMethod
|
||
Amount int64
|
||
ProfitSharingAmount int64
|
||
Subject string
|
||
NotifyURL string
|
||
ExpireMinutes int
|
||
Extra map[string]any
|
||
MerchantID string // 可选,指定收款商户(SaaS 多商户路由)
|
||
}
|
||
|
||
// CreateOrderResp 下单响应
|
||
type CreateOrderResp struct {
|
||
TradeNo string
|
||
PayCredential map[string]any
|
||
IsIdempotent bool // true=幂等返回
|
||
}
|
||
|
||
// TradeService 交易服务
|
||
type TradeService struct {
|
||
tradeRepo *repository.TradeOrderRepository
|
||
channelSvc *ChannelService
|
||
merchantSvc *MerchantService
|
||
seqSvc *sequence.Service
|
||
rdb *redis.Client
|
||
notifySvc *NotifyService
|
||
}
|
||
|
||
func NewTradeService(
|
||
tradeRepo *repository.TradeOrderRepository,
|
||
channelSvc *ChannelService,
|
||
seqSvc *sequence.Service,
|
||
rdb *redis.Client,
|
||
notifySvc *NotifyService,
|
||
merchantSvc *MerchantService,
|
||
) *TradeService {
|
||
return &TradeService{
|
||
tradeRepo: tradeRepo,
|
||
channelSvc: channelSvc,
|
||
merchantSvc: merchantSvc,
|
||
seqSvc: seqSvc,
|
||
rdb: rdb,
|
||
notifySvc: notifySvc,
|
||
}
|
||
}
|
||
|
||
// CreateOrder 统一下单(含幂等控制)
|
||
func (s *TradeService) CreateOrder(ctx context.Context, req *CreateOrderReq) (*CreateOrderResp, error) {
|
||
// 参数校验
|
||
if req.Amount <= 0 {
|
||
return nil, errors.New(errcode.ErrInvalidAmount)
|
||
}
|
||
|
||
// 幂等检查 - Redis SET NX
|
||
idempotentKey := fmt.Sprintf("%s%s:%s", idempotentKeyPrefix, req.AppID, req.MerchantOrderNo)
|
||
set, err := s.rdb.SetNX(ctx, idempotentKey, "1", idempotentTTL).Result()
|
||
if err != nil && !errors.Is(err, redis.Nil) {
|
||
slog.WarnContext(ctx, "redis idempotent check failed, fallback to db", "err", err)
|
||
}
|
||
|
||
if !set {
|
||
// 幂等命中,查询已有订单
|
||
order, err := s.tradeRepo.GetByMerchantOrderNo(ctx, req.AppID, req.MerchantOrderNo)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if order == nil {
|
||
// Redis key 存在但 DB 无记录(极端情况),清除 key 重试
|
||
s.rdb.Del(ctx, idempotentKey)
|
||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||
}
|
||
if order.Status == model.TradeStatusPaid {
|
||
return nil, errors.New(errcode.ErrOrderAlreadyPaid)
|
||
}
|
||
if order.Status == model.TradeStatusClosed {
|
||
return nil, errors.New(errcode.ErrOrderClosed)
|
||
}
|
||
return &CreateOrderResp{
|
||
TradeNo: order.TradeNo,
|
||
PayCredential: order.ChannelExtra,
|
||
IsIdempotent: true,
|
||
}, nil
|
||
}
|
||
|
||
// 生成交易号
|
||
tradeNo, err := s.seqSvc.NextTradeNo(ctx, req.AppID)
|
||
if err != nil {
|
||
s.rdb.Del(ctx, idempotentKey)
|
||
return nil, err
|
||
}
|
||
|
||
// 计算过期时间
|
||
expireMinutes := req.ExpireMinutes
|
||
if expireMinutes <= 0 {
|
||
expireMinutes = int(orderExpireDefault.Minutes())
|
||
}
|
||
expireTime := time.Now().Add(time.Duration(expireMinutes) * time.Minute)
|
||
|
||
// 确定渠道
|
||
channelCode := req.ChannelCode
|
||
if channelCode == "" {
|
||
channelCode = "HEEPAY" // 向后兼容默认值,建议调用方明确传入
|
||
}
|
||
|
||
// 可选:指定收款商户(SaaS 多商户路由),校验归属并按渠道注入 sub_merchant_id
|
||
if req.MerchantID != "" && s.merchantSvc != nil {
|
||
// 校验商户归属(只能使用本 appID 下的商户)
|
||
if _, err := s.merchantSvc.GetMerchantForApp(ctx, req.AppID, req.MerchantID); err != nil {
|
||
s.rdb.Del(ctx, idempotentKey)
|
||
return nil, err
|
||
}
|
||
// 按实际下单渠道取对应进件记录的 channel_merchant_id
|
||
channelMerchantID, err := s.merchantSvc.GetChannelMerchantID(ctx, req.MerchantID, channelCode)
|
||
if err != nil {
|
||
s.rdb.Del(ctx, idempotentKey)
|
||
return nil, err
|
||
}
|
||
if channelMerchantID != "" {
|
||
if req.Extra == nil {
|
||
req.Extra = make(map[string]any)
|
||
}
|
||
req.Extra["sub_merchant_id"] = channelMerchantID
|
||
}
|
||
}
|
||
|
||
// 创建本地订单记录(CREATING 状态)
|
||
order := &model.TradeOrder{
|
||
TradeNo: tradeNo,
|
||
MerchantOrderNo: req.MerchantOrderNo,
|
||
AppID: req.AppID,
|
||
ChannelCode: channelCode,
|
||
PayMethod: req.PayMethod,
|
||
Amount: req.Amount,
|
||
ProfitSharingAmount: req.ProfitSharingAmount,
|
||
Subject: req.Subject,
|
||
NotifyURL: req.NotifyURL,
|
||
Status: model.TradeStatusCreating,
|
||
Extra: req.Extra,
|
||
ExpireTime: expireTime,
|
||
}
|
||
if err := s.tradeRepo.Create(ctx, order); err != nil {
|
||
s.rdb.Del(ctx, idempotentKey)
|
||
return nil, err
|
||
}
|
||
|
||
// 调用渠道下单
|
||
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, channelCode)
|
||
if err != nil {
|
||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
|
||
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
|
||
}
|
||
|
||
channelReq := &channel.CreateOrderReq{
|
||
AppID: req.AppID,
|
||
TradeNo: tradeNo,
|
||
MerchantOrderNo: req.MerchantOrderNo,
|
||
PayMethod: req.PayMethod,
|
||
Amount: req.Amount,
|
||
Subject: req.Subject,
|
||
NotifyURL: req.NotifyURL,
|
||
ExpireTime: expireTime,
|
||
Extra: req.Extra,
|
||
}
|
||
|
||
channelResp, err := ch.CreateOrder(ctx, channelReq)
|
||
if err != nil {
|
||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
|
||
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
|
||
}
|
||
|
||
// 更新为 PAYING 状态,保存支付凭证
|
||
updates := map[string]any{
|
||
"channel_trade_no": channelResp.ChannelTradeNo,
|
||
"channel_extra": model.JSONMap(channelResp.PayCredential),
|
||
}
|
||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusPaying, updates)
|
||
|
||
return &CreateOrderResp{
|
||
TradeNo: tradeNo,
|
||
PayCredential: channelResp.PayCredential,
|
||
}, nil
|
||
}
|
||
|
||
// QueryOrder 查询交易状态
|
||
func (s *TradeService) QueryOrder(ctx context.Context, appID, tradeNo string) (*model.TradeOrder, error) {
|
||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if order == nil || order.AppID != appID {
|
||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||
}
|
||
|
||
// 如果处于 PAYING 状态,主动查询渠道同步最新状态
|
||
if order.Status == model.TradeStatusPaying {
|
||
s.syncOrderStatus(ctx, order)
|
||
}
|
||
|
||
return order, nil
|
||
}
|
||
|
||
// CloseOrder 关闭订单
|
||
func (s *TradeService) CloseOrder(ctx context.Context, appID, tradeNo string) error {
|
||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if order == nil || order.AppID != appID {
|
||
return errors.New(errcode.ErrOrderNotFound)
|
||
}
|
||
if order.Status == model.TradeStatusPaid {
|
||
return errors.New(errcode.ErrOrderAlreadyPaid)
|
||
}
|
||
if order.Status == model.TradeStatusClosed {
|
||
return nil // 已关闭,幂等
|
||
}
|
||
if order.Status != model.TradeStatusPaying {
|
||
return errors.New(errcode.ErrOrderClosed)
|
||
}
|
||
|
||
// 调用渠道关单
|
||
ch, err := s.channelSvc.GetChannel(ctx, appID, order.ChannelCode)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if err := ch.CloseOrder(ctx, &channel.CloseOrderReq{
|
||
TradeNo: tradeNo,
|
||
ChannelTradeNo: order.ChannelTradeNo,
|
||
}); err != nil {
|
||
slog.WarnContext(ctx, "close order on channel failed", "trade_no", tradeNo, "err", err)
|
||
}
|
||
|
||
_, err = s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusClosed, nil)
|
||
return err
|
||
}
|
||
|
||
// HandleUpstreamNotify 处理上游支付回调(验签 + 状态更新 + 触发通知下游)
|
||
//
|
||
// 流程:先用临时无配置实例从 body 提取 trade_no → 查 DB 得 appID → 加载完整渠道配置验签
|
||
func (s *TradeService) HandleUpstreamNotify(ctx context.Context, channelCode string, rawBody []byte, headers map[string]string) (string, error) {
|
||
// 用只负责解析的临时渠道实例提取交易号(不需要密钥配置)
|
||
tempCh, err := channel.Get(channelCode, nil, channel.URLs{})
|
||
if err != nil {
|
||
return "fail", fmt.Errorf("unknown channel: %s", channelCode)
|
||
}
|
||
tradeNo, err := tempCh.ExtractTradeNo(rawBody)
|
||
if err != nil || tradeNo == "" {
|
||
return "fail", fmt.Errorf("extract trade_no from notify: %w", err)
|
||
}
|
||
|
||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||
if err != nil {
|
||
return "fail", err
|
||
}
|
||
if order == nil {
|
||
slog.WarnContext(ctx, "notify: order not found", "trade_no", tradeNo)
|
||
return "fail", errors.New(errcode.ErrOrderNotFound)
|
||
}
|
||
|
||
// 加载完整渠道配置并验签
|
||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, channelCode)
|
||
if err != nil {
|
||
return "fail", err
|
||
}
|
||
|
||
notifyData, err := ch.VerifyNotify(ctx, rawBody, headers)
|
||
if err != nil {
|
||
slog.WarnContext(ctx, "notify: verify sign failed", "trade_no", tradeNo, "err", err)
|
||
return "fail", errors.New(errcode.ErrChannelVerifyFail)
|
||
}
|
||
|
||
// 处理支付通知
|
||
if notifyData.NotifyType == model.NotifyTypePayment && notifyData.Status == model.TradeStatusPaid {
|
||
if err := s.handlePaymentSuccess(ctx, order, notifyData); err != nil {
|
||
return "fail", err
|
||
}
|
||
}
|
||
|
||
return "success", nil
|
||
}
|
||
|
||
// handlePaymentSuccess 处理支付成功
|
||
func (s *TradeService) handlePaymentSuccess(ctx context.Context, order *model.TradeOrder, data *channel.NotifyData) error {
|
||
updates := map[string]any{
|
||
"channel_trade_no": data.ChannelTradeNo,
|
||
}
|
||
if data.PayTime != nil {
|
||
updates["pay_time"] = data.PayTime
|
||
}
|
||
|
||
ok, err := s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if !ok {
|
||
// 已被处理过(幂等),直接返回成功
|
||
slog.InfoContext(ctx, "payment notify idempotent", "trade_no", order.TradeNo)
|
||
return nil
|
||
}
|
||
|
||
// 异步触发下游通知
|
||
if s.notifySvc != nil {
|
||
go func() {
|
||
bgCtx := context.Background()
|
||
if err := s.notifySvc.SendNotify(bgCtx, order.TradeNo, model.NotifyTypePayment, order.NotifyURL); err != nil {
|
||
slog.Error("send notify failed", "trade_no", order.TradeNo, "err", err)
|
||
}
|
||
}()
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// syncOrderStatus 主动查询渠道同步订单状态(查询接口兜底)
|
||
func (s *TradeService) syncOrderStatus(ctx context.Context, order *model.TradeOrder) {
|
||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||
if err != nil {
|
||
slog.WarnContext(ctx, "syncOrderStatus: get channel failed", "trade_no", order.TradeNo, "err", err)
|
||
return
|
||
}
|
||
resp, err := ch.QueryOrder(ctx, &channel.QueryOrderReq{
|
||
TradeNo: order.TradeNo,
|
||
ChannelTradeNo: order.ChannelTradeNo,
|
||
})
|
||
if err != nil {
|
||
slog.WarnContext(ctx, "syncOrderStatus: query channel failed", "trade_no", order.TradeNo, "err", err)
|
||
return
|
||
}
|
||
if resp.Status == model.TradeStatusPaid {
|
||
updates := map[string]any{
|
||
"channel_trade_no": resp.ChannelTradeNo,
|
||
}
|
||
if resp.PayTime != nil {
|
||
updates["pay_time"] = resp.PayTime
|
||
}
|
||
s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||
}
|
||
}
|
||
|
||
func parseJSON(data []byte, v any) error {
|
||
return json.Unmarshal(data, v)
|
||
}
|