feat(jwt-auth): support remote JWKS (#3838)

Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
This commit is contained in:
Betula-L
2026-05-25 16:04:10 +08:00
committed by GitHub
parent e6fc09b14f
commit a86aaadaa4
17 changed files with 2780 additions and 115 deletions

View File

@@ -17,19 +17,21 @@ package config
import (
"encoding/json"
"fmt"
"strings"
"github.com/go-jose/go-jose/v3"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/tidwall/gjson"
)
// RuleSet 插件是否至少在一个 domain 或 route 上生效
var RuleSet bool
const maxJWKsFetchTimeout = int64(10 * 1000) // milliseconds
const maxJWKsCacheDuration = int64(7 * 24 * 3600) // seconds
const minJWKsCacheDuration = RemoteJWKsMinRefreshIntervalSeconds
// ParseGlobalConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
// 此处解析的是全局配置,域名和路由级配置由 ParseRuleConfig 负责。
func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log log.Log) error {
RuleSet = false
config.RuleSet = len(json.Get("_rules_").Array()) > 0
consumers := json.Get("consumers")
if !consumers.IsArray() {
return fmt.Errorf("failed to parse configuration for consumers: consumers is not a array")
@@ -70,7 +72,7 @@ func ParseRuleConfig(json gjson.Result, global JWTAuthConfig, config *JWTAuthCon
config.Allow = append(config.Allow, item.String())
}
RuleSet = true
config.RuleSet = true
return nil
}
@@ -88,11 +90,39 @@ func ParseConsumer(consumer gjson.Result, names map[string]struct{}) (c *Consume
return nil, fmt.Errorf("consumer already exists: %s", c.Name)
}
// 检查JWKs是否合法
jwks := &jose.JSONWebKeySet{}
err = json.Unmarshal([]byte(c.JWKs), jwks)
if err != nil {
return nil, fmt.Errorf("jwks is invalid, consumer:%s, status:%s, jwks:%s", c.Name, err.Error(), c.JWKs)
c.Issuer = strings.TrimSpace(c.Issuer)
c.JWKs = strings.TrimSpace(c.JWKs)
if c.RemoteJWKs != nil {
normalizeRemoteJWKs(c.RemoteJWKs)
}
if c.JWKs == "" && c.RemoteJWKs == nil {
return nil, fmt.Errorf("one of jwks and remote_jwks is required, consumer:%s", c.Name)
}
if c.JWKs != "" && c.RemoteJWKs != nil {
return nil, fmt.Errorf("only one of jwks and remote_jwks can be configured, consumer:%s", c.Name)
}
if c.JWKs != "" {
if c.JWKsCacheDuration != nil || c.JWKsFetchTimeout != nil {
return nil, fmt.Errorf("jwks_cache_duration and jwks_fetch_timeout only apply to remote_jwks, consumer:%s", c.Name)
}
// Validate inline JWKS before accepting the consumer.
jwks := &jose.JSONWebKeySet{}
err = json.Unmarshal([]byte(c.JWKs), jwks)
if err != nil {
return nil, fmt.Errorf("jwks is invalid, consumer:%s, status:%s, jwks:%s", c.Name, err.Error(), c.JWKs)
}
if len(jwks.Keys) == 0 {
return nil, fmt.Errorf("jwks is empty, consumer:%s", c.Name)
}
c.ParsedJWKs = jwks
}
if c.RemoteJWKs != nil {
if c.Issuer == "" {
return nil, fmt.Errorf("issuer is required when remote_jwks is set, consumer:%s", c.Name)
}
if err := validateRemoteJWKs(c.RemoteJWKs); err != nil {
return nil, fmt.Errorf("remote_jwks is invalid, consumer:%s, reason:%s", c.Name, err.Error())
}
}
// 检查是否需要使用默认jwt抽取来源
@@ -132,7 +162,87 @@ func ParseConsumer(consumer gjson.Result, names map[string]struct{}) (c *Consume
c.KeepToken = &DefaultKeepToken
}
if c.RemoteJWKs != nil {
// Fill the default remote JWKS cache duration.
if c.JWKsCacheDuration == nil {
v := DefaultJWKsCacheDuration
c.JWKsCacheDuration = &v
}
if *c.JWKsCacheDuration <= 0 {
return nil, fmt.Errorf("jwks_cache_duration must be positive, consumer:%s", c.Name)
}
if *c.JWKsCacheDuration < minJWKsCacheDuration {
return nil, fmt.Errorf("jwks_cache_duration must be greater than or equal to %d, consumer:%s", minJWKsCacheDuration, c.Name)
}
if *c.JWKsCacheDuration > maxJWKsCacheDuration {
return nil, fmt.Errorf("jwks_cache_duration must be less than or equal to %d, consumer:%s", maxJWKsCacheDuration, c.Name)
}
// Fill the default remote JWKS fetch timeout.
if c.JWKsFetchTimeout == nil {
v := DefaultJWKsFetchTimeout
c.JWKsFetchTimeout = &v
}
if *c.JWKsFetchTimeout <= 0 {
return nil, fmt.Errorf("jwks_fetch_timeout must be positive, consumer:%s", c.Name)
}
if *c.JWKsFetchTimeout > maxJWKsFetchTimeout {
return nil, fmt.Errorf("jwks_fetch_timeout must be less than or equal to %d, consumer:%s", maxJWKsFetchTimeout, c.Name)
}
}
// consumer合法记录consumer名称
names[c.Name] = struct{}{}
return c, nil
}
func normalizeRemoteJWKs(remote *RemoteJWKs) {
remote.ServiceName = strings.TrimSpace(remote.ServiceName)
remote.ServiceHost = strings.TrimSpace(remote.ServiceHost)
remote.Path = strings.TrimSpace(remote.Path)
if remote.ServicePort == nil {
v := int64(443)
remote.ServicePort = &v
}
}
func validateRemoteJWKs(remote *RemoteJWKs) error {
if remote.ServiceName == "" {
return fmt.Errorf("service_name is required")
}
if hasInvalidRemoteJWKsFieldChar(remote.ServiceName) || strings.ContainsAny(remote.ServiceName, "|/?#@:") {
return fmt.Errorf("service_name must not contain whitespace, control characters, or URI separators")
}
if remote.ServiceHost == "" {
return fmt.Errorf("service_host is required")
}
if hasInvalidRemoteJWKsFieldChar(remote.ServiceHost) {
return fmt.Errorf("service_host must not contain whitespace or control characters")
}
if strings.ContainsAny(remote.ServiceHost, "/?#@:") || strings.Contains(remote.ServiceHost, "://") {
return fmt.Errorf("service_host must be a host without port")
}
if remote.Path == "" || !strings.HasPrefix(remote.Path, "/") {
return fmt.Errorf("path must start with /")
}
if hasInvalidRemoteJWKsFieldChar(remote.Path) {
return fmt.Errorf("path must not contain whitespace or control characters")
}
if *remote.ServicePort <= 0 || *remote.ServicePort > 65535 {
return fmt.Errorf("service_port is invalid")
}
return nil
}
func hasInvalidRemoteJWKsFieldChar(value string) bool {
return strings.ContainsAny(value, " \t\r\n") || hasControlChar(value)
}
func hasControlChar(value string) bool {
for _, r := range value {
if r < 0x20 || r == 0x7f {
return true
}
}
return false
}