mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
feat:add oidc wasm plugin (#568)
This commit is contained in:
219
plugins/wasm-go/extensions/oidc/oc/verifer.go
Normal file
219
plugins/wasm-go/extensions/oidc/oc/verifer.go
Normal file
@@ -0,0 +1,219 @@
|
||||
/*
|
||||
Copyright 2023 go-oidc
|
||||
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package oc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
)
|
||||
|
||||
// IDTokenVerifierConfig
|
||||
type IDTokenVerifier struct {
|
||||
config *IDConfig
|
||||
issuer string
|
||||
}
|
||||
|
||||
const (
|
||||
issuerGoogleAccounts = "https://accounts.google.com"
|
||||
issuerGoogleAccountsNoScheme = "accounts.google.com"
|
||||
|
||||
LEEWAY = 5 * time.Minute
|
||||
)
|
||||
|
||||
type IDToken struct {
|
||||
Issuer string
|
||||
Audience []string
|
||||
Subject string
|
||||
Expiry time.Time
|
||||
IssuedAt time.Time
|
||||
Nonce string
|
||||
AccessTokenHash string
|
||||
sigAlgorithm string
|
||||
claims []byte
|
||||
distributedClaims map[string]claimSource
|
||||
}
|
||||
type TokenExpiredError struct {
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
func (e *TokenExpiredError) Error() string {
|
||||
return fmt.Sprintf("oidc: token is expired (Token Expiry: %v)", e.Expiry)
|
||||
}
|
||||
func (i *IDToken) Claims(v interface{}) error {
|
||||
if i.claims == nil {
|
||||
return errors.New("oidc: claims not set")
|
||||
}
|
||||
return json.Unmarshal(i.claims, v)
|
||||
}
|
||||
|
||||
func (v *IDTokenVerifier) VerifyToken(rawIDToken string, keySet jose.JSONWebKeySet) (*IDToken, error) {
|
||||
var log wrapper.Log
|
||||
payload, err := parseJWT(rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(" malformed jwt: %v", err)
|
||||
}
|
||||
var token idToken
|
||||
if err := json.Unmarshal(payload, &token); err != nil {
|
||||
log.Errorf("idToken Unmarshal error : %v ", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
distributedClaims := make(map[string]claimSource)
|
||||
|
||||
//step through the token to map claim names to claim sources
|
||||
for cn, src := range token.ClaimNames {
|
||||
if src == "" {
|
||||
return nil, fmt.Errorf("failed to obtain source from claim name")
|
||||
}
|
||||
s, ok := token.ClaimSources[src]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source does not exist")
|
||||
}
|
||||
distributedClaims[cn] = s
|
||||
}
|
||||
|
||||
t := &IDToken{
|
||||
Issuer: token.Issuer,
|
||||
Subject: token.Subject,
|
||||
Audience: []string(token.Audience),
|
||||
Expiry: time.Time(token.Expiry),
|
||||
IssuedAt: time.Time(token.IssuedAt),
|
||||
Nonce: token.Nonce,
|
||||
AccessTokenHash: token.AtHash,
|
||||
claims: payload,
|
||||
distributedClaims: distributedClaims,
|
||||
}
|
||||
|
||||
// Check issuer.
|
||||
if !v.config.SkipIssuerCheck && t.Issuer != v.issuer {
|
||||
// Google sometimes returns "accounts.google.com" as the issuer claim instead of
|
||||
// the required "https://accounts.google.com". Detect this case and allow it only
|
||||
// for Google.
|
||||
//
|
||||
// We will not add hooks to let other providers go off spec like this.
|
||||
if !(v.issuer == issuerGoogleAccounts && t.Issuer == issuerGoogleAccountsNoScheme) {
|
||||
return nil, fmt.Errorf("oidc: id token issued by a different provider, expected %q got %q", v.issuer, t.Issuer)
|
||||
}
|
||||
}
|
||||
|
||||
if v.config.ClientID != "" {
|
||||
if !contains(t.Audience, v.config.ClientID) {
|
||||
return nil, fmt.Errorf("oidc: expected audience %q got %q", v.config.ClientID, t.Audience)
|
||||
}
|
||||
}
|
||||
|
||||
// If a SkipExpiryCheck is false, make sure token is not expired.
|
||||
if !v.config.SkipExpiryCheck {
|
||||
now := time.Now
|
||||
if v.config.Now != nil {
|
||||
now = v.config.Now
|
||||
}
|
||||
nowTime := now()
|
||||
|
||||
if t.Expiry.Before(nowTime) {
|
||||
return nil, &TokenExpiredError{Expiry: t.Expiry}
|
||||
}
|
||||
|
||||
// If nbf claim is provided in token, ensure that it is indeed in the past.
|
||||
if token.NotBefore != nil {
|
||||
nbfTime := time.Time(*token.NotBefore)
|
||||
// Set to 5 minutes since this is what other OpenID Connect providers do to deal with clock skew.
|
||||
// https://github.com/AzureAD/azure-activedirectory-identitymodel-extensions-for-dotnet/blob/6.12.2/src/Microsoft.IdentityModel.Tokens/TokenValidationParameters.cs#L149-L153
|
||||
|
||||
if nowTime.Add(LEEWAY).Before(nbfTime) {
|
||||
return nil, fmt.Errorf("oidc: current time %v before the nbf (not before) time: %v", nowTime, nbfTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jws, err := jose.ParseSigned(rawIDToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
|
||||
}
|
||||
|
||||
switch len(jws.Signatures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("oidc: id token not signed")
|
||||
case 1:
|
||||
default:
|
||||
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
|
||||
}
|
||||
|
||||
sig := jws.Signatures[0]
|
||||
supportedSigAlgs := v.config.SupportedSigningAlgs
|
||||
|
||||
if len(supportedSigAlgs) == 0 {
|
||||
supportedSigAlgs = []string{RS256}
|
||||
}
|
||||
|
||||
if !contains(supportedSigAlgs, sig.Header.Algorithm) {
|
||||
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
|
||||
}
|
||||
|
||||
t.sigAlgorithm = sig.Header.Algorithm
|
||||
|
||||
keyID := ""
|
||||
for _, sig := range jws.Signatures {
|
||||
keyID = sig.Header.KeyID
|
||||
break
|
||||
}
|
||||
|
||||
for _, key := range keySet.Keys {
|
||||
if keyID == "" || key.KeyID == keyID {
|
||||
if gotPayload, err := jws.Verify(&key); err == nil {
|
||||
if !bytes.Equal(gotPayload, payload) {
|
||||
return nil, errors.New("oidc: internal error, payload parsed did not match previous payload")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
func contains(sli []string, ele string) bool {
|
||||
for _, s := range sli {
|
||||
if s == ele {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func parseJWT(p string) ([]byte, error) {
|
||||
parts := strings.Split(p, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
func (cfg *Oatuh2Config) Verifier(config *IDConfig) *IDTokenVerifier {
|
||||
return &IDTokenVerifier{
|
||||
config: config,
|
||||
issuer: cfg.Issuer,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user