mirror of
https://github.com/alibaba/higress.git
synced 2026-06-01 00:27:26 +08:00
Signed-off-by: Betula-L <6059935+Betula-L@users.noreply.github.com> Co-authored-by: Betula-L <6059935+Betula-L@users.noreply.github.com>
330 lines
10 KiB
Go
330 lines
10 KiB
Go
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package handler
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"time"
|
|
|
|
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/jwt-auth/config"
|
|
"github.com/go-jose/go-jose/v3"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/wasm-go/pkg/log"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
)
|
|
|
|
type cachedJWKs struct {
|
|
keys *jose.JSONWebKeySet
|
|
fetchedAt time.Time
|
|
}
|
|
|
|
type remoteJWKsFetchState struct {
|
|
inFlight bool
|
|
startedAt time.Time
|
|
deadline time.Time
|
|
lastFailedAt time.Time
|
|
}
|
|
|
|
type remoteJWKsCacheKey struct {
|
|
serviceName string
|
|
serviceHost string
|
|
servicePort int64
|
|
path string
|
|
}
|
|
|
|
// These maps are process-local caches in the single-threaded proxy-wasm VM.
|
|
var remoteJWKsCache = map[remoteJWKsCacheKey]cachedJWKs{}
|
|
var remoteJWKsFetchStates = map[remoteJWKsCacheKey]remoteJWKsFetchState{}
|
|
|
|
var errRemoteJWKsCacheMiss = errors.New("remote jwks cache is missing or expired")
|
|
var errRemoteJWKsRefreshThrottled = errors.New("remote jwks refresh is throttled")
|
|
|
|
var dispatchRemoteJWKsHTTPCall = proxywasm.DispatchHttpCall
|
|
|
|
// Failed remote JWKS fetches are backed off per remote service reference.
|
|
// In-flight requests are not coalesced because proxy-wasm callbacks are bound
|
|
// to one HTTP stream context.
|
|
const remoteJWKsMinRefreshInterval = time.Duration(cfg.RemoteJWKsMinRefreshIntervalSeconds) * time.Second
|
|
const maxRemoteJWKsResponseSize = 64 * 1024
|
|
|
|
func remoteJWKsCacheKeyForConsumer(consumer *cfg.Consumer) remoteJWKsCacheKey {
|
|
remote := consumer.RemoteJWKs
|
|
if remote == nil {
|
|
return remoteJWKsCacheKey{}
|
|
}
|
|
return remoteJWKsCacheKey{
|
|
serviceName: remote.ServiceName,
|
|
serviceHost: remote.ServiceHost,
|
|
servicePort: remoteJWKsServicePort(remote),
|
|
path: remote.Path,
|
|
}
|
|
}
|
|
|
|
func PruneRemoteJWKsCache(consumers []*cfg.Consumer) {
|
|
active := make(map[remoteJWKsCacheKey]struct{}, len(consumers))
|
|
for _, consumer := range consumers {
|
|
if consumer != nil && consumer.RemoteJWKs != nil {
|
|
active[remoteJWKsCacheKeyForConsumer(consumer)] = struct{}{}
|
|
}
|
|
}
|
|
for key := range remoteJWKsCache {
|
|
if _, ok := active[key]; !ok {
|
|
delete(remoteJWKsCache, key)
|
|
}
|
|
}
|
|
for key := range remoteJWKsFetchStates {
|
|
if _, ok := active[key]; !ok {
|
|
delete(remoteJWKsFetchStates, key)
|
|
}
|
|
}
|
|
}
|
|
|
|
func consumerJWKs(consumer *cfg.Consumer, now time.Time) (*jose.JSONWebKeySet, error) {
|
|
raw := consumer.JWKs
|
|
if raw == "" {
|
|
cached, ok := remoteJWKsCache[remoteJWKsCacheKeyForConsumer(consumer)]
|
|
cacheDuration := remoteJWKsCacheDuration(consumer)
|
|
if ok && now.Before(cached.fetchedAt.Add(time.Duration(cacheDuration)*time.Second)) {
|
|
return cached.keys, nil
|
|
}
|
|
if ok && remoteJWKsFetchInFlight(consumer, now) {
|
|
return cached.keys, nil
|
|
}
|
|
if !remoteJWKsFetchAllowed(consumer, now) {
|
|
return nil, errRemoteJWKsRefreshThrottled
|
|
}
|
|
return nil, errRemoteJWKsCacheMiss
|
|
}
|
|
|
|
if consumer.ParsedJWKs != nil {
|
|
return consumer.ParsedJWKs, nil
|
|
}
|
|
return parseJWKs(raw)
|
|
}
|
|
|
|
// remoteJWKsFetchedAfter tells the verifier whether this same request has
|
|
// already retried with a freshly fetched JWKS.
|
|
func remoteJWKsFetchedAfter(consumer *cfg.Consumer, t time.Time) bool {
|
|
cached, ok := remoteJWKsCache[remoteJWKsCacheKeyForConsumer(consumer)]
|
|
return ok && cached.fetchedAt.After(t)
|
|
}
|
|
|
|
func isRemoteJWKsCacheMiss(err error) bool {
|
|
return errors.Is(err, errRemoteJWKsCacheMiss)
|
|
}
|
|
|
|
func isRemoteJWKsRefreshThrottled(err error) bool {
|
|
return errors.Is(err, errRemoteJWKsRefreshThrottled)
|
|
}
|
|
|
|
func fetchRemoteJWKs(consumer *cfg.Consumer, log log.Log, callback func()) error {
|
|
cluster, path, err := remoteJWKsFetchCluster(consumer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
timeout := uint32(remoteJWKsFetchTimeout(consumer))
|
|
startedAt := time.Now()
|
|
if !recordRemoteJWKsFetchStart(consumer, startedAt) {
|
|
return errRemoteJWKsRefreshThrottled
|
|
}
|
|
headers := [][2]string{{"Accept", "application/json"}, {":method", http.MethodGet}, {":path", path}, {":authority", remoteJWKsAuthority(consumer)}}
|
|
_, err = dispatchRemoteJWKsHTTPCall(cluster.ClusterName(), headers, nil, nil, timeout, func(numHeaders, bodySize, numTrailers int) {
|
|
statusCode, err := remoteJWKsResponseStatus()
|
|
if err != nil {
|
|
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
|
|
log.Warnf("failed to read remote jwks response status, consumer:%s, reason:%s", consumer.Name, err.Error())
|
|
callback()
|
|
return
|
|
}
|
|
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
|
|
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
|
|
log.Warnf("failed to fetch remote jwks, consumer:%s, status:%d", consumer.Name, statusCode)
|
|
callback()
|
|
return
|
|
}
|
|
if bodySize > maxRemoteJWKsResponseSize {
|
|
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
|
|
log.Warnf("remote jwks is invalid, consumer:%s, status:%d, reason:jwks response exceeds %d bytes", consumer.Name, statusCode, maxRemoteJWKsResponseSize)
|
|
callback()
|
|
return
|
|
}
|
|
body, err := proxywasm.GetHttpCallResponseBody(0, bodySize)
|
|
if err != nil {
|
|
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
|
|
log.Warnf("failed to read remote jwks response body, consumer:%s, status:%d, reason:%s", consumer.Name, statusCode, err.Error())
|
|
callback()
|
|
return
|
|
}
|
|
keys, err := parseRemoteJWKsResponse(string(body))
|
|
if err != nil {
|
|
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
|
|
log.Warnf("remote jwks is invalid, consumer:%s, status:%d, reason:%s", consumer.Name, statusCode, err.Error())
|
|
callback()
|
|
return
|
|
}
|
|
cacheRemoteJWKs(consumer, keys, startedAt, time.Now())
|
|
callback()
|
|
})
|
|
if err != nil {
|
|
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func remoteJWKsResponseStatus() (int, error) {
|
|
headers, err := proxywasm.GetHttpCallResponseHeaders()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for _, header := range headers {
|
|
if header[0] == ":status" {
|
|
return strconv.Atoi(header[1])
|
|
}
|
|
}
|
|
return 0, fmt.Errorf("missing :status")
|
|
}
|
|
|
|
func remoteJWKsFetchCluster(consumer *cfg.Consumer) (wrapper.FQDNCluster, string, error) {
|
|
remote := consumer.RemoteJWKs
|
|
if remote == nil || remote.ServiceName == "" || remote.Path == "" {
|
|
return wrapper.FQDNCluster{}, "", fmt.Errorf("remote_jwks is not configured")
|
|
}
|
|
return wrapper.FQDNCluster{
|
|
FQDN: remote.ServiceName,
|
|
Host: remote.ServiceHost,
|
|
Port: remoteJWKsServicePort(remote),
|
|
}, remote.Path, nil
|
|
}
|
|
|
|
func remoteJWKsServicePort(remote *cfg.RemoteJWKs) int64 {
|
|
if remote.ServicePort == nil {
|
|
return 443
|
|
}
|
|
return *remote.ServicePort
|
|
}
|
|
|
|
func remoteJWKsAuthority(consumer *cfg.Consumer) string {
|
|
remote := consumer.RemoteJWKs
|
|
if remote == nil {
|
|
return ""
|
|
}
|
|
port := remoteJWKsServicePort(remote)
|
|
if port == 80 || port == 443 {
|
|
return remote.ServiceHost
|
|
}
|
|
return remote.ServiceHost + ":" + strconv.FormatInt(port, 10)
|
|
}
|
|
|
|
func remoteJWKsCacheDuration(consumer *cfg.Consumer) int64 {
|
|
if consumer.JWKsCacheDuration == nil {
|
|
return cfg.DefaultJWKsCacheDuration
|
|
}
|
|
return *consumer.JWKsCacheDuration
|
|
}
|
|
|
|
func remoteJWKsFetchTimeout(consumer *cfg.Consumer) int64 {
|
|
if consumer.JWKsFetchTimeout == nil {
|
|
return cfg.DefaultJWKsFetchTimeout
|
|
}
|
|
return *consumer.JWKsFetchTimeout
|
|
}
|
|
|
|
func parseRemoteJWKsResponse(raw string) (*jose.JSONWebKeySet, error) {
|
|
if len(raw) > maxRemoteJWKsResponseSize {
|
|
return nil, fmt.Errorf("jwks response exceeds %d bytes", maxRemoteJWKsResponseSize)
|
|
}
|
|
return parseJWKs(raw)
|
|
}
|
|
|
|
func parseJWKs(raw string) (*jose.JSONWebKeySet, error) {
|
|
jwks := &jose.JSONWebKeySet{}
|
|
if err := json.Unmarshal([]byte(raw), jwks); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(jwks.Keys) == 0 {
|
|
return nil, fmt.Errorf("jwks has no keys")
|
|
}
|
|
return jwks, nil
|
|
}
|
|
|
|
// Initial cold/expired fetches are only backed off after failures. A successful
|
|
// completed fetch must not block the next TTL-driven refresh.
|
|
func remoteJWKsFetchAllowed(consumer *cfg.Consumer, now time.Time) bool {
|
|
state, ok := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
|
|
if !ok {
|
|
return true
|
|
}
|
|
if remoteJWKsInFlight(state, now) {
|
|
return false
|
|
}
|
|
return state.lastFailedAt.IsZero() || now.Sub(state.lastFailedAt) >= remoteJWKsMinRefreshInterval
|
|
}
|
|
|
|
func remoteJWKsFetchInFlight(consumer *cfg.Consumer, now time.Time) bool {
|
|
state, ok := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
|
|
return ok && remoteJWKsInFlight(state, now)
|
|
}
|
|
|
|
func remoteJWKsInFlight(state remoteJWKsFetchState, now time.Time) bool {
|
|
return state.inFlight && now.Before(state.deadline)
|
|
}
|
|
|
|
func recordRemoteJWKsFetchStart(consumer *cfg.Consumer, now time.Time) bool {
|
|
if !remoteJWKsFetchAllowed(consumer, now) {
|
|
return false
|
|
}
|
|
state := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
|
|
state.inFlight = true
|
|
state.startedAt = now
|
|
state.deadline = now.Add(time.Duration(remoteJWKsFetchTimeout(consumer)) * time.Millisecond)
|
|
remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)] = state
|
|
return true
|
|
}
|
|
|
|
func recordRemoteJWKsFetchFailure(consumer *cfg.Consumer, startedAt time.Time, now time.Time) {
|
|
state := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
|
|
if !state.startedAt.Equal(startedAt) {
|
|
return
|
|
}
|
|
state.inFlight = false
|
|
state.startedAt = time.Time{}
|
|
state.deadline = time.Time{}
|
|
state.lastFailedAt = now
|
|
remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)] = state
|
|
}
|
|
|
|
func cacheRemoteJWKs(consumer *cfg.Consumer, keys *jose.JSONWebKeySet, startedAt time.Time, now time.Time) {
|
|
cacheKey := remoteJWKsCacheKeyForConsumer(consumer)
|
|
state := remoteJWKsFetchStates[cacheKey]
|
|
if !state.startedAt.Equal(startedAt) {
|
|
return
|
|
}
|
|
remoteJWKsCache[cacheKey] = cachedJWKs{
|
|
keys: keys,
|
|
fetchedAt: now,
|
|
}
|
|
state.inFlight = false
|
|
state.startedAt = time.Time{}
|
|
state.deadline = time.Time{}
|
|
state.lastFailedAt = time.Time{}
|
|
remoteJWKsFetchStates[cacheKey] = state
|
|
}
|