This commit is contained in:
2026-03-13 15:51:59 +08:00
parent 4db2386bbf
commit 4e91f4cede
133 changed files with 19502 additions and 37 deletions

View File

@@ -0,0 +1,74 @@
package service
import (
"context"
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"pay-bridge/internal/repository"
)
type AdminAuthService struct {
repo *repository.AdminUserRepository
jwtSecret []byte
expireHrs int
}
func NewAdminAuthService(repo *repository.AdminUserRepository, jwtSecret string, expireHours int) *AdminAuthService {
return &AdminAuthService{
repo: repo,
jwtSecret: []byte(jwtSecret),
expireHrs: expireHours,
}
}
// Login 验证用户名密码,成功返回 JWT token
func (s *AdminAuthService) Login(ctx context.Context, username, password string) (string, error) {
user, err := s.repo.GetByUsername(ctx, username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", errors.New("用户名或密码错误")
}
return "", err
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return "", errors.New("用户名或密码错误")
}
claims := jwt.MapClaims{
"username": user.Username,
"exp": time.Now().Add(time.Duration(s.expireHrs) * time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.jwtSecret)
}
// ParseToken 验证并解析 JWT返回用户名
func (s *AdminAuthService) ParseToken(tokenStr string) (string, error) {
token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("invalid signing method")
}
return s.jwtSecret, nil
}, jwt.WithValidMethods([]string{"HS256"}))
if err != nil {
return "", err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return "", errors.New("invalid token")
}
username, ok := claims["username"].(string)
if !ok {
return "", errors.New("invalid token claims")
}
return username, nil
}

View File

@@ -0,0 +1,141 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
"pay-bridge/internal/errcode"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/crypto"
)
// AppService 应用服务
type AppService struct {
repo *repository.AppRepository
encKey string
}
func NewAppService(repo *repository.AppRepository, encKey string) *AppService {
return &AppService{repo: repo, encKey: encKey}
}
// GetAppSecret 获取 appSecret用于鉴权中间件
func (s *AppService) GetAppSecret(ctx context.Context, appID string) (string, error) {
app, err := s.repo.GetByAppID(ctx, appID)
if err != nil {
return "", err
}
if app == nil {
return "", errors.New(errcode.ErrAppNotFound)
}
secret, err := crypto.Decrypt(app.AppSecret, s.encKey)
if err != nil {
return "", fmt.Errorf("decrypt app secret: %w", err)
}
return secret, nil
}
// GetApp 获取应用信息
func (s *AppService) GetApp(ctx context.Context, appID string) (*model.App, error) {
return s.repo.GetByAppID(ctx, appID)
}
// CreateAppResult 创建应用的返回,包含明文 secret仅展示一次
type CreateAppResult struct {
App *model.App
PlainSecret string
}
// CreateApp 创建应用,自动生成 app_id 和 app_secret
func (s *AppService) CreateApp(ctx context.Context, appName string) (*CreateAppResult, error) {
appID := generateAppID()
plainSecret := generateSecret()
encSecret, err := crypto.Encrypt(plainSecret, s.encKey)
if err != nil {
return nil, err
}
app := &model.App{
AppID: appID,
AppSecret: encSecret,
AppName: appName,
Status: 1,
}
if err := s.repo.Create(ctx, app); err != nil {
return nil, err
}
return &CreateAppResult{App: app, PlainSecret: plainSecret}, nil
}
// ListApps 分页查询应用列表
func (s *AppService) ListApps(ctx context.Context, limit, offset int) ([]*model.App, error) {
return s.repo.List(ctx, limit, offset)
}
// DisableApp 禁用应用
func (s *AppService) DisableApp(ctx context.Context, appID string) error {
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
if err != nil {
return err
}
if app == nil {
return errors.New(errcode.ErrAppNotFound)
}
return s.repo.UpdateStatus(ctx, appID, 0)
}
// EnableApp 启用应用
func (s *AppService) EnableApp(ctx context.Context, appID string) error {
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
if err != nil {
return err
}
if app == nil {
return errors.New(errcode.ErrAppNotFound)
}
return s.repo.UpdateStatus(ctx, appID, 1)
}
// ResetSecret 重置应用密钥,返回新的明文 secret仅此一次
func (s *AppService) ResetSecret(ctx context.Context, appID string) (string, error) {
app, err := s.repo.GetByAppIDUnscoped(ctx, appID)
if err != nil {
return "", err
}
if app == nil {
return "", errors.New(errcode.ErrAppNotFound)
}
plainSecret := generateSecret()
encSecret, err := crypto.Encrypt(plainSecret, s.encKey)
if err != nil {
return "", err
}
if err := s.repo.UpdateSecret(ctx, appID, encSecret); err != nil {
return "", err
}
return plainSecret, nil
}
// generateAppID 生成 app_idapp_ + yyMMdd + 8位随机hex
func generateAppID() string {
b := make([]byte, 4)
_, _ = rand.Read(b)
date := time.Now().Format("060102")
return "app_" + date + hex.EncodeToString(b)
}
// generateSecret 生成 32 字节随机 secret64位hex
func generateSecret() string {
b := make([]byte, 32)
_, _ = rand.Read(b)
return strings.ToUpper(hex.EncodeToString(b))
}

View File

