100 lines
2.3 KiB
Go
100 lines
2.3 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"crypto/hmac"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"net/http"
|
||
"strconv"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"pay-bridge/internal/api/handler"
|
||
"pay-bridge/internal/errcode"
|
||
)
|
||
|
||
// AppLoader 根据 appId 加载 app 信息的接口
|
||
type AppLoader interface {
|
||
GetAppSecret(ctx context.Context, appID string) (string, error)
|
||
}
|
||
|
||
// Auth 鉴权中间件
|
||
// 请求头:X-App-Id、X-Timestamp、X-Sign
|
||
// 签名算法:HMAC-SHA256(appId + timestamp + body, appSecret)
|
||
func Auth(loader AppLoader) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
appID := c.GetHeader("X-App-Id")
|
||
timestamp := c.GetHeader("X-Timestamp")
|
||
sign := c.GetHeader("X-Sign")
|
||
|
||
if appID == "" || timestamp == "" || sign == "" {
|
||
handler.Unauthorized(c, errcode.ErrUnauthorized, errcode.Message(errcode.ErrUnauthorized))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 时间戳防重放(5分钟内有效)
|
||
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
||
if err != nil || abs(time.Now().Unix()-ts) > 300 {
|
||
handler.Unauthorized(c, errcode.ErrUnauthorized, "请求已过期")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
appSecret, err := loader.GetAppSecret(c.Request.Context(), appID)
|
||
if err != nil {
|
||
handler.Unauthorized(c, errcode.ErrAppNotFound, errcode.Message(errcode.ErrAppNotFound))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 读取 body(注意:body 只能读一次,需要提前 cache)
|
||
body := bodyFromContext(c)
|
||
|
||
expectedSign := sign256(appID+timestamp+string(body), appSecret)
|
||
if !hmac.Equal([]byte(expectedSign), []byte(sign)) {
|
||
handler.Unauthorized(c, errcode.ErrUnauthorized, errcode.Message(errcode.ErrUnauthorized))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Set("app_id", appID)
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// ChannelCallback 渠道回调鉴权(由渠道适配器验签,此中间件只做基础检查)
|
||
func ChannelCallback() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
channelCode := c.Param("channelCode")
|
||
if channelCode == "" {
|
||
c.AbortWithStatus(http.StatusBadRequest)
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
func sign256(payload, secret string) string {
|
||
h := hmac.New(sha256.New, []byte(secret))
|
||
h.Write([]byte(payload))
|
||
return hex.EncodeToString(h.Sum(nil))
|
||
}
|
||
|
||
func abs(n int64) int64 {
|
||
if n < 0 {
|
||
return -n
|
||
}
|
||
return n
|
||
}
|
||
|
||
func bodyFromContext(c *gin.Context) []byte {
|
||
if v, exists := c.Get("raw_body"); exists {
|
||
if b, ok := v.([]byte); ok {
|
||
return b
|
||
}
|
||
}
|
||
return nil
|
||
}
|