Compare commits

...

56 Commits

Author SHA1 Message Date
澄潭
ce298054f1 Release 2.1.10 (#3447) 2026-02-03 19:13:11 +08:00
TianHao Zhang
24c69fb0b7 fix concurrent SSE connections returning wrong endpoint (#3341) 2026-01-19 11:56:57 +08:00
johnlanni
a38be77b9e Fix the issue of backend errors not being propagated in streamable proxy mode 2026-01-19 11:56:57 +08:00
johnlanni
27999dcc59 fix(mcp): remove accept-encoding header to prevent response compression 2026-01-19 11:56:56 +08:00
aias00
811179a6a0 fix: skip unhealthy or disabled services form nacos and always marshal AllowTools field (#3220)
Co-authored-by: EricaLiu <30773688+Erica177@users.noreply.github.com>
2026-01-19 11:56:56 +08:00
woody
5f43dd0224 feat/ai proxy vertex ai compatible (#3324) 2026-01-19 11:56:56 +08:00
rinfx
e23ab3ca7c Replace model-router and model-mapper with Go implementation (#3317) 2026-01-19 11:54:41 +08:00
woody
032a69556f feat(vertex): 为 ai-proxy 插件的 Vertex AI Provider 添加 Express Mode 支持 || feat(vertex): Add Express Mode support to Vertex AI Provider of ai-proxy plug-in (#3301) 2026-01-19 11:54:41 +08:00
qshuai
ee6bb11730 docs: unknown config entry <show_limit_quota_header> in ai-token-ratelimit plugin (#3241) 2026-01-19 11:54:41 +08:00
CZJCC
fc600f204a feat(ai-proxy): add Bearer Token authentication support for Bedrock p… (#3305) 2026-01-19 11:54:41 +08:00
澄潭
357418853f Update README.md 2026-01-19 11:54:41 +08:00
澄潭
e8586cccd7 Update README.md 2026-01-19 11:54:40 +08:00
nixidexiangjiao
d55b9a0837 feat(ai-load-balancer): enhance global least request load balancer (#3255) 2026-01-19 11:54:28 +08:00
johnlanni
4f04ac067b update helm README.md
Change-Id: Ic216d36c4cb0e570c9084b63c9f250c9ab6f4cec
2026-01-19 11:54:27 +08:00
Wilson Wu
c7028bd7f2 feat: add topology spread constraints for gateway and controller (#3171)
Signed-off-by: Wilson Wu <iwilsonwu@gmail.com>
2026-01-19 11:54:27 +08:00
Kent Dong
95ff52cde9 feat: Add traffic-editor plugin (#2825) 2026-01-19 11:54:27 +08:00
steven
7c7205b572 fix(helm,podmonitor): add podMonitorSelector for gateway metrics configuration (#3022) 2026-01-19 11:54:27 +08:00
Jingze
f342f50ca4 feat: Add response-cache plugin (#3061)
Co-authored-by: mirror58229 <674958229@qq.com>
2026-01-19 11:54:27 +08:00
github-actions[bot]
659d136bfe Update CRD file in the helm folder (#3155)
Co-authored-by: johnlanni <6763318+johnlanni@users.noreply.github.com>
2026-01-19 11:54:26 +08:00
澄潭
541e5e206f fix(mcp-server): fix MCP server version negotiation to comply with spec (#3258) 2026-01-19 11:54:26 +08:00
rinfx
387c337654 support disable thinking and add reasoning token usage (#3261) 2026-01-19 11:54:14 +08:00
Bingkun Zhao
8024a96881 fix: ai-proxy dify provider extract hostname from difyApiUrl (#3257) 2026-01-19 11:54:14 +08:00
firebook
f71c1900a8 upgrade vipshop Description of Use in ADOPTERS.md (#3250) 2026-01-19 11:54:14 +08:00
rinfx
1199946d36 special handling for cases where extracted content is empty and add unit test (#3251) 2026-01-19 11:54:14 +08:00
rinfx
b1571de6f0 Cross provider lb bugfix (#3252) 2026-01-19 11:54:13 +08:00
zzjin
20dae295a8 add: Include labring as an adopter in ADOPTERS.md (#3249)
Signed-off-by: zzjin <tczzjin@gmail.com>
2026-01-19 11:54:13 +08:00
Maple Lee
9a1f9e4606 Add kuaishou to ADOPTERS.md (#3244) 2026-01-19 11:54:13 +08:00
Wangzy
6f4ef33590 Add tool-search server (#3136)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2026-01-19 11:54:13 +08:00
Kent Dong
fef8ecc822 fix: Switch to the new HasRequestBody logic in ai-proxy (#3211) 2026-01-19 11:53:58 +08:00
rinfx
0ade9504be [feat] ai-security-guard support checking prompt and image in request body (#3206) 2026-01-19 11:53:58 +08:00
rinfx
6311fecfce add rebuild logic for ai-cache (#3185) 2026-01-19 11:53:58 +08:00
Kent Dong
5c225de080 fix: Enlarge the request body buffer size when processing multipart data in model-router (#3237) 2026-01-19 11:53:58 +08:00
rinfx
bf9ef5eefd support vertex's claude (#3236) 2026-01-19 11:53:57 +08:00
澄潭
26f5737a80 Update ADOPTERS.md 2026-01-19 11:53:57 +08:00
firebook
50c1a5e78c Add vipshop to ADOPTERS.md (#3234) 2026-01-19 11:53:57 +08:00
Kent Dong
647304eb45 doc: Add Trip.com to the adopters list (#3233) 2026-01-19 11:53:57 +08:00
澄潭
0a7fc9f412 Add ADOPTERS.md to document project adopters (#3231) 2026-01-19 11:53:44 +08:00
007gzs
c9253264ef Rust Plugin add Rule matcher test (#3230) 2026-01-19 11:53:43 +08:00
woody
8c80084ada fix(ai-proxy): ensure basePathHandling works with original protocol (#3225) 2026-01-19 11:53:43 +08:00
澄潭
9f5ee99c2d Update README.md 2026-01-19 11:53:43 +08:00
澄潭
3770bd2f55 Update README.md 2026-01-19 11:53:43 +08:00
rinfx
698a395e89 vertex support global region (#3213) 2026-01-19 11:53:43 +08:00
澄潭
2c72767203 feat: enhance model mapper and router with rebuild triggers and path extensions (#3218) 2026-01-19 11:53:43 +08:00
Liang Deng
bb3ac59834 feat(ai-proxy): support handle array content in chatToolMessage2BedrockMessage (#3200)
Signed-off-by: Liang Deng <ytdengliang@gmail.com>
Co-authored-by: rinfx <yucheng.lxr@alibaba-inc.com>
2026-01-19 11:53:42 +08:00
johnlanni
6c1fe57034 fix(ai-proxy): only perform protocol conversion for non-original protocols
Change-Id: Ib8ae3ebf6b47284108663c97777032d6282bb53c
2026-01-19 11:53:25 +08:00
johnlanni
5c5cc6ac90 update go sum 2026-01-19 11:53:25 +08:00
johnlanni
265da8e4d6 add wrapper.WithRebuildMaxMemBytes(200MB) to ai-statistics&ai-proxy 2026-01-19 11:53:25 +08:00
johnlanni
119698eea4 update wasm-go dep of mcp-server 2026-01-19 11:53:25 +08:00
rinfx
18d20ca135 doubao support configuration for domain (#3184) 2026-01-19 11:53:24 +08:00
rinfx
9978db2ac6 [feat] ai-security-guard refactor & support checking multimoadl input (#3075) 2026-01-19 11:53:24 +08:00
Kent Dong
1582fa6ef9 fix: Bypass the response body processing for MCP streamable transport (#3187) 2026-01-19 11:53:24 +08:00
woody
2b49fd5b26 implement generic provider for vendor-agnostic passthrough (#3175) 2026-01-19 11:53:24 +08:00
woody
48433a6549 Fix OpenAI capability rewrite dropping query string (#3168) 2026-01-19 11:52:50 +08:00
rinfx
8ec48b3b85 [feat] load balancing across different clusters and endpoints based on metrics (#3063) 2026-01-19 11:52:43 +08:00
rinfx
32007d2ab8 remove omitempty for toolcall index (#3148)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2026-01-19 11:52:33 +08:00
澄潭
27b088fc7e Update .licenserc.yaml 2026-01-19 11:52:26 +08:00
153 changed files with 15375 additions and 1456 deletions

View File

@@ -35,6 +35,7 @@ header:
- 'hgctl/pkg/manifests'
- 'pkg/ingress/kube/gateway/istio/testdata'
- 'release-notes/**'
- '.cursor/**'
comment: on-failure
dependency:

13
ADOPTERS.md Normal file
View File

@@ -0,0 +1,13 @@
# Adopters of Higress
Below are the adopters of the Higress project. If you are using Higress in your organization, please add your name to the list by submitting a pull request: this will help foster the Higress community. Kindly ensure the list remains in alphabetical order.
| Organization | Contact (GitHub User Name) | Environment | Description of Use |
|---------------------------------------|----------------------------------------|--------------------------------------------|-----------------------------------------------------------------------|
| [antdigital](https://antdigital.com/) | [@Lovelcp](https://github.com/Lovelcp) | Production | Ingress Gateway, Microservice gateway, LLM Gateway, MCP Gateway |
| [kuaishou](https://ir.kuaishou.com/) | [@maplecap](https://github.com/maplecap) | Production | LLM Gateway |
| [Trip.com](https://www.trip.com/) | [@CH3CHO](https://github.com/CH3CHO) | Production | LLM Gateway, MCP Gateway |
| [vipshop](https://github.com/vipshop/) | [@firebook](https://github.com/firebook) | Production | LLM Gateway, MCP Gateway, Inference Gateway |
| [labring](https://github.com/labring/) | [@zzjin](https://github.com/zzjin) | Production | Ingress Gateway |
| < company name here> | < your github handle here > | <Production/Testing/Experimenting/etc> | <Ingress Gateway/Microservice gateway/LLM Gateway/MCP Gateway/Inference Gateway> |

View File

@@ -146,7 +146,7 @@ docker-buildx-push: clean-env docker.higress-buildx
export PARENT_GIT_TAG:=$(shell cat VERSION)
export PARENT_GIT_REVISION:=$(TAG)
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.2.0/envoy-symbol-ARCH.tar.gz
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.10/envoy-symbol-ARCH.tar.gz
build-envoy: prebuild
./tools/hack/build-envoy.sh

View File

@@ -45,7 +45,7 @@ Higress was born within Alibaba to solve the issues of Tengine reload affecting
You can click the button below to install the enterprise version of Higress:
[![Deploy on AlibabaCloud](https://img.alicdn.com/imgextra/i1/O1CN01e6vwe71EWTHoZEcpK_!!6000000000359-55-tps-170-40.svg)](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
[![Deploy on AlibabaCloud](https://img.alicdn.com/imgextra/i1/O1CN01e6vwe71EWTHoZEcpK_!!6000000000359-55-tps-170-40.svg)](https://www.aliyun.com/product/api-gateway?spm=higress-github.topbar.0.0.0)
If you use open-source Higress and wish to obtain enterprise-level support, you can contact the project maintainer johnlanni's email: **zty98751@alibaba-inc.com** or social media accounts (WeChat ID: **nomadao**, DingTalk ID: **chengtanzty**). Please note **Higress** when adding as a friend :)
@@ -82,6 +82,8 @@ Port descriptions:
>
> If you experience a timeout when pulling image from `higress-registry.cn-hangzhou.cr.aliyuncs.com`, you can try replacing it with the following docker registry mirror source:
>
> **North America**: `higress-registry.us-west-1.cr.aliyuncs.com`
>
> **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
For other installation methods such as Helm deployment under K8s, please refer to the official [Quick Start documentation](https://higress.io/en-us/docs/user/quickstart).
@@ -117,7 +119,16 @@ If you are deploying on the cloud, it is recommended to use the [Enterprise Edit
Higress can function as a feature-rich ingress controller, which is compatible with many annotations of K8s' nginx ingress controller.
[Gateway API](https://gateway-api.sigs.k8s.io/) support is coming soon and will support smooth migration from Ingress API to Gateway API.
[Gateway API](https://gateway-api.sigs.k8s.io/) is already supported, and it supports a smooth migration from Ingress API to Gateway API.
Compared to ingress-nginx, the resource overhead has significantly decreased, and the speed at which route changes take effect has improved by ten times.
> The following resource overhead comparison comes from [sealos](https://github.com/labring).
>
> For details, you can read this [article](https://sealos.io/blog/sealos-envoy-vs-nginx-2000-tenants) to understand how sealos migrates the monitoring of **tens of thousands of ingress** resources from nginx ingress to higress.
![](https://img.alicdn.com/imgextra/i1/O1CN01bhEtb229eeMNBWmdP_!!6000000008093-2-tps-750-547.png)
- **Microservice gateway**:

View File

@@ -1 +1 @@
v2.1.9
v2.1.10

View File

@@ -1,5 +1,5 @@
apiVersion: v2
appVersion: 2.1.9
appVersion: 2.1.10
description: Helm chart for deploying higress gateways
icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/
@@ -15,4 +15,4 @@ dependencies:
repository: "file://../redis"
version: 0.0.1
type: application
version: 2.1.9
version: 2.1.10

View File

@@ -3,7 +3,8 @@
# Declare variables to be passed into your templates.
global:
# -- Specify the image registry and pull policy
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
# Will inherit from parent chart's global.hub if not set
hub: ""
# -- Specify image pull policy if default behavior isn't desired.
# Default behavior: latest images will be Always else IfNotPresent.
imagePullPolicy: ""

View File

@@ -71,6 +71,11 @@ spec:
items:
type: string
type: array
routeType:
enum:
- HTTP
- GRPC
type: string
service:
items:
type: string

View File

@@ -203,7 +203,7 @@ template:
{{- if $o11y.enabled }}
{{- $config := $o11y.promtail }}
- name: promtail
image: {{ $config.image.repository }}:{{ $config.image.tag }}
image: {{ $config.image.repository | default (printf "%s/promtail" .Values.global.hub) }}:{{ $config.image.tag }}
imagePullPolicy: IfNotPresent
args:
- -config.file=/etc/promtail/promtail.yaml
@@ -250,6 +250,10 @@ template:
tolerations:
{{- toYaml . | nindent 6 }}
{{- end }}
{{- with .Values.gateway.topologySpreadConstraints }}
topologySpreadConstraints:
{{- toYaml . | nindent 6 }}
{{- end }}
volumes:
- emptyDir: {}
name: workload-socket

View File

@@ -289,6 +289,10 @@ spec:
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.controller.topologySpreadConstraints }}
topologySpreadConstraints:
{{- toYaml . | nindent 8 }}
{{- end }}
volumes:
- name: log
emptyDir: {}

View File

@@ -5,6 +5,9 @@ metadata:
namespace: {{ .Release.Namespace }}
labels:
{{- include "gateway.labels" . | nindent 4}}
{{- with .Values.gateway.metrics.podMonitorSelector }}
{{- toYaml . | nindent 4 }}
{{- end }}
annotations:
{{- .Values.gateway.annotations | toYaml | nindent 4 }}
spec:

View File

@@ -24,9 +24,6 @@ spec:
{{- end }}
{{- with .Values.gateway.service.externalTrafficPolicy }}
externalTrafficPolicy: "{{ . }}"
{{- end }}
{{- with .Values.gateway.service.loadBalancerClass}}
loadBalancerClass: "{{ . }}"
{{- end }}
type: {{ .Values.gateway.service.type }}
ports:

View File

@@ -362,7 +362,7 @@ global:
enabled: false
promtail:
image:
repository: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/promtail
repository: "" # Will use global.hub if not set
tag: 2.9.4
port: 3101
resources:
@@ -377,7 +377,7 @@ global:
# The default value is "" and when caName="", the CA will be configured by other
# mechanisms (e.g., environmental variable CA_PROVIDER).
caName: ""
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
hub: "" # Will use global.hub if not set
clusterName: ""
# -- meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior
@@ -433,7 +433,7 @@ gateway:
# -- The readiness timeout seconds
readinessTimeoutSeconds: 3
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
hub: "" # Will use global.hub if not set
tag: ""
# -- revision declares which revision this gateway is a part of
revision: ""
@@ -522,12 +522,19 @@ gateway:
affinity: {}
topologySpreadConstraints: []
# -- If specified, the gateway will act as a network gateway for the given network.
networkGateway: ""
metrics:
# -- If true, create PodMonitor or VMPodScrape for gateway
enabled: false
# -- Selector for PodMonitor
# When using monitoring.coreos.com/v1.PodMonitor, the selector must match
# the label "release: kube-prome" is the default for kube-prometheus-stack
podMonitorSelector:
release: kube-prome
# -- provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com
provider: monitoring.coreos.com
interval: ""
@@ -548,7 +555,7 @@ controller:
replicas: 1
image: higress
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
hub: "" # Will use global.hub if not set
tag: ""
env: {}
@@ -624,6 +631,8 @@ controller:
affinity: {}
topologySpreadConstraints: []
autoscaling:
enabled: false
minReplicas: 1
@@ -642,7 +651,7 @@ pilot:
rollingMaxSurge: 100%
rollingMaxUnavailable: 25%
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
hub: "" # Will use global.hub if not set
tag: ""
# -- Can be a full hub/image:tag
@@ -795,7 +804,7 @@ pluginServer:
replicas: 2
image: plugin-server
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
hub: "" # Will use global.hub if not set
tag: ""
imagePullSecrets: []

View File

@@ -1,9 +1,9 @@
dependencies:
- name: higress-core
repository: file://../core
version: 2.1.9
version: 2.1.10
- name: higress-console
repository: https://higress.io/helm-charts/
version: 2.1.9
digest: sha256:d696af6726b40219cc16e7cf8de7400101479dfbd8deb3101d7ee736415b9875
generated: "2025-11-13T16:33:49.721553+08:00"
digest: sha256:fbb896461a8bdc1d5a4f8403253a59497b3b7a13909e9b92a4f3ce3f4f8d999d
generated: "2026-02-03T16:05:30.300315+08:00"

View File

@@ -1,5 +1,5 @@
apiVersion: v2
appVersion: 2.1.9
appVersion: 2.1.10
description: Helm chart for deploying Higress gateways
icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/
@@ -12,9 +12,9 @@ sources:
dependencies:
- name: higress-core
repository: "file://../core"
version: 2.1.9
version: 2.1.10
- name: higress-console
repository: "https://higress.io/helm-charts/"
version: 2.1.9
type: application
version: 2.1.9
version: 2.1.10

View File

@@ -44,7 +44,7 @@ The command removes all the Kubernetes components associated with the chart and
| controller.autoscaling.minReplicas | int | `1` | |
| controller.autoscaling.targetCPUUtilizationPercentage | int | `80` | |
| controller.env | object | `{}` | |
| controller.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| controller.hub | string | `""` | |
| controller.image | string | `"higress"` | |
| controller.imagePullSecrets | list | `[]` | |
| controller.labels | object | `{}` | |
@@ -83,6 +83,7 @@ The command removes all the Kubernetes components associated with the chart and
| controller.serviceAccount.name | string | `""` | If not set and create is true, a name is generated using the fullname template |
| controller.tag | string | `""` | |
| controller.tolerations | list | `[]` | |
| controller.topologySpreadConstraints | list | `[]` | |
| downstream | object | `{"connectionBufferLimits":32768,"http2":{"initialConnectionWindowSize":1048576,"initialStreamWindowSize":65535,"maxConcurrentStreams":100},"idleTimeout":180,"maxRequestHeadersKb":60,"routeTimeout":0}` | Downstream config settings |
| gateway.affinity | object | `{}` | |
| gateway.annotations | object | `{}` | Annotations to apply to all resources |
@@ -95,7 +96,7 @@ The command removes all the Kubernetes components associated with the chart and
| gateway.hostNetwork | bool | `false` | |
| gateway.httpPort | int | `80` | |
| gateway.httpsPort | int | `443` | |
| gateway.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| gateway.hub | string | `""` | |
| gateway.image | string | `"gateway"` | |
| gateway.kind | string | `"Deployment"` | Use a `DaemonSet` or `Deployment` |
| gateway.labels | object | `{}` | Labels to apply to all resources |
@@ -104,6 +105,7 @@ The command removes all the Kubernetes components associated with the chart and
| gateway.metrics.interval | string | `""` | |
| gateway.metrics.metricRelabelConfigs | list | `[]` | for operator.victoriametrics.com/v1beta1.VMPodScrape |
| gateway.metrics.metricRelabelings | list | `[]` | for monitoring.coreos.com/v1.PodMonitor |
| gateway.metrics.podMonitorSelector | object | `{"release":"kube-prome"}` | Selector for PodMonitor When using monitoring.coreos.com/v1.PodMonitor, the selector must match the label "release: kube-prome" is the default for kube-prometheus-stack |
| gateway.metrics.provider | string | `"monitoring.coreos.com"` | provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com |
| gateway.metrics.rawSpec | object | `{}` | some more raw podMetricsEndpoints spec |
| gateway.metrics.relabelConfigs | list | `[]` | |
@@ -151,6 +153,7 @@ The command removes all the Kubernetes components associated with the chart and
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
| gateway.tag | string | `""` | |
| gateway.tolerations | list | `[]` | |
| gateway.topologySpreadConstraints | list | `[]` | |
| gateway.unprivilegedPortSupported | string | `nil` | |
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |
@@ -191,7 +194,7 @@ The command removes all the Kubernetes components associated with the chart and
| global.multiCluster.clusterName | string | `""` | Should be set to the name of the cluster this installation will run in. This is required for sidecar injection to properly label proxies |
| global.multiCluster.enabled | bool | `true` | Set to true to connect two kubernetes clusters via their respective ingressgateway services when pods in each cluster cannot directly talk to one another. All clusters should be using Istio mTLS and must have a shared root CA for this model to work. |
| global.network | string | `""` | Network defines the network this cluster belong to. This name corresponds to the networks in the map of mesh networks. |
| global.o11y | object | `{"enabled":false,"promtail":{"image":{"repository":"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/promtail","tag":"2.9.4"},"port":3101,"resources":{"limits":{"cpu":"500m","memory":"2Gi"}},"securityContext":{}}}` | Observability (o11y) configurations |
| global.o11y | object | `{"enabled":false,"promtail":{"image":{"repository":"","tag":"2.9.4"},"port":3101,"resources":{"limits":{"cpu":"500m","memory":"2Gi"}},"securityContext":{}}}` | Observability (o11y) configurations |
| global.omitSidecarInjectorConfigMap | bool | `false` | |
| global.onDemandRDS | bool | `false` | |
| global.oneNamespace | bool | `false` | Whether to restrict the applications namespace the controller manages; If not set, controller watches all namespaces |
@@ -243,7 +246,7 @@ The command removes all the Kubernetes components associated with the chart and
| global.watchNamespace | string | `""` | If not empty, Higress Controller will only watch resources in the specified namespace. When isolating different business systems using K8s namespace, if each namespace requires a standalone gateway instance, this parameter can be used to confine the Ingress watching of Higress within the given namespace. |
| global.xdsMaxRecvMsgSize | string | `"104857600"` | |
| gzip | object | `{"chunkSize":4096,"compressionLevel":"BEST_COMPRESSION","compressionStrategy":"DEFAULT_STRATEGY","contentType":["text/html","text/css","text/plain","text/xml","application/json","application/javascript","application/xhtml+xml","image/svg+xml"],"disableOnEtagHeader":true,"enable":true,"memoryLevel":5,"minContentLength":1024,"windowBits":12}` | Gzip compression settings |
| hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| hub | string | `""` | |
| meshConfig | object | `{"enablePrometheusMerge":true,"rootNamespace":null,"trustDomain":"cluster.local"}` | meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior See https://istio.io/docs/reference/config/istio.mesh.v1alpha1/ for all available options |
| meshConfig.rootNamespace | string | `nil` | The namespace to treat as the administrative root namespace for Istio configuration. When processing a leaf namespace Istio will search for declarations in that namespace first and if none are found it will search in the root namespace. Any matching declaration found in the root namespace is processed as if it were declared in the leaf namespace. |
| meshConfig.trustDomain | string | `"cluster.local"` | The trust domain corresponds to the trust root of a system Refer to https://github.com/spiffe/spiffe/blob/master/standards/SPIFFE-ID.md#21-trust-domain |
@@ -260,7 +263,7 @@ The command removes all the Kubernetes components associated with the chart and
| pilot.env.PILOT_ENABLE_METADATA_EXCHANGE | string | `"false"` | |
| pilot.env.PILOT_SCOPE_GATEWAY_TO_NAMESPACE | string | `"false"` | |
| pilot.env.VALIDATION_ENABLED | string | `"false"` | |
| pilot.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| pilot.hub | string | `""` | |
| pilot.image | string | `"pilot"` | Can be a full hub/image:tag |
| pilot.jwksResolverExtraRootCA | string | `""` | You can use jwksResolverExtraRootCA to provide a root certificate in PEM format. This will then be trusted by pilot when resolving JWKS URIs. |
| pilot.keepaliveMaxServerConnectionAge | string | `"30m"` | The following is used to limit how long a sidecar can be connected to a pilot. It balances out load across pilot instances at the cost of increasing system churn. |
@@ -275,7 +278,7 @@ The command removes all the Kubernetes components associated with the chart and
| pilot.serviceAnnotations | object | `{}` | |
| pilot.tag | string | `""` | |
| pilot.traceSampling | float | `1` | |
| pluginServer.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
| pluginServer.hub | string | `""` | |
| pluginServer.image | string | `"plugin-server"` | |
| pluginServer.imagePullSecrets | list | `[]` | |
| pluginServer.labels | object | `{}` | |

View File

@@ -112,6 +112,7 @@ helm delete higress -n higress-system
| gateway.metrics.rawSpec | object | `{}` | 额外的度量规范 |
| gateway.metrics.relabelConfigs | list | `[]` | 重新标签配置 |
| gateway.metrics.relabelings | list | `[]` | 重新标签项 |
| gateway.metrics.podMonitorSelector | object | `{"release":"kube-prometheus-stack"}` | PodMonitor 选择器,当使用 prometheus stack 的podmonitor自动发现时选择器必须匹配标签 "release: kube-prome",这是 kube-prometheus-stack 的默认设置 |
| gateway.metrics.scrapeTimeout | string | `""` | 抓取的超时时间 |
| gateway.name | string | `"higress-gateway"` | 网关名称 |
| gateway.networkGateway | string | `""` | 网络网关指定 |

View File

@@ -53,7 +53,6 @@ require (
github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
github.com/cockroachdb/redact v1.1.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/deckarep/golang-set v1.7.1 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect

View File

@@ -185,10 +185,7 @@ github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TB
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-faker/faker/v4 v4.1.0 h1:ffuWmpDrducIUOO0QSKSF5Q2dxAht+dhsT9FvVHhPEI=
github.com/go-faker/faker/v4 v4.1.0/go.mod h1:uuNc0PSRxF8nMgjGrrrU4Nw5cF30Jc6Kd0/FUTTYbhg=
github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw=
github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw=
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
@@ -429,7 +426,6 @@ github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKf
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-ops"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/tool-search"
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
xds "github.com/cncf/xds/go/xds/type/v3"

View File

@@ -0,0 +1,144 @@
# Tool Search MCP Server
这是一个基于 Higress Golang Filter 实现的 MCP Server用于提供工具语义搜索功能。当前实现**仅支持向量语义搜索**(基于 Milvus 向量数据库),**不包含全文检索或混合搜索**。
## 功能特性
- **向量语义搜索**:使用 OpenAI 兼容的 Embedding API 将用户查询转换为向量,并在 Milvus 中进行相似度检索
- **工具元数据支持**从数据库中读取完整的工具定义JSON 格式),并动态拼接工具名称
- **全量工具列表**:支持获取数据库中所有可用工具
- **可配置 Embedding 模型**:支持自定义模型、维度及 API 端点(如 DashScope
- **Milvus 集成**:通过标准 gRPC 接口连接 Milvus 向量数据库
## 数据库要求Milvus
本服务依赖 **Milvus 向量数据库**需预先创建集合Collection其 Schema 应包含以下字段:
| 字段名 | 类型 | 说明 |
|--------------|-------------------|-------------------------|
| `id` | VarChar(64) | 文档唯一 ID |
| `content` | VarChar(64) | 工具描述文本 |
| `metadata` | JSON | 完整的工具定义(必须包含 `name` 字段) |
| `vector` | FloatVector(1024) | embedding 向量 |
| `metadata` | Int64 | 创建时间 |
## 配置参数
### 根级配置
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|--------------|--------|------|-----------------------------------------------------|------|
| `vector` | object | 是 | - | 向量数据库配置(见下文) |
| `embedding` | object | 是 | - | Embedding API 配置(见下文) |
| `description`| string | 否 | `"Tool search server for semantic similarity search"` | MCP Server 描述信息 |
### Vector 配置(`vector` 对象)
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|-------------|--------|------|--------------------|------|
| `type` | string | 是 | - | **必须为 `"milvus"`** |
| `host` | string | 是 | - | Milvus 服务地址(如 `localhost` |
| `port` | int | 是 | - | Milvus gRPC 端口(如 `19530` |
| `database` | string | 否 | `"default"` | Milvus 数据库名 |
| `tableName` | string | 否 | `"apig_mcp_tools"` | Milvus 集合名 |
| `username` | string | 否 | - | 认证用户名(可选) |
| `password` | string | 否 | - | 认证密码(可选) |
### Embedding 配置(`embedding` 对象)
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|--------------|--------|------|-----------------------------------------------------------|------|
| `apiKey` | string | 是 | - | Embedding 服务的 API Key |
| `baseURL` | string | 否 | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI 兼容 API 的 Base URL |
| `model` | string | 否 | `text-embedding-v4` | 使用的 Embedding 模型 |
| `dimensions` | int | 否 | `1024` | 向量维度 |
## 配置示例
Tool Search MCP Server 也可以作为 Higress 的一个模块进行配置。以下是一个在 Higress ConfigMap 中配置 Tool Search 的示例:
```yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: higress-config
namespace: higress-system
data:
higress: |
mcpServer:
enable: true
sse_path_suffix: "/sse"
redis:
address: "<Redis IP>:6379"
username: ""
password: ""
db: 0
match_list:
- path_rewrite_prefix: ""
upstream_type: ""
enable_path_rewrite: false
match_rule_domain: "*"
match_rule_path: "/mcp-servers/tool-search"
match_rule_type: "prefix"
servers:
- path: "/mcp-servers/tool-search"
name: "tool-search"
type: "tool-search"
config:
vector:
type: "milvus"
host: "localhost"
port: 19530
database: "default"
tableName: "apig_mcp_tools"
username: "root"
password: "Milvus"
maxTools: 1000
embedding:
apiKey: "your-dashscope-api-key"
baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1"
model: "text-embedding-v4"
dimensions: 1024
description: "Higress 工具语义搜索服务"
```
## 工具搜索接口
Tool Search MCP Server 提供以下 MCP 工具:
### x_higress_tool_search
基于语义相似度搜索最相关的工具。
**输入参数**:
| 参数名 | 类型 | 必填 | 说明 |
|---------|--------|------|------|
| `query` | string | 是 | 查询语句,用于与工具描述进行语义相似度比较 |
| `topK` | int | 否 | 指定需要选择的工具数量默认选择前10个工具 |
**输出格式**:
```
{
"tools": [
{
"name": "server_name___tool_name",
"title": "Tool Title",
"description": "Tool description",
"inputSchema": {...},
"outputSchema": {...}
}
]
}
```
## 搜索实现
通过向量相似度进行搜索,索引配置如下
- 使用 HNSW 索引算法进行向量索引
- 默认参数M=8, efConstruction=64
- 相似度度量方式内积IP

View File

@@ -0,0 +1,18 @@
{
"vector": {
"type": "milvus",
"vectorWeight": 0.5,
"tableName": "apig_mcp_tools",
"host": "localhost",
"port": 19530,
"database": "default",
"username": "root",
"password": "Milvus"
},
"embedding": {
"apiKey": "your-dashscope-api-key",
"baseURL": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "text-embedding-v4",
"dimensions": 1024
}
}

View File

@@ -0,0 +1,79 @@
package tool_search
import (
"context"
"fmt"
"time"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
)
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
type EmbeddingClient struct {
client *openai.Client
model string
dimensions int
}
// NewEmbeddingClient creates a new EmbeddingClient instance for OpenAI-compatible APIs
func NewEmbeddingClient(apiKey, baseURL, model string, dimensions int) *EmbeddingClient {
api.LogInfof("Creating EmbeddingClient with baseURL: %s, model: %s, dimensions: %d", baseURL, model, dimensions)
// Create client with timeout
client := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(baseURL),
option.WithRequestTimeout(30*time.Second),
)
return &EmbeddingClient{
client: &client,
model: model,
dimensions: dimensions,
}
}
// GetEmbedding generates vector embedding for the given text
func (e *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
api.LogInfof("Generating embedding for text (length: %d)", len(text))
api.LogDebugf("Using model: %s, dimensions: %d", e.model, e.dimensions)
// Add timeout to context if not already present
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
params := openai.EmbeddingNewParams{
Model: e.model,
Input: openai.EmbeddingNewParamsInputUnion{
OfString: openai.String(text),
},
Dimensions: openai.Int(int64(e.dimensions)),
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
}
api.LogDebugf("Calling OpenAI-compatible API for embedding generation")
embeddingResp, err := e.client.Embeddings.New(ctx, params)
if err != nil {
api.LogErrorf("OpenAI-compatible API call failed: %v", err)
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
if len(embeddingResp.Data) == 0 {
api.LogErrorf("Empty embedding response from API")
return nil, fmt.Errorf("empty embedding response")
}
api.LogDebugf("Successfully received embedding from API")
api.LogDebugf("Response data length: %d, embedding dimension: %d", len(embeddingResp.Data), len(embeddingResp.Data[0].Embedding))
// Convert []float64 to []float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
}
api.LogInfof("Embedding conversion completed, final dimension: %d", len(embedding))
return embedding, nil
}

View File

@@ -0,0 +1,204 @@
package tool_search
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)
type MilvusVectorStoreProvider struct {
client client.Client
collection string
dimensions int
}
func NewMilvusVectorStoreProvider(cfg *config.VectorDBConfig, dimensions int) (*MilvusVectorStoreProvider, error) {
connectParam := client.Config{
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
}
connectParam.DBName = cfg.Database
if cfg.Username != "" && cfg.Password != "" {
connectParam.Username = cfg.Username
connectParam.Password = cfg.Password
}
milvusClient, err := client.NewClient(context.Background(), connectParam)
if err != nil {
return nil, fmt.Errorf("failed to create milvus client: %w", err)
}
return &MilvusVectorStoreProvider{
client: milvusClient,
collection: cfg.Collection,
dimensions: dimensions,
}, nil
}
func (c *MilvusVectorStoreProvider) ListAllDocs(ctx context.Context, limit int) ([]schema.Document, error) {
expr := ""
outputFields := []string{"id", "content", "metadata", "created_at"}
var queryResult []entity.Column
var err error
if limit > 0 {
queryResult, err = c.client.Query(
ctx,
c.collection,
[]string{}, // partitions
expr, // filter condition
outputFields,
client.WithLimit(int64(limit)),
)
} else {
queryResult, err = c.client.Query(
ctx,
c.collection,
[]string{}, // partitions
expr, // filter condition
outputFields,
)
}
if err != nil {
return nil, fmt.Errorf("failed to query all documents: %w", err)
}
if len(queryResult) == 0 {
return []schema.Document{}, nil
}
rowCount := queryResult[0].Len()
documents := make([]schema.Document, 0, rowCount)
for i := 0; i < rowCount; i++ {
var (
id string
content string
metadata map[string]interface{}
createdAt int64
)
for _, col := range queryResult {
switch col.Name() {
case "id":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
id = v.(string)
}
case "content":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
content = v.(string)
}
case "metadata":
if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil {
if bytes, ok := v.([]byte); ok {
_ = json.Unmarshal(bytes, &metadata)
}
}
case "created_at":
if v, err := col.(*entity.ColumnInt64).Get(i); err == nil {
createdAt = v.(int64)
}
}
}
doc := schema.Document{
ID: id,
Content: content,
Metadata: metadata,
CreatedAt: time.UnixMilli(createdAt),
}
documents = append(documents, doc)
}
return documents, nil
}
func (c *MilvusVectorStoreProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
if options == nil {
options = &schema.SearchOptions{TopK: 10}
}
sp, err := entity.NewIndexHNSWSearchParam(16) // 默认 HNSW 搜索参数
if err != nil {
return nil, fmt.Errorf("failed to build search param: %w", err)
}
outputFields := []string{"id", "content", "metadata"}
searchResults, err := c.client.Search(
ctx,
c.collection,
[]string{}, // partition names
"", // filter expression
outputFields, // output fields
[]entity.Vector{entity.FloatVector(vector)},
"vector", // anns_field
entity.IP, // metric_type
options.TopK,
sp,
)
if err != nil {
return nil, fmt.Errorf("failed to search documents: %w", err)
}
var results []schema.SearchResult
for _, result := range searchResults {
for i := 0; i < result.ResultCount; i++ {
id, _ := result.IDs.Get(i)
score := result.Scores[i]
var content string
var metadata map[string]interface{}
for _, field := range result.Fields {
switch field.Name() {
case "content":
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
if contentVal, err := contentCol.Get(i); err == nil {
if contentStr, ok := contentVal.(string); ok {
content = contentStr
}
}
}
case "metadata":
if metaCol, ok := field.(*entity.ColumnJSONBytes); ok {
if metaVal, err := metaCol.Get(i); err == nil {
if metaBytes, ok := metaVal.([]byte); ok {
if err := json.Unmarshal(metaBytes, &metadata); err != nil {
metadata = make(map[string]interface{})
}
}
}
}
}
}
searchResult := schema.SearchResult{
Document: schema.Document{
ID: fmt.Sprintf("%s", id),
Content: content,
Metadata: metadata,
},
Score: float64(score),
}
results = append(results, searchResult)
}
}
return results, nil
}
func (c *MilvusVectorStoreProvider) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}

View File

@@ -0,0 +1,237 @@
package tool_search
import (
"context"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
// SearchService handles tool search operations
type SearchService struct {
milvusProvider *MilvusVectorStoreProvider
config *config.VectorDBConfig
tableName string
dimensions int
maxTools int // 写死的最大工具数量,仅用于单测
embeddingClient *EmbeddingClient
}
// NewSearchService creates a new SearchService instance
func NewSearchService(host string, port int, database, username, password, tableName string, embeddingClient *EmbeddingClient, dimensions int, maxTools int) *SearchService {
// Create Milvus configuration
cfg := &config.VectorDBConfig{
Provider: "milvus",
Host: host,
Port: port,
Database: database,
Collection: tableName,
Username: username,
Password: password,
}
// Create Milvus provider
provider, err := NewMilvusVectorStoreProvider(cfg, dimensions)
if err != nil {
api.LogErrorf("Failed to create Milvus provider: %v", err)
return nil
}
return &SearchService{
milvusProvider: provider,
config: cfg,
tableName: tableName,
dimensions: dimensions,
maxTools: maxTools, // 使用写死的值
embeddingClient: embeddingClient,
}
}
// ToolSearchResult represents the result of a tool search
type ToolSearchResult struct {
Tools []ToolDefinition `json:"tools"`
}
// ToolDefinition represents a tool definition in the search result
type ToolDefinition map[string]interface{}
// SearchTools performs semantic search for tools
func (s *SearchService) SearchTools(ctx context.Context, query string, topK int) (*ToolSearchResult, error) {
api.LogInfof("Starting tool search for query: '%s', topK: %d", query, topK)
// Generate vector embedding for the query
vector, err := s.embeddingClient.GetEmbedding(ctx, query)
if err != nil {
api.LogErrorf("Failed to generate embedding for query '%s': %v", query, err)
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
api.LogInfof("Embedding generated successfully, vector dimension: %d", len(vector))
// Perform vector search
records, err := s.searchToolsInDB(query, vector, topK)
if err != nil {
api.LogErrorf("Failed to search tools: %v", err)
return nil, fmt.Errorf("failed to search tools: %w", err)
}
api.LogInfof("Vector search completed, found %d records", len(records))
return s.convertRecordsToResult(records), nil
}
// convertRecordsToResult converts database records to tool search result
func (s *SearchService) convertRecordsToResult(records []ToolRecord) *ToolSearchResult {
api.LogInfof("Converting %d records to tool definitions", len(records))
tools := make([]ToolDefinition, 0, len(records))
for i, record := range records {
var tool ToolDefinition
// Use metadata if available
if len(record.Metadata) > 0 {
tool = record.Metadata
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
} else {
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
// If no metadata, create a basic tool definition
tool = ToolDefinition{
"name": record.Name,
"description": record.Content,
}
}
// Update the name to include server name
tool["name"] = fmt.Sprintf("%s", record.Name)
tools = append(tools, tool)
api.LogDebugf("Tool %d: %s - %s", i+1, tool["name"], record.Content)
}
api.LogInfof("Successfully converted %d tools", len(tools))
return &ToolSearchResult{Tools: tools}
}
// GetAllTools retrieves all available tools
func (s *SearchService) GetAllTools() (*ToolSearchResult, error) {
api.LogInfo("Retrieving all tools")
records, err := s.getAllToolsFromDB()
if err != nil {
api.LogErrorf("Failed to get all tools: %v", err)
return nil, fmt.Errorf("failed to get all tools: %w", err)
}
api.LogInfof("Found %d tools in database", len(records))
// Convert records to tool definitions
tools := make([]ToolDefinition, 0, len(records))
for _, record := range records {
var tool ToolDefinition
// Use metadata if available
if len(record.Metadata) > 0 {
tool = record.Metadata
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
} else {
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
// If no metadata, create a basic tool definition
tool = ToolDefinition{
"name": record.Name,
"description": record.Content,
}
}
// Update the name to include server name
tool["name"] = fmt.Sprintf("%s", record.Name)
tools = append(tools, tool)
}
api.LogInfof("Successfully converted %d tools", len(tools))
return &ToolSearchResult{Tools: tools}, nil
}
// ToolRecord represents a tool record in the database
type ToolRecord struct {
ID string `json:"id"`
Name string `json:"name"`
Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata"`
}
func (s *SearchService) searchToolsInDB(query string, vector []float32, topK int) ([]ToolRecord, error) {
api.LogInfof("Performing vector search for query: '%s', topK: %d", query, topK)
// For Milvus, we'll perform vector search directly
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Perform vector search
searchOptions := &schema.SearchOptions{
TopK: topK,
}
results, err := s.milvusProvider.SearchDocs(ctx, vector, searchOptions)
if err != nil {
api.LogErrorf("Vector search failed: %v", err)
return nil, fmt.Errorf("failed to perform vector search: %w", err)
}
// Convert results to ToolRecords
var records []ToolRecord
for _, result := range results {
doc := result.Document
tool := ToolRecord{
ID: doc.ID,
Content: doc.Content,
Metadata: doc.Metadata,
}
if name, ok := doc.Metadata["name"].(string); ok {
tool.Name = name
}
records = append(records, tool)
}
api.LogInfof("Vector search completed, found %d results", len(records))
return records, nil
}
// getAllToolsFromDB retrieves all tools from the database
func (s *SearchService) getAllToolsFromDB() ([]ToolRecord, error) {
api.LogInfof("Executing GetAllTools query from collection: %s", s.tableName)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Retrieve all documents with limit
docs, err := s.milvusProvider.ListAllDocs(ctx, s.maxTools)
if err != nil {
api.LogErrorf("Failed to list documents: %v", err)
return nil, fmt.Errorf("failed to list documents: %w", err)
}
// Convert documents to ToolRecords
var tools []ToolRecord
for _, doc := range docs {
tool := ToolRecord{
ID: doc.ID,
Content: doc.Content,
Metadata: doc.Metadata,
}
if name, ok := doc.Metadata["name"].(string); ok {
tool.Name = name
}
tools = append(tools, tool)
}
api.LogInfof("GetAllTools query completed, found %d tools", len(tools))
return tools, nil
}

View File

@@ -0,0 +1,196 @@
package tool_search
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
const (
Version = "1.0.0"
// 默认配置值
defaultTableName = "apig_mcp_tools"
defaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
defaultModel = "text-embedding-v4"
defaultDimensions = 1024
// 写死最大工具数量为1000仅用于单测
fixedMaxTools = 1000
)
func init() {
common.GlobalRegistry.RegisterServer("tool-search", &ToolSearchConfig{})
}
type VectorConfig struct {
Type string `json:"type"`
VectorWeight float64 `json:"vectorWeight"`
TableName string `json:"tableName"`
Host string `json:"host"`
Port int `json:"port"`
Database string `json:"database"`
Username string `json:"username"`
Password string `json:"password"`
}
type EmbeddingConfig struct {
APIKey string `json:"apiKey"`
BaseURL string `json:"baseURL"`
Model string `json:"model"`
Dimensions int `json:"dimensions"`
}
type ToolSearchConfig struct {
Vector VectorConfig `json:"vector"`
Embedding EmbeddingConfig `json:"embedding"`
description string
}
func (c *ToolSearchConfig) ParseConfig(config map[string]any) error {
// Parse vector configuration
vectorConfig, ok := config["vector"].(map[string]any)
if !ok {
return errors.New("missing vector configuration")
}
if err := c.parseVectorConfig(vectorConfig); err != nil {
return fmt.Errorf("failed to parse vector config: %w", err)
}
// Parse embedding configuration
embeddingConfig, ok := config["embedding"].(map[string]any)
if !ok {
return errors.New("missing embedding configuration")
}
if err := c.parseEmbeddingConfig(embeddingConfig); err != nil {
return fmt.Errorf("failed to parse embedding config: %w", err)
}
// Optional description
if description, ok := config["description"].(string); ok {
c.description = description
} else {
c.description = "Tool search server for semantic similarity search"
}
api.LogDebugf("ToolSearchConfig ParseConfig: %+v", config)
return nil
}
func (c *ToolSearchConfig) parseVectorConfig(config map[string]any) error {
if vectorType, ok := config["type"].(string); ok {
c.Vector.Type = vectorType
} else {
return errors.New("missing vector.type")
}
if c.Vector.Type != "milvus" {
return fmt.Errorf("unsupported vector.type: %s, only 'milvus' is supported", c.Vector.Type)
}
if host, ok := config["host"].(string); ok {
c.Vector.Host = host
} else {
return errors.New("missing vector.host")
}
if port, ok := config["port"].(float64); ok {
c.Vector.Port = int(port)
} else if port, ok := config["port"].(int); ok {
c.Vector.Port = port
} else {
return errors.New("missing vector.port")
}
if database, ok := config["database"].(string); ok {
c.Vector.Database = database
} else {
c.Vector.Database = "default" // 默认数据库
}
if tableName, ok := config["tableName"].(string); ok {
c.Vector.TableName = tableName
} else {
c.Vector.TableName = defaultTableName
}
if username, ok := config["username"].(string); ok {
c.Vector.Username = username
}
if password, ok := config["password"].(string); ok {
c.Vector.Password = password
}
// 移除maxTools的解析逻辑
return nil
}
func (c *ToolSearchConfig) parseEmbeddingConfig(config map[string]any) error {
// Parse API key (required)
if apiKey, ok := config["apiKey"].(string); ok {
c.Embedding.APIKey = apiKey
} else {
return errors.New("missing embedding.apiKey")
}
// Parse optional fields with defaults
if baseURL, ok := config["baseURL"].(string); ok {
c.Embedding.BaseURL = baseURL
} else {
c.Embedding.BaseURL = defaultBaseURL
}
if model, ok := config["model"].(string); ok {
c.Embedding.Model = model
} else {
c.Embedding.Model = defaultModel
}
if dimensions, ok := config["dimensions"].(float64); ok {
c.Embedding.Dimensions = int(dimensions)
} else if dimensions, ok := config["dimensions"].(int); ok {
c.Embedding.Dimensions = dimensions
} else {
c.Embedding.Dimensions = defaultDimensions
}
return nil
}
func (c *ToolSearchConfig) NewServer(serverName string) (*common.MCPServer, error) {
mcpServer := common.NewMCPServer(
serverName,
Version,
common.WithInstructions(c.description),
)
// Create embedding client
embeddingClient := NewEmbeddingClient(c.Embedding.APIKey, c.Embedding.BaseURL, c.Embedding.Model, c.Embedding.Dimensions)
// Create search service使用写死的fixedMaxTools值
searchService := NewSearchService(
c.Vector.Host,
c.Vector.Port,
c.Vector.Database,
c.Vector.Username,
c.Vector.Password,
c.Vector.TableName,
embeddingClient,
c.Embedding.Dimensions,
fixedMaxTools, // 使用写死的值
)
// Add tool search tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("x_higress_tool_search", "Higress MCP Tools Searcher", GetToolSearchSchema()),
HandleToolSearch(searchService),
)
return mcpServer, nil
}

View File

@@ -0,0 +1,198 @@
package tool_search
import (
"context"
"encoding/json"
"fmt"
"os"
"testing"
"time"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
// Mock implementation of CommonCAPI for testing
type mockCommonCAPI struct {
logs []string
}
func (m *mockCommonCAPI) Log(level api.LogType, message string) {
fmt.Printf("[%s] %s\n", level, message)
m.logs = append(m.logs, message)
}
func (m *mockCommonCAPI) LogLevel() api.LogType {
return api.Debug
}
// TestServer is used for local functional testing
func TestServer(t *testing.T) {
// Setup mock API for logging
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
// Load configuration from environment variables or use defaults
config := map[string]any{
"vector": map[string]any{
"type": "milvus",
"vectorWeight": 0.6,
"tableName": getEnvOrDefault("TEST_TABLE_NAME", "apig_mcp_tools"),
"host": getEnvOrDefault("TEST_MILVUS_HOST", "localhost"),
"port": getEnvOrDefaultInt("TEST_MILVUS_PORT", 19530),
"database": getEnvOrDefault("TEST_MILVUS_DATABASE", "default"),
"username": getEnvOrDefault("TEST_MILVUS_USERNAME", "root"),
"password": getEnvOrDefault("TEST_MILVUS_PASSWORD", "Milvus"),
"maxTools": getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
},
"embedding": map[string]any{
"apiKey": getEnvOrDefault("TEST_API_KEY", "your-dashscope-api-key"),
"baseURL": getEnvOrDefault("TEST_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
"model": getEnvOrDefault("TEST_MODEL", "text-embedding-v4"),
"dimensions": 1024,
},
"description": "Test MCP Tools Search Server",
}
// Create configuration instance
toolSearchConfig := &ToolSearchConfig{}
if err := toolSearchConfig.ParseConfig(config); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create MCP Server
_, err := toolSearchConfig.NewServer("test-tool-search")
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Test database connection
vectorConfig := config["vector"].(map[string]any)
embeddingConfig := config["embedding"].(map[string]any)
// Test GetAllTools
t.Logf("\n=== Testing GetAllTools ===")
embeddingClient := NewEmbeddingClient(
embeddingConfig["apiKey"].(string),
embeddingConfig["baseURL"].(string),
embeddingConfig["model"].(string),
embeddingConfig["dimensions"].(int),
)
searchService := NewSearchService(
vectorConfig["host"].(string),
vectorConfig["port"].(int),
vectorConfig["database"].(string),
vectorConfig["username"].(string),
vectorConfig["password"].(string),
vectorConfig["tableName"].(string),
embeddingClient,
embeddingConfig["dimensions"].(int),
getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
)
allTools, err := searchService.GetAllTools()
if err != nil {
t.Logf("GetAllTools failed: %v", err)
} else {
t.Logf("Found %d tools:", len(allTools.Tools))
for i, tool := range allTools.Tools {
if i < 3 { // Show only first 3 tools
toolJSON, _ := json.MarshalIndent(tool, "", " ")
t.Logf("Tool %d: %s", i+1, string(toolJSON))
}
}
if len(allTools.Tools) > 3 {
t.Logf("... and %d more tools", len(allTools.Tools)-3)
}
}
// Test tool search with timing
t.Logf("\n=== Testing Tool Search ===")
testQueries := []string{
"weather data",
"database query",
"file operations",
"HTTP requests",
"library documents",
}
for _, query := range testQueries {
t.Logf("\n--- Testing query: '%s' ---", query)
// Create MCP tool call request
request := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
Meta *struct {
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
}{
Name: "x_higress_tool_search",
Arguments: map[string]interface{}{
"query": query,
"topK": 3,
},
},
}
// Get tool handler
handler := HandleToolSearch(searchService)
// Execute search with timing
start := time.Now()
result, err := handler(context.Background(), request)
duration := time.Since(start)
if err != nil {
t.Logf("Search failed: %v", err)
continue
}
// Print results with timing information
t.Logf("Search completed in %v", duration)
if len(result.Content) > 0 {
if textContent, ok := result.Content[0].(mcp.TextContent); ok {
var toolsResult map[string]interface{}
if err := json.Unmarshal([]byte(textContent.Text), &toolsResult); err == nil {
toolsJSON, _ := json.MarshalIndent(toolsResult, "", " ")
t.Logf("Tools Result: %s", string(toolsJSON))
} else {
t.Logf("Text Content: %s", textContent.Text)
}
}
}
}
// Test configuration validation
t.Logf("\n=== Configuration Validation ===")
t.Logf("Host: %s", vectorConfig["host"])
t.Logf("Port: %d", vectorConfig["port"])
t.Logf("Database: %s", vectorConfig["database"])
t.Logf("Table Name: %s", vectorConfig["tableName"])
t.Logf("Vector Weight: %f", vectorConfig["vectorWeight"])
t.Logf("Text Weight: %f", 1.0-vectorConfig["vectorWeight"].(float64))
t.Logf("Model: %s", embeddingConfig["model"])
t.Logf("Dimensions: %d", embeddingConfig["dimensions"])
t.Logf("API Base URL: %s", embeddingConfig["baseURL"])
t.Logf("\n=== Test completed ===")
}
// Helper function to get environment variable or default value
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvOrDefaultInt(key string, defaultValue int) int {
if valueStr := os.Getenv(key); valueStr != "" {
if value, err := fmt.Sscanf(valueStr, "%d", &defaultValue); err == nil && value == 1 {
return defaultValue
}
}
return defaultValue
}

View File

@@ -0,0 +1,114 @@
package tool_search
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
// HandleToolSearch handles the x_higress_tool_search tool
func HandleToolSearch(searchService *SearchService) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
api.LogInfo("HandleToolSearch called")
arguments := request.Params.Arguments
api.LogDebugf("Request arguments: %+v", arguments)
// Get query parameter
query, ok := arguments["query"].(string)
if !ok {
api.LogErrorf("Invalid query argument type: %T", arguments["query"])
return nil, fmt.Errorf("invalid query argument")
}
// Validate query
if query == "" {
api.LogError("Empty query provided")
return nil, fmt.Errorf("query cannot be empty")
}
// Get topK parameter (optional, default to 10)
topK := 10
if topKVal, ok := arguments["topK"]; ok {
switch v := topKVal.(type) {
case float64:
topK = int(v)
case int:
topK = v
case int64:
topK = int(v)
default:
api.LogWarnf("Invalid topK argument type: %T, using default: %d", topKVal, topK)
}
// Validate topK range
if topK <= 0 || topK > 100 {
api.LogWarnf("Invalid topK value: %d, using default: 10", topK)
topK = 10
}
}
api.LogInfof("Parsed parameters - query: '%s', topK: %d", query, topK)
// Add timeout to context
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Perform search
result, err := searchService.SearchTools(ctx, query, topK)
if err != nil {
api.LogErrorf("Search failed: %v", err)
return nil, fmt.Errorf("failed to search tools: %w", err)
}
api.LogInfof("Search completed successfully, found %d tools", len(result.Tools))
// Build response
response := map[string]interface{}{
"tools": result.Tools,
}
jsonData, err := json.Marshal(response)
if err != nil {
api.LogErrorf("Failed to marshal response: %v", err)
return nil, fmt.Errorf("failed to marshal search results: %w", err)
}
api.LogDebugf("Response marshaled successfully, JSON length: %d", len(jsonData))
api.LogDebugf("Returning MCP CallToolResult")
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: string(jsonData),
},
},
}, nil
}
}
// GetToolSearchSchema returns the schema for the tool search tool
func GetToolSearchSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Query statement for semantic similarity comparison with tool descriptions"
},
"topK": {
"type": "integer",
"description": "Specify how many tools need to be selected, default is to select the top 10 tools.",
"minimum": 1,
"maximum": 100
}
},
"required": ["query"]
}`)
}

View File

@@ -26,8 +26,8 @@ type config struct {
matchList []common.MatchRule
enableUserLevelServer bool
rateLimitConfig *handler.MCPRatelimitConfig
defaultServer *common.SSEServer
redisClient *common.RedisClient
sharedMCPServer *common.MCPServer // Created once, thread-safe with sync.RWMutex
}
func (c *config) Destroy() {
@@ -110,6 +110,9 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
}
GlobalSSEPathSuffix = ssePathSuffix
// Create shared MCPServer once during config parsing (thread-safe with sync.RWMutex)
conf.sharedMCPServer = common.NewMCPServer(DefaultServerName, Version)
return conf, nil
}
@@ -125,9 +128,6 @@ func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
if childConfig.rateLimitConfig != nil {
newConfig.rateLimitConfig = childConfig.rateLimitConfig
}
if childConfig.defaultServer != nil {
newConfig.defaultServer = childConfig.defaultServer
}
return &newConfig
}

View File

@@ -37,6 +37,7 @@ type filter struct {
skipRequestBody bool
skipResponseBody bool
cachedResponseBody []byte
sseServer *common.SSEServer // SSE server instance for this filter (per-request, not shared)
userLevelConfig bool
mcpConfigHandler *handler.MCPConfigHandler
@@ -135,11 +136,13 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
trimmed += "?" + rq
}
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
// Create SSE server instance for this filter (per-request, not shared)
// MCPServer is shared (thread-safe), but SSEServer must be per-request (contains request-specific messageEndpoint)
f.sseServer = common.NewSSEServer(f.config.sharedMCPServer,
common.WithSSEEndpoint(GlobalSSEPathSuffix),
common.WithMessageEndpoint(trimmed),
common.WithRedisClient(f.config.redisClient))
f.serverName = f.config.defaultServer.GetServerName()
f.serverName = f.sseServer.GetServerName()
body := "SSE connection create"
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
}
@@ -238,10 +241,9 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
ret := api.Continue
api.LogDebugf("Upstream Type: %s", f.matchedRule.UpstreamType)
switch f.matchedRule.UpstreamType {
case common.RestUpstream, common.StreamableUpstream:
case common.RestUpstream:
api.LogDebugf("Encoding data from Rest upstream")
ret = f.encodeDataFromRestUpstream(buffer, endStream)
break
case common.SSEUpstream:
api.LogDebugf("Encoding data from SSE upstream")
ret = f.encodeDataFromSSEUpstream(buffer, endStream)
@@ -249,6 +251,8 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
// Always continue as long as the stream has ended.
ret = api.Continue
}
case common.StreamableUpstream:
// Do nothing for streamable upstream
}
return ret
}
@@ -274,9 +278,9 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
if f.serverName != "" {
if f.config.redisClient != nil {
// handle default server
// handle SSE server for this filter instance
buffer.Reset()
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
f.sseServer.HandleSSE(f.callbacks, f.stopChan)
return api.Running
} else {
_ = buffer.SetString(RedisNotEnabledResponseBody)

View File

@@ -0,0 +1 @@
BasedOnStyle: Google

View File

@@ -135,8 +135,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
return true;
}
void PluginRootContext::incrementRequestCount() {
request_count_++;
if (request_count_ >= REBUILD_THRESHOLD) {
LOG_DEBUG("Request count reached threshold, triggering rebuild");
setFilterState("wasm_need_rebuild", "true");
request_count_ = 0; // Reset counter after setting rebuild flag
}
}
FilterHeadersStatus PluginRootContext::onHeader(
const ModelMapperConfigRule& rule) {
// Increment request count and check for rebuild
incrementRequestCount();
// Check memory threshold and trigger rebuild if needed
std::string value;
if (getValue({"plugin_vm_memory"}, &value)) {
// The value is stored as binary uint64_t, convert to string for logging
if (value.size() == sizeof(uint64_t)) {
uint64_t memory_size;
memcpy(&memory_size, value.data(), sizeof(uint64_t));
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
setFilterState("wasm_need_rebuild", "true");
}
} else {
LOG_ERROR("invalid memory size format");
}
} else {
LOG_ERROR("get vm memory size failed");
}
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}

View File

@@ -42,9 +42,10 @@ struct ModelMapperConfigRule {
std::vector<std::pair<std::string, std::string>> prefix_model_mapping_;
std::string default_model_mapping_;
std::vector<std::string> enable_on_path_suffix_ = {
"/completions", "/embeddings", "/images/generations",
"/audio/speech", "/fine_tuning/jobs", "/moderations",
"/image-synthesis", "/video-synthesis"};
"/completions", "/embeddings", "/images/generations",
"/audio/speech", "/fine_tuning/jobs", "/moderations",
"/image-synthesis", "/video-synthesis", "/rerank",
"/messages"};
};
// PluginRootContext is the root context for all streams processed by the
@@ -60,9 +61,13 @@ class PluginRootContext : public RootContext,
FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
bool configure(size_t);
void incrementRequestCount();
private:
bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
uint64_t request_count_ = 0;
static constexpr uint64_t REBUILD_THRESHOLD = 1000;
static constexpr size_t MEMORY_THRESHOLD_BYTES = 200 * 1024 * 1024;
};
// Per-stream context.

View File

@@ -123,9 +123,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
return true;
}
void PluginRootContext::incrementRequestCount() {
request_count_++;
if (request_count_ >= REBUILD_THRESHOLD) {
LOG_DEBUG("Request count reached threshold, triggering rebuild");
setFilterState("wasm_need_rebuild", "true");
request_count_ = 0; // Reset counter after setting rebuild flag
}
}
FilterHeadersStatus PluginRootContext::onHeader(
PluginContext& ctx,
const ModelRouterConfigRule& rule) {
PluginContext& ctx, const ModelRouterConfigRule& rule) {
// Increment request count and check for rebuild
incrementRequestCount();
// Check memory threshold and trigger rebuild if needed
std::string value;
if (getValue({"plugin_vm_memory"}, &value)) {
// The value is stored as binary uint64_t, convert to string for logging
if (value.size() == sizeof(uint64_t)) {
uint64_t memory_size;
memcpy(&memory_size, value.data(), sizeof(uint64_t));
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
setFilterState("wasm_need_rebuild", "true");
}
} else {
LOG_ERROR("invalid memory size format");
}
} else {
LOG_ERROR("get vm memory size failed");
}
if (!Wasm::Common::Http::hasRequestBody()) {
return FilterHeadersStatus::Continue;
}
@@ -157,7 +188,7 @@ FilterHeadersStatus PluginRootContext::onHeader(
auto content_type_value = content_type_ptr->view();
LOG_DEBUG(absl::StrCat("Content-Type: ", content_type_value));
if (absl::StrContains(content_type_value,
Wasm::Common::Http::ContentTypeValues::Json)) {
Wasm::Common::Http::ContentTypeValues::Json)) {
ctx.mode_ = MODE_JSON;
LOG_DEBUG("Enable JSON mode.");
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
@@ -165,12 +196,15 @@ FilterHeadersStatus PluginRootContext::onHeader(
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
return FilterHeadersStatus::StopIteration;
}
if (absl::StrContains(content_type_value,
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
if (absl::StrContains(
content_type_value,
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
// Get the boundary from the content type
auto boundary_start = content_type_value.find("boundary=");
if (boundary_start == std::string::npos) {
LOG_WARN(absl::StrCat("No boundary found in a multipart/form-data content-type: ", content_type_value));
LOG_WARN(absl::StrCat(
"No boundary found in a multipart/form-data content-type: ",
content_type_value));
return FilterHeadersStatus::Continue;
}
boundary_start += 9;
@@ -181,21 +215,27 @@ FilterHeadersStatus PluginRootContext::onHeader(
auto boundary_length = boundary_end - boundary_start;
if (boundary_length < 1 || boundary_length > 70) {
// See https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
LOG_WARN(absl::StrCat("Invalid boundary value in a multipart/form-data content-type: ", content_type_value));
LOG_WARN(absl::StrCat(
"Invalid boundary value in a multipart/form-data content-type: ",
content_type_value));
return FilterHeadersStatus::Continue;
}
auto boundary_value = content_type_value.substr(boundary_start, boundary_end - boundary_start);
auto boundary_value = content_type_value.substr(
boundary_start, boundary_end - boundary_start);
ctx.mode_ = MODE_MULTIPART;
ctx.boundary_ = boundary_value;
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=", boundary_value));
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=",
boundary_value));
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
return FilterHeadersStatus::StopIteration;
}
return FilterHeadersStatus::Continue;
}
FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule,
std::string_view body) {
FilterDataStatus PluginRootContext::onJsonBody(
const ModelRouterConfigRule& rule, std::string_view body) {
const auto& model_key = rule.model_key_;
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
@@ -231,18 +271,18 @@ FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule
}
FilterDataStatus PluginRootContext::onMultipartBody(
PluginContext& ctx,
const ModelRouterConfigRule& rule,
WasmDataPtr& body,
PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body,
bool end_stream) {
const auto& add_provider_header = rule.add_provider_header_;
const auto& model_to_header = rule.model_to_header_;
const auto boundary = ctx.boundary_;
const auto body_view = body->view();
const auto model_param_header = absl::StrCat("Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
const auto model_param_header = absl::StrCat(
"Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
for (size_t pos = 0; (pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
for (size_t pos = 0;
(pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
LOG_DEBUG(absl::StrCat("Found boundary at ", pos));
pos += boundary.length();
size_t end_pos = body_view.find(boundary, pos);
@@ -264,7 +304,7 @@ FilterDataStatus PluginRootContext::onMultipartBody(
LOG_DEBUG("No value start found in part");
break;
}
value_start += 4; // Skip the "\r\n\r\n"
value_start += 4; // Skip the "\r\n\r\n"
// model parameter should be only one line
size_t value_end = part.find(CRLF, value_start);
if (value_end == std::string_view::npos) {
@@ -283,8 +323,12 @@ FilterDataStatus PluginRootContext::onMultipartBody(
const auto& model = model_value.substr(pos + 1);
replaceRequestHeader(add_provider_header, provider);
size_t new_size = 0;
auto new_buffer_data = absl::StrCat(body_view.substr(0, part_pos + value_start), model, body_view.substr(part_pos + value_end));
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0, std::numeric_limits<size_t>::max(), new_buffer_data, &new_size);
auto new_buffer_data =
absl::StrCat(body_view.substr(0, part_pos + value_start), model,
body_view.substr(part_pos + value_end));
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0,
std::numeric_limits<size_t>::max(),
new_buffer_data, &new_size);
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
", model:", model));
LOG_DEBUG(absl::StrCat("result=", result, " new_size=", new_size));
@@ -294,7 +338,8 @@ FilterDataStatus PluginRootContext::onMultipartBody(
}
}
// We are done now. We can stop processing the body.
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ", body_view.length() , " bytes."));
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ",
body_view.length(), " bytes."));
ctx.mode_ = MODE_BYPASS;
return FilterDataStatus::Continue;
}
@@ -324,8 +369,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
auto* rootCtx = rootContext();
body_total_size_ += body_size;
switch (mode_) {
case MODE_JSON:
{
case MODE_JSON: {
if (!end_stream) {
return FilterDataStatus::StopIterationAndBuffer;
}
@@ -333,8 +377,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
return rootCtx->onJsonBody(*config_, body->view());
}
case MODE_MULTIPART:
{
case MODE_MULTIPART: {
auto body =
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
return rootCtx->onMultipartBody(*this, *config_, body, end_stream);

View File

@@ -48,9 +48,10 @@ struct ModelRouterConfigRule {
std::string add_provider_header_;
std::string model_to_header_;
std::vector<std::string> enable_on_path_suffix_ = {
"/completions", "/embeddings", "/images/generations",
"/audio/speech", "/fine_tuning/jobs", "/moderations",
"/image-synthesis", "/video-synthesis"};
"/completions", "/embeddings", "/images/generations",
"/audio/speech", "/fine_tuning/jobs", "/moderations",
"/image-synthesis", "/video-synthesis", "/rerank",
"/messages"};
};
class PluginContext;
@@ -65,13 +66,20 @@ class PluginRootContext : public RootContext,
: RootContext(id, root_id) {}
~PluginRootContext() {}
bool onConfigure(size_t) override;
FilterHeadersStatus onHeader(PluginContext& ctx, const ModelRouterConfigRule&);
FilterHeadersStatus onHeader(PluginContext& ctx,
const ModelRouterConfigRule&);
FilterDataStatus onJsonBody(const ModelRouterConfigRule&, std::string_view);
FilterDataStatus onMultipartBody(PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body, bool end_stream);
FilterDataStatus onMultipartBody(PluginContext& ctx,
const ModelRouterConfigRule& rule,
WasmDataPtr& body, bool end_stream);
bool configure(size_t);
void incrementRequestCount();
private:
bool parsePluginConfig(const json&, ModelRouterConfigRule&) override;
uint64_t request_count_ = 0;
static constexpr uint64_t REBUILD_THRESHOLD = 1000;
static constexpr size_t MEMORY_THRESHOLD_BYTES = 200 * 1024 * 1024;
};
// Per-stream context.
@@ -98,4 +106,4 @@ class PluginContext : public Context {
} // namespace null_plugin
} // namespace proxy_wasm
#endif
#endif

View File

@@ -157,7 +157,8 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeader) {
@@ -183,7 +184,8 @@ TEST_F(ModelRouterTest, ModelToHeader) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, IgnorePath) {
@@ -210,7 +212,8 @@ TEST_F(ModelRouterTest, IgnorePath) {
body_.set(request_json);
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
@@ -244,10 +247,10 @@ TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
route_name_ = "route-a";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
std::string configuration = R"({
"addProviderHeader": "x-higress-llm-provider"
@@ -257,8 +260,11 @@ TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(R"(
content_type_ =
"multipart/form-data; "
"boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(
R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
@@ -274,16 +280,21 @@ Content-Type: application/json
[
]
----------------------------100751621174704322650451--
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
)",
std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.WillOnce([&](WasmBufferType, size_t start, size_t length, std::string_view body) {
std::cerr << "===============" << "\n";
.WillOnce([&](WasmBufferType, size_t start, size_t length,
std::string_view body) {
std::cerr << "==============="
<< "\n";
std::cerr << body << "\n";
std::cerr << "===============" << "\n";
std::cerr << "==============="
<< "\n";
EXPECT_EQ(start, 0);
EXPECT_EQ(length, std::numeric_limits<size_t>::max());
auto expected_body= std::regex_replace(R"(
auto expected_body = std::regex_replace(
R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
@@ -292,7 +303,9 @@ batch
Content-Disposition: form-data; name="model"
qwen-turbo
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
)",
std::regex("\n"),
"\r\n"); // Multipart data requires CRLF line endings
EXPECT_EQ(body, expected_body);
return WasmResult::Ok;
});
@@ -308,42 +321,54 @@ qwen-turbo
auto last_body_size = 0;
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
auto body = request_data.substr(
0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 +
2 /* "model" + CRLF + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
body = request_data.substr(
0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
body = request_data.substr(
0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::Continue);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 + 50 /* "qwen-turbo" + CRLF + boundary */);
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 +
50 /* "qwen-turbo" + CRLF + boundary */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::Continue);
last_body_size = body.size();
body_.set(request_data);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true),
FilterDataStatus::Continue);
}
TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
std::string configuration = R"(
std::string configuration = R"(
{
"modelToHeader": "x-higress-llm-model"
})";
@@ -352,8 +377,11 @@ TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
EXPECT_TRUE(root_context_->configure(configuration.size()));
path_ = "/v1/chat/completions";
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(R"(
content_type_ =
"multipart/form-data; "
"boundary=--------------------------100751621174704322650451";
std::string request_data = std::regex_replace(
R"(
----------------------------100751621174704322650451
Content-Disposition: form-data; name="purpose"
@@ -369,7 +397,8 @@ Content-Type: application/json
[
]
----------------------------100751621174704322650451--
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
)",
std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
EXPECT_CALL(*mock_context_,
setBuffer(testing::_, testing::_, testing::_, testing::_))
.Times(0);
@@ -384,38 +413,50 @@ Content-Type: application/json
auto last_body_size = 0;
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
auto body = request_data.substr(
0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 +
2 /* "model" + CRLF + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
body = request_data.substr(
0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::StopIterationAndBuffer);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
body = request_data.substr(
0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::Continue);
last_body_size = body.size();
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
body = request_data.substr(
0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
body_.set(body);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
FilterDataStatus::Continue);
last_body_size = body.size();
body_.set(request_data);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true),
FilterDataStatus::Continue);
}
} // namespace model_router

View File

@@ -7,21 +7,21 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/google/uuid v1.6.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.6-0.20251103065747-41d65dbb2f9e
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
)
require github.com/tetratelabs/wazero v1.7.2 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -4,10 +4,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/wasm-go v1.0.6-0.20251103065747-41d65dbb2f9e h1:wYW/DXjyQniQLaB26c+J9NQk3+AhqByzS1r18NShvB4=
github.com/higress-group/wasm-go v1.0.6-0.20251103065747-41d65dbb2f9e/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac h1:tdJzS56Xa6BSHAi9P2omvb98bpI8qFGg6jnCPtPmDgA=
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -38,6 +38,7 @@ func init() {
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessStreamingResponseBodyBy(onHttpResponseBody),
wrapper.WithRebuildAfterRequests[config.PluginConfig](1000),
)
}

View File

@@ -15,13 +15,18 @@ description: 针对LLM服务的负载均衡策略
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `lb_type` | string | 选填 | endpoint | 负载均衡类型,可选`endpoint`,`cluster` |
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
目前支持的负载均衡策略包括:
`lb_type``endpoint`支持的负载均衡策略包括:
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
- `endpoint_metrics`: 基于 llm 服务暴露的 metrics 进行负载均衡
`lb_type``cluster` 时支持的负载均衡策略包括:
- `cluster_metrics`: 基于网关统计的不同service的指标进行服务之间的负载均衡
# 全局最小请求数
## 功能说明
@@ -59,6 +64,7 @@ sequenceDiagram
## 配置示例
```yaml
lb_type: endpoint
lb_policy: global_least_request
lb_config:
serviceFQDN: redis.static
@@ -116,11 +122,12 @@ lb_config:
| `password` | string | 选填 | 空 | redis 密码 |
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
| `database` | int | 选填 | 0 | redis 数据库序号 |
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
| `redisKeyTTL` | int | 选填 | 1800s | prompt 前缀对应的key的ttl |
## 配置示例
```yaml
lb_type: endpoint
lb_policy: prefix_cache
lb_config:
serviceFQDN: redis.static
@@ -161,14 +168,73 @@ sequenceDiagram
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `criticalModels` | []string | | | critical的模型列表 |
| `metric_policy` | string | | | 如何使用llm暴露的metrics做负载均衡当前支持`[default, least, most]` |
| `target_metric` | string | 选填 | | 要使用的metric名称`metric_policy` 取值为 `least` 或者 `most` 时生效 |
| `rate_limit` | string | 选填 | 1 | 单个节点处理请求比例上限取值范围0~1 |
## 配置示例
使用 [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 中的算法
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: default
rate_limit: 0.6 # 单个节点承载的最大请求比例
```
根据当前排队请求数进行负载均衡
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_waiting
rate_limit: 0.6 # 单个节点承载的最大请求比例
```
根据当前GPU中正在处理的请求数进行负载均衡
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_running
rate_limit: 0.6 # 单个节点承载的最大请求比例
```
# 跨服务负载均衡
## 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `mode` | string | 必填 | | 如何使用服务级指标做负载均衡,当前支持`[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
| `service_list` | []string | 必填 | | 路由后端服务列表 |
| `rate_limit` | string | 选填 | 1 | 单个服务处理请求比例上限取值范围0~1 |
| `cluster_header` | string | 选填 | `x-envoy-target-cluster` | 通过取该header的值得知需要路由到哪个后端服务 |
| `queue_size` | int | 选填 | 100 | 根据最近的多少个请求进行观测指标的计算 |
`mode` 各取值含义如下:
- `LeastBusy`: 路由到当前并发请求数最少的服务
- `LeastTotalLatency`: 路由到当前RT最低的服务
- `LeastFirstTokenLatency`: 路由到当前首包RT最低的服务
## 配置示例
```yaml
lb_policy: least_busy
lb_type: cluster
lb_policy: cluster_metrics
lb_config:
criticalModels:
- meta-llama/Llama-2-7b-hf
- sql-lora
mode: LeastTotalLatency # 策略名称
queue_size: 100 # 统计指标时使用的最近请求数
rate_limit: 0.6 # 单个服务承载的最大请求比例
service_list:
- outbound|80||test-1.dns
- outbound|80||test-2.static
```

View File

@@ -15,14 +15,19 @@ The configuration is:
| Name | Type | Required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `lb_policy` | string | required | | load balance type |
| `lb_type` | string | optional | endpoint | load balance policy type, `endpoint` or `cluster` |
| `lb_policy` | string | required | | load balance policy type |
| `lb_config` | object | required | | configuration for the current load balance type |
Current supported load balance policies are:
When `lb_type = endpoint`, current supported load balance policies are:
- `global_least_request`: global least request based on redis
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
- `endpoint_metrics`: Load balancing based on metrics exposed by the llm service
When `lb_type = cluster`, current supported load balance policies are:
- `cluster_metrics`: Load balancing based on metrics of clusters
# Global Least Request
## Introduction
@@ -60,6 +65,7 @@ sequenceDiagram
## Configuration Example
```yaml
lb_type: endpoint
lb_policy: global_least_request
lb_config:
serviceFQDN: redis.static
@@ -118,11 +124,12 @@ Then subsequent requests with the same prefix will also be routed to pod 1:
| `password` | string | optional | `` | redis password |
| `timeout` | int | optional | 3000ms | redis request timeout |
| `database` | int | optional | 0 | redis database number |
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
| `redisKeyTTL` | int | optional | 1800s | prompt prefix key's ttl |
## Configuration Example
```yaml
lb_type: endpoint
lb_policy: prefix_cache
lb_config:
serviceFQDN: redis.static
@@ -164,14 +171,71 @@ sequenceDiagram
| Name | Type | Required | default | description |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `criticalModels` | []string | required | | critical model names |
| `metric_policy` | string | required | | How to use the metrics exposed by LLM for load balancing, currently supporting `[default, least, most]` |
| `target_metric` | string | optional | | The metric name to use. This is valid only when `metric_policy` is `least` or `most` |
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, 0~1 |
## Configuration Example
Use the algorithm of [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md):
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: default
rate_limit: 0.6
```
Load balancing based on the current number of queued requests:
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_waiting
rate_limit: 0.6
```
Load balancing based on the number of requests currently being processed by the GPU:
```yaml
lb_type: endpoint
lb_policy: metrics_based
lb_config:
metric_policy: least
target_metric: vllm:num_requests_running
rate_limit: 0.6
```
# Cross-service load balancing
## Configuration
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------------------|-------------|-------------------------------------|
| `mode` | string | required | | how to use cluster metrics, value of `[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
| `service_list` | []string | required | | service list of current route |
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, value of 0~1 |
| `cluster_header` | string | optional | `x-envoy-target-cluster` | By retrieving the value of this header, we can determine which backend service to route to |
| `queue_size` | int | optional | 100 | The metrics is calculated based on the number of most recent requests. |
The meanings of the values for `mode` are as follows:
- `LeastBusy`: Routes to the service with the fewest concurrent requests.
- `LeastTotalLatency`: Routes to the service with the lowest response time (RT).
- `LeastFirstTokenLatency`: Routes to the service with the lowest RT for the first packet.
## Configuration Example
```yaml
lb_policy: least_busy
lb_type: cluster
lb_policy: cluster_metrics
lb_config:
criticalModels:
- meta-llama/Llama-2-7b-hf
- sql-lora
```
mode: LeastTotalLatency
rate_limit: 0.6
service_list:
- outbound|80||test-1.dns
- outbound|80||test-2.static
```

View File

@@ -0,0 +1,218 @@
package cluster_metrics
import (
"fmt"
"math/rand"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
DefaultQueueSize = 100
DefaultClusterHeader = "x-higress-target-cluster"
)
type ClusterEndpointLoadBalancer struct {
// Configurations
Mode string
ClusterHeader string
ServiceList []string
RateLimit float64
// Statistic
ServiceRequestOngoing map[string]int
ServiceRequestCount map[string]int
FirstTokenLatencyRequests map[string]*utils.FixedQueue[float64]
TotalLatencyRequests map[string]*utils.FixedQueue[float64]
}
func NewClusterEndpointLoadBalancer(json gjson.Result) (ClusterEndpointLoadBalancer, error) {
lb := ClusterEndpointLoadBalancer{}
lb.ServiceRequestOngoing = make(map[string]int)
lb.ServiceRequestCount = make(map[string]int)
lb.FirstTokenLatencyRequests = make(map[string]*utils.FixedQueue[float64])
lb.TotalLatencyRequests = make(map[string]*utils.FixedQueue[float64])
lb.Mode = json.Get("mode").String()
lb.ClusterHeader = json.Get("cluster_header").String()
if lb.ClusterHeader == "" {
lb.ClusterHeader = DefaultClusterHeader
}
if json.Get("rate_limit").Exists() {
lb.RateLimit = json.Get("rate_limit").Float()
} else {
lb.RateLimit = 1.0
}
queueSize := int(json.Get("queue_size").Int())
if queueSize == 0 {
queueSize = DefaultQueueSize
}
for _, svc := range json.Get("service_list").Array() {
serviceName := svc.String()
lb.ServiceList = append(lb.ServiceList, serviceName)
lb.ServiceRequestOngoing[serviceName] = 0
lb.ServiceRequestCount[serviceName] = 0
lb.FirstTokenLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
lb.TotalLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
}
return lb, nil
}
func (lb ClusterEndpointLoadBalancer) getRequestRate(serviceName string) float64 {
totalRequestCount := 0
for _, v := range lb.ServiceRequestCount {
totalRequestCount += v
}
if totalRequestCount != 0 {
return float64(lb.ServiceRequestCount[serviceName]) / float64(totalRequestCount)
}
return 0
}
func (lb ClusterEndpointLoadBalancer) getServiceTTFT(serviceName string) float64 {
queue, ok := lb.FirstTokenLatencyRequests[serviceName]
if !ok || queue.Size() == 0 {
return 0
}
value := 0.0
queue.ForEach(func(i int, item float64) {
value += float64(item)
})
return value / float64(queue.Size())
}
func (lb ClusterEndpointLoadBalancer) getServiceTotalRT(serviceName string) float64 {
queue, ok := lb.TotalLatencyRequests[serviceName]
if !ok || queue.Size() == 0 {
return 0
}
value := 0.0
queue.ForEach(func(i int, item float64) {
value += float64(item)
})
return value / float64(queue.Size())
}
// Callbacks which are called in request path
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
ctx.SetContext("request_start", time.Now().UnixMilli())
candidate := lb.ServiceList[rand.Int()%len(lb.ServiceList)]
var debugInfo string
switch lb.Mode {
case "LeastBusy":
for svc, ongoingNum := range lb.ServiceRequestOngoing {
if candidate == svc {
continue
}
if lb.getRequestRate(candidate) >= lb.RateLimit {
candidate = svc
} else if ongoingNum < lb.ServiceRequestOngoing[candidate] && lb.getRequestRate(svc) < lb.RateLimit {
candidate = svc
}
}
for svc := range lb.ServiceRequestOngoing {
debugInfo += fmt.Sprintf("[service: %s] {ongoing request: %d, total request: %d, request rate: %.2f}, ",
svc, lb.ServiceRequestOngoing[svc], lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
}
case "LeastFirstTokenLatency":
candidateTTFT := lb.getServiceTTFT(candidate)
for _, svc := range lb.ServiceList {
if candidate == svc {
continue
}
if lb.getRequestRate(candidate) >= lb.RateLimit {
candidate = svc
candidateTTFT = lb.getServiceTTFT(svc)
} else if lb.getServiceTTFT(svc) < candidateTTFT && lb.getRequestRate(svc) < lb.RateLimit {
candidate = svc
candidateTTFT = lb.getServiceTTFT(svc)
}
}
for _, svc := range lb.ServiceList {
debugInfo += fmt.Sprintf("[service: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}, ",
svc, lb.getServiceTTFT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
}
case "LeastTotalLatency":
candidateTotalRT := lb.getServiceTotalRT(candidate)
for _, svc := range lb.ServiceList {
if candidate == svc {
continue
}
if lb.getRequestRate(candidate) >= lb.RateLimit {
candidate = svc
candidateTotalRT = lb.getServiceTotalRT(svc)
} else if lb.getServiceTotalRT(svc) < candidateTotalRT && lb.getRequestRate(svc) < lb.RateLimit {
candidate = svc
candidateTotalRT = lb.getServiceTotalRT(svc)
}
}
for _, svc := range lb.ServiceList {
debugInfo += fmt.Sprintf("[service: %s] {average latency: %.2f, total request: %d, request rate: %.2f}, ",
svc, lb.getServiceTotalRT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
}
}
debugInfo += fmt.Sprintf("final service: %s", candidate)
log.Debug(debugInfo)
proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, candidate)
ctx.SetContext(lb.ClusterHeader, candidate)
lb.ServiceRequestOngoing[candidate] += 1
lb.ServiceRequestCount[candidate] += 1
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
ctx.SetContext("statusCode", statusCode)
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
if ctx.GetContext("ttft_recorded") == nil {
candidate := ctx.GetContext(lb.ClusterHeader).(string)
duration := float64(time.Now().UnixMilli() - ctx.GetContext("request_start").(int64))
// punish failed request
if ctx.GetContext("statusCode").(string) != "200" {
for _, svc := range lb.ServiceList {
ttft := lb.getServiceTTFT(svc)
if duration < ttft {
duration = ttft
}
}
duration *= 2
}
lb.FirstTokenLatencyRequests[candidate].Enqueue(duration)
ctx.SetContext("ttft_recorded", struct{}{})
}
return data
}
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {
candidate := ctx.GetContext(lb.ClusterHeader).(string)
lb.ServiceRequestOngoing[candidate] -= 1
duration := float64(time.Now().UnixMilli() - ctx.GetContext("request_start").(int64))
// punish failed request
if ctx.GetContext("statusCode").(string) != "200" {
for _, svc := range lb.ServiceList {
rt := lb.getServiceTotalRT(svc)
if duration < rt {
duration = rt
}
}
duration *= 2
}
lb.TotalLatencyRequests[candidate].Enqueue(duration)
}

View File

@@ -40,13 +40,19 @@ type Metrics struct {
KvCacheMaxTokenCapacity int
}
type UserSelectedMetric struct {
MetricName string
MetricValue float64
}
type PodMetrics struct {
Pod
Metrics
UserSelectedMetric
}
func (pm *PodMetrics) String() string {
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
return fmt.Sprintf("Pod: %+v; Metrics: %+v, UserSelectedMetric: %+v", pm.Pod, pm.Metrics, pm.UserSelectedMetric)
}
func (pm *PodMetrics) Clone() *PodMetrics {
@@ -63,6 +69,10 @@ func (pm *PodMetrics) Clone() *PodMetrics {
KVCacheUsagePercent: pm.KVCacheUsagePercent,
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
},
UserSelectedMetric: UserSelectedMetric{
MetricName: pm.MetricName,
MetricValue: pm.MetricValue,
},
}
return clone
}

View File

@@ -23,7 +23,7 @@ import (
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
dto "github.com/prometheus/client_model/go"
"go.uber.org/multierr"
@@ -53,6 +53,16 @@ func PromToPodMetrics(
) (*backend.PodMetrics, error) {
var errs error
updated := existing.Clone()
// User selected metric
if updated.MetricName != "" {
metricValue, err := getLatestMetric(metricFamilies, updated.MetricName)
errs = multierr.Append(errs, err)
if err == nil {
updated.MetricValue = metricValue.GetGauge().GetValue()
}
return updated, errs
}
// Default metric
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
errs = multierr.Append(errs, err)
if err == nil {

View File

@@ -0,0 +1,120 @@
package endpoint_metrics
import (
"math/rand"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
FixedQueueSize = 100
)
type MetricsEndpointLoadBalancer struct {
metricPolicy string
targetMetric string
endpointRequests *utils.FixedQueue[string]
maxRate float64
}
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
lb := MetricsEndpointLoadBalancer{}
if json.Get("metric_policy").Exists() {
lb.metricPolicy = json.Get("metric_policy").String()
} else {
lb.metricPolicy = scheduling.MetricPolicyDefault
}
if json.Get("target_metric").Exists() {
lb.targetMetric = json.Get("target_metric").String()
}
if json.Get("rate_limit").Exists() {
lb.maxRate = json.Get("rate_limit").Float()
} else {
lb.maxRate = 1.0
}
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
return lb, nil
}
// Callbacks which are called in request path
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: true,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
return types.ActionContinue
}
finalAddress := targetPod.Address
otherHosts := []string{} // 如果当前host超过请求数限制那么在其中随机挑选一个
currentRate := 0.0
for k := range hostMetrics {
if k != finalAddress {
otherHosts = append(otherHosts, k)
}
}
if lb.endpointRequests.Size() != 0 {
count := 0.0
lb.endpointRequests.ForEach(func(i int, item string) {
if item == finalAddress {
count += 1
}
})
currentRate = count / float64(lb.endpointRequests.Size())
}
if currentRate > lb.maxRate && len(otherHosts) > 0 {
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
}
lb.endpointRequests.Enqueue(finalAddress)
log.Debugf("pod %s is selected", finalAddress)
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}

View File

@@ -20,7 +20,7 @@ import (
"errors"
"math"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)

View File

@@ -20,15 +20,22 @@ package scheduling
import (
"errors"
"fmt"
"math"
"math/rand"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend/vllm"
"github.com/prometheus/common/expfmt"
)
const (
MetricPolicyDefault = "default"
MetricPolicyLeast = "least"
MetricPolicyMost = "most"
)
const (
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
kvCacheThreshold = 0.8
@@ -107,11 +114,11 @@ var (
}
)
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
func NewScheduler(pm []*backend.PodMetrics, filter Filter) *Scheduler {
return &Scheduler{
podMetrics: pm,
filter: defaultFilter,
filter: filter,
}
}
@@ -130,7 +137,7 @@ func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error)
return pods[i].Pod, nil
}
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
func GetScheduler(hostMetrics map[string]string, metricPolicy string, targetMetric string) (*Scheduler, error) {
if len(hostMetrics) == 0 {
return nil, errors.New("backend is not support llm scheduling")
}
@@ -147,6 +154,9 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
Address: addr,
},
Metrics: backend.Metrics{},
UserSelectedMetric: backend.UserSelectedMetric{
MetricName: targetMetric,
},
}
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
if err != nil {
@@ -154,5 +164,60 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
}
pms = append(pms, pm)
}
return NewScheduler(pms), nil
if metricPolicy == MetricPolicyLeast {
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
min := math.MaxFloat64
max := 0.0
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
if pod.MetricValue <= min {
min = pod.MetricValue
}
if pod.MetricValue >= max {
max = pod.MetricValue
}
}
for _, pod := range pods {
if pod.MetricValue >= min && pod.MetricValue <= min+(max-min)/float64(len(pods)) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
filter := filter{
name: "least user selected metric",
filter: filterFunc,
}
return NewScheduler(pms, &filter), nil
} else if metricPolicy == MetricPolicyMost {
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
min := math.MaxFloat64
max := 0.0
filtered := []*backend.PodMetrics{}
for _, pod := range pods {
if pod.MetricValue <= min {
min = pod.MetricValue
}
if pod.MetricValue >= max {
max = pod.MetricValue
}
}
for _, pod := range pods {
if pod.MetricValue <= max && pod.MetricValue >= max-(max-min)/float64(len(pods)) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
filter := filter{
name: "most user selected metric",
filter: filterFunc,
}
return NewScheduler(pms, &filter), nil
}
return NewScheduler(pms, defaultFilter), nil
}

View File

@@ -16,40 +16,91 @@ import (
)
const (
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLua = `local seed = KEYS[1]
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s"
RedisLua = `local seed = tonumber(KEYS[1])
local hset_key = KEYS[2]
local current_target = KEYS[3]
local current_count = 0
local last_clean_key = KEYS[3]
local clean_interval = tonumber(KEYS[4])
local current_target = KEYS[5]
local healthy_count = tonumber(KEYS[6])
local enable_detail_log = KEYS[7]
math.randomseed(seed)
local function randomBool()
return math.random() >= 0.5
end
-- 1. Selection
local current_count = 0
local same_count_hits = 0
if redis.call('HEXISTS', hset_key, current_target) == 1 then
current_count = redis.call('HGET', hset_key, current_target)
for i = 4, #KEYS do
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
local count = redis.call('HGET', hset_key, KEYS[i])
if tonumber(count) < tonumber(current_count) then
current_target = KEYS[i]
current_count = count
elseif count == current_count and randomBool() then
current_target = KEYS[i]
end
end
end
for i = 8, 8 + healthy_count - 1 do
local host = KEYS[i]
local count = 0
local val = redis.call('HGET', hset_key, host)
if val then
count = tonumber(val) or 0
end
if same_count_hits == 0 or count < current_count then
current_target = host
current_count = count
same_count_hits = 1
elseif count == current_count then
same_count_hits = same_count_hits + 1
if math.random(same_count_hits) == 1 then
current_target = host
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
local new_count = redis.call("HGET", hset_key, current_target)
return current_target`
-- Collect host counts for logging
local host_details = {}
if enable_detail_log == "1" then
local fields = {}
for i = 8, #KEYS do
table.insert(fields, KEYS[i])
end
if #fields > 0 then
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
for i, val in ipairs(values) do
table.insert(host_details, fields[i])
table.insert(host_details, tostring(val or 0))
end
end
end
-- 2. Cleanup
local current_time = math.floor(seed / 1000000)
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
if current_time - last_clean_time >= clean_interval then
local all_keys = redis.call('HKEYS', hset_key)
if #all_keys > 0 then
-- Create a lookup table for current hosts (from index 8 onwards)
local current_hosts = {}
for i = 8, #KEYS do
current_hosts[KEYS[i]] = true
end
-- Remove keys not in current hosts
for _, host in ipairs(all_keys) do
if not current_hosts[host] then
redis.call('HDEL', hset_key, host)
end
end
end
redis.call('SET', last_clean_key, current_time)
end
return {current_target, new_count, host_details}`
)
type GlobalLeastRequestLoadBalancer struct {
redisClient wrapper.RedisClient
redisClient wrapper.RedisClient
maxRequestCount int64
cleanInterval int64 // seconds
enableDetailLog bool
}
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
@@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa
}
// database default is 0
database := json.Get("database").Int()
lb.maxRequestCount = json.Get("maxRequestCount").Int()
lb.cleanInterval = json.Get("cleanInterval").Int()
if lb.cleanInterval == 0 {
lb.cleanInterval = 60 * 60 // default 60 minutes
} else {
lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds
}
lb.enableDetailLog = true
if val := json.Get("enableDetailLog"); val.Exists() {
lb.enableDetailLog = val.Bool()
}
log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog)
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
}
@@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
ctx.SetContext("error", true)
return types.ActionContinue
}
allHostMap := make(map[string]struct{})
// Only healthy host can be selected
healthyHostArray := []string{}
for _, hostInfo := range hostInfos {
allHostMap[hostInfo[0]] = struct{}{}
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
healthyHostArray = append(healthyHostArray, hostInfo[0])
}
@@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
}
randomIndex := rand.Intn(len(healthyHostArray))
hostSelected := healthyHostArray[randomIndex]
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
// KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts]
keys := []interface{}{
time.Now().UnixMicro(),
fmt.Sprintf(RedisKeyFormat, routeName, clusterName),
fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName),
lb.cleanInterval,
hostSelected,
len(healthyHostArray),
"0",
}
if lb.enableDetailLog {
keys[6] = "1"
}
for _, v := range healthyHostArray {
keys = append(keys, v)
}
// Append unhealthy hosts (those in allHostMap but not in healthyHostArray)
for host := range allHostMap {
isHealthy := false
for _, hh := range healthyHostArray {
if host == hh {
isHealthy = true
break
}
}
if !isHealthy {
keys = append(keys, host)
}
}
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
if err := response.Error(); err != nil {
log.Errorf("HGetAll failed: %+v", err)
@@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
proxywasm.ResumeHttpRequest()
return
}
hostSelected = response.String()
valArray := response.Array()
if len(valArray) < 2 {
log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray)
ctx.SetContext("error", true)
proxywasm.ResumeHttpRequest()
return
}
hostSelected = valArray[0].String()
currentCount := valArray[1].Integer()
// detail log
if lb.enableDetailLog && len(valArray) >= 3 {
detailLogStr := "host and count: "
details := valArray[2].Array()
for i := 0; i+1 < len(details); i += 2 {
h := details[i].String()
c := details[i+1].String()
detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c)
}
log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr)
}
// check rate limit
if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) {
ctx.SetContext("error", true)
log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount)
// return 429
proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1)
ctx.DontReadResponseBody()
return
}
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
ctx.SetContext("error", true)
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
proxywasm.ResumeHttpRequest()
return
}
log.Debugf("host_selected: %s", hostSelected)
// finally resume the request
ctx.SetContext("host_selected", hostSelected)
proxywasm.ResumeHttpRequest()
})
if err != nil {
ctx.SetContext("error", true)
log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err)
return types.ActionContinue
}
return types.ActionPause
@@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo
if host_selected == "" {
log.Errorf("get host_selected failed")
} else {
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
if err != nil {
log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err)
}
}
}
}

View File

@@ -0,0 +1,220 @@
-- Mocking Redis environment
local redis_data = {
hset = {},
kv = {}
}
local redis = {
call = function(cmd, ...)
local args = {...}
if cmd == "HGET" then
local key, field = args[1], args[2]
return redis_data.hset[field]
elseif cmd == "HSET" then
local key, field, val = args[1], args[2], args[3]
redis_data.hset[field] = val
elseif cmd == "HINCRBY" then
local key, field, increment = args[1], args[2], args[3]
local val = tonumber(redis_data.hset[field] or 0)
redis_data.hset[field] = tostring(val + increment)
return redis_data.hset[field]
elseif cmd == "HKEYS" then
local keys = {}
for k, _ in pairs(redis_data.hset) do
table.insert(keys, k)
end
return keys
elseif cmd == "HDEL" then
local key, field = args[1], args[2]
redis_data.hset[field] = nil
elseif cmd == "GET" then
return redis_data.kv[args[1]]
elseif cmd == "HMGET" then
local key = args[1]
local res = {}
for i = 2, #args do
table.insert(res, redis_data.hset[args[i]])
end
return res
elseif cmd == "SET" then
redis_data.kv[args[1]] = args[2]
end
end
}
-- The actual logic from lb_policy.go
local function run_lb_logic(KEYS)
local seed = tonumber(KEYS[1])
local hset_key = KEYS[2]
local last_clean_key = KEYS[3]
local clean_interval = tonumber(KEYS[4])
local current_target = KEYS[5]
local healthy_count = tonumber(KEYS[6])
local enable_detail_log = KEYS[7]
math.randomseed(seed)
-- 1. Selection
local current_count = 0
local same_count_hits = 0
for i = 8, 8 + healthy_count - 1 do
local host = KEYS[i]
local count = 0
local val = redis.call('HGET', hset_key, host)
if val then
count = tonumber(val) or 0
end
if same_count_hits == 0 or count < current_count then
current_target = host
current_count = count
same_count_hits = 1
elseif count == current_count then
same_count_hits = same_count_hits + 1
if math.random(same_count_hits) == 1 then
current_target = host
end
end
end
redis.call("HINCRBY", hset_key, current_target, 1)
local new_count = redis.call("HGET", hset_key, current_target)
-- Collect host counts for logging
local host_details = {}
if enable_detail_log == "1" then
local fields = {}
for i = 8, #KEYS do
table.insert(fields, KEYS[i])
end
if #fields > 0 then
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
for i, val in ipairs(values) do
table.insert(host_details, fields[i])
table.insert(host_details, tostring(val or 0))
end
end
end
-- 2. Cleanup
local current_time = math.floor(seed / 1000000)
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
if current_time - last_clean_time >= clean_interval then
local all_keys = redis.call('HKEYS', hset_key)
if #all_keys > 0 then
-- Create a lookup table for current hosts (from index 8 onwards)
local current_hosts = {}
for i = 8, #KEYS do
current_hosts[KEYS[i]] = true
end
-- Remove keys not in current hosts
for _, host in ipairs(all_keys) do
if not current_hosts[host] then
redis.call('HDEL', hset_key, host)
end
end
end
redis.call('SET', last_clean_key, current_time)
end
return {current_target, new_count, host_details}
end
-- --- Test 1: Load Balancing Distribution ---
print("--- Test 1: Load Balancing Distribution ---")
local hosts = {"host1", "host2", "host3", "host4", "host5"}
local iterations = 100000
local results = {}
for _, h in ipairs(hosts) do results[h] = 0 end
-- Reset redis
redis_data.hset = {}
for _, h in ipairs(hosts) do redis_data.hset[h] = "0" end
print(string.format("Running %d iterations with %d hosts (all counts started at 0)...", iterations, #hosts))
for i = 1, iterations do
local initial_host = hosts[math.random(#hosts)]
-- KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, enable_detail_log, ...healthy_hosts]
local keys = {i * 1000000, "table_key", "clean_key", 3600, initial_host, #hosts, "1"}
for _, h in ipairs(hosts) do table.insert(keys, h) end
local res = run_lb_logic(keys)
local selected = res[1]
results[selected] = results[selected] + 1
end
for _, h in ipairs(hosts) do
local percentage = (results[h] / iterations) * 100
print(string.format("%s: %6d (%.2f%%)", h, results[h], percentage))
end
-- --- Test 2: IP Cleanup Logic ---
print("\n--- Test 2: IP Cleanup Logic ---")
local function test_cleanup()
redis_data.hset = {
["host1"] = "10",
["host2"] = "5",
["old_ip_1"] = "1",
["old_ip_2"] = "1",
}
redis_data.kv["clean_key"] = "1000" -- Last cleaned at 1000s
local current_hosts = {"host1", "host2"}
local current_time_ms = 1000 * 1000000 + 500 * 1000000 -- 1500s (interval is 300s, let's say)
local clean_interval = 300
print("Initial Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end return res end)(), ", "))
-- Run logic (seed is microtime)
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "1"}
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
run_lb_logic(keys)
print("After Cleanup Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end table.sort(res) return res end)(), ", "))
local exists_old1 = redis_data.hset["old_ip_1"] ~= nil
local exists_old2 = redis_data.hset["old_ip_2"] ~= nil
if not exists_old1 and not exists_old2 then
print("Success: Outdated IPs removed.")
else
print("Failure: Outdated IPs still exist.")
end
print("New last_clean_time:", redis_data.kv["clean_key"])
end
test_cleanup()
-- --- Test 3: No Cleanup if Interval Not Reached ---
print("\n--- Test 3: No Cleanup if Interval Not Reached ---")
local function test_no_cleanup()
redis_data.hset = {
["host1"] = "10",
["old_ip_1"] = "1",
}
redis_data.kv["clean_key"] = "1000"
local current_hosts = {"host1"}
local current_time_ms = 1000 * 1000000 + 100 * 1000000 -- 1100s (interval 300s, not reached)
local clean_interval = 300
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "0"}
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
run_lb_logic(keys)
if redis_data.hset["old_ip_1"] then
print("Success: Cleanup not triggered as expected.")
else
print("Failure: Cleanup triggered unexpectedly.")
end
end
test_no_cleanup()

View File

@@ -0,0 +1,24 @@
package global_least_request
import (
"fmt"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func (lb GlobalLeastRequestLoadBalancer) checkRateLimit(hostSelected string, currentCount int64, ctx wrapper.HttpContext, routeName string, clusterName string) bool {
// 如果没有配置最大请求数,直接通过
if lb.maxRequestCount <= 0 {
return true
}
// 如果当前请求数大于最大请求数,则限流
// 注意Lua脚本已经加了1所以这里比较的是加1后的值
if currentCount > lb.maxRequestCount {
// 恢复 Redis 计数
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected, -1, nil)
return false
}
return true
}

View File

@@ -1,81 +0,0 @@
package least_busy
import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type LeastBusyLoadBalancer struct {
criticalModels map[string]struct{}
}
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
lb := LeastBusyLoadBalancer{}
lb.criticalModels = make(map[string]struct{})
for _, model := range json.Get("criticalModels").Array() {
lb.criticalModels[model.String()] = struct{}{}
}
return lb, nil
}
// Callbacks which are called in request path
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
return types.HeaderStopIteration
}
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
requestModel := gjson.GetBytes(body, "model")
if !requestModel.Exists() {
return types.ActionContinue
}
_, isCritical := lb.criticalModels[requestModel.String()]
llmReq := &scheduling.LLMRequest{
Model: requestModel.String(),
Critical: isCritical,
}
hostInfos, err := proxywasm.GetUpstreamHosts()
if err != nil {
return types.ActionContinue
}
hostMetrics := make(map[string]string)
for _, hostInfo := range hostInfos {
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
}
}
scheduler, err := scheduling.GetScheduler(hostMetrics)
if err != nil {
log.Debugf("initial scheduler failed: %v", err)
return types.ActionContinue
}
targetPod, err := scheduler.Schedule(llmReq)
log.Debugf("targetPod: %+v", targetPod.Address)
if err != nil {
log.Debugf("pod select failed: %v", err)
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
} else {
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
}
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
ctx.DontReadResponseBody()
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
return data
}
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
return types.ActionContinue
}
func (lb LeastBusyLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}

View File

@@ -7,9 +7,10 @@ import (
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
)
func main() {}
@@ -37,34 +38,57 @@ type LoadBalancer interface {
}
type Config struct {
policy string
lb LoadBalancer
lbType string
lbPolicy string
lb LoadBalancer
}
const (
LeastBusyLoadBalancerPolicy = "least_busy"
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
PrefixCache = "prefix_cache"
ClusterLoadBalancerType = "cluster"
EndpointLoadBalancerType = "endpoint"
// Cluster load balancer policies
MetricsBasedCluster = "cluster_metrics"
// Endpoint load balancer policies
MetricsBasedEndpoint = "endpoint_metrics"
MetricsBasedEndpointDeprecated = "metrics_based" // Compatible with old configurations, equal to `endpoint_metrics`
GlobalLeastRequestEndpoint = "global_least_request"
PrefixCacheEndpoint = "prefix_cache"
)
func parseConfig(json gjson.Result, config *Config) error {
config.policy = json.Get("lb_policy").String()
config.lbType = json.Get("lb_type").String()
// Compatible with old configurations
if config.lbType == "" {
config.lbType = EndpointLoadBalancerType
}
config.lbPolicy = json.Get("lb_policy").String()
var err error
switch config.policy {
case LeastBusyLoadBalancerPolicy:
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
case GlobalLeastRequestLoadBalancerPolicy:
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
case PrefixCache:
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
switch config.lbType {
case ClusterLoadBalancerType:
switch config.lbPolicy {
case MetricsBasedCluster:
config.lb, err = cluster_metrics.NewClusterEndpointLoadBalancer(json.Get("lb_config"))
default:
err = fmt.Errorf("lb_policy %s is not supported", config.lbPolicy)
}
case EndpointLoadBalancerType:
switch config.lbPolicy {
case MetricsBasedEndpoint, MetricsBasedEndpointDeprecated:
config.lb, err = endpoint_metrics.NewMetricsEndpointLoadBalancer(json.Get("lb_config"))
case GlobalLeastRequestEndpoint:
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
case PrefixCacheEndpoint:
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
default:
err = fmt.Errorf("lb_psolicy %s is not supported", config.lbPolicy)
}
default:
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
err = fmt.Errorf("lb_type %s is not supported", config.lbType)
}
return err
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
ctx.DisableReroute()
return config.lb.HandleHttpRequestHeaders(ctx)
}

View File

@@ -0,0 +1,175 @@
package utils
import (
"errors"
)
// FixedQueue 实现了一个固定容量的环形缓冲区队列
// 当队列满时,新元素会覆盖最旧的元素
type FixedQueue[T any] struct {
data []T
head int
tail int
size int
cap int
}
// NewFixed 创建一个指定容量的固定队列
func NewFixedQueue[T any](capacity int) *FixedQueue[T] {
if capacity <= 0 {
capacity = 16
}
return &FixedQueue[T]{
data: make([]T, capacity),
head: 0,
tail: 0,
size: 0,
cap: capacity,
}
}
// Enqueue 入队操作
// 如果队列已满,会覆盖最旧的元素
func (q *FixedQueue[T]) Enqueue(item T) {
if q.size < q.cap {
// 队列未满,正常插入
q.data[q.tail] = item
q.tail = (q.tail + 1) % q.cap
q.size++
} else {
// 队列已满,覆盖最旧元素
q.data[q.tail] = item
q.head = (q.head + 1) % q.cap // 移动head丢弃最旧元素
q.tail = (q.tail + 1) % q.cap // tail正常移动
// size保持不变仍然是cap
}
}
// Dequeue 出队操作
func (q *FixedQueue[T]) Dequeue() (T, error) {
var zero T
if q.size == 0 {
return zero, errors.New("queue is empty")
}
item := q.data[q.head]
// 清除引用,避免内存泄漏
var zeroVal T
q.data[q.head] = zeroVal
q.head = (q.head + 1) % q.cap
q.size--
return item, nil
}
// Peek 查看队头元素但不移除
func (q *FixedQueue[T]) Peek() (T, error) {
var zero T
if q.size == 0 {
return zero, errors.New("queue is empty")
}
return q.data[q.head], nil
}
// Size 返回队列中元素的数量
func (q *FixedQueue[T]) Size() int {
return q.size
}
// Capacity 返回队列的固定容量
func (q *FixedQueue[T]) Capacity() int {
return q.cap
}
// IsEmpty 判断队列是否为空
func (q *FixedQueue[T]) IsEmpty() bool {
return q.size == 0
}
// IsFull 判断队列是否已满
func (q *FixedQueue[T]) IsFull() bool {
return q.size == q.cap
}
// OverwriteCount 返回被覆盖的元素数量
// 注意:这个实现中我们不直接跟踪覆盖次数,
// 但可以通过其他方式计算(如果需要的话)
func (q *FixedQueue[T]) OverwriteCount() int {
// 如果需要跟踪覆盖次数,可以添加一个字段
// 目前这个实现不提供此功能
return 0
}
// Clear 清空队列
func (q *FixedQueue[T]) Clear() {
// 清除所有引用
for i := 0; i < q.size; i++ {
idx := (q.head + i) % q.cap
var zero T
q.data[idx] = zero
}
q.head = 0
q.tail = 0
q.size = 0
}
// ToSlice 返回队列元素的切片副本(按队列顺序,从最旧到最新)
func (q *FixedQueue[T]) ToSlice() []T {
if q.size == 0 {
return []T{}
}
result := make([]T, q.size)
if q.head <= q.tail || q.size == q.cap {
if q.head < q.tail {
// 数据连续且未满
copy(result, q.data[q.head:q.tail])
} else {
// 数据连续但已满head == tail
// 或者数据跨越边界
if q.head == q.tail && q.size == q.cap {
// 已满且head == tail的情况
copy(result, q.data[q.head:])
if len(result) > q.cap-q.head {
copy(result[q.cap-q.head:], q.data[:q.tail])
}
} else {
// 跨越边界
copy(result, q.data[q.head:])
copy(result[q.cap-q.head:], q.data[:q.tail])
}
}
} else {
// 跨越边界的情况
copy(result, q.data[q.head:])
copy(result[q.cap-q.head:], q.data[:q.tail])
}
return result
}
// Oldest 返回最旧的元素(队头)
func (q *FixedQueue[T]) Oldest() (T, error) {
return q.Peek()
}
// Newest 返回最新的元素(队尾的前一个元素)
func (q *FixedQueue[T]) Newest() (T, error) {
var zero T
if q.size == 0 {
return zero, errors.New("queue is empty")
}
// tail指向下一个插入位置所以最新元素在 (tail - 1 + cap) % cap
newestIndex := (q.tail - 1 + q.cap) % q.cap
return q.data[newestIndex], nil
}
// ForEach 对队列中的每个元素执行回调函数
func (q *FixedQueue[T]) ForEach(fn func(index int, item T)) {
for i := 0; i < q.size; i++ {
idx := (q.head + i) % q.cap
fn(i, q.data[idx])
}
}

View File

@@ -1,4 +1,7 @@
.DEFAULT:
build:
tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go
mv ai-proxy.wasm ../../../../docker-compose-test/
mv ai-proxy.wasm ../../../../docker-compose-test/
build-go:
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go

View File

@@ -231,6 +231,18 @@ Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:
| `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 |
| `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为 11434 |
#### 通用代理Generic
当只需要借助 AI Proxy 的鉴权、basePath 处理或首包超时能力,且不希望插件改写路径时,可将 `provider.type` 设置为 `generic`。该 Provider 不绑定任何模型厂商,也不会做能力映射。
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ------------- | -------- | -------- | ------ | -------------------------------------------------------------------- |
| `genericHost` | string | 非必填 | - | 指定要转发到的目标 Host未配置时沿用客户端请求的 Host。 |
- 配置了 `apiTokens` 时,会自动写入 `Authorization: Bearer <token>` 请求头,复用全局的 Token 轮询能力。
- 当配置了 `firstByteTimeout` 时,会自动注入 `x-envoy-upstream-rq-first-byte-timeout-ms`
- `basePath``basePathHandling` 同样适用,可在通用转发中快捷地移除或添加统一前缀。
#### 混元
混元所对应的 `type``hunyuan`。它特有的配置字段如下:
@@ -297,7 +309,9 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
#### Google Vertex AI
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
Google Vertex AI 所对应的 type 为 vertex。支持两种认证模式
**标准模式**(使用 Service Account
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
@@ -308,25 +322,56 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
**Express Mode**(使用 API Key简化配置
Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可快速开始使用,无需配置 Service Account。详见 [Vertex AI Express Mode 文档](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview)。
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
| `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key从 Google Cloud Console 的 API & Services > Credentials 获取 |
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
**OpenAI 兼容模式**(使用 Vertex AI Chat Completions API
Vertex AI 提供了 OpenAI 兼容的 Chat Completions API 端点,可以直接使用 OpenAI 格式的请求和响应,无需进行协议转换。详见 [Vertex AI OpenAI 兼容性文档](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview)。
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
| `vertexOpenAICompatible` | boolean | 非必填 | false | 启用 OpenAI 兼容模式。启用后将使用 Vertex AI 的 OpenAI-compatible Chat Completions API |
| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key |
| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等) |
| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID |
| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称 |
**注意**OpenAI 兼容模式与 Express Mode 互斥,不能同时配置 `apiTokens``vertexOpenAICompatible`
#### AWS Bedrock
AWS Bedrock 所对应的 type 为 bedrock。它特有的配置字段如下
AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|---------------------------|--------|------|-----|------------------------------|
| `modelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version |
| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain |
1. **AWS Signature V4 认证**:使用 `awsAccessKey``awsSecretKey` 进行 AWS 标准签名认证
2. **Bearer Token 认证**:使用 `apiTokens` 配置 AWS Bearer Token适用于 IAM Identity Center 等场景)
**注意**:两种认证方式二选一,如果同时配置了 `apiTokens`,将优先使用 Bearer Token 认证方式。
它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|---------------------------|---------------|-------------------|-------|---------------------------------------------------|
| `apiTokens` | array of string | 与 ak/sk 二选一 | - | AWS Bearer Token用于 Bearer Token 认证方式 |
| `awsAccessKey` | string | 与 apiTokens 二选一 | - | AWS Access Key用于 AWS Signature V4 认证 |
| `awsSecretKey` | string | 与 apiTokens 二选一 | - | AWS Secret Access Key用于 AWS Signature V4 认证 |
| `awsRegion` | string | 必填 | - | AWS 区域例如us-east-1 |
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
#### NVIDIA Triton Interference Server
NVIDIA Triton Interference Server 所对应的 type 为 triton。它特有的配置字段如下
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|---------------------------|--------|------|-----|------------------------------|
| `awsAccessKey` | string | 必填 | - | AWS Access Key用于身份认证 |
| `awsSecretKey` | string | 必填 | - | AWS Secret Access Key用于身份认证 |
| `awsRegion` | string | 必填 | - | AWS 区域例如us-east-1 |
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|----------------------|--------|--------|-------|------------------------------------------|
| `tritonModelVersion` | string | 必填 | - | 用于指定 Triton Server 中 model version |
| `tritonDomain` | string | 必填 | - | Triton Server 部署的指定请求 Domain |
## 用法示例
@@ -1935,7 +1980,7 @@ provider:
}
```
### 使用 OpenAI 协议代理 Google Vertex 服务
### 使用 OpenAI 协议代理 Google Vertex 服务(标准模式)
**配置信息**
@@ -1997,8 +2042,134 @@ provider:
}
```
### 使用 OpenAI 协议代理 Google Vertex 服务Express Mode
Express Mode 是 Vertex AI 的简化访问模式,只需 API Key 即可快速开始使用。
**配置信息**
```yaml
provider:
type: vertex
apiTokens:
- "YOUR_API_KEY"
```
**请求示例**
```json
{
"model": "gemini-2.5-flash",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"stream": false
}
```
**响应示例**
```json
{
"id": "chatcmpl-0000000000000",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是 Gemini由 Google 开发的人工智能助手。有什么我可以帮您的吗?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.5-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 25,
"total_tokens": 35
}
}
```
### 使用 OpenAI 协议代理 Google Vertex 服务OpenAI 兼容模式)
OpenAI 兼容模式使用 Vertex AI 的 OpenAI-compatible Chat Completions API请求和响应都使用 OpenAI 格式,无需进行协议转换。
**配置信息**
```yaml
provider:
type: vertex
vertexOpenAICompatible: true
vertexAuthKey: |
{
"type": "service_account",
"project_id": "your-project-id",
"private_key_id": "your-private-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
"token_uri": "https://oauth2.googleapis.com/token"
}
vertexRegion: us-central1
vertexProjectId: your-project-id
vertexAuthServiceName: your-auth-service-name
modelMapping:
"gpt-4": "gemini-2.0-flash"
"*": "gemini-1.5-flash"
```
**请求示例**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"stream": false
}
```
**响应示例**
```json
{
"id": "chatcmpl-abc123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是由 Google 开发的 Gemini 模型。我可以帮助回答问题、提供信息和进行对话。有什么我可以帮您的吗?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.0-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 12,
"completion_tokens": 35,
"total_tokens": 47
}
}
```
### 使用 OpenAI 协议代理 AWS Bedrock 服务
AWS Bedrock 支持两种认证方式:
#### 方式一:使用 AWS Access Key/Secret Key 认证AWS Signature V4
**配置信息**
```yaml
@@ -2006,7 +2177,21 @@ provider:
type: bedrock
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
awsRegion: "YOUR_AWS_REGION"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
#### 方式二:使用 Bearer Token 认证(适用于 IAM Identity Center 等场景)
**配置信息**
```yaml
provider:
type: bedrock
apiTokens:
- "YOUR_AWS_BEARER_TOKEN"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
@@ -2015,7 +2200,7 @@ provider:
```json
{
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
"messages": [
{
"role": "user",

View File

@@ -197,6 +197,18 @@ For Ollama, the corresponding `type` is `ollama`. Its unique configuration field
| `ollamaServerHost` | string | Required | - | The host address of the Ollama server. |
| `ollamaServerPort` | number | Required | - | The port number of the Ollama server, defaults to 11434. |
#### Generic
For a vendor-agnostic passthrough, set the provider `type` to `generic`. Requests are forwarded without path remapping, while still benefiting from the shared header/basePath utilities.
| Name | Data Type | Requirement | Default | Description |
|----------------|-----------|-------------|---------|----------------------------------------------------------------------------------------------------------|
| `genericHost` | string | Optional | - | Overrides the upstream `Host` header. Use it to route traffic to a specific backend domain for generic proxying. |
- When `apiTokens` are configured, the Generic provider injects `Authorization: Bearer <token>` automatically.
- `firstByteTimeout` applies to any request whose body sets `stream: true`, ensuring consistent streaming behavior even without capability definitions.
- `basePath` and `basePathHandling` remain available to strip or prepend prefixes before forwarding.
#### Hunyuan
For Hunyuan, the corresponding `type` is `hunyuan`. Its unique configuration fields are:
@@ -243,7 +255,9 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
#### Google Vertex AI
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
For Vertex, the corresponding `type` is `vertex`. It supports two authentication modes:
**Standard Mode** (using Service Account):
| Name | Data Type | Requirement | Default | Description |
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
@@ -254,16 +268,47 @@ For Vertex, the corresponding `type` is `vertex`. Its unique configuration field
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
**Express Mode** (using API Key, simplified configuration):
Express Mode is a simplified access mode introduced by Vertex AI. You can quickly get started with just an API Key, without configuring a Service Account. See [Vertex AI Express Mode documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview).
| Name | Data Type | Requirement | Default | Description |
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials |
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
**OpenAI Compatible Mode** (using Vertex AI Chat Completions API):
Vertex AI provides an OpenAI-compatible Chat Completions API endpoint, allowing you to use OpenAI format requests and responses directly without protocol conversion. See [Vertex AI OpenAI Compatibility documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview).
| Name | Data Type | Requirement | Default | Description |
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `vertexOpenAICompatible` | boolean | Optional | false | Enable OpenAI compatible mode. When enabled, uses Vertex AI's OpenAI-compatible Chat Completions API |
| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key for authentication |
| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) |
| `vertexProjectId` | string | Required | - | Google Cloud Project ID |
| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication |
**Note**: OpenAI Compatible Mode and Express Mode are mutually exclusive. You cannot configure both `apiTokens` and `vertexOpenAICompatible` at the same time.
#### AWS Bedrock
For AWS Bedrock, the corresponding `type` is `bedrock`. Its unique configuration field is:
For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods:
| Name | Data Type | Requirement | Default | Description |
|---------------------------|-----------|-------------|---------|---------------------------------------------------------|
| `awsAccessKey` | string | Required | - | AWS Access Key used for authentication |
| `awsSecretKey` | string | Required | - | AWS Secret Access Key used for authentication |
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
1. **AWS Signature V4 Authentication**: Uses `awsAccessKey` and `awsSecretKey` for standard AWS signature authentication
2. **Bearer Token Authentication**: Uses `apiTokens` to configure AWS Bearer Token (suitable for IAM Identity Center and similar scenarios)
**Note**: Choose one of the two authentication methods. If `apiTokens` is configured, Bearer Token authentication will be used preferentially.
Its unique configuration fields are:
| Name | Data Type | Requirement | Default | Description |
|---------------------------|-----------------|--------------------------|---------|-------------------------------------------------------------------|
| `apiTokens` | array of string | Either this or ak/sk | - | AWS Bearer Token for Bearer Token authentication |
| `awsAccessKey` | string | Either this or apiTokens | - | AWS Access Key for AWS Signature V4 authentication |
| `awsSecretKey` | string | Either this or apiTokens | - | AWS Secret Access Key for AWS Signature V4 authentication |
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
## Usage Examples
@@ -1708,7 +1753,7 @@ provider:
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Standard Mode)
**Configuration Information**
```yaml
provider:
@@ -1766,14 +1811,148 @@ provider:
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Express Mode)
Express Mode is a simplified access mode for Vertex AI. You only need an API Key to get started quickly.
**Configuration Information**
```yaml
provider:
type: vertex
apiTokens:
- "YOUR_API_KEY"
```
**Request Example**
```json
{
"model": "gemini-2.5-flash",
"messages": [
{
"role": "user",
"content": "Who are you?"
}
],
"stream": false
}
```
**Response Example**
```json
{
"id": "chatcmpl-0000000000000",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am Gemini, an AI assistant developed by Google. How can I help you today?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.5-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 25,
"total_tokens": 35
}
}
```
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (OpenAI Compatible Mode)
OpenAI Compatible Mode uses Vertex AI's OpenAI-compatible Chat Completions API. Both requests and responses use OpenAI format, requiring no protocol conversion.
**Configuration Information**
```yaml
provider:
type: vertex
vertexOpenAICompatible: true
vertexAuthKey: |
{
"type": "service_account",
"project_id": "your-project-id",
"private_key_id": "your-private-key-id",
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
"token_uri": "https://oauth2.googleapis.com/token"
}
vertexRegion: us-central1
vertexProjectId: your-project-id
vertexAuthServiceName: your-auth-service-name
modelMapping:
"gpt-4": "gemini-2.0-flash"
"*": "gemini-1.5-flash"
```
**Request Example**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, who are you?"
}
],
"stream": false
}
```
**Response Example**
```json
{
"id": "chatcmpl-abc123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am Gemini, an AI model developed by Google. I can help answer questions, provide information, and engage in conversations. How can I assist you today?"
},
"finish_reason": "stop"
}
],
"created": 1729986750,
"model": "gemini-2.0-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 12,
"completion_tokens": 35,
"total_tokens": 47
}
}
```
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
AWS Bedrock supports two authentication methods:
#### Method 1: Using AWS Access Key/Secret Key Authentication (AWS Signature V4)
**Configuration Information**
```yaml
provider:
type: bedrock
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
awsRegion: "YOUR_AWS_REGION"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
#### Method 2: Using Bearer Token Authentication (suitable for IAM Identity Center and similar scenarios)
**Configuration Information**
```yaml
provider:
type: bedrock
apiTokens:
- "YOUR_AWS_BEARER_TOKEN"
awsRegion: "us-east-1"
bedrockAdditionalFields:
top_k: 200
```
@@ -1781,7 +1960,7 @@ provider:
**Request Example**
```json
{
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
"messages": [
{
"role": "user",

View File

@@ -7,8 +7,8 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.9-0.20251226032831-95da539a1ec7
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)

View File

@@ -2,10 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac h1:tdJzS56Xa6BSHAi9P2omvb98bpI8qFGg6jnCPtPmDgA=
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/higress-group/wasm-go v1.0.9-0.20251226032831-95da539a1ec7 h1:ddAkPFIIf6isVQNymS6+X6QO51/WV0Af4Afb9a2z9TE=
github.com/higress-group/wasm-go v1.0.9-0.20251226032831-95da539a1ec7/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=

View File

@@ -102,6 +102,7 @@ func init() {
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
wrapper.ProcessResponseBody(onHttpResponseBody),
wrapper.WithRebuildAfterRequests[config.PluginConfig](1000),
wrapper.WithRebuildMaxMemBytes[config.PluginConfig](200*1024*1024),
)
}
@@ -145,7 +146,7 @@ func initContext(ctx wrapper.HttpContext) {
ctx.SetContext(ctxKey, value)
}
for _, originHeader := range headerToOriginalHeaderMapping {
proxywasm.RemoveHttpRequestHeader(originHeader)
_ = proxywasm.RemoveHttpRequestHeader(originHeader)
}
originalAuth, _ := proxywasm.GetHttpRequestHeader(util.HeaderOriginalAuth)
if originalAuth == "" {
@@ -204,23 +205,24 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
apiName = handler.GetApiName(path.Path)
}
}
// Auto-detect protocol based on request path and handle conversion if needed
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
// convert to OpenAI format (/v1/chat/completions)
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
// Provider doesn't support Claude protocol natively, convert to OpenAI format
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
// Update apiName to match the new path
apiName = provider.ApiNameChatCompletion
// Mark that we need to convert response back to Claude format
ctx.SetContext("needClaudeResponseConversion", true)
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
} else if apiName == provider.ApiNameAnthropicMessages {
// Provider supports Claude protocol natively, no conversion needed
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
} else {
// Only perform protocol conversion for non-original protocols.
// Auto-detect protocol based on request path and handle conversion if needed
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
// convert to OpenAI format (/v1/chat/completions)
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
// Provider doesn't support Claude protocol natively, convert to OpenAI format
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
// Update apiName to match the new path
apiName = provider.ApiNameChatCompletion
// Mark that we need to convert response back to Claude format
ctx.SetContext("needClaudeResponseConversion", true)
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
} else if apiName == provider.ApiNameAnthropicMessages {
// Provider supports Claude protocol natively, no conversion needed
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
}
}
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
@@ -247,8 +249,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
providerConfig.SetAvailableApiTokens(ctx)
// save the original request host and path in case they are needed for apiToken health check and retry
ctx.SetContext(provider.CtxRequestHost, wrapper.GetRequestHost())
ctx.SetContext(provider.CtxRequestPath, wrapper.GetRequestPath())
ctx.SetContext(provider.CtxRequestHost, ctx.Host())
ctx.SetContext(provider.CtxRequestPath, ctx.Path())
err := handler.OnRequestHeaders(ctx, apiName)
if err != nil {
@@ -256,7 +258,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
return types.ActionContinue
}
hasRequestBody := wrapper.HasRequestBody()
hasRequestBody := ctx.HasRequestBody()
if hasRequestBody {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)

View File

@@ -27,6 +27,10 @@ func Test_getApiName(t *testing.T) {
{"openai files", "/v1/files", provider.ApiNameFiles},
{"openai retrieve file", "/v1/files/fileid", provider.ApiNameRetrieveFile},
{"openai retrieve file content", "/v1/files/fileid/content", provider.ApiNameRetrieveFileContent},
{"openai videos", "/v1/videos", provider.ApiNameVideos},
{"openai retrieve video", "/v1/videos/videoid", provider.ApiNameRetrieveVideo},
{"openai retrieve video content", "/v1/videos/videoid/content", provider.ApiNameRetrieveVideoContent},
{"openai video remix", "/v1/videos/videoid/remix", provider.ApiNameVideoRemix},
{"openai models", "/v1/models", provider.ApiNameModels},
{"openai fine tuning jobs", "/v1/fine_tuning/jobs", provider.ApiNameFineTuningJobs},
{"openai retrieve fine tuning job", "/v1/fine_tuning/jobs/jobid", provider.ApiNameRetrieveFineTuningJob},
@@ -102,6 +106,7 @@ func TestAzure(t *testing.T) {
test.RunAzureOnHttpRequestBodyTests(t)
test.RunAzureOnHttpResponseHeadersTests(t)
test.RunAzureOnHttpResponseBodyTests(t)
test.RunAzureBasePathHandlingTests(t)
}
func TestFireworks(t *testing.T) {
@@ -109,3 +114,33 @@ func TestFireworks(t *testing.T) {
test.RunFireworksOnHttpRequestHeadersTests(t)
test.RunFireworksOnHttpRequestBodyTests(t)
}
func TestMinimax(t *testing.T) {
test.RunMinimaxBasePathHandlingTests(t)
}
func TestUtil(t *testing.T) {
test.RunMapRequestPathByCapabilityTests(t)
}
func TestGeneric(t *testing.T) {
test.RunGenericParseConfigTests(t)
test.RunGenericOnHttpRequestHeadersTests(t)
test.RunGenericOnHttpRequestBodyTests(t)
}
func TestVertex(t *testing.T) {
test.RunVertexParseConfigTests(t)
test.RunVertexExpressModeOnHttpRequestHeadersTests(t)
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
}
func TestBedrock(t *testing.T) {
test.RunBedrockParseConfigTests(t)
test.RunBedrockOnHttpRequestHeadersTests(t)
test.RunBedrockOnHttpRequestBodyTests(t)
test.RunBedrockOnHttpResponseHeadersTests(t)
test.RunBedrockOnHttpResponseBodyTests(t)
}

View File

@@ -174,12 +174,14 @@ func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName Ap
}
func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName ApiName) string {
originalPath := util.GetOriginalRequestPath()
// When using original protocol, don't overwrite the path.
// This ensures basePathHandling works correctly even in TransformRequestBody stage.
if m.config.IsOriginal() {
return originalPath
return ""
}
originalPath := util.GetOriginalRequestPath()
if m.serviceUrlType == azureServiceUrlTypeFull {
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
return m.serviceUrlFullPath

View File

@@ -43,8 +43,11 @@ const (
type bedrockProviderInitializer struct{}
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if len(config.awsAccessKey) == 0 || len(config.awsSecretKey) == 0 {
return errors.New("missing bedrock access authentication parameters")
hasAkSk := len(config.awsAccessKey) > 0 && len(config.awsSecretKey) > 0
hasApiToken := len(config.apiTokens) > 0
if !hasAkSk && !hasApiToken {
return errors.New("missing bedrock access authentication parameters: either apiTokens or (awsAccessKey + awsSecretKey) is required")
}
if len(config.awsRegion) == 0 {
return errors.New("missing bedrock region parameters")
@@ -107,9 +110,8 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
chatChoice.Delta.Content = nil
chatChoice.Delta.ToolCalls = []toolCall{
{
Id: bedrockEvent.Start.ToolUse.ToolUseID,
Index: 0,
Type: "function",
Id: bedrockEvent.Start.ToolUse.ToolUseID,
Type: "function",
Function: functionCall{
Name: bedrockEvent.Start.ToolUse.Name,
Arguments: "",
@@ -138,8 +140,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
if bedrockEvent.Delta.ToolUse != nil {
chatChoice.Delta.ToolCalls = []toolCall{
{
Index: 0,
Type: "function",
Type: "function",
Function: functionCall{
Arguments: bedrockEvent.Delta.ToolUse.Input,
},
@@ -168,7 +169,6 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
TotalTokens: bedrockEvent.Usage.TotalTokens,
}
}
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
var openAIChunk strings.Builder
openAIChunk.WriteString(ssePrefix)
@@ -637,6 +637,13 @@ func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
// If apiTokens is configured, set Bearer token authentication here
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
// AWS SigV4 authentication is handled in setAuthHeaders because it requires the request body
if len(b.config.apiTokens) > 0 {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+b.config.GetApiTokenInUse(ctx))
}
}
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
@@ -662,18 +669,18 @@ func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
case ApiNameChatCompletion:
return b.onChatCompletionResponseBody(ctx, body)
case ApiNameImageGeneration:
return b.onImageGenerationResponseBody(ctx, body)
return b.onImageGenerationResponseBody(body)
}
return nil, errUnsupportedApiName
}
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
func (b *bedrockProvider) onImageGenerationResponseBody(body []byte) ([]byte, error) {
bedrockResponse := &bedrockImageGenerationResponse{}
if err := json.Unmarshal(body, bedrockResponse); err != nil {
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
}
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
response := b.buildBedrockImageGenerationResponse(bedrockResponse)
return json.Marshal(response)
}
@@ -713,7 +720,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
return requestBytes, err
}
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
data := make([]imageGenerationData, len(bedrockResponse.Images))
for i, image := range bedrockResponse.Images {
data[i] = imageGenerationData{
@@ -1062,17 +1069,19 @@ func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
Text: text,
},
}
openaiContent := chatMessage.ParseContent()
for _, part := range openaiContent {
var content bedrockMessageContent
if part.Type == contentTypeText {
content.Text = part.Text
} else {
continue
} else if contentList, ok := chatMessage.Content.([]any); ok {
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if ok && contentMap["type"] == contentTypeText {
if text, ok := contentMap[contentTypeText].(string); ok {
toolResultContent.Content = append(toolResultContent.Content, toolResultContentBlock{
Text: text,
})
}
}
}
} else {
log.Warnf("only text content is supported, current content is %v", chatMessage.Content)
log.Warnf("the content type is not supported, current content is %v", chatMessage.Content)
}
return bedrockMessage{
Role: roleUser,
@@ -1139,6 +1148,13 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
}
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
// Bearer token authentication is already set in TransformRequestHeaders
// This function only handles AWS SigV4 authentication which requires the request body
if len(b.config.apiTokens) > 0 {
return
}
// Use AWS Signature V4 authentication
t := time.Now().UTC()
amzDate := t.Format("20060102T150405Z")
dateStamp := t.Format("20060102")

View File

@@ -203,19 +203,20 @@ type claudeThinkingConfig struct {
}
type claudeTextGenRequest struct {
Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"`
System *claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
Model string `json:"model,omitempty"`
Messages []claudeChatMessage `json:"messages"`
System *claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
AnthropicVersion string `json:"anthropic_version,omitempty"`
}
type claudeTextGenResponse struct {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
@@ -59,7 +60,18 @@ func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if d.config.difyApiUrl != "" {
log.Debugf("use local host: %s", d.config.difyApiUrl)
util.OverwriteRequestHostHeader(headers, d.config.difyApiUrl)
// Extract hostname, including Full URL or Domain
host := d.config.difyApiUrl
if parsedUrl, err := url.Parse(d.config.difyApiUrl); err == nil && parsedUrl.Host != "" {
host = parsedUrl.Host
} else {
host = strings.TrimPrefix(strings.TrimPrefix(d.config.difyApiUrl, "http://"), "https://")
if idx := strings.Index(host, "/"); idx != -1 {
host = host[:idx]
}
}
log.Debugf("extracted hostname: %s", host)
util.OverwriteRequestHostHeader(headers, host)
} else {
util.OverwriteRequestHostHeader(headers, difyDomain)
}

View File

@@ -70,7 +70,11 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, doubaoDomain)
if m.config.doubaoDomain != "" {
util.OverwriteRequestHostHeader(headers, m.config.doubaoDomain)
} else {
util.OverwriteRequestHostHeader(headers, doubaoDomain)
}
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}

View File

@@ -0,0 +1,85 @@
package provider
import (
"net/http"
"strconv"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// genericProviderInitializer 用于创建一个不做能力映射的通用 Provider。
type genericProviderInitializer struct{}
// ValidateConfig 通用 Provider 不需要额外的配置校验。
func (m *genericProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil
}
// DefaultCapabilities 返回空映射,表示不会做路径或能力重写。
func (m *genericProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{}
}
// CreateProvider 创建 generic provider并沿用通用的上下文缓存能力。
func (m *genericProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(m.DefaultCapabilities())
return &genericProvider{
config: config,
}, nil
}
// genericProvider 只负责公共的头部、请求体处理逻辑,不绑定任何厂商。
type genericProvider struct {
config ProviderConfig
}
func (m *genericProvider) GetProviderType() string {
return providerTypeGeneric
}
// OnRequestHeaders 复用通用的 handleRequestHeaders并在配置首包超时时写入相关头部。
func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
if m.config.firstByteTimeout > 0 {
ctx.SetContext(ctxKeyIsStreaming, true)
m.applyFirstByteTimeout()
}
return nil
}
func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
return types.ActionContinue, nil
}
// TransformRequestHeaders 只处理鉴权与 Host 改写,不做路径重写。
func (m *genericProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if len(m.config.apiTokens) > 0 {
if token := m.config.GetApiTokenInUse(ctx); token != "" {
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+token)
}
}
if m.config.genericHost != "" {
util.OverwriteRequestHostHeader(headers, m.config.genericHost)
}
headers.Del("Content-Length")
}
// applyFirstByteTimeout 在配置了 firstByteTimeout 时,为所有流式请求写入超时头。
func (m *genericProvider) applyFirstByteTimeout() {
if m.config.firstByteTimeout == 0 {
return
}
err := proxywasm.ReplaceHttpRequestHeader(
"x-envoy-upstream-rq-first-byte-timeout-ms",
strconv.FormatUint(uint64(m.config.firstByteTimeout), 10),
)
if err != nil {
log.Errorf("generic provider: failed to set first byte timeout header: %v", err)
return
}
log.Debugf("[generic][firstByteTimeout] %d", m.config.firstByteTimeout)
}

View File

@@ -106,8 +106,11 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte) (typ
}
// Map the model and rewrite the request path.
// When using original protocol, don't overwrite the path to ensure basePathHandling works correctly.
request.Model = getMappedModel(request.Model, m.config.modelMapping)
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if !m.config.IsOriginal() {
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
}
if m.config.context == nil {
minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, "")

View File

@@ -393,7 +393,7 @@ func (m *chatMessage) ParseContent() []chatMessageContent {
}
type toolCall struct {
Index int `json:"index,omitempty"`
Index int `json:"index"`
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function functionCall `json:"function"`

View File

@@ -144,6 +144,7 @@ const (
providerTypeLongcat = "longcat"
providerTypeFireworks = "fireworks"
providerTypeVllm = "vllm"
providerTypeGeneric = "generic"
protocolOpenAI = "openai"
protocolOriginal = "original"
@@ -227,6 +228,7 @@ var (
providerTypeLongcat: &longcatProviderInitializer{},
providerTypeFireworks: &fireworksProviderInitializer{},
providerTypeVllm: &vllmProviderInitializer{},
providerTypeGeneric: &genericProviderInitializer{},
}
)
@@ -385,6 +387,9 @@ type ProviderConfig struct {
// @Title zh-CN Vertex token刷新提前时间
// @Description zh-CN 用于Google服务账号认证access token过期时间判定提前刷新单位为秒默认值为60秒
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
// @Title zh-CN Vertex AI OpenAI兼容模式
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API请求和响应均使用OpenAI格式无需协议转换。与Express Mode(apiTokens)互斥。
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
// @Title zh-CN 翻译服务需指定的目标语种
// @Description zh-CN 翻译结果的语种目前仅适用于DeepL服务。
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
@@ -409,6 +414,9 @@ type ProviderConfig struct {
basePath string `required:"false" yaml:"basePath" json:"basePath"`
// @Title zh-CN basePathHandling用于指定basePath的处理方式可选值removePrefix、prepend
basePathHandling basePathHandling `required:"false" yaml:"basePathHandling" json:"basePathHandling"`
// @Title zh-CN generic Provider 对应的Host
// @Description zh-CN 仅适用于generic provider用于覆盖请求转发的目标Host
genericHost string `required:"false" yaml:"genericHost" json:"genericHost"`
// @Title zh-CN 首包超时
// @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0表示不开启首包超时
firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"`
@@ -424,6 +432,9 @@ type ProviderConfig struct {
// @Title zh-CN vLLM主机地址
// @Description zh-CN 仅适用于vLLM服务指定vLLM服务器的主机地址例如vllm-service.cluster.local
vllmServerHost string `required:"false" yaml:"vllmServerHost" json:"vllmServerHost"`
// @Title zh-CN 豆包服务域名
// @Description zh-CN 仅适用于豆包服务,默认转发域名为 ark.cn-beijing.volces.com
doubaoDomain string `required:"false" yaml:"doubaoDomain" json:"doubaoDomain"`
}
func (c *ProviderConfig) GetId() string {
@@ -532,6 +543,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.vertexTokenRefreshAhead == 0 {
c.vertexTokenRefreshAhead = 60
}
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
c.targetLang = json.Get("targetLang").String()
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
@@ -619,8 +631,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.basePath != "" && c.basePathHandling == "" {
c.basePathHandling = basePathHandlingRemovePrefix
}
c.genericHost = json.Get("genericHost").String()
c.vllmServerHost = json.Get("vllmServerHost").String()
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
c.doubaoDomain = json.Get("doubaoDomain").String()
}
func (c *ProviderConfig) Validate() error {
@@ -963,12 +977,33 @@ func (c *ProviderConfig) handleRequestBody(
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
headers := util.GetRequestHeaders()
originPath := headers.Get(":path")
// Record the path after removePrefix processing
var removePrefixPath string
if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix {
headers.Set(":path", strings.TrimPrefix(originPath, c.basePath))
removePrefixPath = strings.TrimPrefix(originPath, c.basePath)
headers.Set(":path", removePrefixPath)
}
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, apiName, headers)
}
// When using original protocol with removePrefix, restore the basePath-processed path.
// This ensures basePathHandling works correctly even when TransformRequestHeaders
// overwrites the path (which most providers do).
//
// TODO: Most providers (OpenAI, vLLM, DeepSeek, Claude, etc.) unconditionally overwrite
// the path in TransformRequestHeaders without checking IsOriginal(). Ideally, each provider
// should check IsOriginal() before overwriting the path (like Qwen does). Once all providers
// are updated to handle protocol correctly, this workaround can be removed.
// Affected providers: OpenAI, vLLM, ZhipuAI, Moonshot, Longcat, DeepSeek, Azure, Yi,
// TogetherAI, Stepfun, Ollama, Hunyuan, GitHub, Doubao, Cohere, Baichuan, AI360, Claude,
// Groq, Grok, Spark, Fireworks, Cloudflare, Baidu, OpenRouter, DeepL (24+ providers)
if c.IsOriginal() && removePrefixPath != "" {
headers.Set(":path", removePrefixPath)
}
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
}

View File

@@ -21,21 +21,60 @@ import (
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
vertexAuthDomain = "oauth2.googleapis.com"
vertexDomain = "{REGION}-aiplatform.googleapis.com"
vertexDomain = "aiplatform.googleapis.com"
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexEmbeddingAction = "predict"
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
// Express Mode 路径模板 (不含 project/location)
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
// OpenAI-compatible endpoint 路径模板
// /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions
vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions"
vertexChatCompletionAction = "generateContent"
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
vertexAnthropicMessageAction = "rawPredict"
vertexAnthropicMessageStreamAction = "streamRawPredict"
vertexEmbeddingAction = "predict"
vertexGlobalRegion = "global"
contextClaudeMarker = "isClaudeRequest"
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
vertexAnthropicVersion = "vertex-2023-10-16"
)
type vertexProviderInitializer struct{}
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// Express Mode: 如果配置了 apiTokens则使用 API Key 认证
if len(config.apiTokens) > 0 {
// Express Mode 与 OpenAI 兼容模式互斥
if config.vertexOpenAICompatible {
return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)")
}
// Express Mode 不需要其他配置
return nil
}
// OpenAI 兼容模式: 需要 OAuth 认证配置
if config.vertexOpenAICompatible {
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode")
}
if config.vertexRegion == "" || config.vertexProjectId == "" {
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode")
}
if config.vertexAuthServiceName == "" {
return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode")
}
return nil
}
// 标准模式: 保持原有验证逻辑
if config.vertexAuthKey == "" {
return errors.New("missing vertexAuthKey in vertex provider config")
}
@@ -57,21 +96,45 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(v.DefaultCapabilities())
return &vertexProvider{
config: config,
client: wrapper.NewClusterClient(wrapper.DnsCluster{
provider := &vertexProvider{
config: config,
contextCache: createContextCache(&config),
claude: &claudeProvider{
config: config,
contextCache: createContextCache(&config),
},
}
// 仅标准模式需要 OAuth 客户端Express Mode 通过 apiTokens 配置)
if !provider.isExpressMode() {
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
Domain: vertexAuthDomain,
ServiceName: config.vertexAuthServiceName,
Port: 443,
}),
contextCache: createContextCache(&config),
}, nil
})
}
return provider, nil
}
// isExpressMode 检测是否启用 Express Mode
// 如果配置了 apiTokens则使用 Express ModeAPI Key 认证)
func (v *vertexProvider) isExpressMode() bool {
return len(v.config.apiTokens) > 0
}
// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式
// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API
func (v *vertexProvider) isOpenAICompatibleMode() bool {
return v.config.vertexOpenAICompatible
}
type vertexProvider struct {
client wrapper.HttpClient
config ProviderConfig
contextCache *contextCache
claude *claudeProvider
}
func (v *vertexProvider) GetProviderType() string {
@@ -94,8 +157,21 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
}
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
vertexRegionDomain := strings.Replace(vertexDomain, "{REGION}", v.config.vertexRegion, 1)
util.OverwriteRequestHostHeader(headers, vertexRegionDomain)
var finalVertexDomain string
if v.isExpressMode() {
// Express Mode: 固定域名,不带 region 前缀
finalVertexDomain = vertexDomain
} else {
// 标准模式: 带 region 前缀
if v.config.vertexRegion != vertexGlobalRegion {
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
} else {
finalVertexDomain = vertexDomain
}
}
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
}
func (v *vertexProvider) getToken() (cached bool, err error) {
@@ -137,8 +213,42 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if v.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetRequestHeaders()
// OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射
if v.isOpenAICompatibleMode() {
ctx.SetContext(contextOpenAICompatibleMarker, true)
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
return types.ActionContinue, err
}
// OpenAI 兼容模式需要 OAuth token
cached, err := v.getToken()
if cached {
return types.ActionContinue, nil
}
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
headers.Set("Content-Length", fmt.Sprint(len(body)))
if v.isExpressMode() {
// Express Mode: 不需要 Authorization headerAPI Key 已在 URL 中
headers.Del("Authorization")
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
return types.ActionContinue, err
}
// 标准模式: 需要获取 OAuth token
util.ReplaceRequestHeaders(headers)
_ = proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
@@ -162,17 +272,58 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
}
}
// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求
// 不转换请求体格式,只进行模型映射和路径设置
func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API")
}
// 解析请求进行模型映射
request := &chatCompletionRequest{}
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
// 设置 OpenAI 兼容端点路径
path := v.getOpenAICompatibleRequestPath()
util.OverwriteRequestPathHeader(headers, path)
// 如果模型被映射,需要更新请求体中的模型字段
if request.Model != "" {
body, _ = sjson.SetBytes(body, "model", request.Model)
}
// 保持 OpenAI 格式,直接返回(可能更新了模型字段)
return body, nil
}
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := v.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
if strings.HasPrefix(request.Model, "claude") {
ctx.SetContext(contextClaudeMarker, true)
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexChatRequest(request)
return json.Marshal(vertexRequest)
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
claudeRequest.Model = ""
claudeRequest.AnthropicVersion = vertexAnthropicVersion
claudeBody, err := json.Marshal(claudeRequest)
if err != nil {
return nil, err
}
return claudeBody, nil
} else {
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
util.OverwriteRequestPathHeader(headers, path)
vertexRequest := v.buildVertexChatRequest(request)
return json.Marshal(vertexRequest)
}
}
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
@@ -188,6 +339,15 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
}
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
return util.DecodeUnicodeEscapesInSSE(chunk), nil
}
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
}
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
if isLastChunk {
return []byte(ssePrefix + "[DONE]\n\n"), nil
@@ -225,6 +385,15 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
}
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON将非 ASCII 字符编码为 \uXXXX
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
return util.DecodeUnicodeEscapes(body), nil
}
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
return v.claude.TransformResponseBody(ctx, apiName, body)
}
if apiName == ApiNameChatCompletion {
return v.onChatCompletionResponseBody(ctx, body)
} else {
@@ -252,6 +421,9 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
PromptTokens: response.UsageMetadata.PromptTokenCount,
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
TotalTokens: response.UsageMetadata.TotalTokenCount,
CompletionTokensDetails: &completionTokensDetails{
ReasoningTokens: response.UsageMetadata.ThoughtsTokenCount,
},
},
}
for _, candidate := range response.Candidates {
@@ -320,6 +492,7 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
choice.Delta = &chatMessage{}
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
part := vertexResp.Candidates[0].Content.Parts[0]
if part.FunctionCall != nil {
@@ -361,6 +534,9 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
PromptTokens: vertexResp.UsageMetadata.PromptTokenCount,
CompletionTokens: vertexResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: vertexResp.UsageMetadata.TotalTokenCount,
CompletionTokensDetails: &completionTokensDetails{
ReasoningTokens: vertexResp.UsageMetadata.ThoughtsTokenCount,
},
},
}
return &streamResponse
@@ -370,6 +546,32 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if stream {
action = vertexAnthropicMessageStreamAction
} else {
action = vertexAnthropicMessageAction
}
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?,使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
}
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
@@ -379,7 +581,28 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
} else {
action = vertexChatCompletionAction
}
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
if v.isExpressMode() {
// Express Mode: 简化路径 + API Key 参数
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
apiKey := v.config.GetRandomToken()
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse使用 & 拼接
var fullPath string
if strings.Contains(action, "?") {
fullPath = basePath + "&key=" + apiKey
} else {
fullPath = basePath + "?key=" + apiKey
}
return fullPath
}
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
return path
}
// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径
func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
}
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
@@ -400,19 +623,22 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
},
}
if request.ReasoningEffort != "" {
thinkingBudget := 1024 // default
switch request.ReasoningEffort {
case "low":
thinkingBudget = 1024
case "medium":
thinkingBudget = 4096
case "high":
thinkingBudget = 16384
}
vertexRequest.GenerationConfig.ThinkingConfig = vertexThinkingConfig{
thinkingConfig := vertexThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: thinkingBudget,
ThinkingBudget: 1024,
}
switch request.ReasoningEffort {
case "none":
thinkingConfig.IncludeThoughts = false
thinkingConfig.ThinkingBudget = 0
case "low":
thinkingConfig.ThinkingBudget = 1024
case "medium":
thinkingConfig.ThinkingBudget = 4096
case "high":
thinkingConfig.ThinkingBudget = 16384
}
vertexRequest.GenerationConfig.ThinkingConfig = thinkingConfig
}
if request.Tools != nil {
functions := make([]function, 0, len(request.Tools))
@@ -633,6 +859,7 @@ type vertexUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"`
}
type vertexEmbeddingResponse struct {

View File

@@ -160,6 +160,59 @@ var azureResponseAPIConfig = func() json.RawMessage {
return data
}()
// 测试配置Azure OpenAI basePath移除 + original协议
var azureBasePathRemovePrefixOriginalConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-basepath-original",
},
"azureServiceUrl": "https://basepath-test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15-preview",
"basePath": "/azure-gpt4",
"basePathHandling": "removePrefix",
"protocol": "original",
},
})
return data
}()
// 测试配置Azure OpenAI basePath移除 + openai协议
var azureBasePathRemovePrefixOpenAIConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-basepath-openai",
},
"azureServiceUrl": "https://basepath-test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15-preview",
"basePath": "/azure-gpt4",
"basePathHandling": "removePrefix",
"modelMapping": map[string]string{
"*": "gpt-4",
},
},
})
return data
}()
// 测试配置Azure OpenAI basePath prepend + original协议
var azureBasePathPrependOriginalConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-prepend-original",
},
"azureServiceUrl": "https://prepend-test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15-preview",
"basePath": "/api/v1",
"basePathHandling": "prepend",
"protocol": "original",
},
})
return data
}()
func RunAzureParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本Azure OpenAI配置解析
@@ -682,3 +735,218 @@ func RunAzureOnHttpResponseBodyTests(t *testing.T) {
})
})
}
// RunAzureBasePathHandlingTests 测试 basePath 处理在不同协议下的行为
func RunAzureBasePathHandlingTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 核心用例:测试 basePath removePrefix 在 original 协议下能正常工作
// 重要:此测试验证在 TransformRequestBody 阶段后 path 仍然保持正确
// 之前的 bug 是 transformRequestPath 在 IsOriginal() 时返回 originalPath
// 导致在 Body 阶段 path 被重新覆盖为包含 basePath 的原始路径
t.Run("azure basePath removePrefix with original protocol after body processing", func(t *testing.T) {
host, status := test.NewTestHost(azureBasePathRemovePrefixOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟带有 basePath 前缀的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/azure-gpt4/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 在 Headers 阶段后验证 path此时 handleRequestHeaders 已执行)
headersAfterHeaderStage := host.GetRequestHeaders()
pathAfterHeaders, _ := test.GetHeaderValue(headersAfterHeaderStage, ":path")
// Headers 阶段后basePath 应该已被移除
require.NotContains(t, pathAfterHeaders, "/azure-gpt4",
"After headers stage: basePath should be removed")
// 执行 Body 阶段(此时 TransformRequestBody 会被调用)
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 核心验证:在 Body 阶段后验证 path
// 这是关键测试点:确保 TransformRequestBody 中的 transformRequestPath
// 不会将 path 重新覆盖为包含 basePath 的原始路径
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// basePath "/azure-gpt4" 不应该出现在最终路径中
require.NotContains(t, pathValue, "/azure-gpt4",
"After body stage: basePath should still be removed (not restored by TransformRequestBody)")
// 路径应该是移除 basePath 后的结果
require.Equal(t, "/v1/chat/completions", pathValue,
"Path should be the original path without basePath after full request processing")
// 验证 Host 被正确设置
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "basepath-test.openai.azure.com", hostValue)
// 验证 api-key 被正确设置
apiKeyValue, hasApiKey := test.GetHeaderValue(requestHeaders, "api-key")
require.True(t, hasApiKey, "api-key header should exist")
require.Equal(t, "sk-azure-basepath-original", apiKeyValue)
})
// 测试 basePath removePrefix 在 openai 协议下的行为
// 在 openai 协议下path 会被转换为 Azure 格式,但 basePath 仍然应该被移除
t.Run("azure basePath removePrefix with openai protocol after body processing", func(t *testing.T) {
host, status := test.NewTestHost(azureBasePathRemovePrefixOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟带有 basePath 前缀的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/azure-gpt4/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 执行 Body 阶段TransformRequestBody 会被调用)
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 在 Body 阶段后验证请求头
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// basePath 应该被移除,路径会被转换为 Azure 路径格式
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// basePath "/azure-gpt4" 不应该出现在最终路径中
require.NotContains(t, pathValue, "/azure-gpt4",
"After body stage: basePath should be removed from path")
// 在 openai 协议下,路径会被转换为 Azure 的路径格式
require.Contains(t, pathValue, "/openai/deployments/gpt-4/chat/completions",
"Path should be transformed to Azure format")
require.Contains(t, pathValue, "api-version=2024-02-15-preview",
"Path should contain API version")
})
// 测试 basePath prepend 在 original 协议下能正常工作
// 验证在 Body 阶段后 prepend 的 basePath 仍然保持
t.Run("azure basePath prepend with original protocol after body processing", func(t *testing.T) {
host, status := test.NewTestHost(azureBasePathPrependOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟不带 basePath 的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 执行 Body 阶段
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 在 Body 阶段后验证请求头
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证 basePath 被正确添加且在 Body 阶段后保持
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// basePath "/api/v1" 应该被添加到路径前面
require.True(t, strings.HasPrefix(pathValue, "/api/v1"),
"After body stage: Path should still start with prepended basePath")
})
// 测试 original 协议下请求体不被修改,同时验证 path 处理
t.Run("azure original protocol preserves request body and path", func(t *testing.T) {
host, status := test.NewTestHost(azureBasePathRemovePrefixOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/azure-gpt4/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体(包含自定义字段)
requestBody := `{
"model": "custom-model-name",
"messages": [{"role": "user", "content": "Hello"}],
"custom_field": "custom_value"
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被保持原样
transformedBody := host.GetRequestBody()
require.NotNil(t, transformedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(transformedBody, &bodyMap)
require.NoError(t, err)
// model 应该保持原样original 协议不做模型映射)
model, exists := bodyMap["model"]
require.True(t, exists, "Model should exist")
require.Equal(t, "custom-model-name", model, "Model should remain unchanged")
// 自定义字段应该保持原样
customField, exists := bodyMap["custom_field"]
require.True(t, exists, "Custom field should exist")
require.Equal(t, "custom_value", customField, "Custom field should remain unchanged")
// 同时验证 path 在 Body 阶段后仍然正确
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.NotContains(t, pathValue, "/azure-gpt4",
"After body stage: basePath should be removed")
require.Equal(t, "/v1/chat/completions", pathValue,
"Path should be correct after body processing")
})
// 测试无 basePath 前缀的请求removePrefix 配置不影响)
// 验证在 Body 阶段后 path 仍然保持正确
t.Run("azure request without basePath prefix after body processing", func(t *testing.T) {
host, status := test.NewTestHost(azureBasePathRemovePrefixOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟不带 basePath 前缀的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 执行 Body 阶段
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 在 Body 阶段后验证请求头
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// 路径应该保持原样(没有 basePath 前缀时removePrefix 不会改变 path
// 同时验证 TransformRequestBody 没有覆盖 path
require.Equal(t, "/v1/chat/completions", pathValue,
"After body stage: Path should remain unchanged when no basePath prefix")
})
})
}

View File

@@ -0,0 +1,527 @@
package test
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// Test config: Basic Bedrock config with AWS Access Key/Secret Key (AWS Signature V4)
var basicBedrockConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Bedrock config with Bearer Token authentication
var bedrockApiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-for-unit-test",
},
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Bedrock config with multiple Bearer Tokens
var bedrockMultiTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-1-for-unit-test",
"test-token-2-for-unit-test",
},
"awsRegion": "us-west-2",
"modelMapping": map[string]string{
"gpt-4": "anthropic.claude-3-opus-20240229-v1:0",
"*": "anthropic.claude-3-haiku-20240307-v1:0",
},
},
})
return data
}()
// Test config: Bedrock config with additional fields
var bedrockWithAdditionalFieldsConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsSecretKey": "test-sk-for-unit-test",
"awsRegion": "us-east-1",
"bedrockAdditionalFields": map[string]interface{}{
"top_k": 200,
},
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Invalid config - missing both apiTokens and ak/sk
var bedrockInvalidConfigMissingAuth = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Invalid config - missing region
var bedrockInvalidConfigMissingRegion = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"apiTokens": []string{
"test-token-for-unit-test",
},
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
// Test config: Invalid config - only has access key without secret key
var bedrockInvalidConfigPartialAkSk = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "bedrock",
"awsAccessKey": "test-ak-for-unit-test",
"awsRegion": "us-east-1",
"modelMapping": map[string]string{
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
},
},
})
return data
}()
func RunBedrockParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// Test basic Bedrock config with AWS Signature V4 authentication
t.Run("basic bedrock config with ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(basicBedrockConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test Bedrock config with Bearer Token authentication
t.Run("bedrock config with api token", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test Bedrock config with multiple tokens
t.Run("bedrock config with multiple tokens", func(t *testing.T) {
host, status := test.NewTestHost(bedrockMultiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test Bedrock config with additional fields
t.Run("bedrock config with additional fields", func(t *testing.T) {
host, status := test.NewTestHost(bedrockWithAdditionalFieldsConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// Test invalid config - missing authentication
t.Run("bedrock invalid config missing auth", func(t *testing.T) {
host, status := test.NewTestHost(bedrockInvalidConfigMissingAuth)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// Test invalid config - missing region
t.Run("bedrock invalid config missing region", func(t *testing.T) {
host, status := test.NewTestHost(bedrockInvalidConfigMissingRegion)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// Test invalid config - partial ak/sk (only access key, no secret key)
t.Run("bedrock invalid config partial ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(bedrockInvalidConfigPartialAkSk)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
func RunBedrockOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock request headers processing with AWS Signature V4
t.Run("bedrock chat completion request headers with ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(basicBedrockConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Verify request headers
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Host is changed to Bedrock service domain
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
})
// Test Bedrock request headers processing with Bearer Token
t.Run("bedrock chat completion request headers with api token", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Verify request headers
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Host is changed to Bedrock service domain
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
})
})
}
func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock request body processing with Bearer Token authentication
t.Run("bedrock chat completion request body with api token", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
],
"temperature": 0.7
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Verify request headers for Bearer Token authentication
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Authorization header uses Bearer token
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain the configured token")
// Verify path is transformed to Bedrock format
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
})
// Test Bedrock request body processing with AWS Signature V4 authentication
t.Run("bedrock chat completion request body with ak/sk", func(t *testing.T) {
host, status := test.NewTestHost(basicBedrockConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
],
"temperature": 0.7
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Verify request headers for AWS Signature V4 authentication
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Verify Authorization header uses AWS Signature
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
// Verify X-Amz-Date header exists
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
// Verify path is transformed to Bedrock format
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
})
// Test Bedrock streaming request
t.Run("bedrock streaming request", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set streaming request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
],
"stream": true
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Verify path is transformed to Bedrock streaming format
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
require.Contains(t, pathValue, "/converse-stream", "Path should contain converse-stream endpoint for streaming")
})
})
}
func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock response headers processing
t.Run("bedrock response headers", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Process response headers
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
{"X-Amzn-Requestid", "test-request-id-12345"},
})
require.Equal(t, types.ActionContinue, action)
// Verify response headers
responseHeaders := host.GetResponseHeaders()
require.NotNil(t, responseHeaders)
// Verify status code
statusValue, hasStatus := test.GetHeaderValue(responseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "200", statusValue, "Status should be 200")
})
})
}
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// Test Bedrock response body processing
t.Run("bedrock response body", func(t *testing.T) {
host, status := test.NewTestHost(bedrockApiTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// Set request headers
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// Set request body
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// Set response property to ensure IsResponseFromUpstream() returns true
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// Process response headers (must include :status 200 for body processing)
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.ActionContinue, action)
// Process response body (Bedrock format)
responseBody := `{
"output": {
"message": {
"role": "assistant",
"content": [
{
"text": "Hello! How can I help you today?"
}
]
}
},
"stopReason": "end_turn",
"usage": {
"inputTokens": 10,
"outputTokens": 15,
"totalTokens": 25
}
}`
action = host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// Verify response body is transformed to OpenAI format
transformedResponseBody := host.GetResponseBody()
require.NotNil(t, transformedResponseBody)
var responseMap map[string]interface{}
err := json.Unmarshal(transformedResponseBody, &responseMap)
require.NoError(t, err)
// Verify choices exist in transformed response
choices, exists := responseMap["choices"]
require.True(t, exists, "Choices should exist in response body")
require.NotNil(t, choices, "Choices should not be nil")
// Verify usage exists
usage, exists := responseMap["usage"]
require.True(t, exists, "Usage should exist in response body")
require.NotNil(t, usage, "Usage should not be nil")
})
})
}

View File

@@ -213,7 +213,9 @@ func RunFireworksOnHttpRequestHeadersTests(t *testing.T) {
{":method", "GET"},
})
require.Equal(t, types.ActionContinue, action)
// TODO: Due to the limitations of the test framework, we just treat it as a request with body here.
//require.Equal(t, types.ActionContinue, action)
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头处理
requestHeaders := host.GetRequestHeaders()

View File

@@ -0,0 +1,239 @@
package test
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 通用测试配置:最简配置,覆盖 host 与 token 注入。
var genericBasicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "generic",
"apiTokens": []string{"sk-generic-basic"},
"genericHost": "generic.backend.internal",
},
})
return data
}()
// 通用测试配置:开启 basePath removePrefix。
var genericBasePathConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "generic",
"apiTokens": []string{"sk-generic-basepath"},
"genericHost": "basepath.backend.internal",
"basePath": "/proxy",
"basePathHandling": "removePrefix",
},
})
return data
}()
// 通用测试配置:开启 basePath prepend。
var genericPrependBasePathConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "generic",
"apiTokens": []string{"sk-generic-prepend"},
"genericHost": "prepend.backend.internal",
"basePath": "/custom",
"basePathHandling": "prepend",
},
})
return data
}()
// 通用测试配置:覆盖 firstByteTimeout用于流式能力验证。
var genericStreamingConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "generic",
"apiTokens": []string{"sk-generic-stream"},
"genericHost": "stream.backend.internal",
"firstByteTimeout": 1500,
},
})
return data
}()
// 通用测试配置:无 token也不设置 host。
var genericNoTokenConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "generic",
},
})
return data
}()
func RunGenericParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
t.Run("generic basic config", func(t *testing.T) {
host, status := test.NewTestHost(genericBasicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
t.Run("generic config without token", func(t *testing.T) {
host, status := test.NewTestHost(genericNoTokenConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
})
t.Run("generic config with streaming options", func(t *testing.T) {
host, status := test.NewTestHost(genericStreamingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
func RunGenericOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("generic injects token and custom host", func(t *testing.T) {
host, status := test.NewTestHost(genericBasicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "client.local"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "generic.backend.internal"))
require.True(t, test.HasHeaderWithValue(requestHeaders, "Authorization", "Bearer sk-generic-basic"))
_, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length")
require.False(t, hasContentLength, "generic provider should remove Content-Length")
})
t.Run("generic removes basePath prefix", func(t *testing.T) {
host, status := test.NewTestHost(genericBasePathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "client.local"},
{":path", "/proxy/service/echo"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, ":path", "/service/echo"))
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "basepath.backend.internal"))
})
t.Run("generic prepends basePath when configured", func(t *testing.T) {
host, status := test.NewTestHost(genericPrependBasePathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "client.local"},
{":path", "/v1/echo"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, ":path", "/custom/v1/echo"))
})
t.Run("generic firstByteTimeout injects timeout header only", func(t *testing.T) {
host, status := test.NewTestHost(genericStreamingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "client.local"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms", "1500"))
_, hasAccept := test.GetHeaderValue(requestHeaders, "Accept")
require.False(t, hasAccept, "Accept header should remain untouched when enabling firstByteTimeout")
})
})
}
func RunGenericOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("generic body passthrough keeps headers unchanged with timeout", func(t *testing.T) {
host, status := test.NewTestHost(genericStreamingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "client.local"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
body := `{"model":"gpt-any","stream":true}`
action := host.CallOnHttpRequestBody([]byte(body))
require.Equal(t, types.ActionContinue, action)
requestHeaders := host.GetRequestHeaders()
require.True(t, test.HasHeaderWithValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms", "1500"))
_, hasAccept := test.GetHeaderValue(requestHeaders, "Accept")
require.False(t, hasAccept, "Accept header should remain untouched even when firstByteTimeout is enabled")
processedBody := host.GetRequestBody()
require.JSONEq(t, body, string(processedBody))
})
t.Run("generic without first byte timeout keeps headers untouched", func(t *testing.T) {
host, status := test.NewTestHost(genericBasicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "client.local"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
body := `{"model":"gpt-any","stream":true}`
action := host.CallOnHttpRequestBody([]byte(body))
require.Equal(t, types.ActionContinue, action)
requestHeaders := host.GetRequestHeaders()
_, hasAccept := test.GetHeaderValue(requestHeaders, "Accept")
require.False(t, hasAccept, "Accept header should remain untouched when first byte timeout is disabled")
_, hasTimeout := test.GetHeaderValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms")
require.False(t, hasTimeout, "timeout header should not be added when first byte timeout is disabled")
processedBody := host.GetRequestBody()
require.JSONEq(t, body, string(processedBody))
})
})
}

View File

@@ -0,0 +1,251 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置Minimax Pro API + basePath removePrefix + original 协议
var minimaxProBasePathRemovePrefixOriginalConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "minimax",
"apiTokens": []string{
"sk-minimax-test",
},
"minimaxApiType": "pro",
"minimaxGroupId": "test-group-id",
"basePath": "/minimax-api",
"basePathHandling": "removePrefix",
"protocol": "original",
},
})
return data
}()
// 测试配置Minimax Pro API + basePath removePrefix + 默认协议openai
var minimaxProBasePathRemovePrefixOpenAIConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "minimax",
"apiTokens": []string{
"sk-minimax-openai",
},
"minimaxApiType": "pro",
"minimaxGroupId": "test-group-id",
"basePath": "/minimax-api",
"basePathHandling": "removePrefix",
},
})
return data
}()
// 测试配置Minimax V2 API + basePath removePrefix + original 协议
var minimaxV2BasePathRemovePrefixOriginalConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "minimax",
"apiTokens": []string{
"sk-minimax-v2",
},
"minimaxApiType": "v2",
"basePath": "/minimax-v2",
"basePathHandling": "removePrefix",
"protocol": "original",
},
})
return data
}()
// RunMinimaxBasePathHandlingTests 测试 Minimax basePath 处理在不同协议下的行为
func RunMinimaxBasePathHandlingTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 核心用例:测试 Minimax Pro API + basePath removePrefix + original 协议
// 重要:此测试验证在 handleRequestBodyByChatCompletionPro 阶段后 path 仍然保持正确
// 之前的 bug 是 handleRequestBodyByChatCompletionPro 无条件覆盖 path
// 导致在 Body 阶段 path 被重新覆盖为 minimaxChatCompletionProPath
t.Run("minimax pro basePath removePrefix with original protocol after body processing", func(t *testing.T) {
host, status := test.NewTestHost(minimaxProBasePathRemovePrefixOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟带有 basePath 前缀的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/minimax-api/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 在 Headers 阶段后验证 path此时 handleRequestHeaders 已执行)
headersAfterHeaderStage := host.GetRequestHeaders()
pathAfterHeaders, _ := test.GetHeaderValue(headersAfterHeaderStage, ":path")
// Headers 阶段后basePath 应该已被移除
require.NotContains(t, pathAfterHeaders, "/minimax-api",
"After headers stage: basePath should be removed")
// 执行 Body 阶段(此时 handleRequestBodyByChatCompletionPro 会被调用)
requestBody := `{"model": "abab5.5-chat", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 核心验证:在 Body 阶段后验证 path
// 这是关键测试点:确保 handleRequestBodyByChatCompletionPro
// 不会将 path 重新覆盖为 minimaxChatCompletionProPath
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// basePath "/minimax-api" 不应该出现在最终路径中
require.NotContains(t, pathValue, "/minimax-api",
"After body stage: basePath should still be removed")
// original 协议下path 不应该被覆盖为 minimaxChatCompletionProPath
require.NotContains(t, pathValue, "chatcompletion_pro",
"With original protocol: path should not be overwritten to minimax pro path")
// 路径应该是移除 basePath 后的结果
require.Equal(t, "/v1/chat/completions", pathValue,
"Path should be the original path without basePath after full request processing")
// 验证 Host 被正确设置
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "api.minimax.chat", hostValue)
// 验证 Authorization 被正确设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Equal(t, "Bearer sk-minimax-test", authValue)
})
// 测试 Minimax Pro API + basePath removePrefix + 默认协议openai
// 在 openai 协议下path 应该被覆盖为 minimaxChatCompletionProPath
t.Run("minimax pro basePath removePrefix with openai protocol after body processing", func(t *testing.T) {
host, status := test.NewTestHost(minimaxProBasePathRemovePrefixOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟带有 basePath 前缀的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/minimax-api/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 执行 Body 阶段
requestBody := `{"model": "abab5.5-chat", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 在 Body 阶段后验证请求头
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// basePath "/minimax-api" 不应该出现在最终路径中
require.NotContains(t, pathValue, "/minimax-api",
"After body stage: basePath should be removed from path")
// 在 openai 协议下path 应该被覆盖为 minimaxChatCompletionProPath
require.True(t, strings.Contains(pathValue, "chatcompletion_pro"),
"With openai protocol: path should be overwritten to minimax pro path")
require.Contains(t, pathValue, "GroupId=test-group-id",
"Path should contain GroupId parameter")
})
// 测试 Minimax V2 API + basePath removePrefix + original 协议
// V2 API 使用 handleRequestBody 而不是 handleRequestBodyByChatCompletionPro
t.Run("minimax v2 basePath removePrefix with original protocol after body processing", func(t *testing.T) {
host, status := test.NewTestHost(minimaxV2BasePathRemovePrefixOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 模拟带有 basePath 前缀的请求
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/minimax-v2/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 执行 Body 阶段
requestBody := `{"model": "abab5.5-chat", "messages": [{"role": "user", "content": "Hello"}]}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 在 Body 阶段后验证请求头
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// basePath "/minimax-v2" 不应该出现在最终路径中
require.NotContains(t, pathValue, "/minimax-v2",
"After body stage: basePath should be removed from path")
// 路径应该是移除 basePath 后的结果
require.Equal(t, "/v1/chat/completions", pathValue,
"Path should be the original path without basePath")
})
// 测试 original 协议下请求体保持原样
t.Run("minimax pro original protocol preserves request body and path", func(t *testing.T) {
host, status := test.NewTestHost(minimaxProBasePathRemovePrefixOriginalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/minimax-api/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体(包含自定义字段)
requestBody := `{
"model": "custom-model",
"messages": [{"role": "user", "content": "Hello"}],
"custom_field": "custom_value"
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体被保持原样
transformedBody := host.GetRequestBody()
require.NotNil(t, transformedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(transformedBody, &bodyMap)
require.NoError(t, err)
// model 应该保持原样original 协议不做模型映射)
model, exists := bodyMap["model"]
require.True(t, exists, "Model should exist")
require.Equal(t, "custom-model", model, "Model should remain unchanged")
// 自定义字段应该保持原样
customField, exists := bodyMap["custom_field"]
require.True(t, exists, "Custom field should exist")
require.Equal(t, "custom_value", customField, "Custom field should remain unchanged")
// 同时验证 path 在 Body 阶段后仍然正确
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.NotContains(t, pathValue, "/minimax-api",
"After body stage: basePath should be removed")
require.Equal(t, "/v1/chat/completions", pathValue,
"Path should be correct after body processing")
})
})
}

View File

@@ -0,0 +1,116 @@
package test
import (
"testing"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
)
func RunMapRequestPathByCapabilityTests(t *testing.T) {
testCases := []struct {
name string
apiName string
origin string
mapping map[string]string
expected string
}{
{
name: "no mapping returns empty",
apiName: "openai/v1/chatcompletions",
origin: "/v1/chat/completions",
mapping: map[string]string{},
expected: "",
},
{
name: "file placeholder is replaced",
apiName: "openai/v1/retrievefile",
origin: "/openai/v1/files/file-abc",
mapping: map[string]string{
"openai/v1/retrievefile": "/v1/files/{file_id}",
},
expected: "/v1/files/file-abc",
},
{
name: "file content keeps query parameters",
apiName: "openai/v1/retrievefilecontent",
origin: "/openai/v1/files/file-123/content?variant=thumbnail",
mapping: map[string]string{
"openai/v1/retrievefilecontent": "/v1/files/{file_id}/content",
},
expected: "/v1/files/file-123/content?variant=thumbnail",
},
{
name: "file content merges query string with mapped query",
apiName: "openai/v1/retrievefilecontent",
origin: "/openai/v1/files/file-123/content?variant=thumbnail",
mapping: map[string]string{
"openai/v1/retrievefilecontent": "/v1/files/{file_id}/content?download=1",
},
expected: "/v1/files/file-123/content?download=1&variant=thumbnail",
},
{
name: "retrieve batch replaces batch id",
apiName: "openai/v1/retrievebatch",
origin: "/openai/v1/batches/batch-001",
mapping: map[string]string{
"openai/v1/retrievebatch": "/v1/batches/{batch_id}",
},
expected: "/v1/batches/batch-001",
},
{
name: "cancel batch replaces batch id",
apiName: "openai/v1/cancelbatch",
origin: "/openai/v1/batches/batch-002/cancel",
mapping: map[string]string{
"openai/v1/cancelbatch": "/v1/batches/{batch_id}/cancel",
},
expected: "/v1/batches/batch-002/cancel",
},
{
name: "video placeholder is replaced",
apiName: "openai/v1/retrievevideo",
origin: "/openai/v1/videos/video-xyz",
mapping: map[string]string{
"openai/v1/retrievevideo": "/v1/videos/{video_id}",
},
expected: "/v1/videos/video-xyz",
},
{
name: "video content placeholder with query",
apiName: "openai/v1/retrievevideocontent",
origin: "/openai/v1/videos/video-xyz/content?variant=thumbnail",
mapping: map[string]string{
"openai/v1/retrievevideocontent": "/v1/videos/{video_id}/content",
},
expected: "/v1/videos/video-xyz/content?variant=thumbnail",
},
{
name: "video remix placeholder is replaced",
apiName: "openai/v1/videoremix",
origin: "/openai/v1/videos/video-xyz/remix",
mapping: map[string]string{
"openai/v1/videoremix": "/v1/videos/{video_id}/remix",
},
expected: "/v1/videos/video-xyz/remix",
},
{
name: "non placeholder mapping returns mapped path directly",
apiName: "openai/v1/videos",
origin: "/openai/v1/videos",
mapping: map[string]string{
"openai/v1/videos": "/v1/videos",
},
expected: "/v1/videos",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got := util.MapRequestPathByCapability(tc.apiName, tc.origin, tc.mapping)
if got != tc.expected {
t.Fatalf("expected %q, got %q", tc.expected, got)
}
})
}
}

View File

@@ -0,0 +1,888 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置Vertex 标准模式配置
var basicVertexConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
},
})
return data
}()
// 测试配置Vertex Express Mode 配置(使用 apiTokens
var vertexExpressModeConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"apiTokens": []string{"test-api-key-123456789"},
},
})
return data
}()
// 测试配置Vertex Express Mode 配置(含模型映射)
var vertexExpressModeWithModelMappingConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"apiTokens": []string{"test-api-key-123456789"},
"modelMapping": map[string]string{
"gpt-4": "gemini-2.5-flash",
"gpt-3.5-turbo": "gemini-2.5-flash-lite",
"text-embedding-ada-002": "text-embedding-001",
},
},
})
return data
}()
// 测试配置Vertex Express Mode 配置(含安全设置)
var vertexExpressModeWithSafetyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"apiTokens": []string{"test-api-key-123456789"},
"geminiSafetySetting": map[string]string{
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_LOW_AND_ABOVE",
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
},
},
})
return data
}()
// 测试配置:无效 Vertex 标准模式配置(缺少 vertexAuthKey
var invalidVertexStandardModeConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
// 缺少必需的标准模式配置
},
})
return data
}()
// 测试配置Vertex OpenAI 兼容模式配置
var vertexOpenAICompatibleModeConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"vertexOpenAICompatible": true,
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
},
})
return data
}()
// 测试配置Vertex OpenAI 兼容模式配置(含模型映射)
var vertexOpenAICompatibleModeWithModelMappingConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"vertexOpenAICompatible": true,
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
"vertexRegion": "us-central1",
"vertexProjectId": "test-project-id",
"vertexAuthServiceName": "test-auth-service",
"modelMapping": map[string]string{
"gpt-4": "gemini-2.0-flash",
"gpt-3.5-turbo": "gemini-1.5-flash",
},
},
})
return data
}()
// 测试配置:无效配置 - Express Mode 与 OpenAI 兼容模式互斥
var invalidVertexExpressAndOpenAICompatibleConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "vertex",
"apiTokens": []string{"test-api-key"},
"vertexOpenAICompatible": true,
},
})
return data
}()
func RunVertexParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试 Vertex 标准模式配置解析
t.Run("vertex standard mode config", func(t *testing.T) {
host, status := test.NewTestHost(basicVertexConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Vertex Express Mode 配置解析
t.Run("vertex express mode config", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Vertex Express Mode 配置(含模型映射)
t.Run("vertex express mode with model mapping config", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效 Vertex 标准模式配置(缺少 vertexAuthKey
t.Run("invalid vertex standard mode config - missing auth key", func(t *testing.T) {
host, status := test.NewTestHost(invalidVertexStandardModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试 Vertex Express Mode 配置(含安全设置)
t.Run("vertex express mode with safety setting config", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Vertex OpenAI 兼容模式配置解析
t.Run("vertex openai compatible mode config", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试 Vertex OpenAI 兼容模式配置(含模型映射)
t.Run("vertex openai compatible mode with model mapping config", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效配置 - Express Mode 与 OpenAI 兼容模式互斥
t.Run("invalid config - express mode and openai compatible mode conflict", func(t *testing.T) {
host, status := test.NewTestHost(invalidVertexExpressAndOpenAICompatibleConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
func RunVertexExpressModeOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Express Mode 请求头处理(聊天完成接口)
t.Run("vertex express mode chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为 vertex 域名Express Mode 使用不带 region 前缀的域名)
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain without region prefix")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasVertexLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "vertex") {
hasVertexLogs = true
break
}
}
require.True(t, hasVertexLogs, "Should have vertex processing logs")
})
// 测试 Vertex Express Mode 请求头处理(嵌入接口)
t.Run("vertex express mode embeddings request headers", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证嵌入接口的请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host转换
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain")
})
})
}
func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Express Mode 请求体处理(聊天完成接口)
t.Run("vertex express mode chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// Express Mode 不需要暂停等待 OAuth token
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证请求体被转换为 Vertex 格式
require.Contains(t, string(processedBody), "contents", "Request should be converted to vertex format")
require.Contains(t, string(processedBody), "generationConfig", "Request should contain vertex generation config")
// 验证路径包含 API Key
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format without project/location")
// 验证没有 Authorization headerExpress Mode 使用 URL 参数)
hasAuthHeader := false
for _, header := range requestHeaders {
if header[0] == "Authorization" && header[1] != "" {
hasAuthHeader = true
break
}
}
require.False(t, hasAuthHeader, "Authorization header should be removed in Express Mode")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasVertexLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "vertex") {
hasVertexLogs = true
break
}
}
require.True(t, hasVertexLogs, "Should have vertex processing logs")
})
// 测试 Vertex Express Mode 请求体处理(嵌入接口)
t.Run("vertex express mode embeddings request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-001","input":"test text"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证嵌入接口的请求体处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证请求体被转换为 Vertex 格式
require.Contains(t, string(processedBody), "instances", "Request should be converted to vertex format")
// 验证路径包含 API Key
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
})
// 测试 Vertex Express Mode 请求体处理(流式请求)
t.Run("vertex express mode streaming request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置流式请求体
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证路径包含流式 action
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "streamGenerateContent", "Path should contain streaming action")
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
})
// 测试 Vertex Express Mode 请求体处理(含模型映射)
t.Run("vertex express mode with model mapping request body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体(使用 OpenAI 模型名)
requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证路径包含映射后的模型名
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
})
})
}
func RunVertexExpressModeOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Express Mode 响应体处理(聊天完成接口)
t.Run("vertex express mode chat completion response body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应属性确保IsResponseFromUpstream()返回true
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体Vertex 格式)
responseBody := `{
"candidates": [{
"content": {
"parts": [{
"text": "Hello! How can I help you today?"
}]
},
"finishReason": "STOP",
"index": 0
}],
"usageMetadata": {
"promptTokenCount": 9,
"candidatesTokenCount": 12,
"totalTokenCount": 21
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 验证响应体内容转换为OpenAI格式
responseStr := string(processedResponseBody)
// 检查响应体是否被转换
if strings.Contains(responseStr, "chat.completion") {
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
require.Contains(t, responseStr, "usage", "Response should contain usage information")
}
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasResponseBodyLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "vertex") {
hasResponseBodyLogs = true
break
}
}
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
})
})
}
func RunVertexOpenAICompatibleModeOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex OpenAI 兼容模式请求头处理
t.Run("vertex openai compatible mode request headers", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为 vertex 域名(带 region 前缀)
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "us-central1-aiplatform.googleapis.com"), "Host header should be changed to vertex domain with region prefix")
})
})
}
func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex OpenAI 兼容模式请求体处理(不转换格式,保持 OpenAI 格式)
t.Run("vertex openai compatible mode request body - no format conversion", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体OpenAI 格式)
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// OpenAI 兼容模式需要等待 OAuth token所以返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 验证请求体保持 OpenAI 格式(不转换为 Vertex 原生格式)
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// OpenAI 兼容模式应该保持 messages 字段,而不是转换为 contents
require.Contains(t, string(processedBody), "messages", "Request should keep OpenAI format with messages field")
require.NotContains(t, string(processedBody), "contents", "Request should NOT be converted to vertex native format")
// 验证路径为 OpenAI 兼容端点
requestHeaders := host.GetRequestHeaders()
pathHeader := ""
for _, header := range requestHeaders {
if header[0] == ":path" {
pathHeader = header[1]
break
}
}
require.Contains(t, pathHeader, "/v1beta1/projects/", "Path should use OpenAI compatible endpoint format")
require.Contains(t, pathHeader, "/endpoints/openapi/chat/completions", "Path should contain openapi chat completions endpoint")
})
// 测试 Vertex OpenAI 兼容模式请求体处理(含模型映射)
t.Run("vertex openai compatible mode with model mapping", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体(使用 OpenAI 模型名)
requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionPause, action)
// 验证请求体中的模型名被映射
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 模型名应该被映射为 gemini-2.0-flash
require.Contains(t, string(processedBody), "gemini-2.0-flash", "Model name should be mapped to gemini-2.0-flash")
})
// 测试 Vertex OpenAI 兼容模式不支持 Embeddings API
t.Run("vertex openai compatible mode - embeddings not supported", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-001","input":"test text"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
// OpenAI 兼容模式只支持 chat completionsembeddings 应该返回错误
require.Equal(t, types.ActionContinue, action)
})
})
}
func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex Express Mode 流式响应处理
t.Run("vertex express mode streaming response body", func(t *testing.T) {
host, status := test.NewTestHost(vertexExpressModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置流式请求体
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置流式响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "text/event-stream"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 模拟流式响应体
chunk1 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":5,"totalTokenCount":14}}`
chunk2 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":12,"totalTokenCount":21}}`
// 处理流式响应体
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
require.Equal(t, types.ActionContinue, action1)
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
require.Equal(t, types.ActionContinue, action2)
// 验证流式响应处理
debugLogs := host.GetDebugLogs()
hasStreamingLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "vertex") {
hasStreamingLogs = true
break
}
}
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
})
})
}
func RunVertexOpenAICompatibleModeOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试 Vertex OpenAI 兼容模式响应体处理(直接透传,不转换格式)
t.Run("vertex openai compatible mode response body - passthrough", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应属性确保IsResponseFromUpstream()返回true
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体OpenAI 格式 - 因为 Vertex AI OpenAI-compatible API 返回的就是 OpenAI 格式)
responseBody := `{
"id": "chatcmpl-abc123",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you today?"
},
"finish_reason": "stop"
}],
"created": 1729986750,
"model": "gemini-2.0-flash",
"object": "chat.completion",
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体被直接透传(不进行格式转换)
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 响应应该保持原样
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "chatcmpl-abc123", "Response should be passed through unchanged")
require.Contains(t, responseStr, "chat.completion", "Response should contain original object type")
})
// 测试 Vertex OpenAI 兼容模式流式响应处理(直接透传)
t.Run("vertex openai compatible mode streaming response - passthrough", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置流式请求体
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置流式响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "text/event-stream"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 模拟 OpenAI 格式的流式响应Vertex AI OpenAI-compatible API 返回)
chunk1 := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}`
chunk2 := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":"stop"}]}`
// 处理流式响应体 - 应该直接透传
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
require.Equal(t, types.ActionContinue, action1)
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
require.Equal(t, types.ActionContinue, action2)
})
// 测试 Vertex OpenAI 兼容模式流式响应处理Unicode 转义解码)
t.Run("vertex openai compatible mode streaming response - unicode escape decoding", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置流式请求体
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置流式响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "text/event-stream"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 模拟带有 Unicode 转义的流式响应Vertex AI OpenAI-compatible API 可能返回的格式)
// \u4e2d\u6587 = 中文
chunkWithUnicode := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4e2d\u6587\u6d4b\u8bd5"},"finish_reason":null}]}`
// 处理流式响应体 - 应该解码 Unicode 转义
action := host.CallOnHttpStreamingResponseBody([]byte(chunkWithUnicode), false)
require.Equal(t, types.ActionContinue, action)
// 验证响应体中的 Unicode 转义已被解码
responseBody := host.GetResponseBody()
require.NotNil(t, responseBody)
responseStr := string(responseBody)
// 应该包含解码后的中文字符,而不是 \uXXXX 转义序列
require.Contains(t, responseStr, "中文测试", "Unicode escapes should be decoded to Chinese characters")
require.NotContains(t, responseStr, `\u4e2d`, "Should not contain Unicode escape sequences")
})
// 测试 Vertex OpenAI 兼容模式非流式响应处理Unicode 转义解码)
t.Run("vertex openai compatible mode response body - unicode escape decoding", func(t *testing.T) {
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 模拟带有 Unicode 转义的响应体
// \u76c8\u5229\u80fd\u529b = 盈利能力
responseBodyWithUnicode := `{"id":"chatcmpl-abc123","object":"chat.completion","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"message":{"role":"assistant","content":"\u76c8\u5229\u80fd\u529b\u5206\u6790"},"finish_reason":"stop"}]}`
// 处理响应体 - 应该解码 Unicode 转义
action := host.CallOnHttpResponseBody([]byte(responseBodyWithUnicode))
require.Equal(t, types.ActionContinue, action)
// 验证响应体中的 Unicode 转义已被解码
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
responseStr := string(processedResponseBody)
// 应该包含解码后的中文字符
require.Contains(t, responseStr, "盈利能力分析", "Unicode escapes should be decoded to Chinese characters")
require.NotContains(t, responseStr, `\u76c8`, "Should not contain Unicode escape sequences")
})
})
}

