598 lines
16 KiB
Go
598 lines
16 KiB
Go
package heepay
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log/slog"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"pay-bridge/internal/channel"
|
||
"pay-bridge/internal/model"
|
||
)
|
||
|
||
const (
|
||
ChannelCode = "HEEPAY"
|
||
codeSuccess = "10000"
|
||
)
|
||
|
||
// cst 中国标准时间(UTC+8),避免依赖系统 time.Local
|
||
var cst = time.FixedZone("CST", 8*3600)
|
||
|
||
// Adapter 汇元支付适配器
|
||
type Adapter struct {
|
||
config *model.ChannelConfig
|
||
payURL string // 支付网关地址
|
||
merchantURL string // 进件网关地址
|
||
client *http.Client
|
||
}
|
||
|
||
func New(config *model.ChannelConfig, urls channel.URLs) channel.PaymentChannel {
|
||
return &Adapter{
|
||
config: config,
|
||
payURL: urls.PayURL,
|
||
merchantURL: urls.MerchantURL,
|
||
client: &http.Client{Timeout: 30 * time.Second},
|
||
}
|
||
}
|
||
|
||
func init() {
|
||
channel.Register(ChannelCode, New)
|
||
}
|
||
|
||
func (a *Adapter) Code() string { return ChannelCode }
|
||
|
||
// CreateOrder 统一下单
|
||
func (a *Adapter) CreateOrder(ctx context.Context, req *channel.CreateOrderReq) (*channel.CreateOrderResp, error) {
|
||
tradeType, err := mapPayMethod(req.PayMethod)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
biz := map[string]any{
|
||
"out_trade_no": req.TradeNo,
|
||
"body": req.Subject,
|
||
"total_fee": req.Amount,
|
||
"notify_url": req.NotifyURL,
|
||
"trade_type": tradeType,
|
||
}
|
||
if req.Extra != nil {
|
||
if openid, ok := req.Extra["openid"].(string); ok {
|
||
biz["openid"] = openid
|
||
}
|
||
if subAppID, ok := req.Extra["sub_appid"].(string); ok {
|
||
biz["sub_appid"] = subAppID
|
||
}
|
||
}
|
||
|
||
resp, err := a.post(ctx, "pay.heepay.trade.create", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
channelTradeNo, _ := resp["transaction_id"].(string)
|
||
if channelTradeNo == "" {
|
||
channelTradeNo, _ = resp["prepay_id"].(string)
|
||
}
|
||
|
||
raw, _ := json.Marshal(resp)
|
||
return &channel.CreateOrderResp{
|
||
ChannelTradeNo: channelTradeNo,
|
||
PayCredential: resp,
|
||
RawResponse: raw,
|
||
}, nil
|
||
}
|
||
|
||
// QueryOrder 查询订单
|
||
func (a *Adapter) QueryOrder(ctx context.Context, req *channel.QueryOrderReq) (*channel.QueryOrderResp, error) {
|
||
biz := map[string]any{
|
||
"out_trade_no": req.TradeNo,
|
||
}
|
||
if req.ChannelTradeNo != "" {
|
||
biz["transaction_id"] = req.ChannelTradeNo
|
||
}
|
||
|
||
resp, err := a.post(ctx, "pay.heepay.trade.query", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result := &channel.QueryOrderResp{
|
||
TradeNo: req.TradeNo,
|
||
ChannelTradeNo: strVal(resp, "transaction_id"),
|
||
}
|
||
|
||
switch strVal(resp, "trade_state") {
|
||
case "SUCCESS":
|
||
result.Status = model.TradeStatusPaid
|
||
if s := strVal(resp, "pay_time"); s != "" {
|
||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||
result.PayTime = &t
|
||
}
|
||
case "CLOSED", "REVOKED":
|
||
result.Status = model.TradeStatusClosed
|
||
case "PAYERROR":
|
||
result.Status = model.TradeStatusFailed
|
||
default:
|
||
result.Status = model.TradeStatusPaying
|
||
}
|
||
|
||
if v, ok := resp["total_fee"].(float64); ok {
|
||
result.Amount = int64(v)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// CloseOrder 关闭订单
|
||
func (a *Adapter) CloseOrder(ctx context.Context, req *channel.CloseOrderReq) error {
|
||
biz := map[string]any{
|
||
"out_trade_no": req.TradeNo,
|
||
}
|
||
_, err := a.post(ctx, "pay.heepay.trade.close", biz)
|
||
return err
|
||
}
|
||
|
||
// Refund 发起退款
|
||
func (a *Adapter) Refund(ctx context.Context, req *channel.RefundReq) (*channel.RefundResp, error) {
|
||
biz := map[string]any{
|
||
"out_trade_no": req.TradeNo,
|
||
"out_refund_no": req.RefundNo,
|
||
"total_fee": req.TotalAmount,
|
||
"refund_fee": req.RefundAmount,
|
||
"refund_desc": req.Reason,
|
||
"notify_url": req.NotifyURL,
|
||
}
|
||
if req.ChannelTradeNo != "" {
|
||
biz["transaction_id"] = req.ChannelTradeNo
|
||
}
|
||
|
||
resp, err := a.post(ctx, "pay.heepay.trade.refund", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &channel.RefundResp{
|
||
RefundNo: req.RefundNo,
|
||
ChannelRefundNo: strVal(resp, "refund_id"),
|
||
Status: model.RefundStatusProcessing,
|
||
}, nil
|
||
}
|
||
|
||
// QueryRefund 查询退款
|
||
func (a *Adapter) QueryRefund(ctx context.Context, req *channel.QueryRefundReq) (*channel.QueryRefundResp, error) {
|
||
biz := map[string]any{
|
||
"out_refund_no": req.RefundNo,
|
||
}
|
||
if req.ChannelRefundNo != "" {
|
||
biz["refund_id"] = req.ChannelRefundNo
|
||
}
|
||
|
||
resp, err := a.post(ctx, "pay.heepay.trade.refund.query", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
result := &channel.QueryRefundResp{
|
||
RefundNo: req.RefundNo,
|
||
ChannelRefundNo: strVal(resp, "refund_id"),
|
||
}
|
||
|
||
switch strVal(resp, "refund_status") {
|
||
case "SUCCESS":
|
||
result.Status = model.RefundStatusSuccess
|
||
if s := strVal(resp, "refund_success_time"); s != "" {
|
||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||
result.RefundTime = &t
|
||
}
|
||
case "PROCESSING":
|
||
result.Status = model.RefundStatusProcessing
|
||
case "FAIL":
|
||
result.Status = model.RefundStatusFailed
|
||
default:
|
||
result.Status = model.RefundStatusPending
|
||
}
|
||
|
||
if v, ok := resp["refund_fee"].(float64); ok {
|
||
result.RefundAmount = int64(v)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// ExtractTradeNo 从回调 body 中提取平台交易号(在验签前调用,仅做 JSON 解析)
|
||
func (a *Adapter) ExtractTradeNo(rawBody []byte) (string, error) {
|
||
var outer struct {
|
||
Data json.RawMessage `json:"data"`
|
||
}
|
||
if err := json.Unmarshal(rawBody, &outer); err != nil {
|
||
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal body: %w", err)
|
||
}
|
||
var bizData struct {
|
||
OutTradeNo string `json:"out_trade_no"`
|
||
}
|
||
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
|
||
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal data: %w", err)
|
||
}
|
||
if bizData.OutTradeNo == "" {
|
||
return "", fmt.Errorf("heepay ExtractTradeNo: out_trade_no is empty")
|
||
}
|
||
return bizData.OutTradeNo, nil
|
||
}
|
||
|
||
// VerifyNotify 验证上游回调签名并解析
|
||
// 汇元回调的签名规则与响应验签相同:公共参数+data整体 按字典序排列,用汇元公钥验签
|
||
func (a *Adapter) VerifyNotify(ctx context.Context, rawBody []byte, headers map[string]string) (*channel.NotifyData, error) {
|
||
// 解析外层公共参数
|
||
var outer struct {
|
||
Code string `json:"code"`
|
||
Msg string `json:"msg"`
|
||
TradeID string `json:"trade_id"`
|
||
Sign string `json:"sign"`
|
||
Data json.RawMessage `json:"data"`
|
||
}
|
||
if err := json.Unmarshal(rawBody, &outer); err != nil {
|
||
return nil, fmt.Errorf("unmarshal notify body: %w", err)
|
||
}
|
||
|
||
// 验签(data 整体作为字符串参与验签)
|
||
verifyParams := map[string]string{
|
||
"code": outer.Code,
|
||
"msg": outer.Msg,
|
||
"trade_id": outer.TradeID,
|
||
"data": string(outer.Data),
|
||
}
|
||
if err := VerifyResponse(verifyParams, outer.Sign, a.config.PublicKey); err != nil {
|
||
return nil, fmt.Errorf("verify notify sign: %w", err)
|
||
}
|
||
|
||
if outer.Code != codeSuccess {
|
||
return nil, fmt.Errorf("notify code not success: %s %s", outer.Code, outer.Msg)
|
||
}
|
||
|
||
// 解析业务数据
|
||
var bizData map[string]any
|
||
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
|
||
return nil, fmt.Errorf("unmarshal notify data: %w", err)
|
||
}
|
||
|
||
notifyType := strVal(bizData, "notify_type")
|
||
result := &channel.NotifyData{
|
||
TradeNo: strVal(bizData, "out_trade_no"),
|
||
ChannelTradeNo: strVal(bizData, "transaction_id"),
|
||
RawData: rawBody,
|
||
}
|
||
|
||
switch notifyType {
|
||
case "payment":
|
||
result.NotifyType = model.NotifyTypePayment
|
||
result.Status = model.TradeStatusPaid
|
||
if v, ok := bizData["total_fee"].(float64); ok {
|
||
result.Amount = int64(v)
|
||
}
|
||
if s := strVal(bizData, "pay_time"); s != "" {
|
||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||
result.PayTime = &t
|
||
}
|
||
case "refund":
|
||
result.NotifyType = model.NotifyTypeRefund
|
||
result.RefundNo = strVal(bizData, "out_refund_no")
|
||
result.RefundStatus = model.RefundStatusSuccess
|
||
if v, ok := bizData["refund_fee"].(float64); ok {
|
||
result.RefundAmount = int64(v)
|
||
}
|
||
default:
|
||
return nil, fmt.Errorf("unknown notify_type: %s", notifyType)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// ProfitSharing 分账
|
||
func (a *Adapter) ProfitSharing(ctx context.Context, req *channel.ProfitSharingReq) (*channel.ProfitSharingResp, error) {
|
||
biz := map[string]any{
|
||
"transaction_id": req.ChannelTradeNo,
|
||
"out_order_no": req.SharingNo,
|
||
"receivers": []map[string]any{
|
||
{
|
||
"type": "MERCHANT_ID",
|
||
"account": req.ReceiverMerchantID,
|
||
"amount": req.Amount,
|
||
"description": "分润",
|
||
},
|
||
},
|
||
}
|
||
resp, err := a.post(ctx, "pay.heepay.trade.profitsharing", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &channel.ProfitSharingResp{
|
||
SharingNo: req.SharingNo,
|
||
ChannelSharingNo: strVal(resp, "order_id"),
|
||
}, nil
|
||
}
|
||
|
||
// RollbackProfitSharing 回退分账
|
||
func (a *Adapter) RollbackProfitSharing(ctx context.Context, req *channel.RollbackSharingReq) error {
|
||
biz := map[string]any{
|
||
"transaction_id": req.TradeNo,
|
||
"out_order_no": req.SharingNo,
|
||
"out_return_no": req.SharingNo + "_R",
|
||
"return_account": a.config.MerchantID,
|
||
"return_amount": 0,
|
||
"description": "退款回退",
|
||
}
|
||
_, err := a.post(ctx, "pay.heepay.trade.profitsharing.return", biz)
|
||
return err
|
||
}
|
||
|
||
// DownloadBill 下载对账账单
|
||
func (a *Adapter) DownloadBill(ctx context.Context, req *channel.DownloadBillReq) (*channel.BillData, error) {
|
||
biz := map[string]any{
|
||
"bill_date": req.BillDate,
|
||
"bill_type": "ALL",
|
||
}
|
||
resp, err := a.post(ctx, "pay.heepay.trade.bill.download", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 账单为 CSV 内容,暂存 raw 由对账服务解析
|
||
_ = resp
|
||
return &channel.BillData{}, nil
|
||
}
|
||
|
||
// UploadFile 上传文件到汇元进件网关(customer.file.upload)
|
||
func (a *Adapter) UploadFile(ctx context.Context, req *channel.UploadFileReq) (*channel.UploadFileResp, error) {
|
||
// 计算 MD5 签名
|
||
h := md5.Sum(req.FileContent)
|
||
fileSign := hex.EncodeToString(h[:])
|
||
|
||
// Base64 编码文件内容
|
||
fileContentB64 := base64.StdEncoding.EncodeToString(req.FileContent)
|
||
|
||
biz := map[string]any{
|
||
"file_content": fileContentB64,
|
||
"file_sign": fileSign,
|
||
"file_media_type": req.FileMediaType,
|
||
}
|
||
|
||
resp, err := a.postToMerchant(ctx, "customer.file.upload", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
fileID := strVal(resp, "file_id")
|
||
if fileID == "" {
|
||
return nil, fmt.Errorf("heepay upload file: empty file_id in response")
|
||
}
|
||
return &channel.UploadFileResp{FileID: fileID}, nil
|
||
}
|
||
|
||
// MerchantApply 企业入网申请(customer.enter.enterprise.apply)
|
||
func (a *Adapter) MerchantApply(ctx context.Context, req *channel.MerchantApplyReq) (*channel.MerchantApplyResp, error) {
|
||
resp, err := a.postToMerchant(ctx, "customer.enter.enterprise.apply", req.BizContent)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &channel.MerchantApplyResp{
|
||
RequestNo: strVal(resp, "request_no"),
|
||
}, nil
|
||
}
|
||
|
||
// QueryMerchantStatus 查询商户状态(customer.enter.query)
|
||
func (a *Adapter) QueryMerchantStatus(ctx context.Context, channelMerchantID string) (*channel.MerchantStatusResp, error) {
|
||
biz := map[string]any{
|
||
"request_no": channelMerchantID,
|
||
}
|
||
resp, err := a.postToMerchant(ctx, "customer.enter.query", biz)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &channel.MerchantStatusResp{
|
||
ChannelMerchantID: strVal(resp, "merch_id"),
|
||
Status: strVal(resp, "audit_state"),
|
||
FailReason: strVal(resp, "audit_reason"),
|
||
}, nil
|
||
}
|
||
|
||
// --- 底层通信 ---
|
||
|
||
// heepayRequest 公共请求结构
|
||
type heepayRequest struct {
|
||
AppID string `json:"app_id"`
|
||
Method string `json:"method"`
|
||
Format string `json:"format"`
|
||
Charset string `json:"charset"`
|
||
SignType string `json:"sign_type"`
|
||
Timestamp string `json:"timestamp"`
|
||
Version string `json:"version"`
|
||
BizContent string `json:"biz_content"`
|
||
Sign string `json:"sign"`
|
||
}
|
||
|
||
// heepayResponse 公共响应结构
|
||
// 注意:支付网关 code 为字符串("10000"),进件网关 code 为数字(10000),
|
||
// 使用 json.RawMessage 兼容,再通过 codeStr() 统一转为字符串比较。
|
||
type heepayResponse struct {
|
||
Code json.RawMessage `json:"code"`
|
||
Msg string `json:"msg"`
|
||
SubCode string `json:"sub_code"`
|
||
SubMsg string `json:"sub_msg"`
|
||
Sign string `json:"sign"`
|
||
TradeID string `json:"trade_id"`
|
||
Data json.RawMessage `json:"data"`
|
||
}
|
||
|
||
// codeStr 将 code 字段统一转为字符串(兼容字符串和数字两种格式)
|
||
func (r *heepayResponse) codeStr() string {
|
||
if len(r.Code) == 0 {
|
||
return ""
|
||
}
|
||
// 去掉引号(字符串形式)或直接返回数字字符串
|
||
raw := string(r.Code)
|
||
if len(raw) >= 2 && raw[0] == '"' {
|
||
return raw[1 : len(raw)-1]
|
||
}
|
||
return raw
|
||
}
|
||
|
||
// post 调用汇元支付网关(payURL)
|
||
func (a *Adapter) post(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
|
||
return a.postURL(ctx, a.payURL, method, bizParams)
|
||
}
|
||
|
||
// postToMerchant 调用汇元进件网关(merchantURL)
|
||
func (a *Adapter) postToMerchant(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
|
||
return a.postURL(ctx, a.merchantURL, method, bizParams)
|
||
}
|
||
|
||
// post 调用汇元支付网关(payURL)
|
||
// 注意:原 post 方法内部调用 postURL
|
||
func (a *Adapter) postURL(ctx context.Context, gatewayURL, method string, bizParams map[string]any) (map[string]any, error) {
|
||
bizJSON, err := json.Marshal(bizParams)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("marshal biz_content: %w", err)
|
||
}
|
||
bizContent := string(bizJSON)
|
||
|
||
timestamp := time.Now().In(cst).Format("2006-01-02 15:04:05")
|
||
signParams := map[string]string{
|
||
"app_id": a.config.MerchantID,
|
||
"method": method,
|
||
"format": "JSON",
|
||
"charset": "utf-8",
|
||
"sign_type": SignTypeRSA2,
|
||
"timestamp": timestamp,
|
||
"version": "1.0",
|
||
"biz_content": bizContent,
|
||
}
|
||
|
||
sign, err := Sign(signParams, a.config.PrivateKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("sign request: %w", err)
|
||
}
|
||
|
||
reqBody := heepayRequest{
|
||
AppID: a.config.MerchantID,
|
||
Method: method,
|
||
Format: "JSON",
|
||
Charset: "utf-8",
|
||
SignType: SignTypeRSA2,
|
||
Timestamp: timestamp,
|
||
Version: "1.0",
|
||
BizContent: bizContent,
|
||
Sign: sign,
|
||
}
|
||
|
||
bodyBytes, err := json.Marshal(reqBody)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
slog.DebugContext(ctx, "heepay request", "method", method, "url", gatewayURL)
|
||
|
||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, gatewayURL, bytes.NewReader(bodyBytes))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
|
||
httpResp, err := a.client.Do(httpReq)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("heepay http request: %w", err)
|
||
}
|
||
defer httpResp.Body.Close()
|
||
|
||
respBytes, err := io.ReadAll(httpResp.Body)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
slog.DebugContext(ctx, "heepay response", "method", method, "body", string(respBytes))
|
||
|
||
if len(respBytes) == 0 {
|
||
return nil, fmt.Errorf("heepay empty response from %s (method=%s)", gatewayURL, method)
|
||
}
|
||
|
||
// 先解析为 raw map,确保所有字段都能参与验签
|
||
var rawMap map[string]json.RawMessage
|
||
if err := json.Unmarshal(respBytes, &rawMap); err != nil {
|
||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||
}
|
||
|
||
signRaw, _ := rawMap["sign"]
|
||
signB64 := strings.Trim(string(signRaw), `"`)
|
||
|
||
// 构建验签参数:所有非 sign、非空字段
|
||
verifyParams := make(map[string]string, len(rawMap))
|
||
for k, v := range rawMap {
|
||
if k == "sign" {
|
||
continue
|
||
}
|
||
raw := string(v)
|
||
if raw == "null" || raw == `""` {
|
||
continue
|
||
}
|
||
// 字符串类型去引号并反转义,数字/对象/数组直接用 raw 值
|
||
if len(raw) >= 2 && raw[0] == '"' {
|
||
var s string
|
||
json.Unmarshal(v, &s)
|
||
verifyParams[k] = s
|
||
} else {
|
||
verifyParams[k] = raw
|
||
}
|
||
}
|
||
|
||
if err := VerifyResponse(verifyParams, signB64, a.config.PublicKey); err != nil {
|
||
return nil, fmt.Errorf("verify response sign: %w", err)
|
||
}
|
||
|
||
// 再用强类型 struct 方便后续处理
|
||
var resp heepayResponse
|
||
json.Unmarshal(respBytes, &resp)
|
||
codeStr := resp.codeStr()
|
||
|
||
if codeStr != codeSuccess {
|
||
return nil, fmt.Errorf("heepay error [%s] %s: %s %s", codeStr, resp.Msg, resp.SubCode, resp.SubMsg)
|
||
}
|
||
|
||
var bizResult map[string]any
|
||
if len(resp.Data) > 0 {
|
||
if err := json.Unmarshal(resp.Data, &bizResult); err != nil {
|
||
return nil, fmt.Errorf("unmarshal response data: %w", err)
|
||
}
|
||
}
|
||
return bizResult, nil
|
||
}
|
||
|
||
// mapPayMethod 将内部支付方式映射为汇元 trade_type
|
||
func mapPayMethod(m model.PayMethod) (string, error) {
|
||
mapping := map[model.PayMethod]string{
|
||
model.PayMethodWechatJSAPI: "JSAPI",
|
||
model.PayMethodWechatH5: "MWEB",
|
||
model.PayMethodWechatNative: "NATIVE",
|
||
model.PayMethodWechatMini: "MINIAPP",
|
||
model.PayMethodAlipay: "ALI_NATIVE",
|
||
model.PayMethodQuickPay: "QUICK",
|
||
}
|
||
t, ok := mapping[m]
|
||
if !ok {
|
||
return "", fmt.Errorf("unsupported pay method: %s", m)
|
||
}
|
||
return t, nil
|
||
}
|
||
|
||
func strVal(m map[string]any, key string) string {
|
||
if v, ok := m[key]; ok {
|
||
if s, ok := v.(string); ok {
|
||
return s
|
||
}
|
||
return fmt.Sprintf("%v", v)
|
||
}
|
||
return ""
|
||
}
|