@@ -0,0 +1,140 @@
package service
import (
"context"
"fmt"
"sync"
"time"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/config"
"pay-bridge/pkg/crypto"
)
const channelCacheTTL = 5 * time.Minute
type cachedChannel struct {
ch channel.PaymentChannel
expiresAt time.Time
}
// ChannelService 渠道服务(负责加载渠道配置并获取渠道实例)
type ChannelService struct {
repo *repository.ChannelConfigRepository
encKey string
urlsCfg config.ChannelsConfig
mu sync.Mutex
cache map[string]*cachedChannel
}
func NewChannelService(repo *repository.ChannelConfigRepository, encKey string, urlsCfg config.ChannelsConfig) *ChannelService {
return &ChannelService{
repo: repo,
encKey: encKey,
urlsCfg: urlsCfg,
cache: make(map[string]*cachedChannel),
}
}
// GetChannel 根据 appID 和渠道码获取渠道适配器实例5 分钟内存缓存)
func (s *ChannelService) GetChannel(ctx context.Context, appID, channelCode string) (channel.PaymentChannel, error) {
cacheKey := appID + ":" + channelCode
s.mu.Lock()
if entry, ok := s.cache[cacheKey]; ok && time.Now().Before(entry.expiresAt) {
ch := entry.ch
s.mu.Unlock()
return ch, nil
}
s.mu.Unlock()
cfg, err := s.repo.GetByAppChannel(ctx, appID, channelCode)
if err != nil {
return nil, err
}
if cfg == nil {
return nil, fmt.Errorf("channel config not found: app=%s channel=%s", appID, channelCode)
}
decCfg, err := s.decryptConfig(cfg)
if err != nil {
return nil, err
}
ch, err := channel.Get(channelCode, decCfg, s.urlsFor(channelCode))
if err != nil {
return nil, err
}
s.mu.Lock()
s.cache[cacheKey] = &cachedChannel{ch: ch, expiresAt: time.Now().Add(channelCacheTTL)}
s.mu.Unlock()
return ch, nil
}
// InvalidateCache 使指定渠道的缓存失效(配置变更时调用)
func (s *ChannelService) InvalidateCache(appID, channelCode string) {
s.mu.Lock()
delete(s.cache, appID+":"+channelCode)
s.mu.Unlock()
}
// ListChannelCodes 获取应用下所有渠道码
func (s *ChannelService) ListChannelCodes(ctx context.Context, appID string) ([]string, error) {
cfgs, err := s.repo.ListByApp(ctx, appID)
if err != nil {
return nil, err
}
codes := make([]string, 0, len(cfgs))
for _, c := range cfgs {
codes = append(codes, c.ChannelCode)
}
return codes, nil
}
// GetChannelConfig 获取渠道配置(已解密)
func (s *ChannelService) GetChannelConfig(ctx context.Context, appID, channelCode string) (*model.ChannelConfig, error) {
cfg, err := s.repo.GetByAppChannel(ctx, appID, channelCode)
if err != nil || cfg == nil {
return cfg, err
}
return s.decryptConfig(cfg)
}
// urlsFor 根据渠道码返回对应的网关地址配置
func (s *ChannelService) urlsFor(channelCode string) channel.URLs {
switch channelCode {
case "HEEPAY":
return channel.URLs{
PayURL: s.urlsCfg.Heepay.PayURL,
MerchantURL: s.urlsCfg.Heepay.MerchantURL,
}
default:
return channel.URLs{}
}
}
func (s *ChannelService) decryptConfig(cfg *model.ChannelConfig) (*model.ChannelConfig, error) {
copied := *cfg
if cfg.APIKey != "" {
dec, err := crypto.Decrypt(cfg.APIKey, s.encKey)
if err != nil {
return nil, fmt.Errorf("decrypt api_key: %w", err)
}
copied.APIKey = dec
}
if cfg.PrivateKey != "" {
dec, err := crypto.Decrypt(cfg.PrivateKey, s.encKey)
if err != nil {
return nil, fmt.Errorf("decrypt private_key: %w", err)
}
copied.PrivateKey = dec
}
return &copied, nil
}

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"log/slog"
"time"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
)
// merchantRepo 定义 MerchantService 所需的数据访问方法,便于测试时注入 mock
type merchantRepo interface {
Create(ctx context.Context, m *model.Merchant) error
GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error)
GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error)
UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error
List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error)
ListAnomalous(ctx context.Context) ([]*model.Merchant, error)
CreateApplication(ctx context.Context, app *model.MerchantApplication) error
GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error)
GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error)
UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error
}
// MerchantService 商户进件与管理服务
type MerchantService struct {
merchantRepo merchantRepo
channelSvc *ChannelService
}
func NewMerchantService(
merchantRepo *repository.MerchantRepository,
channelSvc *ChannelService,
) *MerchantService {
return &MerchantService{
merchantRepo: merchantRepo,
channelSvc: channelSvc,
}
}
func genApplicationID() string {
b := make([]byte, 16)
rand.Read(b)
return "APP" + hex.EncodeToString(b)[:16]
}
// Apply 提交商户进件申请
// bizContent 为完整的入网申请业务参数(对应 001 文档的 biz_content 结构)
func (s *MerchantService) Apply(ctx context.Context, merchantID, channelCode string, bizContent map[string]any) (string, error) {
merchant, err := s.merchantRepo.GetByMerchantID(ctx, merchantID)
if err != nil {
return "", err
}
if merchant == nil {
return "", errors.New("merchant not found")
}
if merchant.Status == model.MerchantStatusFrozen {
return "", errors.New("merchant is frozen")
}
ch, err := s.channelSvc.GetChannel(ctx, "", channelCode)
if err != nil {
return "", err
}
resp, err := ch.MerchantApply(ctx, &channel.MerchantApplyReq{
MerchantID: merchantID,
BizContent: bizContent,
})
if err != nil {
return "", err
}
applicationID := genApplicationID()
app := &model.MerchantApplication{
ApplicationID: applicationID,
MerchantID: merchantID,
ChannelCode: channelCode,
SubmitData: model.JSONMap(bizContent),
AuditStatus: model.AuditStatusSubmitting,
SubmittedAt: time.Now(),
}
// 持久化渠道返回的 request_no用于后续查询/修改
if resp.RequestNo != "" {
app.SubmitData["_channel_request_no"] = resp.RequestNo
}
if err := s.merchantRepo.CreateApplication(ctx, app); err != nil {
return "", err
}
slog.InfoContext(ctx, "merchant application submitted",
"merchant_id", merchantID,
"application_id", applicationID,
"channel_code", channelCode,
"channel_request_no", resp.RequestNo,
)
return applicationID, nil
}
// UploadFile 上传文件到指定渠道,返回渠道 file_id
func (s *MerchantService) UploadFile(ctx context.Context, channelCode string, req *channel.UploadFileReq) (string, error) {
ch, err := s.channelSvc.GetChannel(ctx, "", channelCode)
if err != nil {
return "", err
}
resp, err := ch.UploadFile(ctx, req)
if err != nil {
return "", err
}
return resp.FileID, nil
}
// QueryAuditStatus 查询进件审核状态
func (s *MerchantService) QueryAuditStatus(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
app, err := s.merchantRepo.GetLatestApplication(ctx, merchantID)
if err != nil {
return nil, err
}
if app == nil {
return nil, nil
}
// 如果仍在审核中,向渠道查询最新状态
if app.AuditStatus == model.AuditStatusSubmitting || app.AuditStatus == model.AuditStatusReviewing {
// 从 submit_data 中读取渠道返回的 request_no
channelRequestNo, _ := app.SubmitData["_channel_request_no"].(string)
if channelRequestNo != "" {
ch, err := s.channelSvc.GetChannel(ctx, "", app.ChannelCode)
if err == nil {
resp, err := ch.QueryMerchantStatus(ctx, channelRequestNo)
if err == nil {
merchant, _ := s.merchantRepo.GetByMerchantID(ctx, merchantID)
s.syncMerchantStatus(ctx, merchantID, app.ApplicationID, merchant, resp)
app, _ = s.merchantRepo.GetLatestApplication(ctx, merchantID)
}
}
}
}
return app, nil
}
// syncMerchantStatus 同步渠道返回的审核状态到本地
func (s *MerchantService) syncMerchantStatus(ctx context.Context, merchantID, applicationID string,
merchant *model.Merchant, resp *channel.MerchantStatusResp) {
now := time.Now()
appUpdates := map[string]any{}
switch resp.Status {
case "APPROVED":
appUpdates["audit_status"] = model.AuditStatusApproved
appUpdates["audited_at"] = now
if resp.ChannelMerchantID != "" {
appUpdates["channel_merchant_id"] = resp.ChannelMerchantID
}
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusActive, nil)
case "REJECTED":
appUpdates["audit_status"] = model.AuditStatusRejected
appUpdates["reject_reason"] = resp.RejectReason
appUpdates["audited_at"] = now
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusRejected, nil)
case "REVIEWING":
appUpdates["audit_status"] = model.AuditStatusReviewing
case "FROZEN":
s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusFrozen, nil)
}
if len(appUpdates) > 0 {
s.merchantRepo.UpdateApplication(ctx, applicationID, appUpdates)
}
}
// GetChannelMerchantID 返回指定商户在指定渠道进件审核通过后的渠道商户ID
// 若该商户未在该渠道进件或审核未通过,返回空字符串
func (s *MerchantService) GetChannelMerchantID(ctx context.Context, merchantID, channelCode string) (string, error) {
app, err := s.merchantRepo.GetApprovedApplicationByChannel(ctx, merchantID, channelCode)
if err != nil {
return "", err
}
if app == nil {
return "", nil
}
return app.ChannelMerchantID, nil
}
// CreateMerchantForApp 业务侧创建商户,强制绑定 appID
func (s *MerchantService) CreateMerchantForApp(ctx context.Context, appID string, m *model.Merchant) error {
m.AppID = appID
return s.merchantRepo.Create(ctx, m)
}
// GetMerchantForApp 业务侧查询,校验 appID 归属
func (s *MerchantService) GetMerchantForApp(ctx context.Context, appID, merchantID string) (*model.Merchant, error) {
m, err := s.merchantRepo.GetByMerchantIDAndAppID(ctx, merchantID, appID)
if err != nil {
return nil, err
}
if m == nil {
return nil, errors.New("30001") // merchant not found
}
return m, nil
}
// ListMerchantsForApp 业务侧列表,只返回该 appID 下的商户
func (s *MerchantService) ListMerchantsForApp(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
return s.merchantRepo.ListByAppID(ctx, appID, status, limit, offset)
}
// ApplyForApp 业务侧进件,校验 appID 归属后委托 Apply
func (s *MerchantService) ApplyForApp(ctx context.Context, appID, merchantID, channelCode string, bizContent map[string]any) (string, error) {
if _, err := s.GetMerchantForApp(ctx, appID, merchantID); err != nil {
return "", err
}
return s.Apply(ctx, merchantID, channelCode, bizContent)
}
// QueryAuditStatusForApp 业务侧查审核状态,校验 appID 归属
func (s *MerchantService) QueryAuditStatusForApp(ctx context.Context, appID, merchantID string) (*model.MerchantApplication, error) {
if _, err := s.GetMerchantForApp(ctx, appID, merchantID); err != nil {
return nil, err
}
return s.QueryAuditStatus(ctx, merchantID)
}
// CheckAnomalies 检查状态异常的商户(由 cron 调用)
func (s *MerchantService) CheckAnomalies(ctx context.Context) error {
merchants, err := s.merchantRepo.ListAnomalous(ctx)
if err != nil {
return err
}
slog.InfoContext(ctx, "anomalous merchants found", "count", len(merchants))
// 实际业务中可在此发送告警通知
return nil
}
// CreateMerchant 创建商户基础信息
func (s *MerchantService) CreateMerchant(ctx context.Context, m *model.Merchant) error {
return s.merchantRepo.Create(ctx, m)
}
// GetMerchant 查询商户信息
func (s *MerchantService) GetMerchant(ctx context.Context, merchantID string) (*model.Merchant, error) {
return s.merchantRepo.GetByMerchantID(ctx, merchantID)
}
// ListMerchants 查询商户列表
func (s *MerchantService) ListMerchants(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
return s.merchantRepo.List(ctx, status, limit, offset)
}
// FreezeMerchant 冻结商户
func (s *MerchantService) FreezeMerchant(ctx context.Context, merchantID string) error {
return s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusFrozen, nil)
}
// UnfreezeMerchant 解冻商户
func (s *MerchantService) UnfreezeMerchant(ctx context.Context, merchantID string) error {
return s.merchantRepo.UpdateStatus(ctx, merchantID, model.MerchantStatusActive, nil)
}

View File

