81 lines
2.6 KiB
Go
81 lines
2.6 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
"pay-bridge/internal/model"
|
||
)
|
||
|
||
// TradeOrderRepository 交易订单数据访问
|
||
type TradeOrderRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
func NewTradeOrderRepository(db *gorm.DB) *TradeOrderRepository {
|
||
return &TradeOrderRepository{db: db}
|
||
}
|
||
|
||
// Create 创建订单
|
||
func (r *TradeOrderRepository) Create(ctx context.Context, order *model.TradeOrder) error {
|
||
return r.db.WithContext(ctx).Create(order).Error
|
||
}
|
||
|
||
// GetByTradeNo 按 trade_no 查询
|
||
func (r *TradeOrderRepository) GetByTradeNo(ctx context.Context, tradeNo string) (*model.TradeOrder, error) {
|
||
var order model.TradeOrder
|
||
err := r.db.WithContext(ctx).Where("trade_no = ?", tradeNo).First(&order).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return &order, err
|
||
}
|
||
|
||
// GetByMerchantOrderNo 按 app_id + merchant_order_no 查询
|
||
func (r *TradeOrderRepository) GetByMerchantOrderNo(ctx context.Context, appID, merchantOrderNo string) (*model.TradeOrder, error) {
|
||
var order model.TradeOrder
|
||
err := r.db.WithContext(ctx).Where("app_id = ? AND merchant_order_no = ?", appID, merchantOrderNo).First(&order).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return &order, err
|
||
}
|
||
|
||
// GetByChannelTradeNo 按渠道交易号查询
|
||
func (r *TradeOrderRepository) GetByChannelTradeNo(ctx context.Context, channelTradeNo string) (*model.TradeOrder, error) {
|
||
var order model.TradeOrder
|
||
err := r.db.WithContext(ctx).Where("channel_trade_no = ?", channelTradeNo).First(&order).Error
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, nil
|
||
}
|
||
return &order, err
|
||
}
|
||
|
||
// UpdateStatus 乐观锁更新状态(只允许从 fromStatus 流转到 toStatus)
|
||
// 返回 bool 表示是否更新成功(false = 已被其他 goroutine 更新)
|
||
func (r *TradeOrderRepository) UpdateStatus(ctx context.Context, tradeNo string, fromStatus, toStatus model.TradeStatus, updates map[string]any) (bool, error) {
|
||
if updates == nil {
|
||
updates = make(map[string]any)
|
||
}
|
||
updates["status"] = toStatus
|
||
|
||
result := r.db.WithContext(ctx).Model(&model.TradeOrder{}).
|
||
Where("trade_no = ? AND status = ?", tradeNo, fromStatus).
|
||
Updates(updates)
|
||
if result.Error != nil {
|
||
return false, result.Error
|
||
}
|
||
return result.RowsAffected > 0, nil
|
||
}
|
||
|
||
// ListPayingExpired 查询已过期的 PAYING 订单(用于定时关单补偿)
|
||
func (r *TradeOrderRepository) ListPayingExpired(ctx context.Context, before time.Time, limit int) ([]*model.TradeOrder, error) {
|
||
var orders []*model.TradeOrder
|
||
err := r.db.WithContext(ctx).
|
||
Where("status = ? AND expire_time < ?", model.TradeStatusPaying, before).
|
||
Limit(limit).Find(&orders).Error
|
||
return orders, err
|
||
}
|