Files
pay-bridge/backend/internal/service/refund.go
2026-03-13 15:51:59 +08:00

214 lines
5.6 KiB
Go

package service
import (
"context"
"errors"
"log/slog"
"time"
"pay-bridge/internal/channel"
"pay-bridge/internal/errcode"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/sequence"
)
// CreateRefundReq 退款请求
type CreateRefundReq struct {
AppID string
TradeNo string
RefundAmount int64
Reason string
NotifyURL string
}
// RefundService 退款服务
type RefundService struct {
refundRepo *repository.RefundOrderRepository
tradeRepo *repository.TradeOrderRepository
channelSvc *ChannelService
seqSvc *sequence.Service
notifySvc *NotifyService
}
func NewRefundService(
refundRepo *repository.RefundOrderRepository,
tradeRepo *repository.TradeOrderRepository,
channelSvc *ChannelService,
seqSvc *sequence.Service,
notifySvc *NotifyService,
) *RefundService {
return &RefundService{
refundRepo: refundRepo,
tradeRepo: tradeRepo,
channelSvc: channelSvc,
seqSvc: seqSvc,
notifySvc: notifySvc,
}
}
// CreateRefund 发起退款
func (s *RefundService) CreateRefund(ctx context.Context, req *CreateRefundReq) (*model.RefundOrder, error) {
// 查询原交易
order, err := s.tradeRepo.GetByTradeNo(ctx, req.TradeNo)
if err != nil {
return nil, err
}
if order == nil || order.AppID != req.AppID {
return nil, errors.New(errcode.ErrOrderNotFound)
}
if order.Status != model.TradeStatusPaid && order.Status != model.TradeStatusRefunded {
return nil, errors.New(errcode.ErrOrderNotPaid)
}
// 校验可退金额
refunded, err := s.refundRepo.SumRefundedAmount(ctx, req.TradeNo)
if err != nil {
return nil, err
}
if refunded+req.RefundAmount > order.Amount {
return nil, errors.New(errcode.ErrRefundAmountExceed)
}
// 生成退款单号
refundNo, err := s.seqSvc.NextRefundNo(ctx, req.AppID)
if err != nil {
return nil, err
}
// 创建退款记录
refund := &model.RefundOrder{
RefundNo: refundNo,
TradeNo: req.TradeNo,
AppID: req.AppID,
ChannelCode: order.ChannelCode,
RefundAmount: req.RefundAmount,
Reason: req.Reason,
Status: model.RefundStatusPending,
NotifyURL: req.NotifyURL,
}
if err := s.refundRepo.Create(ctx, refund); err != nil {
return nil, err
}
// 调用渠道退款
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, order.ChannelCode)
if err != nil {
return nil, err
}
channelResp, err := ch.Refund(ctx, &channel.RefundReq{
TradeNo: req.TradeNo,
ChannelTradeNo: order.ChannelTradeNo,
RefundNo: refundNo,
RefundAmount: req.RefundAmount,
TotalAmount: order.Amount,
Reason: req.Reason,
NotifyURL: req.NotifyURL,
})
if err != nil {
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusFailed, nil)
return nil, errors.New(errcode.ErrChannelRefundFail)
}
// 更新渠道退款单号
updates := map[string]any{
"channel_refund_no": channelResp.ChannelRefundNo,
"status": model.RefundStatusProcessing,
}
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusProcessing, updates)
refund.ChannelRefundNo = channelResp.ChannelRefundNo
refund.Status = model.RefundStatusProcessing
return refund, nil
}
// QueryRefund 查询退款状态
func (s *RefundService) QueryRefund(ctx context.Context, appID, refundNo string) (*model.RefundOrder, error) {
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
if err != nil {
return nil, err
}
if refund == nil || refund.AppID != appID {
return nil, errors.New(errcode.ErrRefundNotFound)
}
// 如果处于处理中,主动查询渠道
if refund.Status == model.RefundStatusProcessing {
s.syncRefundStatus(ctx, refund)
// 重新查询最新状态
refund, _ = s.refundRepo.GetByRefundNo(ctx, refundNo)
}
return refund, nil
}
// HandleRefundNotify 处理退款回调
func (s *RefundService) HandleRefundNotify(ctx context.Context, refundNo string, channelRefundNo string, status model.RefundStatus) error {
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
if err != nil || refund == nil {
return errors.New(errcode.ErrRefundNotFound)
}
if refund.Status == model.RefundStatusSuccess {
return nil // 幂等
}
updates := map[string]any{
"channel_refund_no": channelRefundNo,
}
if status == model.RefundStatusSuccess {
now := time.Now()
updates["refund_time"] = now
}
ok, err := s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusProcessing, status, updates)
if err != nil {
return err
}
if !ok {
return nil // 幂等
}
// 退款成功后通知下游
if status == model.RefundStatusSuccess && refund.NotifyURL != "" && s.notifySvc != nil {
go func() {
bgCtx := context.Background()
if err := s.notifySvc.SendNotify(bgCtx, refund.TradeNo, model.NotifyTypeRefund, refund.NotifyURL); err != nil {
slog.Error("send refund notify failed", "refund_no", refundNo, "err", err)
}
}()
}
return nil
}
func (s *RefundService) syncRefundStatus(ctx context.Context, refund *model.RefundOrder) {
order, err := s.tradeRepo.GetByTradeNo(ctx, refund.TradeNo)
if err != nil || order == nil {
return
}
ch, err := s.channelSvc.GetChannel(ctx, refund.AppID, refund.ChannelCode)
if err != nil {
return
}
resp, err := ch.QueryRefund(ctx, &channel.QueryRefundReq{
RefundNo: refund.RefundNo,
ChannelRefundNo: refund.ChannelRefundNo,
})
if err != nil {
return
}
if resp.Status != refund.Status {
updates := map[string]any{
"channel_refund_no": resp.ChannelRefundNo,
}
if resp.RefundTime != nil {
updates["refund_time"] = resp.RefundTime
}
s.refundRepo.UpdateStatus(ctx, refund.RefundNo, refund.Status, resp.Status, updates)
}
}