@@ -0,0 +1,200 @@
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"pay-bridge/internal/model"
)
// mockMerchantRepo 实现 merchantRepo interface
type mockMerchantRepo struct {
mock.Mock
}
func (m *mockMerchantRepo) Create(ctx context.Context, merchant *model.Merchant) error {
return m.Called(ctx, merchant).Error(0)
}
func (m *mockMerchantRepo) GetByMerchantID(ctx context.Context, merchantID string) (*model.Merchant, error) {
args := m.Called(ctx, merchantID)
return args.Get(0).(*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) GetByMerchantIDAndAppID(ctx context.Context, merchantID, appID string) (*model.Merchant, error) {
args := m.Called(ctx, merchantID, appID)
v, _ := args.Get(0).(*model.Merchant)
return v, args.Error(1)
}
func (m *mockMerchantRepo) UpdateStatus(ctx context.Context, merchantID string, status model.MerchantStatus, updates map[string]any) error {
return m.Called(ctx, merchantID, status, updates).Error(0)
}
func (m *mockMerchantRepo) List(ctx context.Context, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) ListByAppID(ctx context.Context, appID string, status model.MerchantStatus, limit, offset int) ([]*model.Merchant, error) {
args := m.Called(ctx, appID, status, limit, offset)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) ListAnomalous(ctx context.Context) ([]*model.Merchant, error) {
args := m.Called(ctx)
return args.Get(0).([]*model.Merchant), args.Error(1)
}
func (m *mockMerchantRepo) CreateApplication(ctx context.Context, app *model.MerchantApplication) error {
return m.Called(ctx, app).Error(0)
}
func (m *mockMerchantRepo) GetLatestApplication(ctx context.Context, merchantID string) (*model.MerchantApplication, error) {
args := m.Called(ctx, merchantID)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantRepo) GetApprovedApplicationByChannel(ctx context.Context, merchantID, channelCode string) (*model.MerchantApplication, error) {
args := m.Called(ctx, merchantID, channelCode)
v, _ := args.Get(0).(*model.MerchantApplication)
return v, args.Error(1)
}
func (m *mockMerchantRepo) UpdateApplication(ctx context.Context, applicationID string, updates map[string]any) error {
return m.Called(ctx, applicationID, updates).Error(0)
}
// newTestMerchantService 创建注入了 mock repo 的 servicechannelSvc 为 nil仅测不涉及渠道的方法
func newTestMerchantService(repo merchantRepo) *MerchantService {
return &MerchantService{merchantRepo: repo}
}
var ctx = context.Background()
// --- GetMerchantForApp ---
func TestGetMerchantForApp_OK(t *testing.T) {
repo := new(mockMerchantRepo)
want := &model.Merchant{MerchantID: "m001", AppID: "app1"}
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return(want, nil)
svc := newTestMerchantService(repo)
got, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.NoError(t, err)
assert.Equal(t, want, got)
repo.AssertExpectations(t)
}
func TestGetMerchantForApp_NotFound(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.EqualError(t, err, "30001")
}
func TestGetMerchantForApp_WrongAppID(t *testing.T) {
repo := new(mockMerchantRepo)
// 商户存在但属于 other_appGetByMerchantIDAndAppID 返回 nil
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "evil_app").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "evil_app", "m001")
assert.EqualError(t, err, "30001", "跨 appID 访问应返回 not found而不是泄露商户信息")
}
func TestGetMerchantForApp_DBError(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), errors.New("db error"))
svc := newTestMerchantService(repo)
_, err := svc.GetMerchantForApp(ctx, "app1", "m001")
assert.EqualError(t, err, "db error")
}
// --- CreateMerchantForApp ---
func TestCreateMerchantForApp_SetsAppID(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("Create", ctx, mock.MatchedBy(func(m *model.Merchant) bool {
return m.AppID == "app1" && m.MerchantID == "m001"
})).Return(nil)
svc := newTestMerchantService(repo)
m := &model.Merchant{MerchantID: "m001"}
err := svc.CreateMerchantForApp(ctx, "app1", m)
assert.NoError(t, err)
assert.Equal(t, "app1", m.AppID, "AppID 应被强制写入")
repo.AssertExpectations(t)
}
// --- ListMerchantsForApp ---
func TestListMerchantsForApp_OnlyReturnsOwnApp(t *testing.T) {
repo := new(mockMerchantRepo)
want := []*model.Merchant{{MerchantID: "m001", AppID: "app1"}}
repo.On("ListByAppID", ctx, "app1", model.MerchantStatus(""), 20, 0).Return(want, nil)
svc := newTestMerchantService(repo)
got, err := svc.ListMerchantsForApp(ctx, "app1", "", 20, 0)
assert.NoError(t, err)
assert.Len(t, got, 1)
repo.AssertExpectations(t)
}
// --- ApplyForApp ---
func TestApplyForApp_MerchantNotBelongToApp(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetByMerchantIDAndAppID", ctx, "m001", "app1").Return((*model.Merchant)(nil), nil)
svc := newTestMerchantService(repo)
_, err := svc.ApplyForApp(ctx, "app1", "m001", "HEEPAY", nil)
assert.EqualError(t, err, "30001", "不属于该 app 的商户不能提交进件")
}
// --- GetChannelMerchantID ---
func TestGetChannelMerchantID_Approved(t *testing.T) {
repo := new(mockMerchantRepo)
app := &model.MerchantApplication{
ChannelMerchantID: "ch_m_999",
}
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").Return(app, nil)
svc := newTestMerchantService(repo)
id, err := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
assert.NoError(t, err)
assert.Equal(t, "ch_m_999", id)
}
func TestGetChannelMerchantID_NotApproved(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").Return((*model.MerchantApplication)(nil), nil)
svc := newTestMerchantService(repo)
id, err := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
assert.NoError(t, err)
assert.Empty(t, id, "未在该渠道进件时返回空字符串")
}
func TestGetChannelMerchantID_MultiChannel(t *testing.T) {
repo := new(mockMerchantRepo)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "HEEPAY").
Return(&model.MerchantApplication{ChannelMerchantID: "hee_001"}, nil)
repo.On("GetApprovedApplicationByChannel", ctx, "m001", "ALIPAY").
Return(&model.MerchantApplication{ChannelMerchantID: "ali_001"}, nil)
svc := newTestMerchantService(repo)
heeID, _ := svc.GetChannelMerchantID(ctx, "m001", "HEEPAY")
aliID, _ := svc.GetChannelMerchantID(ctx, "m001", "ALIPAY")
assert.Equal(t, "hee_001", heeID, "不同渠道应返回各自的 channel_merchant_id")
assert.Equal(t, "ali_001", aliID)
}

View File

@@ -0,0 +1,229 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
)
// 重试间隔9 次推送机会第1次立即后续8次重试
var retryIntervals = []time.Duration{
0,
15 * time.Second,
30 * time.Second,
1 * time.Minute,
5 * time.Minute,
30 * time.Minute,
1 * time.Hour,
6 * time.Hour,
12 * time.Hour,
}
const maxRetry = 8
// NotifyService 通知服务
type NotifyService struct {
notifyRepo *repository.NotifyLogRepository
tradeRepo *repository.TradeOrderRepository
httpClient *http.Client
}
func NewNotifyService(
notifyRepo *repository.NotifyLogRepository,
tradeRepo *repository.TradeOrderRepository,
httpTimeout time.Duration,
) *NotifyService {
return &NotifyService{
notifyRepo: notifyRepo,
tradeRepo: tradeRepo,
httpClient: &http.Client{Timeout: httpTimeout},
}
}
// SendNotify 向下游发送通知(首次调用)
func (s *NotifyService) SendNotify(ctx context.Context, tradeNo string, notifyType model.NotifyType, notifyURL string) error {
// 构建通知内容
payload, err := s.buildPayload(ctx, tradeNo, notifyType)
if err != nil {
return err
}
// 创建通知记录
now := time.Now()
log := &model.NotifyLog{
TradeNo: tradeNo,
NotifyType: notifyType,
NotifyURL: notifyURL,
Status: model.NotifyStatusPending,
RetryCount: 0,
}
if err := s.notifyRepo.Upsert(ctx, log); err != nil {
slog.ErrorContext(ctx, "upsert notify log failed", "trade_no", tradeNo, "err", err)
}
// 发送通知
resp, err := s.sendHTTP(ctx, notifyURL, payload)
if err == nil && isSuccessResponse(resp) {
s.notifyRepo.MarkSuccess(ctx, log.ID, resp)
slog.InfoContext(ctx, "notify success", "trade_no", tradeNo, "type", notifyType)
return nil
}
// 首次失败,写入重试队列
errMsg := ""
if err != nil {
errMsg = err.Error()
} else {
errMsg = resp
}
nextTime := now.Add(retryIntervals[1])
s.notifyRepo.IncrRetryCount(ctx, log.ID, model.NotifyStatusRetry, &nextTime, errMsg)
slog.WarnContext(ctx, "notify failed, scheduled retry", "trade_no", tradeNo, "next_retry", nextTime)
return nil
}
// ProcessRetryQueue 处理重试队列(由 Poller 调用)
func (s *NotifyService) ProcessRetryQueue(ctx context.Context, batchSize int) error {
logs, err := s.notifyRepo.ListPendingRetry(ctx, time.Now(), batchSize)
if err != nil {
return err
}
for _, log := range logs {
s.processOne(ctx, log)
}
return nil
}
func (s *NotifyService) processOne(ctx context.Context, log *model.NotifyLog) {
payload, err := s.buildPayload(ctx, log.TradeNo, log.NotifyType)
if err != nil {
slog.ErrorContext(ctx, "build payload failed", "trade_no", log.TradeNo, "err", err)
return
}
resp, err := s.sendHTTP(ctx, log.NotifyURL, payload)
if err == nil && isSuccessResponse(resp) {
s.notifyRepo.MarkSuccess(ctx, log.ID, resp)
slog.InfoContext(ctx, "notify retry success", "trade_no", log.TradeNo, "retry_count", log.RetryCount)
return
}
errMsg := ""
if err != nil {
errMsg = err.Error()
} else {
errMsg = resp
}
nextRetryIdx := log.RetryCount + 1
if nextRetryIdx > maxRetry {
s.notifyRepo.MarkGiveup(ctx, log.ID)
slog.WarnContext(ctx, "notify giveup after max retries", "trade_no", log.TradeNo)
return
}
var nextTime *time.Time
if nextRetryIdx < len(retryIntervals) {
t := time.Now().Add(retryIntervals[nextRetryIdx])
nextTime = &t
}
status := model.NotifyStatusRetry
if nextRetryIdx >= maxRetry {
status = model.NotifyStatusGiveup
}
s.notifyRepo.IncrRetryCount(ctx, log.ID, status, nextTime, errMsg)
}
// buildPayload 构建通知内容
func (s *NotifyService) buildPayload(ctx context.Context, tradeNo string, notifyType model.NotifyType) ([]byte, error) {
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil || order == nil {
return nil, fmt.Errorf("order not found: %s", tradeNo)
}
payload := map[string]any{
"trade_no": order.TradeNo,
"merchant_order_no": order.MerchantOrderNo,
"app_id": order.AppID,
"pay_method": order.PayMethod,
"amount": order.Amount,
"status": order.Status,
"notify_type": notifyType,
"timestamp": time.Now().Unix(),
}
if order.ChannelTradeNo != "" {
payload["channel_trade_no"] = order.ChannelTradeNo
}
if order.PayTime != nil {
payload["pay_time"] = order.PayTime.Unix()
}
return json.Marshal(payload)
}
// sendHTTP 向下游发送 HTTP POST 通知
func (s *NotifyService) sendHTTP(ctx context.Context, notifyURL string, payload []byte) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, notifyURL, bytes.NewReader(payload))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return string(body), nil
}
// isSuccessResponse 判断下游是否返回成功
// 下游返回 HTTP 200 且 body 包含 "success" 则视为成功
func isSuccessResponse(body string) bool {
return strings.Contains(strings.ToLower(body), "success")
}
// NextRetryTime 计算下次重试时间
func NextRetryTime(retryCount int) (time.Time, bool) {
idx := retryCount + 1
if idx >= len(retryIntervals) {
return time.Time{}, false
}
return time.Now().Add(retryIntervals[idx]), true
}
// StartPoller 启动通知重试 Poller goroutine
func (s *NotifyService) StartPoller(ctx context.Context, interval time.Duration, batchSize int) {
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.ProcessRetryQueue(ctx, batchSize); err != nil {
slog.Error("notify poller error", "err", err)
}
}
}
}()
slog.Info("notify poller started", "interval", interval)
}

