package heepay import ( "bytes" "context" "crypto/md5" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "io" "log/slog" "net/http" "strings" "time" "pay-bridge/internal/channel" "pay-bridge/internal/model" ) const ( ChannelCode = "HEEPAY" codeSuccess = "10000" ) // cst 中国标准时间(UTC+8),避免依赖系统 time.Local var cst = time.FixedZone("CST", 8*3600) // Adapter 汇元支付适配器 type Adapter struct { config *model.ChannelConfig payURL string // 支付网关地址 merchantURL string // 进件网关地址 client *http.Client } func New(config *model.ChannelConfig, urls channel.URLs) channel.PaymentChannel { return &Adapter{ config: config, payURL: urls.PayURL, merchantURL: urls.MerchantURL, client: &http.Client{Timeout: 30 * time.Second}, } } func init() { channel.Register(ChannelCode, New) } func (a *Adapter) Code() string { return ChannelCode } // CreateOrder 统一下单 func (a *Adapter) CreateOrder(ctx context.Context, req *channel.CreateOrderReq) (*channel.CreateOrderResp, error) { tradeType, err := mapPayMethod(req.PayMethod) if err != nil { return nil, err } biz := map[string]any{ "out_trade_no": req.TradeNo, "body": req.Subject, "total_fee": req.Amount, "notify_url": req.NotifyURL, "trade_type": tradeType, } if req.Extra != nil { if openid, ok := req.Extra["openid"].(string); ok { biz["openid"] = openid } if subAppID, ok := req.Extra["sub_appid"].(string); ok { biz["sub_appid"] = subAppID } } resp, err := a.post(ctx, "pay.heepay.trade.create", biz) if err != nil { return nil, err } channelTradeNo, _ := resp["transaction_id"].(string) if channelTradeNo == "" { channelTradeNo, _ = resp["prepay_id"].(string) } raw, _ := json.Marshal(resp) return &channel.CreateOrderResp{ ChannelTradeNo: channelTradeNo, PayCredential: resp, RawResponse: raw, }, nil } // QueryOrder 查询订单 func (a *Adapter) QueryOrder(ctx context.Context, req *channel.QueryOrderReq) (*channel.QueryOrderResp, error) { biz := map[string]any{ "out_trade_no": req.TradeNo, } if req.ChannelTradeNo != "" { biz["transaction_id"] = req.ChannelTradeNo } resp, err := a.post(ctx, "pay.heepay.trade.query", biz) if err != nil { return nil, err } result := &channel.QueryOrderResp{ TradeNo: req.TradeNo, ChannelTradeNo: strVal(resp, "transaction_id"), } switch strVal(resp, "trade_state") { case "SUCCESS": result.Status = model.TradeStatusPaid if s := strVal(resp, "pay_time"); s != "" { t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst) result.PayTime = &t } case "CLOSED", "REVOKED": result.Status = model.TradeStatusClosed case "PAYERROR": result.Status = model.TradeStatusFailed default: result.Status = model.TradeStatusPaying } if v, ok := resp["total_fee"].(float64); ok { result.Amount = int64(v) } return result, nil } // CloseOrder 关闭订单 func (a *Adapter) CloseOrder(ctx context.Context, req *channel.CloseOrderReq) error { biz := map[string]any{ "out_trade_no": req.TradeNo, } _, err := a.post(ctx, "pay.heepay.trade.close", biz) return err } // Refund 发起退款 func (a *Adapter) Refund(ctx context.Context, req *channel.RefundReq) (*channel.RefundResp, error) { biz := map[string]any{ "out_trade_no": req.TradeNo, "out_refund_no": req.RefundNo, "total_fee": req.TotalAmount, "refund_fee": req.RefundAmount, "refund_desc": req.Reason, "notify_url": req.NotifyURL, } if req.ChannelTradeNo != "" { biz["transaction_id"] = req.ChannelTradeNo } resp, err := a.post(ctx, "pay.heepay.trade.refund", biz) if err != nil { return nil, err } return &channel.RefundResp{ RefundNo: req.RefundNo, ChannelRefundNo: strVal(resp, "refund_id"), Status: model.RefundStatusProcessing, }, nil } // QueryRefund 查询退款 func (a *Adapter) QueryRefund(ctx context.Context, req *channel.QueryRefundReq) (*channel.QueryRefundResp, error) { biz := map[string]any{ "out_refund_no": req.RefundNo, } if req.ChannelRefundNo != "" { biz["refund_id"] = req.ChannelRefundNo } resp, err := a.post(ctx, "pay.heepay.trade.refund.query", biz) if err != nil { return nil, err } result := &channel.QueryRefundResp{ RefundNo: req.RefundNo, ChannelRefundNo: strVal(resp, "refund_id"), } switch strVal(resp, "refund_status") { case "SUCCESS": result.Status = model.RefundStatusSuccess if s := strVal(resp, "refund_success_time"); s != "" { t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst) result.RefundTime = &t } case "PROCESSING": result.Status = model.RefundStatusProcessing case "FAIL": result.Status = model.RefundStatusFailed default: result.Status = model.RefundStatusPending } if v, ok := resp["refund_fee"].(float64); ok { result.RefundAmount = int64(v) } return result, nil } // ExtractTradeNo 从回调 body 中提取平台交易号(在验签前调用,仅做 JSON 解析) func (a *Adapter) ExtractTradeNo(rawBody []byte) (string, error) { var outer struct { Data json.RawMessage `json:"data"` } if err := json.Unmarshal(rawBody, &outer); err != nil { return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal body: %w", err) } var bizData struct { OutTradeNo string `json:"out_trade_no"` } if err := json.Unmarshal(outer.Data, &bizData); err != nil { return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal data: %w", err) } if bizData.OutTradeNo == "" { return "", fmt.Errorf("heepay ExtractTradeNo: out_trade_no is empty") } return bizData.OutTradeNo, nil } // VerifyNotify 验证上游回调签名并解析 // 汇元回调的签名规则与响应验签相同:公共参数+data整体 按字典序排列,用汇元公钥验签 func (a *Adapter) VerifyNotify(ctx context.Context, rawBody []byte, headers map[string]string) (*channel.NotifyData, error) { // 解析外层公共参数 var outer struct { Code string `json:"code"` Msg string `json:"msg"` TradeID string `json:"trade_id"` Sign string `json:"sign"` Data json.RawMessage `json:"data"` } if err := json.Unmarshal(rawBody, &outer); err != nil { return nil, fmt.Errorf("unmarshal notify body: %w", err) } // 验签(data 整体作为字符串参与验签) verifyParams := map[string]string{ "code": outer.Code, "msg": outer.Msg, "trade_id": outer.TradeID, "data": string(outer.Data), } if err := VerifyResponse(verifyParams, outer.Sign, a.config.PublicKey); err != nil { return nil, fmt.Errorf("verify notify sign: %w", err) } if outer.Code != codeSuccess { return nil, fmt.Errorf("notify code not success: %s %s", outer.Code, outer.Msg) } // 解析业务数据 var bizData map[string]any if err := json.Unmarshal(outer.Data, &bizData); err != nil { return nil, fmt.Errorf("unmarshal notify data: %w", err) } notifyType := strVal(bizData, "notify_type") result := &channel.NotifyData{ TradeNo: strVal(bizData, "out_trade_no"), ChannelTradeNo: strVal(bizData, "transaction_id"), RawData: rawBody, } switch notifyType { case "payment": result.NotifyType = model.NotifyTypePayment result.Status = model.TradeStatusPaid if v, ok := bizData["total_fee"].(float64); ok { result.Amount = int64(v) } if s := strVal(bizData, "pay_time"); s != "" { t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst) result.PayTime = &t } case "refund": result.NotifyType = model.NotifyTypeRefund result.RefundNo = strVal(bizData, "out_refund_no") result.RefundStatus = model.RefundStatusSuccess if v, ok := bizData["refund_fee"].(float64); ok { result.RefundAmount = int64(v) } default: return nil, fmt.Errorf("unknown notify_type: %s", notifyType) } return result, nil } // ProfitSharing 分账 func (a *Adapter) ProfitSharing(ctx context.Context, req *channel.ProfitSharingReq) (*channel.ProfitSharingResp, error) { biz := map[string]any{ "transaction_id": req.ChannelTradeNo, "out_order_no": req.SharingNo, "receivers": []map[string]any{ { "type": "MERCHANT_ID", "account": req.ReceiverMerchantID, "amount": req.Amount, "description": "分润", }, }, } resp, err := a.post(ctx, "pay.heepay.trade.profitsharing", biz) if err != nil { return nil, err } return &channel.ProfitSharingResp{ SharingNo: req.SharingNo, ChannelSharingNo: strVal(resp, "order_id"), }, nil } // RollbackProfitSharing 回退分账 func (a *Adapter) RollbackProfitSharing(ctx context.Context, req *channel.RollbackSharingReq) error { biz := map[string]any{ "transaction_id": req.TradeNo, "out_order_no": req.SharingNo, "out_return_no": req.SharingNo + "_R", "return_account": a.config.MerchantID, "return_amount": 0, "description": "退款回退", } _, err := a.post(ctx, "pay.heepay.trade.profitsharing.return", biz) return err } // DownloadBill 下载对账账单 func (a *Adapter) DownloadBill(ctx context.Context, req *channel.DownloadBillReq) (*channel.BillData, error) { biz := map[string]any{ "bill_date": req.BillDate, "bill_type": "ALL", } resp, err := a.post(ctx, "pay.heepay.trade.bill.download", biz) if err != nil { return nil, err } // 账单为 CSV 内容,暂存 raw 由对账服务解析 _ = resp return &channel.BillData{}, nil } // UploadFile 上传文件到汇元进件网关(customer.file.upload) func (a *Adapter) UploadFile(ctx context.Context, req *channel.UploadFileReq) (*channel.UploadFileResp, error) { // 计算 MD5 签名 h := md5.Sum(req.FileContent) fileSign := hex.EncodeToString(h[:]) // Base64 编码文件内容 fileContentB64 := base64.StdEncoding.EncodeToString(req.FileContent) biz := map[string]any{ "file_content": fileContentB64, "file_sign": fileSign, "file_media_type": req.FileMediaType, } resp, err := a.postToMerchant(ctx, "customer.file.upload", biz) if err != nil { return nil, err } fileID := strVal(resp, "file_id") if fileID == "" { return nil, fmt.Errorf("heepay upload file: empty file_id in response") } return &channel.UploadFileResp{FileID: fileID}, nil } // MerchantApply 企业入网申请(customer.enter.enterprise.apply) func (a *Adapter) MerchantApply(ctx context.Context, req *channel.MerchantApplyReq) (*channel.MerchantApplyResp, error) { resp, err := a.postToMerchant(ctx, "customer.enter.enterprise.apply", req.BizContent) if err != nil { return nil, err } return &channel.MerchantApplyResp{ RequestNo: strVal(resp, "request_no"), }, nil } // QueryMerchantStatus 查询商户状态(customer.enter.query) func (a *Adapter) QueryMerchantStatus(ctx context.Context, channelMerchantID string) (*channel.MerchantStatusResp, error) { biz := map[string]any{ "request_no": channelMerchantID, } resp, err := a.postToMerchant(ctx, "customer.enter.query", biz) if err != nil { return nil, err } return &channel.MerchantStatusResp{ ChannelMerchantID: strVal(resp, "merch_id"), Status: strVal(resp, "audit_state"), FailReason: strVal(resp, "audit_reason"), }, nil } // --- 底层通信 --- // heepayRequest 公共请求结构 type heepayRequest struct { AppID string `json:"app_id"` Method string `json:"method"` Format string `json:"format"` Charset string `json:"charset"` SignType string `json:"sign_type"` Timestamp string `json:"timestamp"` Version string `json:"version"` BizContent string `json:"biz_content"` Sign string `json:"sign"` } // heepayResponse 公共响应结构 // 注意:支付网关 code 为字符串("10000"),进件网关 code 为数字(10000), // 使用 json.RawMessage 兼容,再通过 codeStr() 统一转为字符串比较。 type heepayResponse struct { Code json.RawMessage `json:"code"` Msg string `json:"msg"` SubCode string `json:"sub_code"` SubMsg string `json:"sub_msg"` Sign string `json:"sign"` TradeID string `json:"trade_id"` Data json.RawMessage `json:"data"` } // codeStr 将 code 字段统一转为字符串(兼容字符串和数字两种格式) func (r *heepayResponse) codeStr() string { if len(r.Code) == 0 { return "" } // 去掉引号(字符串形式)或直接返回数字字符串 raw := string(r.Code) if len(raw) >= 2 && raw[0] == '"' { return raw[1 : len(raw)-1] } return raw } // post 调用汇元支付网关(payURL) func (a *Adapter) post(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) { return a.postURL(ctx, a.payURL, method, bizParams) } // postToMerchant 调用汇元进件网关(merchantURL) func (a *Adapter) postToMerchant(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) { return a.postURL(ctx, a.merchantURL, method, bizParams) } // post 调用汇元支付网关(payURL) // 注意:原 post 方法内部调用 postURL func (a *Adapter) postURL(ctx context.Context, gatewayURL, method string, bizParams map[string]any) (map[string]any, error) { bizJSON, err := json.Marshal(bizParams) if err != nil { return nil, fmt.Errorf("marshal biz_content: %w", err) } bizContent := string(bizJSON) timestamp := time.Now().In(cst).Format("2006-01-02 15:04:05") signParams := map[string]string{ "app_id": a.config.MerchantID, "method": method, "format": "JSON", "charset": "utf-8", "sign_type": SignTypeRSA2, "timestamp": timestamp, "version": "1.0", "biz_content": bizContent, } sign, err := Sign(signParams, a.config.PrivateKey) if err != nil { return nil, fmt.Errorf("sign request: %w", err) } reqBody := heepayRequest{ AppID: a.config.MerchantID, Method: method, Format: "JSON", Charset: "utf-8", SignType: SignTypeRSA2, Timestamp: timestamp, Version: "1.0", BizContent: bizContent, Sign: sign, } bodyBytes, err := json.Marshal(reqBody) if err != nil { return nil, err } slog.DebugContext(ctx, "heepay request", "method", method, "url", gatewayURL) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, gatewayURL, bytes.NewReader(bodyBytes)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") httpResp, err := a.client.Do(httpReq) if err != nil { return nil, fmt.Errorf("heepay http request: %w", err) } defer httpResp.Body.Close() respBytes, err := io.ReadAll(httpResp.Body) if err != nil { return nil, err } slog.DebugContext(ctx, "heepay response", "method", method, "body", string(respBytes)) if len(respBytes) == 0 { return nil, fmt.Errorf("heepay empty response from %s (method=%s)", gatewayURL, method) } // 先解析为 raw map,确保所有字段都能参与验签 var rawMap map[string]json.RawMessage if err := json.Unmarshal(respBytes, &rawMap); err != nil { return nil, fmt.Errorf("unmarshal response: %w", err) } signRaw, _ := rawMap["sign"] signB64 := strings.Trim(string(signRaw), `"`) // 构建验签参数:所有非 sign、非空字段 verifyParams := make(map[string]string, len(rawMap)) for k, v := range rawMap { if k == "sign" { continue } raw := string(v) if raw == "null" || raw == `""` { continue } // 字符串类型去引号并反转义,数字/对象/数组直接用 raw 值 if len(raw) >= 2 && raw[0] == '"' { var s string json.Unmarshal(v, &s) verifyParams[k] = s } else { verifyParams[k] = raw } } if err := VerifyResponse(verifyParams, signB64, a.config.PublicKey); err != nil { return nil, fmt.Errorf("verify response sign: %w", err) } // 再用强类型 struct 方便后续处理 var resp heepayResponse json.Unmarshal(respBytes, &resp) codeStr := resp.codeStr() if codeStr != codeSuccess { return nil, fmt.Errorf("heepay error [%s] %s: %s %s", codeStr, resp.Msg, resp.SubCode, resp.SubMsg) } var bizResult map[string]any if len(resp.Data) > 0 { if err := json.Unmarshal(resp.Data, &bizResult); err != nil { return nil, fmt.Errorf("unmarshal response data: %w", err) } } return bizResult, nil } // mapPayMethod 将内部支付方式映射为汇元 trade_type func mapPayMethod(m model.PayMethod) (string, error) { mapping := map[model.PayMethod]string{ model.PayMethodWechatJSAPI: "JSAPI", model.PayMethodWechatH5: "MWEB", model.PayMethodWechatNative: "NATIVE", model.PayMethodWechatMini: "MINIAPP", model.PayMethodAlipay: "ALI_NATIVE", model.PayMethodQuickPay: "QUICK", } t, ok := mapping[m] if !ok { return "", fmt.Errorf("unsupported pay method: %s", m) } return t, nil } func strVal(m map[string]any, key string) string { if v, ok := m[key]; ok { if s, ok := v.(string); ok { return s } return fmt.Sprintf("%v", v) } return "" }