Files
higress/plugins/wasm-go/extensions/jwt-auth/config/parser.go
2026-05-25 16:04:10 +08:00

249 lines
7.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2023 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package config
import (
"encoding/json"
"fmt"
"strings"
"github.com/go-jose/go-jose/v3"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/tidwall/gjson"
)
const maxJWKsFetchTimeout = int64(10 * 1000) // milliseconds
const maxJWKsCacheDuration = int64(7 * 24 * 3600) // seconds
const minJWKsCacheDuration = RemoteJWKsMinRefreshIntervalSeconds
// ParseGlobalConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
// 此处解析的是全局配置,域名和路由级配置由 ParseRuleConfig 负责。
func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log log.Log) error {
config.RuleSet = len(json.Get("_rules_").Array()) > 0
consumers := json.Get("consumers")
if !consumers.IsArray() {
return fmt.Errorf("failed to parse configuration for consumers: consumers is not a array")
}
consumerNames := map[string]struct{}{}
for _, v := range consumers.Array() {
c, err := ParseConsumer(v, consumerNames)
if err != nil {
log.Warn(err.Error())
continue
}
config.Consumers = append(config.Consumers, c)
}
if len(config.Consumers) == 0 {
return fmt.Errorf("at least one consumer should be configured for a rule")
}
return nil
}
// ParseRuleConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
// 此处解析的是域名和路由级配置,全局配置由 ParseConfig 负责。
func ParseRuleConfig(json gjson.Result, global JWTAuthConfig, config *JWTAuthConfig, log log.Log) error {
// override config via global
*config = global
allow := json.Get("allow")
if !allow.Exists() {
return fmt.Errorf("allow is required")
}
if len(allow.Array()) == 0 {
return fmt.Errorf("allow cannot be empty")
}
for _, item := range allow.Array() {
config.Allow = append(config.Allow, item.String())
}
config.RuleSet = true
return nil
}
func ParseConsumer(consumer gjson.Result, names map[string]struct{}) (c *Consumer, err error) {
c = &Consumer{}
// 从gjson中取得原始JSON字符串并使用标准库反序列化以降低代码复杂度。
err = json.Unmarshal([]byte(consumer.Raw), c)
if err != nil {
return nil, fmt.Errorf("failed to parse consumer: %s", err.Error())
}
// 检查consumer是否重复
if _, ok := names[c.Name]; ok {
return nil, fmt.Errorf("consumer already exists: %s", c.Name)
}
c.Issuer = strings.TrimSpace(c.Issuer)
c.JWKs = strings.TrimSpace(c.JWKs)
if c.RemoteJWKs != nil {
normalizeRemoteJWKs(c.RemoteJWKs)
}
if c.JWKs == "" && c.RemoteJWKs == nil {
return nil, fmt.Errorf("one of jwks and remote_jwks is required, consumer:%s", c.Name)
}
if c.JWKs != "" && c.RemoteJWKs != nil {
return nil, fmt.Errorf("only one of jwks and remote_jwks can be configured, consumer:%s", c.Name)
}
if c.JWKs != "" {
if c.JWKsCacheDuration != nil || c.JWKsFetchTimeout != nil {
return nil, fmt.Errorf("jwks_cache_duration and jwks_fetch_timeout only apply to remote_jwks, consumer:%s", c.Name)
}
// Validate inline JWKS before accepting the consumer.
jwks := &jose.JSONWebKeySet{}
err = json.Unmarshal([]byte(c.JWKs), jwks)
if err != nil {
return nil, fmt.Errorf("jwks is invalid, consumer:%s, status:%s, jwks:%s", c.Name, err.Error(), c.JWKs)
}
if len(jwks.Keys) == 0 {
return nil, fmt.Errorf("jwks is empty, consumer:%s", c.Name)
}
c.ParsedJWKs = jwks
}
if c.RemoteJWKs != nil {
if c.Issuer == "" {
return nil, fmt.Errorf("issuer is required when remote_jwks is set, consumer:%s", c.Name)
}
if err := validateRemoteJWKs(c.RemoteJWKs); err != nil {
return nil, fmt.Errorf("remote_jwks is invalid, consumer:%s, reason:%s", c.Name, err.Error())
}
}
// 检查是否需要使用默认jwt抽取来源
if c.FromHeaders == nil && c.FromParams == nil && c.FromCookies == nil {
c.FromHeaders = &DefaultFromHeader
c.FromParams = &DefaultFromParams
c.FromCookies = &DefaultFromCookies
}
// 检查ClaimsToHeaders
if c.ClaimsToHeaders != nil {
// header去重
c2h := map[string]struct{}{}
// 此处需要先把指针解引用到临时变量
tmp := *c.ClaimsToHeaders
for i := range tmp {
if _, ok := c2h[tmp[i].Header]; ok {
return nil, fmt.Errorf("claim to header already exists: %s", c2h[tmp[i].Header])
}
c2h[tmp[i].Header] = struct{}{}
// 为Override填充默认值
if tmp[i].Override == nil {
tmp[i].Override = &DefaultClaimToHeaderOverride
}
}
}
// 为ClockSkewSeconds填充默认值
if c.ClockSkewSeconds == nil {
c.ClockSkewSeconds = &DefaultClockSkewSeconds
}
// 为KeepToken填充默认值
if c.KeepToken == nil {
c.KeepToken = &DefaultKeepToken
}
if c.RemoteJWKs != nil {
// Fill the default remote JWKS cache duration.
if c.JWKsCacheDuration == nil {
v := DefaultJWKsCacheDuration
c.JWKsCacheDuration = &v
}
if *c.JWKsCacheDuration <= 0 {
return nil, fmt.Errorf("jwks_cache_duration must be positive, consumer:%s", c.Name)
}
if *c.JWKsCacheDuration < minJWKsCacheDuration {
return nil, fmt.Errorf("jwks_cache_duration must be greater than or equal to %d, consumer:%s", minJWKsCacheDuration, c.Name)
}
if *c.JWKsCacheDuration > maxJWKsCacheDuration {
return nil, fmt.Errorf("jwks_cache_duration must be less than or equal to %d, consumer:%s", maxJWKsCacheDuration, c.Name)
}
// Fill the default remote JWKS fetch timeout.
if c.JWKsFetchTimeout == nil {
v := DefaultJWKsFetchTimeout
c.JWKsFetchTimeout = &v
}
if *c.JWKsFetchTimeout <= 0 {
return nil, fmt.Errorf("jwks_fetch_timeout must be positive, consumer:%s", c.Name)
}
if *c.JWKsFetchTimeout > maxJWKsFetchTimeout {
return nil, fmt.Errorf("jwks_fetch_timeout must be less than or equal to %d, consumer:%s", maxJWKsFetchTimeout, c.Name)
}
}
// consumer合法记录consumer名称
names[c.Name] = struct{}{}
return c, nil
}
func normalizeRemoteJWKs(remote *RemoteJWKs) {
remote.ServiceName = strings.TrimSpace(remote.ServiceName)
remote.ServiceHost = strings.TrimSpace(remote.ServiceHost)
remote.Path = strings.TrimSpace(remote.Path)
if remote.ServicePort == nil {
v := int64(443)
remote.ServicePort = &v
}
}
func validateRemoteJWKs(remote *RemoteJWKs) error {
if remote.ServiceName == "" {
return fmt.Errorf("service_name is required")
}
if hasInvalidRemoteJWKsFieldChar(remote.ServiceName) || strings.ContainsAny(remote.ServiceName, "|/?#@:") {
return fmt.Errorf("service_name must not contain whitespace, control characters, or URI separators")
}
if remote.ServiceHost == "" {
return fmt.Errorf("service_host is required")
}
if hasInvalidRemoteJWKsFieldChar(remote.ServiceHost) {
return fmt.Errorf("service_host must not contain whitespace or control characters")
}
if strings.ContainsAny(remote.ServiceHost, "/?#@:") || strings.Contains(remote.ServiceHost, "://") {
return fmt.Errorf("service_host must be a host without port")
}
if remote.Path == "" || !strings.HasPrefix(remote.Path, "/") {
return fmt.Errorf("path must start with /")
}
if hasInvalidRemoteJWKsFieldChar(remote.Path) {
return fmt.Errorf("path must not contain whitespace or control characters")
}
if *remote.ServicePort <= 0 || *remote.ServicePort > 65535 {
return fmt.Errorf("service_port is invalid")
}
return nil
}
func hasInvalidRemoteJWKsFieldChar(value string) bool {
return strings.ContainsAny(value, " \t\r\n") || hasControlChar(value)
}
func hasControlChar(value string) bool {
for _, r := range value {
if r < 0x20 || r == 0x7f {
return true
}
}
return false
}