View File

@@ -0,0 +1,280 @@
package service
import (
"context"
"log/slog"
"regexp"
"strings"
"time"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
)
// orderNoPatterns 从备注中提取订单号的正则列表(优先级从高到低)
var orderNoPatterns = []*regexp.Regexp{
regexp.MustCompile(`PAY\d{14}`), // pay-bridge 交易号格式 PAYyyMMddNNNNNNNN
regexp.MustCompile(`REF\d{14}`), // 退款单号
regexp.MustCompile(`[A-Z0-9]{16,32}`), // 通用订单号格式
}
// IncomingPayment 入账通知数据
type IncomingPayment struct {
AccountNo string // 收款账号
Amount int64 // 入账金额(分)
Remark string // 转账备注
PayerName string // 付款方名称
ChannelBillNo string // 渠道流水号
}
// matchWindow 匹配时间窗口7天内的待支付订单
const matchWindow = 7 * 24 * time.Hour
// PaymentMatchService 收款匹配服务
type PaymentMatchService struct {
matchRepo *repository.PaymentMatchRepository
tradeRepo *repository.TradeOrderRepository
notifySvc *NotifyService
tradeSvc *TradeService
}
func NewPaymentMatchService(
matchRepo *repository.PaymentMatchRepository,
tradeRepo *repository.TradeOrderRepository,
notifySvc *NotifyService,
tradeSvc *TradeService,
) *PaymentMatchService {
return &PaymentMatchService{
matchRepo: matchRepo,
tradeRepo: tradeRepo,
notifySvc: notifySvc,
tradeSvc: tradeSvc,
}
}
// HandleIncomingPayment 处理入账通知(核心匹配流程)
func (s *PaymentMatchService) HandleIncomingPayment(ctx context.Context, incoming *IncomingPayment) error {
// 幂等检查
if existing, _ := s.matchRepo.GetMatchLogByBillNo(ctx, incoming.ChannelBillNo); existing != nil {
return nil
}
// 查询收款账户
account, err := s.matchRepo.GetAccountByNo(ctx, incoming.AccountNo)
if err != nil {
return err
}
if account == nil {
slog.WarnContext(ctx, "incoming payment: account not found", "account_no", incoming.AccountNo)
return nil
}
// 执行匹配
result := s.match(ctx, incoming, account)
// 记录匹配结果
now := time.Now()
log := &model.PaymentMatchLog{
AccountID: account.ID,
IncomingAmount: incoming.Amount,
IncomingRemark: incoming.Remark,
PayerName: incoming.PayerName,
ChannelBillNo: incoming.ChannelBillNo,
MatchStatus: result.status,
NameDiff: result.nameDiff,
}
if result.tradeNo != "" {
log.TradeNo = result.tradeNo
log.MatchTime = &now
}
if err := s.matchRepo.CreateMatchLog(ctx, log); err != nil {
return err
}
// 匹配成功:更新订单状态并通知下游
if result.tradeNo != "" {
updates := map[string]any{
"status": model.TradeStatusPaid,
"pay_time": now,
}
ok, err := s.tradeRepo.UpdateStatus(ctx, result.tradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
if err != nil {
return err
}
if ok {
order, _ := s.tradeRepo.GetByTradeNo(ctx, result.tradeNo)
if order != nil && s.notifySvc != nil {
go func() {
bgCtx := context.Background()
s.notifySvc.SendNotify(bgCtx, result.tradeNo, model.NotifyTypePayment, order.NotifyURL)
}()
}
}
slog.InfoContext(ctx, "payment matched",
"trade_no", result.tradeNo,
"amount", incoming.Amount,
"status", result.status,
"name_diff", result.nameDiff,
)
} else {
slog.InfoContext(ctx, "payment pending manual",
"channel_bill_no", incoming.ChannelBillNo,
"amount", incoming.Amount,
)
}
return nil
}
// ManualBindOrder 人工关联入账与订单
func (s *PaymentMatchService) ManualBindOrder(ctx context.Context, matchID uint64, tradeNo, operator string) error {
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil || order == nil {
return err
}
now := time.Now()
updates := map[string]any{
"trade_no": tradeNo,
"match_status": model.MatchStatusMatched,
"match_time": now,
"operator": operator,
}
if err := s.matchRepo.UpdateMatchLog(ctx, matchID, updates); err != nil {
return err
}
// 更新订单状态
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusPaid,
map[string]any{"pay_time": now})
if s.notifySvc != nil {
go func() {
bgCtx := context.Background()
s.notifySvc.SendNotify(bgCtx, tradeNo, model.NotifyTypePayment, order.NotifyURL)
}()
}
return nil
}
// ListPendingManual 查询待人工确认的收款记录
func (s *PaymentMatchService) ListPendingManual(ctx context.Context, appID string, limit, offset int) ([]*model.PaymentMatchLog, error) {
return s.matchRepo.ListPendingManual(ctx, appID, limit, offset)
}
// --- 内部匹配逻辑 ---
type matchResult struct {
tradeNo string
status model.MatchStatus
nameDiff int8
}
func (s *PaymentMatchService) match(ctx context.Context, incoming *IncomingPayment, account *model.SubMerchantAccount) matchResult {
// Step 1: 从备注中提取订单号
candidates := extractOrderNos(incoming.Remark)
var matched *model.TradeOrder
for _, orderNo := range candidates {
// 先按 trade_no 查,再按 merchant_order_no 查
order, _ := s.tradeRepo.GetByTradeNo(ctx, orderNo)
if order == nil {
order, _ = s.tradeRepo.GetByMerchantOrderNo(ctx, account.AppID, orderNo)
}
if order == nil || order.AppID != account.AppID {
continue
}
if order.Status != model.TradeStatusPaying {
continue
}
// Step 2: 金额精确匹配
if order.Amount != incoming.Amount {
continue
}
matched = order
break
}
// 备注匹配失败,降级为金额匹配
if matched == nil {
orders, _ := s.matchRepo.ListPayingByAmount(ctx, account.AppID, incoming.Amount, matchWindow)
if len(orders) == 1 {
matched = orders[0]
} else if len(orders) > 1 {
// Step 3: 用付款方名称缩小范围
matched = filterByPayerName(orders, incoming.PayerName)
if matched == nil {
return matchResult{status: model.MatchStatusPendingManual}
}
} else {
return matchResult{status: model.MatchStatusPendingManual}
}
}
// Step 3: 付款方名称一致性检查
var nameDiff int8 = 0
invoiceName := getInvoiceName(matched)
if invoiceName != "" && incoming.PayerName != "" {
if !strings.EqualFold(strings.TrimSpace(invoiceName), strings.TrimSpace(incoming.PayerName)) {
nameDiff = 1
}
}
status := model.MatchStatusMatched
if nameDiff == 1 {
status = model.MatchStatusNameDiff
}
return matchResult{
tradeNo: matched.TradeNo,
status: status,
nameDiff: nameDiff,
}
}
// extractOrderNos 从备注字符串中提取可能的订单号
func extractOrderNos(remark string) []string {
if remark == "" {
return nil
}
var results []string
seen := map[string]bool{}
for _, re := range orderNoPatterns {
matches := re.FindAllString(remark, -1)
for _, m := range matches {
if !seen[m] {
seen[m] = true
results = append(results, m)
}
}
}
return results
}
// filterByPayerName 从多个候选订单中,选择 invoice_name 与付款方名称匹配的订单
// invoice_name 暂存在 extra 字段中
func filterByPayerName(orders []*model.TradeOrder, payerName string) *model.TradeOrder {
if payerName == "" {
return nil
}
for _, o := range orders {
name := getInvoiceName(o)
if name != "" && strings.EqualFold(strings.TrimSpace(name), strings.TrimSpace(payerName)) {
return o
}
}
return nil
}
// getInvoiceName 从 extra 字段获取开票名称
func getInvoiceName(order *model.TradeOrder) string {
if order.Extra == nil {
return ""
}
if v, ok := order.Extra["invoice_name"]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"errors"
"fmt"
"log/slog"
"time"
"github.com/go-redis/redis/v8"
"pay-bridge/internal/channel"
"pay-bridge/internal/errcode"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/sequence"
)
const (
sharingLockPrefix = "lock:sharing:"
sharingLockTTL = 30 * time.Second
)
// ProfitSharingService 分润服务
type ProfitSharingService struct {
sharingRepo *repository.ProfitSharingRepository
tradeRepo *repository.TradeOrderRepository
channelSvc *ChannelService
seqSvc *sequence.Service
rdb *redis.Client
}
func NewProfitSharingService(
sharingRepo *repository.ProfitSharingRepository,
tradeRepo *repository.TradeOrderRepository,
channelSvc *ChannelService,
seqSvc *sequence.Service,
rdb *redis.Client,
) *ProfitSharingService {
return &ProfitSharingService{
sharingRepo: sharingRepo,
tradeRepo: tradeRepo,
channelSvc: channelSvc,
seqSvc: seqSvc,
rdb: rdb,
}
}
// TriggerSharing 支付成功后触发分润(幂等)
func (s *ProfitSharingService) TriggerSharing(ctx context.Context, tradeNo string) error {
// 分布式锁防止并发重复触发
lockKey := sharingLockPrefix + tradeNo
ok, err := s.rdb.SetNX(ctx, lockKey, "1", sharingLockTTL).Result()
if err != nil {
return fmt.Errorf("acquire sharing lock: %w", err)
}
if !ok {
return nil // 已有进程在处理
}
defer s.rdb.Del(ctx, lockKey)
// 查询交易
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil || order == nil {
return errors.New(errcode.ErrOrderNotFound)
}
if order.ProfitSharingAmount <= 0 {
return nil // 无需分润
}
// 幂等检查:是否已有分润记录
existing, err := s.sharingRepo.GetOrderByTradeNo(ctx, tradeNo)
if err != nil {
return err
}
if existing != nil {
return nil // 已触发过
}
// 获取应用分润配置
cfg, err := s.sharingRepo.GetConfigByAppID(ctx, order.AppID)
if err != nil {
return err
}
if cfg == nil {
return errors.New(errcode.ErrSharingNotConfig)
}
// 校验分润比例
maxAmount := int64(float64(order.Amount) * cfg.MaxSharingRatio)
if order.ProfitSharingAmount > maxAmount {
return errors.New(errcode.ErrSharingAmountExceed)
}
// 生成分润单号
sharingNo, err := s.seqSvc.NextSharingNo(ctx, order.AppID)
if err != nil {
return err
}
// 创建分润记录
sharingOrder := &model.ProfitSharingOrder{
SharingNo: sharingNo,
TradeNo: tradeNo,
AppID: order.AppID,
ReceiverMerchantID: cfg.ReceiverMerchantID,
SharingAmount: order.ProfitSharingAmount,
Status: model.ProfitSharingStatusPending,
}
if err := s.sharingRepo.CreateOrder(ctx, sharingOrder); err != nil {
return err
}
// 调用渠道分账
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
if err != nil {
return err
}
resp, err := ch.ProfitSharing(ctx, &channel.ProfitSharingReq{
TradeNo: tradeNo,
ChannelTradeNo: order.ChannelTradeNo,
SharingNo: sharingNo,
ReceiverMerchantID: cfg.ReceiverMerchantID,
Amount: order.ProfitSharingAmount,
})
if err != nil {
s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
model.ProfitSharingStatusPending,
model.ProfitSharingStatusFailed,
map[string]any{"fail_reason": err.Error()},
)
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
SharingNo: sharingNo,
Action: "SPLIT",
Amount: order.ProfitSharingAmount,
Status: "FAILED",
})
return fmt.Errorf("profit sharing failed: %w", err)
}
now := time.Now()
s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
model.ProfitSharingStatusPending,
model.ProfitSharingStatusProcessing,
map[string]any{
"channel_sharing_no": resp.ChannelSharingNo,
"sharing_time": now,
},
)
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
SharingNo: sharingNo,
Action: "SPLIT",
Amount: order.ProfitSharingAmount,
Status: "PROCESSING",
})
slog.InfoContext(ctx, "profit sharing triggered",
"trade_no", tradeNo,
"sharing_no", sharingNo,
"amount", order.ProfitSharingAmount,
)
return nil
}
// HandleSharingNotify 处理分账回调(上游分账完成通知)
func (s *ProfitSharingService) HandleSharingNotify(ctx context.Context, sharingNo, channelSharingNo string, status model.ProfitSharingStatus) error {
now := time.Now()
updates := map[string]any{
"channel_sharing_no": channelSharingNo,
"sharing_time": now,
}
ok, err := s.sharingRepo.UpdateOrderStatus(ctx, sharingNo,
model.ProfitSharingStatusProcessing, status, updates)
if err != nil {
return err
}
if !ok {
return nil // 幂等
}
logStatus := string(status)
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
SharingNo: sharingNo,
Action: "SPLIT",
Amount: 0,
Status: logStatus,
})
return nil
}
// RollbackSharing 退款前回退分润
func (s *ProfitSharingService) RollbackSharing(ctx context.Context, tradeNo string) error {
sharingOrder, err := s.sharingRepo.GetOrderByTradeNo(ctx, tradeNo)
if err != nil {
return err
}
if sharingOrder == nil {
return nil // 无分润,直接跳过
}
if sharingOrder.Status == model.ProfitSharingStatusRollback {
return nil // 已回退,幂等
}
if sharingOrder.Status != model.ProfitSharingStatusSuccess {
return fmt.Errorf("sharing not success, cannot rollback, status=%s", sharingOrder.Status)
}
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil || order == nil {
return errors.New(errcode.ErrOrderNotFound)
}
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
if err != nil {
return err
}
if err := ch.RollbackProfitSharing(ctx, &channel.RollbackSharingReq{
SharingNo: sharingOrder.SharingNo,
ChannelSharingNo: sharingOrder.ChannelSharingNo,
TradeNo: tradeNo,
}); err != nil {
return fmt.Errorf("rollback sharing failed: %w", err)
}
s.sharingRepo.UpdateOrderStatus(ctx, sharingOrder.SharingNo,
model.ProfitSharingStatusSuccess,
model.ProfitSharingStatusRollback,
nil,
)
s.sharingRepo.CreateLog(ctx, &model.ProfitSharingLog{
SharingNo: sharingOrder.SharingNo,
Action: "ROLLBACK",
Amount: sharingOrder.SharingAmount,
Status: "SUCCESS",
})
return nil
}
// QuerySharing 查询分润状态
func (s *ProfitSharingService) QuerySharing(ctx context.Context, sharingNo string) (*model.ProfitSharingOrder, error) {
order, err := s.sharingRepo.GetOrderBySharingNo(ctx, sharingNo)
if err != nil {
return nil, err
}
if order == nil {
return nil, errors.New(errcode.ErrOrderNotFound)
}
return order, nil
}
// ValidateSharingAmount 下单时校验分润金额是否合法
func (s *ProfitSharingService) ValidateSharingAmount(ctx context.Context, appID string, orderAmount, sharingAmount int64) error {
if sharingAmount <= 0 {
return nil
}
cfg, err := s.sharingRepo.GetConfigByAppID(ctx, appID)
if err != nil {
return err
}
if cfg == nil {
return errors.New(errcode.ErrSharingNotConfig)
}
maxAmount := int64(float64(orderAmount) * cfg.MaxSharingRatio)
if sharingAmount > maxAmount {
return errors.New(errcode.ErrSharingAmountExceed)
}
return nil
}