View File

@@ -93,6 +93,19 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s
if !exist {
return ""
}
mappedPathOnly := mappedPath
mappedQuery := ""
if queryIndex := strings.Index(mappedPathOnly, "?"); queryIndex >= 0 {
mappedPathOnly = mappedPathOnly[:queryIndex]
mappedQuery = mappedPath[queryIndex:]
}
// 将查询字符串从原始路径中剥离,避免干扰正则匹配 video_id 等占位符
pathOnly := originPath
query := ""
if queryIndex := strings.Index(originPath, "?"); queryIndex >= 0 {
pathOnly = originPath[:queryIndex]
query = originPath[queryIndex:]
}
if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") {
replacements := []struct {
regx *regexp.Regexp
@@ -108,8 +121,8 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s
}
for _, r := range replacements {
if r.regx.MatchString(originPath) {
subMatch := r.regx.FindStringSubmatch(originPath)
if r.regx.MatchString(pathOnly) {
subMatch := r.regx.FindStringSubmatch(pathOnly)
if subMatch == nil {
continue
}
@@ -118,12 +131,25 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s
continue
}
id := subMatch[index]
mappedPath = r.regx.ReplaceAllStringFunc(mappedPath, func(s string) string {
mappedPathOnly = r.regx.ReplaceAllStringFunc(mappedPathOnly, func(s string) string {
return strings.Replace(s, "{"+r.key+"}", id, 1)
})
}
}
}
if mappedQuery != "" {
mappedPath = mappedPathOnly + mappedQuery
} else {
mappedPath = mappedPathOnly
}
if query != "" {
// 保留原始查询参数,例如 variant=thumbnail
if strings.Contains(mappedPath, "?") {
mappedPath = mappedPath + "&" + strings.TrimPrefix(query, "?")
} else {
mappedPath += query
}
}
return mappedPath
}

