Files
pay-bridge/backend/internal/channel/heepay/adapter.go
2026-03-13 15:51:59 +08:00

598 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package heepay
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"time"
"pay-bridge/internal/channel"
"pay-bridge/internal/model"
)
const (
ChannelCode = "HEEPAY"
codeSuccess = "10000"
)
// cst 中国标准时间UTC+8避免依赖系统 time.Local
var cst = time.FixedZone("CST", 8*3600)
// Adapter 汇元支付适配器
type Adapter struct {
config *model.ChannelConfig
payURL string // 支付网关地址
merchantURL string // 进件网关地址
client *http.Client
}
func New(config *model.ChannelConfig, urls channel.URLs) channel.PaymentChannel {
return &Adapter{
config: config,
payURL: urls.PayURL,
merchantURL: urls.MerchantURL,
client: &http.Client{Timeout: 30 * time.Second},
}
}
func init() {
channel.Register(ChannelCode, New)
}
func (a *Adapter) Code() string { return ChannelCode }
// CreateOrder 统一下单
func (a *Adapter) CreateOrder(ctx context.Context, req *channel.CreateOrderReq) (*channel.CreateOrderResp, error) {
tradeType, err := mapPayMethod(req.PayMethod)
if err != nil {
return nil, err
}
biz := map[string]any{
"out_trade_no": req.TradeNo,
"body": req.Subject,
"total_fee": req.Amount,
"notify_url": req.NotifyURL,
"trade_type": tradeType,
}
if req.Extra != nil {
if openid, ok := req.Extra["openid"].(string); ok {
biz["openid"] = openid
}
if subAppID, ok := req.Extra["sub_appid"].(string); ok {
biz["sub_appid"] = subAppID
}
}
resp, err := a.post(ctx, "pay.heepay.trade.create", biz)
if err != nil {
return nil, err
}
channelTradeNo, _ := resp["transaction_id"].(string)
if channelTradeNo == "" {
channelTradeNo, _ = resp["prepay_id"].(string)
}
raw, _ := json.Marshal(resp)
return &channel.CreateOrderResp{
ChannelTradeNo: channelTradeNo,
PayCredential: resp,
RawResponse: raw,
}, nil
}
// QueryOrder 查询订单
func (a *Adapter) QueryOrder(ctx context.Context, req *channel.QueryOrderReq) (*channel.QueryOrderResp, error) {
biz := map[string]any{
"out_trade_no": req.TradeNo,
}
if req.ChannelTradeNo != "" {
biz["transaction_id"] = req.ChannelTradeNo
}
resp, err := a.post(ctx, "pay.heepay.trade.query", biz)
if err != nil {
return nil, err
}
result := &channel.QueryOrderResp{
TradeNo: req.TradeNo,
ChannelTradeNo: strVal(resp, "transaction_id"),
}
switch strVal(resp, "trade_state") {
case "SUCCESS":
result.Status = model.TradeStatusPaid
if s := strVal(resp, "pay_time"); s != "" {
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
result.PayTime = &t
}
case "CLOSED", "REVOKED":
result.Status = model.TradeStatusClosed
case "PAYERROR":
result.Status = model.TradeStatusFailed
default:
result.Status = model.TradeStatusPaying
}
if v, ok := resp["total_fee"].(float64); ok {
result.Amount = int64(v)
}
return result, nil
}
// CloseOrder 关闭订单
func (a *Adapter) CloseOrder(ctx context.Context, req *channel.CloseOrderReq) error {
biz := map[string]any{
"out_trade_no": req.TradeNo,
}
_, err := a.post(ctx, "pay.heepay.trade.close", biz)
return err
}
// Refund 发起退款
func (a *Adapter) Refund(ctx context.Context, req *channel.RefundReq) (*channel.RefundResp, error) {
biz := map[string]any{
"out_trade_no": req.TradeNo,
"out_refund_no": req.RefundNo,
"total_fee": req.TotalAmount,
"refund_fee": req.RefundAmount,
"refund_desc": req.Reason,
"notify_url": req.NotifyURL,
}
if req.ChannelTradeNo != "" {
biz["transaction_id"] = req.ChannelTradeNo
}
resp, err := a.post(ctx, "pay.heepay.trade.refund", biz)
if err != nil {
return nil, err
}
return &channel.RefundResp{
RefundNo: req.RefundNo,
ChannelRefundNo: strVal(resp, "refund_id"),
Status: model.RefundStatusProcessing,
}, nil
}
// QueryRefund 查询退款
func (a *Adapter) QueryRefund(ctx context.Context, req *channel.QueryRefundReq) (*channel.QueryRefundResp, error) {
biz := map[string]any{
"out_refund_no": req.RefundNo,
}
if req.ChannelRefundNo != "" {
biz["refund_id"] = req.ChannelRefundNo
}
resp, err := a.post(ctx, "pay.heepay.trade.refund.query", biz)
if err != nil {
return nil, err
}
result := &channel.QueryRefundResp{
RefundNo: req.RefundNo,
ChannelRefundNo: strVal(resp, "refund_id"),
}
switch strVal(resp, "refund_status") {
case "SUCCESS":
result.Status = model.RefundStatusSuccess
if s := strVal(resp, "refund_success_time"); s != "" {
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
result.RefundTime = &t
}
case "PROCESSING":
result.Status = model.RefundStatusProcessing
case "FAIL":
result.Status = model.RefundStatusFailed
default:
result.Status = model.RefundStatusPending
}
if v, ok := resp["refund_fee"].(float64); ok {
result.RefundAmount = int64(v)
}
return result, nil
}
// ExtractTradeNo 从回调 body 中提取平台交易号(在验签前调用,仅做 JSON 解析)
func (a *Adapter) ExtractTradeNo(rawBody []byte) (string, error) {
var outer struct {
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(rawBody, &outer); err != nil {
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal body: %w", err)
}
var bizData struct {
OutTradeNo string `json:"out_trade_no"`
}
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal data: %w", err)
}
if bizData.OutTradeNo == "" {
return "", fmt.Errorf("heepay ExtractTradeNo: out_trade_no is empty")
}
return bizData.OutTradeNo, nil
}
// VerifyNotify 验证上游回调签名并解析
// 汇元回调的签名规则与响应验签相同:公共参数+data整体 按字典序排列,用汇元公钥验签
func (a *Adapter) VerifyNotify(ctx context.Context, rawBody []byte, headers map[string]string) (*channel.NotifyData, error) {
// 解析外层公共参数
var outer struct {
Code string `json:"code"`
Msg string `json:"msg"`
TradeID string `json:"trade_id"`
Sign string `json:"sign"`
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(rawBody, &outer); err != nil {
return nil, fmt.Errorf("unmarshal notify body: %w", err)
}
// 验签data 整体作为字符串参与验签)
verifyParams := map[string]string{
"code": outer.Code,
"msg": outer.Msg,
"trade_id": outer.TradeID,
"data": string(outer.Data),
}
if err := VerifyResponse(verifyParams, outer.Sign, a.config.PublicKey); err != nil {
return nil, fmt.Errorf("verify notify sign: %w", err)
}
if outer.Code != codeSuccess {
return nil, fmt.Errorf("notify code not success: %s %s", outer.Code, outer.Msg)
}
// 解析业务数据
var bizData map[string]any
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
return nil, fmt.Errorf("unmarshal notify data: %w", err)
}
notifyType := strVal(bizData, "notify_type")
result := &channel.NotifyData{
TradeNo: strVal(bizData, "out_trade_no"),
ChannelTradeNo: strVal(bizData, "transaction_id"),
RawData: rawBody,
}
switch notifyType {
case "payment":
result.NotifyType = model.NotifyTypePayment
result.Status = model.TradeStatusPaid
if v, ok := bizData["total_fee"].(float64); ok {
result.Amount = int64(v)
}
if s := strVal(bizData, "pay_time"); s != "" {
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
result.PayTime = &t
}
case "refund":
result.NotifyType = model.NotifyTypeRefund
result.RefundNo = strVal(bizData, "out_refund_no")
result.RefundStatus = model.RefundStatusSuccess
if v, ok := bizData["refund_fee"].(float64); ok {
result.RefundAmount = int64(v)
}
default:
return nil, fmt.Errorf("unknown notify_type: %s", notifyType)
}
return result, nil
}
// ProfitSharing 分账
func (a *Adapter) ProfitSharing(ctx context.Context, req *channel.ProfitSharingReq) (*channel.ProfitSharingResp, error) {
biz := map[string]any{
"transaction_id": req.ChannelTradeNo,
"out_order_no": req.SharingNo,
"receivers": []map[string]any{
{
"type": "MERCHANT_ID",
"account": req.ReceiverMerchantID,
"amount": req.Amount,
"description": "分润",
},
},
}
resp, err := a.post(ctx, "pay.heepay.trade.profitsharing", biz)
if err != nil {
return nil, err
}
return &channel.ProfitSharingResp{
SharingNo: req.SharingNo,
ChannelSharingNo: strVal(resp, "order_id"),
}, nil
}
// RollbackProfitSharing 回退分账
func (a *Adapter) RollbackProfitSharing(ctx context.Context, req *channel.RollbackSharingReq) error {
biz := map[string]any{
"transaction_id": req.TradeNo,
"out_order_no": req.SharingNo,
"out_return_no": req.SharingNo + "_R",
"return_account": a.config.MerchantID,
"return_amount": 0,
"description": "退款回退",
}
_, err := a.post(ctx, "pay.heepay.trade.profitsharing.return", biz)
return err
}
// DownloadBill 下载对账账单
func (a *Adapter) DownloadBill(ctx context.Context, req *channel.DownloadBillReq) (*channel.BillData, error) {
biz := map[string]any{
"bill_date": req.BillDate,
"bill_type": "ALL",
}
resp, err := a.post(ctx, "pay.heepay.trade.bill.download", biz)
if err != nil {
return nil, err
}
// 账单为 CSV 内容,暂存 raw 由对账服务解析
_ = resp
return &channel.BillData{}, nil
}
// UploadFile 上传文件到汇元进件网关customer.file.upload
func (a *Adapter) UploadFile(ctx context.Context, req *channel.UploadFileReq) (*channel.UploadFileResp, error) {
// 计算 MD5 签名
h := md5.Sum(req.FileContent)
fileSign := hex.EncodeToString(h[:])
// Base64 编码文件内容
fileContentB64 := base64.StdEncoding.EncodeToString(req.FileContent)
biz := map[string]any{
"file_content": fileContentB64,
"file_sign": fileSign,
"file_media_type": req.FileMediaType,
}
resp, err := a.postToMerchant(ctx, "customer.file.upload", biz)
if err != nil {
return nil, err
}
fileID := strVal(resp, "file_id")
if fileID == "" {
return nil, fmt.Errorf("heepay upload file: empty file_id in response")
}
return &channel.UploadFileResp{FileID: fileID}, nil
}
// MerchantApply 企业入网申请customer.enter.enterprise.apply
func (a *Adapter) MerchantApply(ctx context.Context, req *channel.MerchantApplyReq) (*channel.MerchantApplyResp, error) {
resp, err := a.postToMerchant(ctx, "customer.enter.enterprise.apply", req.BizContent)
if err != nil {
return nil, err
}
return &channel.MerchantApplyResp{
RequestNo: strVal(resp, "request_no"),
}, nil
}
// QueryMerchantStatus 查询商户状态customer.enter.query
func (a *Adapter) QueryMerchantStatus(ctx context.Context, channelMerchantID string) (*channel.MerchantStatusResp, error) {
biz := map[string]any{
"request_no": channelMerchantID,
}
resp, err := a.postToMerchant(ctx, "customer.enter.query", biz)
if err != nil {
return nil, err
}
return &channel.MerchantStatusResp{
ChannelMerchantID: strVal(resp, "merch_id"),
Status: strVal(resp, "audit_state"),
FailReason: strVal(resp, "audit_reason"),
}, nil
}
// --- 底层通信 ---
// heepayRequest 公共请求结构
type heepayRequest struct {
AppID string `json:"app_id"`
Method string `json:"method"`
Format string `json:"format"`
Charset string `json:"charset"`
SignType string `json:"sign_type"`
Timestamp string `json:"timestamp"`
Version string `json:"version"`
BizContent string `json:"biz_content"`
Sign string `json:"sign"`
}
// heepayResponse 公共响应结构
// 注意:支付网关 code 为字符串("10000"),进件网关 code 为数字10000
// 使用 json.RawMessage 兼容,再通过 codeStr() 统一转为字符串比较。
type heepayResponse struct {
Code json.RawMessage `json:"code"`
Msg string `json:"msg"`
SubCode string `json:"sub_code"`
SubMsg string `json:"sub_msg"`
Sign string `json:"sign"`
TradeID string `json:"trade_id"`
Data json.RawMessage `json:"data"`
}
// codeStr 将 code 字段统一转为字符串(兼容字符串和数字两种格式)
func (r *heepayResponse) codeStr() string {
if len(r.Code) == 0 {
return ""
}
// 去掉引号(字符串形式)或直接返回数字字符串
raw := string(r.Code)
if len(raw) >= 2 && raw[0] == '"' {
return raw[1 : len(raw)-1]
}
return raw
}
// post 调用汇元支付网关payURL
func (a *Adapter) post(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
return a.postURL(ctx, a.payURL, method, bizParams)
}
// postToMerchant 调用汇元进件网关merchantURL
func (a *Adapter) postToMerchant(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
return a.postURL(ctx, a.merchantURL, method, bizParams)
}
// post 调用汇元支付网关payURL
// 注意:原 post 方法内部调用 postURL
func (a *Adapter) postURL(ctx context.Context, gatewayURL, method string, bizParams map[string]any) (map[string]any, error) {
bizJSON, err := json.Marshal(bizParams)
if err != nil {
return nil, fmt.Errorf("marshal biz_content: %w", err)
}
bizContent := string(bizJSON)
timestamp := time.Now().In(cst).Format("2006-01-02 15:04:05")
signParams := map[string]string{
"app_id": a.config.MerchantID,
"method": method,
"format": "JSON",
"charset": "utf-8",
"sign_type": SignTypeRSA2,
"timestamp": timestamp,
"version": "1.0",
"biz_content": bizContent,
}
sign, err := Sign(signParams, a.config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("sign request: %w", err)
}
reqBody := heepayRequest{
AppID: a.config.MerchantID,
Method: method,
Format: "JSON",
Charset: "utf-8",
SignType: SignTypeRSA2,
Timestamp: timestamp,
Version: "1.0",
BizContent: bizContent,
Sign: sign,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
slog.DebugContext(ctx, "heepay request", "method", method, "url", gatewayURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, gatewayURL, bytes.NewReader(bodyBytes))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpResp, err := a.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("heepay http request: %w", err)
}
defer httpResp.Body.Close()
respBytes, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, err
}
slog.DebugContext(ctx, "heepay response", "method", method, "body", string(respBytes))
if len(respBytes) == 0 {
return nil, fmt.Errorf("heepay empty response from %s (method=%s)", gatewayURL, method)
}
// 先解析为 raw map确保所有字段都能参与验签
var rawMap map[string]json.RawMessage
if err := json.Unmarshal(respBytes, &rawMap); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
signRaw, _ := rawMap["sign"]
signB64 := strings.Trim(string(signRaw), `"`)
// 构建验签参数:所有非 sign、非空字段
verifyParams := make(map[string]string, len(rawMap))
for k, v := range rawMap {
if k == "sign" {
continue
}
raw := string(v)
if raw == "null" || raw == `""` {
continue
}
// 字符串类型去引号并反转义,数字/对象/数组直接用 raw 值
if len(raw) >= 2 && raw[0] == '"' {
var s string
json.Unmarshal(v, &s)
verifyParams[k] = s
} else {
verifyParams[k] = raw
}
}
if err := VerifyResponse(verifyParams, signB64, a.config.PublicKey); err != nil {
return nil, fmt.Errorf("verify response sign: %w", err)
}
// 再用强类型 struct 方便后续处理
var resp heepayResponse
json.Unmarshal(respBytes, &resp)
codeStr := resp.codeStr()
if codeStr != codeSuccess {
return nil, fmt.Errorf("heepay error [%s] %s: %s %s", codeStr, resp.Msg, resp.SubCode, resp.SubMsg)
}
var bizResult map[string]any
if len(resp.Data) > 0 {
if err := json.Unmarshal(resp.Data, &bizResult); err != nil {
return nil, fmt.Errorf("unmarshal response data: %w", err)
}
}
return bizResult, nil
}
// mapPayMethod 将内部支付方式映射为汇元 trade_type
func mapPayMethod(m model.PayMethod) (string, error) {
mapping := map[model.PayMethod]string{
model.PayMethodWechatJSAPI: "JSAPI",
model.PayMethodWechatH5: "MWEB",
model.PayMethodWechatNative: "NATIVE",
model.PayMethodWechatMini: "MINIAPP",
model.PayMethodAlipay: "ALI_NATIVE",
model.PayMethodQuickPay: "QUICK",
}
t, ok := mapping[m]
if !ok {
return "", fmt.Errorf("unsupported pay method: %s", m)
}
return t, nil
}
func strVal(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
return fmt.Sprintf("%v", v)
}
return ""
}