View File

@@ -0,0 +1,221 @@
package service
import (
"context"
"fmt"
"log/slog"
"time"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
)
// ReconciliationService T+1 自动对账服务
type ReconciliationService struct {
reconRepo *repository.ReconciliationRepository
tradeRepo *repository.TradeOrderRepository
channelSvc *ChannelService
appRepo *repository.AppRepository
}
func NewReconciliationService(
reconRepo *repository.ReconciliationRepository,
tradeRepo *repository.TradeOrderRepository,
channelSvc *ChannelService,
appRepo *repository.AppRepository,
) *ReconciliationService {
return &ReconciliationService{
reconRepo: reconRepo,
tradeRepo: tradeRepo,
channelSvc: channelSvc,
appRepo: appRepo,
}
}
// RunDailyReconciliation 执行 T+1 对账cron 每日触发)
func (s *ReconciliationService) RunDailyReconciliation(ctx context.Context) error {
// 对账日期:昨天
billDate := time.Now().AddDate(0, 0, -1).Format("2006-01-02")
slog.InfoContext(ctx, "reconciliation started", "bill_date", billDate)
apps, err := s.appRepo.ListActive(ctx)
if err != nil {
return err
}
for _, app := range apps {
if err := s.reconcileApp(ctx, app.AppID, billDate); err != nil {
slog.ErrorContext(ctx, "reconciliation failed for app",
"app_id", app.AppID,
"bill_date", billDate,
"error", err,
)
}
}
return nil
}
// reconcileApp 对指定应用执行对账
func (s *ReconciliationService) reconcileApp(ctx context.Context, appID, billDate string) error {
// 获取所有活跃渠道配置
channelCodes, err := s.channelSvc.ListChannelCodes(ctx, appID)
if err != nil {
return err
}
for _, code := range channelCodes {
if err := s.reconcileChannel(ctx, appID, code, billDate); err != nil {
slog.ErrorContext(ctx, "channel reconciliation failed",
"app_id", appID,
"channel", code,
"bill_date", billDate,
"error", err,
)
}
}
return nil
}
// reconcileChannel 对单个渠道执行对账
func (s *ReconciliationService) reconcileChannel(ctx context.Context, appID, channelCode, billDate string) error {
// 幂等检查
existing, err := s.reconRepo.GetReport(ctx, appID, billDate, channelCode)
if err != nil {
return err
}
if existing != nil && existing.Status == model.ReconciliationStatusMatched {
return nil // 已对账完成
}
// 创建对账报告
report := &model.ReconciliationReport{
AppID: appID,
ChannelCode: channelCode,
BillDate: billDate,
Status: model.ReconciliationStatusPending,
}
if existing == nil {
if err := s.reconRepo.CreateReport(ctx, report); err != nil {
return err
}
} else {
report = existing
}
// 下载渠道对账单
ch, err := s.channelSvc.GetChannel(ctx, appID, channelCode)
if err != nil {
return err
}
billData, err := ch.DownloadBill(ctx, &channel.DownloadBillReq{BillDate: billDate})
if err != nil {
return fmt.Errorf("download bill: %w", err)
}
// 查询本地已支付订单
localOrders, err := s.reconRepo.ListPaidOrdersByDate(ctx, appID, billDate)
if err != nil {
return err
}
// 建立本地订单索引
localIndex := make(map[string]*model.TradeOrder, len(localOrders))
for _, o := range localOrders {
localIndex[o.TradeNo] = o
}
// 建立渠道账单索引
channelIndex := make(map[string]*channel.BillRecord, len(billData.Records))
for i := range billData.Records {
channelIndex[billData.Records[i].TradeNo] = &billData.Records[i]
}
matched := 0
exceptions := 0
// 检查渠道账单中有,本地没有的(漏单)
for _, rec := range billData.Records {
local, ok := localIndex[rec.TradeNo]
if !ok {
// 本地缺失
ex := &model.ReconciliationException{
ReportID: report.ID,
TradeNo: rec.TradeNo,
ChannelBillNo: rec.ChannelBillNo,
ExceptionType: "MISSING_LOCAL",
ChannelAmount: rec.Amount,
Remark: "渠道有记录,本地无此订单",
}
s.reconRepo.CreateException(ctx, ex)
exceptions++
continue
}
// 金额比对
if local.Amount != rec.Amount {
ex := &model.ReconciliationException{
ReportID: report.ID,
TradeNo: rec.TradeNo,
ChannelBillNo: rec.ChannelBillNo,
ExceptionType: "AMOUNT_MISMATCH",
LocalAmount: local.Amount,
ChannelAmount: rec.Amount,
Remark: fmt.Sprintf("金额不符:本地%d 渠道%d", local.Amount, rec.Amount),
}
s.reconRepo.CreateException(ctx, ex)
exceptions++
} else {
matched++
}
}
// 检查本地有,渠道账单中没有的(多单)
for tradeNo, local := range localIndex {
if _, ok := channelIndex[tradeNo]; !ok {
ex := &model.ReconciliationException{
ReportID: report.ID,
TradeNo: tradeNo,
ExceptionType: "MISSING_CHANNEL",
LocalAmount: local.Amount,
Remark: "本地已支付,渠道账单无记录",
}
s.reconRepo.CreateException(ctx, ex)
exceptions++
}
}
// 更新对账报告
status := model.ReconciliationStatusMatched
if exceptions > 0 {
status = model.ReconciliationStatusException
}
updates := map[string]any{
"total_count": len(billData.Records),
"total_amount": billData.TotalAmount,
"matched_count": matched,
"exception_count": exceptions,
"status": status,
}
if err := s.reconRepo.UpdateReport(ctx, report.ID, updates); err != nil {
return err
}
slog.InfoContext(ctx, "reconciliation done",
"app_id", appID,
"channel", channelCode,
"bill_date", billDate,
"matched", matched,
"exceptions", exceptions,
)
return nil
}
// GetReport 查询对账报告
func (s *ReconciliationService) GetReport(ctx context.Context, appID, billDate, channelCode string) (*model.ReconciliationReport, error) {
return s.reconRepo.GetReport(ctx, appID, billDate, channelCode)
}
// GetExceptions 查询对账异常明细
func (s *ReconciliationService) GetExceptions(ctx context.Context, reportID uint64) ([]*model.ReconciliationException, error) {
return s.reconRepo.ListExceptions(ctx, reportID)
}

