Compare commits

..

15 Commits

Author SHA1 Message Date
澄潭
e68a8ac25f add model-mapper plugin & optimize model-router plugin (#1538) 2024-11-22 22:24:42 +08:00
Kent Dong
96575b982e fix: Refresh go.mod and go.sum file contents (#1525) 2024-11-22 13:34:55 +08:00
EnableAsync
c2d405b2a7 feat: Enhance ai-cache Plugin with Vector Similarity-Based LLM Cache Recall and Multi-DB Support (#1248) 2024-11-21 16:57:41 +08:00
Jingze
6efb3109f2 fix: update oidc plugin go.mod dependencies (#1522) 2024-11-19 17:33:42 +08:00
Se7en
1b1c08afb7 fix: apitoken failover for coze (#1515) 2024-11-18 15:36:26 +08:00
Se7en
d24123a55f feat: implement apiToken failover mechanism (#1256) 2024-11-16 19:03:09 +08:00
澄潭
f2a5df3949 use the body returned by the ext auth server when auth fails (#1510) 2024-11-14 18:50:33 +08:00
澄潭
ebc5b2987e fix compile of wasm cpp plugins (#1511) 2024-11-14 18:49:21 +08:00
007gzs
ca97cbd75a fix workflows build-and-push-wasm-plugin-image (#1508) 2024-11-13 17:39:24 +08:00
hanans426
a787e237ce 增加快速部署到阿里云的部署方案 (#1506) 2024-11-13 16:26:55 +08:00
纪卓志
6a1bf90d42 feat: supports custom prepare build script (#1490) 2024-11-12 13:45:28 +08:00
007gzs
60e476da87 fix example sse build error (#1503) 2024-11-11 17:47:26 +08:00
rinfx
2cb8558cda Optimize AI security guard plugin (#1473)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-11-11 14:49:17 +08:00
littlejian
4d1a037942 feat: Automatically generating markdown documentation for helm charts with helm-docs (#1496) 2024-11-11 11:34:38 +08:00
xingyunyang01
39b6eac9d0 AI Agent plugin adds JSON formatting output feature (#1374) 2024-11-11 11:11:02 +08:00
167 changed files with 5258 additions and 4831 deletions

View File

@@ -42,17 +42,19 @@ jobs:
plugin_type="${{ github.event.inputs.plugin_type }}"
plugin_name="${{ github.event.inputs.plugin_name }}"
version="${{ github.event.inputs.version }}"
builder_image="higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-rust-builder:rust${{ env.RUST_VERSION }}-oras${{ env.ORAS_VERSION }}"
else
ref_name=${{ github.ref_name }}
plugin_type=${ref_name#*-} # 删除插件类型前面的字段(wasm-)
plugin_type=${plugin_type%-*} # 删除插件类型后面的字段(-{plugin_name}-vX.Y.Z)
plugin_type=${plugin_type%%-*} # 删除插件类型后面的字段(-{plugin_name}-vX.Y.Z)
plugin_name=${ref_name#*-*-} # 删除插件名前面的字段(wasm-go-)
plugin_name=${plugin_name%-*} # 删除插件名后面的字段(-vX.Y.Z)
version=$(echo "$ref_name" | awk -F'v' '{print $2}')
fi
if [[ "$plugin_type" == "rust" ]]; then
builder_image="higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-rust-builder:rust${{ env.RUST_VERSION }}-oras${{ env.ORAS_VERSION }}"
else
builder_image="higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/wasm-go-builder:go${{ env.GO_VERSION }}-tinygo${{ env.TINYGO_VERSION }}-oras${{ env.ORAS_VERSION }}"
fi
echo "PLUGIN_TYPE=$plugin_type" >> $GITHUB_ENV
echo "PLUGIN_NAME=$plugin_name" >> $GITHUB_ENV
echo "VERSION=$version" >> $GITHUB_ENV

35
.github/workflows/helm-docs.yaml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: "Helm Docs"
on:
pull_request:
branches:
- "*"
push:
jobs:
helm:
name: Helm Docs
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '1.22.9'
- name: Run helm-docs
run: |
GOBIN=$PWD GO111MODULE=on go install github.com/norwoodj/helm-docs/cmd/helm-docs@v1.14.2
./helm-docs -c ${GITHUB_WORKSPACE}/helm/higress -f ../core/values.yaml
DIFF=$(git diff ${GITHUB_WORKSPACE}/helm/higress/*md)
if [ ! -z "$DIFF" ]; then
echo "Please use helm-docs in your clone, of your fork, of the project, and commit a updated README.md for the chart."
fi
git diff --exit-code
rm -f ./helm-docs

3
.gitignore vendored
View File

@@ -16,4 +16,5 @@ helm/**/charts/**.tgz
target/
tools/hack/cluster.conf
envoy/1.20
istio/1.12
istio/1.12
Cargo.lock

View File

@@ -64,6 +64,10 @@ docker run -d --rm --name higress-ai -v ${PWD}:/data \
K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start 文档](https://higress.cn/docs/latest/user/quickstart/)。
如果您是在云上部署,生产环境推荐使用[企业版](https://higress.io/cloud/),开发测试可以使用下面一键部署社区版:
[![Deploy on AlibabaCloud ComputeNest](https://service-info-public.oss-cn-hangzhou.aliyuncs.com/computenest.svg)](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Higress社区版)
## 使用场景

View File

@@ -10,7 +10,7 @@ global:
onDemandRDS: false
hostRDSMergeSubset: false
onlyPushRouteCluster: true
# IngressClass filters which ingress resources the higress controller watches.
# -- IngressClass filters which ingress resources the higress controller watches.
# The default ingress class is higress.
# There are some special cases for special ingress class.
# 1. When the ingress class is set as nginx, the higress controller will watch ingress
@@ -18,28 +18,40 @@ global:
# 2. When the ingress class is set empty, the higress controller will watch all ingress
# resources in the k8s cluster.
ingressClass: "higress"
# -- If not empty, Higress Controller will only watch resources in the specified namespace.
# When isolating different business systems using K8s namespace,
# if each namespace requires a standalone gateway instance,
# this parameter can be used to confine the Ingress watching of Higress within the given namespace.
watchNamespace: ""
# -- Whether to disable HTTP/2 in ALPN
disableAlpnH2: false
# -- If true, Higress Controller will update the status field of Ingress resources.
# When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten,
# this parameter needs to be set to false,
# so Higress won't write the entry IP to the status field of the corresponding Ingress object.
enableStatus: true
# whether to use autoscaling/v2 template for HPA settings
# -- whether to use autoscaling/v2 template for HPA settings
# for internal usage only, not to be configured by users.
autoscalingv2API: true
local: false # When deploying to a local cluster (e.g.: kind cluster), set this to true.
# -- When deploying to a local cluster (e.g.: kind cluster), set this to true.
local: false
kind: false # Deprecated. Please use "global.local" instead. Will be removed later.
# -- If true, Higress Controller will monitor istio resources as well
enableIstioAPI: true
# -- If true, Higress Controller will monitor Gateway API resources as well
enableGatewayAPI: false
# Deprecated
enableHigressIstio: false
# Used to locate istiod.
# -- Used to locate istiod.
istioNamespace: istio-system
# enable pod disruption budget for the control plane, which is used to
# -- enable pod disruption budget for the control plane, which is used to
# ensure Istio control plane components are gradually upgraded or recovered.
defaultPodDisruptionBudget:
enabled: false
# The values aren't mutable due to a current PodDisruptionBudget limitation
# minAvailable: 1
# A minimal set of requested resources to applied to all deployments so that
# -- A minimal set of requested resources to applied to all deployments so that
# Horizontal Pod Autoscaler will be able to function (if set).
# Each component can overwrite these default values by adding its own resources
# block in the relevant section below and setting the desired resources values.
@@ -51,16 +63,16 @@ global:
# cpu: 100m
# memory: 128Mi
# Default hub for Istio images.
# -- Default hub for Istio images.
# Releases are published to docker hub under 'istio' project.
# Dev builds from prow are on gcr.io
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
# Specify image pull policy if default behavior isn't desired.
# -- Specify image pull policy if default behavior isn't desired.
# Default behavior: latest images will be Always else IfNotPresent.
imagePullPolicy: ""
# ImagePullSecrets for all ServiceAccount, list of secrets in the same namespace
# -- ImagePullSecrets for all ServiceAccount, list of secrets in the same namespace
# to use for pulling any images in pods that reference this ServiceAccount.
# For components that don't use ServiceAccounts (i.e. grafana, servicegraph, tracing)
# ImagePullSecrets will be added to the corresponding Deployment(StatefulSet) objects.
@@ -68,14 +80,14 @@ global:
imagePullSecrets: []
# - private-registry-key
# Enabled by default in master for maximising testing.
# -- Enabled by default in master for maximising testing.
istiod:
enableAnalysis: false
# To output all istio components logs in json format by adding --log_as_json argument to each container argument
# -- To output all istio components logs in json format by adding --log_as_json argument to each container argument
logAsJson: false
# Comma-separated minimum per-scope logging level of messages to output, in the form of <scope>:<level>,<scope>:<level>
# -- Comma-separated minimum per-scope logging level of messages to output, in the form of <scope>:<level>,<scope>:<level>
# The control plane has different scopes depending on component, but can configure default log level across all components
# If empty, default scope and level will be used as configured in code
logging:
@@ -83,11 +95,11 @@ global:
omitSidecarInjectorConfigMap: false
# Whether to restrict the applications namespace the controller manages;
# -- Whether to restrict the applications namespace the controller manages;
# If not set, controller watches all namespaces
oneNamespace: false
# Configure whether Operator manages webhook configurations. The current behavior
# -- Configure whether Operator manages webhook configurations. The current behavior
# of Istiod is to manage its own webhook configurations.
# When this option is set as true, Istio Operator, instead of webhooks, manages the
# webhook configurations. When this option is set as false, webhooks manage their
@@ -106,7 +118,7 @@ global:
#- global
#- "{{ valueOrDefault .DeploymentMeta.Namespace \"default\" }}.global"
# Kubernetes >=v1.11.0 will create two PriorityClass, including system-cluster-critical and
# -- Kubernetes >=v1.11.0 will create two PriorityClass, including system-cluster-critical and
# system-node-critical, it is better to configure this in order to make sure your Istio pods
# will not be killed because of low priority class.
# Refer to https://kubernetes.io/docs/concepts/configuration/pod-priority-preemption/#priorityclass
@@ -116,18 +128,18 @@ global:
proxy:
image: proxyv2
# This controls the 'policy' in the sidecar injector.
# -- This controls the 'policy' in the sidecar injector.
autoInject: enabled
# CAUTION: It is important to ensure that all Istio helm charts specify the same clusterDomain value
# -- CAUTION: It is important to ensure that all Istio helm charts specify the same clusterDomain value
# cluster domain. Default value is "cluster.local".
clusterDomain: "cluster.local"
# Per Component log level for proxy, applies to gateways and sidecars. If a component level is
# -- Per Component log level for proxy, applies to gateways and sidecars. If a component level is
# not set, then the global "logLevel" will be used.
componentLogLevel: "misc:error"
# If set, newly injected sidecars will have core dumps enabled.
# -- If set, newly injected sidecars will have core dumps enabled.
enableCoreDump: false
# istio ingress capture allowlist
@@ -136,7 +148,7 @@ global:
excludeInboundPorts: ""
includeInboundPorts: "*"
# istio egress capture allowlist
# -- istio egress capture allowlist
# https://istio.io/docs/tasks/traffic-management/egress.html#calling-external-services-directly
# example: includeIPRanges: "172.30.0.0/16,172.20.0.0/16"
# would only capture egress traffic on those two IP Ranges, all other outbound traffic would
@@ -146,29 +158,29 @@ global:
includeOutboundPorts: ""
excludeOutboundPorts: ""
# Log level for proxy, applies to gateways and sidecars.
# -- Log level for proxy, applies to gateways and sidecars.
# Expected values are: trace|debug|info|warning|error|critical|off
logLevel: warning
#If set to true, istio-proxy container will have privileged securityContext
# -- If set to true, istio-proxy container will have privileged securityContext
privileged: false
# The number of successive failed probes before indicating readiness failure.
# -- The number of successive failed probes before indicating readiness failure.
readinessFailureThreshold: 30
# The number of successive successed probes before indicating readiness success.
# -- The number of successive successed probes before indicating readiness success.
readinessSuccessThreshold: 30
# The initial delay for readiness probes in seconds.
# -- The initial delay for readiness probes in seconds.
readinessInitialDelaySeconds: 1
# The period between readiness probes.
# -- The period between readiness probes.
readinessPeriodSeconds: 2
# The readiness timeout seconds
# -- The readiness timeout seconds
readinessTimeoutSeconds: 3
# Resources for the sidecar.
# -- Resources for the sidecar.
resources:
requests:
cpu: 100m
@@ -177,18 +189,18 @@ global:
cpu: 2000m
memory: 1024Mi
# Default port for Pilot agent health checks. A value of 0 will disable health checking.
# -- Default port for Pilot agent health checks. A value of 0 will disable health checking.
statusPort: 15020
# Specify which tracer to use. One of: lightstep, datadog, stackdriver.
# -- Specify which tracer to use. One of: lightstep, datadog, stackdriver.
# If using stackdriver tracer outside GCP, set env GOOGLE_APPLICATION_CREDENTIALS to the GCP credential file.
tracer: ""
# Controls if sidecar is injected at the front of the container list and blocks the start of the other containers until the proxy is ready
# -- Controls if sidecar is injected at the front of the container list and blocks the start of the other containers until the proxy is ready
holdApplicationUntilProxyStarts: false
proxy_init:
# Base name for the proxy_init container, used to configure iptables.
# -- Base name for the proxy_init container, used to configure iptables.
image: proxyv2
resources:
limits:
@@ -198,7 +210,7 @@ global:
cpu: 10m
memory: 10Mi
# configure remote pilot and istiod service and endpoint
# -- configure remote pilot and istiod service and endpoint
remotePilotAddress: ""
##############################################################################################
@@ -206,20 +218,20 @@ global:
# make sure they are consistent across your Istio helm charts #
##############################################################################################
# The customized CA address to retrieve certificates for the pods in the cluster.
# -- The customized CA address to retrieve certificates for the pods in the cluster.
# CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint.
# If not set explicitly, default to the Istio discovery address.
caAddress: ""
# Configure a remote cluster data plane controlled by an external istiod.
# -- Configure a remote cluster data plane controlled by an external istiod.
# When set to true, istiod is not deployed locally and only a subset of the other
# discovery charts are enabled.
externalIstiod: false
# Configure a remote cluster as the config cluster for an external istiod.
# -- Configure a remote cluster as the config cluster for an external istiod.
configCluster: false
# Configure the policy for validating JWT.
# -- Configure the policy for validating JWT.
# Currently, two options are supported: "third-party-jwt" and "first-party-jwt".
jwtPolicy: "third-party-jwt"
@@ -241,7 +253,7 @@ global:
# of migration TBD, and it may be a disruptive operation to change the Mesh
# ID post-install.
#
# If the mesh admin does not specify a value, Istio will use the value of the
# -- If the mesh admin does not specify a value, Istio will use the value of the
# mesh's Trust Domain. The best practice is to select a proper Trust Domain
# value.
meshID: ""
@@ -275,68 +287,69 @@ global:
#
meshNetworks: {}
# Use the user-specified, secret volume mounted key and certs for Pilot and workloads.
# -- Use the user-specified, secret volume mounted key and certs for Pilot and workloads.
mountMtlsCerts: false
multiCluster:
# Set to true to connect two kubernetes clusters via their respective
# -- Set to true to connect two kubernetes clusters via their respective
# ingressgateway services when pods in each cluster cannot directly
# talk to one another. All clusters should be using Istio mTLS and must
# have a shared root CA for this model to work.
enabled: true
# Should be set to the name of the cluster this installation will run in. This is required for sidecar injection
# -- Should be set to the name of the cluster this installation will run in. This is required for sidecar injection
# to properly label proxies
clusterName: ""
# Network defines the network this cluster belong to. This name
# -- Network defines the network this cluster belong to. This name
# corresponds to the networks in the map of mesh networks.
network: ""
# Configure the certificate provider for control plane communication.
# -- Configure the certificate provider for control plane communication.
# Currently, two providers are supported: "kubernetes" and "istiod".
# As some platforms may not have kubernetes signing APIs,
# Istiod is the default
pilotCertProvider: istiod
sds:
# The JWT token for SDS and the aud field of such JWT. See RFC 7519, section 4.1.3.
# -- The JWT token for SDS and the aud field of such JWT. See RFC 7519, section 4.1.3.
# When a CSR is sent from Istio Agent to the CA (e.g. Istiod), this aud is to make sure the
# JWT is intended for the CA.
token:
aud: istio-ca
sts:
# The service port used by Security Token Service (STS) server to handle token exchange requests.
# -- The service port used by Security Token Service (STS) server to handle token exchange requests.
# Setting this port to a non-zero value enables STS server.
servicePort: 0
# Configuration for each of the supported tracers
# -- Configuration for each of the supported tracers
tracer:
# Configuration for envoy to send trace data to LightStep.
# -- Configuration for envoy to send trace data to LightStep.
# Disabled by default.
# address: the <host>:<port> of the satellite pool
# accessToken: required for sending data to the pool
#
datadog:
# Host:Port for submitting traces to the Datadog agent.
# -- Host:Port for submitting traces to the Datadog agent.
address: "$(HOST_IP):8126"
lightstep:
address: "" # example: lightstep-satellite:443
accessToken: "" # example: abcdefg1234567
# -- example: lightstep-satellite:443
address: ""
# -- example: abcdefg1234567
accessToken: ""
stackdriver:
# enables trace output to stdout.
# -- enables trace output to stdout.
debug: false
# The global default max number of message events per span.
# -- The global default max number of message events per span.
maxNumberOfMessageEvents: 200
# The global default max number of annotation events per span.
# -- The global default max number of annotation events per span.
maxNumberOfAnnotations: 200
# The global default max number of attributes per span.
# -- The global default max number of attributes per span.
maxNumberOfAttributes: 200
# Use the Mesh Control Protocol (MCP) for configuring Istiod. Requires an MCP source.
# -- Use the Mesh Control Protocol (MCP) for configuring Istiod. Requires an MCP source.
useMCP: false
# Observability (o11y) configurations
# -- Observability (o11y) configurations
o11y:
enabled: false
promtail:
@@ -350,7 +363,7 @@ global:
memory: 2Gi
securityContext: {}
# The name of the CA for workload certificates.
# -- The name of the CA for workload certificates.
# For example, when caName=GkeWorkloadCertificate, GKE workload certificates
# will be used as the certificates for workloads.
# The default value is "" and when caName="", the CA will be configured by other
@@ -359,7 +372,7 @@ global:
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
clusterName: ""
# meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior
# -- meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior
# See https://istio.io/docs/reference/config/istio.mesh.v1alpha1/ for all available options
meshConfig:
enablePrometheusMerge: true
@@ -370,14 +383,13 @@ meshConfig:
# and gradual adoption by setting capture only on specific workloads. It also allows
# VMs to use other DNS options, like dnsmasq or unbound.
# The namespace to treat as the administrative root namespace for Istio configuration.
# -- The namespace to treat as the administrative root namespace for Istio configuration.
# When processing a leaf namespace Istio will search for declarations in that namespace first
# and if none are found it will search in the root namespace. Any matching declaration found in the root namespace
# is processed as if it were declared in the leaf namespace.
rootNamespace:
# The trust domain corresponds to the trust root of a system
# -- The trust domain corresponds to the trust root of a system
# Refer to https://github.com/spiffe/spiffe/blob/master/standards/SPIFFE-ID.md#21-trust-domain
trustDomain: "cluster.local"
@@ -391,56 +403,57 @@ meshConfig:
gateway:
name: "higress-gateway"
# -- Number of Higress Gateway pods
replicas: 2
image: gateway
# -- Use a `DaemonSet` or `Deployment`
kind: Deployment
# The number of successive failed probes before indicating readiness failure.
# -- The number of successive failed probes before indicating readiness failure.
readinessFailureThreshold: 30
# The number of successive successed probes before indicating readiness success.
# -- The number of successive successed probes before indicating readiness success.
readinessSuccessThreshold: 1
# The initial delay for readiness probes in seconds.
# -- The initial delay for readiness probes in seconds.
readinessInitialDelaySeconds: 1
# The period between readiness probes.
# -- The period between readiness probes.
readinessPeriodSeconds: 2
# The readiness timeout seconds
# -- The readiness timeout seconds
readinessTimeoutSeconds: 3
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
tag: ""
# revision declares which revision this gateway is a part of
# -- revision declares which revision this gateway is a part of
revision: ""
rbac:
# If enabled, roles will be created to enable accessing certificates from Gateways. This is not needed
# -- If enabled, roles will be created to enable accessing certificates from Gateways. This is not needed
# when using http://gateway-api.org/.
enabled: true
serviceAccount:
# If set, a service account will be created. Otherwise, the default is used
# -- If set, a service account will be created. Otherwise, the default is used
create: true
# Annotations to add to the service account
# -- Annotations to add to the service account
annotations: {}
# The name of the service account to use.
# -- The name of the service account to use.
# If not set, the release name is used
name: ""
# Pod environment variables
# -- Pod environment variables
env: {}
httpPort: 80
httpsPort: 443
hostNetwork: false
# Labels to apply to all resources
# -- Labels to apply to all resources
labels: {}
# Annotations to apply to all resources
# -- Annotations to apply to all resources
annotations: {}
podAnnotations:
@@ -449,14 +462,14 @@ gateway:
prometheus.io/path: "/stats/prometheus"
sidecar.istio.io/inject: "false"
# Define the security context for the pod.
# -- Define the security context for the pod.
# If unset, this will be automatically set to the minimum privileges required to bind to port 80 and 443.
# On Kubernetes 1.22+, this only requires the `net.ipv4.ip_unprivileged_port_start` sysctl.
securityContext: ~
containerSecurityContext: ~
service:
# Type of service. Set to "None" to disable the service entirely
# -- Type of service. Set to "None" to disable the service entirely
type: LoadBalancer
ports:
- name: http2
@@ -496,28 +509,29 @@ gateway:
affinity: {}
# If specified, the gateway will act as a network gateway for the given network.
# -- If specified, the gateway will act as a network gateway for the given network.
networkGateway: ""
metrics:
# If true, create PodMonitor or VMPodScrape for gateway
# -- If true, create PodMonitor or VMPodScrape for gateway
enabled: false
# provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com
# -- provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com
provider: monitoring.coreos.com
interval: ""
scrapeTimeout: ""
honorLabels: false
# for monitoring.coreos.com/v1.PodMonitor
# -- for monitoring.coreos.com/v1.PodMonitor
metricRelabelings: []
relabelings: []
# for operator.victoriametrics.com/v1beta1.VMPodScrape
# -- for operator.victoriametrics.com/v1beta1.VMPodScrape
metricRelabelConfigs: []
relabelConfigs: []
# some more raw podMetricsEndpoints spec
# -- some more raw podMetricsEndpoints spec
rawSpec: {}
controller:
name: "higress-controller"
# -- Number of Higress Controller pods
replicas: 1
image: higress
@@ -541,12 +555,12 @@ controller:
create: true
serviceAccount:
# Specifies whether a service account should be created
# -- Specifies whether a service account should be created
create: true
# Annotations to add to the service account
# -- Annotations to add to the service account
annotations: {}
# The name of the service account to use.
# If not set and create is true, a name is generated using the fullname template
# -- The name of the service account to use.
# -- If not set and create is true, a name is generated using the fullname template
name: ""
podAnnotations: {}
@@ -602,7 +616,7 @@ controller:
enabled: true
email: ""
## Discovery Settings
## -- Discovery Settings
pilot:
autoscaleEnabled: false
autoscaleMin: 1
@@ -614,11 +628,11 @@ pilot:
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
tag: ""
# Can be a full hub/image:tag
# -- Can be a full hub/image:tag
image: pilot
traceSampling: 1.0
# Resources for a small pilot install
# -- Resources for a small pilot install
resources:
requests:
cpu: 500m
@@ -633,21 +647,21 @@ pilot:
cpu:
targetAverageUtilization: 80
# if protocol sniffing is enabled for outbound
# -- if protocol sniffing is enabled for outbound
enableProtocolSniffingForOutbound: true
# if protocol sniffing is enabled for inbound
# -- if protocol sniffing is enabled for inbound
enableProtocolSniffingForInbound: true
nodeSelector: {}
podAnnotations: {}
serviceAnnotations: {}
# You can use jwksResolverExtraRootCA to provide a root certificate
# -- You can use jwksResolverExtraRootCA to provide a root certificate
# in PEM format. This will then be trusted by pilot when resolving
# JWKS URIs.
jwksResolverExtraRootCA: ""
# This is used to set the source of configuration for
# -- This is used to set the source of configuration for
# the associated address in configSource, if nothing is specified
# the default MCP is assumed.
configSource:
@@ -655,21 +669,21 @@ pilot:
plugins: []
# The following is used to limit how long a sidecar can be connected
# -- The following is used to limit how long a sidecar can be connected
# to a pilot. It balances out load across pilot instances at the cost of
# increasing system churn.
keepaliveMaxServerConnectionAge: 30m
# Additional labels to apply to the deployment.
# -- Additional labels to apply to the deployment.
deploymentLabels: {}
## Mesh config settings
# Install the mesh config map, generated from values.yaml.
# -- Install the mesh config map, generated from values.yaml.
# If false, pilot wil use default values (by default) or user-supplied values.
configMap: true
# Additional labels to apply on the pod level for monitoring and logging configuration.
# -- Additional labels to apply on the pod level for monitoring and logging configuration.
podLabels: {}
# Tracing config settings
@@ -685,7 +699,7 @@ tracing:
# service: ""
# port: 9411
# Downstream config settings
# -- Downstream config settings
downstream:
idleTimeout: 180
maxRequestHeadersKb: 60
@@ -696,7 +710,7 @@ downstream:
initialConnectionWindowSize: 1048576
routeTimeout: 0
# Upstream config settings
# -- Upstream config settings
upstream:
idleTimeout: 10
connectionBufferLimits: 10485760

View File

@@ -1,57 +1,276 @@
# Higress Helm Chart
Installs the cloud-native gateway [Higress](http://higress.io/)
## Get Repo Info
```console
helm repo add higress.io https://higress.io/helm-charts
helm repo update
```
_See [helm repo](https://helm.sh/docs/helm/helm_repo/) for command documentation._
## Installing the Chart
To install the chart with the release name `higress`:
```console
helm install higress -n higress-system higress.io/higress --create-namespace --render-subchart-notes
```
## Uninstalling the Chart
To uninstall/delete the higress deployment:
```console
helm delete higress -n higress-system
```
The command removes all the Kubernetes components associated with the chart and deletes the release.
## Configuration
| **Parameter** | **Description** | **Default** |
|---|---|---|
| **Global Parameters** | | |
| global.local | Set to `true` if installing to a local K8s cluster (e.g.: Kind, Rancher Desktop, etc.) | false |
| global.ingressClass | [IngressClass](https://kubernetes.io/zh-cn/docs/concepts/services-networking/ingress/#ingress-class) which is used to filter Ingress resources Higress Controller watches.<br />If there are multiple gateway instances deployed in the cluster, this parameter can be used to distinguish the scope of each gateway instance.<br />There are some special cases for special IngressClass values:<br />1. If set to "nginx", Higress Controller will watch Ingress resources with the `nginx` IngressClass or without any Ingress class.<br />2. If set to empty, Higress Controller will watch all Ingress resources in the K8s cluster. | higress |
| global.watchNamespace | If not empty, Higress Controller will only watch resources in the specified namespace. When isolating different business systems using K8s namespace, if each namespace requires a standalone gateway instance, this parameter can be used to confine the Ingress watching of Higress within the given namespace. | "" |
| global.disableAlpnH2 | Whether to disable HTTP/2 in ALPN | true |
| global.enableStatus | If `true`, Higress Controller will update the `status` field of Ingress resources.<br />When migrating from Nginx Ingress, in order to avoid `status` field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the `status` field of the corresponding Ingress object. | true |
| global.enableIstioAPI | If `true`, Higress Controller will monitor istio resources as well | false |
| global.enableGatewayAPI | If `true`, Higress Controller will monitor Gateway API resources as well | false |
| global.istioNamespace | The namespace istio is installed to | istio-system |
| **Core Paramters** | | |
| higress-core.gateway.replicas | Number of Higress Gateway pods | 2 |
| higress-core.controller.replicas | Number of Higress Controller pods | 1 |
| **Console Paramters** | | |
| higress-console.replicaCount | Number of Higress Console pods | 1 |
| higress-console.service.type | K8s service type used by Higress Console | ClusterIP |
| higress-console.domain | Domain used to access Higress Console | console.higress.io |
| higress-console.tlsSecretName | Name of Secret resource used by TLS connections. | "" |
| higress-console.web.login.prompt | Prompt message to be displayed on the login page | "" |
| higress-console.admin.password.value | If not empty, the admin password will be configured to the specified value. | "" |
| higress-console.admin.password.length | The length of random admin password generated during installation. Only works when `higress-console.admin.password.value` is not set. | 8 |
| higress-console.o11y.enabled | If `true`, o11y suite (Grafana + Promethues) will be installed. | false |
| higress-console.pvc.rwxSupported | Set to `false` when installing to a standard K8s cluster and the target cluster doesn't support the ReadWriteMany access mode of PersistentVolumeClaim. | true |
## Higress for Kubernetes
Higress is a cloud-native api gateway based on Alibaba's internal gateway practices.
Powered by Istio and Envoy, Higress realizes the integration of the triple gateway architecture of traffic gateway, microservice gateway and security gateway, thereby greatly reducing the costs of deployment, operation and maintenance.
## Setup Repo Info
```console
helm repo add higress.io https://higress.io/helm-charts
helm repo update
```
## Install
To install the chart with the release name `higress`:
```console
helm install higress -n higress-system higress.io/higress --create-namespace --render-subchart-notes
```
## Uninstall
To uninstall/delete the higress deployment:
```console
helm delete higress -n higress-system
```
The command removes all the Kubernetes components associated with the chart and deletes the release.
## Parameters
## Values
| Key | Type | Default | Description |
|-----|------|---------|-------------|
| clusterName | string | `""` | |
| controller.affinity | object | `{}` | |
| controller.automaticHttps.email | string | `""` | |
| controller.automaticHttps.enabled | bool | `true` | |
| controller.autoscaling.enabled | bool | `false` | |
| controller.autoscaling.maxReplicas | int | `5` | |
| controller.autoscaling.minReplicas | int | `1` | |
| controller.autoscaling.targetCPUUtilizationPercentage | int | `80` | |
| controller.env | object | `{}` | |
| controller.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| controller.image | string | `"higress"` | |
| controller.imagePullSecrets | list | `[]` | |
| controller.labels | object | `{}` | |
| controller.name | string | `"higress-controller"` | |
| controller.nodeSelector | object | `{}` | |
| controller.podAnnotations | object | `{}` | |
| controller.podSecurityContext | object | `{}` | |
| controller.ports[0].name | string | `"http"` | |
| controller.ports[0].port | int | `8888` | |
| controller.ports[0].protocol | string | `"TCP"` | |
| controller.ports[0].targetPort | int | `8888` | |
| controller.ports[1].name | string | `"http-solver"` | |
| controller.ports[1].port | int | `8889` | |
| controller.ports[1].protocol | string | `"TCP"` | |
| controller.ports[1].targetPort | int | `8889` | |
| controller.ports[2].name | string | `"grpc"` | |
| controller.ports[2].port | int | `15051` | |
| controller.ports[2].protocol | string | `"TCP"` | |
| controller.ports[2].targetPort | int | `15051` | |
| controller.probe.httpGet.path | string | `"/ready"` | |
| controller.probe.httpGet.port | int | `8888` | |
| controller.probe.initialDelaySeconds | int | `1` | |
| controller.probe.periodSeconds | int | `3` | |
| controller.probe.timeoutSeconds | int | `5` | |
| controller.rbac.create | bool | `true` | |
| controller.replicas | int | `1` | Number of Higress Controller pods |
| controller.resources.limits.cpu | string | `"1000m"` | |
| controller.resources.limits.memory | string | `"2048Mi"` | |
| controller.resources.requests.cpu | string | `"500m"` | |
| controller.resources.requests.memory | string | `"2048Mi"` | |
| controller.securityContext | object | `{}` | |
| controller.service.type | string | `"ClusterIP"` | |
| controller.serviceAccount.annotations | object | `{}` | Annotations to add to the service account |
| controller.serviceAccount.create | bool | `true` | Specifies whether a service account should be created |
| controller.serviceAccount.name | string | `""` | If not set and create is true, a name is generated using the fullname template |
| controller.tag | string | `""` | |
| controller.tolerations | list | `[]` | |
| downstream | object | `{"connectionBufferLimits":32768,"http2":{"initialConnectionWindowSize":1048576,"initialStreamWindowSize":65535,"maxConcurrentStreams":100},"idleTimeout":180,"maxRequestHeadersKb":60,"routeTimeout":0}` | Downstream config settings |
| gateway.affinity | object | `{}` | |
| gateway.annotations | object | `{}` | Annotations to apply to all resources |
| gateway.autoscaling.enabled | bool | `false` | |
| gateway.autoscaling.maxReplicas | int | `5` | |
| gateway.autoscaling.minReplicas | int | `1` | |
| gateway.autoscaling.targetCPUUtilizationPercentage | int | `80` | |
| gateway.containerSecurityContext | string | `nil` | |
| gateway.env | object | `{}` | Pod environment variables |
| gateway.hostNetwork | bool | `false` | |
| gateway.httpPort | int | `80` | |
| gateway.httpsPort | int | `443` | |
| gateway.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| gateway.image | string | `"gateway"` | |
| gateway.kind | string | `"Deployment"` | Use a `DaemonSet` or `Deployment` |
| gateway.labels | object | `{}` | Labels to apply to all resources |
| gateway.metrics.enabled | bool | `false` | If true, create PodMonitor or VMPodScrape for gateway |
| gateway.metrics.honorLabels | bool | `false` | |
| gateway.metrics.interval | string | `""` | |
| gateway.metrics.metricRelabelConfigs | list | `[]` | for operator.victoriametrics.com/v1beta1.VMPodScrape |
| gateway.metrics.metricRelabelings | list | `[]` | for monitoring.coreos.com/v1.PodMonitor |
| gateway.metrics.provider | string | `"monitoring.coreos.com"` | provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com |
| gateway.metrics.rawSpec | object | `{}` | some more raw podMetricsEndpoints spec |
| gateway.metrics.relabelConfigs | list | `[]` | |
| gateway.metrics.relabelings | list | `[]` | |
| gateway.metrics.scrapeTimeout | string | `""` | |
| gateway.name | string | `"higress-gateway"` | |
| gateway.networkGateway | string | `""` | If specified, the gateway will act as a network gateway for the given network. |
| gateway.nodeSelector | object | `{}` | |
| gateway.podAnnotations."prometheus.io/path" | string | `"/stats/prometheus"` | |
| gateway.podAnnotations."prometheus.io/port" | string | `"15020"` | |
| gateway.podAnnotations."prometheus.io/scrape" | string | `"true"` | |
| gateway.podAnnotations."sidecar.istio.io/inject" | string | `"false"` | |
| gateway.rbac.enabled | bool | `true` | If enabled, roles will be created to enable accessing certificates from Gateways. This is not needed when using http://gateway-api.org/. |
| gateway.readinessFailureThreshold | int | `30` | The number of successive failed probes before indicating readiness failure. |
| gateway.readinessInitialDelaySeconds | int | `1` | The initial delay for readiness probes in seconds. |
| gateway.readinessPeriodSeconds | int | `2` | The period between readiness probes. |
| gateway.readinessSuccessThreshold | int | `1` | The number of successive successed probes before indicating readiness success. |
| gateway.readinessTimeoutSeconds | int | `3` | The readiness timeout seconds |
| gateway.replicas | int | `2` | Number of Higress Gateway pods |
| gateway.resources.limits.cpu | string | `"2000m"` | |
| gateway.resources.limits.memory | string | `"2048Mi"` | |
| gateway.resources.requests.cpu | string | `"2000m"` | |
| gateway.resources.requests.memory | string | `"2048Mi"` | |
| gateway.revision | string | `""` | revision declares which revision this gateway is a part of |
| gateway.rollingMaxSurge | string | `"100%"` | |
| gateway.rollingMaxUnavailable | string | `"25%"` | |
| gateway.securityContext | string | `nil` | Define the security context for the pod. If unset, this will be automatically set to the minimum privileges required to bind to port 80 and 443. On Kubernetes 1.22+, this only requires the `net.ipv4.ip_unprivileged_port_start` sysctl. |
| gateway.service.annotations | object | `{}` | |
| gateway.service.externalTrafficPolicy | string | `""` | |
| gateway.service.loadBalancerClass | string | `""` | |
| gateway.service.loadBalancerIP | string | `""` | |
| gateway.service.loadBalancerSourceRanges | list | `[]` | |
| gateway.service.ports[0].name | string | `"http2"` | |
| gateway.service.ports[0].port | int | `80` | |
| gateway.service.ports[0].protocol | string | `"TCP"` | |
| gateway.service.ports[0].targetPort | int | `80` | |
| gateway.service.ports[1].name | string | `"https"` | |
| gateway.service.ports[1].port | int | `443` | |
| gateway.service.ports[1].protocol | string | `"TCP"` | |
| gateway.service.ports[1].targetPort | int | `443` | |
| gateway.service.type | string | `"LoadBalancer"` | Type of service. Set to "None" to disable the service entirely |
| gateway.serviceAccount.annotations | object | `{}` | Annotations to add to the service account |
| gateway.serviceAccount.create | bool | `true` | If set, a service account will be created. Otherwise, the default is used |
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
| gateway.tag | string | `""` | |
| gateway.tolerations | list | `[]` | |
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |
| global.caName | string | `""` | The name of the CA for workload certificates. For example, when caName=GkeWorkloadCertificate, GKE workload certificates will be used as the certificates for workloads. The default value is "" and when caName="", the CA will be configured by other mechanisms (e.g., environmental variable CA_PROVIDER). |
| global.configCluster | bool | `false` | Configure a remote cluster as the config cluster for an external istiod. |
| global.defaultPodDisruptionBudget | object | `{"enabled":false}` | enable pod disruption budget for the control plane, which is used to ensure Istio control plane components are gradually upgraded or recovered. |
| global.defaultResources | object | `{"requests":{"cpu":"10m"}}` | A minimal set of requested resources to applied to all deployments so that Horizontal Pod Autoscaler will be able to function (if set). Each component can overwrite these default values by adding its own resources block in the relevant section below and setting the desired resources values. |
| global.defaultUpstreamConcurrencyThreshold | int | `10000` | |
| global.disableAlpnH2 | bool | `false` | Whether to disable HTTP/2 in ALPN |
| global.enableGatewayAPI | bool | `false` | If true, Higress Controller will monitor Gateway API resources as well |
| global.enableH3 | bool | `false` | |
| global.enableHigressIstio | bool | `false` | |
| global.enableIPv6 | bool | `false` | |
| global.enableIstioAPI | bool | `true` | If true, Higress Controller will monitor istio resources as well |
| global.enableProxyProtocol | bool | `false` | |
| global.enableSRDS | bool | `true` | |
| global.enableStatus | bool | `true` | If true, Higress Controller will update the status field of Ingress resources. When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the status field of the corresponding Ingress object. |
| global.externalIstiod | bool | `false` | Configure a remote cluster data plane controlled by an external istiod. When set to true, istiod is not deployed locally and only a subset of the other discovery charts are enabled. |
| global.hostRDSMergeSubset | bool | `false` | |
| global.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | Default hub for Istio images. Releases are published to docker hub under 'istio' project. Dev builds from prow are on gcr.io |
| global.imagePullPolicy | string | `""` | Specify image pull policy if default behavior isn't desired. Default behavior: latest images will be Always else IfNotPresent. |
| global.imagePullSecrets | list | `[]` | ImagePullSecrets for all ServiceAccount, list of secrets in the same namespace to use for pulling any images in pods that reference this ServiceAccount. For components that don't use ServiceAccounts (i.e. grafana, servicegraph, tracing) ImagePullSecrets will be added to the corresponding Deployment(StatefulSet) objects. Must be set for any cluster configured with private docker registry. |
| global.ingressClass | string | `"higress"` | IngressClass filters which ingress resources the higress controller watches. The default ingress class is higress. There are some special cases for special ingress class. 1. When the ingress class is set as nginx, the higress controller will watch ingress resources with the nginx ingress class or without any ingress class. 2. When the ingress class is set empty, the higress controller will watch all ingress resources in the k8s cluster. |
| global.istioNamespace | string | `"istio-system"` | Used to locate istiod. |
| global.istiod | object | `{"enableAnalysis":false}` | Enabled by default in master for maximising testing. |
| global.jwtPolicy | string | `"third-party-jwt"` | Configure the policy for validating JWT. Currently, two options are supported: "third-party-jwt" and "first-party-jwt". |
| global.kind | bool | `false` | |
| global.liteMetrics | bool | `true` | |
| global.local | bool | `false` | When deploying to a local cluster (e.g.: kind cluster), set this to true. |
| global.logAsJson | bool | `false` | |
| global.logging | object | `{"level":"default:info"}` | Comma-separated minimum per-scope logging level of messages to output, in the form of <scope>:<level>,<scope>:<level> The control plane has different scopes depending on component, but can configure default log level across all components If empty, default scope and level will be used as configured in code |
| global.meshID | string | `""` | If the mesh admin does not specify a value, Istio will use the value of the mesh's Trust Domain. The best practice is to select a proper Trust Domain value. |
| global.meshNetworks | object | `{}` | |
| global.mountMtlsCerts | bool | `false` | Use the user-specified, secret volume mounted key and certs for Pilot and workloads. |
| global.multiCluster.clusterName | string | `""` | Should be set to the name of the cluster this installation will run in. This is required for sidecar injection to properly label proxies |
| global.multiCluster.enabled | bool | `true` | Set to true to connect two kubernetes clusters via their respective ingressgateway services when pods in each cluster cannot directly talk to one another. All clusters should be using Istio mTLS and must have a shared root CA for this model to work. |
| global.network | string | `""` | Network defines the network this cluster belong to. This name corresponds to the networks in the map of mesh networks. |
| global.o11y | object | `{"enabled":false,"promtail":{"image":{"repository":"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/promtail","tag":"2.9.4"},"port":3101,"resources":{"limits":{"cpu":"500m","memory":"2Gi"}},"securityContext":{}}}` | Observability (o11y) configurations |
| global.omitSidecarInjectorConfigMap | bool | `false` | |
| global.onDemandRDS | bool | `false` | |
| global.oneNamespace | bool | `false` | Whether to restrict the applications namespace the controller manages; If not set, controller watches all namespaces |
| global.onlyPushRouteCluster | bool | `true` | |
| global.operatorManageWebhooks | bool | `false` | Configure whether Operator manages webhook configurations. The current behavior of Istiod is to manage its own webhook configurations. When this option is set as true, Istio Operator, instead of webhooks, manages the webhook configurations. When this option is set as false, webhooks manage their own webhook configurations. |
| global.pilotCertProvider | string | `"istiod"` | Configure the certificate provider for control plane communication. Currently, two providers are supported: "kubernetes" and "istiod". As some platforms may not have kubernetes signing APIs, Istiod is the default |
| global.priorityClassName | string | `""` | Kubernetes >=v1.11.0 will create two PriorityClass, including system-cluster-critical and system-node-critical, it is better to configure this in order to make sure your Istio pods will not be killed because of low priority class. Refer to https://kubernetes.io/docs/concepts/configuration/pod-priority-preemption/#priorityclass for more detail. |
| global.proxy.autoInject | string | `"enabled"` | This controls the 'policy' in the sidecar injector. |
| global.proxy.clusterDomain | string | `"cluster.local"` | CAUTION: It is important to ensure that all Istio helm charts specify the same clusterDomain value cluster domain. Default value is "cluster.local". |
| global.proxy.componentLogLevel | string | `"misc:error"` | Per Component log level for proxy, applies to gateways and sidecars. If a component level is not set, then the global "logLevel" will be used. |
| global.proxy.enableCoreDump | bool | `false` | If set, newly injected sidecars will have core dumps enabled. |
| global.proxy.excludeIPRanges | string | `""` | |
| global.proxy.excludeInboundPorts | string | `""` | |
| global.proxy.excludeOutboundPorts | string | `""` | |
| global.proxy.holdApplicationUntilProxyStarts | bool | `false` | Controls if sidecar is injected at the front of the container list and blocks the start of the other containers until the proxy is ready |
| global.proxy.image | string | `"proxyv2"` | |
| global.proxy.includeIPRanges | string | `"*"` | istio egress capture allowlist https://istio.io/docs/tasks/traffic-management/egress.html#calling-external-services-directly example: includeIPRanges: "172.30.0.0/16,172.20.0.0/16" would only capture egress traffic on those two IP Ranges, all other outbound traffic would be allowed by the sidecar |
| global.proxy.includeInboundPorts | string | `"*"` | |
| global.proxy.includeOutboundPorts | string | `""` | |
| global.proxy.logLevel | string | `"warning"` | Log level for proxy, applies to gateways and sidecars. Expected values are: trace|debug|info|warning|error|critical|off |
| global.proxy.privileged | bool | `false` | If set to true, istio-proxy container will have privileged securityContext |
| global.proxy.readinessFailureThreshold | int | `30` | The number of successive failed probes before indicating readiness failure. |
| global.proxy.readinessInitialDelaySeconds | int | `1` | The initial delay for readiness probes in seconds. |
| global.proxy.readinessPeriodSeconds | int | `2` | The period between readiness probes. |
| global.proxy.readinessSuccessThreshold | int | `30` | The number of successive successed probes before indicating readiness success. |
| global.proxy.readinessTimeoutSeconds | int | `3` | The readiness timeout seconds |
| global.proxy.resources | object | `{"limits":{"cpu":"2000m","memory":"1024Mi"},"requests":{"cpu":"100m","memory":"128Mi"}}` | Resources for the sidecar. |
| global.proxy.statusPort | int | `15020` | Default port for Pilot agent health checks. A value of 0 will disable health checking. |
| global.proxy.tracer | string | `""` | Specify which tracer to use. One of: lightstep, datadog, stackdriver. If using stackdriver tracer outside GCP, set env GOOGLE_APPLICATION_CREDENTIALS to the GCP credential file. |
| global.proxy_init.image | string | `"proxyv2"` | Base name for the proxy_init container, used to configure iptables. |
| global.proxy_init.resources.limits.cpu | string | `"2000m"` | |
| global.proxy_init.resources.limits.memory | string | `"1024Mi"` | |
| global.proxy_init.resources.requests.cpu | string | `"10m"` | |
| global.proxy_init.resources.requests.memory | string | `"10Mi"` | |
| global.remotePilotAddress | string | `""` | configure remote pilot and istiod service and endpoint |
| global.sds.token | object | `{"aud":"istio-ca"}` | The JWT token for SDS and the aud field of such JWT. See RFC 7519, section 4.1.3. When a CSR is sent from Istio Agent to the CA (e.g. Istiod), this aud is to make sure the JWT is intended for the CA. |
| global.sts.servicePort | int | `0` | The service port used by Security Token Service (STS) server to handle token exchange requests. Setting this port to a non-zero value enables STS server. |
| global.tracer | object | `{"datadog":{"address":"$(HOST_IP):8126"},"lightstep":{"accessToken":"","address":""},"stackdriver":{"debug":false,"maxNumberOfAnnotations":200,"maxNumberOfAttributes":200,"maxNumberOfMessageEvents":200}}` | Configuration for each of the supported tracers |
| global.tracer.datadog | object | `{"address":"$(HOST_IP):8126"}` | Configuration for envoy to send trace data to LightStep. Disabled by default. address: the <host>:<port> of the satellite pool accessToken: required for sending data to the pool |
| global.tracer.datadog.address | string | `"$(HOST_IP):8126"` | Host:Port for submitting traces to the Datadog agent. |
| global.tracer.lightstep.accessToken | string | `""` | example: abcdefg1234567 |
| global.tracer.lightstep.address | string | `""` | example: lightstep-satellite:443 |
| global.tracer.stackdriver.debug | bool | `false` | enables trace output to stdout. |
| global.tracer.stackdriver.maxNumberOfAnnotations | int | `200` | The global default max number of annotation events per span. |
| global.tracer.stackdriver.maxNumberOfAttributes | int | `200` | The global default max number of attributes per span. |
| global.tracer.stackdriver.maxNumberOfMessageEvents | int | `200` | The global default max number of message events per span. |
| global.useMCP | bool | `false` | Use the Mesh Control Protocol (MCP) for configuring Istiod. Requires an MCP source. |
| global.watchNamespace | string | `""` | If not empty, Higress Controller will only watch resources in the specified namespace. When isolating different business systems using K8s namespace, if each namespace requires a standalone gateway instance, this parameter can be used to confine the Ingress watching of Higress within the given namespace. |
| global.xdsMaxRecvMsgSize | string | `"104857600"` | |
| hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| meshConfig | object | `{"enablePrometheusMerge":true,"rootNamespace":null,"trustDomain":"cluster.local"}` | meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior See https://istio.io/docs/reference/config/istio.mesh.v1alpha1/ for all available options |
| meshConfig.rootNamespace | string | `nil` | The namespace to treat as the administrative root namespace for Istio configuration. When processing a leaf namespace Istio will search for declarations in that namespace first and if none are found it will search in the root namespace. Any matching declaration found in the root namespace is processed as if it were declared in the leaf namespace. |
| meshConfig.trustDomain | string | `"cluster.local"` | The trust domain corresponds to the trust root of a system Refer to https://github.com/spiffe/spiffe/blob/master/standards/SPIFFE-ID.md#21-trust-domain |
| pilot.autoscaleEnabled | bool | `false` | |
| pilot.autoscaleMax | int | `5` | |
| pilot.autoscaleMin | int | `1` | |
| pilot.configMap | bool | `true` | Install the mesh config map, generated from values.yaml. If false, pilot wil use default values (by default) or user-supplied values. |
| pilot.configSource | object | `{"subscribedResources":[]}` | This is used to set the source of configuration for the associated address in configSource, if nothing is specified the default MCP is assumed. |
| pilot.cpu.targetAverageUtilization | int | `80` | |
| pilot.deploymentLabels | object | `{}` | Additional labels to apply to the deployment. |
| pilot.enableProtocolSniffingForInbound | bool | `true` | if protocol sniffing is enabled for inbound |
| pilot.enableProtocolSniffingForOutbound | bool | `true` | if protocol sniffing is enabled for outbound |
| pilot.env.PILOT_ENABLE_CROSS_CLUSTER_WORKLOAD_ENTRY | string | `"false"` | |
| pilot.env.PILOT_ENABLE_METADATA_EXCHANGE | string | `"false"` | |
| pilot.env.PILOT_SCOPE_GATEWAY_TO_NAMESPACE | string | `"false"` | |
| pilot.env.VALIDATION_ENABLED | string | `"false"` | |
| pilot.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| pilot.image | string | `"pilot"` | Can be a full hub/image:tag |
| pilot.jwksResolverExtraRootCA | string | `""` | You can use jwksResolverExtraRootCA to provide a root certificate in PEM format. This will then be trusted by pilot when resolving JWKS URIs. |
| pilot.keepaliveMaxServerConnectionAge | string | `"30m"` | The following is used to limit how long a sidecar can be connected to a pilot. It balances out load across pilot instances at the cost of increasing system churn. |
| pilot.nodeSelector | object | `{}` | |
| pilot.plugins | list | `[]` | |
| pilot.podAnnotations | object | `{}` | |
| pilot.podLabels | object | `{}` | Additional labels to apply on the pod level for monitoring and logging configuration. |
| pilot.replicaCount | int | `1` | |
| pilot.resources | object | `{"requests":{"cpu":"500m","memory":"2048Mi"}}` | Resources for a small pilot install |
| pilot.rollingMaxSurge | string | `"100%"` | |
| pilot.rollingMaxUnavailable | string | `"25%"` | |
| pilot.serviceAnnotations | object | `{}` | |
| pilot.tag | string | `""` | |
| pilot.traceSampling | float | `1` | |
| revision | string | `""` | |
| tracing.enable | bool | `false` | |
| tracing.sampling | int | `100` | |
| tracing.skywalking.port | int | `11800` | |
| tracing.skywalking.service | string | `""` | |
| tracing.timeout | int | `500` | |
| upstream | object | `{"connectionBufferLimits":10485760,"idleTimeout":10}` | Upstream config settings |

View File

@@ -0,0 +1,34 @@
## Higress for Kubernetes
Higress is a cloud-native api gateway based on Alibaba's internal gateway practices.
Powered by Istio and Envoy, Higress realizes the integration of the triple gateway architecture of traffic gateway, microservice gateway and security gateway, thereby greatly reducing the costs of deployment, operation and maintenance.
## Setup Repo Info
```console
helm repo add higress.io https://higress.io/helm-charts
helm repo update
```
## Install
To install the chart with the release name `higress`:
```console
helm install higress -n higress-system higress.io/higress --create-namespace --render-subchart-notes
```
## Uninstall
To uninstall/delete the higress deployment:
```console
helm delete higress -n higress-system
```
The command removes all the Kubernetes components associated with the chart and deletes the release.
## Parameters
{{ template "chart.valuesSection" . }}

View File

@@ -27,13 +27,17 @@ http_archive(
url = "https://github.com/higress-group/proxy-wasm-cpp-sdk/archive/" + PROXY_WASM_CPP_SDK_SHA + ".tar.gz",
)
load("@proxy_wasm_cpp_sdk//bazel/dep:deps.bzl", "wasm_dependencies")
load("@proxy_wasm_cpp_sdk//bazel:repositories.bzl", "proxy_wasm_cpp_sdk_repositories")
wasm_dependencies()
proxy_wasm_cpp_sdk_repositories()
load("@proxy_wasm_cpp_sdk//bazel/dep:deps_extra.bzl", "wasm_dependencies_extra")
load("@proxy_wasm_cpp_sdk//bazel:dependencies.bzl", "proxy_wasm_cpp_sdk_dependencies")
wasm_dependencies_extra()
proxy_wasm_cpp_sdk_dependencies()
load("@proxy_wasm_cpp_sdk//bazel:dependencies_extra.bzl", "proxy_wasm_cpp_sdk_dependencies_extra")
proxy_wasm_cpp_sdk_dependencies_extra()
load("@istio_ecosystem_wasm_extensions//bazel:wasm.bzl", "wasm_libraries")

View File

@@ -2,16 +2,16 @@ diff --git a/absl/time/internal/cctz/src/time_zone_format.cc b/absl/time/interna
index d8cb047..0c5f182 100644
--- a/absl/time/internal/cctz/src/time_zone_format.cc
+++ b/absl/time/internal/cctz/src/time_zone_format.cc
@@ -18,6 +18,8 @@
#endif
#endif
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#define HAS_STRPTIME 0
+
#if defined(HAS_STRPTIME) && HAS_STRPTIME
#if !defined(_XOPEN_SOURCE)
#define _XOPEN_SOURCE // Definedness suffices for strptime.
@@ -58,7 +60,7 @@ namespace {
#if !defined(HAS_STRPTIME)
#if !defined(_MSC_VER) && !defined(__MINGW32__)
#define HAS_STRPTIME 1 // assume everyone has strptime() except windows
@@ -58,7 +60,7 @@
#if !HAS_STRPTIME
// Build a strptime() using C++11's std::get_time().
@@ -20,7 +20,7 @@ index d8cb047..0c5f182 100644
std::istringstream input(s);
input >> std::get_time(tm, fmt);
if (input.fail()) return nullptr;
@@ -648,7 +650,7 @@ const char* ParseSubSeconds(const char* dp, detail::femtoseconds* subseconds) {
@@ -648,7 +650,7 @@
// Parses a string into a std::tm using strptime(3).
const char* ParseTM(const char* dp, const char* fmt, std::tm* tm) {
if (dp != nullptr) {

View File

@@ -9,9 +9,9 @@ load(
def wasm_libraries():
http_archive(
name = "com_google_absl",
sha256 = "ec8ef47335310cc3382bdc0d0cc1097a001e67dc83fcba807845aa5696e7e1e4",
strip_prefix = "abseil-cpp-302b250e1d917ede77b5ff00a6fd9f28430f1563",
url = "https://github.com/abseil/abseil-cpp/archive/302b250e1d917ede77b5ff00a6fd9f28430f1563.tar.gz",
sha256 = "3a0bb3d2e6f53352526a8d1a7e7b5749c68cd07f2401766a404fb00d2853fa49",
strip_prefix = "abseil-cpp-4bbdb026899fea9f882a95cbd7d6a4adaf49b2dd",
url = "https://github.com/abseil/abseil-cpp/archive/4bbdb026899fea9f882a95cbd7d6a4adaf49b2dd.tar.gz",
patch_args = ["-p1"],
patches = ["//bazel:absl.patch"],
)
@@ -33,8 +33,8 @@ def wasm_libraries():
urls = ["https://github.com/google/googletest/archive/release-1.10.0.tar.gz"],
)
PROXY_WASM_CPP_HOST_SHA = "7850d1721fe3dd2ccfb86a06116f76c23b1f1bf8"
PROXY_WASM_CPP_HOST_SHA256 = "740690fc1d749849f6e24b5bc48a07dabc0565a7d03b6cd13425dba693956c57"
PROXY_WASM_CPP_HOST_SHA = "ecf42a27fcf78f42e64037d4eff1a0ca5a61e403"
PROXY_WASM_CPP_HOST_SHA256 = "9748156731e9521837686923321bf12725c32c9fa8355218209831cc3ee87080"
http_archive(
name = "proxy_wasm_cpp_host",

View File

@@ -19,7 +19,6 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "common/common_util.h"
namespace Wasm::Common::Http {
@@ -190,7 +189,8 @@ std::vector<std::string> getAllOfHeader(std::string_view key) {
std::vector<std::string> result;
auto headers = getRequestHeaderPairs()->pairs();
for (auto& header : headers) {
if (absl::EqualsIgnoreCase(Wasm::Common::stdToAbsl(header.first), Wasm::Common::stdToAbsl(key))) {
if (absl::EqualsIgnoreCase(Wasm::Common::stdToAbsl(header.first),
Wasm::Common::stdToAbsl(key))) {
result.push_back(std::string(header.second));
}
}
@@ -225,7 +225,8 @@ void forEachCookie(
v = v.substr(1, v.size() - 2);
}
if (!cookie_consumer(Wasm::Common::abslToStd(k), Wasm::Common::abslToStd(v))) {
if (!cookie_consumer(Wasm::Common::abslToStd(k),
Wasm::Common::abslToStd(v))) {
return;
}
}
@@ -265,7 +266,63 @@ std::string buildOriginalUri(std::optional<uint32_t> max_path_length) {
auto scheme = scheme_ptr->view();
auto host_ptr = getRequestHeader(Header::Host);
auto host = host_ptr->view();
return absl::StrCat(Wasm::Common::stdToAbsl(scheme), "://", Wasm::Common::stdToAbsl(host), Wasm::Common::stdToAbsl(final_path));
return absl::StrCat(Wasm::Common::stdToAbsl(scheme), "://",
Wasm::Common::stdToAbsl(host),
Wasm::Common::stdToAbsl(final_path));
}
void extractHostPathFromUri(const absl::string_view& uri,
absl::string_view& host, absl::string_view& path) {
/**
* URI RFC: https://www.ietf.org/rfc/rfc2396.txt
*
* Example:
* uri = "https://example.com:8443/certs"
* pos: ^
* host_pos: ^
* path_pos: ^
* host = "example.com:8443"
* path = "/certs"
*/
const auto pos = uri.find("://");
// Start position of the host
const auto host_pos = (pos == std::string::npos) ? 0 : pos + 3;
// Start position of the path
const auto path_pos = uri.find('/', host_pos);
if (path_pos == std::string::npos) {
// If uri doesn't have "/", the whole string is treated as host.
host = uri.substr(host_pos);
path = "/";
} else {
host = uri.substr(host_pos, path_pos - host_pos);
path = uri.substr(path_pos);
}
}
void extractPathWithoutArgsFromUri(const std::string_view& uri,
std::string_view& path_without_args) {
auto params_pos = uri.find('?');
size_t uri_end;
if (params_pos == std::string::npos) {
uri_end = uri.size();
} else {
uri_end = params_pos;
}
path_without_args = uri.substr(0, uri_end);
}
bool hasRequestBody() {
auto contentType = getRequestHeader("content-type")->toString();
auto contentLengthStr = getRequestHeader("content-length")->toString();
auto transferEncoding = getRequestHeader("transfer-encoding")->toString();
if (!contentType.empty()) {
return true;
}
if (!contentLengthStr.empty()) {
return true;
}
return transferEncoding.find("chunked") != std::string::npos;
}
} // namespace Wasm::Common::Http

View File

@@ -42,6 +42,12 @@ namespace Wasm::Common::Http {
using QueryParams = std::map<std::string, std::string>;
using SystemTime = std::chrono::time_point<std::chrono::system_clock>;
namespace Status {
constexpr int OK = 200;
constexpr int InternalServerError = 500;
constexpr int Unauthorized = 401;
} // namespace Status
namespace Header {
constexpr std::string_view Scheme(":scheme");
constexpr std::string_view Method(":method");
@@ -52,14 +58,17 @@ constexpr std::string_view Accept("accept");
constexpr std::string_view ContentMD5("content-md5");
constexpr std::string_view ContentType("content-type");
constexpr std::string_view ContentLength("content-length");
constexpr std::string_view TransferEncoding("transfer-encoding");
constexpr std::string_view UserAgent("user-agent");
constexpr std::string_view Date("date");
constexpr std::string_view Cookie("cookie");
constexpr std::string_view StrictTransportSecurity("strict-transport-security");
} // namespace Header
namespace ContentTypeValues {
constexpr std::string_view Grpc{"application/grpc"};
}
constexpr std::string_view Json{"application/json"};
} // namespace ContentTypeValues
class PercentEncoding {
public:
@@ -142,4 +151,10 @@ std::unordered_map<std::string, std::string> parseCookies(
std::string buildOriginalUri(std::optional<uint32_t> max_path_length);
void extractHostPathFromUri(const absl::string_view& uri,
absl::string_view& host, absl::string_view& path);
void extractPathWithoutArgsFromUri(const std::string_view& uri,
std::string_view& path_without_args);
bool hasRequestBody();
} // namespace Wasm::Common::Http

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "bot_detect.wasm",
srcs = [
"plugin.cc",
@@ -28,7 +28,6 @@ wasm_cc_binary(
"//common:http_util",
"//common:regex_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "custom_response.wasm",
srcs = [
"plugin.cc",
@@ -27,7 +27,6 @@ wasm_cc_binary(
"//common:json_util",
"//common:http_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "hmac_auth.wasm",
srcs = [
"plugin.cc",
@@ -30,7 +30,6 @@ wasm_cc_binary(
"//common:crypto_util",
"//common:http_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
@@ -33,7 +33,6 @@ wasm_cc_binary(
"//common:json_util",
"//common:http_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "key_auth.wasm",
srcs = [
"plugin.cc",
@@ -28,7 +28,6 @@ wasm_cc_binary(
"//common:json_util",
"//common:http_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "key_rate_limit.wasm",
srcs = [
"plugin.cc",
@@ -29,7 +29,6 @@ wasm_cc_binary(
"//common:json_util",
"//common:http_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
proxy_wasm_cc_binary(
name = "model_mapper.wasm",
srcs = [
"plugin.cc",
"plugin.h",
],
deps = [
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics_higress",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//common:json_util",
"//common:http_util",
"//common:rule_util",
],
)
cc_library(
name = "model_mapper_lib",
srcs = [
"plugin.cc",
],
hdrs = [
"plugin.h",
],
copts = ["-DNULL_PLUGIN"],
deps = [
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"//common:json_util",
"@proxy_wasm_cpp_host//:lib",
"//common:http_util_nullvm",
"//common:rule_util_nullvm",
],
)
cc_test(
name = "model_mapper_test",
srcs = [
"plugin_test.cc",
],
copts = ["-DNULL_PLUGIN"],
deps = [
":model_mapper_lib",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
"@proxy_wasm_cpp_host//:lib",
],
)
declare_wasm_image_targets(
name = "model_mapper",
wasm_file = ":model_mapper.wasm",
)

View File

@@ -0,0 +1,63 @@
## 功能说明
`model-mapper`插件实现了基于LLM协议中的model参数路由的功能
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 ## 运行属性
插件执行阶段:认证阶段
插件执行优先级800
|
## 效果说明
如下配置
```yaml
modelMapping:
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'*': "qwen-turbo"
```
开启后,`gpt-4-` 开头的模型参数会被改写为 `qwen-max`, `gpt-4o` 会被改写为 `qwen-vl-plus`,其他所有模型会被改写为 `qwen-turbo`
例如原本的请求是:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,原始的 LLM 请求体将被改成:
```json
{
"model": "qwen-vl-plus",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,65 @@
## Function Description
The `model-mapper` plugin implements the functionality of routing based on the model parameter in the LLM protocol.
## Configuration Fields
| Name | Data Type | Filling Requirement | Default Value | Description |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | Optional | model | The location of the model parameter in the request body. |
| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model names in the request to the model names supported by the service provider.<br/>1. Supports prefix matching. For example, use "gpt-3-*" to match all models whose names start with “gpt-3-”;<br/>2. Supports using "*" as the key to configure a generic fallback mapping relationship;<br/>3. If the target name in the mapping is an empty string "", it means to keep the original model name. |
| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only applies to requests with these specific path suffixes. |
## Runtime Properties
Plugin execution phase: Authentication phase
Plugin execution priority: 800
## Effect Description
With the following configuration:
```yaml
modelMapping:
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'*': "qwen-turbo"
```
After enabling, model parameters starting with `gpt-4-` will be rewritten to `qwen-max`, `gpt-4o` will be rewritten to `qwen-vl-plus`, and all other models will be rewritten to `qwen-turbo`.
For example, if the original request was:
```json
{
"model": "gpt-4o",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the main repository for the higress project?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the original LLM request body will be modified to:
```json
{
"model": "qwen-vl-plus",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the main repository for the higress project?"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```

View File

@@ -0,0 +1,243 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/model_mapper/plugin.h"
#include <array>
#include <limits>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "common/http_util.h"
#include "common/json_util.h"
using ::nlohmann::json;
using ::Wasm::Common::JsonArrayIterate;
using ::Wasm::Common::JsonGetField;
using ::Wasm::Common::JsonObjectIterate;
using ::Wasm::Common::JsonValueAs;
#ifdef NULL_PLUGIN
namespace proxy_wasm {
namespace null_plugin {
namespace model_mapper {
PROXY_WASM_NULL_PLUGIN_REGISTRY
#endif
static RegisterContextFactory register_ModelMapper(
CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext));
namespace {
constexpr std::string_view SetDecoderBufferLimitKey =
"SetRequestBodyBufferLimit";
constexpr std::string_view DefaultMaxBodyBytes = "10485760";
} // namespace
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelMapperConfigRule& rule) {
if (auto it = configuration.find("modelKey"); it != configuration.end()) {
if (it->is_string()) {
rule.model_key_ = it->get<std::string>();
} else {
LOG_ERROR("Invalid type for modelKey. Expected string.");
return false;
}
}
if (auto it = configuration.find("modelMapping"); it != configuration.end()) {
if (!it->is_object()) {
LOG_ERROR("Invalid type for modelMapping. Expected object.");
return false;
}
auto model_mapping = it->get<Wasm::Common::JsonObject>();
if (!JsonObjectIterate(model_mapping, [&](std::string key) -> bool {
auto model_json = model_mapping.find(key);
if (!model_json->is_string()) {
LOG_ERROR(
"Invalid type for item in modelMapping. Expected string.");
return false;
}
if (key == "*") {
rule.default_model_mapping_ = model_json->get<std::string>();
return true;
}
if (absl::EndsWith(key, "*")) {
rule.prefix_model_mapping_.emplace_back(
absl::StripSuffix(key, "*"), model_json->get<std::string>());
return true;
}
auto ret = rule.exact_model_mapping_.emplace(
key, model_json->get<std::string>());
if (!ret.second) {
LOG_ERROR("Duplicate key in modelMapping: " + key);
return false;
}
return true;
})) {
return false;
}
}
if (!JsonArrayIterate(
configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
if (item.is_string()) {
rule.enable_on_path_suffix_.emplace_back(item.get<std::string>());
return true;
}
return false;
})) {
LOG_WARN("Invalid type for item in enableOnPathSuffix. Expected string.");
return false;
}
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
return false;
}
return true;
}
bool PluginRootContext::configure(size_t configuration_size) {
auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration,
0, configuration_size);
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
}
FilterHeadersStatus PluginRootContext::onHeader(
const ModelMapperConfigRule& rule) {
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
auto params_pos = path.find('?');
size_t uri_end;
if (params_pos == std::string::npos) {
uri_end = path.size();
} else {
uri_end = params_pos;
}
bool enable = false;
for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
enable = true;
break;
}
}
if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
getRequestHeader(Wasm::Common::Http::Header::ContentType);
if (!absl::StrContains(content_type_value->view(),
Wasm::Common::Http::ContentTypeValues::Json)) {
return FilterHeadersStatus::Continue;
}
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
return FilterHeadersStatus::StopIteration;
}
FilterDataStatus PluginRootContext::onBody(const ModelMapperConfigRule& rule,
std::string_view body) {
const auto& exact_model_mapping = rule.exact_model_mapping_;
const auto& prefix_model_mapping = rule.prefix_model_mapping_;
const auto& default_model_mapping = rule.default_model_mapping_;
const auto& model_key = rule.model_key_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
return FilterDataStatus::Continue;
}
auto body_json = body_json_opt.value();
std::string old_model;
if (body_json.contains(model_key)) {
old_model = body_json[model_key];
}
std::string model =
default_model_mapping.empty() ? old_model : default_model_mapping;
if (auto it = exact_model_mapping.find(old_model);
it != exact_model_mapping.end()) {
model = it->second;
} else {
for (auto& prefix_model_pair : prefix_model_mapping) {
if (absl::StartsWith(old_model, prefix_model_pair.first)) {
model = prefix_model_pair.second;
break;
}
}
}
if (!model.empty() && model != old_model) {
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(
absl::StrCat("model mapped, before:", old_model, ", after:", model));
}
return FilterDataStatus::Continue;
}
FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
auto* rootCtx = rootContext();
return rootCtx->onHeaders([rootCtx, this](const auto& config) {
auto ret = rootCtx->onHeader(config);
if (ret == FilterHeadersStatus::StopIteration) {
this->config_ = &config;
}
return ret;
});
}
FilterDataStatus PluginContext::onRequestBody(size_t body_size,
bool end_stream) {
if (config_ == nullptr) {
return FilterDataStatus::Continue;
}
body_total_size_ += body_size;
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
auto* rootCtx = rootContext();
return rootCtx->onBody(*config_, body->view());
}
#ifdef NULL_PLUGIN
} // namespace model_mapper
} // namespace null_plugin
} // namespace proxy_wasm
#endif

View File

@@ -0,0 +1,87 @@
/*
* Copyright (c) 2022 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <string>
#include <unordered_set>
#include "common/route_rule_matcher.h"
#define ASSERT(_X) assert(_X)
#ifndef NULL_PLUGIN
#include "proxy_wasm_intrinsics.h"
#else
#include "include/proxy-wasm/null_plugin.h"
namespace proxy_wasm {
namespace null_plugin {
namespace model_mapper {
#endif
struct ModelMapperConfigRule {
std::string model_key_ = "model";
std::map<std::string, std::string> exact_model_mapping_;
std::vector<std::pair<std::string, std::string>> prefix_model_mapping_;
std::string default_model_mapping_;
std::vector<std::string> enable_on_path_suffix_ = {"/v1/chat/completions"};
};
// PluginRootContext is the root context for all streams processed by the
// thread. It has the same lifetime as the worker thread and acts as target for
// interactions that outlives individual stream, e.g. timer, async calls.
class PluginRootContext : public RootContext,
public RouteRuleMatcher<ModelMapperConfigRule> {
public:
PluginRootContext(uint32_t id, std::string_view root_id)
: RootContext(id, root_id) {}
~PluginRootContext() {}
bool onConfigure(size_t) override;
FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
bool configure(size_t);
private:
bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
};
// Per-stream context.
class PluginContext : public Context {
public:
explicit PluginContext(uint32_t id, RootContext* root) : Context(id, root) {}
FilterHeadersStatus onRequestHeaders(uint32_t, bool) override;
FilterDataStatus onRequestBody(size_t, bool) override;
private:
inline PluginRootContext* rootContext() {
return dynamic_cast<PluginRootContext*>(this->root());
}
size_t body_total_size_ = 0;
const ModelMapperConfigRule* config_ = nullptr;
};
#ifdef NULL_PLUGIN
} // namespace model_mapper
} // namespace null_plugin
} // namespace proxy_wasm
#endif

View File

@@ -0,0 +1,301 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/model_mapper/plugin.h"
#include <cstddef>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "include/proxy-wasm/context.h"
#include "include/proxy-wasm/null.h"
namespace proxy_wasm {
namespace null_plugin {
namespace model_mapper {
NullPluginRegistry* context_registry_;
RegisterNullVmPluginFactory register_model_mapper_plugin("model_mapper", []() {
return std::make_unique<NullPlugin>(model_mapper::context_registry_);
});
class MockContext : public proxy_wasm::ContextBase {
public:
MockContext(WasmBase* wasm) : ContextBase(wasm) {}
MOCK_METHOD(BufferInterface*, getBuffer, (WasmBufferType));
MOCK_METHOD(WasmResult, log, (uint32_t, std::string_view));
MOCK_METHOD(WasmResult, setBuffer,
(WasmBufferType, size_t, size_t, std::string_view));
MOCK_METHOD(WasmResult, getHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */,
std::string_view* /*result */));
MOCK_METHOD(WasmResult, replaceHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */,
std::string_view /* value */));
MOCK_METHOD(WasmResult, removeHeaderMapValue,
(WasmHeaderMapType /* type */, std::string_view /* key */));
MOCK_METHOD(WasmResult, addHeaderMapValue,
(WasmHeaderMapType, std::string_view, std::string_view));
MOCK_METHOD(WasmResult, getProperty, (std::string_view, std::string*));
MOCK_METHOD(WasmResult, setProperty, (std::string_view, std::string_view));
};
class ModelMapperTest : public ::testing::Test {
protected:
ModelMapperTest() {
// Initialize test VM
test_vm_ = createNullVm();
wasm_base_ = std::make_unique<WasmBase>(
std::move(test_vm_), "test-vm", "", "",
std::unordered_map<std::string, std::string>{},
AllowedCapabilitiesMap{});
wasm_base_->load("model_mapper");
wasm_base_->initialize();
// Initialize host side context
mock_context_ = std::make_unique<MockContext>(wasm_base_.get());
current_context_ = mock_context_.get();
// Initialize Wasm sandbox context
root_context_ = std::make_unique<PluginRootContext>(0, "");
context_ = std::make_unique<PluginContext>(1, root_context_.get());
ON_CALL(*mock_context_, log(testing::_, testing::_))
.WillByDefault([](uint32_t, std::string_view m) {
std::cerr << m << "\n";
return WasmResult::Ok;
});
ON_CALL(*mock_context_, getBuffer(testing::_))
.WillByDefault([&](WasmBufferType type) {
if (type == WasmBufferType::HttpRequestBody) {
return &body_;
}
return &config_;
});
ON_CALL(*mock_context_, getHeaderMapValue(WasmHeaderMapType::RequestHeaders,
testing::_, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view* result) {
if (header == "content-type") {
*result = "application/json";
} else if (header == "content-length") {
*result = "1024";
} else if (header == ":path") {
*result = path_;
}
return WasmResult::Ok;
});
ON_CALL(*mock_context_,
replaceHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_,
testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view key,
std::string_view value) { return WasmResult::Ok; });
ON_CALL(*mock_context_,
removeHeaderMapValue(WasmHeaderMapType::RequestHeaders, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view key) {
return WasmResult::Ok;
});
ON_CALL(*mock_context_, addHeaderMapValue(WasmHeaderMapType::RequestHeaders,
testing::_, testing::_))
.WillByDefault([&](WasmHeaderMapType, std::string_view header,
std::string_view value) { return WasmResult::Ok; });
ON_CALL(*mock_context_, getProperty(testing::_, testing::_))
.WillByDefault([&](std::string_view path, std::string* result) {
if (absl::StartsWith(path, "route_name")) {
*result = route_name_;
} else if (absl::StartsWith(path, "cluster_name")) {
*result = service_name_;
}
return WasmResult::Ok;
});
ON_CALL(*mock_context_, setProperty(testing::_, testing::_))
.WillByDefault(
[&](std::string_view, std::string_view) { return WasmResult::Ok; });
}
~ModelMapperTest() override {}
std::unique_ptr<WasmBase> wasm_base_;
std::unique_ptr<WasmVm> test_vm_;
std::unique_ptr<MockContext> mock_context_;
std::unique_ptr<PluginRootContext> root_context_;
std::unique_ptr<PluginContext> context_;
std::string route_name_;
std::string service_name_;
std::string path_;
BufferBase body_;
BufferBase config_;
};
TEST_F(ModelMapperTest, ModelMappingTest) {
std::string configuration = R"(
{
"modelMapping": {
"*": "qwen-long",
"gpt-4*": "qwen-max",
"gpt-4o": "qwen-turbo",
"gpt-4o-mini": "qwen-plus",
"text-embedding-v1": ""
}
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "gpt-3.5"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "gpt-4"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-max"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "gpt-4o"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-turbo"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "gpt-4o-mini"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-plus"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({"model": "text-embedding-v1"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
request_json = R"({})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
}
TEST_F(ModelMapperTest, RouteLevelModelMappingTest) {
std::string configuration = R"(
{
"_rules_": [
{
"_match_route_": ["route-a"],
"_match_service_": ["service-1"],
"modelMapping": {
"*": "qwen-long"
}
},
{
"_match_route_": ["route-b"],
"_match_service_": ["service-2"],
"modelMapping": {
"*": "qwen-max"
}
},
{
"_match_route_": ["route-b"],
"_match_service_": ["service-3"],
"modelMapping": {
"*": "qwen-turbo"
}
}
]})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/api/v1/chat/completions";
std::string request_json = R"({"model": "gpt-4"})";
body_.set(request_json);
route_name_ = "route-a";
service_name_ = "outbound|80||service-1";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-long"})");
return WasmResult::Ok;
});
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
route_name_ = "route-b";
service_name_ = "outbound|80||service-2";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-max"})");
return WasmResult::Ok;
});
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
route_name_ = "route-b";
service_name_ = "outbound|80||service-3";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t, size_t, std::string_view body) {
EXPECT_EQ(body, R"({"model":"qwen-turbo"})");
return WasmResult::Ok;
});
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(20, true), FilterDataStatus::Continue);
}
} // namespace model_mapper
} // namespace null_plugin
} // namespace proxy_wasm

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "model_router.wasm",
srcs = [
"plugin.cc",

View File

@@ -1,33 +1,35 @@
## 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能
## 运行属性
插件执行阶段:`默认阶段`
插件执行优先级:`260`
## 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `enable` | bool | 选填 | false | 是否开启基于model参数路由 |
| `model_key` | string | 选填 | model | 请求body中model参数的位置 |
| `add_header_key` | string | 选填 | x-higress-llm-provider | 从model参数中解析出的provider名字放到哪个请求header中 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 |
## 运行属性
插件执行阶段:认证阶段
插件执行优先级900
## 效果说明
如下开启基于model参数路由的功能:
### 基于 model 参数进行路由
需要做如下配置:
```yaml
enable: true
modelToHeader: x-higress-llm-model
```
开启后,插件将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
插件将请求中 model 参数提取出来,设置到 x-higress-llm-model 这个请求 header 中,用于后续路由,举例来说,原生的 LLM 请求体是:
```json
{
"model": "qwen/qwen-long",
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
@@ -43,7 +45,39 @@ enable: true
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: qwen
x-higress-llm-model: qwen-long
### 提取 model 参数中的 provider 字段用于路由
> 注意这种模式需要客户端在 model 参数中通过`/`分隔的方式,来指定 provider
需要做如下配置:
```yaml
addProviderHeader: x-higress-llm-provider
```
插件会将请求中 model 参数的 provider 部分(如果有)提取出来,设置到 x-higress-llm-provider 这个请求 header 中,用于后续路由,并将 model 参数重写为模型名称部分。举例来说,原生的 LLM 请求体是:
```json
{
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "higress项目主仓库的github地址是什么"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
经过这个插件后,将添加下面这个请求头(可以用于路由匹配)
x-higress-llm-provider: dashscope
原始的 LLM 请求体将被改成:

View File

@@ -1,38 +1,41 @@
## Function Description
The `model-router` plugin implements the functionality of routing based on the `model` parameter in the LLM protocol.
## Runtime Properties
Plugin Execution Phase: `Default Phase`
Plugin Execution Priority: `260`
The `model-router` plugin implements the function of routing based on the model parameter in the LLM protocol.
## Configuration Fields
| Name | Data Type | Filling Requirement | Default Value | Description |
| -------------------- | ------------- | --------------------- | ---------------------- | ----------------------------------------------------- |
| `enable` | bool | Optional | false | Whether to enable routing based on the `model` parameter |
| `model_key` | string | Optional | model | The location of the `model` parameter in the request body |
| `add_header_key` | string | Optional | x-higress-llm-provider | The header where the parsed provider name from the `model` parameter will be placed |
| Name | Data Type | Filling Requirement | Default Value | Description |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | Optional | model | The location of the model parameter in the request body |
| `addProviderHeader` | string | Optional | - | Which request header to place the provider name parsed from the model parameter |
| `modelToHeader` | string | Optional | - | Which request header to directly place the model parameter |
| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only effective for requests with these specific path suffixes |
## Runtime Attributes
Plugin execution phase: Authentication phase
Plugin execution priority: 900
## Effect Description
To enable routing based on the `model` parameter, use the following configuration:
### Routing Based on the model Parameter
The following configuration is required:
```yaml
enable: true
modelToHeader: x-higress-llm-model
```
After enabling, the plugin extracts the provider part (if any) from the `model` parameter in the request, and sets it in the `x-higress-llm-provider` request header for subsequent routing. It also rewrites the `model` parameter to the model name part. For example, the original LLM request body is:
The plugin will extract the model parameter from the request and set it in the x-higress-llm-model request header, which can be used for subsequent routing. For example, the original LLM request body:
```json
{
"model": "openai/gpt-4o",
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address for the main repository of the Higress project?"
"content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
@@ -40,24 +43,55 @@ After enabling, the plugin extracts the provider part (if any) from the `model`
}
```
After processing by the plugin, the following request header (which can be used for routing matching) will be added:
After processing by this plugin, the following request header (which can be used for route matching) will be added:
`x-higress-llm-provider: openai`
x-higress-llm-model: qwen-long
The original LLM request body will be modified to:
### Extracting the provider Field from the model Parameter for Routing
> Note that this mode requires the client to specify the provider using a `/` separator in the model parameter.
The following configuration is required:
```yaml
addProviderHeader: x-higress-llm-provider
```
The plugin will extract the provider part (if present) from the model parameter in the request and set it in the x-higress-llm-provider request header, which can be used for subsequent routing, and rewrite the model parameter to the model name part. For example, the original LLM request body:
```json
{
"model": "gpt-4o",
"model": "dashscope/qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address for the main repository of the Higress project?"
"content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}
```
After processing by this plugin, the following request header (which can be used for route matching) will be added:
x-higress-llm-provider: dashscope
The original LLM request body will be changed to:
```json
{
"model": "qwen-long",
"frequency_penalty": 0,
"max_tokens": 800,
"stream": false,
"messages": [{
"role": "user",
"content": "What is the GitHub address of the main repository for the higress project"
}],
"presence_penalty": 0,
"temperature": 0.7,
"top_p": 0.95
}

View File

@@ -51,41 +51,54 @@ constexpr std::string_view DefaultMaxBodyBytes = "10485760";
bool PluginRootContext::parsePluginConfig(const json& configuration,
ModelRouterConfigRule& rule) {
if (auto it = configuration.find("enable"); it != configuration.end()) {
if (it->is_boolean()) {
rule.enable_ = it->get<bool>();
} else {
LOG_WARN("Invalid type for enable. Expected boolean.");
return false;
}
}
if (auto it = configuration.find("model_key"); it != configuration.end()) {
if (auto it = configuration.find("modelKey"); it != configuration.end()) {
if (it->is_string()) {
rule.model_key_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for model_key. Expected string.");
LOG_ERROR("Invalid type for modelKey. Expected string.");
return false;
}
}
if (auto it = configuration.find("add_header_key");
if (auto it = configuration.find("addProviderHeader");
it != configuration.end()) {
if (it->is_string()) {
rule.add_header_key_ = it->get<std::string>();
rule.add_provider_header_ = it->get<std::string>();
} else {
LOG_WARN("Invalid type for add_header_key. Expected string.");
LOG_ERROR("Invalid type for addProviderHeader. Expected string.");
return false;
}
}
if (auto it = configuration.find("modelToHeader");
it != configuration.end()) {
if (it->is_string()) {
rule.model_to_header_ = it->get<std::string>();
} else {
LOG_ERROR("Invalid type for modelToHeader. Expected string.");
return false;
}
}
if (!JsonArrayIterate(
configuration, "enableOnPathSuffix", [&](const json& item) -> bool {
if (item.is_string()) {
rule.enable_on_path_suffix_.emplace_back(item.get<std::string>());
return true;
}
return false;
})) {
LOG_ERROR("Invalid type for item in enableOnPathSuffix. Expected string.");
return false;
}
return true;
}
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
LOG_ERROR("configuration has errors initialization will not continue.");
return false;
}
return true;
@@ -97,13 +110,13 @@ bool PluginRootContext::configure(size_t configuration_size) {
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
LOG_ERROR(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
@@ -111,7 +124,25 @@ bool PluginRootContext::configure(size_t configuration_size) {
FilterHeadersStatus PluginRootContext::onHeader(
const ModelRouterConfigRule& rule) {
if (!rule.enable_ || !Wasm::Common::Http::hasRequestBody()) {
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString();
auto params_pos = path.find('?');
size_t uri_end;
if (params_pos == std::string::npos) {
uri_end = path.size();
} else {
uri_end = params_pos;
}
bool enable = false;
for (const auto& enable_suffix : rule.enable_on_path_suffix_) {
if (absl::EndsWith({path.c_str(), uri_end}, enable_suffix)) {
enable = true;
break;
}
}
if (!enable) {
return FilterHeadersStatus::Continue;
}
auto content_type_value =
@@ -128,7 +159,8 @@ FilterHeadersStatus PluginRootContext::onHeader(
FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_header_key = rule.add_header_key_;
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
auto body_json_opt = ::Wasm::Common::JsonParse(body);
if (!body_json_opt) {
LOG_WARN(absl::StrCat("cannot parse body to JSON string: ", body));
@@ -137,18 +169,24 @@ FilterDataStatus PluginRootContext::onBody(const ModelRouterConfigRule& rule,
auto body_json = body_json_opt.value();
if (body_json.contains(model_key)) {
std::string model_value = body_json[model_key];
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_header_key, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route not work, model:", model_value));
if (!model_to_header.empty()) {
replaceRequestHeader(model_to_header, model_value);
}
if (!add_provider_header.empty()) {
auto pos = model_value.find('/');
if (pos != std::string::npos) {
const auto& provider = model_value.substr(0, pos);
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_provider_header, provider);
body_json[model_key] = model;
setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(), body_json.dump());
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
} else {
LOG_DEBUG(absl::StrCat("model route to provider not work, model:",
model_value));
}
}
}
return FilterDataStatus::Continue;

View File

@@ -37,9 +37,10 @@ namespace model_router {
#endif
struct ModelRouterConfigRule {
bool enable_ = false;
std::string model_key_ = "model";
std::string add_header_key_ = "x-higress-llm-provider";
std::string add_provider_header_;
std::string model_to_header_;
std::vector<std::string> enable_on_path_suffix_ = {"/v1/chat/completions"};
};
// PluginRootContext is the root context for all streams processed by the

View File

@@ -89,6 +89,8 @@ class ModelRouterTest : public ::testing::Test {
*result = "application/json";
} else if (header == "content-length") {
*result = "1024";
} else if (header == ":path") {
*result = path_;
}
return WasmResult::Ok;
});
@@ -122,6 +124,7 @@ class ModelRouterTest : public ::testing::Test {
std::unique_ptr<PluginRootContext> root_context_;
std::unique_ptr<PluginContext> context_;
std::string route_name_;
std::string path_;
BufferBase body_;
BufferBase config_;
};
@@ -129,12 +132,13 @@ class ModelRouterTest : public ::testing::Test {
TEST_F(ModelRouterTest, RewriteModelAndHeader) {
std::string configuration = R"(
{
"enable": true
"addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
@@ -154,19 +158,73 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeader) {
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
std::string request_json = R"({"model": "qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
EXPECT_CALL(
*mock_context_,
replaceHeaderMapValue(testing::_, std::string_view("x-higress-llm-model"),
std::string_view("qwen-long")));
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, IgnorePath) {
std::string configuration = R"(
{
"addProviderHeader": "x-higress-llm-provider"
})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/xxxx";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
EXPECT_CALL(*mock_context_,
replaceHeaderMapValue(testing::_,
std::string_view("x-higress-llm-provider"),
std::string_view("qwen")))
.Times(0);
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
EXPECT_EQ(context_->onRequestBody(28, true), FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
std::string configuration = R"(
{
"_rules_": [
{
"_match_route_": ["route-a"],
"enable": true
"addProviderHeader": "x-higress-llm-provider"
}
]})";
config_.set(configuration);
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/api/v1/chat/completions";
std::string request_json = R"({"model": "qwen/qwen-long"})";
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "request_block.wasm",
srcs = [
"plugin.cc",
@@ -27,7 +27,6 @@ wasm_cc_binary(
"//common:json_util",
"//common:http_util",
"//common:rule_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@proxy_wasm_cpp_sdk//bazel/wasm:wasm.bzl", "wasm_cc_binary")
load("@proxy_wasm_cpp_sdk//bazel:defs.bzl", "proxy_wasm_cc_binary")
load("//bazel:wasm.bzl", "declare_wasm_image_targets")
wasm_cc_binary(
proxy_wasm_cc_binary(
name = "sni_misdirect.wasm",
srcs = [
"plugin.cc",
@@ -25,7 +25,6 @@ wasm_cc_binary(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings",
"//common:http_util",
"@proxy_wasm_cpp_sdk//:proxy_wasm_intrinsics",
],
)

View File

@@ -5,7 +5,7 @@ description: AI Agent插件配置参考
---
## 功能说明
一个可定制化的 API AI Agent支持配置 http method 类型为 GET 与 POST 的 API支持多轮对话支持流式与非流式模式。
一个可定制化的 API AI Agent支持配置 http method 类型为 GET 与 POST 的 API支持多轮对话支持流式与非流式模式,支持将结果格式化为自定义的 json
agent流程图如下
![ai-agent](https://img.alicdn.com/imgextra/i1/O1CN01PGSDW31WQfEPm173u_!!6000000002783-0-tps-2733-1473.jpg)
@@ -21,6 +21,7 @@ agent流程图如下
| `llm` | object | 必填 | - | 配置 AI 服务提供商的信息 |
| `apis` | object | 必填 | - | 配置外部 API 服务提供商的信息 |
| `promptTemplate` | object | 非必填 | - | 配置 Agent ReAct 模板的信息 |
| `jsonResp` | object | 非必填 | - | 配置 json 格式化的相关信息 |
`llm`的配置字段说明如下:
@@ -78,7 +79,14 @@ agent流程图如下
| `observation` | string | 非必填 | - | Agent ReAct 模板的 observation 部分 |
| `thought2` | string | 非必填 | - | Agent ReAct 模板的 thought2 部分 |
## 用法示例
`jsonResp`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------|---------|--------|-----------------------------------|
| `enable` | bool | 非必填 | false | 是否开启 json 格式化。 |
| `jsonSchema` | string | 非必填 | - | 自定义 json schema |
## 用法示例-不开启 json 格式化
**配置信息**
@@ -293,7 +301,7 @@ deepl提供了一个工具用于翻译给定的句子支持多语言。。
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"我想在济南市鑫盛大厦附近喝咖啡,给我推荐几个"}],"presence_penalty":0,"temperature":0,"top_p":0}'
@@ -308,7 +316,7 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"济南市现在的天气情况如何?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
@@ -323,7 +331,7 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role": "user","content": "济南的天气如何?"},{ "role": "assistant","content": "目前济南市的天气为多云气温为24℃数据更新时间为2024年9月12日21时50分14秒。"},{"role": "user","content": "北京呢?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
@@ -338,7 +346,7 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"济南市现在的天气情况如何?用华氏度表示,用日语回答"}],"presence_penalty":0,"temperature":0,"top_p":0}'
@@ -353,7 +361,7 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"帮我用德语翻译以下句子:九头蛇万岁!"}],"presence_penalty":0,"temperature":0,"top_p":0}'
@@ -364,3 +372,71 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
```json
{"id":"65dcf12c-61ff-9e68-bffa-44fc9e6070d5","choices":[{"index":0,"message":{"role":"assistant","content":" “九头蛇万岁!”的德语翻译为“Hoch lebe Hydra!”。"},"finish_reason":"stop"}],"created":1724043865,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":908,"completion_tokens":52,"total_tokens":960}}
```
## 用法示例-开启 json 格式化
**配置信息**
在上述配置的基础上增加 jsonResp 配置
```yaml
jsonResp:
enable: true
```
**请求示例**
```shell
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"北京市现在的天气情况如何?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**响应示例**
```json
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content": "{\"city\": \"北京市\", \"weather_condition\": \"多云\", \"temperature\": \"19℃\", \"data_update_time\": \"2024年10月9日16时37分53秒\"}"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
```
如果不自定义 json schema大模型会自动生成一个 json 格式
**配置信息**
增加自定义 json schema 配置
```yaml
jsonResp:
enable: true
jsonSchema: |
title: WeatherSchema
type: object
properties:
location:
type: string
description: 城市名称.
weather:
type: string
description: 天气情况.
temperature:
type: string
description: 温度.
update_time:
type: string
description: 数据更新时间.
required:
- location
- weather
- temperature
additionalProperties: false
```
**请求示例**
```shell
curl 'http://<这里换成网关地址>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"北京市现在的天气情况如何?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**响应示例**
```json
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content": "{\"location\": \"北京市\", \"weather\": \"多云\", \"temperature\": \"19℃\", \"update_time\": \"2024年10月9日16时37分53秒\"}"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
```

View File

@@ -4,7 +4,7 @@ keywords: [ AI Gateway, AI Agent ]
description: AI Agent plugin configuration reference
---
## Functional Description
A customizable API AI Agent that supports configuring HTTP method types as GET and POST APIs. Supports multiple dialogue rounds, streaming and non-streaming modes.
A customizable API AI Agent that supports configuring HTTP method types as GET and POST APIs. Supports multiple dialogue rounds, streaming and non-streaming modes, support for formatting results as custom json.
The agent flow chart is as follows:
![ai-agent](https://github.com/user-attachments/assets/b0761a0c-1afa-496c-a98e-bb9f38b340f8)
@@ -20,6 +20,7 @@ Plugin execution priority: `200`
| `llm` | object | Required | - | Configuration information for AI service provider |
| `apis` | object | Required | - | Configuration information for external API service provider |
| `promptTemplate` | object | Optional | - | Configuration information for Agent ReAct template |
| `jsonResp` | object | Optional | - | Configuring json formatting information |
The configuration fields for `llm` are as follows:
| Name | Data Type | Requirement | Default Value | Description |
@@ -71,7 +72,13 @@ The configuration fields for `chTemplate` and `enTemplate` are as follows:
| `observation` | string | Optional | - | The observation part of the Agent ReAct template |
| `thought2` | string | Optional | - | The thought2 part of the Agent ReAct template |
## Usage Example
The configuration fields for `jsonResp` are as follows:
| Name | Data Type | Requirement | Default Value | Description |
|--------------------|-----------|-------------|---------------|------------------------------------|
| `enable` | bool | Optional | - | Whether to enable json formatting. |
| `jsonSchema` | string | Optional | - | Custom json schema |
## Usage Example-disable json formatting
**Configuration Information**
```yaml
llm:
@@ -335,3 +342,68 @@ curl 'http://<replace with gateway public IP>/api/openai/v1/chat/completions' \
{"id":"65dcf12c-61ff-9e68-bffa-44fc9e6070d5","choices":[{"index":0,"message":{"role":"assistant","content":" The German translation of \"Hail Hydra!\" is \"Hoch lebe Hydra!\"."},"finish_reason":"stop"}],"created":1724043865,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":908,"completion_tokens":52,"total_tokens":960}}
```
## Usage Example-enable json formatting
**Configuration Information**
Add jsonResp configuration to the above configuration
```yaml
jsonResp:
enable: true
```
**Request Example**
```shell
curl 'http://<replace with gateway public IP>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"What is the current weather in Beijing ?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**Response Example**
```json
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content": "{\"city\": \"BeiJing\", \"weather_condition\": \"cloudy\", \"temperature\": \"19℃\", \"data_update_time\": \"Oct 9, 2024, at 16:37\"}"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
```
If you don't customise the json schema, the big model will automatically generate a json format
**Configuration Information**
Add custom json schema configuration
```yaml
jsonResp:
enable: true
jsonSchema:
title: WeatherSchema
type: object
properties:
location:
type: string
description: city name.
weather:
type: string
description: weather conditions.
temperature:
type: string
description: temperature.
update_time:
type: string
description: the update time of data.
required:
- location
- weather
- temperature
additionalProperties: false
```
**Request Example**
```shell
curl 'http://<replace with gateway public IP>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"What is the current weather in Beijing ?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**Response Example**
```json
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content": "{\"location\": \"Beijing\", \"weather\": \"cloudy\", \"temperature\": \"19℃\", \"update_time\": \"Oct 9, 2024, at 16:37\"}"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
```

View File

@@ -211,6 +211,15 @@ type LLMInfo struct {
MaxTokens int64 `yaml:"maxToken" json:"maxTokens"`
}
type JsonResp struct {
// @Title zh-CN Enable
// @Description zh-CN 是否要启用json格式化输出
Enable bool `yaml:"enable" json:"enable"`
// @Title zh-CN Json Schema
// @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json
JsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"`
}
type PluginConfig struct {
// @Title zh-CN 返回 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
@@ -225,6 +234,7 @@ type PluginConfig struct {
LLMClient wrapper.HttpClient `yaml:"-" json:"-"`
APIsParam []APIsParam `yaml:"-" json:"-"`
PromptTemplate PromptTemplate `yaml:"promptTemplate" json:"promptTemplate"`
JsonResp JsonResp `yaml:"jsonResp" json:"jsonResp"`
}
func initResponsePromptTpl(gjson gjson.Result, c *PluginConfig) {
@@ -402,3 +412,15 @@ func initLLMClient(gjson gjson.Result, c *PluginConfig) {
Host: c.LLMInfo.Domain,
})
}
func initJsonResp(gjson gjson.Result, c *PluginConfig) {
c.JsonResp.Enable = false
if c.JsonResp.Enable = gjson.Get("jsonResp.enable").Bool(); c.JsonResp.Enable {
c.JsonResp.JsonSchema = nil
if jsonSchemaValue := gjson.Get("jsonResp.jsonSchema"); jsonSchemaValue.Exists() {
if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok {
c.JsonResp.JsonSchema = schemaValue
}
}
}
}

View File

@@ -2,8 +2,10 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
@@ -47,6 +49,8 @@ func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
initLLMClient(gjson, c)
initJsonResp(gjson, c)
return nil
}
@@ -76,10 +80,10 @@ func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRe
log.Debugf("[onHttpRequestBody] newRequestBody: %s", string(newbody))
err := proxywasm.ReplaceHttpRequestBody(newbody)
if err != nil {
log.Debug("替换失败")
log.Debugf("failed replace err: %s", err.Error())
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "替换失败"+err.Error())), -1)
}
log.Debug("[onHttpRequestBody] request替换成功")
log.Debug("[onHttpRequestBody] replace request success")
return types.ActionContinue
}
}
@@ -175,11 +179,103 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wra
return types.ActionContinue
}
func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
func extractJson(bodyStr string) (string, error) {
// simply extract json from response body string
startIndex := strings.Index(bodyStr, "{")
endIndex := strings.LastIndex(bodyStr, "}") + 1
// if not found
if startIndex == -1 || startIndex >= endIndex {
return "", errors.New("cannot find json in the response body")
}
jsonStr := bodyStr[startIndex:endIndex]
// attempt to parse the JSON
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
if err != nil {
return "", err
}
return jsonStr, nil
}
func jsonFormat(llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonSchema map[string]interface{}, assistantMessage Message, actionInput string, headers [][2]string, streamMode bool, rawResponse Response, log wrapper.Log) string {
prompt := fmt.Sprintf(prompttpl.Json_Resp_Template, jsonSchema, actionInput)
messages := []dashscope.Message{{Role: "user", Content: prompt}}
completion := dashscope.Completion{
Model: llmInfo.Model,
Messages: messages,
}
completionSerialized, _ := json.Marshal(completion)
var content string
err := llmClient.Post(
llmInfo.Path,
headers,
completionSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
//得到gpt的返回结果
var responseCompletion dashscope.CompletionResponse
_ = json.Unmarshal(responseBody, &responseCompletion)
log.Infof("[jsonFormat] content: %s", responseCompletion.Choices[0].Message.Content)
content = responseCompletion.Choices[0].Message.Content
jsonStr, err := extractJson(content)
if err != nil {
log.Debugf("[onHttpRequestBody] extractJson err: %s", err.Error())
jsonStr = content
}
if streamMode {
stream(jsonStr, rawResponse, log)
} else {
noneStream(assistantMessage, jsonStr, rawResponse, log)
}
}, uint32(llmInfo.MaxExecutionTime))
if err != nil {
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
proxywasm.ResumeHttpRequest()
}
return content
}
func noneStream(assistantMessage Message, actionInput string, rawResponse Response, log wrapper.Log) {
assistantMessage.Role = "assistant"
assistantMessage.Content = actionInput
rawResponse.Choices[0].Message = assistantMessage
newbody, err := json.Marshal(rawResponse)
if err != nil {
proxywasm.ResumeHttpResponse()
return
} else {
proxywasm.ReplaceHttpResponseBody(newbody)
log.Debug("[onHttpResponseBody] replace response success")
proxywasm.ResumeHttpResponse()
}
}
func stream(actionInput string, rawResponse Response, log wrapper.Log) {
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
proxywasm.ReplaceHttpResponseHeaders(headers)
// Remove quotes from actionInput
actionInput = strings.Trim(actionInput, "\"")
returnStreamResponseTemplate := `data:{"id":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"%s","object":"chat.completion","usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}` + "\n\ndata:[DONE]\n\n"
newbody := fmt.Sprintf(returnStreamResponseTemplate, rawResponse.ID, actionInput, rawResponse.Model, rawResponse.Usage.PromptTokens, rawResponse.Usage.CompletionTokens, rawResponse.Usage.TotalTokens)
log.Infof("[onHttpResponseBody] newResponseBody: ", newbody)
proxywasm.ReplaceHttpResponseBody([]byte(newbody))
log.Debug("[onHttpResponseBody] replace response success")
proxywasm.ResumeHttpResponse()
}
func toolsCallResult(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
if statusCode != http.StatusOK {
log.Debugf("statusCode: %d", statusCode)
}
log.Info("========函数返回结果========")
log.Info("========function result========")
log.Infof(string(responseBody))
observation := "Observation: " + string(responseBody)
@@ -187,15 +283,15 @@ func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content strin
dashscope.MessageStore.AddForUser(observation)
completion := dashscope.Completion{
Model: config.LLMInfo.Model,
Model: llmInfo.Model,
Messages: dashscope.MessageStore,
MaxTokens: config.LLMInfo.MaxTokens,
MaxTokens: llmInfo.MaxTokens,
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + llmInfo.APIKey}}
completionSerialized, _ := json.Marshal(completion)
err := config.LLMClient.Post(
config.LLMInfo.Path,
err := llmClient.Post(
llmInfo.Path,
headers,
completionSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
@@ -205,42 +301,31 @@ func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content strin
log.Infof("[toolsCall] content: %s", responseCompletion.Choices[0].Message.Content)
if responseCompletion.Choices[0].Message.Content != "" {
retType, actionInput := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
retType, actionInput := toolsCall(ctx, llmClient, llmInfo, jsonResp, aPIsParam, aPIClient, responseCompletion.Choices[0].Message.Content, rawResponse, log)
if retType == types.ActionContinue {
//得到了Final Answer
var assistantMessage Message
var streamMode bool
if ctx.GetContext(StreamContextKey) == nil {
assistantMessage.Role = "assistant"
assistantMessage.Content = actionInput
rawResponse.Choices[0].Message = assistantMessage
newbody, err := json.Marshal(rawResponse)
if err != nil {
proxywasm.ResumeHttpResponse()
return
streamMode = false
if jsonResp.Enable {
jsonFormat(llmClient, llmInfo, jsonResp.JsonSchema, assistantMessage, actionInput, headers, streamMode, rawResponse, log)
} else {
proxywasm.ReplaceHttpResponseBody(newbody)
log.Debug("[onHttpResponseBody] response替换成功")
proxywasm.ResumeHttpResponse()
noneStream(assistantMessage, actionInput, rawResponse, log)
}
} else {
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
proxywasm.ReplaceHttpResponseHeaders(headers)
// Remove quotes from actionInput
actionInput = strings.Trim(actionInput, "\"")
returnStreamResponseTemplate := `data:{"id":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"%s","object":"chat.completion","usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}` + "\n\ndata:[DONE]\n\n"
newbody := fmt.Sprintf(returnStreamResponseTemplate, rawResponse.ID, actionInput, rawResponse.Model, rawResponse.Usage.PromptTokens, rawResponse.Usage.CompletionTokens, rawResponse.Usage.TotalTokens)
log.Infof("[onHttpResponseBody] newResponseBody: ", newbody)
proxywasm.ReplaceHttpResponseBody([]byte(newbody))
log.Debug("[onHttpResponseBody] response替换成功")
proxywasm.ResumeHttpResponse()
streamMode = true
if jsonResp.Enable {
jsonFormat(llmClient, llmInfo, jsonResp.JsonSchema, assistantMessage, actionInput, headers, streamMode, rawResponse, log)
} else {
stream(actionInput, rawResponse, log)
}
}
}
} else {
proxywasm.ResumeHttpRequest()
}
}, uint32(config.LLMInfo.MaxExecutionTime))
}, uint32(llmInfo.MaxExecutionTime))
if err != nil {
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
proxywasm.ResumeHttpRequest()
@@ -294,7 +379,7 @@ func outputParser(response string, log wrapper.Log) (string, string) {
return "", ""
}
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
func toolsCall(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
dashscope.MessageStore.AddForAssistant(content)
action, actionInput := outputParser(content, log)
@@ -305,9 +390,9 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
}
count := ctx.GetContext(ToolCallsCount).(int)
count++
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d", count, config.LLMInfo.MaxIterations)
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d", count, llmInfo.MaxIterations)
//函数递归调用次数,达到了预设的循环次数,强制结束
if int64(count) > config.LLMInfo.MaxIterations {
if int64(count) > llmInfo.MaxIterations {
ctx.SetContext(ToolCallsCount, 0)
return types.ActionContinue, ""
} else {
@@ -316,15 +401,14 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
//没得到最终答案
var url string
var urlStr string
var headers [][2]string
var apiClient wrapper.HttpClient
var method string
var reqBody []byte
var key string
var maxExecutionTime int64
for i, apisParam := range config.APIsParam {
for i, apisParam := range aPIsParam {
maxExecutionTime = apisParam.MaxExecutionTime
for _, tools_param := range apisParam.ToolsParam {
if action == tools_param.ToolName {
@@ -340,28 +424,37 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
method = tools_param.Method
// 组装 headers 和 key
headers = [][2]string{{"Content-Type", "application/json"}}
if apisParam.APIKey.Name != "" {
if apisParam.APIKey.In == "query" {
key = "?" + apisParam.APIKey.Name + "=" + apisParam.APIKey.Value
} else if apisParam.APIKey.In == "header" {
headers = append(headers, [2]string{"Authorization", apisParam.APIKey.Name + " " + apisParam.APIKey.Value})
// 组装 URL 和请求体
urlStr = apisParam.URL + tools_param.Path
// 解析URL模板以查找路径参数
urlParts := strings.Split(urlStr, "/")
for i, part := range urlParts {
if strings.Contains(part, "{") && strings.Contains(part, "}") {
for _, param := range tools_param.ParamName {
paramNameInPath := part[1 : len(part)-1]
if paramNameInPath == param {
if value, ok := data[param]; ok {
// 删除已经使用过的
delete(data, param)
// 替换模板中的占位符
urlParts[i] = url.QueryEscape(value.(string))
}
}
}
}
}
// 组装 URL 和请求体
url = apisParam.URL + tools_param.Path + key
// 重新组合URL
urlStr = strings.Join(urlParts, "/")
queryParams := make([][2]string, 0)
if method == "GET" {
queryParams := make([]string, 0, len(tools_param.ParamName))
for _, param := range tools_param.ParamName {
if value, ok := data[param]; ok {
queryParams = append(queryParams, fmt.Sprintf("%s=%v", param, value))
queryParams = append(queryParams, [2]string{param, fmt.Sprintf("%v", value)})
}
}
if len(queryParams) > 0 {
url += "&" + strings.Join(queryParams, "&")
}
} else if method == "POST" {
var err error
reqBody, err = json.Marshal(data)
@@ -371,9 +464,30 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
}
}
log.Infof("url: %s", url)
// 组装 headers 和 key
headers = [][2]string{{"Content-Type", "application/json"}}
if apisParam.APIKey.Name != "" {
if apisParam.APIKey.In == "query" {
queryParams = append(queryParams, [2]string{apisParam.APIKey.Name, apisParam.APIKey.Value})
} else if apisParam.APIKey.In == "header" {
headers = append(headers, [2]string{"Authorization", apisParam.APIKey.Name + " " + apisParam.APIKey.Value})
}
}
apiClient = config.APIClient[i]
if len(queryParams) > 0 {
// 将 key 拼接到 url 后面
urlStr += "?"
for i, param := range queryParams {
if i != 0 {
urlStr += "&"
}
urlStr += url.QueryEscape(param[0]) + "=" + url.QueryEscape(param[1])
}
}
log.Debugf("url: %s", urlStr)
apiClient = aPIClient[i]
break
}
}
@@ -382,11 +496,11 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
if apiClient != nil {
err := apiClient.Call(
method,
url,
urlStr,
headers,
reqBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
toolsCallResult(ctx, llmClient, llmInfo, jsonResp, aPIsParam, aPIClient, content, rawResponse, log, statusCode, responseBody)
}, uint32(maxExecutionTime))
if err != nil {
log.Debugf("tool calls error: %s", err.Error())
@@ -415,7 +529,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byt
//如果gpt返回的内容不是空的
if rawResponse.Choices[0].Message.Content != "" {
//进入agent的循环思考工具调用的过程中
retType, _ := toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
retType, _ := toolsCall(ctx, config.LLMClient, config.LLMInfo, config.JsonResp, config.APIsParam, config.APIClient, rawResponse.Choices[0].Message.Content, rawResponse, log)
return retType
} else {
return types.ActionContinue

View File

@@ -167,3 +167,7 @@ Action:` + "```" + `
%s
Question: %s
`
const Json_Resp_Template = `
Given the Json Schema: %s, please help me convert the following content to a pure json: %s
Do not respond other content except the pure json!!!!
`

View File

@@ -60,7 +60,7 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的
| vector.apiKey | string | optional | "" | 向量存储服务 API Key |
| vector.topK | int | optional | 1 | 返回TopK结果默认为 1 |
| vector.timeout | uint32 | optional | 10000 | 请求向量存储服务的超时时间单位为毫秒。默认值是10000即10秒 |
| vector.collectionID | string | optional | "" | dashvector 向量存储服务 Collection ID |
| vector.collectionID | string | optional | "" | 向量存储服务 Collection ID |
| vector.threshold | float64 | optional | 1000 | 向量相似度度量阈值 |
| vector.thresholdRelation | string | optional | lt | 相似度度量方式有 `Cosine`, `DotProduct`, `Euclidean` 等,前两者值越大相似度越高,后者值越小相似度越高。对于 `Cosine``DotProduct` 选择 `gt`,对于 `Euclidean` 则选择 `lt`。默认为 `lt`,所有条件包括 `lt` (less than小于)、`lte` (less than or equal to小等于)、`gt` (greater than大于)、`gte` (greater than or equal to大等于) |
@@ -99,6 +99,45 @@ LLM 结果缓存插件,默认配置方式可以直接用于 openai 协议的
| responseTemplate | string | optional | `{"id":"ai-cache.hit","choices":[{"index":0,"message":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` | 返回 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
| streamResponseTemplate | string | optional | `data:{"id":"ai-cache.hit","choices":[{"index":0,"delta":{"role":"assistant","content":%s},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}\n\ndata:[DONE]\n\n` | 返回流式 HTTP 响应的模版,用 %s 标记需要被 cache value 替换的部分 |
# 向量数据库提供商特有配置
## Chroma
Chroma 所对应的 `vector.type``chroma`。它并无特有的配置字段。需要提前创建 Collection并填写 Collection ID 至配置项 `vector.collectionID`,一个 Collection ID 的示例为 `52bbb8b3-724c-477b-a4ce-d5b578214612`
## DashVector
DashVector 所对应的 `vector.type``dashvector`。它并无特有的配置字段。需要提前创建 Collection并填写 `Collection 名称` 至配置项 `vector.collectionID`
## ElasticSearch
ElasticSearch 所对应的 `vector.type``elasticsearch`。需要提前创建 Index 并填写 Index Name 至配置项 `vector.collectionID`
当前依赖于 [KNN](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) 方法,请保证 ES 版本支持 `KNN`,当前已在 `8.16` 版本测试。
它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-------------------|----------|----------|--------|-------------------------------------------------------------------------------|
| `vector.esUsername` | string | 非必填 | - | ElasticSearch 用户名 |
| `vector.esPassword` | string | 非必填 | - | ElasticSearch 密码 |
`vector.esUsername``vector.esPassword` 用于 Basic 认证。同时也支持 Api Key 认证,当填写了 `vector.apiKey` 时,则启用 Api Key 认证,如果使用 SaaS 版本需要填写 `encoded` 的值。
## Milvus
Milvus 所对应的 `vector.type``milvus`。它并无特有的配置字段。需要提前创建 Collection并填写 Collection Name 至配置项 `vector.collectionID`
## Pinecone
Pinecone 所对应的 `vector.type``pinecone`。它并无特有的配置字段。需要提前创建 Index并填写 Index 访问域名至 `vector.serviceHost`
Pinecone 中的 `Namespace` 参数通过插件的 `vector.collectionID` 进行配置,如果不填写 `vector.collectionID`,则默认为 Default Namespace。
## Qdrant
Qdrant 所对应的 `vector.type``qdrant`。它并无特有的配置字段。需要提前创建 Collection并填写 Collection Name 至配置项 `vector.collectionID`
## Weaviate
Weaviate 所对应的 `vector.type``weaviate`。它并无特有的配置字段。
需要提前创建 Collection并填写 Collection Name 至配置项 `vector.collectionID`
需要注意的是 Weaviate 会设置首字母自动大写,在填写配置 `collectionID` 的时候需要将首字母设置为大写。
如果使用 SaaS 需要填写 `vector.serviceHost` 参数。
## 配置示例
### 基础配置
@@ -144,4 +183,4 @@ GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user
## 常见问题
1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。
1. 如果返回的错误为 `error status returned by host: bad argument`,请检查`serviceName`是否正确包含了服务的类型后缀(.dns等)。

View File

@@ -1,27 +0,0 @@
package embedding
// import (
// "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
// )
// const (
// weaviateURL = "172.17.0.1:8081"
// )
// type weaviateProviderInitializer struct {
// }
// func (d *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error {
// return nil
// }
// func (d *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
// return &DSProvider{
// config: config,
// client: wrapper.NewClusterClient(wrapper.DnsCluster{
// ServiceName: config.ServiceName,
// Port: dashScopePort,
// Domain: dashScopeDomain,
// }),
// }, nil
// }

View File

@@ -0,0 +1,201 @@
package vector
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type chromaProviderInitializer struct{}
func (c *chromaProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.collectionID) == 0 {
return errors.New("[Chroma] collectionID is required")
}
if len(config.serviceName) == 0 {
return errors.New("[Chroma] serviceName is required")
}
return nil
}
func (c *chromaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &ChromaProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type ChromaProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (c *ChromaProvider) GetProviderType() string {
return PROVIDER_TYPE_CHROMA
}
func (d *ChromaProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 collection_id, embeddings 和 ids
// 下面是一个例子
// {
// "where": {}, // 用于 metadata 过滤,可选参数
// "where_document": {}, // 用于 document 过滤,可选参数
// "query_embeddings": [
// [1.1, 2.3, 3.2]
// ],
// "limit": 5,
// "include": [
// "metadatas", // 可选
// "documents", // 如果需要答案则需要
// "distances"
// ]
// }
requestBody, err := json.Marshal(chromaQueryRequest{
QueryEmbeddings: []chromaEmbedding{emb},
Limit: d.config.topK,
Include: []string{"distances", "documents"},
})
if err != nil {
log.Errorf("[Chroma] Failed to marshal query embedding request body: %v", err)
return err
}
return d.client.Post(
fmt.Sprintf("/api/v1/collections/%s/query", d.config.collectionID),
[][2]string{
{"Content-Type", "application/json"},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Chroma] Query embedding response: %d, %s", statusCode, responseBody)
results, err := d.parseQueryResponse(responseBody, log)
if err != nil {
err = fmt.Errorf("[Chroma] Failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout,
)
}
func (d *ChromaProvider) UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
queryAnswer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 collection_id, embeddings 和 ids
// 下面是一个例子
// {
// "embeddings": [
// [1.1, 2.3, 3.2]
// ],
// "ids": [
// "你吃了吗?"
// ],
// "documents": [
// "我吃了。"
// ]
// }
// 如果要添加 answer则按照以下例子
// {
// "embeddings": [
// [1.1, 2.3, 3.2]
// ],
// "documents": [
// "answer1"
// ],
// "ids": [
// "id1"
// ]
// }
requestBody, err := json.Marshal(chromaInsertRequest{
Embeddings: []chromaEmbedding{queryEmb},
IDs: []string{queryString}, // queryString 指的是用户查询的问题
Documents: []string{queryAnswer}, // queryAnswer 指的是用户查询的问题的答案
})
if err != nil {
log.Errorf("[Chroma] Failed to marshal upload embedding request body: %v", err)
return err
}
err = d.client.Post(
fmt.Sprintf("/api/v1/collections/%s/add", d.config.collectionID),
[][2]string{
{"Content-Type", "application/json"},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Chroma] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
callback(ctx, log, err)
},
d.config.timeout,
)
return err
}
type chromaEmbedding []float64
type chromaMetadataMap map[string]string
type chromaInsertRequest struct {
Embeddings []chromaEmbedding `json:"embeddings"`
Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数
Documents []string `json:"documents,omitempty"` // 可选参数
IDs []string `json:"ids"`
}
type chromaQueryRequest struct {
Where map[string]string `json:"where,omitempty"` // 可选参数
WhereDocument map[string]string `json:"where_document,omitempty"` // 可选参数
QueryEmbeddings []chromaEmbedding `json:"query_embeddings"`
Limit int `json:"limit"`
Include []string `json:"include"`
}
type chromaQueryResponse struct {
Ids [][]string `json:"ids"` // 第一维是 batch query第二维是查询到的多个 ids
Distances [][]float64 `json:"distances,omitempty"` // 与 Ids 一一对应
Metadatas []chromaMetadataMap `json:"metadatas,omitempty"` // 可选参数
Embeddings []chromaEmbedding `json:"embeddings,omitempty"` // 可选参数
Documents [][]string `json:"documents,omitempty"` // 与 Ids 一一对应
Uris []string `json:"uris,omitempty"` // 可选参数
Data []interface{} `json:"data,omitempty"` // 可选参数
Included []string `json:"included"`
}
func (d *ChromaProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
var queryResp chromaQueryResponse
err := json.Unmarshal(responseBody, &queryResp)
if err != nil {
return nil, err
}
log.Debugf("[Chroma] queryResp Ids len: %d", len(queryResp.Ids))
if len(queryResp.Ids) == 1 && len(queryResp.Ids[0]) == 0 {
return nil, errors.New("no query results found in response")
}
results := make([]QueryResult, 0, len(queryResp.Ids[0]))
for i := range queryResp.Ids[0] {
result := QueryResult{
Text: queryResp.Ids[0][i],
Score: queryResp.Distances[0][i],
Answer: queryResp.Documents[0][i],
}
results = append(results, result)
}
return results, nil
}

View File

@@ -0,0 +1,200 @@
package vector
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type esProviderInitializer struct{}
func (c *esProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.collectionID) == 0 {
return errors.New("[ES] collectionID is required")
}
if len(config.serviceName) == 0 {
return errors.New("[ES] serviceName is required")
}
return nil
}
func (c *esProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &ESProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type ESProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (c *ESProvider) GetProviderType() string {
return PROVIDER_TYPE_ES
}
func (d *ESProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
requestBody, err := json.Marshal(esQueryRequest{
Source: Source{Excludes: []string{"embedding"}},
Knn: knn{
Field: "embedding",
QueryVector: emb,
K: d.config.topK,
},
Size: d.config.topK,
})
if err != nil {
log.Errorf("[ES] Failed to marshal query embedding request body: %v", err)
return err
}
return d.client.Post(
fmt.Sprintf("/%s/_search", d.config.collectionID),
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", d.getCredentials()},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[ES] Query embedding response: %d, %s", statusCode, responseBody)
results, err := d.parseQueryResponse(responseBody, log)
if err != nil {
err = fmt.Errorf("[ES] Failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout,
)
}
// base64 编码 ES 身份认证字符串或使用 Apikey
func (d *ESProvider) getCredentials() string {
if len(d.config.apiKey) != 0 {
return fmt.Sprintf("ApiKey %s", d.config.apiKey)
} else {
credentials := fmt.Sprintf("%s:%s", d.config.esUsername, d.config.esPassword)
encodedCredentials := base64.StdEncoding.EncodeToString([]byte(credentials))
return fmt.Sprintf("Basic %s", encodedCredentials)
}
}
func (d *ESProvider) UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
queryAnswer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 index, embeddings 和 question
// 下面是一个例子
// POST /<index>/_doc
// {
// "embedding": [
// [1.1, 2.3, 3.2]
// ],
// "question": [
// "你吃了吗?"
// ]
// }
requestBody, err := json.Marshal(esInsertRequest{
Embedding: queryEmb,
Question: queryString,
Answer: queryAnswer,
})
if err != nil {
log.Errorf("[ES] Failed to marshal upload embedding request body: %v", err)
return err
}
return d.client.Post(
fmt.Sprintf("/%s/_doc", d.config.collectionID),
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", d.getCredentials()},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[ES] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
callback(ctx, log, err)
},
d.config.timeout,
)
}
type esInsertRequest struct {
Embedding []float64 `json:"embedding"`
Question string `json:"question"`
Answer string `json:"answer"`
}
type knn struct {
Field string `json:"field"`
QueryVector []float64 `json:"query_vector"`
K int `json:"k"`
}
type Source struct {
Excludes []string `json:"excludes"`
}
type esQueryRequest struct {
Source Source `json:"_source"`
Knn knn `json:"knn"`
Size int `json:"size"`
}
type esQueryResponse struct {
Took int `json:"took"`
TimedOut bool `json:"timed_out"`
Hits struct {
Total struct {
Value int `json:"value"`
Relation string `json:"relation"`
} `json:"total"`
Hits []struct {
Index string `json:"_index"`
ID string `json:"_id"`
Score float64 `json:"_score"`
Source map[string]interface{} `json:"_source"`
} `json:"hits"`
} `json:"hits"`
}
func (d *ESProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
log.Infof("[ES] responseBody: %s", string(responseBody))
var queryResp esQueryResponse
err := json.Unmarshal(responseBody, &queryResp)
if err != nil {
return []QueryResult{}, err
}
log.Debugf("[ES] queryResp Hits len: %d", len(queryResp.Hits.Hits))
if len(queryResp.Hits.Hits) == 0 {
return nil, errors.New("no query results found in response")
}
results := make([]QueryResult, 0, queryResp.Hits.Total.Value)
for _, hit := range queryResp.Hits.Hits {
result := QueryResult{
Text: hit.Source["question"].(string),
Score: hit.Score,
Answer: hit.Source["answer"].(string),
}
results = append(results, result)
}
return results, nil
}

View File

@@ -0,0 +1,206 @@
package vector
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type milvusProviderInitializer struct{}
func (c *milvusProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.serviceName) == 0 {
return errors.New("[Milvus] serviceName is required")
}
if len(config.collectionID) == 0 {
return errors.New("[Milvus] collectionID is required")
}
return nil
}
func (c *milvusProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &milvusProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type milvusProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (c *milvusProvider) GetProviderType() string {
return PROVIDER_TYPE_MILVUS
}
type milvusData struct {
Vector []float64 `json:"vector"`
Question string `json:"question,omitempty"`
Answer string `json:"answer,omitempty"`
}
type milvusInsertRequest struct {
CollectionName string `json:"collectionName"`
Data []milvusData `json:"data"`
}
func (d *milvusProvider) UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
queryAnswer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 collectionName, data 和 Authorization. question, answer 可选
// 需要填写 id否则 v2.4.13-hotfix 提示 invalid syntax: invalid parameter[expected=Int64][actual=]
// 如果不填写 id要在创建 collection 的时候设置 autoId 为 true
// 下面是一个例子
// {
// "collectionName": "higress",
// "data": [
// {
// "question": "这里是问题",
// "answer": "这里是答案"
// "vector": [
// 0.9,
// 0.1,
// 0.1
// ]
// }
// ]
// }
requestBody, err := json.Marshal(milvusInsertRequest{
CollectionName: d.config.collectionID,
Data: []milvusData{
{
Question: queryString,
Answer: queryAnswer,
Vector: queryEmb,
},
},
})
if err != nil {
log.Errorf("[Milvus] Failed to marshal upload embedding request body: %v", err)
return err
}
return d.client.Post(
"/v2/vectordb/entities/insert",
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Milvus] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
callback(ctx, log, err)
},
d.config.timeout,
)
}
type milvusQueryRequest struct {
CollectionName string `json:"collectionName"`
Data [][]float64 `json:"data"`
AnnsField string `json:"annsField"`
Limit int `json:"limit"`
OutputFields []string `json:"outputFields"`
}
func (d *milvusProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 collectionName, data, annsField. outputFields 为可选参数
// 下面是一个例子
// {
// "collectionName": "quick_setup",
// "data": [
// [
// 0.3580376395471989,
// "Unknown type",
// 0.18414012509913835,
// "Unknown type",
// 0.9029438446296592
// ]
// ],
// "annsField": "vector",
// "limit": 3,
// "outputFields": [
// "color"
// ]
// }
requestBody, err := json.Marshal(milvusQueryRequest{
CollectionName: d.config.collectionID,
Data: [][]float64{emb},
AnnsField: "vector",
Limit: d.config.topK,
OutputFields: []string{
"question",
"answer",
},
})
if err != nil {
log.Errorf("[Milvus] Failed to marshal query embedding: %v", err)
return err
}
return d.client.Post(
"/v2/vectordb/entities/search",
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Milvus] Query embedding response: %d, %s", statusCode, responseBody)
results, err := d.parseQueryResponse(responseBody, log)
if err != nil {
err = fmt.Errorf("[Milvus] Failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout,
)
}
func (d *milvusProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
if !gjson.GetBytes(responseBody, "data.0.distance").Exists() {
log.Errorf("[Milvus] No distance found in response body: %s", responseBody)
return nil, errors.New("[Milvus] No distance found in response body")
}
if !gjson.GetBytes(responseBody, "data.0.question").Exists() {
log.Errorf("[Milvus] No question found in response body: %s", responseBody)
return nil, errors.New("[Milvus] No question found in response body")
}
if !gjson.GetBytes(responseBody, "data.0.answer").Exists() {
log.Errorf("[Milvus] No answer found in response body: %s", responseBody)
return nil, errors.New("[Milvus] No answer found in response body")
}
resultNum := gjson.GetBytes(responseBody, "data.#").Int()
results := make([]QueryResult, 0, resultNum)
for i := 0; i < int(resultNum); i++ {
result := QueryResult{
Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.question", i)).String(),
Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.distance", i)).Float(),
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.%d.answer", i)).String(),
}
results = append(results, result)
}
return results, nil
}

View File

@@ -0,0 +1,194 @@
package vector
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
type pineconeProviderInitializer struct{}
func (c *pineconeProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.serviceHost) == 0 {
return errors.New("[Pinecone] serviceHost is required")
}
if len(config.serviceName) == 0 {
return errors.New("[Pinecone] serviceName is required")
}
if len(config.apiKey) == 0 {
return errors.New("[Pinecone] apiKey is required")
}
return nil
}
func (c *pineconeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &pineconeProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type pineconeProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (c *pineconeProvider) GetProviderType() string {
return PROVIDER_TYPE_PINECONE
}
type pineconeMetadata struct {
Question string `json:"question"`
Answer string `json:"answer"`
}
type pineconeVector struct {
ID string `json:"id"`
Values []float64 `json:"values"`
Properties pineconeMetadata `json:"metadata"`
}
type pineconeInsertRequest struct {
Vectors []pineconeVector `json:"vectors"`
Namespace string `json:"namespace"`
}
func (d *pineconeProvider) UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
queryAnswer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 vector 和 question
// 下面是一个例子
// {
// "vectors": [
// {
// "id": "A",
// "values": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
// "metadata": {"question": "你好", "answer": "你也好"}
// }
// ]
// }
requestBody, err := json.Marshal(pineconeInsertRequest{
Vectors: []pineconeVector{
{
ID: uuid.New().String(),
Values: queryEmb,
Properties: pineconeMetadata{Question: queryString, Answer: queryAnswer},
},
},
Namespace: d.config.collectionID,
})
if err != nil {
log.Errorf("[Pinecone] Failed to marshal upload embedding request body: %v", err)
return err
}
return d.client.Post(
"/vectors/upsert",
[][2]string{
{"Content-Type", "application/json"},
{"Api-Key", d.config.apiKey},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Pinecone] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
callback(ctx, log, err)
},
d.config.timeout,
)
}
type pineconeQueryRequest struct {
Namespace string `json:"namespace"`
Vector []float64 `json:"vector"`
TopK int `json:"topK"`
IncludeMetadata bool `json:"includeMetadata"`
IncludeValues bool `json:"includeValues"`
}
func (d *pineconeProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 vector
// 下面是一个例子
// {
// "namespace": "higress",
// "vector": [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
// "topK": 1,
// "includeMetadata": false
// }
requestBody, err := json.Marshal(pineconeQueryRequest{
Namespace: d.config.collectionID,
Vector: emb,
TopK: d.config.topK,
IncludeMetadata: true,
IncludeValues: false,
})
if err != nil {
log.Errorf("[Pinecone] Failed to marshal query embedding: %v", err)
return err
}
return d.client.Post(
"/query",
[][2]string{
{"Content-Type", "application/json"},
{"Api-Key", d.config.apiKey},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Pinecone] Query embedding response: %d, %s", statusCode, responseBody)
results, err := d.parseQueryResponse(responseBody, log)
if err != nil {
err = fmt.Errorf("[Pinecone] Failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout,
)
}
func (d *pineconeProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
if !gjson.GetBytes(responseBody, "matches.0.score").Exists() {
log.Errorf("[Pinecone] No distance found in response body: %s", responseBody)
return nil, errors.New("[Pinecone] No distance found in response body")
}
if !gjson.GetBytes(responseBody, "matches.0.metadata.question").Exists() {
log.Errorf("[Pinecone] No question found in response body: %s", responseBody)
return nil, errors.New("[Pinecone] No question found in response body")
}
if !gjson.GetBytes(responseBody, "matches.0.metadata.answer").Exists() {
log.Errorf("[Pinecone] No answer found in response body: %s", responseBody)
return nil, errors.New("[Pinecone] No answer found in response body")
}
resultNum := gjson.GetBytes(responseBody, "matches.#").Int()
results := make([]QueryResult, 0, resultNum)
for i := 0; i < int(resultNum); i++ {
result := QueryResult{
Text: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.metadata.question", i)).String(),
Score: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.score", i)).Float(),
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("matches.%d.metadata.answer", i)).String(),
}
results = append(results, result)
}
return results, nil
}

View File

@@ -10,6 +10,11 @@ import (
const (
PROVIDER_TYPE_DASH_VECTOR = "dashvector"
PROVIDER_TYPE_CHROMA = "chroma"
PROVIDER_TYPE_ES = "elasticsearch"
PROVIDER_TYPE_WEAVIATE = "weaviate"
PROVIDER_TYPE_PINECONE = "pinecone"
PROVIDER_TYPE_QDRANT = "qdrant"
PROVIDER_TYPE_MILVUS = "milvus"
)
type providerInitializer interface {
@@ -20,7 +25,12 @@ type providerInitializer interface {
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASH_VECTOR: &dashVectorProviderInitializer{},
// PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{},
PROVIDER_TYPE_CHROMA: &chromaProviderInitializer{},
PROVIDER_TYPE_ES: &esProviderInitializer{},
PROVIDER_TYPE_WEAVIATE: &weaviateProviderInitializer{},
PROVIDER_TYPE_PINECONE: &pineconeProviderInitializer{},
PROVIDER_TYPE_QDRANT: &qdrantProviderInitializer{},
PROVIDER_TYPE_MILVUS: &milvusProviderInitializer{},
}
)
@@ -71,10 +81,6 @@ type StringQuerier interface {
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error
}
type SimilarityThresholdProvider interface {
GetSimilarityThreshold() float64
}
type ProviderConfig struct {
// @Title zh-CN 向量存储服务提供者类型
// @Description zh-CN 向量存储服务提供者类型,例如 dashvector、chroma
@@ -97,8 +103,8 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求向量存储服务的超时时间单位为毫秒。默认值是10000即10秒
timeout uint32
// @Title zh-CN DashVector 向量存储服务 Collection ID
// @Description zh-CN DashVector 向量存储服务 Collection ID
// @Title zh-CN 向量存储服务 Collection ID
// @Description zh-CN 向量存储服务 Collection ID
collectionID string
// @Title zh-CN 相似度度量阈值
// @Description zh-CN 默认相似度度量阈值,默认为 1000。
@@ -109,6 +115,14 @@ type ProviderConfig struct {
// 所以需要允许自定义比较方式,对于 Cosine 和 DotProduct 选择 gt对于 Euclidean 则选择 lt。
// 默认为 lt所有条件包括 lt (less than小于)、lte (less than or equal to小等于)、gt (greater than大于)、gte (greater than or equal to大等于)
ThresholdRelation string
// ES 配置
// @Title zh-CN ES 用户名
// @Description zh-CN ES 用户名
esUsername string
// @Title zh-CN ES 密码
// @Description zh-CN ES 密码
esPassword string
}
func (c *ProviderConfig) GetProviderType() string {
@@ -117,7 +131,6 @@ func (c *ProviderConfig) GetProviderType() string {
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String()
// DashVector
c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String()
c.servicePort = int64(json.Get("servicePort").Int())
@@ -142,6 +155,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.ThresholdRelation == "" {
c.ThresholdRelation = "lt"
}
// ES
c.esUsername = json.Get("esUsername").String()
c.esPassword = json.Get("esPassword").String()
}
func (c *ProviderConfig) Validate() error {
@@ -152,6 +169,9 @@ func (c *ProviderConfig) Validate() error {
if !has {
return errors.New("unknown vector database service provider type: " + c.typ)
}
if !isRelationValid(c.ThresholdRelation) {
return errors.New("invalid thresholdRelation: " + c.ThresholdRelation)
}
if err := initializer.ValidateConfig(*c); err != nil {
return err
}
@@ -165,3 +185,12 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
}
return initializer.CreateProvider(pc)
}
func isRelationValid(relation string) bool {
for _, r := range []string{"lt", "lte", "gt", "gte"} {
if r == relation {
return true
}
}
return false
}

View File

@@ -0,0 +1,208 @@
package vector
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
type qdrantProviderInitializer struct{}
func (c *qdrantProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.serviceName) == 0 {
return errors.New("[Qdrant] serviceName is required")
}
if len(config.collectionID) == 0 {
return errors.New("[Qdrant] collectionID is required")
}
return nil
}
func (c *qdrantProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &qdrantProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type qdrantProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (c *qdrantProvider) GetProviderType() string {
return PROVIDER_TYPE_QDRANT
}
type qdrantPayload struct {
Question string `json:"question"`
Answer string `json:"answer"`
}
type qdrantPoint struct {
ID string `json:"id"`
Vector []float64 `json:"vector"`
Payload qdrantPayload `json:"payload"`
}
type qdrantInsertRequest struct {
Points []qdrantPoint `json:"points"`
}
func (d *qdrantProvider) UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
queryAnswer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 id 和 vector. payload 可选
// 下面是一个例子
// {
// "points": [
// {
// "id": "76874cce-1fb9-4e16-9b0b-f085ac06ed6f",
// "payload": {
// "question": "这里是问题",
// "answer": "这里是答案"
// },
// "vector": [
// 0.9,
// 0.1,
// 0.1
// ]
// }
// ]
// }
requestBody, err := json.Marshal(qdrantInsertRequest{
Points: []qdrantPoint{
{
ID: uuid.New().String(),
Vector: queryEmb,
Payload: qdrantPayload{Question: queryString, Answer: queryAnswer},
},
},
})
if err != nil {
log.Errorf("[Qdrant] Failed to marshal upload embedding request body: %v", err)
return err
}
return d.client.Put(
fmt.Sprintf("/collections/%s/points", d.config.collectionID),
[][2]string{
{"Content-Type", "application/json"},
{"api-key", d.config.apiKey},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Qdrant] statusCode:%d, responseBody:%s", statusCode, string(responseBody))
callback(ctx, log, err)
},
d.config.timeout,
)
}
type qdrantQueryRequest struct {
Vector []float64 `json:"vector"`
Limit int `json:"limit"`
WithPayload bool `json:"with_payload"`
}
func (d *qdrantProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 vector 和 limit. with_payload 可选,为了直接得到问题答案,所以这里需要
// 下面是一个例子
// {
// "vector": [
// 0.2,
// 0.1,
// 0.9,
// 0.7
// ],
// "limit": 1
// }
requestBody, err := json.Marshal(qdrantQueryRequest{
Vector: emb,
Limit: d.config.topK,
WithPayload: true,
})
if err != nil {
log.Errorf("[Qdrant] Failed to marshal query embedding: %v", err)
return err
}
return d.client.Post(
fmt.Sprintf("/collections/%s/points/search", d.config.collectionID),
[][2]string{
{"Content-Type", "application/json"},
{"api-key", d.config.apiKey},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Qdrant] Query embedding response: %d, %s", statusCode, responseBody)
results, err := d.parseQueryResponse(responseBody, log)
if err != nil {
err = fmt.Errorf("[Qdrant] Failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout,
)
}
func (d *qdrantProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
// 返回的内容例子如下
// {
// "time": 0.002,
// "status": "ok",
// "result": [
// {
// "id": 42,
// "version": 3,
// "score": 0.75,
// "payload": {
// "question": "London",
// "answer": "green"
// },
// "shard_key": "region_1",
// "order_value": 42
// }
// ]
// }
if !gjson.GetBytes(responseBody, "result.0.score").Exists() {
log.Errorf("[Qdrant] No distance found in response body: %s", responseBody)
return nil, errors.New("[Qdrant] No distance found in response body")
}
if !gjson.GetBytes(responseBody, "result.0.payload.answer").Exists() {
log.Errorf("[Qdrant] No answer found in response body: %s", responseBody)
return nil, errors.New("[Qdrant] No answer found in response body")
}
resultNum := gjson.GetBytes(responseBody, "result.#").Int()
results := make([]QueryResult, 0, resultNum)
for i := 0; i < int(resultNum); i++ {
result := QueryResult{
Text: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.question", i)).String(),
Score: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.score", i)).Float(),
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("result.%d.payload.answer", i)).String(),
}
results = append(results, result)
}
return results, nil
}

View File

@@ -0,0 +1,188 @@
package vector
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type weaviateProviderInitializer struct{}
func (c *weaviateProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.collectionID) == 0 {
return errors.New("[Weaviate] collectionID is required")
}
if len(config.serviceName) == 0 {
return errors.New("[Weaviate] serviceName is required")
}
return nil
}
func (c *weaviateProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &WeaviateProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: config.serviceName,
Host: config.serviceHost,
Port: int64(config.servicePort),
}),
}, nil
}
type WeaviateProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (c *WeaviateProvider) GetProviderType() string {
return PROVIDER_TYPE_WEAVIATE
}
func (d *WeaviateProvider) QueryEmbedding(
emb []float64,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(results []QueryResult, ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 class, vector
// 下面是一个例子
// {"query": "{ Get { Higress ( limit: 2 nearVector: { vector: [0.1, 0.2, 0.3] } ) { question _additional { distance } } } }"}
embString, err := json.Marshal(emb)
if err != nil {
log.Errorf("[Weaviate] Failed to marshal query embedding: %v", err)
return err
}
// 这里默认按照 distance 进行升序,所以不用再次排序
graphql := fmt.Sprintf(`
{
Get {
%s (
limit: %d
nearVector: {
vector: %s
}
) {
question
answer
_additional {
distance
}
}
}
}
`, d.config.collectionID, d.config.topK, embString)
requestBody, err := json.Marshal(weaviateQueryRequest{
Query: graphql,
})
if err != nil {
log.Errorf("[Weaviate] Failed to marshal query embedding request body: %v", err)
return err
}
err = d.client.Post(
"/v1/graphql",
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Weaviate] Query embedding response: %d, %s", statusCode, responseBody)
results, err := d.parseQueryResponse(responseBody, log)
if err != nil {
err = fmt.Errorf("[Weaviate] Failed to parse query response: %v", err)
}
callback(results, ctx, log, err)
},
d.config.timeout,
)
return err
}
func (d *WeaviateProvider) UploadAnswerAndEmbedding(
queryString string,
queryEmb []float64,
queryAnswer string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(ctx wrapper.HttpContext, log wrapper.Log, err error)) error {
// 最少需要填写的参数为 class, vector 和 question 和 answer
// 下面是一个例子
// {"class": "Higress", "vector": [0.1, 0.2, 0.3], "properties": {"question": "这里是问题", "answer": "这里是答案"}}
requestBody, err := json.Marshal(weaviateInsertRequest{
Class: d.config.collectionID,
Vector: queryEmb,
Properties: weaviateProperties{Question: queryString, Answer: queryAnswer}, // queryString 指的是用户查询的问题
})
if err != nil {
log.Errorf("[Weaviate] Failed to marshal upload embedding request body: %v", err)
return err
}
return d.client.Post(
"/v1/objects",
[][2]string{
{"Content-Type", "application/json"},
{"Authorization", fmt.Sprintf("Bearer %s", d.config.apiKey)},
},
requestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Debugf("[Weaviate] statusCode: %d, responseBody: %s", statusCode, string(responseBody))
callback(ctx, log, err)
},
d.config.timeout,
)
}
type weaviateProperties struct {
Question string `json:"question"`
Answer string `json:"answer"`
}
type weaviateInsertRequest struct {
Class string `json:"class"`
Vector []float64 `json:"vector"`
Properties weaviateProperties `json:"properties"`
}
type weaviateQueryRequest struct {
Query string `json:"query"`
}
func (d *WeaviateProvider) parseQueryResponse(responseBody []byte, log wrapper.Log) ([]QueryResult, error) {
log.Infof("[Weaviate] queryResp: %s", string(responseBody))
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0._additional.distance", d.config.collectionID)).Exists() {
log.Errorf("[Weaviate] No distance found in response body: %s", responseBody)
return nil, errors.New("[Weaviate] No distance found in response body")
}
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.question", d.config.collectionID)).Exists() {
log.Errorf("[Weaviate] No question found in response body: %s", responseBody)
return nil, errors.New("[Weaviate] No question found in response body")
}
if !gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.0.answer", d.config.collectionID)).Exists() {
log.Errorf("[Weaviate] No answer found in response body: %s", responseBody)
return nil, errors.New("[Weaviate] No answer found in response body")
}
resultNum := gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.#", d.config.collectionID)).Int()
results := make([]QueryResult, 0, resultNum)
for i := 0; i < int(resultNum); i++ {
result := QueryResult{
Text: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.question", d.config.collectionID, i)).String(),
Score: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d._additional.distance", d.config.collectionID, i)).Float(),
Answer: gjson.GetBytes(responseBody, fmt.Sprintf("data.Get.%s.%d.answer", d.config.collectionID, i)).String(),
}
results = append(results, result)
}
return results, nil
}

View File

@@ -8,8 +8,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240528060522-53bccf89f441
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1
)

View File

@@ -5,12 +5,14 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -6,7 +6,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.2
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
)
require (
@@ -14,7 +14,7 @@ require (
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/santhosh-tekuri/jsonschema v1.2.4 // indirect
github.com/tidwall/gjson v1.14.3 // indirect
github.com/tidwall/gjson v1.17.3 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect

View File

@@ -9,6 +9,7 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -17,6 +18,7 @@ github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHi
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -6,8 +6,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.5
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -5,6 +5,7 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -12,6 +13,7 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -6,8 +6,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.5
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -8,6 +8,7 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -15,6 +16,7 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -31,15 +31,16 @@ description: AI 代理插件配置参考
`provider`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值openai默认值使用 OpenAI 的接口契约、original使用目标服务提供商的原始接口契约 |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------| --------------- | -------- | ------ |-----------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值openai默认值使用 OpenAI 的接口契约、original使用目标服务提供商的原始接口契约 |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
`context`的配置字段说明如下:
@@ -75,6 +76,16 @@ custom-setting会遵循如下表格根据`name`和协议来替换对应的字
如果启用了raw模式custom-setting会直接用输入的`name``value`去更改请求中的json内容而不对参数名称做任何限制和修改。
对于大多数协议custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
`failover` 的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
### 提供商特有配置

View File

@@ -1,9 +1,9 @@
package config
import (
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
// @Name ai-proxy
@@ -75,13 +75,17 @@ func (c *PluginConfig) Validate() error {
return nil
}
func (c *PluginConfig) Complete() error {
func (c *PluginConfig) Complete(log wrapper.Log) error {
if c.activeProviderConfig == nil {
c.activeProvider = nil
return nil
}
var err error
c.activeProvider, err = provider.CreateProvider(*c.activeProviderConfig)
providerConfig := c.GetProviderConfig()
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
return err
}

View File

@@ -44,9 +44,10 @@ func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}
return nil
}
@@ -59,9 +60,10 @@ func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, plug
if err := pluginConfig.Validate(); err != nil {
return err
}
if err := pluginConfig.Complete(); err != nil {
if err := pluginConfig.Complete(log); err != nil {
return err
}
return nil
}
@@ -80,7 +82,13 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
path, _ := url.Parse(rawPath)
apiName := getOpenAiApiName(path.Path)
providerConfig := pluginConfig.GetProviderConfig()
if apiName == "" && !providerConfig.IsOriginal() {
if providerConfig.IsOriginal() {
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
apiName = handler.GetApiName(path.Path)
}
}
if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
@@ -89,8 +97,11 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
ctx.SetContext(ctxKeyApiName, apiName)
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
@@ -102,6 +113,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
}
@@ -156,15 +168,24 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
providerConfig := pluginConfig.GetProviderConfig()
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
status, err := proxywasm.GetHttpResponseHeader(":status")
if err != nil || status != "200" {
if err != nil {
log.Errorf("unable to load :status header from response: %v", err)
}
ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
return types.ActionContinue
}
// Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
@@ -233,16 +254,6 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}
func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
@@ -252,3 +263,13 @@ func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
(*ctx).BufferResponseBody()
}
}
func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/chat/completions") {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, "/v1/embeddings") {
return provider.ApiNameEmbeddings
}
return ""
}

View File

@@ -1,14 +1,12 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// ai360Provider is the provider for 360 OpenAI service.
@@ -46,10 +44,7 @@ func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(ai360Domain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -58,47 +53,12 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *ai360Provider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (m *ai360Provider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

View File

@@ -3,16 +3,15 @@ package provider
import (
"errors"
"fmt"
"net/http"
"net/url"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// azureProvider is the provider for Azure OpenAI service.
type azureProviderInitializer struct {
}
@@ -55,47 +54,23 @@ func (m *azureProvider) GetProviderType() string {
}
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
_ = util.OverwriteRequestHost(m.serviceUrl.Host)
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
if apiName == ApiNameChatCompletion {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
} else {
ctx.DontReadRequestBody()
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if m.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.azure.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.azure.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,11 +2,10 @@ package provider
import (
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -47,10 +46,7 @@ func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(baichuanChatCompletionPath)
_ = util.OverwriteRequestHost(baichuanDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -58,28 +54,12 @@ func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.baichuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.baichuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baichuanChatCompletionPath)
util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -16,7 +17,8 @@ import (
// baiduProvider is the provider for baidu ernie bot service.
const (
baiduDomain = "aip.baidubce.com"
baiduDomain = "aip.baidubce.com"
baiduChatCompletionPath = "/chat"
)
var baiduModelToPathSuffixMap = map[string]string{
@@ -60,98 +62,35 @@ func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(baiduDomain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
b.config.handleRequestHeaders(b, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, baiduDomain)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 使用文心一言接口协议
if b.config.protocol == protocolOriginal {
request := &baiduTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := b.getRequestPath(request.Model)
_ = util.OverwriteRequestPath(path)
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log)
}
if b.config.context == nil {
return types.ActionContinue, nil
}
err := b.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
b.setSystemContent(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
err := b.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
path := b.getRequestPath(ctx, request.Model)
util.OverwriteRequestPathHeader(headers, path)
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, b.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := b.getRequestPath(mappedModel)
_ = util.OverwriteRequestPath(path)
if b.config.context == nil {
baiduRequest := b.baiduTextGenRequest(request)
return types.ActionContinue, replaceJsonRequestBody(baiduRequest, log)
}
err := b.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.baidu.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
baiduRequest := b.baiduTextGenRequest(request)
if err := replaceJsonRequestBody(baiduRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.baidu.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
baiduRequest := b.baiduTextGenRequest(request)
return json.Marshal(baiduRequest)
}
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -226,13 +165,13 @@ type baiduTextGenRequest struct {
UserId string `json:"user_id,omitempty"`
}
func (b *baiduProvider) getRequestPath(baiduModel string) string {
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
suffix = baiduModel
}
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetRandomToken())
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
}
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
@@ -339,3 +278,10 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (b *baiduProvider) GetApiName(path string) ApiName {
if strings.Contains(path, baiduChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -105,102 +106,39 @@ func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
}
_ = util.OverwriteRequestPath(claudeChatCompletionPath)
_ = util.OverwriteRequestHost(claudeDomain)
_ = proxywasm.ReplaceHttpRequestHeader("x-api-key", c.config.GetRandomToken())
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestHostHeader(headers, claudeDomain)
headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion
}
_ = proxywasm.AddHttpRequestHeader("anthropic-version", c.config.claudeVersion)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
headers.Add("anthropic-version", c.config.claudeVersion)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
// use original protocol
if c.config.protocol == protocolOriginal {
if c.config.context == nil {
return types.ActionContinue, nil
}
request := &claudeTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
// use openai protocol
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
if c.config.context == nil {
claudeRequest := c.buildClaudeTextGenRequest(request)
return types.ActionContinue, replaceJsonRequestBody(claudeRequest, log)
}
err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.claude.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
claudeRequest := c.buildClaudeTextGenRequest(request)
if err := replaceJsonRequestBody(claudeRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.claude.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
claudeRequest := c.buildClaudeTextGenRequest(request)
return json.Marshal(claudeRequest)
}
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -369,3 +307,25 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG
func (c *claudeProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &claudeTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.System == "" {
request.System = content
} else {
request.System = content + "\n" + request.System
}
return json.Marshal(request)
}
func (c *claudeProvider) GetApiName(path string) ApiName {
if strings.Contains(path, claudeChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -2,12 +2,11 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -47,13 +46,7 @@ func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName A
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
_ = util.OverwriteRequestHost(cloudflareDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + c.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -61,49 +54,13 @@ func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, c.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
if c.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.cloudflare.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := c.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.cloudflare.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.cloudflare.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
}
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

View File

@@ -3,17 +3,16 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
const (
cohereDomain = "api.cohere.com"
chatCompletionPath = "/v1/chat"
cohereDomain = "api.cohere.com"
cohereChatCompletionPath = "/v1/chat"
)
type cohereProviderInitializer struct{}
@@ -27,12 +26,14 @@ func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error
func (m *cohereProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &cohereProvider{
config: config,
config: config,
contextCache: createContextCache(&config),
}, nil
}
type cohereProvider struct {
config ProviderConfig
config ProviderConfig
contextCache *contextCache
}
type cohereTextGenRequest struct {
@@ -57,10 +58,7 @@ func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(cohereDomain)
_ = util.OverwriteRequestPath(chatCompletionPath)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -68,30 +66,7 @@ func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.config.protocol == protocolOriginal {
request := &cohereTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
return m.handleRequestBody(log, request)
}
origin := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, origin); err != nil {
return types.ActionContinue, err
}
request := m.buildCohereRequest(origin)
return m.handleRequestBody(log, request)
}
func (m *cohereProvider) handleRequestBody(log wrapper.Log, request interface{}) (types.Action, error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
err := replaceJsonRequestBody(request, log)
if err != nil {
_ = util.SendResponse(500, "ai-proxy.cohere.proxy_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest {
@@ -112,3 +87,27 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
PresencePenalty: origin.PresencePenalty,
}
}
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
util.OverwriteRequestHostHeader(headers, cohereDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
cohereRequest := m.buildCohereRequest(request)
return json.Marshal(cohereRequest)
}
func (m *cohereProvider) GetApiName(path string) ApiName {
if strings.Contains(path, cohereChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -1,12 +1,15 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
)
@@ -57,6 +60,10 @@ type contextCache struct {
content string
}
type ContextInserter interface {
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
}
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
if callback == nil {
return errors.New("callback is nil")
@@ -98,3 +105,79 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
timeout: providerConfig.timeout,
}
}
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
// get context will overwrite the original request host and path
// save the original request host and path in case they are needed for apiToken health check
ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost())
ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath())
if c.loaded {
log.Debugf("context file loaded from cache")
insertContext(provider, c.content, nil, body, log)
return nil
}
log.Infof("loading context file from %s", c.fileUrl.String())
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
return
}
c.content = string(responseBody)
c.loaded = true
log.Debugf("content: %s", c.content)
insertContext(provider, c.content, nil, body, log)
}, c.timeout)
}
func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
typ := provider.GetProviderType()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
if inserter, ok := provider.(ContextInserter); ok {
body, err = inserter.insertHttpContextMessage(body, content, false)
} else {
body, err = defaultInsertHttpContextMessage(body, content)
}
if err != nil {
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err))
}
if err := replaceHttpJsonRequestBody(body, log); err != nil {
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}
func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) {
request := &chatCompletionRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
fileMessage := chatMessage{
Role: roleSystem,
Content: content,
}
var firstNonSystemMessageIndex int
for i, message := range request.Messages {
if message.Role != roleSystem {
firstNonSystemMessageIndex = i
break
}
}
if firstNonSystemMessageIndex == 0 {
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
} else {
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
}
return json.Marshal(request)
}

View File

@@ -2,6 +2,7 @@ package provider
import (
"errors"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
@@ -38,7 +39,12 @@ func (m *cozeProvider) GetProviderType() string {
}
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestHost(cozeDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, cozeDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -78,49 +80,38 @@ func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(deeplChatCompletionPath)
_ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
d.config.handleRequestHeaders(d, ctx, apiName, log)
return types.HeaderStopIteration, nil
}
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
headers.Del("Accept-Encoding")
}
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if d.config.protocol == protocolOriginal {
request := &deeplRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if err := d.overwriteRequestHost(request.Model); err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
originRequest := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, originRequest); err != nil {
return types.ActionContinue, err
}
if err := d.overwriteRequestHost(originRequest.Model); err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model)
deeplRequest := &deeplRequest{
Text: make([]string, 0),
TargetLang: d.config.targetLang,
}
for _, msg := range originRequest.Messages {
if msg.Role == roleSystem {
deeplRequest.Context = msg.StringContent()
} else {
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
}
}
return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log)
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
}
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return nil, err
}
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
err := d.overwriteRequestHost(headers, request.Model)
if err != nil {
return nil, err
}
baiduRequest := d.deeplTextGenRequest(request)
return json.Marshal(baiduRequest)
}
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -164,13 +155,35 @@ func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplRespo
}
}
func (d *deeplProvider) overwriteRequestHost(model string) error {
func (d *deeplProvider) overwriteRequestHost(headers http.Header, model string) error {
if model == "Pro" {
_ = util.OverwriteRequestHost(deeplHostPro)
util.OverwriteRequestHostHeader(headers, deeplHostPro)
} else if model == "Free" {
_ = util.OverwriteRequestHost(deeplHostFree)
util.OverwriteRequestHostHeader(headers, deeplHostFree)
} else {
return errors.New(`deepl model should be "Free" or "Pro"`)
}
return nil
}
func (d *deeplProvider) deeplTextGenRequest(request *chatCompletionRequest) *deeplRequest {
deeplRequest := &deeplRequest{
Text: make([]string, 0),
TargetLang: d.config.targetLang,
}
for _, msg := range request.Messages {
if msg.Role == roleSystem {
deeplRequest.Context = msg.StringContent()
} else {
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
}
}
return deeplRequest
}
func (d *deeplProvider) GetApiName(path string) ApiName {
if strings.Contains(path, deeplChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -2,12 +2,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
// deepseekProvider is the provider for deepseek Ai service.
@@ -47,10 +45,7 @@ func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(deepseekChatCompletionPath)
_ = util.OverwriteRequestHost(deepseekDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -58,28 +53,12 @@ func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.deepseek.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.deepseek.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deepseekChatCompletionPath)
util.OverwriteRequestHostHeader(headers, deepseekDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,12 +2,11 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
const (
@@ -41,17 +40,10 @@ func (m *doubaoProvider) GetProviderType() string {
}
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestHost(doubaoDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
return types.ActionContinue, nil
}
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(doubaoChatCompletionPath)
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -59,44 +51,19 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.contextCache != nil {
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.doubao.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.doubao.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
} else {
return types.ActionContinue, err
}
} else {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.doubao.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
return types.ActionContinue, err
}
_ = proxywasm.ResumeHttpRequest()
return types.ActionPause, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, doubaoChatCompletionPath)
util.OverwriteRequestHostHeader(headers, doubaoDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *doubaoProvider) GetApiName(path string) ApiName {
if strings.Contains(path, doubaoChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -0,0 +1,594 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"math/rand"
"net/http"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 连续请求失败的阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值
successThreshold int64 `required:"false" yaml:"successThreshold" json:"successThreshold"`
// @Title zh-CN 健康检测的间隔时间,单位毫秒
healthCheckInterval int64 `required:"false" yaml:"healthCheckInterval" json:"healthCheckInterval"`
// @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN 本次请求使用的 apiToken
ctxApiTokenInUse string
// @Title zh-CN 记录 apiToken 请求失败的次数key 为 apiTokenvalue 为失败次数
ctxApiTokenRequestFailureCount string
// @Title zh-CN 记录 apiToken 健康检测成功的次数key 为 apiTokenvalue 为成功次数
ctxApiTokenRequestSuccessCount string
// @Title zh-CN 记录所有可用的 apiToken 列表
ctxApiTokens string
// @Title zh-CN 记录所有不可用的 apiToken 列表
ctxUnavailableApiTokens string
// @Title zh-CN 记录请求的 cluster, host 和 path用于在健康检测时构建请求
ctxHealthCheckEndpoint string
// @Title zh-CN 健康检测选主,只有选到主的 Wasm VM 才执行健康检测
ctxVmLease string
}
type Lease struct {
VMID string `json:"vmID"`
Timestamp int64 `json:"timestamp"`
}
type HealthCheckEndpoint struct {
Host string `json:"host"`
Path string `json:"path"`
Cluster string `json:"cluster"`
}
const (
casMaxRetries = 10
addApiTokenOperation = "addApiToken"
removeApiTokenOperation = "removeApiToken"
addApiTokenRequestCountOperation = "addApiTokenRequestCount"
resetApiTokenRequestCountOperation = "resetApiTokenRequestCount"
ctxRequestHost = "requestHost"
ctxRequestPath = "requestPath"
)
var (
healthCheckClient wrapper.HttpClient
)
func (f *failover) FromJson(json gjson.Result) {
f.enabled = json.Get("enabled").Bool()
f.failureThreshold = json.Get("failureThreshold").Int()
if f.failureThreshold == 0 {
f.failureThreshold = 3
}
f.successThreshold = json.Get("successThreshold").Int()
if f.successThreshold == 0 {
f.successThreshold = 1
}
f.healthCheckInterval = json.Get("healthCheckInterval").Int()
if f.healthCheckInterval == 0 {
f.healthCheckInterval = 5000
}
f.healthCheckTimeout = json.Get("healthCheckTimeout").Int()
if f.healthCheckTimeout == 0 {
f.healthCheckTimeout = 5000
}
f.healthCheckModel = json.Get("healthCheckModel").String()
}
func (f *failover) Validate() error {
if f.healthCheckModel == "" {
return errors.New("missing healthCheckModel in failover config")
}
return nil
}
func (c *ProviderConfig) initVariable() {
// Set provider name as prefix to differentiate shared data
provider := c.GetType()
c.failover.ctxApiTokenInUse = provider + "-apiTokenInUse"
c.failover.ctxApiTokenRequestFailureCount = provider + "-apiTokenRequestFailureCount"
c.failover.ctxApiTokenRequestSuccessCount = provider + "-apiTokenRequestSuccessCount"
c.failover.ctxApiTokens = provider + "-apiTokens"
c.failover.ctxUnavailableApiTokens = provider + "-unavailableApiTokens"
c.failover.ctxHealthCheckEndpoint = provider + "-requestHostAndPath"
c.failover.ctxVmLease = provider + "-vmLease"
}
func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
return nil
}
func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
c.initVariable()
// Reset shared data in case plugin configuration is updated
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
c.resetSharedData()
if c.isFailoverEnabled() {
log.Debugf("ai-proxy plugin failover is enabled")
vmID := generateVMID()
err := c.initApiTokens()
if err != nil {
return fmt.Errorf("failed to init apiTokens: %v", err)
}
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
// Only the Wasm VM that successfully acquires the lease will perform health check
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
if err != nil {
log.Errorf("Failed to get unavailable tokens: %v", err)
return
}
if len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Host: healthCheckEndpoint.Host,
Cluster: healthCheckEndpoint.Cluster,
})
ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(healthCheckEndpoint.Path, modifiedHeaders, modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken, log)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
log.Errorf("Failed to perform health check request: %v", err)
}
}
}
}
})
}
return nil
}
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) ([][2]string, []byte, error) {
originalHeaders := util.SliceToHeader(headers)
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, originalHeaders, log)
}
var err error
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
}
modifiedHeaders := util.HeaderToSlice(originalHeaders)
return modifiedHeaders, body, nil
}
func createHttpContext() *wrapper.CommonHttpCtx[any] {
setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
return ctx
}
func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
if err != nil {
log.Errorf("Failed to get request host and path: %v", err)
}
var healthCheckEndpoint HealthCheckEndpoint
err = json.Unmarshal(data, &healthCheckEndpoint)
if err != nil {
log.Errorf("Failed to unmarshal request host and path: %v", err)
}
headers := [][2]string{
{"content-type", "application/json"},
}
body := []byte(fmt.Sprintf(`{
"model": "%s",
"messages": [
{
"role": "user",
"content": "who are you?"
}
]
}`, c.failover.healthCheckModel))
return healthCheckEndpoint, headers, body
}
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
now := time.Now().Unix()
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return c.setLease(vmID, now, cas, log)
} else {
log.Errorf("Failed to get lease: %v", err)
return false
}
}
if data == nil {
return c.setLease(vmID, now, cas, log)
}
var lease Lease
err = json.Unmarshal(data, &lease)
if err != nil {
log.Errorf("Failed to unmarshal lease data: %v", err)
return false
}
// If vmID is itself, try to renew the lease directly
// If the lease is expired (60s), try to acquire the lease
if lease.VMID == vmID || now-lease.Timestamp > 60 {
lease.VMID = vmID
lease.Timestamp = now
return c.setLease(vmID, now, cas, log)
}
return false
}
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
lease := Lease{
VMID: vmID,
Timestamp: timestamp,
}
leaseByte, err := json.Marshal(lease)
if err != nil {
log.Errorf("Failed to marshal lease data: %v", err)
return false
}
if err := proxywasm.SetSharedData(c.failover.ctxVmLease, leaseByte, cas); err != nil {
log.Errorf("Failed to set or renew lease: %v", err)
return false
}
return true
}
func generateVMID() string {
return uuid.New().String()
}
// When number of request successes exceeds the threshold during health check,
// add the apiToken back to the available list and remove it from the unavailable list
func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
if err != nil {
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
return
}
successCount := successApiTokenRequestCount[apiToken] + 1
if successCount >= c.failover.successThreshold {
log.Infof("apiToken %s is available now, add it back to the apiTokens list", apiToken)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
addApiToken(c.failover.ctxApiTokens, apiToken, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
} else {
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
}
}
// When number of request failures exceeds the threshold,
// remove the apiToken from the available list and add it to the unavailable list
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
return
}
availableTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
if err != nil {
log.Errorf("Failed to get available apiToken: %v", err)
return
}
// unavailable apiToken has been removed from the available list
if !containsElement(availableTokens, apiToken) {
return
}
failureCount := failureApiTokenRequestCount[apiToken] + 1
if failureCount >= c.failover.failureThreshold {
log.Infof("apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
removeApiToken(c.failover.ctxApiTokens, apiToken, log)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
// Set the request host and path to shared data in case they are needed in apiToken health check
c.setHealthCheckEndpoint(ctx, log)
} else {
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
}
}
func addApiToken(key, apiToken string, log wrapper.Log) {
modifyApiToken(key, apiToken, addApiTokenOperation, log)
}
func removeApiToken(key, apiToken string, log wrapper.Log) {
modifyApiToken(key, apiToken, removeApiTokenOperation, log)
}
func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokens, cas, err := getApiTokens(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
exists := containsElement(apiTokens, apiToken)
if op == addApiTokenOperation && exists {
log.Debugf("%s already exists in %s", apiToken, key)
return
} else if op == removeApiTokenOperation && !exists {
log.Debugf("%s does not exist in %s", apiToken, key)
return
}
if op == addApiTokenOperation {
apiTokens = append(apiTokens, apiToken)
} else {
apiTokens = removeElement(apiTokens, apiToken)
}
if err := setApiTokens(key, apiTokens, cas); err == nil {
log.Debugf("Successfully updated %s in %s", apiToken, key)
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func getApiTokens(key string) ([]string, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return []string{}, cas, nil
}
return nil, 0, err
}
if data == nil {
return []string{}, cas, nil
}
var apiTokens []string
if err = json.Unmarshal(data, &apiTokens); err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal tokens: %v", err)
}
return apiTokens, cas, nil
}
func setApiTokens(key string, apiTokens []string, cas uint32) error {
data, err := json.Marshal(apiTokens)
if err != nil {
return fmt.Errorf("failed to marshal tokens: %v", err)
}
return proxywasm.SetSharedData(key, data, cas)
}
func removeElement(slice []string, s string) []string {
for i := 0; i < len(slice); i++ {
if slice[i] == s {
slice = append(slice[:i], slice[i+1:]...)
i--
}
}
return slice
}
func containsElement(slice []string, s string) bool {
for _, item := range slice {
if item == s {
return true
}
}
return false
}
func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
data, cas, err := proxywasm.GetSharedData(key)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return make(map[string]int64), cas, nil
}
return nil, 0, err
}
if data == nil {
return make(map[string]int64), cas, nil
}
var apiTokens map[string]int64
err = json.Unmarshal(data, &apiTokens)
if err != nil {
return nil, 0, err
}
return apiTokens, cas, nil
}
func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
}
func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
}
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
if c.isFailoverEnabled() {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
}
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
log.Infof("reset apiToken %s request failure count", apiTokenInUse)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
}
}
}
func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
if err != nil {
log.Errorf("Failed to get %s: %v", key, err)
continue
}
if op == resetApiTokenRequestCountOperation {
delete(apiTokenRequestCount, apiToken)
} else {
apiTokenRequestCount[apiToken]++
}
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
if err != nil {
log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
}
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
log.Debugf("Successfully updated the count of %s in %s", apiToken, key)
return
} else if !errors.Is(err, types.ErrorStatusCasMismatch) {
log.Errorf("Failed to set %s after %d attempts: %v", key, attempt, err)
return
}
log.Errorf("CAS mismatch when setting %s, retrying...", key)
}
}
func (c *ProviderConfig) initApiTokens() error {
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
}
func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
if err != nil {
return ""
}
count := len(apiTokens)
switch count {
case 0:
return ""
case 1:
return apiTokens[0]
default:
return apiTokens[rand.Intn(count)]
}
}
func (c *ProviderConfig) isFailoverEnabled() bool {
return c.failover.enabled
}
func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxVmLease, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxUnavailableApiTokens, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestSuccessCount, nil, 0)
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
}
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
if c.isFailoverEnabled() {
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
}
}
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
}
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
var apiToken string
if c.isFailoverEnabled() {
// if enable apiToken failover, only use available apiToken
apiToken = c.GetGlobalRandomToken(log)
} else {
apiToken = c.GetRandomToken()
}
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
}
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
if err != nil {
log.Errorf("Failed to get cluster_name: %v", err)
}
host := wrapper.GetRequestHost()
if host == "" {
host = ctx.GetContext(ctxRequestHost).(string)
}
path := wrapper.GetRequestPath()
if path == "" {
path = ctx.GetContext(ctxRequestPath).(string)
}
healthCheckEndpoint := HealthCheckEndpoint{
Host: host,
Path: path,
Cluster: string(cluster),
}
healthCheckEndpointByte, err := json.Marshal(healthCheckEndpoint)
if err != nil {
log.Errorf("Failed to marshal request host and path: %v", err)
}
err = proxywasm.SetSharedData(c.failover.ctxHealthCheckEndpoint, healthCheckEndpointByte, 0)
if err != nil {
log.Errorf("Failed to set request host and path: %v", err)
}
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -17,8 +18,11 @@ import (
// geminiProvider is the provider for google gemini/gemini flash service.
const (
geminiApiKeyHeader = "x-goog-api-key"
geminiDomain = "generativelanguage.googleapis.com"
geminiApiKeyHeader = "x-goog-api-key"
geminiDomain = "generativelanguage.googleapis.com"
geminiChatCompletionPath = "generateContent"
geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse"
geminiEmbeddingPath = "batchEmbedContents"
)
type geminiProviderInitializer struct {
@@ -51,157 +55,56 @@ func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken())
_ = util.OverwriteRequestHost(geminiDomain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionRequestBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings {
return g.onEmbeddingsRequestBody(ctx, body, log)
return g.onChatCompletionRequestBody(ctx, body, headers, log)
} else {
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
// 使用gemini接口协议
if g.config.protocol == protocolOriginal {
request := &geminiChatRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
_ = util.OverwriteRequestPath(path)
// 移除多余的model和stream字段
request = &geminiChatRequest{
Contents: request.Contents,
SafetySettings: request.SafetySettings,
GenerationConfig: request.GenerationConfig,
Tools: request.Tools,
}
if g.config.context == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := g.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
g.setSystemContent(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, g.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream)
_ = util.OverwriteRequestPath(path)
if g.config.context == nil {
geminiRequest := g.buildGeminiChatRequest(request)
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
}
err := g.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
geminiRequest := g.buildGeminiChatRequest(request)
if err := replaceJsonRequestBody(geminiRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
geminiRequest := g.buildGeminiChatRequest(request)
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
// 使用gemini接口协议
if g.config.protocol == protocolOriginal {
request := &geminiBatchEmbeddingRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
_ = util.OverwriteRequestPath(path)
// 移除多余的model字段
request = &geminiBatchEmbeddingRequest{
Requests: request.Requests,
}
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, g.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false)
_ = util.OverwriteRequestPath(path)
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
util.OverwriteRequestPathHeader(headers, path)
geminiRequest := g.buildBatchEmbeddingRequest(request)
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -285,11 +188,11 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
action = "batchEmbedContents"
action = geminiEmbeddingPath
} else if stream {
action = "streamGenerateContent?alt=sse"
action = geminiChatCompletionStreamPath
} else {
action = "generateContent"
action = geminiChatCompletionPath
}
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
}
@@ -605,3 +508,13 @@ func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, gemini
func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (g *geminiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, geminiChatCompletionPath) || strings.Contains(path, geminiChatCompletionStreamPath) {
return ApiNameChatCompletion
}
if strings.Contains(path, geminiEmbeddingPath) {
return ApiNameEmbeddings
}
return ""
}

View File

@@ -1,14 +1,12 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
// githubProvider is the provider for GitHub OpenAI service.
@@ -48,16 +46,7 @@ func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(githubDomain)
if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(githubEmbeddingPath)
}
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -66,47 +55,28 @@ func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, githubDomain)
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
util.OverwriteRequestPathHeader(headers, githubCompletionPath)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
}
return types.ActionContinue, errUnsupportedApiName
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (m *githubProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
func (m *githubProvider) GetApiName(path string) ApiName {
if strings.Contains(path, githubCompletionPath) {
return ApiNameChatCompletion
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
if strings.Contains(path, githubEmbeddingPath) {
return ApiNameEmbeddings
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
func (m *githubProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
return ""
}

View File

@@ -2,11 +2,11 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -18,14 +18,14 @@ const (
type groqProviderInitializer struct{}
func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (m *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
func (g *groqProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &groqProvider{
config: config,
contextCache: createContextCache(&config),
@@ -37,47 +37,35 @@ type groqProvider struct {
contextCache *contextCache
}
func (m *groqProvider) GetProviderType() string {
func (g *groqProvider) GetProviderType() string {
return providerTypeGroq
}
func (m *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(groqChatCompletionPath)
_ = util.OverwriteRequestHost(groqDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.groq.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.groq.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, groqChatCompletionPath)
util.OverwriteRequestHostHeader(headers, groqDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (g *groqProvider) GetApiName(path string) ApiName {
if strings.Contains(path, groqChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -114,26 +115,27 @@ func (m *hunyuanProvider) GetProviderType() string {
}
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// log.Debugf("hunyuanProvider.OnRequestHeaders called! hunyunSecretKey/id is: %s/%s", m.config.hunyuanAuthKey, m.config.hunyuanAuthId)
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(hunyuanDomain)
_ = util.OverwriteRequestPath(hunyuanRequestPath)
// 添加hunyuan需要的自定义字段
_ = proxywasm.ReplaceHttpRequestHeader(actionKey, hunyuanChatCompletionTCAction)
_ = proxywasm.ReplaceHttpRequestHeader(versionKey, versionValue)
// 删除一些字段
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, hunyuanDomain)
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
// 添加 hunyuan 需要的自定义字段
headers.Add(actionKey, hunyuanChatCompletionTCAction)
headers.Add(versionKey, versionValue)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -142,7 +144,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
// 为header添加时间戳字段 因为需要根据body进行签名时依赖时间戳故于body处理部分创建时间戳
var timestamp int64 = time.Now().Unix()
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
// log.Debugf("#debug nash5# OnRequestBody set timestamp header: ", timestamp)
// 使用混元本身接口的协议
if m.config.protocol == protocolOriginal {
@@ -198,7 +199,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
// log.Debugf("#debug nash5# OnRequestBody call hunyuan api using openai's api!")
model := request.Model
if model == "" {
@@ -235,18 +235,6 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
string(body),
)
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
// log.Debugf("#debug nash5# OnRequestBody done, body is: ", string(body))
// // 打印所有的headers
// headers, err2 := proxywasm.GetHttpRequestHeaders()
// if err2 != nil {
// log.Errorf("failed to get request headers: %v", err2)
// } else {
// // 迭代并打印所有请求头
// for _, header := range headers {
// log.Infof("#debug nash5# inB Request header - %s: %s", header[0], header[1])
// }
// }
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
}
@@ -277,6 +265,32 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, err
}
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
hunyuanRequest := m.buildHunyuanTextGenerationRequest(request)
var timestamp int64 = time.Now().Unix()
_ = proxywasm.ReplaceHttpRequestHeader(timestampKey, fmt.Sprintf("%d", timestamp))
// 根据确定好的payload进行签名
body, _ = json.Marshal(hunyuanRequest)
authorizedValueNew := GetTC3Authorizationcode(
m.config.hunyuanAuthId,
m.config.hunyuanAuthKey,
timestamp,
hunyuanDomain,
hunyuanChatCompletionTCAction,
string(body),
)
util.OverwriteRequestAuthorizationHeader(headers, authorizedValueNew)
return json.Marshal(hunyuanRequest)
}
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
@@ -561,3 +575,7 @@ func GetTC3Authorizationcode(secretId string, secretKey string, timestamp int64,
// fmt.Println(curl)
return authorization
}
func (m *hunyuanProvider) GetApiName(path string) ApiName {
return ApiNameChatCompletion
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -78,14 +79,17 @@ func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(minimaxDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, minimaxDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
@@ -107,51 +111,16 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return m.handleRequestBodyByChatCompletionPro(body, log)
} else {
// 使用ChatCompletion v2接口
return m.handleRequestBodyByChatCompletionV2(body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
}
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
}
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
// 使用minimax接口协议
if m.config.protocol == protocolOriginal {
request := &minimaxChatCompletionV2Request{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
if m.config.minimaxGroupId == "" {
return types.ActionContinue, errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when use %s model ", request.Model))
}
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if m.config.context == nil {
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
m.setBotSettings(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
@@ -174,6 +143,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
// 由于 minimaxChatCompletionV2格式和 OpenAI 一致)和 minimaxChatCompletionPro格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
// 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
// minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
@@ -186,37 +158,17 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
}
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, log wrapper.Log) (types.Action, error) {
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
return nil, err
}
// 映射模型重写requestPath
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
_ = util.OverwriteRequestPath(minimaxChatCompletionV2Path)
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
if m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return body, nil
}
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -474,3 +426,10 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re
func (m *minimaxProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (m *minimaxProvider) GetApiName(path string) ApiName {
if strings.Contains(path, minimaxChatCompletionV2Path) || strings.Contains(path, minimaxChatCompletionProPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -2,12 +2,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
const (
@@ -43,9 +41,7 @@ func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(mistralDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -53,28 +49,11 @@ func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.mistral.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.mistral.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, mistralDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -3,13 +3,12 @@ package provider
import (
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"net/http"
)
// moonshotProvider is the provider for Moonshot AI service.
@@ -58,33 +57,29 @@ func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName Api
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(moonshotChatCompletionPath)
_ = util.OverwriteRequestHost(moonshotDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, moonshotChatCompletionPath)
util.OverwriteRequestHostHeader(headers, moonshotDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
// moonshot 有自己获取 context 的配置moonshotFileId因此无法复用 handleRequestBody 方法
// moonshot 的 body 没有修改无须实现TransformRequestBody使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.config.moonshotFileId == "" && m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}

View File

@@ -3,11 +3,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
// ollamaProvider is the provider for Ollama service.
@@ -53,10 +52,7 @@ func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(ollamaChatCompletionPath)
_ = util.OverwriteRequestHost(m.serviceDomain)
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -64,51 +60,11 @@ func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.config.modelMapping == nil && m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
if m.contextCache != nil {
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.ollama.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.ollama.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
} else {
return types.ActionContinue, err
}
} else {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.ollama.transform_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
return types.ActionContinue, err
}
_ = proxywasm.ResumeHttpRequest()
return types.ActionPause, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, ollamaChatCompletionPath)
util.OverwriteRequestHostHeader(headers, m.serviceDomain)
headers.Del("Content-Length")
}

View File

@@ -1,12 +1,13 @@
package provider
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -57,27 +58,31 @@ func (m *openaiProvider) GetProviderType() string {
}
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
if m.customPath == "" {
switch apiName {
case ApiNameChatCompletion:
_ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath)
util.OverwriteRequestPathHeader(headers, defaultOpenaiChatCompletionPath)
case ApiNameEmbeddings:
ctx.DontReadRequestBody()
_ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath)
util.OverwriteRequestPathHeader(headers, defaultOpenaiEmbeddingsPath)
}
} else {
_ = util.OverwriteRequestPath(m.customPath)
util.OverwriteRequestPathHeader(headers, m.customPath)
}
if m.customDomain == "" {
_ = util.OverwriteRequestHost(defaultOpenaiDomain)
util.OverwriteRequestHostHeader(headers, defaultOpenaiDomain)
} else {
_ = util.OverwriteRequestHost(m.customDomain)
util.OverwriteRequestHostHeader(headers, m.customDomain)
}
if len(m.config.apiTokens) > 0 {
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
}
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
headers.Del("Content-Length")
}
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -85,9 +90,13 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
return nil, err
}
if m.config.responseJsonSchema != nil {
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
@@ -101,27 +110,5 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
request.StreamOptions.IncludeUsage = true
}
}
if m.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.openai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return json.Marshal(request)
}

View File

@@ -1,14 +1,17 @@
package provider
import (
"encoding/json"
"errors"
"math/rand"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
type ApiName string
@@ -110,14 +113,32 @@ type Provider interface {
GetProviderType() string
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
}
type ResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
}
@@ -143,6 +164,9 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间单位为毫秒。默认值为120000即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -289,6 +313,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
}
failoverJson := json.Get("failover")
c.failover = &failover{
enabled: false,
}
if failoverJson.Exists() {
c.failover.FromJson(failoverJson)
}
}
func (c *ProviderConfig) Validate() error {
@@ -304,6 +336,12 @@ func (c *ProviderConfig) Validate() error {
}
}
if c.failover.enabled {
if err := c.failover.Validate(); err != nil {
return err
}
}
if c.typ == "" {
return errors.New("missing type in provider config")
}
@@ -355,6 +393,60 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
return initializer.CreateProvider(pc)
}
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
switch req := request.(type) {
case *chatCompletionRequest:
if err := decodeChatCompletionRequest(body, req); err != nil {
return err
}
streaming := req.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
}
return c.setRequestModel(ctx, req, log)
case *embeddingsRequest:
if err := decodeEmbeddingsRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req, log)
default:
return errors.New("unsupported request type")
}
}
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
var model *string
switch req := request.(type) {
case *chatCompletionRequest:
model = &req.Model
case *embeddingsRequest:
model = &req.Model
default:
return errors.New("unsupported request type")
}
return c.mapModel(ctx, model, log)
}
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
if *model == "" {
return errors.New("missing model in request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
mappedModel := getMappedModel(*model, c.modelMapping, log)
if mappedModel == "" {
return errors.New("model becomes empty after applying the configured mapping")
}
*model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, *model)
return nil
}
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
mappedModel := doGetMappedModel(model, modelMapping, log)
if len(mappedModel) != 0 {
@@ -391,3 +483,62 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) {
// use original protocol
if c.protocol == protocolOriginal {
return types.ActionContinue, nil
}
// use openai protocol
var err error
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
util.ReplaceOriginalHttpHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
}
if err != nil {
return types.ActionContinue, err
}
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
originalHeaders := util.GetOriginalHttpHeaders()
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(originalHeaders)
}
}
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
var request interface{}
if apiName == ApiNameChatCompletion {
request = &chatCompletionRequest{}
} else {
request = &embeddingsRequest{}
}
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
return json.Marshal(request)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"math"
"net/http"
"reflect"
"strings"
"time"
@@ -58,35 +59,50 @@ func (m *qwenProviderInitializer) CreateProvider(config ProviderConfig) (Provide
}
type qwenProvider struct {
config ProviderConfig
config ProviderConfig
contextCache *contextCache
}
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, qwenDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
}
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, headers, log)
} else {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
}
func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = util.OverwriteRequestHost(qwenDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
return types.ActionContinue, nil
} else if m.config.qwenEnableCompatible {
_ = util.OverwriteRequestPath(qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(qwenTextEmbeddingPath)
} else {
return types.ActionContinue, errUnsupportedApiName
}
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
@@ -121,65 +137,23 @@ func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
}
return types.ActionContinue, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
if m.config.protocol == protocolOriginal {
if m.config.context == nil {
return types.ActionContinue, nil
}
request := &qwenTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
m.insertContextMessage(request, content, false)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
// Use the qwen multimodal model generation API
if strings.HasPrefix(request.Model, qwenVlModelPrefixName) {
_ = util.OverwriteRequestPath(qwenMultimodalGenerationPath)
util.OverwriteRequestPathHeader(headers, qwenMultimodalGenerationPath)
}
streaming := request.Stream
@@ -191,62 +165,20 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
}
if m.config.context == nil {
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.qwen.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
qwenRequest := m.buildQwenTextGenerationRequest(request, streaming)
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, qwenRequest.Parameters.IncrementalOutput)
}
if err := replaceJsonRequestBody(qwenRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.qwen.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.buildQwenTextGenerationRequest(ctx, request, streaming)
}
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
return nil, err
}
log.Debugf("=== embeddings request: %v", request)
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in the request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
if qwenRequest, err := m.buildQwenTextEmbeddingRequest(request); err == nil {
return types.ActionContinue, replaceJsonRequestBody(qwenRequest, log)
} else {
return types.ActionContinue, err
qwenRequest, err := m.buildQwenTextEmbeddingRequest(request)
if err != nil {
return nil, err
}
return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -375,7 +307,7 @@ func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletionRequest, streaming bool) *qwenTextGenRequest {
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
messages := make([]qwenMessage, 0, len(origRequest.Messages))
for i := range origRequest.Messages {
messages = append(messages, chatMessage2QwenMessage(origRequest.Messages[i]))
@@ -397,6 +329,11 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
Tools: origRequest.Tools,
},
}
if streaming {
ctx.SetContext(ctxKeyIncrementalStreaming, request.Parameters.IncrementalOutput)
}
if len(m.config.qwenFileIds) != 0 && origRequest.Model == qwenLongModelName {
builder := strings.Builder{}
for _, fileId := range m.config.qwenFileIds {
@@ -406,13 +343,15 @@ func (m *qwenProvider) buildQwenTextGenerationRequest(origRequest *chatCompletio
builder.WriteString("fileid://")
builder.WriteString(fileId)
}
contextMessageId := m.insertContextMessage(request, builder.String(), true)
if contextMessageId == 0 {
// The context message cannot come first. We need to add another dummy system message before it.
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
body, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("unable to marshal request: %v", err)
}
return m.insertHttpContextMessage(body, builder.String(), true)
}
return request
return json.Marshal(request)
}
func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse) *chatCompletionResponse {
@@ -569,7 +508,12 @@ func (m *qwenProvider) convertStreamEvent(ctx wrapper.HttpContext, responseBuild
return nil
}
func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content string, onlyOneSystemBeforeFile bool) int {
func (m *qwenProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &qwenTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
fileMessage := qwenMessage{
Role: roleSystem,
Content: content,
@@ -586,10 +530,8 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
}
if firstNonSystemMessageIndex == 0 {
request.Input.Messages = append([]qwenMessage{fileMessage}, request.Input.Messages...)
return 0
} else if !onlyOneSystemBeforeFile {
request.Input.Messages = append(request.Input.Messages[:firstNonSystemMessageIndex], append([]qwenMessage{fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)...)
return firstNonSystemMessageIndex
} else {
builder := strings.Builder{}
for _, message := range request.Input.Messages[:firstNonSystemMessageIndex] {
@@ -599,8 +541,15 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
builder.WriteString(message.StringContent())
}
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)
return 1
firstNonSystemMessageIndex = 1
}
if firstNonSystemMessageIndex == 0 {
// The context message cannot come first. We need to add another dummy system message before it.
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: qwenDummySystemMessageContent}}, request.Input.Messages...)
}
return json.Marshal(request)
}
func (m *qwenProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
@@ -804,3 +753,16 @@ func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
}
}
}
func (m *qwenProvider) GetApiName(path string) ApiName {
switch {
case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath),
strings.Contains(path, qwenCompatiblePath):
return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath):
return ApiNameEmbeddings
default:
return ""
}
}

View File

@@ -3,7 +3,6 @@ package provider
import (
"encoding/json"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
@@ -18,6 +17,13 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er
return nil
}
func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
if err := json.Unmarshal(body, request); err != nil {
return fmt.Errorf("unable to unmarshal request: %v", err)
}
return nil
}
func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
body, err := json.Marshal(request)
if err != nil {
@@ -31,6 +37,15 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
return err
}
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
log.Debugf("request body: %s", string(body))
err := proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return fmt.Errorf("unable to replace the original request body: %v", err)
}
return nil
}
func insertContextMessage(request *chatCompletionRequest, content string) {
fileMessage := chatMessage{
Role: roleSystem,

View File

@@ -2,8 +2,8 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
@@ -71,11 +71,7 @@ func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNam
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(sparkHost)
_ = util.OverwriteRequestPath(sparkChatCompletionPath)
_ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
p.config.handleRequestHeaders(p, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -83,36 +79,7 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 使用Spark协议
if p.config.protocol == protocolOriginal {
request := &sparkRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 目前星火在模型名称错误时也会调用generalv3这里还是按照输入的模型名称设置响应里的模型名称
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
// 使用openai协议
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, p.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
}
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
@@ -205,3 +172,11 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp
func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}

View File

@@ -2,12 +2,10 @@ package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
)
const (
@@ -45,10 +43,7 @@ func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(stepfunChatCompletionPath)
_ = util.OverwriteRequestHost(stepfunDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -56,28 +51,12 @@ func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.stepfun.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.stepfun.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, stepfunChatCompletionPath)
util.OverwriteRequestHostHeader(headers, stepfunDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,11 +2,10 @@ package provider
import (
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -45,10 +44,7 @@ func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName,
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(yiChatCompletionPath)
_ = util.OverwriteRequestHost(yiDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -56,28 +52,12 @@ func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, bod
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.yi.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.yi.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, yiChatCompletionPath)
util.OverwriteRequestHostHeader(headers, yiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -2,11 +2,11 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -44,10 +44,7 @@ func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(zhipuAiChatCompletionPath)
_ = util.OverwriteRequestHost(zhipuAiDomain)
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil
}
@@ -55,28 +52,19 @@ func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.zhihupai.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.zhihupai.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, zhipuAiChatCompletionPath)
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *zhipuAiProvider) GetApiName(path string) ApiName {
if strings.Contains(path, zhipuAiChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -1,6 +1,10 @@
package util
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
import (
"net/http"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
const (
HeaderContentType = "Content-Type"
@@ -21,13 +25,6 @@ func CreateHeaders(kvs ...string) [][2]string {
return headers
}
func OverwriteRequestHost(host string) error {
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-HOST", originHost)
}
return proxywasm.ReplaceHttpRequestHeader(":authority", host)
}
func OverwriteRequestPath(path string) error {
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
_ = proxywasm.ReplaceHttpRequestHeader("X-ENVOY-ORIGINAL-PATH", originPath)
@@ -43,3 +40,56 @@ func OverwriteRequestAuthorization(credential string) error {
}
return proxywasm.ReplaceHttpRequestHeader("Authorization", credential)
}
func OverwriteRequestHostHeader(headers http.Header, host string) {
if originHost, err := proxywasm.GetHttpRequestHeader(":authority"); err == nil {
headers.Set("X-ENVOY-ORIGINAL-HOST", originHost)
}
headers.Set(":authority", host)
}
func OverwriteRequestPathHeader(headers http.Header, path string) {
if originPath, err := proxywasm.GetHttpRequestHeader(":path"); err == nil {
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
}
headers.Set(":path", path)
}
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
if exist := headers.Get("X-HI-ORIGINAL-AUTH"); exist == "" {
if originAuth := headers.Get("Authorization"); originAuth != "" {
headers.Set("X-HI-ORIGINAL-AUTH", originAuth)
}
}
headers.Set("Authorization", credential)
}
func HeaderToSlice(header http.Header) [][2]string {
slice := make([][2]string, 0, len(header))
for key, values := range header {
for _, value := range values {
slice = append(slice, [2]string{key, value})
}
}
return slice
}
func SliceToHeader(slice [][2]string) http.Header {
header := make(http.Header)
for _, pair := range slice {
key := pair[0]
value := pair[1]
header.Add(key, value)
}
return header
}
func GetOriginalHttpHeaders() http.Header {
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
return SliceToHeader(originalHeaders)
}
func ReplaceOriginalHttpHeaders(headers http.Header) {
modifiedHeaders := HeaderToSlice(headers)
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
}

View File

@@ -6,8 +6,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.5
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -5,12 +5,14 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -30,19 +30,23 @@ description: 阿里云内容安全检测
| `denyCode` | int | optional | 200 | 指定内容非法时的响应状态码 |
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
| `protocol` | string | optional | openai | 协议格式非openai协议填`original` |
| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low |
补充说明一下 `denyMessage`,对于openai格式的请求非法请求的处理逻辑为:
- 如果配置了 `denyMessage`
- 优先返回阿里云内容安全的建议回答格式为openai格式的流式/非流式响应
- 如果阿里云内容安全未返回建议的回答,返回内容为 `denyMessage` 配置内容格式为openai格式的流式/非流式响应
- 如果没有配置 `denyMessage`
- 优先返回阿里云内容安全的建议回答格式为openai格式的流式/非流式响应
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`格式为openai格式的流式/非流式响应
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容格式为openai格式的流式/非流式响应
- 如果没有配置 `denyMessage`优先返回阿里云内容安全的建议回答格式为openai格式的流式/非流式响应
- 如果阿里云内容安全未返回建议的回答,返回内容为内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`格式为openai格式的流式/非流式响应
如果用户使用了非openai格式的协议应当配置 `denyMessage`此时对非法请求的处理逻辑为:
- 返回用户配置的 `denyMessage` 内容,用户可以配置其为序列化后的json字符串以保持与正常请求接口返回格式的一致性
- 如果 `denyMessage` 为空,优先返回阿里云内容安全的建议回答,格式为纯文本
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`格式为纯文本
如果用户使用了非openai格式的协议此时对非法请求的处理逻辑为
- 如果配置了 `denyMessage`返回用户配置的 `denyMessage` 内容,非流式响应
- 如果没有配置 `denyMessage`,优先返回阿里云内容安全的建议回答,非流式响应
- 如果阿里云内容安全未返回建议回答,返回内置的兜底回答,内容为`"很抱歉,我无法回答您的问题"`非流式响应
补充说明一下 `riskLevelBar` 的四个等级:
- `max`: 检测请求/响应内容,但是不会产生拦截行为
- `high`: 内容安全检测结果中风险等级为 `high` 时产生拦截
- `medium`: 内容安全检测结果中风险等级 >= `medium` 时产生拦截
- `low`: 内容安全检测结果中风险等级 >= `low` 时产生拦截
## 配置示例
### 前提条件

View File

@@ -7,7 +7,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -9,8 +9,8 @@ github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -35,12 +35,16 @@ func main() {
}
const (
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":%s,"choices":[{"index":0,"message":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":%s,"choices":[{"index":0,"delta":{"role":"assistant","content":%s},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model": %s,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
MaxRisk = "max"
HighRisk = "high"
MediumRisk = "medium"
LowRisk = "low"
NoRisk = "none"
TracingPrefix = "trace_span_tag."
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
@@ -53,10 +57,37 @@ const (
AliyunUserAgent = "CIPFrom/AIGateway"
)
type Response struct {
Code int `json:"Code"`
Message string `json:"Message"`
RequestId string `json:"RequestId"`
Data Data `json:"Data"`
}
type Data struct {
RiskLevel string `json:"RiskLevel"`
Result []Result `json:"Result,omitempty"`
Advice []Advice `json:"Advice,omitempty"`
}
type Result struct {
RiskWords string `json:"RiskWords,omitempty"`
Description string `json:"Description,omitempty"`
Confidence float64 `json:"Confidence,omitempty"`
Label string `json:"Label,omitempty"`
}
type Advice struct {
Answer string `json:"Answer,omitempty"`
HitLabel string `json:"HitLabel,omitempty"`
HitLibName string `json:"HitLibName,omitempty"`
}
type AISecurityConfig struct {
client wrapper.HttpClient
ak string
sk string
token string
checkRequest bool
requestCheckService string
requestContentJsonPath string
@@ -67,6 +98,7 @@ type AISecurityConfig struct {
denyCode int64
denyMessage string
protocolOriginal bool
riskLevelBar string
metrics map[string]proxywasm.MetricCounter
}
@@ -79,12 +111,31 @@ func (config *AISecurityConfig) incrementCounter(metricName string, inc uint64)
counter.Increment(inc)
}
func riskLevelToInt(riskLevel string) int {
switch riskLevel {
case MaxRisk:
return 4
case HighRisk:
return 3
case MediumRisk:
return 2
case LowRisk:
return 1
case NoRisk:
return 0
default:
return -1
}
}
func urlEncoding(rawStr string) string {
encodedStr := url.PathEscape(rawStr)
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
encodedStr = strings.ReplaceAll(encodedStr, ":", "%3A")
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
encodedStr = strings.ReplaceAll(encodedStr, "$", "%24")
encodedStr = strings.ReplaceAll(encodedStr, "@", "%40")
return encodedStr
}
@@ -130,6 +181,7 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
if config.ak == "" || config.sk == "" {
return errors.New("invalid AK/SK config")
}
config.token = json.Get("securityToken").String()
config.checkRequest = json.Get("checkRequest").Bool()
config.checkResponse = json.Get("checkResponse").Bool()
config.protocolOriginal = json.Get("protocol").String() == "original"
@@ -164,6 +216,14 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
} else {
config.responseStreamContentJsonPath = DefaultStreamingResponseJsonPath
}
if obj := json.Get("riskLevelBar"); obj.Exists() {
config.riskLevelBar = obj.String()
if riskLevelToInt(config.riskLevelBar) <= 0 {
return errors.New("invalid risk level, value must be one of [max, high, medium, low]")
}
} else {
config.riskLevelBar = HighRisk
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
@@ -192,105 +252,82 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking request body...")
content := gjson.GetBytes(body, config.requestContentJsonPath).Raw
model := gjson.GetBytes(body, "model").Raw
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").String()
ctx.SetContext("requestModel", model)
log.Debugf("Raw request content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != 200 {
log.Error(string(responseBody))
proxywasm.ResumeHttpRequest()
return
}
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
var denyMessage string
messageNeedSerialization := true
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
messageNeedSerialization = false
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if messageNeedSerialization {
if data, err := json.Marshal(denyMessage); err == nil {
denyMessage = string(data)
} else {
denyMessage = fmt.Sprintf("\"%s\"", DefaultDenyMessage)
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("request"))
config.incrementCounter("ai_sec_request_deny", 1)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(denyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
proxywasm.ResumeHttpRequest()
}
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
} else {
log.Debugf("request content is empty. skip")
if len(content) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)),
}
if config.token != "" {
params["SecurityToken"] = config.token
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at request phase")
proxywasm.ResumeHttpRequest()
return
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
proxywasm.ResumeHttpRequest()
return
}
denyMessage := DefaultDenyMessage
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.incrementCounter("ai_sec_request_deny", 1)
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
}
func convertHeaders(hs [][2]string) map[string][]string {
@@ -341,92 +378,81 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.responseContentJsonPath).Raw
content = gjson.GetBytes(body, config.responseContentJsonPath).String()
}
log.Debugf("Raw response content is: %s", content)
if len(content) > 0 {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": fmt.Sprintf(`{"content": %s}`, content),
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
if statusCode != 200 {
log.Error(string(responseBody))
return
}
respData := gjson.GetBytes(responseBody, "Data")
if respData.Exists() {
respAdvice := respData.Get("Advice")
respResult := respData.Get("Result")
var denyMessage string
if config.protocolOriginal {
// not openai
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else {
denyMessage = DefaultDenyMessage
}
} else {
// openai
if respAdvice.Exists() {
denyMessage = respAdvice.Array()[0].Get("Answer").Raw
} else if config.denyMessage != "" {
denyMessage = config.denyMessage
} else {
denyMessage = DefaultDenyMessage
}
}
if respResult.Array()[0].Get("Label").String() != "nonLabel" {
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(denyMessage)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, denyMessage, randomID, model))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, denyMessage))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_risklabel"}, []byte(respResult.Array()[0].Get("Label").String()))
proxywasm.SetProperty([]string{TracingPrefix, "ai_sec_deny_phase"}, []byte("response"))
config.incrementCounter("ai_sec_response_deny", 1)
}
}
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
} else {
log.Debugf("request content is empty. skip")
if len(content) == 0 {
log.Info("response content is empty. skip")
return types.ActionContinue
}
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)),
}
if config.token != "" {
params["SecurityToken"] = config.token
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
return
}
var response Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
return
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
return
}
denyMessage := DefaultDenyMessage
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(marshalledDenyMessage)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
config.incrementCounter("ai_sec_response_deny", 1)
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
}
return types.ActionPause
}
func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
@@ -434,10 +460,21 @@ func extractMessageFromStreamingBody(data []byte, jsonPath string) string {
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
jsonRaw := gjson.GetBytes(chunk, jsonPath).Raw
if len(jsonRaw) > 2 {
strChunks = append(strChunks, jsonRaw[1:len(jsonRaw)-1])
}
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
}
return strings.Join(strChunks, "")
}
func marshalStr(raw string, log wrapper.Log) string {
helper := map[string]string{
"placeholder": raw,
}
marshalledHelper, _ := json.Marshal(helper)
marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw
if len(marshalledRaw) >= 2 {
return marshalledRaw[1 : len(marshalledRaw)-1]
} else {
log.Errorf("failed to marshal json string, raw string is: %s", raw)
return ""
}
return fmt.Sprintf(`"%s"`, strings.Join(strChunks, ""))
}

View File

@@ -6,8 +6,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -5,12 +5,14 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -6,8 +6,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.1-0.20240617024146-5f150179637c
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
github.com/wasilibs/go-re2 v1.5.3
github.com/zmap/go-iptree v0.0.0-20210731043055-d4e632617837
)

View File

@@ -7,6 +7,7 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -15,6 +16,7 @@ github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

View File

@@ -6,8 +6,8 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.14.3
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
)
require (

View File

@@ -12,6 +12,7 @@ github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -19,6 +20,7 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=

Some files were not shown because too many files have changed in this diff Show More