package repository import ( "context" "errors" "time" "gorm.io/gorm" "pay-bridge/internal/model" ) // TradeOrderRepository 交易订单数据访问 type TradeOrderRepository struct { db *gorm.DB } func NewTradeOrderRepository(db *gorm.DB) *TradeOrderRepository { return &TradeOrderRepository{db: db} } // Create 创建订单 func (r *TradeOrderRepository) Create(ctx context.Context, order *model.TradeOrder) error { return r.db.WithContext(ctx).Create(order).Error } // GetByTradeNo 按 trade_no 查询 func (r *TradeOrderRepository) GetByTradeNo(ctx context.Context, tradeNo string) (*model.TradeOrder, error) { var order model.TradeOrder err := r.db.WithContext(ctx).Where("trade_no = ?", tradeNo).First(&order).Error if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return &order, err } // GetByMerchantOrderNo 按 app_id + merchant_order_no 查询 func (r *TradeOrderRepository) GetByMerchantOrderNo(ctx context.Context, appID, merchantOrderNo string) (*model.TradeOrder, error) { var order model.TradeOrder err := r.db.WithContext(ctx).Where("app_id = ? AND merchant_order_no = ?", appID, merchantOrderNo).First(&order).Error if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return &order, err } // GetByChannelTradeNo 按渠道交易号查询 func (r *TradeOrderRepository) GetByChannelTradeNo(ctx context.Context, channelTradeNo string) (*model.TradeOrder, error) { var order model.TradeOrder err := r.db.WithContext(ctx).Where("channel_trade_no = ?", channelTradeNo).First(&order).Error if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return &order, err } // UpdateStatus 乐观锁更新状态(只允许从 fromStatus 流转到 toStatus) // 返回 bool 表示是否更新成功(false = 已被其他 goroutine 更新) func (r *TradeOrderRepository) UpdateStatus(ctx context.Context, tradeNo string, fromStatus, toStatus model.TradeStatus, updates map[string]any) (bool, error) { if updates == nil { updates = make(map[string]any) } updates["status"] = toStatus result := r.db.WithContext(ctx).Model(&model.TradeOrder{}). Where("trade_no = ? AND status = ?", tradeNo, fromStatus). Updates(updates) if result.Error != nil { return false, result.Error } return result.RowsAffected > 0, nil } // ListPayingExpired 查询已过期的 PAYING 订单(用于定时关单补偿) func (r *TradeOrderRepository) ListPayingExpired(ctx context.Context, before time.Time, limit int) ([]*model.TradeOrder, error) { var orders []*model.TradeOrder err := r.db.WithContext(ctx). Where("status = ? AND expire_time < ?", model.TradeStatusPaying, before). Limit(limit).Find(&orders).Error return orders, err }