View File

@@ -0,0 +1,213 @@
package service
import (
"context"
"errors"
"log/slog"
"time"
"pay-bridge/internal/channel"
"pay-bridge/internal/errcode"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/sequence"
)
// CreateRefundReq 退款请求
type CreateRefundReq struct {
AppID string
TradeNo string
RefundAmount int64
Reason string
NotifyURL string
}
// RefundService 退款服务
type RefundService struct {
refundRepo *repository.RefundOrderRepository
tradeRepo *repository.TradeOrderRepository
channelSvc *ChannelService
seqSvc *sequence.Service
notifySvc *NotifyService
}
func NewRefundService(
refundRepo *repository.RefundOrderRepository,
tradeRepo *repository.TradeOrderRepository,
channelSvc *ChannelService,
seqSvc *sequence.Service,
notifySvc *NotifyService,
) *RefundService {
return &RefundService{
refundRepo: refundRepo,
tradeRepo: tradeRepo,
channelSvc: channelSvc,
seqSvc: seqSvc,
notifySvc: notifySvc,
}
}
// CreateRefund 发起退款
func (s *RefundService) CreateRefund(ctx context.Context, req *CreateRefundReq) (*model.RefundOrder, error) {
// 查询原交易
order, err := s.tradeRepo.GetByTradeNo(ctx, req.TradeNo)
if err != nil {
return nil, err
}
if order == nil || order.AppID != req.AppID {
return nil, errors.New(errcode.ErrOrderNotFound)
}
if order.Status != model.TradeStatusPaid && order.Status != model.TradeStatusRefunded {
return nil, errors.New(errcode.ErrOrderNotPaid)
}
// 校验可退金额
refunded, err := s.refundRepo.SumRefundedAmount(ctx, req.TradeNo)
if err != nil {
return nil, err
}
if refunded+req.RefundAmount > order.Amount {
return nil, errors.New(errcode.ErrRefundAmountExceed)
}
// 生成退款单号
refundNo, err := s.seqSvc.NextRefundNo(ctx, req.AppID)
if err != nil {
return nil, err
}
// 创建退款记录
refund := &model.RefundOrder{
RefundNo: refundNo,
TradeNo: req.TradeNo,
AppID: req.AppID,
ChannelCode: order.ChannelCode,
RefundAmount: req.RefundAmount,
Reason: req.Reason,
Status: model.RefundStatusPending,
NotifyURL: req.NotifyURL,
}
if err := s.refundRepo.Create(ctx, refund); err != nil {
return nil, err
}
// 调用渠道退款
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, order.ChannelCode)
if err != nil {
return nil, err
}
channelResp, err := ch.Refund(ctx, &channel.RefundReq{
TradeNo: req.TradeNo,
ChannelTradeNo: order.ChannelTradeNo,
RefundNo: refundNo,
RefundAmount: req.RefundAmount,
TotalAmount: order.Amount,
Reason: req.Reason,
NotifyURL: req.NotifyURL,
})
if err != nil {
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusFailed, nil)
return nil, errors.New(errcode.ErrChannelRefundFail)
}
// 更新渠道退款单号
updates := map[string]any{
"channel_refund_no": channelResp.ChannelRefundNo,
"status": model.RefundStatusProcessing,
}
s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusPending, model.RefundStatusProcessing, updates)
refund.ChannelRefundNo = channelResp.ChannelRefundNo
refund.Status = model.RefundStatusProcessing
return refund, nil
}
// QueryRefund 查询退款状态
func (s *RefundService) QueryRefund(ctx context.Context, appID, refundNo string) (*model.RefundOrder, error) {
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
if err != nil {
return nil, err
}
if refund == nil || refund.AppID != appID {
return nil, errors.New(errcode.ErrRefundNotFound)
}
// 如果处于处理中,主动查询渠道
if refund.Status == model.RefundStatusProcessing {
s.syncRefundStatus(ctx, refund)
// 重新查询最新状态
refund, _ = s.refundRepo.GetByRefundNo(ctx, refundNo)
}
return refund, nil
}
// HandleRefundNotify 处理退款回调
func (s *RefundService) HandleRefundNotify(ctx context.Context, refundNo string, channelRefundNo string, status model.RefundStatus) error {
refund, err := s.refundRepo.GetByRefundNo(ctx, refundNo)
if err != nil || refund == nil {
return errors.New(errcode.ErrRefundNotFound)
}
if refund.Status == model.RefundStatusSuccess {
return nil // 幂等
}
updates := map[string]any{
"channel_refund_no": channelRefundNo,
}
if status == model.RefundStatusSuccess {
now := time.Now()
updates["refund_time"] = now
}
ok, err := s.refundRepo.UpdateStatus(ctx, refundNo, model.RefundStatusProcessing, status, updates)
if err != nil {
return err
}
if !ok {
return nil // 幂等
}
// 退款成功后通知下游
if status == model.RefundStatusSuccess && refund.NotifyURL != "" && s.notifySvc != nil {
go func() {
bgCtx := context.Background()
if err := s.notifySvc.SendNotify(bgCtx, refund.TradeNo, model.NotifyTypeRefund, refund.NotifyURL); err != nil {
slog.Error("send refund notify failed", "refund_no", refundNo, "err", err)
}
}()
}
return nil
}
func (s *RefundService) syncRefundStatus(ctx context.Context, refund *model.RefundOrder) {
order, err := s.tradeRepo.GetByTradeNo(ctx, refund.TradeNo)
if err != nil || order == nil {
return
}
ch, err := s.channelSvc.GetChannel(ctx, refund.AppID, refund.ChannelCode)
if err != nil {
return
}
resp, err := ch.QueryRefund(ctx, &channel.QueryRefundReq{
RefundNo: refund.RefundNo,
ChannelRefundNo: refund.ChannelRefundNo,
})
if err != nil {
return
}
if resp.Status != refund.Status {
updates := map[string]any{
"channel_refund_no": resp.ChannelRefundNo,
}
if resp.RefundTime != nil {
updates["refund_time"] = resp.RefundTime
}
s.refundRepo.UpdateStatus(ctx, refund.RefundNo, refund.Status, resp.Status, updates)
}
}

View File

