draft
This commit is contained in:
74
backend/internal/service/admin_auth.go
Normal file
74
backend/internal/service/admin_auth.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
type AdminAuthService struct {
|
||||
repo *repository.AdminUserRepository
|
||||
jwtSecret []byte
|
||||
expireHrs int
|
||||
}
|
||||
|
||||
func NewAdminAuthService(repo *repository.AdminUserRepository, jwtSecret string, expireHours int) *AdminAuthService {
|
||||
return &AdminAuthService{
|
||||
repo: repo,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
expireHrs: expireHours,
|
||||
}
|
||||
}
|
||||
|
||||
// Login 验证用户名密码,成功返回 JWT token
|
||||
func (s *AdminAuthService) Login(ctx context.Context, username, password string) (string, error) {
|
||||
user, err := s.repo.GetByUsername(ctx, username)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return "", errors.New("用户名或密码错误")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
|
||||
return "", errors.New("用户名或密码错误")
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"username": user.Username,
|
||||
"exp": time.Now().Add(time.Duration(s.expireHrs) * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(s.jwtSecret)
|
||||
}
|
||||
|
||||
// ParseToken 验证并解析 JWT,返回用户名
|
||||
func (s *AdminAuthService) ParseToken(tokenStr string) (string, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("invalid signing method")
|
||||
}
|
||||
return s.jwtSecret, nil
|
||||
}, jwt.WithValidMethods([]string{"HS256"}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", errors.New("invalid token")
|
||||
}
|
||||
|
||||
username, ok := claims["username"].(string)
|
||||
if !ok {
|
||||
return "", errors.New("invalid token claims")
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
141
backend/internal/service/app.go
Normal file
141
backend/internal/service/app.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/crypto"
|
||||
)
|
||||
|
||||
// AppService 应用服务
|
||||
type AppService struct {
|
||||
repo *repository.AppRepository
|
||||
encKey string
|
||||
}
|
||||
|
||||
func NewAppService(repo *repository.AppRepository, encKey string) *AppService {
|
||||
return &AppService{repo: repo, encKey: encKey}
|
||||
}
|
||||
|
||||
// GetAppSecret 获取 appSecret(用于鉴权中间件)
|
||||
func (s *AppService) GetAppSecret(ctx context.Context, appID string) (string, error) {
|
||||
app, err := s.repo.GetByAppID(ctx, appID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if app == nil {
|
||||
return "", errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
|
||||
secret, err := crypto.Decrypt(app.AppSecret, s.encKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt app secret: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
||||
// GetApp 获取应用信息
|
||||
func (s *AppService) GetApp(ctx context.Context, appID string) (*model.App, error) {
|
||||
return s.repo.GetByAppID(ctx, appID)
|
||||
}
|
||||
|
||||
// CreateAppResult 创建应用的返回,包含明文 secret(仅展示一次)
|
||||
type CreateAppResult struct {
|
||||
App *model.App
|
||||
PlainSecret string
|
||||
}
|
||||
|
||||
// CreateApp 创建应用,自动生成 app_id 和 app_secret
|
||||
func (s *AppService) CreateApp(ctx context.Context, appName string) (*CreateAppResult, error) {
|
||||
appID := generateAppID()
|
||||
plainSecret := generateSecret()
|
||||
|
||||
encSecret, err := crypto.Encrypt(plainSecret, s.encKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app := &model.App{
|
||||
AppID: appID,
|
||||
AppSecret: encSecret,
|
||||
AppName: appName,
|
||||
Status: 1,
|
||||
}
|
||||
if err := s.repo.Create(ctx, app); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CreateAppResult{App: app, PlainSecret: plainSecret}, nil
|
||||
}
|
||||
|
||||
// ListApps 分页查询应用列表
|
||||
func (s *AppService) ListApps(ctx context.Context, limit, offset int) ([]*model.App, error) {
|
||||
return s.repo.List(ctx, limit, offset)
|
||||
}
|
||||
|
||||
// DisableApp 禁用应用
|
||||
func (s *AppService) DisableApp(ctx context.Context, appID string) error {
|
||||
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if app == nil {
|
||||
return errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
return s.repo.UpdateStatus(ctx, appID, 0)
|
||||
}
|
||||
|
||||
// EnableApp 启用应用
|
||||
func (s *AppService) EnableApp(ctx context.Context, appID string) error {
|
||||
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if app == nil {
|
||||
return errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
return s.repo.UpdateStatus(ctx, appID, 1)
|
||||
}
|
||||
|
||||
// ResetSecret 重置应用密钥,返回新的明文 secret(仅此一次)
|
||||
func (s *AppService) ResetSecret(ctx context.Context, appID string) (string, error) {
|
||||
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if app == nil {
|
||||
return "", errors.New(errcode.ErrAppNotFound)
|
||||
}
|
||||
|
||||
plainSecret := generateSecret()
|
||||
encSecret, err := crypto.Encrypt(plainSecret, s.encKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := s.repo.UpdateSecret(ctx, appID, encSecret); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return plainSecret, nil
|
||||
}
|
||||
|
||||
// generateAppID 生成 app_id:app_ + yyMMdd + 8位随机hex
|
||||
func generateAppID() string {
|
||||
b := make([]byte, 4)
|
||||
_, _ = rand.Read(b)
|
||||
date := time.Now().Format("060102")
|
||||
return "app_" + date + hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// generateSecret 生成 32 字节随机 secret(64位hex)
|
||||
func generateSecret() string {
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
return strings.ToUpper(hex.EncodeToString(b))
|
||||
}
|
||||
140
backend/internal/service/channel.go
Normal file
140
backend/internal/service/channel.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/config"
|
||||
"pay-bridge/pkg/crypto"
|
||||
)
|
||||
|
||||
const channelCacheTTL = 5 * time.Minute
|
||||
|
||||
type cachedChannel struct {
|
||||
ch channel.PaymentChannel
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// ChannelService 渠道服务(负责加载渠道配置并获取渠道实例)
|
||||
type ChannelService struct {
|
||||
repo *repository.ChannelConfigRepository
|
||||
encKey string
|
||||
urlsCfg config.ChannelsConfig
|
||||
mu sync.Mutex
|
||||
cache map[string]*cachedChannel
|
||||
}
|
||||
|
||||
func NewChannelService(repo *repository.ChannelConfigRepository, encKey string, urlsCfg config.ChannelsConfig) *ChannelService {
|
||||
return &ChannelService{
|
||||
repo: repo,
|
||||
encKey: encKey,
|
||||
urlsCfg: urlsCfg,
|
||||
cache: make(map[string]*cachedChannel),
|
||||
}
|
||||
}
|
||||
|
||||
// GetChannel 根据 appID 和渠道码获取渠道适配器实例(5 分钟内存缓存)
|
||||
func (s *ChannelService) GetChannel(ctx context.Context, appID, channelCode string) (channel.PaymentChannel, error) {
|
||||
cacheKey := appID + ":" + channelCode
|
||||
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.cache[cacheKey]; ok && time.Now().Before(entry.expiresAt) {
|
||||
ch := entry.ch
|
||||
s.mu.Unlock()
|
||||
return ch, nil
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg, err := s.repo.GetByAppChannel(ctx, appID, channelCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("channel config not found: app=%s channel=%s", appID, channelCode)
|
||||
}
|
||||
|
||||
decCfg, err := s.decryptConfig(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch, err := channel.Get(channelCode, decCfg, s.urlsFor(channelCode))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.cache[cacheKey] = &cachedChannel{ch: ch, expiresAt: time.Now().Add(channelCacheTTL)}
|
||||
s.mu.Unlock()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// InvalidateCache 使指定渠道的缓存失效(配置变更时调用)
|
||||
func (s *ChannelService) InvalidateCache(appID, channelCode string) {
|
||||
s.mu.Lock()
|
||||
delete(s.cache, appID+":"+channelCode)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// ListChannelCodes 获取应用下所有渠道码
|
||||
func (s *ChannelService) ListChannelCodes(ctx context.Context, appID string) ([]string, error) {
|
||||
cfgs, err := s.repo.ListByApp(ctx, appID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
codes := make([]string, 0, len(cfgs))
|
||||
for _, c := range cfgs {
|
||||
codes = append(codes, c.ChannelCode)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
// GetChannelConfig 获取渠道配置(已解密)
|
||||
func (s *ChannelService) GetChannelConfig(ctx context.Context, appID, channelCode string) (*model.ChannelConfig, error) {
|
||||
cfg, err := s.repo.GetByAppChannel(ctx, appID, channelCode)
|
||||
if err != nil || cfg == nil {
|
||||
return cfg, err
|
||||
}
|
||||
return s.decryptConfig(cfg)
|
||||
}
|
||||
|
||||
// urlsFor 根据渠道码返回对应的网关地址配置
|
||||
func (s *ChannelService) urlsFor(channelCode string) channel.URLs {
|
||||
switch channelCode {
|
||||
case "HEEPAY":
|
||||
return channel.URLs{
|
||||
PayURL: s.urlsCfg.Heepay.PayURL,
|
||||
MerchantURL: s.urlsCfg.Heepay.MerchantURL,
|
||||
}
|
||||
default:
|
||||
return channel.URLs{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChannelService) decryptConfig(cfg *model.ChannelConfig) (*model.ChannelConfig, error) {
|
||||
copied := *cfg
|
||||
|
||||
if cfg.APIKey != "" {
|
||||
dec, err := crypto.Decrypt(cfg.APIKey, s.encKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt api_key: %w", err)
|
||||
}
|
||||
copied.APIKey = dec
|
||||
}
|
||||
|
||||
if cfg.PrivateKey != "" {
|
||||
dec, err := crypto.Decrypt(cfg.PrivateKey, s.encKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt private_key: %w", err)
|
||||
}
|
||||
copied.PrivateKey = dec
|
||||
}
|
||||
|
||||
return &copied, nil
|
||||
}
|
||||
268
backend/internal/service/merchant.go
Normal file
268
backend/internal/service/merchant.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// merchantRepo 定义 MerchantService 所需的数据访问方法,便于测试时注入 mock
|
||||
type merchantRepo interface {
|
||||
Create(ctx context.Context, m *model.Merchant) error
|
||||
GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error)
|
||||
GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error)
|
||||
UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error
|
||||
List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
|
||||
ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
|
||||
ListAnomalous(ctx context.Context) ([]*model.Merchant, error)
|
||||
CreateApplication(ctx context.Context, app *model.MerchantApplication) error
|
||||
GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error)
|
||||
GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error)
|
||||
UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error
|
||||
}
|
||||
|
||||
// MerchantService 商户进件与管理服务
|
||||
type MerchantService struct {
|
||||
merchantRepo merchantRepo
|
||||
channelSvc *ChannelService
|
||||
}
|
||||
|
||||
func NewMerchantService(
|
||||
merchantRepo *repository.MerchantRepository,
|
||||
channelSvc *ChannelService,
|
||||
) *MerchantService {
|
||||
return &MerchantService{
|
||||
merchantRepo: merchantRepo,
|
||||
channelSvc: channelSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func genApplicationID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return "APP" + hex.EncodeToString(b)[:16]
|
||||
}
|
||||
|
||||
// Apply 提交商户进件申请
|
||||
// bizContent 为完整的入网申请业务参数(对应 001 文档的 biz_content 结构)
|
||||
func (s *MerchantService) Apply(ctx context.Context, merchantID, channelCode string, bizContent map[string]any) (string, error) {
|
||||
merchant, err := s.merchantRepo.GetByMerchantID(ctx, merchantID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if merchant == nil {
|
||||
return "", errors.New("merchant not found")
|
||||
}
|
||||
if merchant.Status == model.MerchantStatusFrozen {
|
||||
return "", errors.New("merchant is frozen")
|
||||
}
|
||||
|
||||
ch, err := s.channelSvc.GetChannel(ctx, "", channelCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := ch.MerchantApply(ctx, &channel.MerchantApplyReq{
|
||||
MerchantID: merchantID,
|
||||
BizContent: bizContent,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
applicationID := genApplicationID()
|
||||
app := &model.MerchantApplication{
|
||||
ApplicationID: applicationID,
|
||||
MerchantID: merchantID,
|
||||
ChannelCode: channelCode,
|
||||
SubmitData: model.JSONMap(bizContent),
|
||||
AuditStatus: model.AuditStatusSubmitting,
|
||||
SubmittedAt: time.Now(),
|
||||
}
|
||||
// 持久化渠道返回的 request_no,用于后续查询/修改
|
||||
if resp.RequestNo != "" {
|
||||
app.SubmitData["_channel_request_no"] = resp.RequestNo
|
||||
}
|
||||
if err := s.merchantRepo.CreateApplication(ctx, app); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "merchant application submitted",
|
||||
"merchant_id", merchantID,
|
||||
"application_id", applicationID,
|
||||
"channel_code", channelCode,
|
||||
"channel_request_no", resp.RequestNo,
|
||||
)
|
||||
return applicationID, nil
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到指定渠道,返回渠道 file_id
|
||||
func (s *MerchantService) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) {
|
||||
ch, err := s.channelSvc.GetChannel(ctx, "", channelCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := ch.UploadFile(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.FileID, nil
|
||||
}
|
||||
|
||||
// QueryAuditStatus 查询进件审核状态
|
||||
func (s *MerchantService) QueryAuditStatus(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
|
||||
app, err := s.merchantRepo.GetLatestApplication(ctx, merchantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if app == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 如果仍在审核中,向渠道查询最新状态
|
||||
if app.AuditStatus == model.AuditStatusSubmitting || app.AuditStatus == model.AuditStatusReviewing {
|
||||
// 从 submit_data 中读取渠道返回的 request_no
|
||||
channelRequestNo, _ := app.SubmitData["_channel_request_no"].(string)
|
||||
if channelRequestNo != "" {
|
||||
ch, err := s.channelSvc.GetChannel(ctx, "", app.ChannelCode)
|
||||
if err == nil {
|
||||
resp, err := ch.QueryMerchantStatus(ctx, channelRequestNo)
|
||||
if err == nil {
|
||||
merchant, _ := s.merchantRepo.GetByMerchantID(ctx, merchantID)
|
||||
s.syncMerchantStatus(ctx, merchantID, app.ApplicationID, merchant, resp)
|
||||
app, _ = s.merchantRepo.GetLatestApplication(ctx, merchantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return app, nil
|
||||
}
|
||||
|
||||
// syncMerchantStatus 同步渠道返回的审核状态到本地
|
||||
func (s *MerchantService) syncMerchantStatus(ctx context.Context, merchantID, applicationID string,
|
||||
merchant *model.Merchant, resp *channel.MerchantStatusResp) {
|
||||
|
||||
now := time.Now()
|
||||
appUpdates := map[string]any{}
|
||||
|
||||
switch resp.Status {
|
||||
case "APPROVED":
|
||||
appUpdates["audit_status"] = model.AuditStatusApproved
|
||||
appUpdates["audited_at"] = now
|
||||
if resp.ChannelMerchantID != "" {
|
||||
appUpdates["channel_merchant_id"] = resp.ChannelMerchantID
|
||||
}
|
||||
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusActive, nil)
|
||||
|
||||
case "REJECTED":
|
||||
appUpdates["audit_status"] = model.AuditStatusRejected
|
||||
appUpdates["reject_reason"] = resp.RejectReason
|
||||
appUpdates["audited_at"] = now
|
||||
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusRejected, nil)
|
||||
|
||||
case "REVIEWING":
|
||||
appUpdates["audit_status"] = model.AuditStatusReviewing
|
||||
|
||||
case "FROZEN":
|
||||
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusFrozen, nil)
|
||||
}
|
||||
|
||||
if len(appUpdates) > 0 {
|
||||
s.merchantRepo.UpdateApplication(ctx, applicationID, appUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// GetChannelMerchantID 返回指定商户在指定渠道进件审核通过后的渠道商户ID
|
||||
// 若该商户未在该渠道进件或审核未通过,返回空字符串
|
||||
func (s *MerchantService) GetChannelMerchantID(ctx context.Context, merchantID, channelCode string) (string, error) {
|
||||
app, err := s.merchantRepo.GetApprovedApplicationByChannel(ctx, merchantID, channelCode)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if app == nil {
|
||||
return "", nil
|
||||
}
|
||||
return app.ChannelMerchantID, nil
|
||||
}
|
||||
|
||||
// CreateMerchantForApp 业务侧创建商户,强制绑定 appID
|
||||
func (s *MerchantService) CreateMerchantForApp(ctx context.Context, appID string, m *model.Merchant) error {
|
||||
m.AppID = appID
|
||||
return s.merchantRepo.Create(ctx, m)
|
||||
}
|
||||
|
||||
// GetMerchantForApp 业务侧查询,校验 appID 归属
|
||||
func (s *MerchantService) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) {
|
||||
m, err := s.merchantRepo.GetByMerchantIDAndAppID(ctx, merchantID, appID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m == nil {
|
||||
return nil, errors.New("30001") // merchant not found
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ListMerchantsForApp 业务侧列表,只返回该 appID 下的商户
|
||||
func (s *MerchantService) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
return s.merchantRepo.ListByAppID(ctx, appID, status, limit, offset)
|
||||
}
|
||||
|
||||
// ApplyForApp 业务侧进件,校验 appID 归属后委托 Apply
|
||||
func (s *MerchantService) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) {
|
||||
if _, err := s.GetMerchantForApp(ctx, appID, merchantID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Apply(ctx, merchantID, channelCode, bizContent)
|
||||
}
|
||||
|
||||
// QueryAuditStatusForApp 业务侧查审核状态,校验 appID 归属
|
||||
func (s *MerchantService) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) {
|
||||
if _, err := s.GetMerchantForApp(ctx, appID, merchantID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.QueryAuditStatus(ctx, merchantID)
|
||||
}
|
||||
|
||||
// CheckAnomalies 检查状态异常的商户(由 cron 调用)
|
||||
func (s *MerchantService) CheckAnomalies(ctx context.Context) error {
|
||||
merchants, err := s.merchantRepo.ListAnomalous(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.InfoContext(ctx, "anomalous merchants found", "count", len(merchants))
|
||||
// 实际业务中可在此发送告警通知
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateMerchant 创建商户基础信息
|
||||
func (s *MerchantService) CreateMerchant(ctx context.Context, m *model.Merchant) error {
|
||||
return s.merchantRepo.Create(ctx, m)
|
||||
}
|
||||
|
||||
// GetMerchant 查询商户信息
|
||||
func (s *MerchantService) GetMerchant(ctx context.Context, merchantID string) (*model.Merchant, error) {
|
||||
return s.merchantRepo.GetByMerchantID(ctx, merchantID)
|
||||
}
|
||||
|
||||
// ListMerchants 查询商户列表
|
||||
func (s *MerchantService) ListMerchants(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
return s.merchantRepo.List(ctx, status, limit, offset)
|
||||
}
|
||||
|
||||
// FreezeMerchant 冻结商户
|
||||
func (s *MerchantService) FreezeMerchant(ctx context.Context, merchantID string) error {
|
||||
return s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusFrozen, nil)
|
||||
}
|
||||
|
||||
// UnfreezeMerchant 解冻商户
|
||||
func (s *MerchantService) UnfreezeMerchant(ctx context.Context, merchantID string) error {
|
||||
return s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusActive, nil)
|
||||
}
|
||||
200
backend/internal/service/merchant_test.go
Normal file
200
backend/internal/service/merchant_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
// mockMerchantRepo 实现 merchantRepo interface
|
||||
type mockMerchantRepo struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error {
|
||||
return m.Called(ctx, merchant).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, merchantID)
|
||||
return args.Get(0).(*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
|
||||
args := m.Called(ctx, merchantID, appID)
|
||||
v, _ := args.Get(0).(*model.Merchant)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
|
||||
return m.Called(ctx, merchantID, status, updates).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx, appID, status, limit, offset)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
|
||||
args := m.Called(ctx)
|
||||
return args.Get(0).([]*model.Merchant), args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
|
||||
return m.Called(ctx, app).Error(0)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, merchantID)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
|
||||
args := m.Called(ctx, merchantID, channelCode)
|
||||
v, _ := args.Get(0).(*model.MerchantApplication)
|
||||
return v, args.Error(1)
|
||||
}
|
||||
func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
|
||||
return m.Called(ctx, applicationID, updates).Error(0)
|
||||
}
|
||||
|
||||
// newTestMerchantService 创建注入了 mock repo 的 service(channelSvc 为 nil,仅测不涉及渠道的方法)
|
||||
func newTestMerchantService(repo merchantRepo) *MerchantService {
|
||||
return &MerchantService{merchantRepo: repo}
|
||||
}
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
// --- GetMerchantForApp ---
|
||||
|
||||
func TestGetMerchantForApp_OK(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
want := &model.Merchant{MerchantID: "m001", AppID: "app1"}
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
got, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, want, got)
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_NotFound(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.EqualError(t, err, "30001")
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_WrongAppID(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
// 商户存在但属于 other_app,GetByMerchantIDAndAppID 返回 nil
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "evil_app", "m001")
|
||||
|
||||
assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found,而不是泄露商户信息")
|
||||
}
|
||||
|
||||
func TestGetMerchantForApp_DBError(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error"))
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
|
||||
|
||||
assert.EqualError(t, err, "db error")
|
||||
}
|
||||
|
||||
// --- CreateMerchantForApp ---
|
||||
|
||||
func TestCreateMerchantForApp_SetsAppID(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool {
|
||||
return m.AppID == "app1" && m.MerchantID == "m001"
|
||||
})).Return(nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
m := &model.Merchant{MerchantID: "m001"}
|
||||
err := svc.CreateMerchantForApp(ctx, "app1", m)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入")
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- ListMerchantsForApp ---
|
||||
|
||||
func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}}
|
||||
repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, got, 1)
|
||||
repo.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// --- ApplyForApp ---
|
||||
|
||||
func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
_, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil)
|
||||
|
||||
assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件")
|
||||
}
|
||||
|
||||
// --- GetChannelMerchantID ---
|
||||
|
||||
func TestGetChannelMerchantID_Approved(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
app := &model.MerchantApplication{
|
||||
ChannelMerchantID: "ch_m_999",
|
||||
}
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ch_m_999", id)
|
||||
}
|
||||
|
||||
func TestGetChannelMerchantID_NotApproved(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, id, "未在该渠道进件时返回空字符串")
|
||||
}
|
||||
|
||||
func TestGetChannelMerchantID_MultiChannel(t *testing.T) {
|
||||
repo := new(mockMerchantRepo)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").
|
||||
Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil)
|
||||
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").
|
||||
Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil)
|
||||
|
||||
svc := newTestMerchantService(repo)
|
||||
|
||||
heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
|
||||
aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
|
||||
|
||||
assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id")
|
||||
assert.Equal(t, "ali_001", aliID)
|
||||
}
|
||||
229
backend/internal/service/notify.go
Normal file
229
backend/internal/service/notify.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// 重试间隔:9 次推送机会(第1次立即,后续8次重试)
|
||||
var retryIntervals = []time.Duration{
|
||||
0,
|
||||
15 * time.Second,
|
||||
30 * time.Second,
|
||||
1 * time.Minute,
|
||||
5 * time.Minute,
|
||||
30 * time.Minute,
|
||||
1 * time.Hour,
|
||||
6 * time.Hour,
|
||||
12 * time.Hour,
|
||||
}
|
||||
|
||||
const maxRetry = 8
|
||||
|
||||
// NotifyService 通知服务
|
||||
type NotifyService struct {
|
||||
notifyRepo *repository.NotifyLogRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewNotifyService(
|
||||
notifyRepo *repository.NotifyLogRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
httpTimeout time.Duration,
|
||||
) *NotifyService {
|
||||
return &NotifyService{
|
||||
notifyRepo: notifyRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
httpClient: &http.Client{Timeout: httpTimeout},
|
||||
}
|
||||
}
|
||||
|
||||
// SendNotify 向下游发送通知(首次调用)
|
||||
func (s *NotifyService) SendNotify(ctx context.Context, tradeNo string, notifyType model.NotifyType, notifyURL string) error {
|
||||
// 构建通知内容
|
||||
payload, err := s.buildPayload(ctx, tradeNo, notifyType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建通知记录
|
||||
now := time.Now()
|
||||
log := &model.NotifyLog{
|
||||
TradeNo: tradeNo,
|
||||
NotifyType: notifyType,
|
||||
NotifyURL: notifyURL,
|
||||
Status: model.NotifyStatusPending,
|
||||
RetryCount: 0,
|
||||
}
|
||||
if err := s.notifyRepo.Upsert(ctx, log); err != nil {
|
||||
slog.ErrorContext(ctx, "upsert notify log failed", "trade_no", tradeNo, "err", err)
|
||||
}
|
||||
|
||||
// 发送通知
|
||||
resp, err := s.sendHTTP(ctx, notifyURL, payload)
|
||||
if err == nil && isSuccessResponse(resp) {
|
||||
s.notifyRepo.MarkSuccess(ctx, log.ID, resp)
|
||||
slog.InfoContext(ctx, "notify success", "trade_no", tradeNo, "type", notifyType)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 首次失败,写入重试队列
|
||||
errMsg := ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else {
|
||||
errMsg = resp
|
||||
}
|
||||
|
||||
nextTime := now.Add(retryIntervals[1])
|
||||
s.notifyRepo.IncrRetryCount(ctx, log.ID, model.NotifyStatusRetry, &nextTime, errMsg)
|
||||
slog.WarnContext(ctx, "notify failed, scheduled retry", "trade_no", tradeNo, "next_retry", nextTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProcessRetryQueue 处理重试队列(由 Poller 调用)
|
||||
func (s *NotifyService) ProcessRetryQueue(ctx context.Context, batchSize int) error {
|
||||
logs, err := s.notifyRepo.ListPendingRetry(ctx, time.Now(), batchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
s.processOne(ctx, log)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotifyService) processOne(ctx context.Context, log *model.NotifyLog) {
|
||||
payload, err := s.buildPayload(ctx, log.TradeNo, log.NotifyType)
|
||||
if err != nil {
|
||||
slog.ErrorContext(ctx, "build payload failed", "trade_no", log.TradeNo, "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.sendHTTP(ctx, log.NotifyURL, payload)
|
||||
if err == nil && isSuccessResponse(resp) {
|
||||
s.notifyRepo.MarkSuccess(ctx, log.ID, resp)
|
||||
slog.InfoContext(ctx, "notify retry success", "trade_no", log.TradeNo, "retry_count", log.RetryCount)
|
||||
return
|
||||
}
|
||||
|
||||
errMsg := ""
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
} else {
|
||||
errMsg = resp
|
||||
}
|
||||
|
||||
nextRetryIdx := log.RetryCount + 1
|
||||
if nextRetryIdx > maxRetry {
|
||||
s.notifyRepo.MarkGiveup(ctx, log.ID)
|
||||
slog.WarnContext(ctx, "notify giveup after max retries", "trade_no", log.TradeNo)
|
||||
return
|
||||
}
|
||||
|
||||
var nextTime *time.Time
|
||||
if nextRetryIdx < len(retryIntervals) {
|
||||
t := time.Now().Add(retryIntervals[nextRetryIdx])
|
||||
nextTime = &t
|
||||
}
|
||||
|
||||
status := model.NotifyStatusRetry
|
||||
if nextRetryIdx >= maxRetry {
|
||||
status = model.NotifyStatusGiveup
|
||||
}
|
||||
|
||||
s.notifyRepo.IncrRetryCount(ctx, log.ID, status, nextTime, errMsg)
|
||||
}
|
||||
|
||||
// buildPayload 构建通知内容
|
||||
func (s *NotifyService) buildPayload(ctx context.Context, tradeNo string, notifyType model.NotifyType) ([]byte, error) {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return nil, fmt.Errorf("order not found: %s", tradeNo)
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"trade_no": order.TradeNo,
|
||||
"merchant_order_no": order.MerchantOrderNo,
|
||||
"app_id": order.AppID,
|
||||
"pay_method": order.PayMethod,
|
||||
"amount": order.Amount,
|
||||
"status": order.Status,
|
||||
"notify_type": notifyType,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
if order.ChannelTradeNo != "" {
|
||||
payload["channel_trade_no"] = order.ChannelTradeNo
|
||||
}
|
||||
if order.PayTime != nil {
|
||||
payload["pay_time"] = order.PayTime.Unix()
|
||||
}
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
// sendHTTP 向下游发送 HTTP POST 通知
|
||||
func (s *NotifyService) sendHTTP(ctx context.Context, notifyURL string, payload []byte) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, notifyURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// isSuccessResponse 判断下游是否返回成功
|
||||
// 下游返回 HTTP 200 且 body 包含 "success" 则视为成功
|
||||
func isSuccessResponse(body string) bool {
|
||||
return strings.Contains(strings.ToLower(body), "success")
|
||||
}
|
||||
|
||||
// NextRetryTime 计算下次重试时间
|
||||
func NextRetryTime(retryCount int) (time.Time, bool) {
|
||||
idx := retryCount + 1
|
||||
if idx >= len(retryIntervals) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Now().Add(retryIntervals[idx]), true
|
||||
}
|
||||
|
||||
// StartPoller 启动通知重试 Poller goroutine
|
||||
func (s *NotifyService) StartPoller(ctx context.Context, interval time.Duration, batchSize int) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := s.ProcessRetryQueue(ctx, batchSize); err != nil {
|
||||
slog.Error("notify poller error", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
slog.Info("notify poller started", "interval", interval)
|
||||
}
|
||||
280
backend/internal/service/payment_match.go
Normal file
280
backend/internal/service/payment_match.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// orderNoPatterns 从备注中提取订单号的正则列表(优先级从高到低)
|
||||
var orderNoPatterns = []*regexp.Regexp{
|
||||
regexp.MustCompile(`PAY\d{14}`), // pay-bridge 交易号格式 PAYyyMMddNNNNNNNN
|
||||
regexp.MustCompile(`REF\d{14}`), // 退款单号
|
||||
regexp.MustCompile(`[A-Z0-9]{16,32}`), // 通用订单号格式
|
||||
}
|
||||
|
||||
// IncomingPayment 入账通知数据
|
||||
type IncomingPayment struct {
|
||||
AccountNo string // 收款账号
|
||||
Amount int64 // 入账金额(分)
|
||||
Remark string // 转账备注
|
||||
PayerName string // 付款方名称
|
||||
ChannelBillNo string // 渠道流水号
|
||||
}
|
||||
|
||||
// matchWindow 匹配时间窗口(7天内的待支付订单)
|
||||
const matchWindow = 7 * 24 * time.Hour
|
||||
|
||||
// PaymentMatchService 收款匹配服务
|
||||
type PaymentMatchService struct {
|
||||
matchRepo *repository.PaymentMatchRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
notifySvc *NotifyService
|
||||
tradeSvc *TradeService
|
||||
}
|
||||
|
||||
func NewPaymentMatchService(
|
||||
matchRepo *repository.PaymentMatchRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
notifySvc *NotifyService,
|
||||
tradeSvc *TradeService,
|
||||
) *PaymentMatchService {
|
||||
return &PaymentMatchService{
|
||||
matchRepo: matchRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
notifySvc: notifySvc,
|
||||
tradeSvc: tradeSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleIncomingPayment 处理入账通知(核心匹配流程)
|
||||
func (s *PaymentMatchService) HandleIncomingPayment(ctx context.Context, incoming *IncomingPayment) error {
|
||||
// 幂等检查
|
||||
if existing, _ := s.matchRepo.GetMatchLogByBillNo(ctx, incoming.ChannelBillNo); existing != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查询收款账户
|
||||
account, err := s.matchRepo.GetAccountByNo(ctx, incoming.AccountNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if account == nil {
|
||||
slog.WarnContext(ctx, "incoming payment: account not found", "account_no", incoming.AccountNo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 执行匹配
|
||||
result := s.match(ctx, incoming, account)
|
||||
|
||||
// 记录匹配结果
|
||||
now := time.Now()
|
||||
log := &model.PaymentMatchLog{
|
||||
AccountID: account.ID,
|
||||
IncomingAmount: incoming.Amount,
|
||||
IncomingRemark: incoming.Remark,
|
||||
PayerName: incoming.PayerName,
|
||||
ChannelBillNo: incoming.ChannelBillNo,
|
||||
MatchStatus: result.status,
|
||||
NameDiff: result.nameDiff,
|
||||
}
|
||||
if result.tradeNo != "" {
|
||||
log.TradeNo = result.tradeNo
|
||||
log.MatchTime = &now
|
||||
}
|
||||
if err := s.matchRepo.CreateMatchLog(ctx, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 匹配成功:更新订单状态并通知下游
|
||||
if result.tradeNo != "" {
|
||||
updates := map[string]any{
|
||||
"status": model.TradeStatusPaid,
|
||||
"pay_time": now,
|
||||
}
|
||||
ok, err := s.tradeRepo.UpdateStatus(ctx, result.tradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
order, _ := s.tradeRepo.GetByTradeNo(ctx, result.tradeNo)
|
||||
if order != nil && s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
s.notifySvc.SendNotify(bgCtx, result.tradeNo, model.NotifyTypePayment, order.NotifyURL)
|
||||
}()
|
||||
}
|
||||
}
|
||||
slog.InfoContext(ctx, "payment matched",
|
||||
"trade_no", result.tradeNo,
|
||||
"amount", incoming.Amount,
|
||||
"status", result.status,
|
||||
"name_diff", result.nameDiff,
|
||||
)
|
||||
} else {
|
||||
slog.InfoContext(ctx, "payment pending manual",
|
||||
"channel_bill_no", incoming.ChannelBillNo,
|
||||
"amount", incoming.Amount,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ManualBindOrder 人工关联入账与订单
|
||||
func (s *PaymentMatchService) ManualBindOrder(ctx context.Context, matchID uint64, tradeNo, operator string) error {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
updates := map[string]any{
|
||||
"trade_no": tradeNo,
|
||||
"match_status": model.MatchStatusMatched,
|
||||
"match_time": now,
|
||||
"operator": operator,
|
||||
}
|
||||
if err := s.matchRepo.UpdateMatchLog(ctx, matchID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新订单状态
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusPaid,
|
||||
map[string]any{"pay_time": now})
|
||||
|
||||
if s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
s.notifySvc.SendNotify(bgCtx, tradeNo, model.NotifyTypePayment, order.NotifyURL)
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPendingManual 查询待人工确认的收款记录
|
||||
func (s *PaymentMatchService) ListPendingManual(ctx context.Context, appID string, limit, offset int) ([]*model.PaymentMatchLog, error) {
|
||||
return s.matchRepo.ListPendingManual(ctx, appID, limit, offset)
|
||||
}
|
||||
|
||||
// --- 内部匹配逻辑 ---
|
||||
|
||||
type matchResult struct {
|
||||
tradeNo string
|
||||
status model.MatchStatus
|
||||
nameDiff int8
|
||||
}
|
||||
|
||||
func (s *PaymentMatchService) match(ctx context.Context, incoming *IncomingPayment, account *model.SubMerchantAccount) matchResult {
|
||||
// Step 1: 从备注中提取订单号
|
||||
candidates := extractOrderNos(incoming.Remark)
|
||||
|
||||
var matched *model.TradeOrder
|
||||
for _, orderNo := range candidates {
|
||||
// 先按 trade_no 查,再按 merchant_order_no 查
|
||||
order, _ := s.tradeRepo.GetByTradeNo(ctx, orderNo)
|
||||
if order == nil {
|
||||
order, _ = s.tradeRepo.GetByMerchantOrderNo(ctx, account.AppID, orderNo)
|
||||
}
|
||||
if order == nil || order.AppID != account.AppID {
|
||||
continue
|
||||
}
|
||||
if order.Status != model.TradeStatusPaying {
|
||||
continue
|
||||
}
|
||||
// Step 2: 金额精确匹配
|
||||
if order.Amount != incoming.Amount {
|
||||
continue
|
||||
}
|
||||
matched = order
|
||||
break
|
||||
}
|
||||
|
||||
// 备注匹配失败,降级为金额匹配
|
||||
if matched == nil {
|
||||
orders, _ := s.matchRepo.ListPayingByAmount(ctx, account.AppID, incoming.Amount, matchWindow)
|
||||
if len(orders) == 1 {
|
||||
matched = orders[0]
|
||||
} else if len(orders) > 1 {
|
||||
// Step 3: 用付款方名称缩小范围
|
||||
matched = filterByPayerName(orders, incoming.PayerName)
|
||||
if matched == nil {
|
||||
return matchResult{status: model.MatchStatusPendingManual}
|
||||
}
|
||||
} else {
|
||||
return matchResult{status: model.MatchStatusPendingManual}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: 付款方名称一致性检查
|
||||
var nameDiff int8 = 0
|
||||
invoiceName := getInvoiceName(matched)
|
||||
if invoiceName != "" && incoming.PayerName != "" {
|
||||
if !strings.EqualFold(strings.TrimSpace(invoiceName), strings.TrimSpace(incoming.PayerName)) {
|
||||
nameDiff = 1
|
||||
}
|
||||
}
|
||||
|
||||
status := model.MatchStatusMatched
|
||||
if nameDiff == 1 {
|
||||
status = model.MatchStatusNameDiff
|
||||
}
|
||||
|
||||
return matchResult{
|
||||
tradeNo: matched.TradeNo,
|
||||
status: status,
|
||||
nameDiff: nameDiff,
|
||||
}
|
||||
}
|
||||
|
||||
// extractOrderNos 从备注字符串中提取可能的订单号
|
||||
func extractOrderNos(remark string) []string {
|
||||
if remark == "" {
|
||||
return nil
|
||||
}
|
||||
var results []string
|
||||
seen := map[string]bool{}
|
||||
for _, re := range orderNoPatterns {
|
||||
matches := re.FindAllString(remark, -1)
|
||||
for _, m := range matches {
|
||||
if !seen[m] {
|
||||
seen[m] = true
|
||||
results = append(results, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// filterByPayerName 从多个候选订单中,选择 invoice_name 与付款方名称匹配的订单
|
||||
// invoice_name 暂存在 extra 字段中
|
||||
func filterByPayerName(orders []*model.TradeOrder, payerName string) *model.TradeOrder {
|
||||
if payerName == "" {
|
||||
return nil
|
||||
}
|
||||
for _, o := range orders {
|
||||
name := getInvoiceName(o)
|
||||
if name != "" && strings.EqualFold(strings.TrimSpace(name), strings.TrimSpace(payerName)) {
|
||||
return o
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInvoiceName 从 extra 字段获取开票名称
|
||||
func getInvoiceName(order *model.TradeOrder) string {
|
||||
if order.Extra == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := order.Extra["invoice_name"]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
268
backend/internal/service/profit_sharing.go
Normal file
268
backend/internal/service/profit_sharing.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
const (
|
||||
sharingLockPrefix = "lock:sharing:"
|
||||
sharingLockTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
// ProfitSharingService 分润服务
|
||||
type ProfitSharingService struct {
|
||||
sharingRepo *repository.ProfitSharingRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
seqSvc *sequence.Service
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewProfitSharingService(
|
||||
sharingRepo *repository.ProfitSharingRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
seqSvc *sequence.Service,
|
||||
rdb *redis.Client,
|
||||
) *ProfitSharingService {
|
||||
return &ProfitSharingService{
|
||||
sharingRepo: sharingRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
seqSvc: seqSvc,
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerSharing 支付成功后触发分润(幂等)
|
||||
func (s *ProfitSharingService) TriggerSharing(ctx context.Context, tradeNo string) error {
|
||||
// 分布式锁防止并发重复触发
|
||||
lockKey := sharingLockPrefix + tradeNo
|
||||
ok, err := s.rdb.SetNX(ctx, lockKey, "1", sharingLockTTL).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquire sharing lock: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return nil // 已有进程在处理
|
||||
}
|
||||
defer s.rdb.Del(ctx, lockKey)
|
||||
|
||||
// 查询交易
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.ProfitSharingAmount <= 0 {
|
||||
return nil // 无需分润
|
||||
}
|
||||
|
||||
// 幂等检查:是否已有分润记录
|
||||
existing, err := s.sharingRepo.GetOrderByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil // 已触发过
|
||||
}
|
||||
|
||||
// 获取应用分润配置
|
||||
cfg, err := s.sharingRepo.GetConfigByAppID(ctx, order.AppID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New(errcode.ErrSharingNotConfig)
|
||||
}
|
||||
|
||||
// 校验分润比例
|
||||
maxAmount := int64(float64(order.Amount) * cfg.MaxSharingRatio)
|
||||
if order.ProfitSharingAmount > maxAmount {
|
||||
return errors.New(errcode.ErrSharingAmountExceed)
|
||||
}
|
||||
|
||||
// 生成分润单号
|
||||
sharingNo, err := s.seqSvc.NextSharingNo(ctx, order.AppID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建分润记录
|
||||
sharingOrder := &model.ProfitSharingOrder{
|
||||
SharingNo: sharingNo,
|
||||
TradeNo: tradeNo,
|
||||
AppID: order.AppID,
|
||||
ReceiverMerchantID: cfg.ReceiverMerchantID,
|
||||
SharingAmount: order.ProfitSharingAmount,
|
||||
Status: model.ProfitSharingStatusPending,
|
||||
}
|
||||
if err := s.sharingRepo.CreateOrder(ctx, sharingOrder); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 调用渠道分账
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := ch.ProfitSharing(ctx, &channel.ProfitSharingReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
SharingNo: sharingNo,
|
||||
ReceiverMerchantID: cfg.ReceiverMerchantID,
|
||||
Amount: order.ProfitSharingAmount,
|
||||
})
|
||||
if err != nil {
|
||||
s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
|
||||
model.ProfitSharingStatusPending,
|
||||
model.ProfitSharingStatusFailed,
|
||||
map[string]any{"fail_reason": err.Error()},
|
||||
)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingNo,
|
||||
Action: "SPLIT",
|
||||
Amount: order.ProfitSharingAmount,
|
||||
Status: "FAILED",
|
||||
})
|
||||
return fmt.Errorf("profit sharing failed: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
|
||||
model.ProfitSharingStatusPending,
|
||||
model.ProfitSharingStatusProcessing,
|
||||
map[string]any{
|
||||
"channel_sharing_no": resp.ChannelSharingNo,
|
||||
"sharing_time": now,
|
||||
},
|
||||
)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingNo,
|
||||
Action: "SPLIT",
|
||||
Amount: order.ProfitSharingAmount,
|
||||
Status: "PROCESSING",
|
||||
})
|
||||
|
||||
slog.InfoContext(ctx, "profit sharing triggered",
|
||||
"trade_no", tradeNo,
|
||||
"sharing_no", sharingNo,
|
||||
"amount", order.ProfitSharingAmount,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleSharingNotify 处理分账回调(上游分账完成通知)
|
||||
func (s *ProfitSharingService) HandleSharingNotify(ctx context.Context, sharingNo, channelSharingNo string, status model.ProfitSharingStatus) error {
|
||||
now := time.Now()
|
||||
updates := map[string]any{
|
||||
"channel_sharing_no": channelSharingNo,
|
||||
"sharing_time": now,
|
||||
}
|
||||
ok, err := s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
|
||||
model.ProfitSharingStatusProcessing, status, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil // 幂等
|
||||
}
|
||||
logStatus := string(status)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingNo,
|
||||
Action: "SPLIT",
|
||||
Amount: 0,
|
||||
Status: logStatus,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollbackSharing 退款前回退分润
|
||||
func (s *ProfitSharingService) RollbackSharing(ctx context.Context, tradeNo string) error {
|
||||
sharingOrder, err := s.sharingRepo.GetOrderByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sharingOrder == nil {
|
||||
return nil // 无分润,直接跳过
|
||||
}
|
||||
if sharingOrder.Status == model.ProfitSharingStatusRollback {
|
||||
return nil // 已回退,幂等
|
||||
}
|
||||
if sharingOrder.Status != model.ProfitSharingStatusSuccess {
|
||||
return fmt.Errorf("sharing not success, cannot rollback, status=%s", sharingOrder.Status)
|
||||
}
|
||||
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ch.RollbackProfitSharing(ctx, &channel.RollbackSharingReq{
|
||||
SharingNo: sharingOrder.SharingNo,
|
||||
ChannelSharingNo: sharingOrder.ChannelSharingNo,
|
||||
TradeNo: tradeNo,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("rollback sharing failed: %w", err)
|
||||
}
|
||||
|
||||
s.sharingRepo.UpdateOrderStatus(ctx, sharingOrder.SharingNo,
|
||||
model.ProfitSharingStatusSuccess,
|
||||
model.ProfitSharingStatusRollback,
|
||||
nil,
|
||||
)
|
||||
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
|
||||
SharingNo: sharingOrder.SharingNo,
|
||||
Action: "ROLLBACK",
|
||||
Amount: sharingOrder.SharingAmount,
|
||||
Status: "SUCCESS",
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuerySharing 查询分润状态
|
||||
func (s *ProfitSharingService) QuerySharing(ctx context.Context, sharingNo string) (*model.ProfitSharingOrder, error) {
|
||||
order, err := s.sharingRepo.GetOrderBySharingNo(ctx, sharingNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil {
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// ValidateSharingAmount 下单时校验分润金额是否合法
|
||||
func (s *ProfitSharingService) ValidateSharingAmount(ctx context.Context, appID string, orderAmount, sharingAmount int64) error {
|
||||
if sharingAmount <= 0 {
|
||||
return nil
|
||||
}
|
||||
cfg, err := s.sharingRepo.GetConfigByAppID(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New(errcode.ErrSharingNotConfig)
|
||||
}
|
||||
maxAmount := int64(float64(orderAmount) * cfg.MaxSharingRatio)
|
||||
if sharingAmount > maxAmount {
|
||||
return errors.New(errcode.ErrSharingAmountExceed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
221
backend/internal/service/reconciliation.go
Normal file
221
backend/internal/service/reconciliation.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// ReconciliationService T+1 自动对账服务
|
||||
type ReconciliationService struct {
|
||||
reconRepo *repository.ReconciliationRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
appRepo *repository.AppRepository
|
||||
}
|
||||
|
||||
func NewReconciliationService(
|
||||
reconRepo *repository.ReconciliationRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
appRepo *repository.AppRepository,
|
||||
) *ReconciliationService {
|
||||
return &ReconciliationService{
|
||||
reconRepo: reconRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
appRepo: appRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// RunDailyReconciliation 执行 T+1 对账(cron 每日触发)
|
||||
func (s *ReconciliationService) RunDailyReconciliation(ctx context.Context) error {
|
||||
// 对账日期:昨天
|
||||
billDate := time.Now().AddDate(0, 0, -1).Format("2006-01-02")
|
||||
slog.InfoContext(ctx, "reconciliation started", "bill_date", billDate)
|
||||
|
||||
apps, err := s.appRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, app := range apps {
|
||||
if err := s.reconcileApp(ctx, app.AppID, billDate); err != nil {
|
||||
slog.ErrorContext(ctx, "reconciliation failed for app",
|
||||
"app_id", app.AppID,
|
||||
"bill_date", billDate,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileApp 对指定应用执行对账
|
||||
func (s *ReconciliationService) reconcileApp(ctx context.Context, appID, billDate string) error {
|
||||
// 获取所有活跃渠道配置
|
||||
channelCodes, err := s.channelSvc.ListChannelCodes(ctx, appID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, code := range channelCodes {
|
||||
if err := s.reconcileChannel(ctx, appID, code, billDate); err != nil {
|
||||
slog.ErrorContext(ctx, "channel reconciliation failed",
|
||||
"app_id", appID,
|
||||
"channel", code,
|
||||
"bill_date", billDate,
|
||||
"error", err,
|
||||
)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileChannel 对单个渠道执行对账
|
||||
func (s *ReconciliationService) reconcileChannel(ctx context.Context, appID, channelCode, billDate string) error {
|
||||
// 幂等检查
|
||||
existing, err := s.reconRepo.GetReport(ctx, appID, billDate, channelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil && existing.Status == model.ReconciliationStatusMatched {
|
||||
return nil // 已对账完成
|
||||
}
|
||||
|
||||
// 创建对账报告
|
||||
report := &model.ReconciliationReport{
|
||||
AppID: appID,
|
||||
ChannelCode: channelCode,
|
||||
BillDate: billDate,
|
||||
Status: model.ReconciliationStatusPending,
|
||||
}
|
||||
if existing == nil {
|
||||
if err := s.reconRepo.CreateReport(ctx, report); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
report = existing
|
||||
}
|
||||
|
||||
// 下载渠道对账单
|
||||
ch, err := s.channelSvc.GetChannel(ctx, appID, channelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
billData, err := ch.DownloadBill(ctx, &channel.DownloadBillReq{BillDate: billDate})
|
||||
if err != nil {
|
||||
return fmt.Errorf("download bill: %w", err)
|
||||
}
|
||||
|
||||
// 查询本地已支付订单
|
||||
localOrders, err := s.reconRepo.ListPaidOrdersByDate(ctx, appID, billDate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 建立本地订单索引
|
||||
localIndex := make(map[string]*model.TradeOrder, len(localOrders))
|
||||
for _, o := range localOrders {
|
||||
localIndex[o.TradeNo] = o
|
||||
}
|
||||
|
||||
// 建立渠道账单索引
|
||||
channelIndex := make(map[string]*channel.BillRecord, len(billData.Records))
|
||||
for i := range billData.Records {
|
||||
channelIndex[billData.Records[i].TradeNo] = &billData.Records[i]
|
||||
}
|
||||
|
||||
matched := 0
|
||||
exceptions := 0
|
||||
|
||||
// 检查渠道账单中有,本地没有的(漏单)
|
||||
for _, rec := range billData.Records {
|
||||
local, ok := localIndex[rec.TradeNo]
|
||||
if !ok {
|
||||
// 本地缺失
|
||||
ex := &model.ReconciliationException{
|
||||
ReportID: report.ID,
|
||||
TradeNo: rec.TradeNo,
|
||||
ChannelBillNo: rec.ChannelBillNo,
|
||||
ExceptionType: "MISSING_LOCAL",
|
||||
ChannelAmount: rec.Amount,
|
||||
Remark: "渠道有记录,本地无此订单",
|
||||
}
|
||||
s.reconRepo.CreateException(ctx, ex)
|
||||
exceptions++
|
||||
continue
|
||||
}
|
||||
// 金额比对
|
||||
if local.Amount != rec.Amount {
|
||||
ex := &model.ReconciliationException{
|
||||
ReportID: report.ID,
|
||||
TradeNo: rec.TradeNo,
|
||||
ChannelBillNo: rec.ChannelBillNo,
|
||||
ExceptionType: "AMOUNT_MISMATCH",
|
||||
LocalAmount: local.Amount,
|
||||
ChannelAmount: rec.Amount,
|
||||
Remark: fmt.Sprintf("金额不符:本地%d 渠道%d", local.Amount, rec.Amount),
|
||||
}
|
||||
s.reconRepo.CreateException(ctx, ex)
|
||||
exceptions++
|
||||
} else {
|
||||
matched++
|
||||
}
|
||||
}
|
||||
|
||||
// 检查本地有,渠道账单中没有的(多单)
|
||||
for tradeNo, local := range localIndex {
|
||||
if _, ok := channelIndex[tradeNo]; !ok {
|
||||
ex := &model.ReconciliationException{
|
||||
ReportID: report.ID,
|
||||
TradeNo: tradeNo,
|
||||
ExceptionType: "MISSING_CHANNEL",
|
||||
LocalAmount: local.Amount,
|
||||
Remark: "本地已支付,渠道账单无记录",
|
||||
}
|
||||
s.reconRepo.CreateException(ctx, ex)
|
||||
exceptions++
|
||||
}
|
||||
}
|
||||
|
||||
// 更新对账报告
|
||||
status := model.ReconciliationStatusMatched
|
||||
if exceptions > 0 {
|
||||
status = model.ReconciliationStatusException
|
||||
}
|
||||
updates := map[string]any{
|
||||
"total_count": len(billData.Records),
|
||||
"total_amount": billData.TotalAmount,
|
||||
"matched_count": matched,
|
||||
"exception_count": exceptions,
|
||||
"status": status,
|
||||
}
|
||||
if err := s.reconRepo.UpdateReport(ctx, report.ID, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.InfoContext(ctx, "reconciliation done",
|
||||
"app_id", appID,
|
||||
"channel", channelCode,
|
||||
"bill_date", billDate,
|
||||
"matched", matched,
|
||||
"exceptions", exceptions,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReport 查询对账报告
|
||||
func (s *ReconciliationService) GetReport(ctx context.Context, appID, billDate, channelCode string) (*model.ReconciliationReport, error) {
|
||||
return s.reconRepo.GetReport(ctx, appID, billDate, channelCode)
|
||||
}
|
||||
|
||||
// GetExceptions 查询对账异常明细
|
||||
func (s *ReconciliationService) GetExceptions(ctx context.Context, reportID uint64) ([]*model.ReconciliationException, error) {
|
||||
return s.reconRepo.ListExceptions(ctx, reportID)
|
||||
}
|
||||
213
backend/internal/service/refund.go
Normal file
213
backend/internal/service/refund.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
// CreateRefundReq 退款请求
|
||||
type CreateRefundReq struct {
|
||||
AppID string
|
||||
TradeNo string
|
||||
RefundAmount int64
|
||||
Reason string
|
||||
NotifyURL string
|
||||
}
|
||||
|
||||
// RefundService 退款服务
|
||||
type RefundService struct {
|
||||
refundRepo *repository.RefundOrderRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
seqSvc *sequence.Service
|
||||
notifySvc *NotifyService
|
||||
}
|
||||
|
||||
func NewRefundService(
|
||||
refundRepo *repository.RefundOrderRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
seqSvc *sequence.Service,
|
||||
notifySvc *NotifyService,
|
||||
) *RefundService {
|
||||
return &RefundService{
|
||||
refundRepo: refundRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
seqSvc: seqSvc,
|
||||
notifySvc: notifySvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRefund 发起退款
|
||||
func (s *RefundService) CreateRefund(ctx context.Context, req *CreateRefundReq) (*model.RefundOrder, error) {
|
||||
// 查询原交易
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, req.TradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil || order.AppID != req.AppID {
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.Status != model.TradeStatusPaid && order.Status != model.TradeStatusRefunded {
|
||||
return nil, errors.New(errcode.ErrOrderNotPaid)
|
||||
}
|
||||
|
||||
// 校验可退金额
|
||||
refunded, err := s.refundRepo.SumRefundedAmount(ctx, req.TradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refunded+req.RefundAmount > order.Amount {
|
||||
return nil, errors.New(errcode.ErrRefundAmountExceed)
|
||||
}
|
||||
|
||||
// 生成退款单号
|
||||
refundNo, err := s.seqSvc.NextRefundNo(ctx, req.AppID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建退款记录
|
||||
refund := &model.RefundOrder{
|
||||
RefundNo: refundNo,
|
||||
TradeNo: req.TradeNo,
|
||||
AppID: req.AppID,
|
||||
ChannelCode: order.ChannelCode,
|
||||
RefundAmount: req.RefundAmount,
|
||||
Reason: req.Reason,
|
||||
Status: model.RefundStatusPending,
|
||||
NotifyURL: req.NotifyURL,
|
||||
}
|
||||
if err := s.refundRepo.Create(ctx, refund); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用渠道退款
|
||||
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelResp, err := ch.Refund(ctx, &channel.RefundReq{
|
||||
TradeNo: req.TradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
RefundNo: refundNo,
|
||||
RefundAmount: req.RefundAmount,
|
||||
TotalAmount: order.Amount,
|
||||
Reason: req.Reason,
|
||||
NotifyURL: req.NotifyURL,
|
||||
})
|
||||
if err != nil {
|
||||
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusFailed, nil)
|
||||
return nil, errors.New(errcode.ErrChannelRefundFail)
|
||||
}
|
||||
|
||||
// 更新渠道退款单号
|
||||
updates := map[string]any{
|
||||
"channel_refund_no": channelResp.ChannelRefundNo,
|
||||
"status": model.RefundStatusProcessing,
|
||||
}
|
||||
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusProcessing, updates)
|
||||
|
||||
refund.ChannelRefundNo = channelResp.ChannelRefundNo
|
||||
refund.Status = model.RefundStatusProcessing
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// QueryRefund 查询退款状态
|
||||
func (s *RefundService) QueryRefund(ctx context.Context, appID, refundNo string) (*model.RefundOrder, error) {
|
||||
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refund == nil || refund.AppID != appID {
|
||||
return nil, errors.New(errcode.ErrRefundNotFound)
|
||||
}
|
||||
|
||||
// 如果处于处理中,主动查询渠道
|
||||
if refund.Status == model.RefundStatusProcessing {
|
||||
s.syncRefundStatus(ctx, refund)
|
||||
// 重新查询最新状态
|
||||
refund, _ = s.refundRepo.GetByRefundNo(ctx, refundNo)
|
||||
}
|
||||
|
||||
return refund, nil
|
||||
}
|
||||
|
||||
// HandleRefundNotify 处理退款回调
|
||||
func (s *RefundService) HandleRefundNotify(ctx context.Context, refundNo string, channelRefundNo string, status model.RefundStatus) error {
|
||||
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
|
||||
if err != nil || refund == nil {
|
||||
return errors.New(errcode.ErrRefundNotFound)
|
||||
}
|
||||
|
||||
if refund.Status == model.RefundStatusSuccess {
|
||||
return nil // 幂等
|
||||
}
|
||||
|
||||
updates := map[string]any{
|
||||
"channel_refund_no": channelRefundNo,
|
||||
}
|
||||
if status == model.RefundStatusSuccess {
|
||||
now := time.Now()
|
||||
updates["refund_time"] = now
|
||||
}
|
||||
|
||||
ok, err := s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusProcessing, status, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil // 幂等
|
||||
}
|
||||
|
||||
// 退款成功后通知下游
|
||||
if status == model.RefundStatusSuccess && refund.NotifyURL != "" && s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
if err := s.notifySvc.SendNotify(bgCtx, refund.TradeNo, model.NotifyTypeRefund, refund.NotifyURL); err != nil {
|
||||
slog.Error("send refund notify failed", "refund_no", refundNo, "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RefundService) syncRefundStatus(ctx context.Context, refund *model.RefundOrder) {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, refund.TradeNo)
|
||||
if err != nil || order == nil {
|
||||
return
|
||||
}
|
||||
ch, err := s.channelSvc.GetChannel(ctx, refund.AppID, refund.ChannelCode)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := ch.QueryRefund(ctx, &channel.QueryRefundReq{
|
||||
RefundNo: refund.RefundNo,
|
||||
ChannelRefundNo: refund.ChannelRefundNo,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if resp.Status != refund.Status {
|
||||
updates := map[string]any{
|
||||
"channel_refund_no": resp.ChannelRefundNo,
|
||||
}
|
||||
if resp.RefundTime != nil {
|
||||
updates["refund_time"] = resp.RefundTime
|
||||
}
|
||||
s.refundRepo.UpdateStatus(ctx, refund.RefundNo, refund.Status, resp.Status, updates)
|
||||
}
|
||||
}
|
||||
|
||||
189
backend/internal/service/service_fee.go
Normal file
189
backend/internal/service/service_fee.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
)
|
||||
|
||||
// ServiceFeeService 服务费服务
|
||||
type ServiceFeeService struct {
|
||||
feeRepo *repository.ServiceFeeRepository
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
}
|
||||
|
||||
func NewServiceFeeService(
|
||||
feeRepo *repository.ServiceFeeRepository,
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
) *ServiceFeeService {
|
||||
return &ServiceFeeService{
|
||||
feeRepo: feeRepo,
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// ChargeServiceFee 交易完成后扣收服务费
|
||||
func (s *ServiceFeeService) ChargeServiceFee(ctx context.Context, tradeNo string) error {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return fmt.Errorf("order not found: %s", tradeNo)
|
||||
}
|
||||
|
||||
// 幂等检查
|
||||
existing, err := s.feeRepo.GetLog(ctx, tradeNo, "CHARGE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil // 已扣收
|
||||
}
|
||||
|
||||
// 获取服务费配置
|
||||
group := model.PayMethodToGroup(order.PayMethod)
|
||||
cfg, err := s.feeRepo.GetConfig(ctx, order.AppID, group)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil || cfg.FeeRate == 0 {
|
||||
return nil // 未配置或费率为0
|
||||
}
|
||||
|
||||
// 计算服务费(四舍五入到分)
|
||||
feeAmount := calculateFee(order.Amount, cfg.FeeRate)
|
||||
if feeAmount <= 0 {
|
||||
return nil // 不足1分不扣收
|
||||
}
|
||||
|
||||
// 更新订单服务费金额快照
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaid, model.TradeStatusPaid,
|
||||
map[string]any{"service_fee_amount": feeAmount})
|
||||
|
||||
// 创建服务费流水
|
||||
log := &model.ServiceFeeLog{
|
||||
TradeNo: tradeNo,
|
||||
ConfigID: cfg.ID,
|
||||
FeeAmount: feeAmount,
|
||||
FeeRate: cfg.FeeRate,
|
||||
ReceiverMerchantID: cfg.FeeReceiverMerchantID,
|
||||
Action: "CHARGE",
|
||||
Status: "PENDING",
|
||||
}
|
||||
if err := s.feeRepo.CreateLog(ctx, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 调用渠道分账
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
s.feeRepo.UpdateLogStatus(ctx, log.ID, "FAILED", "")
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := ch.ProfitSharing(ctx, &channel.ProfitSharingReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
SharingNo: fmt.Sprintf("FEE%s", tradeNo),
|
||||
ReceiverMerchantID: cfg.FeeReceiverMerchantID,
|
||||
Amount: feeAmount,
|
||||
})
|
||||
if err != nil {
|
||||
s.feeRepo.UpdateLogStatus(ctx, log.ID, "FAILED", "")
|
||||
slog.WarnContext(ctx, "charge service fee failed", "trade_no", tradeNo, "err", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.feeRepo.UpdateLogStatus(ctx, log.ID, "SUCCESS", resp.ChannelSharingNo)
|
||||
slog.InfoContext(ctx, "service fee charged", "trade_no", tradeNo, "fee_amount", feeAmount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RollbackServiceFee 退款时回退服务费
|
||||
func (s *ServiceFeeService) RollbackServiceFee(ctx context.Context, tradeNo string) error {
|
||||
// 幂等检查
|
||||
existing, err := s.feeRepo.GetLog(ctx, tradeNo, "ROLLBACK")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
return nil // 已回退
|
||||
}
|
||||
|
||||
chargeLog, err := s.feeRepo.GetLog(ctx, tradeNo, "CHARGE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if chargeLog == nil || chargeLog.Status != "SUCCESS" {
|
||||
return nil // 没有成功扣收,无需回退
|
||||
}
|
||||
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil || order == nil {
|
||||
return fmt.Errorf("order not found: %s", tradeNo)
|
||||
}
|
||||
|
||||
rollbackLog := &model.ServiceFeeLog{
|
||||
TradeNo: tradeNo,
|
||||
ConfigID: chargeLog.ConfigID,
|
||||
FeeAmount: chargeLog.FeeAmount,
|
||||
FeeRate: chargeLog.FeeRate,
|
||||
ReceiverMerchantID: chargeLog.ReceiverMerchantID,
|
||||
Action: "ROLLBACK",
|
||||
Status: "PENDING",
|
||||
}
|
||||
if err := s.feeRepo.CreateLog(ctx, rollbackLog); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sharingNo := fmt.Sprintf("FEE%s", tradeNo)
|
||||
if err := ch.RollbackProfitSharing(ctx, &channel.RollbackSharingReq{
|
||||
SharingNo: sharingNo,
|
||||
ChannelSharingNo: chargeLog.ChannelSharingNo,
|
||||
TradeNo: tradeNo,
|
||||
}); err != nil {
|
||||
s.feeRepo.UpdateLogStatus(ctx, rollbackLog.ID, "FAILED", "")
|
||||
return err
|
||||
}
|
||||
|
||||
s.feeRepo.UpdateLogStatus(ctx, rollbackLog.ID, "SUCCESS", "")
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateAndValidate 下单时校验分润+服务费不超过订单金额
|
||||
func (s *ServiceFeeService) CalculateAndValidate(ctx context.Context, appID string, payMethod model.PayMethod, orderAmount, sharingAmount int64) (int64, error) {
|
||||
group := model.PayMethodToGroup(payMethod)
|
||||
cfg, err := s.feeRepo.GetConfig(ctx, appID, group)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var feeAmount int64
|
||||
if cfg != nil && cfg.FeeRate > 0 {
|
||||
feeAmount = calculateFee(orderAmount, cfg.FeeRate)
|
||||
}
|
||||
|
||||
if sharingAmount+feeAmount > orderAmount {
|
||||
return 0, fmt.Errorf(errSharingFeeExceed)
|
||||
}
|
||||
return feeAmount, nil
|
||||
}
|
||||
|
||||
const errSharingFeeExceed = "30007" // errcode.ErrSharingFeeExceed
|
||||
|
||||
// calculateFee 计算服务费(四舍五入到分)
|
||||
func calculateFee(amount int64, rate float64) int64 {
|
||||
fee := float64(amount) * rate
|
||||
return int64(math.Round(fee))
|
||||
}
|
||||
371
backend/internal/service/trade.go
Normal file
371
backend/internal/service/trade.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"encoding/json"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/errcode"
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/sequence"
|
||||
)
|
||||
|
||||
const (
|
||||
orderExpireDefault = 30 * time.Minute
|
||||
idempotentKeyPrefix = "idempotent:"
|
||||
idempotentTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// CreateOrderReq 下单请求
|
||||
type CreateOrderReq struct {
|
||||
AppID string
|
||||
ChannelCode string // 指定渠道,为空时使用 defaultChannelCode
|
||||
MerchantOrderNo string
|
||||
PayMethod model.PayMethod
|
||||
Amount int64
|
||||
ProfitSharingAmount int64
|
||||
Subject string
|
||||
NotifyURL string
|
||||
ExpireMinutes int
|
||||
Extra map[string]any
|
||||
MerchantID string // 可选,指定收款商户(SaaS 多商户路由)
|
||||
}
|
||||
|
||||
// CreateOrderResp 下单响应
|
||||
type CreateOrderResp struct {
|
||||
TradeNo string
|
||||
PayCredential map[string]any
|
||||
IsIdempotent bool // true=幂等返回
|
||||
}
|
||||
|
||||
// TradeService 交易服务
|
||||
type TradeService struct {
|
||||
tradeRepo *repository.TradeOrderRepository
|
||||
channelSvc *ChannelService
|
||||
merchantSvc *MerchantService
|
||||
seqSvc *sequence.Service
|
||||
rdb *redis.Client
|
||||
notifySvc *NotifyService
|
||||
}
|
||||
|
||||
func NewTradeService(
|
||||
tradeRepo *repository.TradeOrderRepository,
|
||||
channelSvc *ChannelService,
|
||||
seqSvc *sequence.Service,
|
||||
rdb *redis.Client,
|
||||
notifySvc *NotifyService,
|
||||
merchantSvc *MerchantService,
|
||||
) *TradeService {
|
||||
return &TradeService{
|
||||
tradeRepo: tradeRepo,
|
||||
channelSvc: channelSvc,
|
||||
merchantSvc: merchantSvc,
|
||||
seqSvc: seqSvc,
|
||||
rdb: rdb,
|
||||
notifySvc: notifySvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateOrder 统一下单(含幂等控制)
|
||||
func (s *TradeService) CreateOrder(ctx context.Context, req *CreateOrderReq) (*CreateOrderResp, error) {
|
||||
// 参数校验
|
||||
if req.Amount <= 0 {
|
||||
return nil, errors.New(errcode.ErrInvalidAmount)
|
||||
}
|
||||
|
||||
// 幂等检查 - Redis SET NX
|
||||
idempotentKey := fmt.Sprintf("%s%s:%s", idempotentKeyPrefix, req.AppID, req.MerchantOrderNo)
|
||||
set, err := s.rdb.SetNX(ctx, idempotentKey, "1", idempotentTTL).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
slog.WarnContext(ctx, "redis idempotent check failed, fallback to db", "err", err)
|
||||
}
|
||||
|
||||
if !set {
|
||||
// 幂等命中,查询已有订单
|
||||
order, err := s.tradeRepo.GetByMerchantOrderNo(ctx, req.AppID, req.MerchantOrderNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil {
|
||||
// Redis key 存在但 DB 无记录(极端情况),清除 key 重试
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.Status == model.TradeStatusPaid {
|
||||
return nil, errors.New(errcode.ErrOrderAlreadyPaid)
|
||||
}
|
||||
if order.Status == model.TradeStatusClosed {
|
||||
return nil, errors.New(errcode.ErrOrderClosed)
|
||||
}
|
||||
return &CreateOrderResp{
|
||||
TradeNo: order.TradeNo,
|
||||
PayCredential: order.ChannelExtra,
|
||||
IsIdempotent: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 生成交易号
|
||||
tradeNo, err := s.seqSvc.NextTradeNo(ctx, req.AppID)
|
||||
if err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 计算过期时间
|
||||
expireMinutes := req.ExpireMinutes
|
||||
if expireMinutes <= 0 {
|
||||
expireMinutes = int(orderExpireDefault.Minutes())
|
||||
}
|
||||
expireTime := time.Now().Add(time.Duration(expireMinutes) * time.Minute)
|
||||
|
||||
// 确定渠道
|
||||
channelCode := req.ChannelCode
|
||||
if channelCode == "" {
|
||||
channelCode = "HEEPAY" // 向后兼容默认值,建议调用方明确传入
|
||||
}
|
||||
|
||||
// 可选:指定收款商户(SaaS 多商户路由),校验归属并按渠道注入 sub_merchant_id
|
||||
if req.MerchantID != "" && s.merchantSvc != nil {
|
||||
// 校验商户归属(只能使用本 appID 下的商户)
|
||||
if _, err := s.merchantSvc.GetMerchantForApp(ctx, req.AppID, req.MerchantID); err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
// 按实际下单渠道取对应进件记录的 channel_merchant_id
|
||||
channelMerchantID, err := s.merchantSvc.GetChannelMerchantID(ctx, req.MerchantID, channelCode)
|
||||
if err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
if channelMerchantID != "" {
|
||||
if req.Extra == nil {
|
||||
req.Extra = make(map[string]any)
|
||||
}
|
||||
req.Extra["sub_merchant_id"] = channelMerchantID
|
||||
}
|
||||
}
|
||||
|
||||
// 创建本地订单记录(CREATING 状态)
|
||||
order := &model.TradeOrder{
|
||||
TradeNo: tradeNo,
|
||||
MerchantOrderNo: req.MerchantOrderNo,
|
||||
AppID: req.AppID,
|
||||
ChannelCode: channelCode,
|
||||
PayMethod: req.PayMethod,
|
||||
Amount: req.Amount,
|
||||
ProfitSharingAmount: req.ProfitSharingAmount,
|
||||
Subject: req.Subject,
|
||||
NotifyURL: req.NotifyURL,
|
||||
Status: model.TradeStatusCreating,
|
||||
Extra: req.Extra,
|
||||
ExpireTime: expireTime,
|
||||
}
|
||||
if err := s.tradeRepo.Create(ctx, order); err != nil {
|
||||
s.rdb.Del(ctx, idempotentKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用渠道下单
|
||||
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, channelCode)
|
||||
if err != nil {
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
|
||||
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
|
||||
}
|
||||
|
||||
channelReq := &channel.CreateOrderReq{
|
||||
AppID: req.AppID,
|
||||
TradeNo: tradeNo,
|
||||
MerchantOrderNo: req.MerchantOrderNo,
|
||||
PayMethod: req.PayMethod,
|
||||
Amount: req.Amount,
|
||||
Subject: req.Subject,
|
||||
NotifyURL: req.NotifyURL,
|
||||
ExpireTime: expireTime,
|
||||
Extra: req.Extra,
|
||||
}
|
||||
|
||||
channelResp, err := ch.CreateOrder(ctx, channelReq)
|
||||
if err != nil {
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
|
||||
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
|
||||
}
|
||||
|
||||
// 更新为 PAYING 状态,保存支付凭证
|
||||
updates := map[string]any{
|
||||
"channel_trade_no": channelResp.ChannelTradeNo,
|
||||
"channel_extra": model.JSONMap(channelResp.PayCredential),
|
||||
}
|
||||
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusPaying, updates)
|
||||
|
||||
return &CreateOrderResp{
|
||||
TradeNo: tradeNo,
|
||||
PayCredential: channelResp.PayCredential,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder 查询交易状态
|
||||
func (s *TradeService) QueryOrder(ctx context.Context, appID, tradeNo string) (*model.TradeOrder, error) {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if order == nil || order.AppID != appID {
|
||||
return nil, errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
|
||||
// 如果处于 PAYING 状态,主动查询渠道同步最新状态
|
||||
if order.Status == model.TradeStatusPaying {
|
||||
s.syncOrderStatus(ctx, order)
|
||||
}
|
||||
|
||||
return order, nil
|
||||
}
|
||||
|
||||
// CloseOrder 关闭订单
|
||||
func (s *TradeService) CloseOrder(ctx context.Context, appID, tradeNo string) error {
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if order == nil || order.AppID != appID {
|
||||
return errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
if order.Status == model.TradeStatusPaid {
|
||||
return errors.New(errcode.ErrOrderAlreadyPaid)
|
||||
}
|
||||
if order.Status == model.TradeStatusClosed {
|
||||
return nil // 已关闭,幂等
|
||||
}
|
||||
if order.Status != model.TradeStatusPaying {
|
||||
return errors.New(errcode.ErrOrderClosed)
|
||||
}
|
||||
|
||||
// 调用渠道关单
|
||||
ch, err := s.channelSvc.GetChannel(ctx, appID, order.ChannelCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ch.CloseOrder(ctx, &channel.CloseOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
}); err != nil {
|
||||
slog.WarnContext(ctx, "close order on channel failed", "trade_no", tradeNo, "err", err)
|
||||
}
|
||||
|
||||
_, err = s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusClosed, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// HandleUpstreamNotify 处理上游支付回调(验签 + 状态更新 + 触发通知下游)
|
||||
//
|
||||
// 流程:先用临时无配置实例从 body 提取 trade_no → 查 DB 得 appID → 加载完整渠道配置验签
|
||||
func (s *TradeService) HandleUpstreamNotify(ctx context.Context, channelCode string, rawBody []byte, headers map[string]string) (string, error) {
|
||||
// 用只负责解析的临时渠道实例提取交易号(不需要密钥配置)
|
||||
tempCh, err := channel.Get(channelCode, nil, channel.URLs{})
|
||||
if err != nil {
|
||||
return "fail", fmt.Errorf("unknown channel: %s", channelCode)
|
||||
}
|
||||
tradeNo, err := tempCh.ExtractTradeNo(rawBody)
|
||||
if err != nil || tradeNo == "" {
|
||||
return "fail", fmt.Errorf("extract trade_no from notify: %w", err)
|
||||
}
|
||||
|
||||
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
|
||||
if err != nil {
|
||||
return "fail", err
|
||||
}
|
||||
if order == nil {
|
||||
slog.WarnContext(ctx, "notify: order not found", "trade_no", tradeNo)
|
||||
return "fail", errors.New(errcode.ErrOrderNotFound)
|
||||
}
|
||||
|
||||
// 加载完整渠道配置并验签
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, channelCode)
|
||||
if err != nil {
|
||||
return "fail", err
|
||||
}
|
||||
|
||||
notifyData, err := ch.VerifyNotify(ctx, rawBody, headers)
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "notify: verify sign failed", "trade_no", tradeNo, "err", err)
|
||||
return "fail", errors.New(errcode.ErrChannelVerifyFail)
|
||||
}
|
||||
|
||||
// 处理支付通知
|
||||
if notifyData.NotifyType == model.NotifyTypePayment && notifyData.Status == model.TradeStatusPaid {
|
||||
if err := s.handlePaymentSuccess(ctx, order, notifyData); err != nil {
|
||||
return "fail", err
|
||||
}
|
||||
}
|
||||
|
||||
return "success", nil
|
||||
}
|
||||
|
||||
// handlePaymentSuccess 处理支付成功
|
||||
func (s *TradeService) handlePaymentSuccess(ctx context.Context, order *model.TradeOrder, data *channel.NotifyData) error {
|
||||
updates := map[string]any{
|
||||
"channel_trade_no": data.ChannelTradeNo,
|
||||
}
|
||||
if data.PayTime != nil {
|
||||
updates["pay_time"] = data.PayTime
|
||||
}
|
||||
|
||||
ok, err := s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// 已被处理过(幂等),直接返回成功
|
||||
slog.InfoContext(ctx, "payment notify idempotent", "trade_no", order.TradeNo)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 异步触发下游通知
|
||||
if s.notifySvc != nil {
|
||||
go func() {
|
||||
bgCtx := context.Background()
|
||||
if err := s.notifySvc.SendNotify(bgCtx, order.TradeNo, model.NotifyTypePayment, order.NotifyURL); err != nil {
|
||||
slog.Error("send notify failed", "trade_no", order.TradeNo, "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncOrderStatus 主动查询渠道同步订单状态(查询接口兜底)
|
||||
func (s *TradeService) syncOrderStatus(ctx context.Context, order *model.TradeOrder) {
|
||||
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "syncOrderStatus: get channel failed", "trade_no", order.TradeNo, "err", err)
|
||||
return
|
||||
}
|
||||
resp, err := ch.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: order.TradeNo,
|
||||
ChannelTradeNo: order.ChannelTradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
slog.WarnContext(ctx, "syncOrderStatus: query channel failed", "trade_no", order.TradeNo, "err", err)
|
||||
return
|
||||
}
|
||||
if resp.Status == model.TradeStatusPaid {
|
||||
updates := map[string]any{
|
||||
"channel_trade_no": resp.ChannelTradeNo,
|
||||
}
|
||||
if resp.PayTime != nil {
|
||||
updates["pay_time"] = resp.PayTime
|
||||
}
|
||||
s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSON(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
161
backend/internal/service/wechat.go
Normal file
161
backend/internal/service/wechat.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/model"
|
||||
"pay-bridge/internal/repository"
|
||||
"pay-bridge/pkg/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
wxTokenURL = "https://api.weixin.qq.com/cgi-bin/token"
|
||||
wxSendMsgURL = "https://api.weixin.qq.com/cgi-bin/message/template/send"
|
||||
accessTokenTTL = 90 * time.Minute // 微信 access_token 有效期 2h,提前 30min 刷新
|
||||
)
|
||||
|
||||
// WechatService 微信模板消息服务
|
||||
type WechatService struct {
|
||||
wechatRepo *repository.WechatRepository
|
||||
cryptoKey string
|
||||
httpClient *http.Client
|
||||
// 内存缓存 access_token,避免频繁调用微信接口
|
||||
tokenCache map[string]*tokenEntry
|
||||
}
|
||||
|
||||
type tokenEntry struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewWechatService(wechatRepo *repository.WechatRepository, cryptoKey string) *WechatService {
|
||||
return &WechatService{
|
||||
wechatRepo: wechatRepo,
|
||||
cryptoKey: cryptoKey,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
tokenCache: make(map[string]*tokenEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// SendPaymentNotify 发送支付成功通知
|
||||
func (s *WechatService) SendPaymentNotify(ctx context.Context, appID, tradeNo, openID string, amount int64) error {
|
||||
binding, err := s.wechatRepo.GetBinding(ctx, appID)
|
||||
if err != nil || binding == nil {
|
||||
return nil // 未配置微信通知,跳过
|
||||
}
|
||||
|
||||
data := map[string]any{
|
||||
"trade_no": map[string]string{"value": tradeNo},
|
||||
"amount": map[string]string{"value": fmt.Sprintf("%.2f 元", float64(amount)/100)},
|
||||
"time": map[string]string{"value": time.Now().Format("2006-01-02 15:04:05")},
|
||||
}
|
||||
|
||||
return s.sendTemplate(ctx, appID, binding, openID, tradeNo, data)
|
||||
}
|
||||
|
||||
// sendTemplate 发送模板消息
|
||||
func (s *WechatService) sendTemplate(ctx context.Context, appID string, binding *model.WechatBinding,
|
||||
openID, tradeNo string, data map[string]any) error {
|
||||
|
||||
log := &model.WechatMessageLog{
|
||||
AppID: appID,
|
||||
TradeNo: tradeNo,
|
||||
OpenID: openID,
|
||||
TemplateID: binding.TemplateID,
|
||||
Status: model.WechatMessageStatusPending,
|
||||
}
|
||||
if err := s.wechatRepo.CreateMessageLog(ctx, log); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := s.getAccessToken(ctx, binding)
|
||||
if err != nil {
|
||||
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": err.Error()}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
return err
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"touser": openID,
|
||||
"template_id": binding.TemplateID,
|
||||
"data": data,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
url := fmt.Sprintf("%s?access_token=%s", wxSendMsgURL, token)
|
||||
resp, err := s.httpClient.Post(url, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": err.Error()}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
json.Unmarshal(respBody, &result)
|
||||
|
||||
now := time.Now()
|
||||
if result.ErrCode == 0 {
|
||||
updates := map[string]any{"status": model.WechatMessageStatusSuccess, "sent_at": now}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
slog.InfoContext(ctx, "wechat template sent", "trade_no", tradeNo, "open_id", openID)
|
||||
} else {
|
||||
errMsg := fmt.Sprintf("errcode=%d errmsg=%s", result.ErrCode, result.ErrMsg)
|
||||
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": errMsg}
|
||||
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
|
||||
return fmt.Errorf("wechat send failed: %s", errMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken 获取微信 access_token(带内存缓存)
|
||||
func (s *WechatService) getAccessToken(ctx context.Context, binding *model.WechatBinding) (string, error) {
|
||||
if entry, ok := s.tokenCache[binding.WxAppID]; ok && time.Now().Before(entry.expiresAt) {
|
||||
return entry.token, nil
|
||||
}
|
||||
|
||||
// 解密 secret
|
||||
secret, err := crypto.Decrypt(binding.WxSecret, s.cryptoKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt wx secret: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s",
|
||||
wxTokenURL, binding.WxAppID, secret)
|
||||
resp, err := s.httpClient.Get(url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get wx token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
return "", fmt.Errorf("wx token error: %d %s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
s.tokenCache[binding.WxAppID] = &tokenEntry{
|
||||
token: result.AccessToken,
|
||||
expiresAt: time.Now().Add(accessTokenTTL),
|
||||
}
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
Reference in New Issue
Block a user