draft
This commit is contained in:
597
backend/internal/channel/heepay/adapter.go
Normal file
597
backend/internal/channel/heepay/adapter.go
Normal file
@@ -0,0 +1,597 @@
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelCode = "HEEPAY"
|
||||
codeSuccess = "10000"
|
||||
)
|
||||
|
||||
// cst 中国标准时间(UTC+8),避免依赖系统 time.Local
|
||||
var cst = time.FixedZone("CST", 8*3600)
|
||||
|
||||
// Adapter 汇元支付适配器
|
||||
type Adapter struct {
|
||||
config *model.ChannelConfig
|
||||
payURL string // 支付网关地址
|
||||
merchantURL string // 进件网关地址
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func New(config *model.ChannelConfig, urls channel.URLs) channel.PaymentChannel {
|
||||
return &Adapter{
|
||||
config: config,
|
||||
payURL: urls.PayURL,
|
||||
merchantURL: urls.MerchantURL,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
channel.Register(ChannelCode, New)
|
||||
}
|
||||
|
||||
func (a *Adapter) Code() string { return ChannelCode }
|
||||
|
||||
// CreateOrder 统一下单
|
||||
func (a *Adapter) CreateOrder(ctx context.Context, req *channel.CreateOrderReq) (*channel.CreateOrderResp, error) {
|
||||
tradeType, err := mapPayMethod(req.PayMethod)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
"body": req.Subject,
|
||||
"total_fee": req.Amount,
|
||||
"notify_url": req.NotifyURL,
|
||||
"trade_type": tradeType,
|
||||
}
|
||||
if req.Extra != nil {
|
||||
if openid, ok := req.Extra["openid"].(string); ok {
|
||||
biz["openid"] = openid
|
||||
}
|
||||
if subAppID, ok := req.Extra["sub_appid"].(string); ok {
|
||||
biz["sub_appid"] = subAppID
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.create", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
channelTradeNo, _ := resp["transaction_id"].(string)
|
||||
if channelTradeNo == "" {
|
||||
channelTradeNo, _ = resp["prepay_id"].(string)
|
||||
}
|
||||
|
||||
raw, _ := json.Marshal(resp)
|
||||
return &channel.CreateOrderResp{
|
||||
ChannelTradeNo: channelTradeNo,
|
||||
PayCredential: resp,
|
||||
RawResponse: raw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryOrder 查询订单
|
||||
func (a *Adapter) QueryOrder(ctx context.Context, req *channel.QueryOrderReq) (*channel.QueryOrderResp, error) {
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
}
|
||||
if req.ChannelTradeNo != "" {
|
||||
biz["transaction_id"] = req.ChannelTradeNo
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.query", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &channel.QueryOrderResp{
|
||||
TradeNo: req.TradeNo,
|
||||
ChannelTradeNo: strVal(resp, "transaction_id"),
|
||||
}
|
||||
|
||||
switch strVal(resp, "trade_state") {
|
||||
case "SUCCESS":
|
||||
result.Status = model.TradeStatusPaid
|
||||
if s := strVal(resp, "pay_time"); s != "" {
|
||||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||||
result.PayTime = &t
|
||||
}
|
||||
case "CLOSED", "REVOKED":
|
||||
result.Status = model.TradeStatusClosed
|
||||
case "PAYERROR":
|
||||
result.Status = model.TradeStatusFailed
|
||||
default:
|
||||
result.Status = model.TradeStatusPaying
|
||||
}
|
||||
|
||||
if v, ok := resp["total_fee"].(float64); ok {
|
||||
result.Amount = int64(v)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CloseOrder 关闭订单
|
||||
func (a *Adapter) CloseOrder(ctx context.Context, req *channel.CloseOrderReq) error {
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
}
|
||||
_, err := a.post(ctx, "pay.heepay.trade.close", biz)
|
||||
return err
|
||||
}
|
||||
|
||||
// Refund 发起退款
|
||||
func (a *Adapter) Refund(ctx context.Context, req *channel.RefundReq) (*channel.RefundResp, error) {
|
||||
biz := map[string]any{
|
||||
"out_trade_no": req.TradeNo,
|
||||
"out_refund_no": req.RefundNo,
|
||||
"total_fee": req.TotalAmount,
|
||||
"refund_fee": req.RefundAmount,
|
||||
"refund_desc": req.Reason,
|
||||
"notify_url": req.NotifyURL,
|
||||
}
|
||||
if req.ChannelTradeNo != "" {
|
||||
biz["transaction_id"] = req.ChannelTradeNo
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.refund", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &channel.RefundResp{
|
||||
RefundNo: req.RefundNo,
|
||||
ChannelRefundNo: strVal(resp, "refund_id"),
|
||||
Status: model.RefundStatusProcessing,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryRefund 查询退款
|
||||
func (a *Adapter) QueryRefund(ctx context.Context, req *channel.QueryRefundReq) (*channel.QueryRefundResp, error) {
|
||||
biz := map[string]any{
|
||||
"out_refund_no": req.RefundNo,
|
||||
}
|
||||
if req.ChannelRefundNo != "" {
|
||||
biz["refund_id"] = req.ChannelRefundNo
|
||||
}
|
||||
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.refund.query", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &channel.QueryRefundResp{
|
||||
RefundNo: req.RefundNo,
|
||||
ChannelRefundNo: strVal(resp, "refund_id"),
|
||||
}
|
||||
|
||||
switch strVal(resp, "refund_status") {
|
||||
case "SUCCESS":
|
||||
result.Status = model.RefundStatusSuccess
|
||||
if s := strVal(resp, "refund_success_time"); s != "" {
|
||||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||||
result.RefundTime = &t
|
||||
}
|
||||
case "PROCESSING":
|
||||
result.Status = model.RefundStatusProcessing
|
||||
case "FAIL":
|
||||
result.Status = model.RefundStatusFailed
|
||||
default:
|
||||
result.Status = model.RefundStatusPending
|
||||
}
|
||||
|
||||
if v, ok := resp["refund_fee"].(float64); ok {
|
||||
result.RefundAmount = int64(v)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExtractTradeNo 从回调 body 中提取平台交易号(在验签前调用,仅做 JSON 解析)
|
||||
func (a *Adapter) ExtractTradeNo(rawBody []byte) (string, error) {
|
||||
var outer struct {
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rawBody, &outer); err != nil {
|
||||
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal body: %w", err)
|
||||
}
|
||||
var bizData struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
}
|
||||
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
|
||||
return "", fmt.Errorf("heepay ExtractTradeNo: unmarshal data: %w", err)
|
||||
}
|
||||
if bizData.OutTradeNo == "" {
|
||||
return "", fmt.Errorf("heepay ExtractTradeNo: out_trade_no is empty")
|
||||
}
|
||||
return bizData.OutTradeNo, nil
|
||||
}
|
||||
|
||||
// VerifyNotify 验证上游回调签名并解析
|
||||
// 汇元回调的签名规则与响应验签相同:公共参数+data整体 按字典序排列,用汇元公钥验签
|
||||
func (a *Adapter) VerifyNotify(ctx context.Context, rawBody []byte, headers map[string]string) (*channel.NotifyData, error) {
|
||||
// 解析外层公共参数
|
||||
var outer struct {
|
||||
Code string `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
TradeID string `json:"trade_id"`
|
||||
Sign string `json:"sign"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(rawBody, &outer); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal notify body: %w", err)
|
||||
}
|
||||
|
||||
// 验签(data 整体作为字符串参与验签)
|
||||
verifyParams := map[string]string{
|
||||
"code": outer.Code,
|
||||
"msg": outer.Msg,
|
||||
"trade_id": outer.TradeID,
|
||||
"data": string(outer.Data),
|
||||
}
|
||||
if err := VerifyResponse(verifyParams, outer.Sign, a.config.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("verify notify sign: %w", err)
|
||||
}
|
||||
|
||||
if outer.Code != codeSuccess {
|
||||
return nil, fmt.Errorf("notify code not success: %s %s", outer.Code, outer.Msg)
|
||||
}
|
||||
|
||||
// 解析业务数据
|
||||
var bizData map[string]any
|
||||
if err := json.Unmarshal(outer.Data, &bizData); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal notify data: %w", err)
|
||||
}
|
||||
|
||||
notifyType := strVal(bizData, "notify_type")
|
||||
result := &channel.NotifyData{
|
||||
TradeNo: strVal(bizData, "out_trade_no"),
|
||||
ChannelTradeNo: strVal(bizData, "transaction_id"),
|
||||
RawData: rawBody,
|
||||
}
|
||||
|
||||
switch notifyType {
|
||||
case "payment":
|
||||
result.NotifyType = model.NotifyTypePayment
|
||||
result.Status = model.TradeStatusPaid
|
||||
if v, ok := bizData["total_fee"].(float64); ok {
|
||||
result.Amount = int64(v)
|
||||
}
|
||||
if s := strVal(bizData, "pay_time"); s != "" {
|
||||
t, _ := time.ParseInLocation("2006-01-02 15:04:05", s, cst)
|
||||
result.PayTime = &t
|
||||
}
|
||||
case "refund":
|
||||
result.NotifyType = model.NotifyTypeRefund
|
||||
result.RefundNo = strVal(bizData, "out_refund_no")
|
||||
result.RefundStatus = model.RefundStatusSuccess
|
||||
if v, ok := bizData["refund_fee"].(float64); ok {
|
||||
result.RefundAmount = int64(v)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown notify_type: %s", notifyType)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ProfitSharing 分账
|
||||
func (a *Adapter) ProfitSharing(ctx context.Context, req *channel.ProfitSharingReq) (*channel.ProfitSharingResp, error) {
|
||||
biz := map[string]any{
|
||||
"transaction_id": req.ChannelTradeNo,
|
||||
"out_order_no": req.SharingNo,
|
||||
"receivers": []map[string]any{
|
||||
{
|
||||
"type": "MERCHANT_ID",
|
||||
"account": req.ReceiverMerchantID,
|
||||
"amount": req.Amount,
|
||||
"description": "分润",
|
||||
},
|
||||
},
|
||||
}
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.profitsharing", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.ProfitSharingResp{
|
||||
SharingNo: req.SharingNo,
|
||||
ChannelSharingNo: strVal(resp, "order_id"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RollbackProfitSharing 回退分账
|
||||
func (a *Adapter) RollbackProfitSharing(ctx context.Context, req *channel.RollbackSharingReq) error {
|
||||
biz := map[string]any{
|
||||
"transaction_id": req.TradeNo,
|
||||
"out_order_no": req.SharingNo,
|
||||
"out_return_no": req.SharingNo + "_R",
|
||||
"return_account": a.config.MerchantID,
|
||||
"return_amount": 0,
|
||||
"description": "退款回退",
|
||||
}
|
||||
_, err := a.post(ctx, "pay.heepay.trade.profitsharing.return", biz)
|
||||
return err
|
||||
}
|
||||
|
||||
// DownloadBill 下载对账账单
|
||||
func (a *Adapter) DownloadBill(ctx context.Context, req *channel.DownloadBillReq) (*channel.BillData, error) {
|
||||
biz := map[string]any{
|
||||
"bill_date": req.BillDate,
|
||||
"bill_type": "ALL",
|
||||
}
|
||||
resp, err := a.post(ctx, "pay.heepay.trade.bill.download", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 账单为 CSV 内容,暂存 raw 由对账服务解析
|
||||
_ = resp
|
||||
return &channel.BillData{}, nil
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到汇元进件网关(customer.file.upload)
|
||||
func (a *Adapter) UploadFile(ctx context.Context, req *channel.UploadFileReq) (*channel.UploadFileResp, error) {
|
||||
// 计算 MD5 签名
|
||||
h := md5.Sum(req.FileContent)
|
||||
fileSign := hex.EncodeToString(h[:])
|
||||
|
||||
// Base64 编码文件内容
|
||||
fileContentB64 := base64.StdEncoding.EncodeToString(req.FileContent)
|
||||
|
||||
biz := map[string]any{
|
||||
"file_content": fileContentB64,
|
||||
"file_sign": fileSign,
|
||||
"file_media_type": req.FileMediaType,
|
||||
}
|
||||
|
||||
resp, err := a.postToMerchant(ctx, "customer.file.upload", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fileID := strVal(resp, "file_id")
|
||||
if fileID == "" {
|
||||
return nil, fmt.Errorf("heepay upload file: empty file_id in response")
|
||||
}
|
||||
return &channel.UploadFileResp{FileID: fileID}, nil
|
||||
}
|
||||
|
||||
// MerchantApply 企业入网申请(customer.enter.enterprise.apply)
|
||||
func (a *Adapter) MerchantApply(ctx context.Context, req *channel.MerchantApplyReq) (*channel.MerchantApplyResp, error) {
|
||||
resp, err := a.postToMerchant(ctx, "customer.enter.enterprise.apply", req.BizContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.MerchantApplyResp{
|
||||
RequestNo: strVal(resp, "request_no"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// QueryMerchantStatus 查询商户状态(customer.enter.query)
|
||||
func (a *Adapter) QueryMerchantStatus(ctx context.Context, channelMerchantID string) (*channel.MerchantStatusResp, error) {
|
||||
biz := map[string]any{
|
||||
"request_no": channelMerchantID,
|
||||
}
|
||||
resp, err := a.postToMerchant(ctx, "customer.enter.query", biz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &channel.MerchantStatusResp{
|
||||
ChannelMerchantID: strVal(resp, "merch_id"),
|
||||
Status: strVal(resp, "audit_state"),
|
||||
FailReason: strVal(resp, "audit_reason"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// --- 底层通信 ---
|
||||
|
||||
// heepayRequest 公共请求结构
|
||||
type heepayRequest struct {
|
||||
AppID string `json:"app_id"`
|
||||
Method string `json:"method"`
|
||||
Format string `json:"format"`
|
||||
Charset string `json:"charset"`
|
||||
SignType string `json:"sign_type"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
BizContent string `json:"biz_content"`
|
||||
Sign string `json:"sign"`
|
||||
}
|
||||
|
||||
// heepayResponse 公共响应结构
|
||||
// 注意:支付网关 code 为字符串("10000"),进件网关 code 为数字(10000),
|
||||
// 使用 json.RawMessage 兼容,再通过 codeStr() 统一转为字符串比较。
|
||||
type heepayResponse struct {
|
||||
Code json.RawMessage `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
SubCode string `json:"sub_code"`
|
||||
SubMsg string `json:"sub_msg"`
|
||||
Sign string `json:"sign"`
|
||||
TradeID string `json:"trade_id"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
// codeStr 将 code 字段统一转为字符串(兼容字符串和数字两种格式)
|
||||
func (r *heepayResponse) codeStr() string {
|
||||
if len(r.Code) == 0 {
|
||||
return ""
|
||||
}
|
||||
// 去掉引号(字符串形式)或直接返回数字字符串
|
||||
raw := string(r.Code)
|
||||
if len(raw) >= 2 && raw[0] == '"' {
|
||||
return raw[1 : len(raw)-1]
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// post 调用汇元支付网关(payURL)
|
||||
func (a *Adapter) post(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
|
||||
return a.postURL(ctx, a.payURL, method, bizParams)
|
||||
}
|
||||
|
||||
// postToMerchant 调用汇元进件网关(merchantURL)
|
||||
func (a *Adapter) postToMerchant(ctx context.Context, method string, bizParams map[string]any) (map[string]any, error) {
|
||||
return a.postURL(ctx, a.merchantURL, method, bizParams)
|
||||
}
|
||||
|
||||
// post 调用汇元支付网关(payURL)
|
||||
// 注意:原 post 方法内部调用 postURL
|
||||
func (a *Adapter) postURL(ctx context.Context, gatewayURL, method string, bizParams map[string]any) (map[string]any, error) {
|
||||
bizJSON, err := json.Marshal(bizParams)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal biz_content: %w", err)
|
||||
}
|
||||
bizContent := string(bizJSON)
|
||||
|
||||
timestamp := time.Now().In(cst).Format("2006-01-02 15:04:05")
|
||||
signParams := map[string]string{
|
||||
"app_id": a.config.MerchantID,
|
||||
"method": method,
|
||||
"format": "JSON",
|
||||
"charset": "utf-8",
|
||||
"sign_type": SignTypeRSA2,
|
||||
"timestamp": timestamp,
|
||||
"version": "1.0",
|
||||
"biz_content": bizContent,
|
||||
}
|
||||
|
||||
sign, err := Sign(signParams, a.config.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign request: %w", err)
|
||||
}
|
||||
|
||||
reqBody := heepayRequest{
|
||||
AppID: a.config.MerchantID,
|
||||
Method: method,
|
||||
Format: "JSON",
|
||||
Charset: "utf-8",
|
||||
SignType: SignTypeRSA2,
|
||||
Timestamp: timestamp,
|
||||
Version: "1.0",
|
||||
BizContent: bizContent,
|
||||
Sign: sign,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.DebugContext(ctx, "heepay request", "method", method, "url", gatewayURL)
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, gatewayURL, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
httpResp, err := a.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("heepay http request: %w", err)
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.DebugContext(ctx, "heepay response", "method", method, "body", string(respBytes))
|
||||
|
||||
if len(respBytes) == 0 {
|
||||
return nil, fmt.Errorf("heepay empty response from %s (method=%s)", gatewayURL, method)
|
||||
}
|
||||
|
||||
// 先解析为 raw map,确保所有字段都能参与验签
|
||||
var rawMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal(respBytes, &rawMap); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
signRaw, _ := rawMap["sign"]
|
||||
signB64 := strings.Trim(string(signRaw), `"`)
|
||||
|
||||
// 构建验签参数:所有非 sign、非空字段
|
||||
verifyParams := make(map[string]string, len(rawMap))
|
||||
for k, v := range rawMap {
|
||||
if k == "sign" {
|
||||
continue
|
||||
}
|
||||
raw := string(v)
|
||||
if raw == "null" || raw == `""` {
|
||||
continue
|
||||
}
|
||||
// 字符串类型去引号并反转义,数字/对象/数组直接用 raw 值
|
||||
if len(raw) >= 2 && raw[0] == '"' {
|
||||
var s string
|
||||
json.Unmarshal(v, &s)
|
||||
verifyParams[k] = s
|
||||
} else {
|
||||
verifyParams[k] = raw
|
||||
}
|
||||
}
|
||||
|
||||
if err := VerifyResponse(verifyParams, signB64, a.config.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("verify response sign: %w", err)
|
||||
}
|
||||
|
||||
// 再用强类型 struct 方便后续处理
|
||||
var resp heepayResponse
|
||||
json.Unmarshal(respBytes, &resp)
|
||||
codeStr := resp.codeStr()
|
||||
|
||||
if codeStr != codeSuccess {
|
||||
return nil, fmt.Errorf("heepay error [%s] %s: %s %s", codeStr, resp.Msg, resp.SubCode, resp.SubMsg)
|
||||
}
|
||||
|
||||
var bizResult map[string]any
|
||||
if len(resp.Data) > 0 {
|
||||
if err := json.Unmarshal(resp.Data, &bizResult); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal response data: %w", err)
|
||||
}
|
||||
}
|
||||
return bizResult, nil
|
||||
}
|
||||
|
||||
// mapPayMethod 将内部支付方式映射为汇元 trade_type
|
||||
func mapPayMethod(m model.PayMethod) (string, error) {
|
||||
mapping := map[model.PayMethod]string{
|
||||
model.PayMethodWechatJSAPI: "JSAPI",
|
||||
model.PayMethodWechatH5: "MWEB",
|
||||
model.PayMethodWechatNative: "NATIVE",
|
||||
model.PayMethodWechatMini: "MINIAPP",
|
||||
model.PayMethodAlipay: "ALI_NATIVE",
|
||||
model.PayMethodQuickPay: "QUICK",
|
||||
}
|
||||
t, ok := mapping[m]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unsupported pay method: %s", m)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func strVal(m map[string]any, key string) string {
|
||||
if v, ok := m[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
13
backend/internal/channel/heepay/cbc.go
Normal file
13
backend/internal/channel/heepay/cbc.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package heepay
|
||||
|
||||
import "crypto/cipher"
|
||||
|
||||
// newCBCEncrypter 创建 CBC 加密器(封装 cipher.NewCBCEncrypter)
|
||||
func newCBCEncrypter(block cipher.Block, iv []byte) cipher.BlockMode {
|
||||
return cipher.NewCBCEncrypter(block, iv)
|
||||
}
|
||||
|
||||
// newCBCDecrypter 创建 CBC 解密器
|
||||
func newCBCDecrypter(block cipher.Block, iv []byte) cipher.BlockMode {
|
||||
return cipher.NewCBCDecrypter(block, iv)
|
||||
}
|
||||
142
backend/internal/channel/heepay/crypto.go
Normal file
142
backend/internal/channel/heepay/crypto.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/des"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// EncryptRequest 使用 RSA+3DES 加密请求体
|
||||
// 1. 生成随机 3DES 密钥(24 字节)
|
||||
// 2. 用 3DES-CBC 加密 JSON 请求体
|
||||
// 3. 用汇元 RSA 公钥加密 3DES 密钥
|
||||
func EncryptRequest(plaintext []byte, publicKeyPEM string) (encData string, encKey string, err error) {
|
||||
pubKey, err := parseRSAPublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 生成 3DES 密钥
|
||||
desKey := make([]byte, 24)
|
||||
if _, err = rand.Read(desKey); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// 3DES 加密
|
||||
ciphertext, err := tripleDesEncrypt(plaintext, desKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// RSA 加密 3DES 密钥
|
||||
encKeyBytes, err := rsa.EncryptPKCS1v15(rand.Reader, pubKey, desKey)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
encData = base64.StdEncoding.EncodeToString(ciphertext)
|
||||
encKey = base64.StdEncoding.EncodeToString(encKeyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// DecryptResponse 使用 RSA 私钥 + 3DES 解密响应
|
||||
func DecryptResponse(encData, encKey, privateKeyPEM string) ([]byte, error) {
|
||||
privKey, err := parseRSAPrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encKeyBytes, err := base64.StdEncoding.DecodeString(encKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
desKey, err := rsa.DecryptPKCS1v15(rand.Reader, privKey, encKeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tripleDesDecrypt(ciphertext, desKey)
|
||||
}
|
||||
|
||||
// tripleDesEncrypt 3DES-CBC 加密(PKCS5Padding)
|
||||
func tripleDesEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := des.NewTripleDESCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
plaintext = pkcs5Padding(plaintext, blockSize)
|
||||
|
||||
iv := key[:blockSize] // 使用密钥前 8 字节作为 IV
|
||||
mode := newCBCEncrypter(block, iv)
|
||||
ciphertext := make([]byte, len(plaintext))
|
||||
mode.CryptBlocks(ciphertext, plaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// tripleDesDecrypt 3DES-CBC 解密
|
||||
func tripleDesDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := des.NewTripleDESCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
iv := key[:blockSize]
|
||||
mode := newCBCDecrypter(block, iv)
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext)
|
||||
}
|
||||
|
||||
func pkcs5Padding(data []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(data)%blockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(data, padText...)
|
||||
}
|
||||
|
||||
func pkcs5Unpadding(data []byte) ([]byte, error) {
|
||||
length := len(data)
|
||||
if length == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
padding := int(data[length-1])
|
||||
if padding > length {
|
||||
return nil, errors.New("invalid padding")
|
||||
}
|
||||
return data[:length-padding], nil
|
||||
}
|
||||
|
||||
func parseRSAPublicKey(pemStr string) (*rsa.PublicKey, error) {
|
||||
block, _ := pem.Decode([]byte(pemStr))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("not RSA public key")
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
|
||||
func parseRSAPrivateKey(pemStr string) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode([]byte(pemStr))
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
}
|
||||
355
backend/internal/channel/heepay/sandbox_test.go
Normal file
355
backend/internal/channel/heepay/sandbox_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
//go:build sandbox
|
||||
|
||||
// 沙盒集成测试:直连汇元沙盒环境,验证真实 API 调用。
|
||||
// 需要设置以下环境变量后运行:
|
||||
//
|
||||
// export HEEPAY_MERCHANT_ID=your_merchant_id
|
||||
// export HEEPAY_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----"
|
||||
// export HEEPAY_PUBLIC_KEY="-----BEGIN PUBLIC KEY-----\n...\n-----END PUBLIC KEY-----"
|
||||
//
|
||||
// go test -tags sandbox ./internal/channel/heepay/ -v -timeout 60s
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"pay-bridge/internal/channel"
|
||||
"pay-bridge/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
sandboxPayURL = "http://openapi.heepaydev.com/gateway"
|
||||
sandboxMerchantURL = "http://openapi.heepaydev.com/v1/customer/gateway"
|
||||
)
|
||||
|
||||
// newSandboxAdapter 从环境变量读取凭证,构造沙盒 Adapter
|
||||
func newSandboxAdapter(t *testing.T) *Adapter {
|
||||
t.Helper()
|
||||
merchantID := requireEnv(t, "HEEPAY_MERCHANT_ID")
|
||||
privateKey := requireEnv(t, "HEEPAY_PRIVATE_KEY")
|
||||
publicKey := requireEnv(t, "HEEPAY_PUBLIC_KEY")
|
||||
|
||||
// 支持 \n 转义(shell 传入时换行符可能被转义)
|
||||
privateKey = strings.ReplaceAll(privateKey, `\n`, "\n")
|
||||
publicKey = strings.ReplaceAll(publicKey, `\n`, "\n")
|
||||
|
||||
cfg := &model.ChannelConfig{
|
||||
MerchantID: merchantID,
|
||||
PrivateKey: privateKey,
|
||||
PublicKey: publicKey,
|
||||
Sandbox: 1,
|
||||
}
|
||||
return &Adapter{
|
||||
config: cfg,
|
||||
payURL: sandboxPayURL,
|
||||
merchantURL: sandboxMerchantURL,
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func requireEnv(t *testing.T, key string) string {
|
||||
t.Helper()
|
||||
v := os.Getenv(key)
|
||||
require.NotEmpty(t, v, "环境变量 %s 未设置,无法运行沙盒测试", key)
|
||||
return v
|
||||
}
|
||||
|
||||
// uniqueOrderNo 生成测试用唯一订单号(避免重复)
|
||||
func uniqueOrderNo(prefix string) string {
|
||||
return fmt.Sprintf("%s%d", prefix, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// --- 下单 ---
|
||||
|
||||
func TestSandbox_CreateOrder_JSAPI(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := a.CreateOrder(ctx, &channel.CreateOrderReq{
|
||||
AppID: a.config.MerchantID,
|
||||
TradeNo: uniqueOrderNo("TEST"),
|
||||
MerchantOrderNo: uniqueOrderNo("ORD"),
|
||||
PayMethod: model.PayMethodWechatJSAPI,
|
||||
Amount: 1, // 1 分
|
||||
Subject: "沙盒测试-JSAPI",
|
||||
NotifyURL: "https://example.com/notify",
|
||||
ExpireTime: time.Now().Add(10 * time.Minute),
|
||||
Extra: map[string]any{"openid": "oBk9Y5YMoAb2UG0L1OWQ_xNoBnE0"}, // 沙盒 openid
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.PayCredential, "应返回支付凭证")
|
||||
t.Logf("pay_credential: %+v", resp.PayCredential)
|
||||
}
|
||||
|
||||
func TestSandbox_CreateOrder_Native(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := a.CreateOrder(ctx, &channel.CreateOrderReq{
|
||||
AppID: a.config.MerchantID,
|
||||
TradeNo: uniqueOrderNo("TEST"),
|
||||
MerchantOrderNo: uniqueOrderNo("ORD"),
|
||||
PayMethod: model.PayMethodWechatNative,
|
||||
Amount: 1,
|
||||
Subject: "沙盒测试-Native",
|
||||
NotifyURL: "https://example.com/notify",
|
||||
ExpireTime: time.Now().Add(10 * time.Minute),
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.PayCredential)
|
||||
t.Logf("pay_credential: %+v", resp.PayCredential)
|
||||
}
|
||||
|
||||
// --- 查询订单 ---
|
||||
|
||||
func TestSandbox_QueryOrder_NotExist(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 查一个不存在的订单,预期渠道返回错误
|
||||
_, err := a.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: "NOT_EXIST_" + uniqueOrderNo(""),
|
||||
})
|
||||
|
||||
// 沙盒对不存在订单会返回错误,确认我们能正确解析
|
||||
assert.Error(t, err)
|
||||
t.Logf("expected error: %v", err)
|
||||
}
|
||||
|
||||
// --- 完整流程:下单 → 查询 → 关闭 ---
|
||||
|
||||
func TestSandbox_OrderLifecycle(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
tradeNo := uniqueOrderNo("LIFE")
|
||||
|
||||
// 1. 下单
|
||||
createResp, err := a.CreateOrder(ctx, &channel.CreateOrderReq{
|
||||
AppID: a.config.MerchantID,
|
||||
TradeNo: tradeNo,
|
||||
MerchantOrderNo: uniqueOrderNo("ORD"),
|
||||
PayMethod: model.PayMethodWechatNative,
|
||||
Amount: 1,
|
||||
Subject: "沙盒-生命周期测试",
|
||||
NotifyURL: "https://example.com/notify",
|
||||
ExpireTime: time.Now().Add(10 * time.Minute),
|
||||
})
|
||||
require.NoError(t, err, "下单失败")
|
||||
t.Logf("下单成功,channel_trade_no: %s", createResp.ChannelTradeNo)
|
||||
|
||||
// 2. 查询(刚下单,应为 PAYING 状态)
|
||||
queryResp, err := a.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: createResp.ChannelTradeNo,
|
||||
})
|
||||
require.NoError(t, err, "查询订单失败")
|
||||
assert.Equal(t, model.TradeStatusPaying, queryResp.Status)
|
||||
t.Logf("订单状态: %s", queryResp.Status)
|
||||
|
||||
// 3. 关闭
|
||||
err = a.CloseOrder(ctx, &channel.CloseOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: createResp.ChannelTradeNo,
|
||||
})
|
||||
require.NoError(t, err, "关闭订单失败")
|
||||
t.Log("关闭订单成功")
|
||||
|
||||
// 4. 再次查询,应为 CLOSED
|
||||
queryResp2, err := a.QueryOrder(ctx, &channel.QueryOrderReq{
|
||||
TradeNo: tradeNo,
|
||||
ChannelTradeNo: createResp.ChannelTradeNo,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, model.TradeStatusClosed, queryResp2.Status)
|
||||
}
|
||||
|
||||
// --- 商户进件 ---
|
||||
|
||||
// TestSandbox_UploadFile 上传一张最小 JPEG 到沙盒,验证返回 file_id
|
||||
func TestSandbox_UploadFile(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 最小合法 JPEG(几十字节)
|
||||
minJPEG := []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01,
|
||||
0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0xFF, 0xD9,
|
||||
}
|
||||
|
||||
resp, err := a.UploadFile(ctx, &channel.UploadFileReq{
|
||||
FileContent: minJPEG,
|
||||
FileName: "test_license.jpg",
|
||||
FileMediaType: "image/jpeg",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.FileID, "应返回 file_id")
|
||||
t.Logf("file_id: %s", resp.FileID)
|
||||
}
|
||||
|
||||
// TestSandbox_MerchantApply 提交企业入网申请,验证返回 request_no
|
||||
// 沙盒环境审核不会真正处理,但接口应正常响应
|
||||
func TestSandbox_MerchantApply(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 沙盒测试用企业信息(使用固定测试数据)
|
||||
// 字段名以汇元 customer.enter.enterprise.apply 文档为准
|
||||
bizContent := map[string]any{
|
||||
"merch_name": "测试科技有限公司",
|
||||
"merch_short_name": "测试科技",
|
||||
"merch_type": "ENTERPRISE",
|
||||
"contact_name": "张三",
|
||||
"contact_phone": "13800138000",
|
||||
"contact_email": "test@example.com",
|
||||
"license_no": "91110000123456789X",
|
||||
"legal_name": "李四",
|
||||
"legal_id": "110101199001011234",
|
||||
"province": "北京市",
|
||||
"city": "北京市",
|
||||
"district": "朝阳区",
|
||||
"address": "朝阳区测试路1号",
|
||||
"bank_acct_name": "测试科技有限公司",
|
||||
"bank_acct_no": "6222021234567890123",
|
||||
"bank_name": "中国工商银行",
|
||||
}
|
||||
|
||||
resp, err := a.MerchantApply(ctx, &channel.MerchantApplyReq{
|
||||
MerchantID: uniqueOrderNo("M"),
|
||||
BizContent: bizContent,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, resp.RequestNo, "应返回 request_no")
|
||||
t.Logf("request_no: %s", resp.RequestNo)
|
||||
}
|
||||
|
||||
// TestSandbox_QueryMerchantStatus 用已有的 request_no 查询进件状态
|
||||
// 若无真实 request_no,跳过(防止误报失败)
|
||||
func TestSandbox_QueryMerchantStatus(t *testing.T) {
|
||||
requestNo := os.Getenv("HEEPAY_TEST_REQUEST_NO")
|
||||
if requestNo == "" {
|
||||
t.Skip("未设置 HEEPAY_TEST_REQUEST_NO,跳过进件状态查询测试")
|
||||
}
|
||||
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := a.QueryMerchantStatus(ctx, requestNo)
|
||||
require.NoError(t, err)
|
||||
t.Logf("audit_state: %s, merch_id: %s, fail_reason: %s",
|
||||
resp.Status, resp.ChannelMerchantID, resp.FailReason)
|
||||
}
|
||||
|
||||
// TestSandbox_MerchantOnboardingFlow 完整进件流程:上传文件 → 提交申请 → 查询状态
|
||||
func TestSandbox_MerchantOnboardingFlow(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. 上传营业执照
|
||||
minJPEG := []byte{
|
||||
0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46, 0x49, 0x46, 0x00, 0x01,
|
||||
0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0xFF, 0xD9,
|
||||
}
|
||||
uploadResp, err := a.UploadFile(ctx, &channel.UploadFileReq{
|
||||
FileContent: minJPEG,
|
||||
FileName: "license.jpg",
|
||||
FileMediaType: "image/jpeg",
|
||||
})
|
||||
require.NoError(t, err, "上传营业执照失败")
|
||||
t.Logf("上传成功,file_id: %s", uploadResp.FileID)
|
||||
|
||||
// 2. 提交进件申请(带上文件 ID)
|
||||
bizContent := map[string]any{
|
||||
"merch_name": "沙盒流程测试公司",
|
||||
"merch_short_name": "沙盒测试",
|
||||
"merch_type": "ENTERPRISE",
|
||||
"contact_name": "王五",
|
||||
"contact_phone": "13900139000",
|
||||
"contact_email": "flow@example.com",
|
||||
"license_no": "91110000FLOW00001X",
|
||||
"license_img": uploadResp.FileID,
|
||||
"legal_name": "赵六",
|
||||
"legal_id": "110101199001019999",
|
||||
"province": "北京市",
|
||||
"city": "北京市",
|
||||
"district": "海淀区",
|
||||
"address": "海淀区测试路2号",
|
||||
"bank_acct_name": "沙盒流程测试公司",
|
||||
"bank_acct_no": "6222029876543210987",
|
||||
"bank_name": "中国建设银行",
|
||||
}
|
||||
applyResp, err := a.MerchantApply(ctx, &channel.MerchantApplyReq{
|
||||
MerchantID: uniqueOrderNo("FLOW"),
|
||||
BizContent: bizContent,
|
||||
})
|
||||
require.NoError(t, err, "提交进件申请失败")
|
||||
t.Logf("申请成功,request_no: %s", applyResp.RequestNo)
|
||||
|
||||
// 3. 查询进件状态(沙盒可能立即返回状态)
|
||||
statusResp, err := a.QueryMerchantStatus(ctx, applyResp.RequestNo)
|
||||
require.NoError(t, err, "查询进件状态失败")
|
||||
t.Logf("进件状态: audit_state=%s, merch_id=%s", statusResp.Status, statusResp.ChannelMerchantID)
|
||||
}
|
||||
|
||||
// --- 签名验证(本地,不发网络请求)---
|
||||
|
||||
// TestSign_And_Verify 本地签名/验签往返测试(不发网络请求)
|
||||
// 注意汇元双密钥体系:
|
||||
// - 请求签名:商户私钥签 → 汇元用商户公钥验
|
||||
// - 响应验签:汇元私钥签 → 商户用汇元公钥验(即 a.config.PublicKey)
|
||||
//
|
||||
// 本测试只验证商户私钥签名正确,使用从私钥派生的公钥做自验,
|
||||
// 不混用汇元公钥(两者非同一密钥对)。
|
||||
func TestSign_And_Verify(t *testing.T) {
|
||||
a := newSandboxAdapter(t)
|
||||
|
||||
params := map[string]string{
|
||||
"app_id": a.config.MerchantID,
|
||||
"method": "pay.heepay.trade.create",
|
||||
"format": "JSON",
|
||||
"charset": "utf-8",
|
||||
"sign_type": SignTypeRSA2,
|
||||
"timestamp": "2026-02-28 10:00:00",
|
||||
"version": "1.0",
|
||||
"biz_content": `{"out_trade_no":"TEST001"}`,
|
||||
}
|
||||
|
||||
sign, err := Sign(params, a.config.PrivateKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, sign)
|
||||
t.Logf("sign (first 32 chars): %s...", sign[:32])
|
||||
|
||||
// 从商户私钥提取对应公钥,用于验证我们自己签出来的签名
|
||||
merchantPubKeyPEM, err := extractPublicKeyFromPrivate(a.config.PrivateKey)
|
||||
require.NoError(t, err, "从私钥提取公钥失败")
|
||||
err = VerifyResponse(params, sign, merchantPubKeyPEM)
|
||||
assert.NoError(t, err, "用商户公钥验证商户私钥签名应通过")
|
||||
}
|
||||
|
||||
// extractPublicKeyFromPrivate 从商户私钥中派生出对应的公钥(DER base64 格式)
|
||||
func extractPublicKeyFromPrivate(privKeyStr string) (string, error) {
|
||||
privKey, err := parsePrivateKey(privKeyStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
derBytes, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(derBytes), nil
|
||||
}
|
||||
|
||||
124
backend/internal/channel/heepay/sign.go
Normal file
124
backend/internal/channel/heepay/sign.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package heepay
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const SignTypeRSA2 = "RSA2"
|
||||
|
||||
// Sign 对请求参数签名(商户私钥)
|
||||
// params 为公共参数(不含 sign),biz_content 已作为整体字符串放入 params["biz_content"]
|
||||
func Sign(params map[string]string, privateKeyPEM string) (string, error) {
|
||||
payload := sortAndJoin(params)
|
||||
return signRSA2(payload, privateKeyPEM)
|
||||
}
|
||||
|
||||
// VerifyResponse 验证汇元响应签名(汇元公钥)
|
||||
// params 为响应公共参数(不含 sign),data 已作为整体 JSON 字符串放入 params["data"]
|
||||
func VerifyResponse(params map[string]string, sign, publicKeyPEM string) error {
|
||||
payload := sortAndJoin(params)
|
||||
return verifyRSA2(payload, sign, publicKeyPEM)
|
||||
}
|
||||
|
||||
// sortAndJoin 按参数名 A-Z 排序后拼接 key=value&...(排除 sign 和空值字段)
|
||||
func sortAndJoin(params map[string]string) string {
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
if k == "sign" || params[k] == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
parts := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
parts = append(parts, k+"="+params[k])
|
||||
}
|
||||
return strings.Join(parts, "&")
|
||||
}
|
||||
|
||||
// signRSA2 SHA256WithRSA 签名,Base64 编码
|
||||
func signRSA2(payload, privateKeyPEM string) (string, error) {
|
||||
privKey, err := parsePrivateKey(privateKeyPEM)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(payload))
|
||||
sig, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("rsa sign: %w", err)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(sig), nil
|
||||
}
|
||||
|
||||
// verifyRSA2 验证 SHA256WithRSA 签名
|
||||
func verifyRSA2(payload, signB64, publicKeyPEM string) error {
|
||||
pubKey, err := parsePublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse public key: %w", err)
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(signB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode sign base64: %w", err)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(payload))
|
||||
return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig)
|
||||
}
|
||||
|
||||
func parsePrivateKey(pemStr string) (*rsa.PrivateKey, error) {
|
||||
var der []byte
|
||||
if block, _ := pem.Decode([]byte(pemStr)); block != nil {
|
||||
der = block.Bytes
|
||||
} else {
|
||||
// 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码
|
||||
cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "")
|
||||
var err error
|
||||
der, err = base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("private key is neither PEM nor valid base64: %w", err)
|
||||
}
|
||||
}
|
||||
// 优先尝试 PKCS8,再尝试 PKCS1
|
||||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||||
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||||
return rsaKey, nil
|
||||
}
|
||||
return nil, errors.New("not an RSA private key")
|
||||
}
|
||||
return x509.ParsePKCS1PrivateKey(der)
|
||||
}
|
||||
|
||||
func parsePublicKey(pemStr string) (*rsa.PublicKey, error) {
|
||||
var der []byte
|
||||
if block, _ := pem.Decode([]byte(pemStr)); block != nil {
|
||||
der = block.Bytes
|
||||
} else {
|
||||
// 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码
|
||||
cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "")
|
||||
var err error
|
||||
der, err = base64.StdEncoding.DecodeString(cleaned)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("public key is neither PEM nor valid base64: %w", err)
|
||||
}
|
||||
}
|
||||
pub, err := x509.ParsePKIXPublicKey(der)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse public key: %w", err)
|
||||
}
|
||||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
return nil, errors.New("not an RSA public key")
|
||||
}
|
||||
return rsaPub, nil
|
||||
}
|
||||
Reference in New Issue
Block a user