@@ -0,0 +1,189 @@
package service
import (
"context"
"fmt"
"log/slog"
"math"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
)
// ServiceFeeService 服务费服务
type ServiceFeeService struct {
feeRepo *repository.ServiceFeeRepository
tradeRepo *repository.TradeOrderRepository
channelSvc *ChannelService
}
func NewServiceFeeService(
feeRepo *repository.ServiceFeeRepository,
tradeRepo *repository.TradeOrderRepository,
channelSvc *ChannelService,
) *ServiceFeeService {
return &ServiceFeeService{
feeRepo: feeRepo,
tradeRepo: tradeRepo,
channelSvc: channelSvc,
}
}
// ChargeServiceFee 交易完成后扣收服务费
func (s *ServiceFeeService) ChargeServiceFee(ctx context.Context, tradeNo string) error {
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil || order == nil {
return fmt.Errorf("order not found: %s", tradeNo)
}
// 幂等检查
existing, err := s.feeRepo.GetLog(ctx, tradeNo, "CHARGE")
if err != nil {
return err
}
if existing != nil {
return nil // 已扣收
}
// 获取服务费配置
group := model.PayMethodToGroup(order.PayMethod)
cfg, err := s.feeRepo.GetConfig(ctx, order.AppID, group)
if err != nil {
return err
}
if cfg == nil || cfg.FeeRate == 0 {
return nil // 未配置或费率为0
}
// 计算服务费(四舍五入到分)
feeAmount := calculateFee(order.Amount, cfg.FeeRate)
if feeAmount <= 0 {
return nil // 不足1分不扣收
}
// 更新订单服务费金额快照
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaid, model.TradeStatusPaid,
map[string]any{"service_fee_amount": feeAmount})
// 创建服务费流水
log := &model.ServiceFeeLog{
TradeNo: tradeNo,
ConfigID: cfg.ID,
FeeAmount: feeAmount,
FeeRate: cfg.FeeRate,
ReceiverMerchantID: cfg.FeeReceiverMerchantID,
Action: "CHARGE",
Status: "PENDING",
}
if err := s.feeRepo.CreateLog(ctx, log); err != nil {
return err
}
// 调用渠道分账
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
if err != nil {
s.feeRepo.UpdateLogStatus(ctx, log.ID, "FAILED", "")
return err
}
resp, err := ch.ProfitSharing(ctx, &channel.ProfitSharingReq{
TradeNo: tradeNo,
ChannelTradeNo: order.ChannelTradeNo,
SharingNo: fmt.Sprintf("FEE%s", tradeNo),
ReceiverMerchantID: cfg.FeeReceiverMerchantID,
Amount: feeAmount,
})
if err != nil {
s.feeRepo.UpdateLogStatus(ctx, log.ID, "FAILED", "")
slog.WarnContext(ctx, "charge service fee failed", "trade_no", tradeNo, "err", err)
return err
}
s.feeRepo.UpdateLogStatus(ctx, log.ID, "SUCCESS", resp.ChannelSharingNo)
slog.InfoContext(ctx, "service fee charged", "trade_no", tradeNo, "fee_amount", feeAmount)
return nil
}
// RollbackServiceFee 退款时回退服务费
func (s *ServiceFeeService) RollbackServiceFee(ctx context.Context, tradeNo string) error {
// 幂等检查
existing, err := s.feeRepo.GetLog(ctx, tradeNo, "ROLLBACK")
if err != nil {
return err
}
if existing != nil {
return nil // 已回退
}
chargeLog, err := s.feeRepo.GetLog(ctx, tradeNo, "CHARGE")
if err != nil {
return err
}
if chargeLog == nil || chargeLog.Status != "SUCCESS" {
return nil // 没有成功扣收,无需回退
}
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil || order == nil {
return fmt.Errorf("order not found: %s", tradeNo)
}
rollbackLog := &model.ServiceFeeLog{
TradeNo: tradeNo,
ConfigID: chargeLog.ConfigID,
FeeAmount: chargeLog.FeeAmount,
FeeRate: chargeLog.FeeRate,
ReceiverMerchantID: chargeLog.ReceiverMerchantID,
Action: "ROLLBACK",
Status: "PENDING",
}
if err := s.feeRepo.CreateLog(ctx, rollbackLog); err != nil {
return err
}
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
if err != nil {
return err
}
sharingNo := fmt.Sprintf("FEE%s", tradeNo)
if err := ch.RollbackProfitSharing(ctx, &channel.RollbackSharingReq{
SharingNo: sharingNo,
ChannelSharingNo: chargeLog.ChannelSharingNo,
TradeNo: tradeNo,
}); err != nil {
s.feeRepo.UpdateLogStatus(ctx, rollbackLog.ID, "FAILED", "")
return err
}
s.feeRepo.UpdateLogStatus(ctx, rollbackLog.ID, "SUCCESS", "")
return nil
}
// CalculateAndValidate 下单时校验分润+服务费不超过订单金额
func (s *ServiceFeeService) CalculateAndValidate(ctx context.Context, appID string, payMethod model.PayMethod, orderAmount, sharingAmount int64) (int64, error) {
group := model.PayMethodToGroup(payMethod)
cfg, err := s.feeRepo.GetConfig(ctx, appID, group)
if err != nil {
return 0, err
}
var feeAmount int64
if cfg != nil && cfg.FeeRate > 0 {
feeAmount = calculateFee(orderAmount, cfg.FeeRate)
}
if sharingAmount+feeAmount > orderAmount {
return 0, fmt.Errorf(errSharingFeeExceed)
}
return feeAmount, nil
}
const errSharingFeeExceed = "30007" // errcode.ErrSharingFeeExceed
// calculateFee 计算服务费(四舍五入到分)
func calculateFee(amount int64, rate float64) int64 {
fee := float64(amount) * rate
return int64(math.Round(fee))
}

View File

@@ -0,0 +1,371 @@
package service
import (
"context"
"errors"
"fmt"
"log/slog"
"time"
"encoding/json"
"github.com/go-redis/redis/v8"
"pay-bridge/internal/channel"
"pay-bridge/internal/errcode"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/sequence"
)
const (
orderExpireDefault = 30 * time.Minute
idempotentKeyPrefix = "idempotent:"
idempotentTTL = 24 * time.Hour
)
// CreateOrderReq 下单请求
type CreateOrderReq struct {
AppID string
ChannelCode string // 指定渠道,为空时使用 defaultChannelCode
MerchantOrderNo string
PayMethod model.PayMethod
Amount int64
ProfitSharingAmount int64
Subject string
NotifyURL string
ExpireMinutes int
Extra map[string]any
MerchantID string // 可选指定收款商户SaaS 多商户路由)
}
// CreateOrderResp 下单响应
type CreateOrderResp struct {
TradeNo string
PayCredential map[string]any
IsIdempotent bool // true=幂等返回
}
// TradeService 交易服务
type TradeService struct {
tradeRepo *repository.TradeOrderRepository
channelSvc *ChannelService
merchantSvc *MerchantService
seqSvc *sequence.Service
rdb *redis.Client
notifySvc *NotifyService
}
func NewTradeService(
tradeRepo *repository.TradeOrderRepository,
channelSvc *ChannelService,
seqSvc *sequence.Service,
rdb *redis.Client,
notifySvc *NotifyService,
merchantSvc *MerchantService,
) *TradeService {
return &TradeService{
tradeRepo: tradeRepo,
channelSvc: channelSvc,
merchantSvc: merchantSvc,
seqSvc: seqSvc,
rdb: rdb,
notifySvc: notifySvc,
}
}
// CreateOrder 统一下单(含幂等控制)
func (s *TradeService) CreateOrder(ctx context.Context, req *CreateOrderReq) (*CreateOrderResp, error) {
// 参数校验
if req.Amount <= 0 {
return nil, errors.New(errcode.ErrInvalidAmount)
}
// 幂等检查 - Redis SET NX
idempotentKey := fmt.Sprintf("%s%s:%s", idempotentKeyPrefix, req.AppID, req.MerchantOrderNo)
set, err := s.rdb.SetNX(ctx, idempotentKey, "1", idempotentTTL).Result()
if err != nil && !errors.Is(err, redis.Nil) {
slog.WarnContext(ctx, "redis idempotent check failed, fallback to db", "err", err)
}
if !set {
// 幂等命中,查询已有订单
order, err := s.tradeRepo.GetByMerchantOrderNo(ctx, req.AppID, req.MerchantOrderNo)
if err != nil {
return nil, err
}
if order == nil {
// Redis key 存在但 DB 无记录(极端情况),清除 key 重试
s.rdb.Del(ctx, idempotentKey)
return nil, errors.New(errcode.ErrOrderNotFound)
}
if order.Status == model.TradeStatusPaid {
return nil, errors.New(errcode.ErrOrderAlreadyPaid)
}
if order.Status == model.TradeStatusClosed {
return nil, errors.New(errcode.ErrOrderClosed)
}
return &CreateOrderResp{
TradeNo: order.TradeNo,
PayCredential: order.ChannelExtra,
IsIdempotent: true,
}, nil
}
// 生成交易号
tradeNo, err := s.seqSvc.NextTradeNo(ctx, req.AppID)
if err != nil {
s.rdb.Del(ctx, idempotentKey)
return nil, err
}
// 计算过期时间
expireMinutes := req.ExpireMinutes
if expireMinutes <= 0 {
expireMinutes = int(orderExpireDefault.Minutes())
}
expireTime := time.Now().Add(time.Duration(expireMinutes) * time.Minute)
// 确定渠道
channelCode := req.ChannelCode
if channelCode == "" {
channelCode = "HEEPAY" // 向后兼容默认值,建议调用方明确传入
}
// 可选指定收款商户SaaS 多商户路由),校验归属并按渠道注入 sub_merchant_id
if req.MerchantID != "" && s.merchantSvc != nil {
// 校验商户归属(只能使用本 appID 下的商户)
if _, err := s.merchantSvc.GetMerchantForApp(ctx, req.AppID, req.MerchantID); err != nil {
s.rdb.Del(ctx, idempotentKey)
return nil, err
}
// 按实际下单渠道取对应进件记录的 channel_merchant_id
channelMerchantID, err := s.merchantSvc.GetChannelMerchantID(ctx, req.MerchantID, channelCode)
if err != nil {
s.rdb.Del(ctx, idempotentKey)
return nil, err
}
if channelMerchantID != "" {
if req.Extra == nil {
req.Extra = make(map[string]any)
}
req.Extra["sub_merchant_id"] = channelMerchantID
}
}
// 创建本地订单记录CREATING 状态)
order := &model.TradeOrder{
TradeNo: tradeNo,
MerchantOrderNo: req.MerchantOrderNo,
AppID: req.AppID,
ChannelCode: channelCode,
PayMethod: req.PayMethod,
Amount: req.Amount,
ProfitSharingAmount: req.ProfitSharingAmount,
Subject: req.Subject,
NotifyURL: req.NotifyURL,
Status: model.TradeStatusCreating,
Extra: req.Extra,
ExpireTime: expireTime,
}
if err := s.tradeRepo.Create(ctx, order); err != nil {
s.rdb.Del(ctx, idempotentKey)
return nil, err
}
// 调用渠道下单
ch, err := s.channelSvc.GetChannel(ctx, req.AppID, channelCode)
if err != nil {
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
}
channelReq := &channel.CreateOrderReq{
AppID: req.AppID,
TradeNo: tradeNo,
MerchantOrderNo: req.MerchantOrderNo,
PayMethod: req.PayMethod,
Amount: req.Amount,
Subject: req.Subject,
NotifyURL: req.NotifyURL,
ExpireTime: expireTime,
Extra: req.Extra,
}
channelResp, err := ch.CreateOrder(ctx, channelReq)
if err != nil {
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusCreateFailed, nil)
return nil, fmt.Errorf("%s: %w", errcode.ErrChannelCreateFail, err)
}
// 更新为 PAYING 状态,保存支付凭证
updates := map[string]any{
"channel_trade_no": channelResp.ChannelTradeNo,
"channel_extra": model.JSONMap(channelResp.PayCredential),
}
s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusCreating, model.TradeStatusPaying, updates)
return &CreateOrderResp{
TradeNo: tradeNo,
PayCredential: channelResp.PayCredential,
}, nil
}
// QueryOrder 查询交易状态
func (s *TradeService) QueryOrder(ctx context.Context, appID, tradeNo string) (*model.TradeOrder, error) {
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil {
return nil, err
}
if order == nil || order.AppID != appID {
return nil, errors.New(errcode.ErrOrderNotFound)
}
// 如果处于 PAYING 状态,主动查询渠道同步最新状态
if order.Status == model.TradeStatusPaying {
s.syncOrderStatus(ctx, order)
}
return order, nil
}
// CloseOrder 关闭订单
func (s *TradeService) CloseOrder(ctx context.Context, appID, tradeNo string) error {
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil {
return err
}
if order == nil || order.AppID != appID {
return errors.New(errcode.ErrOrderNotFound)
}
if order.Status == model.TradeStatusPaid {
return errors.New(errcode.ErrOrderAlreadyPaid)
}
if order.Status == model.TradeStatusClosed {
return nil // 已关闭,幂等
}
if order.Status != model.TradeStatusPaying {
return errors.New(errcode.ErrOrderClosed)
}
// 调用渠道关单
ch, err := s.channelSvc.GetChannel(ctx, appID, order.ChannelCode)
if err != nil {
return err
}
if err := ch.CloseOrder(ctx, &channel.CloseOrderReq{
TradeNo: tradeNo,
ChannelTradeNo: order.ChannelTradeNo,
}); err != nil {
slog.WarnContext(ctx, "close order on channel failed", "trade_no", tradeNo, "err", err)
}
_, err = s.tradeRepo.UpdateStatus(ctx, tradeNo, model.TradeStatusPaying, model.TradeStatusClosed, nil)
return err
}
// HandleUpstreamNotify 处理上游支付回调(验签 + 状态更新 + 触发通知下游)
//
// 流程:先用临时无配置实例从 body 提取 trade_no → 查 DB 得 appID → 加载完整渠道配置验签
func (s *TradeService) HandleUpstreamNotify(ctx context.Context, channelCode string, rawBody []byte, headers map[string]string) (string, error) {
// 用只负责解析的临时渠道实例提取交易号(不需要密钥配置)
tempCh, err := channel.Get(channelCode, nil, channel.URLs{})
if err != nil {
return "fail", fmt.Errorf("unknown channel: %s", channelCode)
}
tradeNo, err := tempCh.ExtractTradeNo(rawBody)
if err != nil || tradeNo == "" {
return "fail", fmt.Errorf("extract trade_no from notify: %w", err)
}
order, err := s.tradeRepo.GetByTradeNo(ctx, tradeNo)
if err != nil {
return "fail", err
}
if order == nil {
slog.WarnContext(ctx, "notify: order not found", "trade_no", tradeNo)
return "fail", errors.New(errcode.ErrOrderNotFound)
}
// 加载完整渠道配置并验签
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, channelCode)
if err != nil {
return "fail", err
}
notifyData, err := ch.VerifyNotify(ctx, rawBody, headers)
if err != nil {
slog.WarnContext(ctx, "notify: verify sign failed", "trade_no", tradeNo, "err", err)
return "fail", errors.New(errcode.ErrChannelVerifyFail)
}
// 处理支付通知
if notifyData.NotifyType == model.NotifyTypePayment && notifyData.Status == model.TradeStatusPaid {
if err := s.handlePaymentSuccess(ctx, order, notifyData); err != nil {
return "fail", err
}
}
return "success", nil
}
// handlePaymentSuccess 处理支付成功
func (s *TradeService) handlePaymentSuccess(ctx context.Context, order *model.TradeOrder, data *channel.NotifyData) error {
updates := map[string]any{
"channel_trade_no": data.ChannelTradeNo,
}
if data.PayTime != nil {
updates["pay_time"] = data.PayTime
}
ok, err := s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
if err != nil {
return err
}
if !ok {
// 已被处理过(幂等),直接返回成功
slog.InfoContext(ctx, "payment notify idempotent", "trade_no", order.TradeNo)
return nil
}
// 异步触发下游通知
if s.notifySvc != nil {
go func() {
bgCtx := context.Background()
if err := s.notifySvc.SendNotify(bgCtx, order.TradeNo, model.NotifyTypePayment, order.NotifyURL); err != nil {
slog.Error("send notify failed", "trade_no", order.TradeNo, "err", err)
}
}()
}
return nil
}
// syncOrderStatus 主动查询渠道同步订单状态(查询接口兜底)
func (s *TradeService) syncOrderStatus(ctx context.Context, order *model.TradeOrder) {
ch, err := s.channelSvc.GetChannel(ctx, order.AppID, order.ChannelCode)
if err != nil {
slog.WarnContext(ctx, "syncOrderStatus: get channel failed", "trade_no", order.TradeNo, "err", err)
return
}
resp, err := ch.QueryOrder(ctx, &channel.QueryOrderReq{
TradeNo: order.TradeNo,
ChannelTradeNo: order.ChannelTradeNo,
})
if err != nil {
slog.WarnContext(ctx, "syncOrderStatus: query channel failed", "trade_no", order.TradeNo, "err", err)
return
}
if resp.Status == model.TradeStatusPaid {
updates := map[string]any{
"channel_trade_no": resp.ChannelTradeNo,
}
if resp.PayTime != nil {
updates["pay_time"] = resp.PayTime
}
s.tradeRepo.UpdateStatus(ctx, order.TradeNo, model.TradeStatusPaying, model.TradeStatusPaid, updates)
}
}
func parseJSON(data []byte, v any) error {
return json.Unmarshal(data, v)
}

