package heepay import ( "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/base64" "encoding/pem" "errors" "fmt" "sort" "strings" ) const SignTypeRSA2 = "RSA2" // Sign 对请求参数签名(商户私钥) // params 为公共参数(不含 sign),biz_content 已作为整体字符串放入 params["biz_content"] func Sign(params map[string]string, privateKeyPEM string) (string, error) { payload := sortAndJoin(params) return signRSA2(payload, privateKeyPEM) } // VerifyResponse 验证汇元响应签名(汇元公钥) // params 为响应公共参数(不含 sign),data 已作为整体 JSON 字符串放入 params["data"] func VerifyResponse(params map[string]string, sign, publicKeyPEM string) error { payload := sortAndJoin(params) return verifyRSA2(payload, sign, publicKeyPEM) } // sortAndJoin 按参数名 A-Z 排序后拼接 key=value&...(排除 sign 和空值字段) func sortAndJoin(params map[string]string) string { keys := make([]string, 0, len(params)) for k := range params { if k == "sign" || params[k] == "" { continue } keys = append(keys, k) } sort.Strings(keys) parts := make([]string, 0, len(keys)) for _, k := range keys { parts = append(parts, k+"="+params[k]) } return strings.Join(parts, "&") } // signRSA2 SHA256WithRSA 签名,Base64 编码 func signRSA2(payload, privateKeyPEM string) (string, error) { privKey, err := parsePrivateKey(privateKeyPEM) if err != nil { return "", fmt.Errorf("parse private key: %w", err) } hash := sha256.Sum256([]byte(payload)) sig, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:]) if err != nil { return "", fmt.Errorf("rsa sign: %w", err) } return base64.StdEncoding.EncodeToString(sig), nil } // verifyRSA2 验证 SHA256WithRSA 签名 func verifyRSA2(payload, signB64, publicKeyPEM string) error { pubKey, err := parsePublicKey(publicKeyPEM) if err != nil { return fmt.Errorf("parse public key: %w", err) } sig, err := base64.StdEncoding.DecodeString(signB64) if err != nil { return fmt.Errorf("decode sign base64: %w", err) } hash := sha256.Sum256([]byte(payload)) return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig) } func parsePrivateKey(pemStr string) (*rsa.PrivateKey, error) { var der []byte if block, _ := pem.Decode([]byte(pemStr)); block != nil { der = block.Bytes } else { // 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码 cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "") var err error der, err = base64.StdEncoding.DecodeString(cleaned) if err != nil { return nil, fmt.Errorf("private key is neither PEM nor valid base64: %w", err) } } // 优先尝试 PKCS8,再尝试 PKCS1 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { if rsaKey, ok := key.(*rsa.PrivateKey); ok { return rsaKey, nil } return nil, errors.New("not an RSA private key") } return x509.ParsePKCS1PrivateKey(der) } func parsePublicKey(pemStr string) (*rsa.PublicKey, error) { var der []byte if block, _ := pem.Decode([]byte(pemStr)); block != nil { der = block.Bytes } else { // 汇元文档提供的是裸 Base64(无 PEM header),直接 base64 解码 cleaned := strings.ReplaceAll(strings.TrimSpace(pemStr), "\n", "") var err error der, err = base64.StdEncoding.DecodeString(cleaned) if err != nil { return nil, fmt.Errorf("public key is neither PEM nor valid base64: %w", err) } } pub, err := x509.ParsePKIXPublicKey(der) if err != nil { return nil, fmt.Errorf("parse public key: %w", err) } rsaPub, ok := pub.(*rsa.PublicKey) if !ok { return nil, errors.New("not an RSA public key") } return rsaPub, nil }