View File

@@ -1,6 +1,10 @@
package util
import "regexp"
import (
"regexp"
"strconv"
"strings"
)
func StripPrefix(s string, prefix string) string {
if len(prefix) != 0 && len(s) >= len(prefix) && s[0:len(prefix)] == prefix {
@@ -18,3 +22,43 @@ func MatchStatus(status string, patterns []string) bool {
}
return false
}
// unicodeEscapeRegex matches Unicode escape sequences like \uXXXX
var unicodeEscapeRegex = regexp.MustCompile(`\\u([0-9a-fA-F]{4})`)
// DecodeUnicodeEscapes decodes Unicode escape sequences (\uXXXX) in a string to UTF-8 characters.
// This is useful when a JSON response contains ASCII-safe encoded non-ASCII characters.
func DecodeUnicodeEscapes(input []byte) []byte {
result := unicodeEscapeRegex.ReplaceAllFunc(input, func(match []byte) []byte {
// match is like \uXXXX, extract the hex part (XXXX)
hexStr := string(match[2:6])
codePoint, err := strconv.ParseInt(hexStr, 16, 32)
if err != nil {
return match // return original if parse fails
}
return []byte(string(rune(codePoint)))
})
return result
}
// DecodeUnicodeEscapesInSSE decodes Unicode escape sequences in SSE formatted data.
// It processes each line that starts with "data: " and decodes Unicode escapes in the JSON payload.
func DecodeUnicodeEscapesInSSE(input []byte) []byte {
lines := strings.Split(string(input), "\n")
var result strings.Builder
for i, line := range lines {
if strings.HasPrefix(line, "data: ") {
// Decode Unicode escapes in the JSON payload
jsonData := line[6:]
decodedData := DecodeUnicodeEscapes([]byte(jsonData))
result.WriteString("data: ")
result.Write(decodedData)
} else {
result.WriteString(line)
}
if i < len(lines)-1 {
result.WriteString("\n")
}
}
return []byte(result.String())
}

View File

@@ -0,0 +1,108 @@
package util
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDecodeUnicodeEscapes(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Chinese characters",
input: `\u4e2d\u6587\u6d4b\u8bd5`,
expected: `中文测试`,
},
{
name: "Mixed content",
input: `Hello \u4e16\u754c World`,
expected: `Hello 世界 World`,
},
{
name: "No escape sequences",
input: `Hello World`,
expected: `Hello World`,
},
{
name: "JSON with Unicode escapes",
input: `{"content":"\u76c8\u5229\u80fd\u529b"}`,
expected: `{"content":"盈利能力"}`,
},
{
name: "Full width parentheses",
input: `\uff08\u76c8\u5229\uff09`,
expected: `(盈利)`,
},
{
name: "Empty string",
input: ``,
expected: ``,
},
{
name: "Invalid escape sequence (not modified)",
input: `\u00GG`,
expected: `\u00GG`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DecodeUnicodeEscapes([]byte(tt.input))
assert.Equal(t, tt.expected, string(result))
})
}
}
func TestDecodeUnicodeEscapesInSSE(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "SSE data with Unicode escapes",
input: `data: {"choices":[{"delta":{"content":"\u4e2d\u6587"}}]}
`,
expected: `data: {"choices":[{"delta":{"content":"中文"}}]}
`,
},
{
name: "Multiple SSE data lines",
input: `data: {"content":"\u4e2d\u6587"}
data: {"content":"\u82f1\u6587"}
data: [DONE]
`,
expected: `data: {"content":"中文"}
data: {"content":"英文"}
data: [DONE]
`,
},
{
name: "Non-data lines unchanged",
input: ": comment\nevent: message\ndata: test\n",
expected: ": comment\nevent: message\ndata: test\n",
},
{
name: "Real Vertex AI response format",
input: `data: {"choices":[{"delta":{"content":"\uff08\u76c8\u5229\u80fd\u529b\uff09","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
`,
expected: `data: {"choices":[{"delta":{"content":"(盈利能力)","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DecodeUnicodeEscapesInSSE([]byte(tt.input))
assert.Equal(t, tt.expected, string(result))
})
}
}

View File

@@ -0,0 +1,585 @@
package config
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
MaxRisk = "max"
HighRisk = "high"
MediumRisk = "medium"
LowRisk = "low"
NoRisk = "none"
S4Sensitive = "s4"
S3Sensitive = "s3"
S2Sensitive = "s2"
S1Sensitive = "s1"
NoSensitive = "s0"
ContentModerationType = "contentModeration"
PromptAttackType = "promptAttack"
SensitiveDataType = "sensitiveData"
MaliciousUrlDataType = "maliciousUrl"
ModelHallucinationDataType = "modelHallucination"
// Default configurations
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
DefaultTimeout = 2000
AliyunUserAgent = "CIPFrom/AIGateway"
LengthLimit = 1800
DefaultRequestCheckService = "llm_query_moderation"
DefaultResponseCheckService = "llm_response_moderation"
DefaultRequestJsonPath = "messages.@reverse.0.content"
DefaultResponseJsonPath = "choices.0.message.content"
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
// Actions
MultiModalGuard = "MultiModalGuard"
MultiModalGuardForBase64 = "MultiModalGuardForBase64"
TextModerationPlus = "TextModerationPlus"
// Services
DefaultMultiModalGuardTextInputCheckService = "query_security_check"
DefaultMultiModalGuardTextOutputCheckService = "response_security_check"
DefaultMultiModalGuardImageInputCheckService = "img_query_security_check"
DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
)
// api types
const (
ApiTextGeneration = "text_generation"
ApiImageGeneration = "image_generation"
)
// provider types
const (
ProviderOpenAI = "openai"
ProviderQwen = "qwen"
ProviderComfyUI = "comfyui"
)
type Response struct {
Code int `json:"Code"`
Message string `json:"Message"`
RequestId string `json:"RequestId"`
Data Data `json:"Data"`
}
type Data struct {
RiskLevel string `json:"RiskLevel,omitempty"`
AttackLevel string `json:"AttackLevel,omitempty"`
Result []Result `json:"Result,omitempty"`
Advice []Advice `json:"Advice,omitempty"`
Detail []Detail `json:"Detail,omitempty"`
}
type Result struct {
RiskWords string `json:"RiskWords,omitempty"`
Description string `json:"Description,omitempty"`
Confidence float64 `json:"Confidence,omitempty"`
Label string `json:"Label,omitempty"`
}
type Advice struct {
Answer string `json:"Answer,omitempty"`
HitLabel string `json:"HitLabel,omitempty"`
HitLibName string `json:"HitLibName,omitempty"`
}
type Detail struct {
Suggestion string `json:"Suggestion,omitempty"`
Type string `json:"Type,omitempty"`
Level string `json:"Level,omitempty"`
}
type Matcher struct {
Exact string
Prefix string
Re *regexp.Regexp
}
func (m *Matcher) match(consumer string) bool {
if m.Exact != "" {
return consumer == m.Exact
} else if m.Prefix != "" {
return strings.HasPrefix(consumer, m.Prefix)
} else if m.Re != nil {
return m.Re.MatchString(consumer)
} else {
return false
}
}
type AISecurityConfig struct {
Client wrapper.HttpClient
Host string
AK string
SK string
Token string
Action string
CheckRequest bool
CheckRequestImage bool
RequestCheckService string
RequestImageCheckService string
RequestContentJsonPath string
CheckResponse bool
ResponseCheckService string
ResponseImageCheckService string
ResponseContentJsonPath string
ResponseStreamContentJsonPath string
DenyCode int64
DenyMessage string
ProtocolOriginal bool
RiskLevelBar string
ContentModerationLevelBar string
PromptAttackLevelBar string
SensitiveDataLevelBar string
MaliciousUrlLevelBar string
ModelHallucinationLevelBar string
Timeout uint32
BufferLimit int
Metrics map[string]proxywasm.MetricCounter
ConsumerRequestCheckService []map[string]interface{}
ConsumerResponseCheckService []map[string]interface{}
ConsumerRiskLevel []map[string]interface{}
// text_generation, image_generation, etc.
ApiType string
// openai, qwen, comfyui, etc.
ProviderType string
}
func (config *AISecurityConfig) Parse(json gjson.Result) error {
serviceName := json.Get("serviceName").String()
servicePort := json.Get("servicePort").Int()
serviceHost := json.Get("serviceHost").String()
config.Host = serviceHost
if serviceName == "" || servicePort == 0 || serviceHost == "" {
return errors.New("invalid service config")
}
config.AK = json.Get("accessKey").String()
config.SK = json.Get("secretKey").String()
if config.AK == "" || config.SK == "" {
return errors.New("invalid AK/SK config")
}
config.Token = json.Get("securityToken").String()
// set action
if obj := json.Get("action"); obj.Exists() {
config.Action = json.Get("action").String()
} else {
config.Action = TextModerationPlus
}
// set default values
config.SetDefaultValues()
// set values
if obj := json.Get("riskLevelBar"); obj.Exists() {
config.RiskLevelBar = obj.String()
}
if obj := json.Get("requestCheckService"); obj.Exists() {
config.RequestCheckService = obj.String()
}
if obj := json.Get("requestImageCheckService"); obj.Exists() {
config.RequestImageCheckService = obj.String()
}
if obj := json.Get("responseCheckService"); obj.Exists() {
config.ResponseCheckService = obj.String()
}
if obj := json.Get("responseImageCheckService"); obj.Exists() {
config.ResponseImageCheckService = obj.String()
}
config.CheckRequest = json.Get("checkRequest").Bool()
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
config.CheckResponse = json.Get("checkResponse").Bool()
config.ProtocolOriginal = json.Get("protocol").String() == "original"
config.DenyMessage = json.Get("denyMessage").String()
if obj := json.Get("denyCode"); obj.Exists() {
config.DenyCode = obj.Int()
}
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
config.RequestContentJsonPath = obj.String()
}
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
config.ResponseContentJsonPath = obj.String()
}
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
config.ResponseStreamContentJsonPath = obj.String()
}
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
config.ContentModerationLevelBar = obj.String()
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
config.PromptAttackLevelBar = obj.String()
if LevelToInt(config.PromptAttackLevelBar) <= 0 {
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
config.SensitiveDataLevelBar = obj.String()
if LevelToInt(config.SensitiveDataLevelBar) <= 0 {
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
}
}
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
config.ModelHallucinationLevelBar = obj.String()
if LevelToInt(config.ModelHallucinationLevelBar) <= 0 {
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
config.MaliciousUrlLevelBar = obj.String()
if LevelToInt(config.MaliciousUrlLevelBar) <= 0 {
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
}
}
if obj := json.Get("timeout"); obj.Exists() {
config.Timeout = uint32(obj.Int())
}
if obj := json.Get("bufferLimit"); obj.Exists() {
config.BufferLimit = int(obj.Int())
}
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
for _, item := range json.Get("consumerRequestCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m)
}
}
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
for _, item := range json.Get("consumerResponseCheckService").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m)
}
}
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
for _, item := range json.Get("consumerRiskLevel").Array() {
m := make(map[string]interface{})
for k, v := range item.Map() {
m[k] = v.Value()
}
consumerName, ok1 := m["name"]
matchType, ok2 := m["matchType"]
if !ok1 || !ok2 {
continue
}
switch fmt.Sprint(matchType) {
case "exact":
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
case "prefix":
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
case "regexp":
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
}
config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m)
}
}
if obj := json.Get("apiType"); obj.Exists() {
config.ApiType = obj.String()
}
if obj := json.Get("providerType"); obj.Exists() {
config.ProviderType = obj.String()
}
config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
Host: serviceHost,
})
config.Metrics = make(map[string]proxywasm.MetricCounter)
return nil
}
func (config *AISecurityConfig) SetDefaultValues() {
switch config.Action {
case TextModerationPlus:
config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService
config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService
case MultiModalGuard:
config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService
config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService
config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService
}
config.RiskLevelBar = HighRisk
config.DenyCode = DefaultDenyCode
config.RequestContentJsonPath = DefaultRequestJsonPath
config.ResponseContentJsonPath = DefaultResponseJsonPath
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
config.ContentModerationLevelBar = MaxRisk
config.PromptAttackLevelBar = MaxRisk
config.SensitiveDataLevelBar = S4Sensitive
config.ModelHallucinationLevelBar = MaxRisk
config.MaliciousUrlLevelBar = MaxRisk
config.Timeout = DefaultTimeout
config.BufferLimit = 1000
config.ApiType = ApiTextGeneration
config.ProviderType = ProviderOpenAI
}
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
counter, ok := config.Metrics[metricName]
if !ok {
counter = proxywasm.DefineCounterMetric(metricName)
config.Metrics[metricName] = counter
}
counter.Increment(inc)
}
func (config *AISecurityConfig) GetRequestCheckService(consumer string) string {
result := config.RequestCheckService
for _, obj := range config.ConsumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string {
result := config.RequestImageCheckService
for _, obj := range config.ConsumerRequestCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if requestCheckService, ok := obj["requestImageCheckService"]; ok {
result, _ = requestCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetResponseCheckService(consumer string) string {
result := config.ResponseCheckService
for _, obj := range config.ConsumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string {
result := config.ResponseImageCheckService
for _, obj := range config.ConsumerResponseCheckService {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if responseCheckService, ok := obj["responseImageCheckService"]; ok {
result, _ = responseCheckService.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string {
result := config.RiskLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
result, _ = riskLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string {
result := config.ContentModerationLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
result, _ = contentModerationLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string {
result := config.PromptAttackLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
result, _ = promptAttackLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string {
result := config.SensitiveDataLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
result, _ = sensitiveDataLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string {
result := config.MaliciousUrlLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
result, _ = maliciousUrlLevelBar.(string)
}
break
}
}
}
return result
}
func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string {
result := config.ModelHallucinationLevelBar
for _, obj := range config.ConsumerRiskLevel {
if matcher, ok := obj["matcher"].(Matcher); ok {
if matcher.match(consumer) {
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
result, _ = modelHallucinationLevelBar.(string)
}
break
}
}
}
return result
}
func LevelToInt(riskLevel string) int {
// First check against our defined constants
switch strings.ToLower(riskLevel) {
case MaxRisk, S4Sensitive:
return 4
case HighRisk, S3Sensitive:
return 3
case MediumRisk, S2Sensitive:
return 2
case LowRisk, S1Sensitive:
return 1
case NoRisk, NoSensitive:
return 0
default:
return -1
}
}
func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
if action == MultiModalGuard || action == MultiModalGuardForBase64 {
// Check top-level risk levels for MultiModalGuard
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
return false
}
// Also check AttackLevel for prompt attack detection
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
return false
}
// Check detailed results for backward compatibility
for _, detail := range data.Detail {
switch detail.Type {
case ContentModerationType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
return false
}
case PromptAttackType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
return false
}
case SensitiveDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
return false
}
case MaliciousUrlDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
return false
}
case ModelHallucinationDataType:
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
return false
}
}
}
return true
} else {
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
}
}

View File

@@ -5,8 +5,8 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
@@ -20,5 +20,6 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -4,8 +4,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
@@ -24,6 +28,8 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,249 @@
package common
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"sort"
"golang.org/x/exp/maps"
"fmt"
"net/url"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/google/uuid"
)
const (
ALGORITHM = "ACS3-HMAC-SHA256"
)
type Request struct {
httpMethod string
canonicalUri string
host string
xAcsAction string
xAcsVersion string
headers map[string]string
body []byte
queryParam map[string]interface{}
}
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
req := &Request{
httpMethod: httpMethod,
canonicalUri: canonicalUri,
host: host,
xAcsAction: xAcsAction,
xAcsVersion: xAcsVersion,
headers: make(map[string]string),
queryParam: make(map[string]interface{}),
}
req.headers["host"] = host
req.headers["x-acs-action"] = xAcsAction
req.headers["x-acs-version"] = xAcsVersion
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
req.headers["x-acs-signature-nonce"] = uuid.New().String()
return req
}
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
newQueryParams := make(map[string]interface{})
processObject(newQueryParams, "", req.queryParam)
req.queryParam = newQueryParams
canonicalQueryString := ""
keys := maps.Keys(req.queryParam)
sort.Strings(keys)
for _, k := range keys {
v := req.queryParam[k]
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
}
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
var bodyContent []byte
if req.body == nil {
bodyContent = []byte("")
} else {
bodyContent = req.body
}
hashedRequestPayload := sha256Hex(bodyContent)
req.headers["x-acs-content-sha256"] = hashedRequestPayload
if SecurityToken != "" {
req.headers["x-acs-security-token"] = SecurityToken
}
canonicalHeaders := ""
signedHeaders := ""
HeadersKeys := maps.Keys(req.headers)
sort.Strings(HeadersKeys)
for _, k := range HeadersKeys {
lowerKey := strings.ToLower(k)
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
signedHeaders += lowerKey + ";"
}
}
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
if err != nil {
fmt.Println(err)
panic(err)
}
signature := strings.ToLower(hex.EncodeToString(byteData))
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
req.headers["Authorization"] = authorization
}
func hmac256(key []byte, toSignString string) ([]byte, error) {
h := hmac.New(sha256.New, key)
_, err := h.Write([]byte(toSignString))
if err != nil {
return nil, err
}
return h.Sum(nil), nil
}
func sha256Hex(byteArray []byte) string {
hash := sha256.New()
_, _ = hash.Write(byteArray)
hexString := hex.EncodeToString(hash.Sum(nil))
return hexString
}
func percentCode(str string) string {
str = strings.ReplaceAll(str, "+", "%20")
str = strings.ReplaceAll(str, "*", "%2A")
str = strings.ReplaceAll(str, "%7E", "~")
return str
}
func formDataToString(formData map[string]interface{}) *string {
tmp := make(map[string]interface{})
processObject(tmp, "", formData)
res := ""
urlEncoder := url.Values{}
for key, value := range tmp {
v := fmt.Sprintf("%v", value)
urlEncoder.Add(key, v)
}
res = urlEncoder.Encode()
return &res
}
// processObject 递归处理对象将复杂对象如Map和List展开为平面的键值对
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
if value == nil {
return
}
switch v := value.(type) {
case []interface{}:
for i, item := range v {
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
}
case map[string]interface{}:
for subKey, subValue := range v {
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
}
default:
if strings.HasPrefix(key, ".") {
key = key[1:]
}
if b, ok := v.([]byte); ok {
mapResult[key] = string(b)
} else {
mapResult[key] = fmt.Sprintf("%v", v)
}
}
}
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
httpMethod := "POST"
canonicalUri := "/"
xAcsVersion := "2022-03-02"
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
req.queryParam["Service"] = checkService
body := make(map[string]interface{})
serviceParameters := make(map[string]interface{})
serviceParameters["content"] = text
serviceParameters["sessionId"] = sessionID
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
serviceParametersJSON, _ := json.Marshal(serviceParameters)
body["ServiceParameters"] = serviceParametersJSON
str := formDataToString(body)
req.body = []byte(*str)
req.headers["content-type"] = "application/x-www-form-urlencoded"
req.headers["User-Agent"] = cfg.AliyunUserAgent
getAuthorization(req, config.AK, config.SK, config.Token)
q := url.Values{}
keys := maps.Keys(req.queryParam)
sort.Strings(keys)
for _, k := range keys {
v := req.queryParam[k]
q.Set(k, fmt.Sprintf("%v", v))
}
for k, v := range req.headers {
if k != "host" {
headers = append(headers, [2]string{k, v})
}
}
return "?" + q.Encode(), headers, req.body
}
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
httpMethod := "POST"
canonicalUri := "/"
xAcsVersion := "2022-03-02"
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
req.queryParam["Service"] = checkService
body := make(map[string]interface{})
serviceParameters := make(map[string]interface{})
if imgUrl != "" {
serviceParameters["imageUrls"] = []string{imgUrl}
}
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
serviceParametersJSON, _ := json.Marshal(serviceParameters)
body["ServiceParameters"] = serviceParametersJSON
if imgBase64 != "" {
body["ImageBase64Str"] = imgBase64
}
str := formDataToString(body)
req.body = []byte(*str)
req.headers["content-type"] = "application/x-www-form-urlencoded"
req.headers["User-Agent"] = cfg.AliyunUserAgent
getAuthorization(req, config.AK, config.SK, config.Token)
q := url.Values{}
keys := maps.Keys(req.queryParam)
sort.Strings(keys)
for _, k := range keys {
v := req.queryParam[k]
q.Set(k, fmt.Sprintf("%v", v))
}
for k, v := range req.headers {
// host will be added by envoy automatically
if k != "host" {
headers = append(headers, [2]string{k, v})
}
}
return "?" + q.Encode(), headers, req.body
}

View File

@@ -0,0 +1,634 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"encoding/hex"
"net/url"
"strings"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/stretchr/testify/require"
)
func TestSha256Hex(t *testing.T) {
t.Run("empty input", func(t *testing.T) {
result := sha256Hex([]byte(""))
// SHA256 of empty string
expected := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
require.Equal(t, expected, result)
})
t.Run("simple string", func(t *testing.T) {
result := sha256Hex([]byte("hello"))
expected := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
require.Equal(t, expected, result)
})
t.Run("unicode string", func(t *testing.T) {
result := sha256Hex([]byte("你好"))
// Just verify it returns a valid hex string
require.Len(t, result, 64)
_, err := hex.DecodeString(result)
require.NoError(t, err)
})
}
func TestHmac256(t *testing.T) {
t.Run("valid hmac", func(t *testing.T) {
key := []byte("test-key")
message := "test-message"
result, err := hmac256(key, message)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, result, 32) // SHA256 produces 32 bytes
})
t.Run("empty key", func(t *testing.T) {
key := []byte("")
message := "test-message"
result, err := hmac256(key, message)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, result, 32)
})
t.Run("empty message", func(t *testing.T) {
key := []byte("test-key")
message := ""
result, err := hmac256(key, message)
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, result, 32)
})
t.Run("verify hmac consistency", func(t *testing.T) {
key := []byte("test-key")
message := "test-message"
result1, err1 := hmac256(key, message)
result2, err2 := hmac256(key, message)
require.NoError(t, err1)
require.NoError(t, err2)
require.Equal(t, result1, result2)
})
}
func TestPercentCode(t *testing.T) {
t.Run("replace plus sign", func(t *testing.T) {
input := "test+value"
result := percentCode(input)
require.Equal(t, "test%20value", result)
})
t.Run("replace asterisk", func(t *testing.T) {
input := "test*value"
result := percentCode(input)
require.Equal(t, "test%2Avalue", result)
})
t.Run("replace tilde encoding", func(t *testing.T) {
input := "test%7Evalue"
result := percentCode(input)
require.Equal(t, "test~value", result)
})
t.Run("multiple replacements", func(t *testing.T) {
input := "test+value*test%7E"
result := percentCode(input)
require.Equal(t, "test%20value%2Atest~", result)
})
t.Run("no replacements needed", func(t *testing.T) {
input := "test-value"
result := percentCode(input)
require.Equal(t, "test-value", result)
})
}
func TestProcessObject(t *testing.T) {
t.Run("simple string value", func(t *testing.T) {
result := make(map[string]interface{})
processObject(result, "key", "value")
require.Equal(t, "value", result["key"])
})
t.Run("simple int value", func(t *testing.T) {
result := make(map[string]interface{})
processObject(result, "key", 123)
require.Equal(t, "123", result["key"])
})
t.Run("nil value", func(t *testing.T) {
result := make(map[string]interface{})
processObject(result, "key", nil)
require.Empty(t, result)
})
t.Run("map value", func(t *testing.T) {
result := make(map[string]interface{})
input := map[string]interface{}{
"subkey1": "value1",
"subkey2": "value2",
}
processObject(result, "key", input)
require.Equal(t, "value1", result["key.subkey1"])
require.Equal(t, "value2", result["key.subkey2"])
})
t.Run("array value", func(t *testing.T) {
result := make(map[string]interface{})
input := []interface{}{"item1", "item2", "item3"}
processObject(result, "key", input)
require.Equal(t, "item1", result["key.1"])
require.Equal(t, "item2", result["key.2"])
require.Equal(t, "item3", result["key.3"])
})
t.Run("nested map", func(t *testing.T) {
result := make(map[string]interface{})
input := map[string]interface{}{
"level1": map[string]interface{}{
"level2": "value",
},
}
processObject(result, "key", input)
require.Equal(t, "value", result["key.level1.level2"])
})
t.Run("nested array", func(t *testing.T) {
result := make(map[string]interface{})
input := []interface{}{
[]interface{}{"nested1", "nested2"},
}
processObject(result, "key", input)
require.Equal(t, "nested1", result["key.1.1"])
require.Equal(t, "nested2", result["key.1.2"])
})
t.Run("key with leading dot", func(t *testing.T) {
result := make(map[string]interface{})
processObject(result, ".key", "value")
require.Equal(t, "value", result["key"])
})
t.Run("byte array value", func(t *testing.T) {
result := make(map[string]interface{})
input := []byte("test")
processObject(result, "key", input)
require.Equal(t, "test", result["key"])
})
t.Run("complex nested structure", func(t *testing.T) {
result := make(map[string]interface{})
input := map[string]interface{}{
"array": []interface{}{
map[string]interface{}{
"item": "value",
},
},
}
processObject(result, "key", input)
require.Equal(t, "value", result["key.array.1.item"])
})
}
func TestFormDataToString(t *testing.T) {
t.Run("simple map", func(t *testing.T) {
input := map[string]interface{}{
"key1": "value1",
"key2": "value2",
}
result := formDataToString(input)
require.NotNil(t, result)
require.Contains(t, *result, "key1=value1")
require.Contains(t, *result, "key2=value2")
})
t.Run("map with array", func(t *testing.T) {
input := map[string]interface{}{
"key": []interface{}{"item1", "item2"},
}
result := formDataToString(input)
require.NotNil(t, result)
require.Contains(t, *result, "key.1=item1")
require.Contains(t, *result, "key.2=item2")
})
t.Run("map with nested map", func(t *testing.T) {
input := map[string]interface{}{
"key": map[string]interface{}{
"subkey": "value",
},
}
result := formDataToString(input)
require.NotNil(t, result)
require.Contains(t, *result, "key.subkey=value")
})
t.Run("empty map", func(t *testing.T) {
input := map[string]interface{}{}
result := formDataToString(input)
require.NotNil(t, result)
require.Empty(t, *result)
})
t.Run("map with nil value", func(t *testing.T) {
input := map[string]interface{}{
"key1": "value1",
"key2": nil,
}
result := formDataToString(input)
require.NotNil(t, result)
require.Contains(t, *result, "key1=value1")
require.NotContains(t, *result, "key2")
})
}
func TestGenerateRequestForText(t *testing.T) {
config := cfg.AISecurityConfig{
Host: "security.example.com",
AK: "test-ak",
SK: "test-sk",
Token: "",
}
t.Run("basic text request", func(t *testing.T) {
path, headers, body := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"test content",
"test-session-id",
)
require.NotEmpty(t, path)
require.True(t, strings.HasPrefix(path, "?"))
require.Contains(t, path, "Service=llm_query_moderation")
require.NotEmpty(t, headers)
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
require.Equal(t, "TextModerationPlus", headerMap["x-acs-action"])
require.Equal(t, "2022-03-02", headerMap["x-acs-version"])
require.Equal(t, "application/x-www-form-urlencoded", headerMap["content-type"])
require.Equal(t, cfg.AliyunUserAgent, headerMap["User-Agent"])
require.Contains(t, headerMap, "Authorization")
require.Contains(t, headerMap, "x-acs-date")
require.Contains(t, headerMap, "x-acs-signature-nonce")
require.Contains(t, headerMap, "x-acs-content-sha256")
require.NotEmpty(t, body)
bodyStr := string(body)
require.Contains(t, bodyStr, "ServiceParameters")
// Body is URL encoded, so decode it to check content
decodedBody, err := url.QueryUnescape(bodyStr)
require.NoError(t, err)
require.Contains(t, decodedBody, "test content")
require.Contains(t, decodedBody, "test-session-id")
require.Contains(t, decodedBody, cfg.AliyunUserAgent)
})
t.Run("request with security token", func(t *testing.T) {
configWithToken := config
configWithToken.Token = "test-token"
path, headers, body := GenerateRequestForText(
configWithToken,
"TextModerationPlus",
"llm_query_moderation",
"test content",
"test-session-id",
)
require.NotEmpty(t, path)
require.NotEmpty(t, headers)
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
require.Equal(t, "test-token", headerMap["x-acs-security-token"])
require.NotEmpty(t, body)
})
t.Run("empty content", func(t *testing.T) {
path, headers, body := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"",
"test-session-id",
)
require.NotEmpty(t, path)
require.NotEmpty(t, headers)
require.NotEmpty(t, body)
bodyStr := string(body)
require.Contains(t, bodyStr, "ServiceParameters")
decodedBody, err := url.QueryUnescape(bodyStr)
require.NoError(t, err)
require.Contains(t, decodedBody, `"content":""`)
})
t.Run("different check service", func(t *testing.T) {
path, headers, body := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_response_moderation",
"test content",
"test-session-id",
)
require.Contains(t, path, "Service=llm_response_moderation")
require.NotEmpty(t, headers)
require.NotEmpty(t, body)
})
}
func TestGenerateRequestForImage(t *testing.T) {
config := cfg.AISecurityConfig{
Host: "security.example.com",
AK: "test-ak",
SK: "test-sk",
Token: "",
}
t.Run("image request with URL", func(t *testing.T) {
path, headers, body := GenerateRequestForImage(
config,
"MultiModalGuard",
"llm_image_moderation",
"https://example.com/image.jpg",
"",
)
require.NotEmpty(t, path)
require.True(t, strings.HasPrefix(path, "?"))
require.Contains(t, path, "Service=llm_image_moderation")
require.NotEmpty(t, headers)
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
require.Equal(t, "MultiModalGuard", headerMap["x-acs-action"])
require.Equal(t, "2022-03-02", headerMap["x-acs-version"])
require.Equal(t, "application/x-www-form-urlencoded", headerMap["content-type"])
require.Equal(t, cfg.AliyunUserAgent, headerMap["User-Agent"])
require.Contains(t, headerMap, "Authorization")
require.Contains(t, headerMap, "x-acs-date")
require.Contains(t, headerMap, "x-acs-signature-nonce")
require.Contains(t, headerMap, "x-acs-content-sha256")
require.NotEmpty(t, body)
bodyStr := string(body)
require.Contains(t, bodyStr, "ServiceParameters")
decodedBody, err := url.QueryUnescape(bodyStr)
require.NoError(t, err)
require.Contains(t, decodedBody, "https://example.com/image.jpg")
require.Contains(t, decodedBody, cfg.AliyunUserAgent)
})
t.Run("image request with base64", func(t *testing.T) {
base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
path, headers, body := GenerateRequestForImage(
config,
"MultiModalGuard",
"llm_image_moderation",
"",
base64Data,
)
require.NotEmpty(t, path)
require.NotEmpty(t, headers)
require.NotEmpty(t, body)
bodyStr := string(body)
require.Contains(t, bodyStr, "ImageBase64Str")
// Base64 data is URL encoded, decode to check
decodedBody, err := url.QueryUnescape(bodyStr)
require.NoError(t, err)
require.Contains(t, decodedBody, base64Data)
})
t.Run("image request with both URL and base64", func(t *testing.T) {
path, headers, body := GenerateRequestForImage(
config,
"MultiModalGuard",
"llm_image_moderation",
"https://example.com/image.jpg",
"base64data",
)
require.NotEmpty(t, path)
require.NotEmpty(t, headers)
require.NotEmpty(t, body)
bodyStr := string(body)
require.Contains(t, bodyStr, "ImageBase64Str")
decodedBody, err := url.QueryUnescape(bodyStr)
require.NoError(t, err)
require.Contains(t, decodedBody, "https://example.com/image.jpg")
require.Contains(t, decodedBody, "base64data")
})
t.Run("image request with security token", func(t *testing.T) {
configWithToken := config
configWithToken.Token = "test-token"
path, headers, body := GenerateRequestForImage(
configWithToken,
"MultiModalGuard",
"llm_image_moderation",
"https://example.com/image.jpg",
"",
)
require.NotEmpty(t, path)
require.NotEmpty(t, headers)
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
require.Equal(t, "test-token", headerMap["x-acs-security-token"])
require.NotEmpty(t, body)
})
t.Run("empty image URL and base64", func(t *testing.T) {
path, headers, body := GenerateRequestForImage(
config,
"MultiModalGuard",
"llm_image_moderation",
"",
"",
)
require.NotEmpty(t, path)
require.NotEmpty(t, headers)
require.NotEmpty(t, body)
bodyStr := string(body)
require.Contains(t, bodyStr, "ServiceParameters")
decodedBody, err := url.QueryUnescape(bodyStr)
require.NoError(t, err)
require.Contains(t, decodedBody, cfg.AliyunUserAgent)
require.NotContains(t, decodedBody, "imageUrls")
require.NotContains(t, decodedBody, "ImageBase64Str")
})
}
func TestNewRequest(t *testing.T) {
// Test newRequest indirectly through GenerateRequestForText
// Since it's a private function, we test it through public API
t.Run("request structure", func(t *testing.T) {
config := cfg.AISecurityConfig{
Host: "security.example.com",
AK: "test-ak",
SK: "test-sk",
Token: "",
}
path, headers, _ := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"test",
"session-id",
)
// Verify that newRequest was called correctly by checking headers
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
// Verify headers set by newRequest
require.Equal(t, "TextModerationPlus", headerMap["x-acs-action"])
require.Equal(t, "2022-03-02", headerMap["x-acs-version"])
require.Contains(t, headerMap, "x-acs-date")
require.Contains(t, headerMap, "x-acs-signature-nonce")
require.NotEmpty(t, path)
})
}
func TestGetAuthorization(t *testing.T) {
// Test getAuthorization indirectly through GenerateRequestForText
// Since it's a private function, we test it through public API
t.Run("authorization header format", func(t *testing.T) {
config := cfg.AISecurityConfig{
Host: "security.example.com",
AK: "test-ak",
SK: "test-sk",
Token: "",
}
_, headers, _ := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"test content",
"test-session-id",
)
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
authHeader := headerMap["Authorization"]
require.NotEmpty(t, authHeader)
require.Contains(t, authHeader, "ACS3-HMAC-SHA256")
require.Contains(t, authHeader, "Credential=test-ak")
require.Contains(t, authHeader, "SignedHeaders=")
require.Contains(t, authHeader, "Signature=")
// Verify content SHA256 is set
require.Contains(t, headerMap, "x-acs-content-sha256")
require.Len(t, headerMap["x-acs-content-sha256"], 64) // SHA256 hex string length
})
t.Run("authorization with security token", func(t *testing.T) {
config := cfg.AISecurityConfig{
Host: "security.example.com",
AK: "test-ak",
SK: "test-sk",
Token: "test-token",
}
_, headers, _ := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"test content",
"test-session-id",
)
headerMap := make(map[string]string)
for _, h := range headers {
headerMap[h[0]] = h[1]
}
require.Equal(t, "test-token", headerMap["x-acs-security-token"])
require.Contains(t, headerMap, "Authorization")
})
t.Run("authorization signature consistency", func(t *testing.T) {
config := cfg.AISecurityConfig{
Host: "security.example.com",
AK: "test-ak",
SK: "test-sk",
Token: "",
}
// Generate two requests with same content
_, headers1, body1 := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"test content",
"test-session-id",
)
_, headers2, body2 := GenerateRequestForText(
config,
"TextModerationPlus",
"llm_query_moderation",
"test content",
"test-session-id",
)
// Bodies should be the same (except for sessionId which is random)
require.NotEmpty(t, body1)
require.NotEmpty(t, body2)
// Headers should have authorization
headerMap1 := make(map[string]string)
for _, h := range headers1 {
headerMap1[h[0]] = h[1]
}
headerMap2 := make(map[string]string)
for _, h := range headers2 {
headerMap2[h[0]] = h[1]
}
require.Contains(t, headerMap1, "Authorization")
require.Contains(t, headerMap2, "Authorization")
// Signatures will be different due to nonce and timestamp, but format should be same
require.Contains(t, headerMap1["Authorization"], "ACS3-HMAC-SHA256")
require.Contains(t, headerMap2["Authorization"], "ACS3-HMAC-SHA256")
})
}

View File

@@ -0,0 +1,250 @@
package text
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("end_of_stream_received", false)
ctx.SetContext("during_call", false)
ctx.SetContext("risk_detected", false)
sessionID, _ := utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
if strings.Contains(contentType, "text/event-stream") {
ctx.NeedPauseStreamingResponse()
return types.ActionContinue
} else {
ctx.BufferResponseBody()
return types.HeaderStopIteration
}
}
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
consumer, _ := ctx.GetContext("consumer").(string)
var sessionID string
if ctx.GetContext("sessionID") == nil {
sessionID, _ = utils.GenerateHexID(20)
ctx.SetContext("sessionID", sessionID)
} else {
sessionID, _ = ctx.GetContext("sessionID").(string)
}
var bufferQueue [][]byte
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
ctx.SetContext("during_call", false)
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
ctx.SetContext("during_call", false)
return
}
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
denyMessage := cfg.DefaultDenyMessage
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = "\n" + response.Data.Advice[0].Answer
} else if config.DenyMessage != "" {
denyMessage = config.DenyMessage
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
return
}
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
bufferQueue = [][]byte{}
if !endStream {
ctx.SetContext("during_call", false)
singleCall()
}
}
singleCall = func() {
if ctx.GetContext("during_call").(bool) {
return
}
if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
var buffer string
for ctx.BufferQueueSize() > 0 {
front := ctx.PopBuffer()
bufferQueue = append(bufferQueue, front)
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
buffer += msg
if len([]rune(buffer)) >= config.BufferLimit {
break
}
}
// case 1: streaming body has reasoning_content, part of buffer maybe empty
// case 2: streaming body has toolcall result, part of buffer maybe empty
log.Debugf("current content piece: %s", buffer)
if len(buffer) == 0 {
buffer = "[empty content]"
}
ctx.SetContext("during_call", true)
log.Debugf("current content piece: %s", buffer)
checkService := config.GetResponseCheckService(consumer)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
if ctx.GetContext("end_of_stream_received").(bool) {
proxywasm.ResumeHttpResponse()
}
}
}
}
if !ctx.GetContext("risk_detected").(bool) {
unifiedChunk := wrapper.UnifySSEChunk(data)
hasTrailingSeparator := bytes.HasSuffix(unifiedChunk, []byte("\n\n"))
trimmedChunk := bytes.TrimSpace(unifiedChunk)
chunks := bytes.Split(trimmedChunk, []byte("\n\n"))
// Filter out empty chunks
nonEmptyChunks := make([][]byte, 0, len(chunks))
for _, chunk := range chunks {
if len(chunk) > 0 {
nonEmptyChunks = append(nonEmptyChunks, chunk)
}
}
// Restore separators
for i := range len(nonEmptyChunks) - 1 {
nonEmptyChunks[i] = append(nonEmptyChunks[i], []byte("\n\n")...)
}
if hasTrailingSeparator && len(nonEmptyChunks) > 0 {
nonEmptyChunks[len(nonEmptyChunks)-1] = append(nonEmptyChunks[len(nonEmptyChunks)-1], []byte("\n\n")...)
}
for _, chunk := range nonEmptyChunks {
ctx.PushBuffer(chunk)
}
// for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
// ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
// }
ctx.SetContext("end_of_stream_received", endOfStream)
if !ctx.GetContext("during_call").(bool) {
singleCall()
}
} else if endOfStream {
proxywasm.ResumeHttpResponse()
}
return []byte{}
}
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
startTime := time.Now().UnixMilli()
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
isStreamingResponse := strings.Contains(contentType, "event-stream")
var content string
if isStreamingResponse {
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
} else {
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
}
log.Debugf("Raw response content is: %s", content)
if len(content) == 0 {
log.Info("response content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpResponse()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
proxywasm.ResumeHttpResponse()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
config.IncrementCounter("ai_sec_response_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
checkService := config.GetResponseCheckService(consumer)
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,83 @@
package multi_modal_guard
import (
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
return types.ActionContinue
}
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return text.HandleTextGenerationRequestBody(ctx, config, body)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI:
return image.HandleOpenAIImageGenerationRequestBody(ctx, config, body)
case cfg.ProviderQwen:
return image.HandleQwenImageGenerationRequestBody(ctx, config, body)
default:
log.Errorf("[on request body] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on request body] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseHeader(ctx, config)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI, cfg.ProviderQwen:
return image.HandleImageGenerationResponseHeader(ctx, config)
default:
log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
default:
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
return data
}
}
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
case cfg.ApiImageGeneration:
switch config.ProviderType {
case cfg.ProviderOpenAI:
return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body)
case cfg.ProviderQwen:
return image.HandleQwenImageGenerationResponseBody(ctx, config, body)
default:
log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType)
return types.ActionContinue
}
default:
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
return types.ActionContinue
}
}

View File

@@ -0,0 +1,27 @@
package image
import (
"strings"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
type ImageItem struct {
Content string
Type string // URL or BASE64
}
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
ctx.SetContext("risk_detected", false)
if strings.Contains(contentType, "text/event-stream") {
ctx.DontReadResponseBody()
return types.ActionContinue
} else {
ctx.BufferResponseBody()
return types.HeaderStopIteration
}
}

View File

@@ -0,0 +1,271 @@
package image
import (
"encoding/json"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func parseOpenAIRequest(body []byte) (text string, images []ImageItem) {
text = gjson.GetBytes(body, "prompt").String()
return text, images
}
func parseOpenAIResponse(body []byte) []ImageItem {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
result := []ImageItem{}
for _, part := range gjson.GetBytes(body, "data").Array() {
if url := part.Get("url").String(); url != "" {
result = append(result, ImageItem{
Content: url,
Type: "URL",
})
}
if b64 := part.Get("b64_json").String(); b64 != "" {
result = append(result, ImageItem{
Content: b64,
Type: "BASE64",
})
}
}
return result
}
func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
checkService := config.GetRequestCheckService(consumer)
checkImageService := config.GetRequestImageCheckService(consumer)
startTime := time.Now().UnixMilli()
content, images := parseOpenAIRequest(body)
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 && len(images) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCallForImage = func() {
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
if len(content) > 0 {
singleCall()
} else {
singleCallForImage()
}
return types.ActionPause
}
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
checkImageService := config.GetResponseImageCheckService(consumer)
startTime := time.Now().UnixMilli()
imgResults := parseOpenAIResponse(body)
if len(imgResults) == 0 {
return types.ActionContinue
}
imageIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(imgResults) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(imgResults) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgResults) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
img := imgResults[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,429 @@
package image
import (
"encoding/json"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func parseImage(body []byte, jsonPath string) *ImageItem {
if gjson.GetBytes(body, jsonPath).Exists() {
imgContent := gjson.GetBytes(body, jsonPath).String()
if strings.HasPrefix(imgContent, "data:image") {
return &ImageItem{
Content: imgContent,
Type: "BASE64",
}
} else {
return &ImageItem{
Content: imgContent,
Type: "URL",
}
}
}
return nil
}
func parseImageArray(body []byte, jsonPath string) []ImageItem {
result := []ImageItem{}
if gjson.GetBytes(body, jsonPath).Exists() {
for _, item := range gjson.GetBytes(body, jsonPath).Array() {
imgContent := item.String()
if strings.HasPrefix(imgContent, "data:image") {
result = append(result, ImageItem{
Content: imgContent,
Type: "BASE64",
})
} else {
result = append(result, ImageItem{
Content: imgContent,
Type: "URL",
})
}
}
}
return result
}
func parseQwenRequest(body []byte) (text string, images []ImageItem) {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
images = []ImageItem{}
// 文生图/文生图v1/文生图v2
if gjson.GetBytes(body, "input.prompt").Exists() {
text += gjson.GetBytes(body, "input.prompt").String()
}
// 图像背景生成
if gjson.GetBytes(body, "input.ref_prompt").Exists() {
text += gjson.GetBytes(body, "input.ref_prompt").String()
}
if gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Exists() {
for _, item := range gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Array() {
text += item.String()
}
}
if gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Exists() {
for _, item := range gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Array() {
text += item.String()
}
}
// 创意文字
if gjson.GetBytes(body, "input.text").Exists() {
text += gjson.GetBytes(body, "input.text").String()
}
if gjson.GetBytes(body, "input.negative_prompt").Exists() {
text += gjson.GetBytes(body, "input.negative_prompt").String()
}
// 图像编辑
if gjson.GetBytes(body, "input.messages.0.content").Exists() {
for _, item := range gjson.GetBytes(body, "input.messages.0.content").Array() {
if item.Get("text").Exists() {
text += item.Get("text").String()
} else if item.Get("image").Exists() {
imgContent := item.Get("image").String()
if strings.HasPrefix(imgContent, "data:image") {
images = append(images, ImageItem{
Content: imgContent,
Type: "BASE64",
})
} else {
images = append(images, ImageItem{
Content: imgContent,
Type: "URL",
})
}
}
}
}
// image json path
imageJsonPath := []string{
"input.image_url", // 图像翻译/人像风格重绘/图像画面扩展/人物实例分割/图像擦除补全
"input.base_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
"input.mask_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
"input.sketch_image_url", // 涂鸦作画
"input.template_image_url", // 鞋靴模特
"input.shoe_image_url", // 鞋靴模特
"input.base_image_url", // 图像背景生成
"input.ref_image_url", // 图像背景生成
"input.mask_url", // 图像擦除补全
"input.foreground_url", // 图像擦除补全
"input.person_image_url", // AI试衣
"input.top_garment_url", // AI试衣
"input.bottom_garment_url", // AI试衣
"input.coarse_image_url", // AI试衣
"input.template_url", // 人物写真生成
}
for _, jsonPath := range imageJsonPath {
tmpImage := parseImage(body, jsonPath)
if tmpImage != nil {
images = append(images, *tmpImage)
}
}
// image array json path
imageArrayJsonPath := []string{
"input.images", // 通用图像编辑2.5/人物图像检测
"input.reference_edge.foreground_edge", // 图像背景生成
"input.reference_edge.background_edge", // 图像背景生成
"input.user_urls", // 人物写真生成
}
for _, jsonPath := range imageArrayJsonPath {
tmpImageArray := parseImageArray(body, jsonPath)
images = append(images, tmpImageArray...)
}
return text, images
}
func parseQwenResponse(body []byte) []string {
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
result := []string{}
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
// 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API
for _, part := range gjson.GetBytes(body, "output.results").Array() {
if url := part.Get("url").String(); url != "" {
result = append(result, url)
}
}
// 图像编辑
for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() {
if url := part.Get("image").String(); url != "" {
result = append(result, url)
}
}
// 图像翻译/AI试衣OutfitAnyone
if url := gjson.GetBytes(body, "output.image_url").String(); url != "" {
result = append(result, url)
}
// 图像画面扩展/(part of)人物实例分割/图像擦除补全
if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" {
result = append(result, url)
}
// 鞋靴模特
if url := gjson.GetBytes(body, "output.result_url").String(); url != "" {
result = append(result, url)
}
// 创意海报生成
for _, part := range gjson.GetBytes(body, "output.render_urls").Array() {
if url := part.String(); url != "" {
result = append(result, url)
}
}
for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() {
if url := part.String(); url != "" {
result = append(result, url)
}
}
// 人物实例分割
if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" {
result = append(result, url)
}
// 文字变形API
for _, part := range gjson.GetBytes(body, "output.results").Array() {
if url := part.Get("png_url").String(); url != "" {
result = append(result, url)
}
if url := part.Get("svg_url").String(); url != "" {
result = append(result, url)
}
}
return result
}
func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
checkService := config.GetRequestCheckService(consumer)
checkImageService := config.GetRequestImageCheckService(consumer)
startTime := time.Now().UnixMilli()
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
content, images := parseQwenRequest(body)
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 && len(images) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCallForImage = func() {
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
if len(content) > 0 {
singleCall()
} else {
singleCallForImage()
}
return types.ActionPause
}
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
log.Debugf("checking response body...")
checkImageService := config.GetResponseImageCheckService(consumer)
startTime := time.Now().UnixMilli()
imgUrls := parseQwenResponse(body)
if len(imgUrls) == 0 {
return types.ActionContinue
}
imageIndex := 0
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(imgUrls) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(imgUrls) {
singleCall()
} else {
proxywasm.ResumeHttpResponse()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(imgUrls) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse()
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
imgUrl := imgUrls[imageIndex]
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -0,0 +1,231 @@
package text
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
type ImageItem struct {
Content string
Type string // URL or BASE64
}
func parseContent(json gjson.Result) (text string, images []ImageItem) {
images = []ImageItem{}
if json.IsArray() {
for _, item := range json.Array() {
switch item.Get("type").String() {
case "text":
text += item.Get("text").String()
case "image_url":
imgContent := item.Get("image_url.url").String()
if strings.HasPrefix(imgContent, "data:image") {
images = append(images, ImageItem{
Content: imgContent,
Type: "BASE64",
})
} else {
images = append(images, ImageItem{
Content: imgContent,
Type: "URL",
})
}
}
}
} else {
text = json.String()
}
return text, images
}
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
checkService := config.GetRequestCheckService(consumer)
checkImageService := config.GetRequestImageCheckService(consumer)
startTime := time.Now().UnixMilli()
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
content, images := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath))
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 && len(images) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
imageIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
var singleCallForImage func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if len(images) > 0 && config.CheckRequestImage {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
imageIndex += 1
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("%+v", err)
if imageIndex < len(images) {
singleCallForImage()
} else {
proxywasm.ResumeHttpRequest()
}
return
}
endTime := time.Now().UnixMilli()
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if imageIndex >= len(images) {
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCallForImage()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCallForImage = func() {
img := images[imageIndex]
imgUrl := ""
imgBase64 := ""
if img.Type == "BASE64" {
imgBase64 = img.Content
} else {
imgUrl = img.Content
}
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
if len(content) > 0 {
singleCall()
} else {
singleCallForImage()
}
return types.ActionPause
}

View File

@@ -0,0 +1,48 @@
package text_moderation_plus
import (
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
return types.ActionContinue
}
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
return text.HandleTextGenerationRequestBody(ctx, config, body)
}
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseHeader(ctx, config)
default:
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
return types.ActionContinue
}
}
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
default:
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
return data
}
}
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
switch config.ApiType {
case cfg.ApiTextGeneration:
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
default:
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
return types.ActionContinue
}
}

View File

@@ -0,0 +1,104 @@
package text
import (
"encoding/json"
"fmt"
"net/http"
"time"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
consumer, _ := ctx.GetContext("consumer").(string)
startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
log.Debugf("Raw request content is: %s", content)
if len(content) == 0 {
log.Info("request content is empty. skip")
return types.ActionContinue
}
contentIndex := 0
sessionID, _ := utils.GenerateHexID(20)
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response cfg.Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at request phase")
proxywasm.ResumeHttpRequest()
return
}
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest()
} else {
singleCall()
}
return
}
denyMessage := cfg.DefaultDenyMessage
if config.DenyMessage != "" {
denyMessage = config.DenyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
if config.ProtocolOriginal {
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := utils.GenerateRandomChatID()
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.IncrementCounter("ai_sec_request_deny", 1)
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
var nextContentIndex int
if contentIndex+cfg.LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + cfg.LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
checkService := config.GetRequestCheckService(consumer)
path, headers, body := common.GenerateRequestForText(config, cfg.TextModerationPlus, checkService, contentPiece, sessionID)
err := config.Client.Post(path, headers, body, callback, config.Timeout)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
singleCall()
return types.ActionPause
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -18,6 +18,8 @@ import (
"encoding/json"
"testing"
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
@@ -143,16 +145,16 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.ak)
require.Equal(t, "test-sk", securityConfig.sk)
require.Equal(t, true, securityConfig.checkRequest)
require.Equal(t, true, securityConfig.checkResponse)
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.timeout)
require.Equal(t, 1000, securityConfig.bufferLimit)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "test-ak", securityConfig.AK)
require.Equal(t, "test-sk", securityConfig.SK)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, true, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
require.Equal(t, uint32(2000), securityConfig.Timeout)
require.Equal(t, 1000, securityConfig.BufferLimit)
})
// 测试仅检查请求的配置
@@ -164,12 +166,12 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, true, securityConfig.checkRequest)
require.Equal(t, false, securityConfig.checkResponse)
require.Equal(t, "high", securityConfig.contentModerationLevelBar)
require.Equal(t, "high", securityConfig.promptAttackLevelBar)
require.Equal(t, "S3", securityConfig.sensitiveDataLevelBar)
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, true, securityConfig.CheckRequest)
require.Equal(t, false, securityConfig.CheckResponse)
require.Equal(t, "high", securityConfig.ContentModerationLevelBar)
require.Equal(t, "high", securityConfig.PromptAttackLevelBar)
require.Equal(t, "S3", securityConfig.SensitiveDataLevelBar)
})
// 测试缺少必需字段的配置
@@ -202,13 +204,13 @@ func TestParseConfig(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, config)
securityConfig := config.(*AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.getRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.getRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.getResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.getResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.getMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.getMaliciousUrlLevelBar("ccc-regexp-test"))
securityConfig := config.(*cfg.AISecurityConfig)
require.Equal(t, "llm_query_moderation", securityConfig.GetRequestCheckService("aaaa"))
require.Equal(t, "llm_query_moderation_1", securityConfig.GetRequestCheckService("aaa"))
require.Equal(t, "llm_response_moderation", securityConfig.GetResponseCheckService("bb"))
require.Equal(t, "llm_response_moderation_1", securityConfig.GetResponseCheckService("bbb-prefix-test"))
require.Equal(t, "high", securityConfig.GetMaliciousUrlLevelBar("cc"))
require.Equal(t, "low", securityConfig.GetMaliciousUrlLevelBar("ccc-regexp-test"))
})
})
}
@@ -382,65 +384,100 @@ func TestOnHttpResponseHeaders(t *testing.T) {
})
}
func TestOnHttpResponseBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试响应体安全检查通过
t.Run("response body security check pass", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 设置响应体
body := `{"choices": [{"message": {"role": "assistant", "content": "Hello, how can I help you?"}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
// 应该返回ActionPause等待安全检查结果
require.Equal(t, types.ActionPause, action)
// 模拟安全检查服务响应(通过)
securityResponse := `{"Code": 200, "Message": "Success", "RequestId": "req-123", "Data": {"RiskLevel": "low"}}`
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(securityResponse))
action = host.GetHttpStreamAction()
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试空响应内容
t.Run("empty response content", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 设置空内容的响应体
body := `{"choices": [{"message": {"role": "assistant", "content": ""}}]}`
action := host.CallOnHttpResponseBody([]byte(body))
// 空内容应该直接通过
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestRiskLevelFunctions(t *testing.T) {
// 测试风险等级转换函数
t.Run("risk level conversion", func(t *testing.T) {
require.Equal(t, 4, levelToInt(MaxRisk))
require.Equal(t, 3, levelToInt(HighRisk))
require.Equal(t, 2, levelToInt(MediumRisk))
require.Equal(t, 1, levelToInt(LowRisk))
require.Equal(t, 0, levelToInt(NoRisk))
require.Equal(t, -1, levelToInt("invalid"))
require.Equal(t, 4, cfg.LevelToInt(cfg.MaxRisk))
require.Equal(t, 3, cfg.LevelToInt(cfg.HighRisk))
require.Equal(t, 2, cfg.LevelToInt(cfg.MediumRisk))
require.Equal(t, 1, cfg.LevelToInt(cfg.LowRisk))
require.Equal(t, 0, cfg.LevelToInt(cfg.NoRisk))
require.Equal(t, -1, cfg.LevelToInt("invalid"))
})
// 测试风险等级比较
t.Run("risk level comparison", func(t *testing.T) {
require.True(t, levelToInt(HighRisk) >= levelToInt(MediumRisk))
require.True(t, levelToInt(MediumRisk) >= levelToInt(LowRisk))
require.True(t, levelToInt(LowRisk) >= levelToInt(NoRisk))
require.False(t, levelToInt(LowRisk) >= levelToInt(HighRisk))
require.True(t, cfg.LevelToInt(cfg.HighRisk) >= cfg.LevelToInt(cfg.MediumRisk))
require.True(t, cfg.LevelToInt(cfg.MediumRisk) >= cfg.LevelToInt(cfg.LowRisk))
require.True(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.NoRisk))
require.False(t, cfg.LevelToInt(cfg.LowRisk) >= cfg.LevelToInt(cfg.HighRisk))
})
}
func TestUtilityFunctions(t *testing.T) {
// 测试URL编码函数
t.Run("url encoding", func(t *testing.T) {
original := "test+string:with=special&chars@$"
encoded := urlEncoding(original)
require.NotEqual(t, original, encoded)
require.Contains(t, encoded, "%2B") // + 应该被编码
require.Contains(t, encoded, "%3A") // : 应该被编码
require.Contains(t, encoded, "%3D") // = 应该被编码
require.Contains(t, encoded, "%26") // & 应该被编码
})
// 测试HMAC-SHA1签名函数
t.Run("hmac sha1", func(t *testing.T) {
message := "test message"
secret := "test secret"
signature := hmacSha1(message, secret)
require.NotEmpty(t, signature)
require.NotEqual(t, message, signature)
})
// 测试签名生成函数
t.Run("signature generation", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
params := map[string]string{
"key1": "value1",
"key2": "value2",
}
secret := "test-secret"
signature := getSign(params, secret)
require.NotEmpty(t, signature)
})
// 测试十六进制ID生成函数
t.Run("hex id generation", func(t *testing.T) {
id, err := generateHexID(16)
id, err := utils.GenerateHexID(16)
require.NoError(t, err)
require.Len(t, id, 16)
require.Regexp(t, "^[0-9a-f]+$", id)
@@ -448,7 +485,7 @@ func TestUtilityFunctions(t *testing.T) {
// 测试随机ID生成函数
t.Run("random id generation", func(t *testing.T) {
id := generateRandomID()
id := utils.GenerateRandomChatID()
require.NotEmpty(t, id)
require.Contains(t, id, "chatcmpl-")
require.Len(t, id, 38) // "chatcmpl-" + 29 random chars

View File

@@ -0,0 +1,43 @@
package utils
import (
"bytes"
"crypto/rand"
"encoding/hex"
mrand "math/rand"
"strings"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
func GenerateHexID(length int) (string, error) {
bytes := make([]byte, length/2)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func GenerateRandomChatID() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 29)
for i := range b {
b[i] = charset[mrand.Intn(len(charset))]
}
return "chatcmpl-" + string(b)
}
func ExtractMessageFromStreamingBody(data []byte, jsonPath string) string {
chunks := bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n"))
strChunks := []string{}
for _, chunk := range chunks {
// Example: "choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]
strChunks = append(strChunks, gjson.GetBytes(chunk, jsonPath).String())
}
return strings.Join(strChunks, "")
}
func GetConsumer(ctx wrapper.HttpContext) string {
return ctx.GetStringContext("consumer", "")
}

View File

@@ -5,8 +5,8 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)

Some files were not shown because too many files have changed in this diff Show More