125 lines
3.6 KiB
Go
125 lines
3.6 KiB
Go
package heepay
|
||
|
||
import (
|
||
"crypto"
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/sha256"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
"sort"
|
||
"strings"
|
||
)
|
||
|
||
const SignTypeRSA2 = "RSA2"
|
||
|
||
// Sign 对请求参数签名(商户私钥)
|
||
// params 为公共参数(不含 sign),biz_content 已作为整体字符串放入 params["biz_content"]
|
||
func Sign(params map[string]string, privateKeyPEM string) (string, error) {
|
||
payload := sortAndJoin(params)
|
||
return signRSA2(payload, privateKeyPEM)
|
||
}
|
||
|
||
// VerifyResponse 验证汇元响应签名(汇元公钥)
|
||
// params 为响应公共参数(不含 sign),data 已作为整体 JSON 字符串放入 params["data"]
|
||
func VerifyResponse(params map[string]string, sign, publicKeyPEM string) error {
|
||
payload := sortAndJoin(params)
|
||
return verifyRSA2(payload, sign, publicKeyPEM)
|
||
}
|
||
|
||
// sortAndJoin 按参数名 A-Z 排序后拼接 key=value&...(排除 sign 和空值字段)
|
||
func sortAndJoin(params map[string]string) string {
|
||
keys := make([]string, 0, len(params))
|
||
for k := range params {
|
||
if k == "sign" || params[k] == "" {
|
||
continue
|
||
}
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
|
||
parts := make([]string, 0, len(keys))
|
||
for _, k := range keys {
|
||
parts = append(parts, k+"="+params[k])
|
||
}
|
||
return strings.Join(parts, "&")
|
||
}
|
||
|
||
// signRSA2 SHA256WithRSA 签名,Base64 编码
|
||
func signRSA2(payload, privateKeyPEM string) (string, error) {
|
||
privKey, err := parsePrivateKey(privateKeyPEM)
|
||
if err != nil {
|
||
return "", fmt.Errorf("parse private key: %w", err)
|
||
}
|
||
hash := sha256.Sum256([]byte(payload))
|
||
sig, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:])
|
||
if err != nil {
|
||
return "", fmt.Errorf("rsa sign: %w", err)
|
||
}
|
||
return base64.StdEncoding.EncodeToString(sig), nil
|
||
}
|
||
|
||
// verifyRSA2 验证 SHA256WithRSA 签名
|
||
func verifyRSA2(payload, signB64, publicKeyPEM string) error {
|
||
pubKey, err := parsePublicKey(publicKeyPEM)
|
||
if err != nil {
|
||
return fmt.Errorf("parse public key: %w", err)
|
||
}
|
||
sig, err := base64.StdEncoding.DecodeString(signB64)
|
||
if err != nil {
|
||
return fmt.Errorf("decode sign base64: %w", err)
|
||
}
|
||
hash := sha256.Sum256([]byte(payload))
|
||
return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig)
|
||
}
|
||
|
||
func parsePrivateKey(pemStr string) (*rsa.PrivateKey, error) {
|
||
var der []byte
|
||
if block, _ := pem.Decode([]byte(pemStr)); block != nil {
|
||
der = block.Bytes
|
||
} else {
|
||
// 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码
|
||
cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "")
|
||
var err error
|
||
der, err = base64.StdEncoding.DecodeString(cleaned)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("private key is neither PEM nor valid base64: %w", err)
|
||
}
|
||
}
|
||
// 优先尝试 PKCS8,再尝试 PKCS1
|
||
if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
|
||
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||
return rsaKey, nil
|
||
}
|
||
return nil, errors.New("not an RSA private key")
|
||
}
|
||
return x509.ParsePKCS1PrivateKey(der)
|
||
}
|
||
|
||
func parsePublicKey(pemStr string) (*rsa.PublicKey, error) {
|
||
var der []byte
|
||
if block, _ := pem.Decode([]byte(pemStr)); block != nil {
|
||
der = block.Bytes
|
||
} else {
|
||
// 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码
|
||
cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "")
|
||
var err error
|
||
der, err = base64.StdEncoding.DecodeString(cleaned)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("public key is neither PEM nor valid base64: %w", err)
|
||
}
|
||
}
|
||
pub, err := x509.ParsePKIXPublicKey(der)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("parse public key: %w", err)
|
||
}
|
||
rsaPub, ok := pub.(*rsa.PublicKey)
|
||
if !ok {
|
||
return nil, errors.New("not an RSA public key")
|
||
}
|
||
return rsaPub, nil
|
||
}
|