View File

@@ -0,0 +1,161 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"pay-bridge/internal/model"
"pay-bridge/internal/repository"
"pay-bridge/pkg/crypto"
)
const (
wxTokenURL = "https://api.weixin.qq.com/cgi-bin/token"
wxSendMsgURL = "https://api.weixin.qq.com/cgi-bin/message/template/send"
accessTokenTTL = 90 * time.Minute // 微信 access_token 有效期 2h提前 30min 刷新
)
// WechatService 微信模板消息服务
type WechatService struct {
wechatRepo *repository.WechatRepository
cryptoKey string
httpClient *http.Client
// 内存缓存 access_token避免频繁调用微信接口
tokenCache map[string]*tokenEntry
}
type tokenEntry struct {
token string
expiresAt time.Time
}
func NewWechatService(wechatRepo *repository.WechatRepository, cryptoKey string) *WechatService {
return &WechatService{
wechatRepo: wechatRepo,
cryptoKey: cryptoKey,
httpClient: &http.Client{Timeout: 10 * time.Second},
tokenCache: make(map[string]*tokenEntry),
}
}
// SendPaymentNotify 发送支付成功通知
func (s *WechatService) SendPaymentNotify(ctx context.Context, appID, tradeNo, openID string, amount int64) error {
binding, err := s.wechatRepo.GetBinding(ctx, appID)
if err != nil || binding == nil {
return nil // 未配置微信通知,跳过
}
data := map[string]any{
"trade_no": map[string]string{"value": tradeNo},
"amount": map[string]string{"value": fmt.Sprintf("%.2f 元", float64(amount)/100)},
"time": map[string]string{"value": time.Now().Format("2006-01-02 15:04:05")},
}
return s.sendTemplate(ctx, appID, binding, openID, tradeNo, data)
}
// sendTemplate 发送模板消息
func (s *WechatService) sendTemplate(ctx context.Context, appID string, binding *model.WechatBinding,
openID, tradeNo string, data map[string]any) error {
log := &model.WechatMessageLog{
AppID: appID,
TradeNo: tradeNo,
OpenID: openID,
TemplateID: binding.TemplateID,
Status: model.WechatMessageStatusPending,
}
if err := s.wechatRepo.CreateMessageLog(ctx, log); err != nil {
return err
}
token, err := s.getAccessToken(ctx, binding)
if err != nil {
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": err.Error()}
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
return err
}
payload := map[string]any{
"touser": openID,
"template_id": binding.TemplateID,
"data": data,
}
body, _ := json.Marshal(payload)
url := fmt.Sprintf("%s?access_token=%s", wxSendMsgURL, token)
resp, err := s.httpClient.Post(url, "application/json", bytes.NewReader(body))
if err != nil {
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": err.Error()}
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
return err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
json.Unmarshal(respBody, &result)
now := time.Now()
if result.ErrCode == 0 {
updates := map[string]any{"status": model.WechatMessageStatusSuccess, "sent_at": now}
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
slog.InfoContext(ctx, "wechat template sent", "trade_no", tradeNo, "open_id", openID)
} else {
errMsg := fmt.Sprintf("errcode=%d errmsg=%s", result.ErrCode, result.ErrMsg)
updates := map[string]any{"status": model.WechatMessageStatusFailed, "err_msg": errMsg}
s.wechatRepo.UpdateMessageLog(ctx, log.ID, updates)
return fmt.Errorf("wechat send failed: %s", errMsg)
}
return nil
}
// getAccessToken 获取微信 access_token带内存缓存
func (s *WechatService) getAccessToken(ctx context.Context, binding *model.WechatBinding) (string, error) {
if entry, ok := s.tokenCache[binding.WxAppID]; ok && time.Now().Before(entry.expiresAt) {
return entry.token, nil
}
// 解密 secret
secret, err := crypto.Decrypt(binding.WxSecret, s.cryptoKey)
if err != nil {
return "", fmt.Errorf("decrypt wx secret: %w", err)
}
url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s",
wxTokenURL, binding.WxAppID, secret)
resp, err := s.httpClient.Get(url)
if err != nil {
return "", fmt.Errorf("get wx token: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var result struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(body, &result); err != nil {
return "", err
}
if result.ErrCode != 0 {
return "", fmt.Errorf("wx token error: %d %s", result.ErrCode, result.ErrMsg)
}
s.tokenCache[binding.WxAppID] = &tokenEntry{
token: result.AccessToken,
expiresAt: time.Now().Add(accessTokenTTL),
}
return result.AccessToken, nil
}