mirror of
https://github.com/alibaba/higress.git
synced 2026-06-04 01:57:26 +08:00
Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com> Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
249 lines
7.9 KiB
Go
249 lines
7.9 KiB
Go
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package config
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"github.com/go-jose/go-jose/v3"
|
||
"github.com/higress-group/wasm-go/pkg/log"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
const maxJWKsFetchTimeout = int64(10 * 1000) // milliseconds
|
||
const maxJWKsCacheDuration = int64(7 * 24 * 3600) // seconds
|
||
const minJWKsCacheDuration = RemoteJWKsMinRefreshIntervalSeconds
|
||
|
||
// ParseGlobalConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
|
||
// 此处解析的是全局配置,域名和路由级配置由 ParseRuleConfig 负责。
|
||
func ParseGlobalConfig(json gjson.Result, config *JWTAuthConfig, log log.Log) error {
|
||
config.RuleSet = len(json.Get("_rules_").Array()) > 0
|
||
consumers := json.Get("consumers")
|
||
if !consumers.IsArray() {
|
||
return fmt.Errorf("failed to parse configuration for consumers: consumers is not a array")
|
||
}
|
||
|
||
consumerNames := map[string]struct{}{}
|
||
for _, v := range consumers.Array() {
|
||
c, err := ParseConsumer(v, consumerNames)
|
||
if err != nil {
|
||
log.Warn(err.Error())
|
||
continue
|
||
}
|
||
config.Consumers = append(config.Consumers, c)
|
||
}
|
||
if len(config.Consumers) == 0 {
|
||
return fmt.Errorf("at least one consumer should be configured for a rule")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ParseRuleConfig 从wrapper提供的配置中解析并转换到插件运行时需要使用的配置。
|
||
// 此处解析的是域名和路由级配置,全局配置由 ParseConfig 负责。
|
||
func ParseRuleConfig(json gjson.Result, global JWTAuthConfig, config *JWTAuthConfig, log log.Log) error {
|
||
// override config via global
|
||
*config = global
|
||
|
||
allow := json.Get("allow")
|
||
if !allow.Exists() {
|
||
return fmt.Errorf("allow is required")
|
||
}
|
||
|
||
if len(allow.Array()) == 0 {
|
||
return fmt.Errorf("allow cannot be empty")
|
||
}
|
||
|
||
for _, item := range allow.Array() {
|
||
config.Allow = append(config.Allow, item.String())
|
||
}
|
||
|
||
config.RuleSet = true
|
||
return nil
|
||
}
|
||
|
||
func ParseConsumer(consumer gjson.Result, names map[string]struct{}) (c *Consumer, err error) {
|
||
c = &Consumer{}
|
||
|
||
// 从gjson中取得原始JSON字符串,并使用标准库反序列化,以降低代码复杂度。
|
||
err = json.Unmarshal([]byte(consumer.Raw), c)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse consumer: %s", err.Error())
|
||
}
|
||
|
||
// 检查consumer是否重复
|
||
if _, ok := names[c.Name]; ok {
|
||
return nil, fmt.Errorf("consumer already exists: %s", c.Name)
|
||
}
|
||
|
||
c.Issuer = strings.TrimSpace(c.Issuer)
|
||
c.JWKs = strings.TrimSpace(c.JWKs)
|
||
if c.RemoteJWKs != nil {
|
||
normalizeRemoteJWKs(c.RemoteJWKs)
|
||
}
|
||
if c.JWKs == "" && c.RemoteJWKs == nil {
|
||
return nil, fmt.Errorf("one of jwks and remote_jwks is required, consumer:%s", c.Name)
|
||
}
|
||
if c.JWKs != "" && c.RemoteJWKs != nil {
|
||
return nil, fmt.Errorf("only one of jwks and remote_jwks can be configured, consumer:%s", c.Name)
|
||
}
|
||
if c.JWKs != "" {
|
||
if c.JWKsCacheDuration != nil || c.JWKsFetchTimeout != nil {
|
||
return nil, fmt.Errorf("jwks_cache_duration and jwks_fetch_timeout only apply to remote_jwks, consumer:%s", c.Name)
|
||
}
|
||
// Validate inline JWKS before accepting the consumer.
|
||
jwks := &jose.JSONWebKeySet{}
|
||
err = json.Unmarshal([]byte(c.JWKs), jwks)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("jwks is invalid, consumer:%s, status:%s, jwks:%s", c.Name, err.Error(), c.JWKs)
|
||
}
|
||
if len(jwks.Keys) == 0 {
|
||
return nil, fmt.Errorf("jwks is empty, consumer:%s", c.Name)
|
||
}
|
||
c.ParsedJWKs = jwks
|
||
}
|
||
if c.RemoteJWKs != nil {
|
||
if c.Issuer == "" {
|
||
return nil, fmt.Errorf("issuer is required when remote_jwks is set, consumer:%s", c.Name)
|
||
}
|
||
if err := validateRemoteJWKs(c.RemoteJWKs); err != nil {
|
||
return nil, fmt.Errorf("remote_jwks is invalid, consumer:%s, reason:%s", c.Name, err.Error())
|
||
}
|
||
}
|
||
|
||
// 检查是否需要使用默认jwt抽取来源
|
||
if c.FromHeaders == nil && c.FromParams == nil && c.FromCookies == nil {
|
||
c.FromHeaders = &DefaultFromHeader
|
||
c.FromParams = &DefaultFromParams
|
||
c.FromCookies = &DefaultFromCookies
|
||
}
|
||
|
||
// 检查ClaimsToHeaders
|
||
if c.ClaimsToHeaders != nil {
|
||
// header去重
|
||
c2h := map[string]struct{}{}
|
||
|
||
// 此处需要先把指针解引用到临时变量
|
||
tmp := *c.ClaimsToHeaders
|
||
for i := range tmp {
|
||
if _, ok := c2h[tmp[i].Header]; ok {
|
||
return nil, fmt.Errorf("claim to header already exists: %s", c2h[tmp[i].Header])
|
||
}
|
||
c2h[tmp[i].Header] = struct{}{}
|
||
|
||
// 为Override填充默认值
|
||
if tmp[i].Override == nil {
|
||
tmp[i].Override = &DefaultClaimToHeaderOverride
|
||
}
|
||
}
|
||
}
|
||
|
||
// 为ClockSkewSeconds填充默认值
|
||
if c.ClockSkewSeconds == nil {
|
||
c.ClockSkewSeconds = &DefaultClockSkewSeconds
|
||
}
|
||
|
||
// 为KeepToken填充默认值
|
||
if c.KeepToken == nil {
|
||
c.KeepToken = &DefaultKeepToken
|
||
}
|
||
|
||
if c.RemoteJWKs != nil {
|
||
// Fill the default remote JWKS cache duration.
|
||
if c.JWKsCacheDuration == nil {
|
||
v := DefaultJWKsCacheDuration
|
||
c.JWKsCacheDuration = &v
|
||
}
|
||
if *c.JWKsCacheDuration <= 0 {
|
||
return nil, fmt.Errorf("jwks_cache_duration must be positive, consumer:%s", c.Name)
|
||
}
|
||
if *c.JWKsCacheDuration < minJWKsCacheDuration {
|
||
return nil, fmt.Errorf("jwks_cache_duration must be greater than or equal to %d, consumer:%s", minJWKsCacheDuration, c.Name)
|
||
}
|
||
if *c.JWKsCacheDuration > maxJWKsCacheDuration {
|
||
return nil, fmt.Errorf("jwks_cache_duration must be less than or equal to %d, consumer:%s", maxJWKsCacheDuration, c.Name)
|
||
}
|
||
|
||
// Fill the default remote JWKS fetch timeout.
|
||
if c.JWKsFetchTimeout == nil {
|
||
v := DefaultJWKsFetchTimeout
|
||
c.JWKsFetchTimeout = &v
|
||
}
|
||
if *c.JWKsFetchTimeout <= 0 {
|
||
return nil, fmt.Errorf("jwks_fetch_timeout must be positive, consumer:%s", c.Name)
|
||
}
|
||
if *c.JWKsFetchTimeout > maxJWKsFetchTimeout {
|
||
return nil, fmt.Errorf("jwks_fetch_timeout must be less than or equal to %d, consumer:%s", maxJWKsFetchTimeout, c.Name)
|
||
}
|
||
}
|
||
|
||
// consumer合法,记录consumer名称
|
||
names[c.Name] = struct{}{}
|
||
return c, nil
|
||
}
|
||
|
||
func normalizeRemoteJWKs(remote *RemoteJWKs) {
|
||
remote.ServiceName = strings.TrimSpace(remote.ServiceName)
|
||
remote.ServiceHost = strings.TrimSpace(remote.ServiceHost)
|
||
remote.Path = strings.TrimSpace(remote.Path)
|
||
if remote.ServicePort == nil {
|
||
v := int64(443)
|
||
remote.ServicePort = &v
|
||
}
|
||
}
|
||
|
||
func validateRemoteJWKs(remote *RemoteJWKs) error {
|
||
if remote.ServiceName == "" {
|
||
return fmt.Errorf("service_name is required")
|
||
}
|
||
if hasInvalidRemoteJWKsFieldChar(remote.ServiceName) || strings.ContainsAny(remote.ServiceName, "|/?#@:") {
|
||
return fmt.Errorf("service_name must not contain whitespace, control characters, or URI separators")
|
||
}
|
||
if remote.ServiceHost == "" {
|
||
return fmt.Errorf("service_host is required")
|
||
}
|
||
if hasInvalidRemoteJWKsFieldChar(remote.ServiceHost) {
|
||
return fmt.Errorf("service_host must not contain whitespace or control characters")
|
||
}
|
||
if strings.ContainsAny(remote.ServiceHost, "/?#@:") || strings.Contains(remote.ServiceHost, "://") {
|
||
return fmt.Errorf("service_host must be a host without port")
|
||
}
|
||
if remote.Path == "" || !strings.HasPrefix(remote.Path, "/") {
|
||
return fmt.Errorf("path must start with /")
|
||
}
|
||
if hasInvalidRemoteJWKsFieldChar(remote.Path) {
|
||
return fmt.Errorf("path must not contain whitespace or control characters")
|
||
}
|
||
if *remote.ServicePort <= 0 || *remote.ServicePort > 65535 {
|
||
return fmt.Errorf("service_port is invalid")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func hasInvalidRemoteJWKsFieldChar(value string) bool {
|
||
return strings.ContainsAny(value, " \t\r\n") || hasControlChar(value)
|
||
}
|
||
|
||
func hasControlChar(value string) bool {
|
||
for _, r := range value {
|
||
if r < 0x20 || r == 0x7f {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|