Files
higress/plugins/wasm-go/extensions/jwt-auth/handler/jwks_cache.go
2026-05-25 16:04:10 +08:00

330 lines
10 KiB
Go

// Copyright (c) 2023 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package handler
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/jwt-auth/config"
"github.com/go-jose/go-jose/v3"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
type cachedJWKs struct {
keys *jose.JSONWebKeySet
fetchedAt time.Time
}
type remoteJWKsFetchState struct {
inFlight bool
startedAt time.Time
deadline time.Time
lastFailedAt time.Time
}
type remoteJWKsCacheKey struct {
serviceName string
serviceHost string
servicePort int64
path string
}
// These maps are process-local caches in the single-threaded proxy-wasm VM.
var remoteJWKsCache = map[remoteJWKsCacheKey]cachedJWKs{}
var remoteJWKsFetchStates = map[remoteJWKsCacheKey]remoteJWKsFetchState{}
var errRemoteJWKsCacheMiss = errors.New("remote jwks cache is missing or expired")
var errRemoteJWKsRefreshThrottled = errors.New("remote jwks refresh is throttled")
var dispatchRemoteJWKsHTTPCall = proxywasm.DispatchHttpCall
// Failed remote JWKS fetches are backed off per remote service reference.
// In-flight requests are not coalesced because proxy-wasm callbacks are bound
// to one HTTP stream context.
const remoteJWKsMinRefreshInterval = time.Duration(cfg.RemoteJWKsMinRefreshIntervalSeconds) * time.Second
const maxRemoteJWKsResponseSize = 64 * 1024
func remoteJWKsCacheKeyForConsumer(consumer *cfg.Consumer) remoteJWKsCacheKey {
remote := consumer.RemoteJWKs
if remote == nil {
return remoteJWKsCacheKey{}
}
return remoteJWKsCacheKey{
serviceName: remote.ServiceName,
serviceHost: remote.ServiceHost,
servicePort: remoteJWKsServicePort(remote),
path: remote.Path,
}
}
func PruneRemoteJWKsCache(consumers []*cfg.Consumer) {
active := make(map[remoteJWKsCacheKey]struct{}, len(consumers))
for _, consumer := range consumers {
if consumer != nil && consumer.RemoteJWKs != nil {
active[remoteJWKsCacheKeyForConsumer(consumer)] = struct{}{}
}
}
for key := range remoteJWKsCache {
if _, ok := active[key]; !ok {
delete(remoteJWKsCache, key)
}
}
for key := range remoteJWKsFetchStates {
if _, ok := active[key]; !ok {
delete(remoteJWKsFetchStates, key)
}
}
}
func consumerJWKs(consumer *cfg.Consumer, now time.Time) (*jose.JSONWebKeySet, error) {
raw := consumer.JWKs
if raw == "" {
cached, ok := remoteJWKsCache[remoteJWKsCacheKeyForConsumer(consumer)]
cacheDuration := remoteJWKsCacheDuration(consumer)
if ok && now.Before(cached.fetchedAt.Add(time.Duration(cacheDuration)*time.Second)) {
return cached.keys, nil
}
if ok && remoteJWKsFetchInFlight(consumer, now) {
return cached.keys, nil
}
if !remoteJWKsFetchAllowed(consumer, now) {
return nil, errRemoteJWKsRefreshThrottled
}
return nil, errRemoteJWKsCacheMiss
}
if consumer.ParsedJWKs != nil {
return consumer.ParsedJWKs, nil
}
return parseJWKs(raw)
}
// remoteJWKsFetchedAfter tells the verifier whether this same request has
// already retried with a freshly fetched JWKS.
func remoteJWKsFetchedAfter(consumer *cfg.Consumer, t time.Time) bool {
cached, ok := remoteJWKsCache[remoteJWKsCacheKeyForConsumer(consumer)]
return ok && cached.fetchedAt.After(t)
}
func isRemoteJWKsCacheMiss(err error) bool {
return errors.Is(err, errRemoteJWKsCacheMiss)
}
func isRemoteJWKsRefreshThrottled(err error) bool {
return errors.Is(err, errRemoteJWKsRefreshThrottled)
}
func fetchRemoteJWKs(consumer *cfg.Consumer, log log.Log, callback func()) error {
cluster, path, err := remoteJWKsFetchCluster(consumer)
if err != nil {
return err
}
timeout := uint32(remoteJWKsFetchTimeout(consumer))
startedAt := time.Now()
if !recordRemoteJWKsFetchStart(consumer, startedAt) {
return errRemoteJWKsRefreshThrottled
}
headers := [][2]string{{"Accept", "application/json"}, {":method", http.MethodGet}, {":path", path}, {":authority", remoteJWKsAuthority(consumer)}}
_, err = dispatchRemoteJWKsHTTPCall(cluster.ClusterName(), headers, nil, nil, timeout, func(numHeaders, bodySize, numTrailers int) {
statusCode, err := remoteJWKsResponseStatus()
if err != nil {
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
log.Warnf("failed to read remote jwks response status, consumer:%s, reason:%s", consumer.Name, err.Error())
callback()
return
}
if statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices {
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
log.Warnf("failed to fetch remote jwks, consumer:%s, status:%d", consumer.Name, statusCode)
callback()
return
}
if bodySize > maxRemoteJWKsResponseSize {
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
log.Warnf("remote jwks is invalid, consumer:%s, status:%d, reason:jwks response exceeds %d bytes", consumer.Name, statusCode, maxRemoteJWKsResponseSize)
callback()
return
}
body, err := proxywasm.GetHttpCallResponseBody(0, bodySize)
if err != nil {
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
log.Warnf("failed to read remote jwks response body, consumer:%s, status:%d, reason:%s", consumer.Name, statusCode, err.Error())
callback()
return
}
keys, err := parseRemoteJWKsResponse(string(body))
if err != nil {
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
log.Warnf("remote jwks is invalid, consumer:%s, status:%d, reason:%s", consumer.Name, statusCode, err.Error())
callback()
return
}
cacheRemoteJWKs(consumer, keys, startedAt, time.Now())
callback()
})
if err != nil {
recordRemoteJWKsFetchFailure(consumer, startedAt, time.Now())
return err
}
return nil
}
func remoteJWKsResponseStatus() (int, error) {
headers, err := proxywasm.GetHttpCallResponseHeaders()
if err != nil {
return 0, err
}
for _, header := range headers {
if header[0] == ":status" {
return strconv.Atoi(header[1])
}
}
return 0, fmt.Errorf("missing :status")
}
func remoteJWKsFetchCluster(consumer *cfg.Consumer) (wrapper.FQDNCluster, string, error) {
remote := consumer.RemoteJWKs
if remote == nil || remote.ServiceName == "" || remote.Path == "" {
return wrapper.FQDNCluster{}, "", fmt.Errorf("remote_jwks is not configured")
}
return wrapper.FQDNCluster{
FQDN: remote.ServiceName,
Host: remote.ServiceHost,
Port: remoteJWKsServicePort(remote),
}, remote.Path, nil
}
func remoteJWKsServicePort(remote *cfg.RemoteJWKs) int64 {
if remote.ServicePort == nil {
return 443
}
return *remote.ServicePort
}
func remoteJWKsAuthority(consumer *cfg.Consumer) string {
remote := consumer.RemoteJWKs
if remote == nil {
return ""
}
port := remoteJWKsServicePort(remote)
if port == 80 || port == 443 {
return remote.ServiceHost
}
return remote.ServiceHost + ":" + strconv.FormatInt(port, 10)
}
func remoteJWKsCacheDuration(consumer *cfg.Consumer) int64 {
if consumer.JWKsCacheDuration == nil {
return cfg.DefaultJWKsCacheDuration
}
return *consumer.JWKsCacheDuration
}
func remoteJWKsFetchTimeout(consumer *cfg.Consumer) int64 {
if consumer.JWKsFetchTimeout == nil {
return cfg.DefaultJWKsFetchTimeout
}
return *consumer.JWKsFetchTimeout
}
func parseRemoteJWKsResponse(raw string) (*jose.JSONWebKeySet, error) {
if len(raw) > maxRemoteJWKsResponseSize {
return nil, fmt.Errorf("jwks response exceeds %d bytes", maxRemoteJWKsResponseSize)
}
return parseJWKs(raw)
}
func parseJWKs(raw string) (*jose.JSONWebKeySet, error) {
jwks := &jose.JSONWebKeySet{}
if err := json.Unmarshal([]byte(raw), jwks); err != nil {
return nil, err
}
if len(jwks.Keys) == 0 {
return nil, fmt.Errorf("jwks has no keys")
}
return jwks, nil
}
// Initial cold/expired fetches are only backed off after failures. A successful
// completed fetch must not block the next TTL-driven refresh.
func remoteJWKsFetchAllowed(consumer *cfg.Consumer, now time.Time) bool {
state, ok := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
if !ok {
return true
}
if remoteJWKsInFlight(state, now) {
return false
}
return state.lastFailedAt.IsZero() || now.Sub(state.lastFailedAt) >= remoteJWKsMinRefreshInterval
}
func remoteJWKsFetchInFlight(consumer *cfg.Consumer, now time.Time) bool {
state, ok := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
return ok && remoteJWKsInFlight(state, now)
}
func remoteJWKsInFlight(state remoteJWKsFetchState, now time.Time) bool {
return state.inFlight && now.Before(state.deadline)
}
func recordRemoteJWKsFetchStart(consumer *cfg.Consumer, now time.Time) bool {
if !remoteJWKsFetchAllowed(consumer, now) {
return false
}
state := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
state.inFlight = true
state.startedAt = now
state.deadline = now.Add(time.Duration(remoteJWKsFetchTimeout(consumer)) * time.Millisecond)
remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)] = state
return true
}
func recordRemoteJWKsFetchFailure(consumer *cfg.Consumer, startedAt time.Time, now time.Time) {
state := remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)]
if !state.startedAt.Equal(startedAt) {
return
}
state.inFlight = false
state.startedAt = time.Time{}
state.deadline = time.Time{}
state.lastFailedAt = now
remoteJWKsFetchStates[remoteJWKsCacheKeyForConsumer(consumer)] = state
}
func cacheRemoteJWKs(consumer *cfg.Consumer, keys *jose.JSONWebKeySet, startedAt time.Time, now time.Time) {
cacheKey := remoteJWKsCacheKeyForConsumer(consumer)
state := remoteJWKsFetchStates[cacheKey]
if !state.startedAt.Equal(startedAt) {
return
}
remoteJWKsCache[cacheKey] = cachedJWKs{
keys: keys,
fetchedAt: now,
}
state.inFlight = false
state.startedAt = time.Time{}
state.deadline = time.Time{}
state.lastFailedAt = time.Time{}
remoteJWKsFetchStates[cacheKey] = state
}