mirror of
https://github.com/alibaba/higress.git
synced 2026-06-06 11:17:29 +08:00
feat(jwt-auth): support remote JWKS (#3838)
Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com> Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
This commit is contained in:
@@ -15,7 +15,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -33,6 +34,12 @@ type ErrDenied struct {
|
||||
denied func() types.Action
|
||||
}
|
||||
|
||||
type logSafeJWT string
|
||||
|
||||
type verifiedConsumer struct {
|
||||
claims map[string]any
|
||||
}
|
||||
|
||||
type Logger interface {
|
||||
Warnf(format string, args ...interface{})
|
||||
}
|
||||
@@ -61,36 +68,72 @@ func (e *ErrDenied) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func consumerVerify(consumer *cfg.Consumer, verifyTime time.Time, header HeaderProvider, log Logger) error {
|
||||
tokenStr := extractToken(*consumer.KeepToken, consumer, header, log)
|
||||
func consumerVerify(consumer *cfg.Consumer, verifyTime time.Time, header HeaderProvider, log Logger) (*verifiedConsumer, error) {
|
||||
tokenStr := extractToken(true, consumer, header, log)
|
||||
if tokenStr == "" {
|
||||
return &ErrDenied{
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt is missing, consumer: %s", consumer.Name),
|
||||
denied: deniedJWTMissing,
|
||||
}
|
||||
}
|
||||
tokenLogValue := jwtLogValue(tokenStr)
|
||||
|
||||
// 当前版本的higress暂不支持jwe,此处用ParseSigned
|
||||
token, err := jwt.ParseSigned(tokenStr)
|
||||
if err != nil {
|
||||
return &ErrDenied{
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt parse failed, consumer: %s, token: %s, reason: %s",
|
||||
consumer.Name,
|
||||
tokenStr,
|
||||
tokenLogValue,
|
||||
err.Error(),
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
}
|
||||
|
||||
// 此处可以直接使用 JSON 反序列 jwks
|
||||
jwks := jose.JSONWebKeySet{}
|
||||
err = json.Unmarshal([]byte(consumer.JWKs), &jwks)
|
||||
if consumer.RemoteJWKs != nil {
|
||||
// Avoid remote JWKS fetches for tokens that cannot belong to this issuer.
|
||||
// Signature and time claims are still verified after keys are loaded.
|
||||
unsafeClaims := jwt.Claims{}
|
||||
if err := token.UnsafeClaimsWithoutVerification(&unsafeClaims); err != nil {
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: failed to parse unsafe claims: %s",
|
||||
consumer.Name,
|
||||
tokenLogValue,
|
||||
err.Error(),
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
}
|
||||
if unsafeClaims.Issuer != consumer.Issuer {
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: issuer does not equal",
|
||||
consumer.Name,
|
||||
tokenLogValue,
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jwks, err := consumerJWKs(consumer, verifyTime)
|
||||
if err != nil {
|
||||
return &ErrDenied{
|
||||
if isRemoteJWKsCacheMiss(err) {
|
||||
return nil, err
|
||||
}
|
||||
if isRemoteJWKsRefreshThrottled(err) {
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: remote jwks refresh is throttled",
|
||||
consumer.Name,
|
||||
tokenLogValue,
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
}
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt parse failed, consumer: %s, token: %s, reason: %s",
|
||||
consumer.Name,
|
||||
tokenStr,
|
||||
tokenLogValue,
|
||||
err.Error(),
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
@@ -98,9 +141,9 @@ func consumerVerify(consumer *cfg.Consumer, verifyTime time.Time, header HeaderP
|
||||
}
|
||||
|
||||
out := jwt.Claims{}
|
||||
rawClaims := map[string]any{}
|
||||
|
||||
// 提前确认 kid 状态
|
||||
// Check the token key ID before signature verification so remote JWKS can
|
||||
// refresh on unknown keys without trying an arbitrary cached key.
|
||||
var kid string
|
||||
var key jose.JSONWebKey
|
||||
for _, header := range token.Headers {
|
||||
@@ -109,13 +152,38 @@ func consumerVerify(consumer *cfg.Consumer, verifyTime time.Time, header HeaderP
|
||||
break
|
||||
}
|
||||
}
|
||||
// 没有 kid 时选择第一个 key
|
||||
if kid == "" {
|
||||
key = jwks.Keys[0]
|
||||
}
|
||||
|
||||
keys := jwks.Key(kid)
|
||||
if len(keys) == 0 { // kid 不存在时选择第一个 key
|
||||
if consumer.RemoteJWKs == nil {
|
||||
if keys := jwks.Key(""); len(keys) > 0 {
|
||||
key = keys[0]
|
||||
} else {
|
||||
key = jwks.Keys[0]
|
||||
}
|
||||
} else if len(jwks.Keys) == 1 {
|
||||
key = jwks.Keys[0]
|
||||
} else {
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: kid is required for multi-key remote jwks",
|
||||
consumer.Name,
|
||||
tokenLogValue,
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
}
|
||||
} else if keys := jwks.Key(kid); len(keys) == 0 {
|
||||
if consumer.RemoteJWKs != nil {
|
||||
if remoteJWKsFetchedAfter(consumer, verifyTime) {
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: kid does not match remote jwks",
|
||||
consumer.Name,
|
||||
tokenLogValue,
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
} else {
|
||||
return nil, errRemoteJWKsCacheMiss
|
||||
}
|
||||
}
|
||||
key = jwks.Keys[0]
|
||||
} else {
|
||||
key = keys[0]
|
||||
@@ -123,24 +191,24 @@ func consumerVerify(consumer *cfg.Consumer, verifyTime time.Time, header HeaderP
|
||||
|
||||
// Claims 支持直接传入 jose 的 jwk
|
||||
// 无需额外调用verify,claims内部已进行验证
|
||||
err = token.Claims(key, &out)
|
||||
rawClaims := map[string]any{}
|
||||
err = token.Claims(key, &out, &rawClaims)
|
||||
if err != nil {
|
||||
return &ErrDenied{
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: %s",
|
||||
consumer.Name,
|
||||
tokenStr,
|
||||
tokenLogValue,
|
||||
err.Error(),
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
}
|
||||
token.UnsafeClaimsWithoutVerification(&rawClaims)
|
||||
|
||||
if out.Issuer != consumer.Issuer {
|
||||
return &ErrDenied{
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: issuer does not equal",
|
||||
consumer.Name,
|
||||
tokenStr,
|
||||
tokenLogValue,
|
||||
),
|
||||
denied: deniedJWTVerificationFails,
|
||||
}
|
||||
@@ -155,20 +223,35 @@ func consumerVerify(consumer *cfg.Consumer, verifyTime time.Time, header HeaderP
|
||||
time.Duration(*consumer.ClockSkewSeconds)*time.Second,
|
||||
)
|
||||
if err != nil {
|
||||
return &ErrDenied{
|
||||
return nil, &ErrDenied{
|
||||
msg: fmt.Sprintf("jwt verify failed, consumer: %s, token: %s, reason: %s",
|
||||
consumer.Name,
|
||||
tokenStr,
|
||||
tokenLogValue,
|
||||
err.Error(),
|
||||
),
|
||||
denied: deniedJWTExpired,
|
||||
}
|
||||
}
|
||||
|
||||
if consumer.ClaimsToHeaders != nil {
|
||||
claimsToHeader(rawClaims, *consumer.ClaimsToHeaders)
|
||||
return &verifiedConsumer{claims: rawClaims}, nil
|
||||
}
|
||||
|
||||
func jwtLogValue(token string) logSafeJWT {
|
||||
return logSafeJWT(token)
|
||||
}
|
||||
|
||||
func (t logSafeJWT) String() string {
|
||||
sum := sha256.Sum256([]byte(t))
|
||||
return "sha256:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func applyConsumerSideEffects(consumer *cfg.Consumer, verified *verifiedConsumer, header HeaderProvider, log Logger) {
|
||||
if !*consumer.KeepToken {
|
||||
_ = extractToken(false, consumer, header, log)
|
||||
}
|
||||
if consumer.ClaimsToHeaders != nil {
|
||||
claimsToHeader(verified.claims, *consumer.ClaimsToHeaders)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deniedJWTMissing() types.Action {
|
||||
|
||||
Reference in New Issue
Block a user