mirror of
https://github.com/alibaba/higress.git
synced 2026-02-25 21:21:01 +08:00
Compare commits
63 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
516b016584 | ||
|
|
69f481d25e | ||
|
|
87357ae9ac | ||
|
|
82b6830af3 | ||
|
|
3b0f0bb31f | ||
|
|
bae535c980 | ||
|
|
caf00bfeac | ||
|
|
ce298054f1 | ||
|
|
24c69fb0b7 | ||
|
|
a38be77b9e | ||
|
|
27999dcc59 | ||
|
|
811179a6a0 | ||
|
|
5f43dd0224 | ||
|
|
e23ab3ca7c | ||
|
|
032a69556f | ||
|
|
ee6bb11730 | ||
|
|
fc600f204a | ||
|
|
357418853f | ||
|
|
e8586cccd7 | ||
|
|
d55b9a0837 | ||
|
|
4f04ac067b | ||
|
|
c7028bd7f2 | ||
|
|
95ff52cde9 | ||
|
|
7c7205b572 | ||
|
|
f342f50ca4 | ||
|
|
659d136bfe | ||
|
|
541e5e206f | ||
|
|
387c337654 | ||
|
|
8024a96881 | ||
|
|
f71c1900a8 | ||
|
|
1199946d36 | ||
|
|
b1571de6f0 | ||
|
|
20dae295a8 | ||
|
|
9a1f9e4606 | ||
|
|
6f4ef33590 | ||
|
|
fef8ecc822 | ||
|
|
0ade9504be | ||
|
|
6311fecfce | ||
|
|
5c225de080 | ||
|
|
bf9ef5eefd | ||
|
|
26f5737a80 | ||
|
|
50c1a5e78c | ||
|
|
647304eb45 | ||
|
|
0a7fc9f412 | ||
|
|
c9253264ef | ||
|
|
8c80084ada | ||
|
|
9f5ee99c2d | ||
|
|
3770bd2f55 | ||
|
|
698a395e89 | ||
|
|
2c72767203 | ||
|
|
bb3ac59834 | ||
|
|
6c1fe57034 | ||
|
|
5c5cc6ac90 | ||
|
|
265da8e4d6 | ||
|
|
119698eea4 | ||
|
|
18d20ca135 | ||
|
|
9978db2ac6 | ||
|
|
1582fa6ef9 | ||
|
|
2b49fd5b26 | ||
|
|
48433a6549 | ||
|
|
8ec48b3b85 | ||
|
|
32007d2ab8 | ||
|
|
27b088fc7e |
@@ -31,7 +31,8 @@ jobs:
|
||||
- name: Upload to OSS
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp -r -u ./artifact/ oss://higress-website-cn-hongkong/standalone/'
|
||||
ossArgs: 'cp -r -u ./artifact/ oss://higress-ai/standalone/'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
|
||||
|
||||
5
.github/workflows/deploy-to-oss.yaml
vendored
5
.github/workflows/deploy-to-oss.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: Download Helm Charts Index
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp oss://higress-website-cn-hongkong/helm-charts/index.yaml ./artifact/'
|
||||
ossArgs: 'cp oss://higress-ai/helm-charts/index.yaml ./artifact/'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
@@ -48,7 +48,8 @@ jobs:
|
||||
- name: Upload to OSS
|
||||
uses: go-choppy/ossutil-github-action@master
|
||||
with:
|
||||
ossArgs: 'cp -r -u ./artifact/ oss://higress-website-cn-hongkong/helm-charts/'
|
||||
ossArgs: 'cp -r -u ./artifact/ oss://higress-ai/helm-charts/'
|
||||
accessKey: ${{ secrets.ACCESS_KEYID }}
|
||||
accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
|
||||
endpoint: oss-cn-hongkong.aliyuncs.com
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ header:
|
||||
- 'hgctl/pkg/manifests'
|
||||
- 'pkg/ingress/kube/gateway/istio/testdata'
|
||||
- 'release-notes/**'
|
||||
- '.cursor/**'
|
||||
|
||||
comment: on-failure
|
||||
dependency:
|
||||
|
||||
13
ADOPTERS.md
Normal file
13
ADOPTERS.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Adopters of Higress
|
||||
|
||||
Below are the adopters of the Higress project. If you are using Higress in your organization, please add your name to the list by submitting a pull request: this will help foster the Higress community. Kindly ensure the list remains in alphabetical order.
|
||||
|
||||
|
||||
| Organization | Contact (GitHub User Name) | Environment | Description of Use |
|
||||
|---------------------------------------|----------------------------------------|--------------------------------------------|-----------------------------------------------------------------------|
|
||||
| [antdigital](https://antdigital.com/) | [@Lovelcp](https://github.com/Lovelcp) | Production | Ingress Gateway, Microservice gateway, LLM Gateway, MCP Gateway |
|
||||
| [kuaishou](https://ir.kuaishou.com/) | [@maplecap](https://github.com/maplecap) | Production | LLM Gateway |
|
||||
| [Trip.com](https://www.trip.com/) | [@CH3CHO](https://github.com/CH3CHO) | Production | LLM Gateway, MCP Gateway |
|
||||
| [vipshop](https://github.com/vipshop/) | [@firebook](https://github.com/firebook) | Production | LLM Gateway, MCP Gateway, Inference Gateway |
|
||||
| [labring](https://github.com/labring/) | [@zzjin](https://github.com/zzjin) | Production | Ingress Gateway |
|
||||
| < company name here> | < your github handle here > | <Production/Testing/Experimenting/etc> | <Ingress Gateway/Microservice gateway/LLM Gateway/MCP Gateway/Inference Gateway> |
|
||||
@@ -1 +1 @@
|
||||
higress-console: v2.1.9
|
||||
higress-console: v2.1.11
|
||||
@@ -146,7 +146,7 @@ docker-buildx-push: clean-env docker.higress-buildx
|
||||
export PARENT_GIT_TAG:=$(shell cat VERSION)
|
||||
export PARENT_GIT_REVISION:=$(TAG)
|
||||
|
||||
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.2.0/envoy-symbol-ARCH.tar.gz
|
||||
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.11/envoy-symbol-ARCH.tar.gz
|
||||
|
||||
build-envoy: prebuild
|
||||
./tools/hack/build-envoy.sh
|
||||
|
||||
15
README.md
15
README.md
@@ -45,7 +45,7 @@ Higress was born within Alibaba to solve the issues of Tengine reload affecting
|
||||
|
||||
You can click the button below to install the enterprise version of Higress:
|
||||
|
||||
[](https://www.aliyun.com/product/apigateway?spm=higress-github.topbar.0.0.0)
|
||||
[](https://www.aliyun.com/product/api-gateway?spm=higress-github.topbar.0.0.0)
|
||||
|
||||
|
||||
If you use open-source Higress and wish to obtain enterprise-level support, you can contact the project maintainer johnlanni's email: **zty98751@alibaba-inc.com** or social media accounts (WeChat ID: **nomadao**, DingTalk ID: **chengtanzty**). Please note **Higress** when adding as a friend :)
|
||||
@@ -82,6 +82,8 @@ Port descriptions:
|
||||
>
|
||||
> If you experience a timeout when pulling image from `higress-registry.cn-hangzhou.cr.aliyuncs.com`, you can try replacing it with the following docker registry mirror source:
|
||||
>
|
||||
> **North America**: `higress-registry.us-west-1.cr.aliyuncs.com`
|
||||
>
|
||||
> **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
For other installation methods such as Helm deployment under K8s, please refer to the official [Quick Start documentation](https://higress.io/en-us/docs/user/quickstart).
|
||||
@@ -117,7 +119,16 @@ If you are deploying on the cloud, it is recommended to use the [Enterprise Edit
|
||||
|
||||
Higress can function as a feature-rich ingress controller, which is compatible with many annotations of K8s' nginx ingress controller.
|
||||
|
||||
[Gateway API](https://gateway-api.sigs.k8s.io/) support is coming soon and will support smooth migration from Ingress API to Gateway API.
|
||||
[Gateway API](https://gateway-api.sigs.k8s.io/) is already supported, and it supports a smooth migration from Ingress API to Gateway API.
|
||||
|
||||
Compared to ingress-nginx, the resource overhead has significantly decreased, and the speed at which route changes take effect has improved by ten times.
|
||||
|
||||
> The following resource overhead comparison comes from [sealos](https://github.com/labring).
|
||||
>
|
||||
> For details, you can read this [article](https://sealos.io/blog/sealos-envoy-vs-nginx-2000-tenants) to understand how sealos migrates the monitoring of **tens of thousands of ingress** resources from nginx ingress to higress.
|
||||
|
||||

|
||||
|
||||
|
||||
- **Microservice gateway**:
|
||||
|
||||
|
||||
Submodule envoy/envoy updated: 3fe314c698...384e5aab43
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.1.9
|
||||
appVersion: 2.1.11
|
||||
description: Helm chart for deploying higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -15,4 +15,4 @@ dependencies:
|
||||
repository: "file://../redis"
|
||||
version: 0.0.1
|
||||
type: application
|
||||
version: 2.1.9
|
||||
version: 2.1.11
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
# Declare variables to be passed into your templates.
|
||||
global:
|
||||
# -- Specify the image registry and pull policy
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
# Will inherit from parent chart's global.hub if not set
|
||||
hub: ""
|
||||
# -- Specify image pull policy if default behavior isn't desired.
|
||||
# Default behavior: latest images will be Always else IfNotPresent.
|
||||
imagePullPolicy: ""
|
||||
|
||||
@@ -71,6 +71,11 @@ spec:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
routeType:
|
||||
enum:
|
||||
- HTTP
|
||||
- GRPC
|
||||
type: string
|
||||
service:
|
||||
items:
|
||||
type: string
|
||||
|
||||
@@ -203,7 +203,7 @@ template:
|
||||
{{- if $o11y.enabled }}
|
||||
{{- $config := $o11y.promtail }}
|
||||
- name: promtail
|
||||
image: {{ $config.image.repository }}:{{ $config.image.tag }}
|
||||
image: {{ $config.image.repository | default (printf "%s/promtail" .Values.global.hub) }}:{{ $config.image.tag }}
|
||||
imagePullPolicy: IfNotPresent
|
||||
args:
|
||||
- -config.file=/etc/promtail/promtail.yaml
|
||||
@@ -250,6 +250,10 @@ template:
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 6 }}
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.topologySpreadConstraints }}
|
||||
topologySpreadConstraints:
|
||||
{{- toYaml . | nindent 6 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- emptyDir: {}
|
||||
name: workload-socket
|
||||
|
||||
@@ -289,6 +289,10 @@ spec:
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.controller.topologySpreadConstraints }}
|
||||
topologySpreadConstraints:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
- name: log
|
||||
emptyDir: {}
|
||||
|
||||
@@ -5,6 +5,9 @@ metadata:
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
{{- include "gateway.labels" . | nindent 4}}
|
||||
{{- with .Values.gateway.metrics.podMonitorSelector }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
annotations:
|
||||
{{- .Values.gateway.annotations | toYaml | nindent 4 }}
|
||||
spec:
|
||||
|
||||
@@ -24,9 +24,6 @@ spec:
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.service.externalTrafficPolicy }}
|
||||
externalTrafficPolicy: "{{ . }}"
|
||||
{{- end }}
|
||||
{{- with .Values.gateway.service.loadBalancerClass}}
|
||||
loadBalancerClass: "{{ . }}"
|
||||
{{- end }}
|
||||
type: {{ .Values.gateway.service.type }}
|
||||
ports:
|
||||
|
||||
@@ -362,7 +362,7 @@ global:
|
||||
enabled: false
|
||||
promtail:
|
||||
image:
|
||||
repository: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/promtail
|
||||
repository: "" # Will use global.hub if not set
|
||||
tag: 2.9.4
|
||||
port: 3101
|
||||
resources:
|
||||
@@ -377,7 +377,7 @@ global:
|
||||
# The default value is "" and when caName="", the CA will be configured by other
|
||||
# mechanisms (e.g., environmental variable CA_PROVIDER).
|
||||
caName: ""
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
hub: "" # Will use global.hub if not set
|
||||
|
||||
clusterName: ""
|
||||
# -- meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior
|
||||
@@ -433,7 +433,7 @@ gateway:
|
||||
# -- The readiness timeout seconds
|
||||
readinessTimeoutSeconds: 3
|
||||
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
hub: "" # Will use global.hub if not set
|
||||
tag: ""
|
||||
# -- revision declares which revision this gateway is a part of
|
||||
revision: ""
|
||||
@@ -522,12 +522,19 @@ gateway:
|
||||
|
||||
affinity: {}
|
||||
|
||||
topologySpreadConstraints: []
|
||||
|
||||
# -- If specified, the gateway will act as a network gateway for the given network.
|
||||
networkGateway: ""
|
||||
|
||||
metrics:
|
||||
# -- If true, create PodMonitor or VMPodScrape for gateway
|
||||
enabled: false
|
||||
# -- Selector for PodMonitor
|
||||
# When using monitoring.coreos.com/v1.PodMonitor, the selector must match
|
||||
# the label "release: kube-prome" is the default for kube-prometheus-stack
|
||||
podMonitorSelector:
|
||||
release: kube-prome
|
||||
# -- provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com
|
||||
provider: monitoring.coreos.com
|
||||
interval: ""
|
||||
@@ -548,7 +555,7 @@ controller:
|
||||
replicas: 1
|
||||
image: higress
|
||||
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
hub: "" # Will use global.hub if not set
|
||||
tag: ""
|
||||
env: {}
|
||||
|
||||
@@ -624,6 +631,8 @@ controller:
|
||||
|
||||
affinity: {}
|
||||
|
||||
topologySpreadConstraints: []
|
||||
|
||||
autoscaling:
|
||||
enabled: false
|
||||
minReplicas: 1
|
||||
@@ -642,7 +651,7 @@ pilot:
|
||||
rollingMaxSurge: 100%
|
||||
rollingMaxUnavailable: 25%
|
||||
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
hub: "" # Will use global.hub if not set
|
||||
tag: ""
|
||||
|
||||
# -- Can be a full hub/image:tag
|
||||
@@ -795,7 +804,7 @@ pluginServer:
|
||||
replicas: 2
|
||||
image: plugin-server
|
||||
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
hub: "" # Will use global.hub if not set
|
||||
tag: ""
|
||||
|
||||
imagePullSecrets: []
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: file://../core
|
||||
version: 2.1.9
|
||||
version: 2.1.11
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 2.1.9
|
||||
digest: sha256:d696af6726b40219cc16e7cf8de7400101479dfbd8deb3101d7ee736415b9875
|
||||
generated: "2025-11-13T16:33:49.721553+08:00"
|
||||
version: 2.1.11
|
||||
digest: sha256:09058429db5bef8f2d3e5820f3f84457b3dad34f4638878018cd22623fa38f92
|
||||
generated: "2026-02-20T23:47:51.258092+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.1.9
|
||||
appVersion: 2.1.11
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -12,9 +12,9 @@ sources:
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: "file://../core"
|
||||
version: 2.1.9
|
||||
version: 2.1.11
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 2.1.9
|
||||
version: 2.1.11
|
||||
type: application
|
||||
version: 2.1.9
|
||||
version: 2.1.11
|
||||
|
||||
@@ -44,7 +44,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| controller.autoscaling.minReplicas | int | `1` | |
|
||||
| controller.autoscaling.targetCPUUtilizationPercentage | int | `80` | |
|
||||
| controller.env | object | `{}` | |
|
||||
| controller.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
|
||||
| controller.hub | string | `""` | |
|
||||
| controller.image | string | `"higress"` | |
|
||||
| controller.imagePullSecrets | list | `[]` | |
|
||||
| controller.labels | object | `{}` | |
|
||||
@@ -83,6 +83,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| controller.serviceAccount.name | string | `""` | If not set and create is true, a name is generated using the fullname template |
|
||||
| controller.tag | string | `""` | |
|
||||
| controller.tolerations | list | `[]` | |
|
||||
| controller.topologySpreadConstraints | list | `[]` | |
|
||||
| downstream | object | `{"connectionBufferLimits":32768,"http2":{"initialConnectionWindowSize":1048576,"initialStreamWindowSize":65535,"maxConcurrentStreams":100},"idleTimeout":180,"maxRequestHeadersKb":60,"routeTimeout":0}` | Downstream config settings |
|
||||
| gateway.affinity | object | `{}` | |
|
||||
| gateway.annotations | object | `{}` | Annotations to apply to all resources |
|
||||
@@ -95,7 +96,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| gateway.hostNetwork | bool | `false` | |
|
||||
| gateway.httpPort | int | `80` | |
|
||||
| gateway.httpsPort | int | `443` | |
|
||||
| gateway.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
|
||||
| gateway.hub | string | `""` | |
|
||||
| gateway.image | string | `"gateway"` | |
|
||||
| gateway.kind | string | `"Deployment"` | Use a `DaemonSet` or `Deployment` |
|
||||
| gateway.labels | object | `{}` | Labels to apply to all resources |
|
||||
@@ -104,6 +105,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| gateway.metrics.interval | string | `""` | |
|
||||
| gateway.metrics.metricRelabelConfigs | list | `[]` | for operator.victoriametrics.com/v1beta1.VMPodScrape |
|
||||
| gateway.metrics.metricRelabelings | list | `[]` | for monitoring.coreos.com/v1.PodMonitor |
|
||||
| gateway.metrics.podMonitorSelector | object | `{"release":"kube-prome"}` | Selector for PodMonitor When using monitoring.coreos.com/v1.PodMonitor, the selector must match the label "release: kube-prome" is the default for kube-prometheus-stack |
|
||||
| gateway.metrics.provider | string | `"monitoring.coreos.com"` | provider group name for CustomResourceDefinition, can be monitoring.coreos.com or operator.victoriametrics.com |
|
||||
| gateway.metrics.rawSpec | object | `{}` | some more raw podMetricsEndpoints spec |
|
||||
| gateway.metrics.relabelConfigs | list | `[]` | |
|
||||
@@ -151,6 +153,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
|
||||
| gateway.tag | string | `""` | |
|
||||
| gateway.tolerations | list | `[]` | |
|
||||
| gateway.topologySpreadConstraints | list | `[]` | |
|
||||
| gateway.unprivilegedPortSupported | string | `nil` | |
|
||||
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
|
||||
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |
|
||||
@@ -191,7 +194,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.multiCluster.clusterName | string | `""` | Should be set to the name of the cluster this installation will run in. This is required for sidecar injection to properly label proxies |
|
||||
| global.multiCluster.enabled | bool | `true` | Set to true to connect two kubernetes clusters via their respective ingressgateway services when pods in each cluster cannot directly talk to one another. All clusters should be using Istio mTLS and must have a shared root CA for this model to work. |
|
||||
| global.network | string | `""` | Network defines the network this cluster belong to. This name corresponds to the networks in the map of mesh networks. |
|
||||
| global.o11y | object | `{"enabled":false,"promtail":{"image":{"repository":"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/promtail","tag":"2.9.4"},"port":3101,"resources":{"limits":{"cpu":"500m","memory":"2Gi"}},"securityContext":{}}}` | Observability (o11y) configurations |
|
||||
| global.o11y | object | `{"enabled":false,"promtail":{"image":{"repository":"","tag":"2.9.4"},"port":3101,"resources":{"limits":{"cpu":"500m","memory":"2Gi"}},"securityContext":{}}}` | Observability (o11y) configurations |
|
||||
| global.omitSidecarInjectorConfigMap | bool | `false` | |
|
||||
| global.onDemandRDS | bool | `false` | |
|
||||
| global.oneNamespace | bool | `false` | Whether to restrict the applications namespace the controller manages; If not set, controller watches all namespaces |
|
||||
@@ -243,7 +246,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.watchNamespace | string | `""` | If not empty, Higress Controller will only watch resources in the specified namespace. When isolating different business systems using K8s namespace, if each namespace requires a standalone gateway instance, this parameter can be used to confine the Ingress watching of Higress within the given namespace. |
|
||||
| global.xdsMaxRecvMsgSize | string | `"104857600"` | |
|
||||
| gzip | object | `{"chunkSize":4096,"compressionLevel":"BEST_COMPRESSION","compressionStrategy":"DEFAULT_STRATEGY","contentType":["text/html","text/css","text/plain","text/xml","application/json","application/javascript","application/xhtml+xml","image/svg+xml"],"disableOnEtagHeader":true,"enable":true,"memoryLevel":5,"minContentLength":1024,"windowBits":12}` | Gzip compression settings |
|
||||
| hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
|
||||
| hub | string | `""` | |
|
||||
| meshConfig | object | `{"enablePrometheusMerge":true,"rootNamespace":null,"trustDomain":"cluster.local"}` | meshConfig defines runtime configuration of components, including Istiod and istio-agent behavior See https://istio.io/docs/reference/config/istio.mesh.v1alpha1/ for all available options |
|
||||
| meshConfig.rootNamespace | string | `nil` | The namespace to treat as the administrative root namespace for Istio configuration. When processing a leaf namespace Istio will search for declarations in that namespace first and if none are found it will search in the root namespace. Any matching declaration found in the root namespace is processed as if it were declared in the leaf namespace. |
|
||||
| meshConfig.trustDomain | string | `"cluster.local"` | The trust domain corresponds to the trust root of a system Refer to https://github.com/spiffe/spiffe/blob/master/standards/SPIFFE-ID.md#21-trust-domain |
|
||||
@@ -260,7 +263,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| pilot.env.PILOT_ENABLE_METADATA_EXCHANGE | string | `"false"` | |
|
||||
| pilot.env.PILOT_SCOPE_GATEWAY_TO_NAMESPACE | string | `"false"` | |
|
||||
| pilot.env.VALIDATION_ENABLED | string | `"false"` | |
|
||||
| pilot.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
|
||||
| pilot.hub | string | `""` | |
|
||||
| pilot.image | string | `"pilot"` | Can be a full hub/image:tag |
|
||||
| pilot.jwksResolverExtraRootCA | string | `""` | You can use jwksResolverExtraRootCA to provide a root certificate in PEM format. This will then be trusted by pilot when resolving JWKS URIs. |
|
||||
| pilot.keepaliveMaxServerConnectionAge | string | `"30m"` | The following is used to limit how long a sidecar can be connected to a pilot. It balances out load across pilot instances at the cost of increasing system churn. |
|
||||
@@ -275,7 +278,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| pilot.serviceAnnotations | object | `{}` | |
|
||||
| pilot.tag | string | `""` | |
|
||||
| pilot.traceSampling | float | `1` | |
|
||||
| pluginServer.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
|
||||
| pluginServer.hub | string | `""` | |
|
||||
| pluginServer.image | string | `"plugin-server"` | |
|
||||
| pluginServer.imagePullSecrets | list | `[]` | |
|
||||
| pluginServer.labels | object | `{}` | |
|
||||
|
||||
@@ -112,6 +112,7 @@ helm delete higress -n higress-system
|
||||
| gateway.metrics.rawSpec | object | `{}` | 额外的度量规范 |
|
||||
| gateway.metrics.relabelConfigs | list | `[]` | 重新标签配置 |
|
||||
| gateway.metrics.relabelings | list | `[]` | 重新标签项 |
|
||||
| gateway.metrics.podMonitorSelector | object | `{"release":"kube-prometheus-stack"}` | PodMonitor 选择器,当使用 prometheus stack 的podmonitor自动发现时,选择器必须匹配标签 "release: kube-prome",这是 kube-prometheus-stack 的默认设置 |
|
||||
| gateway.metrics.scrapeTimeout | string | `""` | 抓取的超时时间 |
|
||||
| gateway.name | string | `"higress-gateway"` | 网关名称 |
|
||||
| gateway.networkGateway | string | `""` | 网络网关指定 |
|
||||
|
||||
Submodule istio/istio updated: 3a661d92b0...1a18c6d5b6
@@ -53,7 +53,6 @@ require (
|
||||
github.com/cockroachdb/errors v1.9.1 // indirect
|
||||
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
|
||||
github.com/cockroachdb/redact v1.1.3 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/deckarep/golang-set v1.7.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0 // indirect
|
||||
|
||||
@@ -185,10 +185,7 @@ github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TB
|
||||
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
|
||||
github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
|
||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||
github.com/go-faker/faker/v4 v4.1.0 h1:ffuWmpDrducIUOO0QSKSF5Q2dxAht+dhsT9FvVHhPEI=
|
||||
github.com/go-faker/faker/v4 v4.1.0/go.mod h1:uuNc0PSRxF8nMgjGrrrU4Nw5cF30Jc6Kd0/FUTTYbhg=
|
||||
github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw=
|
||||
github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw=
|
||||
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
|
||||
@@ -429,7 +426,6 @@ github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKf
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-ops"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/tool-search"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
|
||||
144
plugins/golang-filter/mcp-server/servers/tool-search/README.md
Normal file
144
plugins/golang-filter/mcp-server/servers/tool-search/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Tool Search MCP Server
|
||||
|
||||
这是一个基于 Higress Golang Filter 实现的 MCP Server,用于提供工具语义搜索功能。当前实现**仅支持向量语义搜索**(基于 Milvus 向量数据库),**不包含全文检索或混合搜索**。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **向量语义搜索**:使用 OpenAI 兼容的 Embedding API 将用户查询转换为向量,并在 Milvus 中进行相似度检索
|
||||
- **工具元数据支持**:从数据库中读取完整的工具定义(JSON 格式),并动态拼接工具名称
|
||||
- **全量工具列表**:支持获取数据库中所有可用工具
|
||||
- **可配置 Embedding 模型**:支持自定义模型、维度及 API 端点(如 DashScope)
|
||||
- **Milvus 集成**:通过标准 gRPC 接口连接 Milvus 向量数据库
|
||||
|
||||
## 数据库要求(Milvus)
|
||||
|
||||
本服务依赖 **Milvus 向量数据库**,需预先创建集合(Collection),其 Schema 应包含以下字段:
|
||||
|
||||
| 字段名 | 类型 | 说明 |
|
||||
|--------------|-------------------|-------------------------|
|
||||
| `id` | VarChar(64) | 文档唯一 ID |
|
||||
| `content` | VarChar(64) | 工具描述文本 |
|
||||
| `metadata` | JSON | 完整的工具定义(必须包含 `name` 字段) |
|
||||
| `vector` | FloatVector(1024) | embedding 向量 |
|
||||
| `metadata` | Int64 | 创建时间 |
|
||||
|
||||
|
||||
## 配置参数
|
||||
|
||||
### 根级配置
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|--------------|--------|------|-----------------------------------------------------|------|
|
||||
| `vector` | object | 是 | - | 向量数据库配置(见下文) |
|
||||
| `embedding` | object | 是 | - | Embedding API 配置(见下文) |
|
||||
| `description`| string | 否 | `"Tool search server for semantic similarity search"` | MCP Server 描述信息 |
|
||||
|
||||
### Vector 配置(`vector` 对象)
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|-------------|--------|------|--------------------|------|
|
||||
| `type` | string | 是 | - | **必须为 `"milvus"`** |
|
||||
| `host` | string | 是 | - | Milvus 服务地址(如 `localhost`) |
|
||||
| `port` | int | 是 | - | Milvus gRPC 端口(如 `19530`) |
|
||||
| `database` | string | 否 | `"default"` | Milvus 数据库名 |
|
||||
| `tableName` | string | 否 | `"apig_mcp_tools"` | Milvus 集合名 |
|
||||
| `username` | string | 否 | - | 认证用户名(可选) |
|
||||
| `password` | string | 否 | - | 认证密码(可选) |
|
||||
|
||||
### Embedding 配置(`embedding` 对象)
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|--------------|--------|------|-----------------------------------------------------------|------|
|
||||
| `apiKey` | string | 是 | - | Embedding 服务的 API Key |
|
||||
| `baseURL` | string | 否 | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI 兼容 API 的 Base URL |
|
||||
| `model` | string | 否 | `text-embedding-v4` | 使用的 Embedding 模型 |
|
||||
| `dimensions` | int | 否 | `1024` | 向量维度 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
|
||||
Tool Search MCP Server 也可以作为 Higress 的一个模块进行配置。以下是一个在 Higress ConfigMap 中配置 Tool Search 的示例:
|
||||
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: higress-config
|
||||
namespace: higress-system
|
||||
data:
|
||||
higress: |
|
||||
mcpServer:
|
||||
enable: true
|
||||
sse_path_suffix: "/sse"
|
||||
redis:
|
||||
address: "<Redis IP>:6379"
|
||||
username: ""
|
||||
password: ""
|
||||
db: 0
|
||||
match_list:
|
||||
- path_rewrite_prefix: ""
|
||||
upstream_type: ""
|
||||
enable_path_rewrite: false
|
||||
match_rule_domain: "*"
|
||||
match_rule_path: "/mcp-servers/tool-search"
|
||||
match_rule_type: "prefix"
|
||||
servers:
|
||||
- path: "/mcp-servers/tool-search"
|
||||
name: "tool-search"
|
||||
type: "tool-search"
|
||||
config:
|
||||
vector:
|
||||
type: "milvus"
|
||||
host: "localhost"
|
||||
port: 19530
|
||||
database: "default"
|
||||
tableName: "apig_mcp_tools"
|
||||
username: "root"
|
||||
password: "Milvus"
|
||||
maxTools: 1000
|
||||
embedding:
|
||||
apiKey: "your-dashscope-api-key"
|
||||
baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
model: "text-embedding-v4"
|
||||
dimensions: 1024
|
||||
description: "Higress 工具语义搜索服务"
|
||||
```
|
||||
|
||||
## 工具搜索接口
|
||||
|
||||
Tool Search MCP Server 提供以下 MCP 工具:
|
||||
|
||||
### x_higress_tool_search
|
||||
|
||||
基于语义相似度搜索最相关的工具。
|
||||
|
||||
**输入参数**:
|
||||
|
||||
| 参数名 | 类型 | 必填 | 说明 |
|
||||
|---------|--------|------|------|
|
||||
| `query` | string | 是 | 查询语句,用于与工具描述进行语义相似度比较 |
|
||||
| `topK` | int | 否 | 指定需要选择的工具数量,默认选择前10个工具 |
|
||||
|
||||
**输出格式**:
|
||||
|
||||
```
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"name": "server_name___tool_name",
|
||||
"title": "Tool Title",
|
||||
"description": "Tool description",
|
||||
"inputSchema": {...},
|
||||
"outputSchema": {...}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 搜索实现
|
||||
|
||||
通过向量相似度进行搜索,索引配置如下
|
||||
- 使用 HNSW 索引算法进行向量索引
|
||||
- 默认参数:M=8, efConstruction=64
|
||||
- 相似度度量方式:内积(IP)
|
||||
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"vector": {
|
||||
"type": "milvus",
|
||||
"vectorWeight": 0.5,
|
||||
"tableName": "apig_mcp_tools",
|
||||
"host": "localhost",
|
||||
"port": 19530,
|
||||
"database": "default",
|
||||
"username": "root",
|
||||
"password": "Milvus"
|
||||
},
|
||||
"embedding": {
|
||||
"apiKey": "your-dashscope-api-key",
|
||||
"baseURL": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "text-embedding-v4",
|
||||
"dimensions": 1024
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/openai/openai-go/v2"
|
||||
"github.com/openai/openai-go/v2/option"
|
||||
)
|
||||
|
||||
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
|
||||
type EmbeddingClient struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
dimensions int
|
||||
}
|
||||
|
||||
// NewEmbeddingClient creates a new EmbeddingClient instance for OpenAI-compatible APIs
|
||||
func NewEmbeddingClient(apiKey, baseURL, model string, dimensions int) *EmbeddingClient {
|
||||
api.LogInfof("Creating EmbeddingClient with baseURL: %s, model: %s, dimensions: %d", baseURL, model, dimensions)
|
||||
|
||||
// Create client with timeout
|
||||
client := openai.NewClient(
|
||||
option.WithAPIKey(apiKey),
|
||||
option.WithBaseURL(baseURL),
|
||||
option.WithRequestTimeout(30*time.Second),
|
||||
)
|
||||
|
||||
return &EmbeddingClient{
|
||||
client: &client,
|
||||
model: model,
|
||||
dimensions: dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmbedding generates vector embedding for the given text
|
||||
func (e *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
api.LogInfof("Generating embedding for text (length: %d)", len(text))
|
||||
api.LogDebugf("Using model: %s, dimensions: %d", e.model, e.dimensions)
|
||||
|
||||
// Add timeout to context if not already present
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
params := openai.EmbeddingNewParams{
|
||||
Model: e.model,
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfString: openai.String(text),
|
||||
},
|
||||
Dimensions: openai.Int(int64(e.dimensions)),
|
||||
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
|
||||
}
|
||||
|
||||
api.LogDebugf("Calling OpenAI-compatible API for embedding generation")
|
||||
embeddingResp, err := e.client.Embeddings.New(ctx, params)
|
||||
if err != nil {
|
||||
api.LogErrorf("OpenAI-compatible API call failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
|
||||
if len(embeddingResp.Data) == 0 {
|
||||
api.LogErrorf("Empty embedding response from API")
|
||||
return nil, fmt.Errorf("empty embedding response")
|
||||
}
|
||||
|
||||
api.LogDebugf("Successfully received embedding from API")
|
||||
api.LogDebugf("Response data length: %d, embedding dimension: %d", len(embeddingResp.Data), len(embeddingResp.Data[0].Embedding))
|
||||
|
||||
// Convert []float64 to []float32
|
||||
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
|
||||
for i, v := range embeddingResp.Data[0].Embedding {
|
||||
embedding[i] = float32(v)
|
||||
}
|
||||
|
||||
api.LogInfof("Embedding conversion completed, final dimension: %d", len(embedding))
|
||||
return embedding, nil
|
||||
}
|
||||
204
plugins/golang-filter/mcp-server/servers/tool-search/milvus.go
Normal file
204
plugins/golang-filter/mcp-server/servers/tool-search/milvus.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
)
|
||||
|
||||
type MilvusVectorStoreProvider struct {
|
||||
client client.Client
|
||||
collection string
|
||||
dimensions int
|
||||
}
|
||||
|
||||
func NewMilvusVectorStoreProvider(cfg *config.VectorDBConfig, dimensions int) (*MilvusVectorStoreProvider, error) {
|
||||
connectParam := client.Config{
|
||||
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
}
|
||||
connectParam.DBName = cfg.Database
|
||||
|
||||
if cfg.Username != "" && cfg.Password != "" {
|
||||
connectParam.Username = cfg.Username
|
||||
connectParam.Password = cfg.Password
|
||||
}
|
||||
|
||||
milvusClient, err := client.NewClient(context.Background(), connectParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create milvus client: %w", err)
|
||||
}
|
||||
|
||||
return &MilvusVectorStoreProvider{
|
||||
client: milvusClient,
|
||||
collection: cfg.Collection,
|
||||
dimensions: dimensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MilvusVectorStoreProvider) ListAllDocs(ctx context.Context, limit int) ([]schema.Document, error) {
|
||||
expr := ""
|
||||
|
||||
outputFields := []string{"id", "content", "metadata", "created_at"}
|
||||
|
||||
var queryResult []entity.Column
|
||||
var err error
|
||||
|
||||
if limit > 0 {
|
||||
queryResult, err = c.client.Query(
|
||||
ctx,
|
||||
c.collection,
|
||||
[]string{}, // partitions
|
||||
expr, // filter condition
|
||||
outputFields,
|
||||
client.WithLimit(int64(limit)),
|
||||
)
|
||||
} else {
|
||||
queryResult, err = c.client.Query(
|
||||
ctx,
|
||||
c.collection,
|
||||
[]string{}, // partitions
|
||||
expr, // filter condition
|
||||
outputFields,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query all documents: %w", err)
|
||||
}
|
||||
|
||||
if len(queryResult) == 0 {
|
||||
return []schema.Document{}, nil
|
||||
}
|
||||
|
||||
rowCount := queryResult[0].Len()
|
||||
documents := make([]schema.Document, 0, rowCount)
|
||||
|
||||
for i := 0; i < rowCount; i++ {
|
||||
var (
|
||||
id string
|
||||
content string
|
||||
metadata map[string]interface{}
|
||||
createdAt int64
|
||||
)
|
||||
|
||||
for _, col := range queryResult {
|
||||
switch col.Name() {
|
||||
case "id":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
id = v.(string)
|
||||
}
|
||||
case "content":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
content = v.(string)
|
||||
}
|
||||
case "metadata":
|
||||
if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil {
|
||||
if bytes, ok := v.([]byte); ok {
|
||||
_ = json.Unmarshal(bytes, &metadata)
|
||||
}
|
||||
}
|
||||
case "created_at":
|
||||
if v, err := col.(*entity.ColumnInt64).Get(i); err == nil {
|
||||
createdAt = v.(int64)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
doc := schema.Document{
|
||||
ID: id,
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
CreatedAt: time.UnixMilli(createdAt),
|
||||
}
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
func (c *MilvusVectorStoreProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
|
||||
if options == nil {
|
||||
options = &schema.SearchOptions{TopK: 10}
|
||||
}
|
||||
|
||||
sp, err := entity.NewIndexHNSWSearchParam(16) // 默认 HNSW 搜索参数
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build search param: %w", err)
|
||||
}
|
||||
|
||||
outputFields := []string{"id", "content", "metadata"}
|
||||
searchResults, err := c.client.Search(
|
||||
ctx,
|
||||
c.collection,
|
||||
[]string{}, // partition names
|
||||
"", // filter expression
|
||||
outputFields, // output fields
|
||||
[]entity.Vector{entity.FloatVector(vector)},
|
||||
"vector", // anns_field
|
||||
entity.IP, // metric_type
|
||||
options.TopK,
|
||||
sp,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search documents: %w", err)
|
||||
}
|
||||
|
||||
var results []schema.SearchResult
|
||||
for _, result := range searchResults {
|
||||
for i := 0; i < result.ResultCount; i++ {
|
||||
id, _ := result.IDs.Get(i)
|
||||
score := result.Scores[i]
|
||||
|
||||
var content string
|
||||
var metadata map[string]interface{}
|
||||
for _, field := range result.Fields {
|
||||
switch field.Name() {
|
||||
case "content":
|
||||
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
|
||||
if contentVal, err := contentCol.Get(i); err == nil {
|
||||
if contentStr, ok := contentVal.(string); ok {
|
||||
content = contentStr
|
||||
}
|
||||
}
|
||||
}
|
||||
case "metadata":
|
||||
if metaCol, ok := field.(*entity.ColumnJSONBytes); ok {
|
||||
if metaVal, err := metaCol.Get(i); err == nil {
|
||||
if metaBytes, ok := metaVal.([]byte); ok {
|
||||
if err := json.Unmarshal(metaBytes, &metadata); err != nil {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
searchResult := schema.SearchResult{
|
||||
Document: schema.Document{
|
||||
ID: fmt.Sprintf("%s", id),
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
},
|
||||
Score: float64(score),
|
||||
}
|
||||
results = append(results, searchResult)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *MilvusVectorStoreProvider) Close() error {
|
||||
if c.client != nil {
|
||||
return c.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
237
plugins/golang-filter/mcp-server/servers/tool-search/search.go
Normal file
237
plugins/golang-filter/mcp-server/servers/tool-search/search.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
// SearchService handles tool search operations
|
||||
type SearchService struct {
|
||||
milvusProvider *MilvusVectorStoreProvider
|
||||
config *config.VectorDBConfig
|
||||
tableName string
|
||||
dimensions int
|
||||
maxTools int // 写死的最大工具数量,仅用于单测
|
||||
embeddingClient *EmbeddingClient
|
||||
}
|
||||
|
||||
// NewSearchService creates a new SearchService instance
|
||||
func NewSearchService(host string, port int, database, username, password, tableName string, embeddingClient *EmbeddingClient, dimensions int, maxTools int) *SearchService {
|
||||
// Create Milvus configuration
|
||||
cfg := &config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
Host: host,
|
||||
Port: port,
|
||||
Database: database,
|
||||
Collection: tableName,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Create Milvus provider
|
||||
provider, err := NewMilvusVectorStoreProvider(cfg, dimensions)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to create Milvus provider: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &SearchService{
|
||||
milvusProvider: provider,
|
||||
config: cfg,
|
||||
tableName: tableName,
|
||||
dimensions: dimensions,
|
||||
maxTools: maxTools, // 使用写死的值
|
||||
embeddingClient: embeddingClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ToolSearchResult represents the result of a tool search
|
||||
type ToolSearchResult struct {
|
||||
Tools []ToolDefinition `json:"tools"`
|
||||
}
|
||||
|
||||
// ToolDefinition represents a tool definition in the search result
|
||||
type ToolDefinition map[string]interface{}
|
||||
|
||||
// SearchTools performs semantic search for tools
|
||||
func (s *SearchService) SearchTools(ctx context.Context, query string, topK int) (*ToolSearchResult, error) {
|
||||
api.LogInfof("Starting tool search for query: '%s', topK: %d", query, topK)
|
||||
|
||||
// Generate vector embedding for the query
|
||||
vector, err := s.embeddingClient.GetEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to generate embedding for query '%s': %v", query, err)
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Embedding generated successfully, vector dimension: %d", len(vector))
|
||||
|
||||
// Perform vector search
|
||||
records, err := s.searchToolsInDB(query, vector, topK)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to search tools: %v", err)
|
||||
return nil, fmt.Errorf("failed to search tools: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Vector search completed, found %d records", len(records))
|
||||
|
||||
return s.convertRecordsToResult(records), nil
|
||||
}
|
||||
|
||||
// convertRecordsToResult converts database records to tool search result
|
||||
func (s *SearchService) convertRecordsToResult(records []ToolRecord) *ToolSearchResult {
|
||||
api.LogInfof("Converting %d records to tool definitions", len(records))
|
||||
|
||||
tools := make([]ToolDefinition, 0, len(records))
|
||||
for i, record := range records {
|
||||
var tool ToolDefinition
|
||||
|
||||
// Use metadata if available
|
||||
if len(record.Metadata) > 0 {
|
||||
tool = record.Metadata
|
||||
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
|
||||
} else {
|
||||
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
|
||||
// If no metadata, create a basic tool definition
|
||||
tool = ToolDefinition{
|
||||
"name": record.Name,
|
||||
"description": record.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Update the name to include server name
|
||||
tool["name"] = fmt.Sprintf("%s", record.Name)
|
||||
|
||||
tools = append(tools, tool)
|
||||
|
||||
api.LogDebugf("Tool %d: %s - %s", i+1, tool["name"], record.Content)
|
||||
}
|
||||
|
||||
api.LogInfof("Successfully converted %d tools", len(tools))
|
||||
return &ToolSearchResult{Tools: tools}
|
||||
}
|
||||
|
||||
// GetAllTools retrieves all available tools
|
||||
func (s *SearchService) GetAllTools() (*ToolSearchResult, error) {
|
||||
api.LogInfo("Retrieving all tools")
|
||||
records, err := s.getAllToolsFromDB()
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to get all tools: %v", err)
|
||||
return nil, fmt.Errorf("failed to get all tools: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Found %d tools in database", len(records))
|
||||
|
||||
// Convert records to tool definitions
|
||||
tools := make([]ToolDefinition, 0, len(records))
|
||||
for _, record := range records {
|
||||
var tool ToolDefinition
|
||||
|
||||
// Use metadata if available
|
||||
if len(record.Metadata) > 0 {
|
||||
tool = record.Metadata
|
||||
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
|
||||
} else {
|
||||
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
|
||||
// If no metadata, create a basic tool definition
|
||||
tool = ToolDefinition{
|
||||
"name": record.Name,
|
||||
"description": record.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Update the name to include server name
|
||||
tool["name"] = fmt.Sprintf("%s", record.Name)
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
api.LogInfof("Successfully converted %d tools", len(tools))
|
||||
return &ToolSearchResult{Tools: tools}, nil
|
||||
}
|
||||
|
||||
// ToolRecord represents a tool record in the database
|
||||
type ToolRecord struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
func (s *SearchService) searchToolsInDB(query string, vector []float32, topK int) ([]ToolRecord, error) {
|
||||
api.LogInfof("Performing vector search for query: '%s', topK: %d", query, topK)
|
||||
|
||||
// For Milvus, we'll perform vector search directly
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Perform vector search
|
||||
searchOptions := &schema.SearchOptions{
|
||||
TopK: topK,
|
||||
}
|
||||
|
||||
results, err := s.milvusProvider.SearchDocs(ctx, vector, searchOptions)
|
||||
if err != nil {
|
||||
api.LogErrorf("Vector search failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to perform vector search: %w", err)
|
||||
}
|
||||
|
||||
// Convert results to ToolRecords
|
||||
var records []ToolRecord
|
||||
for _, result := range results {
|
||||
doc := result.Document
|
||||
tool := ToolRecord{
|
||||
ID: doc.ID,
|
||||
Content: doc.Content,
|
||||
Metadata: doc.Metadata,
|
||||
}
|
||||
|
||||
if name, ok := doc.Metadata["name"].(string); ok {
|
||||
tool.Name = name
|
||||
}
|
||||
|
||||
records = append(records, tool)
|
||||
}
|
||||
|
||||
api.LogInfof("Vector search completed, found %d results", len(records))
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// getAllToolsFromDB retrieves all tools from the database
|
||||
func (s *SearchService) getAllToolsFromDB() ([]ToolRecord, error) {
|
||||
api.LogInfof("Executing GetAllTools query from collection: %s", s.tableName)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Retrieve all documents with limit
|
||||
docs, err := s.milvusProvider.ListAllDocs(ctx, s.maxTools)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to list documents: %v", err)
|
||||
return nil, fmt.Errorf("failed to list documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert documents to ToolRecords
|
||||
var tools []ToolRecord
|
||||
for _, doc := range docs {
|
||||
tool := ToolRecord{
|
||||
ID: doc.ID,
|
||||
Content: doc.Content,
|
||||
Metadata: doc.Metadata,
|
||||
}
|
||||
|
||||
if name, ok := doc.Metadata["name"].(string); ok {
|
||||
tool.Name = name
|
||||
}
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
api.LogInfof("GetAllTools query completed, found %d tools", len(tools))
|
||||
return tools, nil
|
||||
}
|
||||
196
plugins/golang-filter/mcp-server/servers/tool-search/server.go
Normal file
196
plugins/golang-filter/mcp-server/servers/tool-search/server.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "1.0.0"
|
||||
|
||||
// 默认配置值
|
||||
defaultTableName = "apig_mcp_tools"
|
||||
defaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
defaultModel = "text-embedding-v4"
|
||||
defaultDimensions = 1024
|
||||
// 写死最大工具数量为1000,仅用于单测
|
||||
fixedMaxTools = 1000
|
||||
)
|
||||
|
||||
func init() {
|
||||
common.GlobalRegistry.RegisterServer("tool-search", &ToolSearchConfig{})
|
||||
}
|
||||
|
||||
type VectorConfig struct {
|
||||
Type string `json:"type"`
|
||||
VectorWeight float64 `json:"vectorWeight"`
|
||||
TableName string `json:"tableName"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Database string `json:"database"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type EmbeddingConfig struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
BaseURL string `json:"baseURL"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
}
|
||||
|
||||
type ToolSearchConfig struct {
|
||||
Vector VectorConfig `json:"vector"`
|
||||
Embedding EmbeddingConfig `json:"embedding"`
|
||||
description string
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) ParseConfig(config map[string]any) error {
|
||||
// Parse vector configuration
|
||||
vectorConfig, ok := config["vector"].(map[string]any)
|
||||
if !ok {
|
||||
return errors.New("missing vector configuration")
|
||||
}
|
||||
|
||||
if err := c.parseVectorConfig(vectorConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse vector config: %w", err)
|
||||
}
|
||||
|
||||
// Parse embedding configuration
|
||||
embeddingConfig, ok := config["embedding"].(map[string]any)
|
||||
if !ok {
|
||||
return errors.New("missing embedding configuration")
|
||||
}
|
||||
|
||||
if err := c.parseEmbeddingConfig(embeddingConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse embedding config: %w", err)
|
||||
}
|
||||
|
||||
// Optional description
|
||||
if description, ok := config["description"].(string); ok {
|
||||
c.description = description
|
||||
} else {
|
||||
c.description = "Tool search server for semantic similarity search"
|
||||
}
|
||||
|
||||
api.LogDebugf("ToolSearchConfig ParseConfig: %+v", config)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) parseVectorConfig(config map[string]any) error {
|
||||
if vectorType, ok := config["type"].(string); ok {
|
||||
c.Vector.Type = vectorType
|
||||
} else {
|
||||
return errors.New("missing vector.type")
|
||||
}
|
||||
|
||||
if c.Vector.Type != "milvus" {
|
||||
return fmt.Errorf("unsupported vector.type: %s, only 'milvus' is supported", c.Vector.Type)
|
||||
}
|
||||
|
||||
if host, ok := config["host"].(string); ok {
|
||||
c.Vector.Host = host
|
||||
} else {
|
||||
return errors.New("missing vector.host")
|
||||
}
|
||||
|
||||
if port, ok := config["port"].(float64); ok {
|
||||
c.Vector.Port = int(port)
|
||||
} else if port, ok := config["port"].(int); ok {
|
||||
c.Vector.Port = port
|
||||
} else {
|
||||
return errors.New("missing vector.port")
|
||||
}
|
||||
|
||||
if database, ok := config["database"].(string); ok {
|
||||
c.Vector.Database = database
|
||||
} else {
|
||||
c.Vector.Database = "default" // 默认数据库
|
||||
}
|
||||
|
||||
if tableName, ok := config["tableName"].(string); ok {
|
||||
c.Vector.TableName = tableName
|
||||
} else {
|
||||
c.Vector.TableName = defaultTableName
|
||||
}
|
||||
|
||||
if username, ok := config["username"].(string); ok {
|
||||
c.Vector.Username = username
|
||||
}
|
||||
|
||||
if password, ok := config["password"].(string); ok {
|
||||
c.Vector.Password = password
|
||||
}
|
||||
|
||||
// 移除maxTools的解析逻辑
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) parseEmbeddingConfig(config map[string]any) error {
|
||||
// Parse API key (required)
|
||||
if apiKey, ok := config["apiKey"].(string); ok {
|
||||
c.Embedding.APIKey = apiKey
|
||||
} else {
|
||||
return errors.New("missing embedding.apiKey")
|
||||
}
|
||||
|
||||
// Parse optional fields with defaults
|
||||
if baseURL, ok := config["baseURL"].(string); ok {
|
||||
c.Embedding.BaseURL = baseURL
|
||||
} else {
|
||||
c.Embedding.BaseURL = defaultBaseURL
|
||||
}
|
||||
|
||||
if model, ok := config["model"].(string); ok {
|
||||
c.Embedding.Model = model
|
||||
} else {
|
||||
c.Embedding.Model = defaultModel
|
||||
}
|
||||
|
||||
if dimensions, ok := config["dimensions"].(float64); ok {
|
||||
c.Embedding.Dimensions = int(dimensions)
|
||||
} else if dimensions, ok := config["dimensions"].(int); ok {
|
||||
c.Embedding.Dimensions = dimensions
|
||||
} else {
|
||||
c.Embedding.Dimensions = defaultDimensions
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
Version,
|
||||
common.WithInstructions(c.description),
|
||||
)
|
||||
|
||||
// Create embedding client
|
||||
embeddingClient := NewEmbeddingClient(c.Embedding.APIKey, c.Embedding.BaseURL, c.Embedding.Model, c.Embedding.Dimensions)
|
||||
|
||||
// Create search service,使用写死的fixedMaxTools值
|
||||
searchService := NewSearchService(
|
||||
c.Vector.Host,
|
||||
c.Vector.Port,
|
||||
c.Vector.Database,
|
||||
c.Vector.Username,
|
||||
c.Vector.Password,
|
||||
c.Vector.TableName,
|
||||
embeddingClient,
|
||||
c.Embedding.Dimensions,
|
||||
fixedMaxTools, // 使用写死的值
|
||||
)
|
||||
|
||||
// Add tool search tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("x_higress_tool_search", "Higress MCP Tools Searcher", GetToolSearchSchema()),
|
||||
HandleToolSearch(searchService),
|
||||
)
|
||||
|
||||
return mcpServer, nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// Mock implementation of CommonCAPI for testing
|
||||
type mockCommonCAPI struct {
|
||||
logs []string
|
||||
}
|
||||
|
||||
func (m *mockCommonCAPI) Log(level api.LogType, message string) {
|
||||
fmt.Printf("[%s] %s\n", level, message)
|
||||
m.logs = append(m.logs, message)
|
||||
}
|
||||
|
||||
func (m *mockCommonCAPI) LogLevel() api.LogType {
|
||||
return api.Debug
|
||||
}
|
||||
|
||||
// TestServer is used for local functional testing
|
||||
func TestServer(t *testing.T) {
|
||||
// Setup mock API for logging
|
||||
mockAPI := &mockCommonCAPI{}
|
||||
api.SetCommonCAPI(mockAPI)
|
||||
|
||||
// Load configuration from environment variables or use defaults
|
||||
config := map[string]any{
|
||||
"vector": map[string]any{
|
||||
"type": "milvus",
|
||||
"vectorWeight": 0.6,
|
||||
"tableName": getEnvOrDefault("TEST_TABLE_NAME", "apig_mcp_tools"),
|
||||
"host": getEnvOrDefault("TEST_MILVUS_HOST", "localhost"),
|
||||
"port": getEnvOrDefaultInt("TEST_MILVUS_PORT", 19530),
|
||||
"database": getEnvOrDefault("TEST_MILVUS_DATABASE", "default"),
|
||||
"username": getEnvOrDefault("TEST_MILVUS_USERNAME", "root"),
|
||||
"password": getEnvOrDefault("TEST_MILVUS_PASSWORD", "Milvus"),
|
||||
"maxTools": getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
|
||||
},
|
||||
"embedding": map[string]any{
|
||||
"apiKey": getEnvOrDefault("TEST_API_KEY", "your-dashscope-api-key"),
|
||||
"baseURL": getEnvOrDefault("TEST_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
"model": getEnvOrDefault("TEST_MODEL", "text-embedding-v4"),
|
||||
"dimensions": 1024,
|
||||
},
|
||||
"description": "Test MCP Tools Search Server",
|
||||
}
|
||||
|
||||
// Create configuration instance
|
||||
toolSearchConfig := &ToolSearchConfig{}
|
||||
if err := toolSearchConfig.ParseConfig(config); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create MCP Server
|
||||
_, err := toolSearchConfig.NewServer("test-tool-search")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// Test database connection
|
||||
vectorConfig := config["vector"].(map[string]any)
|
||||
embeddingConfig := config["embedding"].(map[string]any)
|
||||
|
||||
// Test GetAllTools
|
||||
t.Logf("\n=== Testing GetAllTools ===")
|
||||
embeddingClient := NewEmbeddingClient(
|
||||
embeddingConfig["apiKey"].(string),
|
||||
embeddingConfig["baseURL"].(string),
|
||||
embeddingConfig["model"].(string),
|
||||
embeddingConfig["dimensions"].(int),
|
||||
)
|
||||
|
||||
searchService := NewSearchService(
|
||||
vectorConfig["host"].(string),
|
||||
vectorConfig["port"].(int),
|
||||
vectorConfig["database"].(string),
|
||||
vectorConfig["username"].(string),
|
||||
vectorConfig["password"].(string),
|
||||
vectorConfig["tableName"].(string),
|
||||
embeddingClient,
|
||||
embeddingConfig["dimensions"].(int),
|
||||
getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
|
||||
)
|
||||
|
||||
allTools, err := searchService.GetAllTools()
|
||||
if err != nil {
|
||||
t.Logf("GetAllTools failed: %v", err)
|
||||
} else {
|
||||
t.Logf("Found %d tools:", len(allTools.Tools))
|
||||
for i, tool := range allTools.Tools {
|
||||
if i < 3 { // Show only first 3 tools
|
||||
toolJSON, _ := json.MarshalIndent(tool, "", " ")
|
||||
t.Logf("Tool %d: %s", i+1, string(toolJSON))
|
||||
}
|
||||
}
|
||||
if len(allTools.Tools) > 3 {
|
||||
t.Logf("... and %d more tools", len(allTools.Tools)-3)
|
||||
}
|
||||
}
|
||||
|
||||
// Test tool search with timing
|
||||
t.Logf("\n=== Testing Tool Search ===")
|
||||
testQueries := []string{
|
||||
"weather data",
|
||||
"database query",
|
||||
"file operations",
|
||||
"HTTP requests",
|
||||
"library documents",
|
||||
}
|
||||
|
||||
for _, query := range testQueries {
|
||||
t.Logf("\n--- Testing query: '%s' ---", query)
|
||||
|
||||
// Create MCP tool call request
|
||||
request := mcp.CallToolRequest{
|
||||
Params: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
Meta *struct {
|
||||
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Name: "x_higress_tool_search",
|
||||
Arguments: map[string]interface{}{
|
||||
"query": query,
|
||||
"topK": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Get tool handler
|
||||
handler := HandleToolSearch(searchService)
|
||||
|
||||
// Execute search with timing
|
||||
start := time.Now()
|
||||
result, err := handler(context.Background(), request)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Search failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Print results with timing information
|
||||
t.Logf("Search completed in %v", duration)
|
||||
if len(result.Content) > 0 {
|
||||
if textContent, ok := result.Content[0].(mcp.TextContent); ok {
|
||||
var toolsResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &toolsResult); err == nil {
|
||||
toolsJSON, _ := json.MarshalIndent(toolsResult, "", " ")
|
||||
t.Logf("Tools Result: %s", string(toolsJSON))
|
||||
} else {
|
||||
t.Logf("Text Content: %s", textContent.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test configuration validation
|
||||
t.Logf("\n=== Configuration Validation ===")
|
||||
t.Logf("Host: %s", vectorConfig["host"])
|
||||
t.Logf("Port: %d", vectorConfig["port"])
|
||||
t.Logf("Database: %s", vectorConfig["database"])
|
||||
t.Logf("Table Name: %s", vectorConfig["tableName"])
|
||||
t.Logf("Vector Weight: %f", vectorConfig["vectorWeight"])
|
||||
t.Logf("Text Weight: %f", 1.0-vectorConfig["vectorWeight"].(float64))
|
||||
t.Logf("Model: %s", embeddingConfig["model"])
|
||||
t.Logf("Dimensions: %d", embeddingConfig["dimensions"])
|
||||
t.Logf("API Base URL: %s", embeddingConfig["baseURL"])
|
||||
|
||||
t.Logf("\n=== Test completed ===")
|
||||
}
|
||||
|
||||
// Helper function to get environment variable or default value
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvOrDefaultInt(key string, defaultValue int) int {
|
||||
if valueStr := os.Getenv(key); valueStr != "" {
|
||||
if value, err := fmt.Sscanf(valueStr, "%d", &defaultValue); err == nil && value == 1 {
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
114
plugins/golang-filter/mcp-server/servers/tool-search/tools.go
Normal file
114
plugins/golang-filter/mcp-server/servers/tool-search/tools.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// HandleToolSearch handles the x_higress_tool_search tool
|
||||
func HandleToolSearch(searchService *SearchService) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
api.LogInfo("HandleToolSearch called")
|
||||
|
||||
arguments := request.Params.Arguments
|
||||
api.LogDebugf("Request arguments: %+v", arguments)
|
||||
|
||||
// Get query parameter
|
||||
query, ok := arguments["query"].(string)
|
||||
if !ok {
|
||||
api.LogErrorf("Invalid query argument type: %T", arguments["query"])
|
||||
return nil, fmt.Errorf("invalid query argument")
|
||||
}
|
||||
|
||||
// Validate query
|
||||
if query == "" {
|
||||
api.LogError("Empty query provided")
|
||||
return nil, fmt.Errorf("query cannot be empty")
|
||||
}
|
||||
|
||||
// Get topK parameter (optional, default to 10)
|
||||
topK := 10
|
||||
if topKVal, ok := arguments["topK"]; ok {
|
||||
switch v := topKVal.(type) {
|
||||
case float64:
|
||||
topK = int(v)
|
||||
case int:
|
||||
topK = v
|
||||
case int64:
|
||||
topK = int(v)
|
||||
default:
|
||||
api.LogWarnf("Invalid topK argument type: %T, using default: %d", topKVal, topK)
|
||||
}
|
||||
|
||||
// Validate topK range
|
||||
if topK <= 0 || topK > 100 {
|
||||
api.LogWarnf("Invalid topK value: %d, using default: 10", topK)
|
||||
topK = 10
|
||||
}
|
||||
}
|
||||
|
||||
api.LogInfof("Parsed parameters - query: '%s', topK: %d", query, topK)
|
||||
|
||||
// Add timeout to context
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Perform search
|
||||
result, err := searchService.SearchTools(ctx, query, topK)
|
||||
if err != nil {
|
||||
api.LogErrorf("Search failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to search tools: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Search completed successfully, found %d tools", len(result.Tools))
|
||||
|
||||
// Build response
|
||||
response := map[string]interface{}{
|
||||
"tools": result.Tools,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to marshal response: %v", err)
|
||||
return nil, fmt.Errorf("failed to marshal search results: %w", err)
|
||||
}
|
||||
|
||||
api.LogDebugf("Response marshaled successfully, JSON length: %d", len(jsonData))
|
||||
api.LogDebugf("Returning MCP CallToolResult")
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetToolSearchSchema returns the schema for the tool search tool
|
||||
func GetToolSearchSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Query statement for semantic similarity comparison with tool descriptions"
|
||||
},
|
||||
"topK": {
|
||||
"type": "integer",
|
||||
"description": "Specify how many tools need to be selected, default is to select the top 10 tools.",
|
||||
"minimum": 1,
|
||||
"maximum": 100
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}`)
|
||||
}
|
||||
@@ -26,8 +26,8 @@ type config struct {
|
||||
matchList []common.MatchRule
|
||||
enableUserLevelServer bool
|
||||
rateLimitConfig *handler.MCPRatelimitConfig
|
||||
defaultServer *common.SSEServer
|
||||
redisClient *common.RedisClient
|
||||
sharedMCPServer *common.MCPServer // Created once, thread-safe with sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *config) Destroy() {
|
||||
@@ -110,6 +110,9 @@ func (p *Parser) Parse(any *anypb.Any, callbacks api.ConfigCallbackHandler) (int
|
||||
}
|
||||
GlobalSSEPathSuffix = ssePathSuffix
|
||||
|
||||
// Create shared MCPServer once during config parsing (thread-safe with sync.RWMutex)
|
||||
conf.sharedMCPServer = common.NewMCPServer(DefaultServerName, Version)
|
||||
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@@ -125,9 +128,6 @@ func (p *Parser) Merge(parent interface{}, child interface{}) interface{} {
|
||||
if childConfig.rateLimitConfig != nil {
|
||||
newConfig.rateLimitConfig = childConfig.rateLimitConfig
|
||||
}
|
||||
if childConfig.defaultServer != nil {
|
||||
newConfig.defaultServer = childConfig.defaultServer
|
||||
}
|
||||
return &newConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ type filter struct {
|
||||
skipRequestBody bool
|
||||
skipResponseBody bool
|
||||
cachedResponseBody []byte
|
||||
sseServer *common.SSEServer // SSE server instance for this filter (per-request, not shared)
|
||||
|
||||
userLevelConfig bool
|
||||
mcpConfigHandler *handler.MCPConfigHandler
|
||||
@@ -135,11 +136,13 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
|
||||
trimmed += "?" + rq
|
||||
}
|
||||
|
||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||
// Create SSE server instance for this filter (per-request, not shared)
|
||||
// MCPServer is shared (thread-safe), but SSEServer must be per-request (contains request-specific messageEndpoint)
|
||||
f.sseServer = common.NewSSEServer(f.config.sharedMCPServer,
|
||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||
common.WithMessageEndpoint(trimmed),
|
||||
common.WithRedisClient(f.config.redisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
f.serverName = f.sseServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusOK, body, nil, 0, "")
|
||||
}
|
||||
@@ -238,10 +241,9 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
ret := api.Continue
|
||||
api.LogDebugf("Upstream Type: %s", f.matchedRule.UpstreamType)
|
||||
switch f.matchedRule.UpstreamType {
|
||||
case common.RestUpstream, common.StreamableUpstream:
|
||||
case common.RestUpstream:
|
||||
api.LogDebugf("Encoding data from Rest upstream")
|
||||
ret = f.encodeDataFromRestUpstream(buffer, endStream)
|
||||
break
|
||||
case common.SSEUpstream:
|
||||
api.LogDebugf("Encoding data from SSE upstream")
|
||||
ret = f.encodeDataFromSSEUpstream(buffer, endStream)
|
||||
@@ -249,6 +251,8 @@ func (f *filter) EncodeData(buffer api.BufferInstance, endStream bool) api.Statu
|
||||
// Always continue as long as the stream has ended.
|
||||
ret = api.Continue
|
||||
}
|
||||
case common.StreamableUpstream:
|
||||
// Do nothing for streamable upstream
|
||||
}
|
||||
return ret
|
||||
}
|
||||
@@ -274,9 +278,9 @@ func (f *filter) encodeDataFromRestUpstream(buffer api.BufferInstance, endStream
|
||||
|
||||
if f.serverName != "" {
|
||||
if f.config.redisClient != nil {
|
||||
// handle default server
|
||||
// handle SSE server for this filter instance
|
||||
buffer.Reset()
|
||||
f.config.defaultServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
f.sseServer.HandleSSE(f.callbacks, f.stopChan)
|
||||
return api.Running
|
||||
} else {
|
||||
_ = buffer.SetString(RedisNotEnabledResponseBody)
|
||||
|
||||
1
plugins/wasm-cpp/.clang-format
Normal file
1
plugins/wasm-cpp/.clang-format
Normal file
@@ -0,0 +1 @@
|
||||
BasedOnStyle: Google
|
||||
@@ -135,8 +135,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PluginRootContext::incrementRequestCount() {
|
||||
request_count_++;
|
||||
if (request_count_ >= REBUILD_THRESHOLD) {
|
||||
LOG_DEBUG("Request count reached threshold, triggering rebuild");
|
||||
setFilterState("wasm_need_rebuild", "true");
|
||||
request_count_ = 0; // Reset counter after setting rebuild flag
|
||||
}
|
||||
}
|
||||
|
||||
FilterHeadersStatus PluginRootContext::onHeader(
|
||||
const ModelMapperConfigRule& rule) {
|
||||
// Increment request count and check for rebuild
|
||||
incrementRequestCount();
|
||||
|
||||
// Check memory threshold and trigger rebuild if needed
|
||||
std::string value;
|
||||
if (getValue({"plugin_vm_memory"}, &value)) {
|
||||
// The value is stored as binary uint64_t, convert to string for logging
|
||||
if (value.size() == sizeof(uint64_t)) {
|
||||
uint64_t memory_size;
|
||||
memcpy(&memory_size, value.data(), sizeof(uint64_t));
|
||||
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
|
||||
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
|
||||
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
|
||||
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
|
||||
setFilterState("wasm_need_rebuild", "true");
|
||||
}
|
||||
} else {
|
||||
LOG_ERROR("invalid memory size format");
|
||||
}
|
||||
} else {
|
||||
LOG_ERROR("get vm memory size failed");
|
||||
}
|
||||
|
||||
if (!Wasm::Common::Http::hasRequestBody()) {
|
||||
return FilterHeadersStatus::Continue;
|
||||
}
|
||||
|
||||
@@ -42,9 +42,10 @@ struct ModelMapperConfigRule {
|
||||
std::vector<std::pair<std::string, std::string>> prefix_model_mapping_;
|
||||
std::string default_model_mapping_;
|
||||
std::vector<std::string> enable_on_path_suffix_ = {
|
||||
"/completions", "/embeddings", "/images/generations",
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||
"/image-synthesis", "/video-synthesis"};
|
||||
"/completions", "/embeddings", "/images/generations",
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||
"/image-synthesis", "/video-synthesis", "/rerank",
|
||||
"/messages"};
|
||||
};
|
||||
|
||||
// PluginRootContext is the root context for all streams processed by the
|
||||
@@ -60,9 +61,13 @@ class PluginRootContext : public RootContext,
|
||||
FilterHeadersStatus onHeader(const ModelMapperConfigRule&);
|
||||
FilterDataStatus onBody(const ModelMapperConfigRule&, std::string_view);
|
||||
bool configure(size_t);
|
||||
void incrementRequestCount();
|
||||
|
||||
private:
|
||||
bool parsePluginConfig(const json&, ModelMapperConfigRule&) override;
|
||||
uint64_t request_count_ = 0;
|
||||
static constexpr uint64_t REBUILD_THRESHOLD = 1000;
|
||||
static constexpr size_t MEMORY_THRESHOLD_BYTES = 200 * 1024 * 1024;
|
||||
};
|
||||
|
||||
// Per-stream context.
|
||||
|
||||
@@ -123,9 +123,40 @@ bool PluginRootContext::configure(size_t configuration_size) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void PluginRootContext::incrementRequestCount() {
|
||||
request_count_++;
|
||||
if (request_count_ >= REBUILD_THRESHOLD) {
|
||||
LOG_DEBUG("Request count reached threshold, triggering rebuild");
|
||||
setFilterState("wasm_need_rebuild", "true");
|
||||
request_count_ = 0; // Reset counter after setting rebuild flag
|
||||
}
|
||||
}
|
||||
|
||||
FilterHeadersStatus PluginRootContext::onHeader(
|
||||
PluginContext& ctx,
|
||||
const ModelRouterConfigRule& rule) {
|
||||
PluginContext& ctx, const ModelRouterConfigRule& rule) {
|
||||
// Increment request count and check for rebuild
|
||||
incrementRequestCount();
|
||||
|
||||
// Check memory threshold and trigger rebuild if needed
|
||||
std::string value;
|
||||
if (getValue({"plugin_vm_memory"}, &value)) {
|
||||
// The value is stored as binary uint64_t, convert to string for logging
|
||||
if (value.size() == sizeof(uint64_t)) {
|
||||
uint64_t memory_size;
|
||||
memcpy(&memory_size, value.data(), sizeof(uint64_t));
|
||||
LOG_DEBUG(absl::StrCat("vm memory size is ", memory_size));
|
||||
if (memory_size >= MEMORY_THRESHOLD_BYTES) {
|
||||
LOG_INFO(absl::StrCat("Memory threshold reached (", memory_size, " >= ",
|
||||
MEMORY_THRESHOLD_BYTES, "), triggering rebuild"));
|
||||
setFilterState("wasm_need_rebuild", "true");
|
||||
}
|
||||
} else {
|
||||
LOG_ERROR("invalid memory size format");
|
||||
}
|
||||
} else {
|
||||
LOG_ERROR("get vm memory size failed");
|
||||
}
|
||||
|
||||
if (!Wasm::Common::Http::hasRequestBody()) {
|
||||
return FilterHeadersStatus::Continue;
|
||||
}
|
||||
@@ -157,7 +188,7 @@ FilterHeadersStatus PluginRootContext::onHeader(
|
||||
auto content_type_value = content_type_ptr->view();
|
||||
LOG_DEBUG(absl::StrCat("Content-Type: ", content_type_value));
|
||||
if (absl::StrContains(content_type_value,
|
||||
Wasm::Common::Http::ContentTypeValues::Json)) {
|
||||
Wasm::Common::Http::ContentTypeValues::Json)) {
|
||||
ctx.mode_ = MODE_JSON;
|
||||
LOG_DEBUG("Enable JSON mode.");
|
||||
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
|
||||
@@ -165,12 +196,15 @@ FilterHeadersStatus PluginRootContext::onHeader(
|
||||
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
|
||||
return FilterHeadersStatus::StopIteration;
|
||||
}
|
||||
if (absl::StrContains(content_type_value,
|
||||
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
|
||||
if (absl::StrContains(
|
||||
content_type_value,
|
||||
Wasm::Common::Http::ContentTypeValues::MultipartFormData)) {
|
||||
// Get the boundary from the content type
|
||||
auto boundary_start = content_type_value.find("boundary=");
|
||||
if (boundary_start == std::string::npos) {
|
||||
LOG_WARN(absl::StrCat("No boundary found in a multipart/form-data content-type: ", content_type_value));
|
||||
LOG_WARN(absl::StrCat(
|
||||
"No boundary found in a multipart/form-data content-type: ",
|
||||
content_type_value));
|
||||
return FilterHeadersStatus::Continue;
|
||||
}
|
||||
boundary_start += 9;
|
||||
@@ -181,21 +215,27 @@ FilterHeadersStatus PluginRootContext::onHeader(
|
||||
auto boundary_length = boundary_end - boundary_start;
|
||||
if (boundary_length < 1 || boundary_length > 70) {
|
||||
// See https://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
|
||||
LOG_WARN(absl::StrCat("Invalid boundary value in a multipart/form-data content-type: ", content_type_value));
|
||||
LOG_WARN(absl::StrCat(
|
||||
"Invalid boundary value in a multipart/form-data content-type: ",
|
||||
content_type_value));
|
||||
return FilterHeadersStatus::Continue;
|
||||
}
|
||||
auto boundary_value = content_type_value.substr(boundary_start, boundary_end - boundary_start);
|
||||
auto boundary_value = content_type_value.substr(
|
||||
boundary_start, boundary_end - boundary_start);
|
||||
ctx.mode_ = MODE_MULTIPART;
|
||||
ctx.boundary_ = boundary_value;
|
||||
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=", boundary_value));
|
||||
LOG_DEBUG(absl::StrCat("Enable multipart/form-data mode. Boundary=",
|
||||
boundary_value));
|
||||
removeRequestHeader(Wasm::Common::Http::Header::ContentLength);
|
||||
setFilterState(SetDecoderBufferLimitKey, DefaultMaxBodyBytes);
|
||||
LOG_INFO(absl::StrCat("SetRequestBodyBufferLimit: ", DefaultMaxBodyBytes));
|
||||
return FilterHeadersStatus::StopIteration;
|
||||
}
|
||||
return FilterHeadersStatus::Continue;
|
||||
}
|
||||
|
||||
FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule,
|
||||
std::string_view body) {
|
||||
FilterDataStatus PluginRootContext::onJsonBody(
|
||||
const ModelRouterConfigRule& rule, std::string_view body) {
|
||||
const auto& model_key = rule.model_key_;
|
||||
const auto& add_provider_header = rule.add_provider_header_;
|
||||
const auto& model_to_header = rule.model_to_header_;
|
||||
@@ -231,18 +271,18 @@ FilterDataStatus PluginRootContext::onJsonBody(const ModelRouterConfigRule& rule
|
||||
}
|
||||
|
||||
FilterDataStatus PluginRootContext::onMultipartBody(
|
||||
PluginContext& ctx,
|
||||
const ModelRouterConfigRule& rule,
|
||||
WasmDataPtr& body,
|
||||
PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body,
|
||||
bool end_stream) {
|
||||
const auto& add_provider_header = rule.add_provider_header_;
|
||||
const auto& model_to_header = rule.model_to_header_;
|
||||
|
||||
const auto boundary = ctx.boundary_;
|
||||
const auto body_view = body->view();
|
||||
const auto model_param_header = absl::StrCat("Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
|
||||
const auto model_param_header = absl::StrCat(
|
||||
"Content-Disposition: form-data; name=\"", rule.model_key_, "\"");
|
||||
|
||||
for (size_t pos = 0; (pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
|
||||
for (size_t pos = 0;
|
||||
(pos = body_view.find(boundary, pos)) != std::string_view::npos;) {
|
||||
LOG_DEBUG(absl::StrCat("Found boundary at ", pos));
|
||||
pos += boundary.length();
|
||||
size_t end_pos = body_view.find(boundary, pos);
|
||||
@@ -264,7 +304,7 @@ FilterDataStatus PluginRootContext::onMultipartBody(
|
||||
LOG_DEBUG("No value start found in part");
|
||||
break;
|
||||
}
|
||||
value_start += 4; // Skip the "\r\n\r\n"
|
||||
value_start += 4; // Skip the "\r\n\r\n"
|
||||
// model parameter should be only one line
|
||||
size_t value_end = part.find(CRLF, value_start);
|
||||
if (value_end == std::string_view::npos) {
|
||||
@@ -283,8 +323,12 @@ FilterDataStatus PluginRootContext::onMultipartBody(
|
||||
const auto& model = model_value.substr(pos + 1);
|
||||
replaceRequestHeader(add_provider_header, provider);
|
||||
size_t new_size = 0;
|
||||
auto new_buffer_data = absl::StrCat(body_view.substr(0, part_pos + value_start), model, body_view.substr(part_pos + value_end));
|
||||
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0, std::numeric_limits<size_t>::max(), new_buffer_data, &new_size);
|
||||
auto new_buffer_data =
|
||||
absl::StrCat(body_view.substr(0, part_pos + value_start), model,
|
||||
body_view.substr(part_pos + value_end));
|
||||
auto result = setBuffer(WasmBufferType::HttpRequestBody, 0,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
new_buffer_data, &new_size);
|
||||
LOG_DEBUG(absl::StrCat("model route to provider:", provider,
|
||||
", model:", model));
|
||||
LOG_DEBUG(absl::StrCat("result=", result, " new_size=", new_size));
|
||||
@@ -294,7 +338,8 @@ FilterDataStatus PluginRootContext::onMultipartBody(
|
||||
}
|
||||
}
|
||||
// We are done now. We can stop processing the body.
|
||||
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ", body_view.length() , " bytes."));
|
||||
LOG_DEBUG(absl::StrCat("Done processing multipart body after caching ",
|
||||
body_view.length(), " bytes."));
|
||||
ctx.mode_ = MODE_BYPASS;
|
||||
return FilterDataStatus::Continue;
|
||||
}
|
||||
@@ -324,8 +369,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
|
||||
auto* rootCtx = rootContext();
|
||||
body_total_size_ += body_size;
|
||||
switch (mode_) {
|
||||
case MODE_JSON:
|
||||
{
|
||||
case MODE_JSON: {
|
||||
if (!end_stream) {
|
||||
return FilterDataStatus::StopIterationAndBuffer;
|
||||
}
|
||||
@@ -333,8 +377,7 @@ FilterDataStatus PluginContext::onRequestBody(size_t body_size,
|
||||
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
|
||||
return rootCtx->onJsonBody(*config_, body->view());
|
||||
}
|
||||
case MODE_MULTIPART:
|
||||
{
|
||||
case MODE_MULTIPART: {
|
||||
auto body =
|
||||
getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_total_size_);
|
||||
return rootCtx->onMultipartBody(*this, *config_, body, end_stream);
|
||||
|
||||
@@ -48,9 +48,10 @@ struct ModelRouterConfigRule {
|
||||
std::string add_provider_header_;
|
||||
std::string model_to_header_;
|
||||
std::vector<std::string> enable_on_path_suffix_ = {
|
||||
"/completions", "/embeddings", "/images/generations",
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||
"/image-synthesis", "/video-synthesis"};
|
||||
"/completions", "/embeddings", "/images/generations",
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||
"/image-synthesis", "/video-synthesis", "/rerank",
|
||||
"/messages"};
|
||||
};
|
||||
|
||||
class PluginContext;
|
||||
@@ -65,13 +66,20 @@ class PluginRootContext : public RootContext,
|
||||
: RootContext(id, root_id) {}
|
||||
~PluginRootContext() {}
|
||||
bool onConfigure(size_t) override;
|
||||
FilterHeadersStatus onHeader(PluginContext& ctx, const ModelRouterConfigRule&);
|
||||
FilterHeadersStatus onHeader(PluginContext& ctx,
|
||||
const ModelRouterConfigRule&);
|
||||
FilterDataStatus onJsonBody(const ModelRouterConfigRule&, std::string_view);
|
||||
FilterDataStatus onMultipartBody(PluginContext& ctx, const ModelRouterConfigRule& rule, WasmDataPtr& body, bool end_stream);
|
||||
FilterDataStatus onMultipartBody(PluginContext& ctx,
|
||||
const ModelRouterConfigRule& rule,
|
||||
WasmDataPtr& body, bool end_stream);
|
||||
bool configure(size_t);
|
||||
void incrementRequestCount();
|
||||
|
||||
private:
|
||||
bool parsePluginConfig(const json&, ModelRouterConfigRule&) override;
|
||||
uint64_t request_count_ = 0;
|
||||
static constexpr uint64_t REBUILD_THRESHOLD = 1000;
|
||||
static constexpr size_t MEMORY_THRESHOLD_BYTES = 200 * 1024 * 1024;
|
||||
};
|
||||
|
||||
// Per-stream context.
|
||||
@@ -98,4 +106,4 @@ class PluginContext : public Context {
|
||||
} // namespace null_plugin
|
||||
} // namespace proxy_wasm
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -157,7 +157,8 @@ TEST_F(ModelRouterTest, RewriteModelAndHeader) {
|
||||
body_.set(request_json);
|
||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||
FilterHeadersStatus::StopIteration);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||
FilterDataStatus::Continue);
|
||||
}
|
||||
|
||||
TEST_F(ModelRouterTest, ModelToHeader) {
|
||||
@@ -183,7 +184,8 @@ TEST_F(ModelRouterTest, ModelToHeader) {
|
||||
body_.set(request_json);
|
||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||
FilterHeadersStatus::StopIteration);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||
FilterDataStatus::Continue);
|
||||
}
|
||||
|
||||
TEST_F(ModelRouterTest, IgnorePath) {
|
||||
@@ -210,7 +212,8 @@ TEST_F(ModelRouterTest, IgnorePath) {
|
||||
body_.set(request_json);
|
||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||
FilterHeadersStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||
FilterDataStatus::Continue);
|
||||
}
|
||||
|
||||
TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
|
||||
@@ -244,10 +247,10 @@ TEST_F(ModelRouterTest, RouteLevelRewriteModelAndHeader) {
|
||||
route_name_ = "route-a";
|
||||
EXPECT_EQ(context_->onRequestHeaders(0, false),
|
||||
FilterHeadersStatus::StopIteration);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(request_json.length(), true),
|
||||
FilterDataStatus::Continue);
|
||||
}
|
||||
|
||||
|
||||
TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
|
||||
std::string configuration = R"({
|
||||
"addProviderHeader": "x-higress-llm-provider"
|
||||
@@ -257,8 +260,11 @@ TEST_F(ModelRouterTest, RewriteModelAndHeaderMultipartFormData) {
|
||||
EXPECT_TRUE(root_context_->configure(configuration.size()));
|
||||
|
||||
path_ = "/v1/chat/completions";
|
||||
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
|
||||
std::string request_data = std::regex_replace(R"(
|
||||
content_type_ =
|
||||
"multipart/form-data; "
|
||||
"boundary=--------------------------100751621174704322650451";
|
||||
std::string request_data = std::regex_replace(
|
||||
R"(
|
||||
----------------------------100751621174704322650451
|
||||
Content-Disposition: form-data; name="purpose"
|
||||
|
||||
@@ -274,16 +280,21 @@ Content-Type: application/json
|
||||
[
|
||||
]
|
||||
----------------------------100751621174704322650451--
|
||||
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||
)",
|
||||
std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||
EXPECT_CALL(*mock_context_,
|
||||
setBuffer(testing::_, testing::_, testing::_, testing::_))
|
||||
.WillOnce([&](WasmBufferType, size_t start, size_t length, std::string_view body) {
|
||||
std::cerr << "===============" << "\n";
|
||||
.WillOnce([&](WasmBufferType, size_t start, size_t length,
|
||||
std::string_view body) {
|
||||
std::cerr << "==============="
|
||||
<< "\n";
|
||||
std::cerr << body << "\n";
|
||||
std::cerr << "===============" << "\n";
|
||||
std::cerr << "==============="
|
||||
<< "\n";
|
||||
EXPECT_EQ(start, 0);
|
||||
EXPECT_EQ(length, std::numeric_limits<size_t>::max());
|
||||
auto expected_body= std::regex_replace(R"(
|
||||
auto expected_body = std::regex_replace(
|
||||
R"(
|
||||
----------------------------100751621174704322650451
|
||||
Content-Disposition: form-data; name="purpose"
|
||||
|
||||
@@ -292,7 +303,9 @@ batch
|
||||
Content-Disposition: form-data; name="model"
|
||||
|
||||
qwen-turbo
|
||||
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||
)",
|
||||
std::regex("\n"),
|
||||
"\r\n"); // Multipart data requires CRLF line endings
|
||||
EXPECT_EQ(body, expected_body);
|
||||
return WasmResult::Ok;
|
||||
});
|
||||
@@ -308,42 +321,54 @@ qwen-turbo
|
||||
|
||||
auto last_body_size = 0;
|
||||
|
||||
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
||||
auto body = request_data.substr(
|
||||
0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
|
||||
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 +
|
||||
2 /* "model" + CRLF + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
|
||||
body = request_data.substr(
|
||||
0, request_data.find("qwen-turbo") + 10 /* "qwen-turbo" */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
|
||||
body = request_data.substr(
|
||||
0, request_data.find("qwen-turbo") + 10 + 2 /* "qwen-turbo" + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::Continue);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 + 50 /* "qwen-turbo" + CRLF + boundary */);
|
||||
body = request_data.substr(0, request_data.find("qwen-turbo") + 10 + 2 +
|
||||
50 /* "qwen-turbo" + CRLF + boundary */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::Continue);
|
||||
last_body_size = body.size();
|
||||
|
||||
body_.set(request_data);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true),
|
||||
FilterDataStatus::Continue);
|
||||
}
|
||||
|
||||
TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
|
||||
std::string configuration = R"(
|
||||
std::string configuration = R"(
|
||||
{
|
||||
"modelToHeader": "x-higress-llm-model"
|
||||
})";
|
||||
@@ -352,8 +377,11 @@ TEST_F(ModelRouterTest, ModelToHeaderMultipartFormData) {
|
||||
EXPECT_TRUE(root_context_->configure(configuration.size()));
|
||||
|
||||
path_ = "/v1/chat/completions";
|
||||
content_type_ = "multipart/form-data; boundary=--------------------------100751621174704322650451";
|
||||
std::string request_data = std::regex_replace(R"(
|
||||
content_type_ =
|
||||
"multipart/form-data; "
|
||||
"boundary=--------------------------100751621174704322650451";
|
||||
std::string request_data = std::regex_replace(
|
||||
R"(
|
||||
----------------------------100751621174704322650451
|
||||
Content-Disposition: form-data; name="purpose"
|
||||
|
||||
@@ -369,7 +397,8 @@ Content-Type: application/json
|
||||
[
|
||||
]
|
||||
----------------------------100751621174704322650451--
|
||||
)", std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||
)",
|
||||
std::regex("\n"), "\r\n"); // Multipart data requires CRLF line endings
|
||||
EXPECT_CALL(*mock_context_,
|
||||
setBuffer(testing::_, testing::_, testing::_, testing::_))
|
||||
.Times(0);
|
||||
@@ -384,38 +413,50 @@ Content-Type: application/json
|
||||
|
||||
auto last_body_size = 0;
|
||||
|
||||
auto body = request_data.substr(0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
||||
auto body = request_data.substr(
|
||||
0, request_data.find("batch") + 5 + 2 /* batch + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 + 2 /* "model" + CRLF + CRLF */);
|
||||
body = request_data.substr(0, request_data.find("\"model\"") + 5 + 2 +
|
||||
2 /* "model" + CRLF + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen") + 4 /* "qwen" */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
|
||||
body = request_data.substr(
|
||||
0, request_data.find("qwen-max") + 8 /* "qwen-max" */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::StopIterationAndBuffer);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::StopIterationAndBuffer);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
|
||||
body = request_data.substr(
|
||||
0, request_data.find("qwen-max") + 8 + 2 /* "qwen-max" + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::Continue);
|
||||
last_body_size = body.size();
|
||||
|
||||
body = request_data.substr(0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
|
||||
body = request_data.substr(
|
||||
0, request_data.find("qwen-max") + 8 + 2 + 50 /* "qwen-max" + CRLF */);
|
||||
body_.set(body);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, false),
|
||||
FilterDataStatus::Continue);
|
||||
last_body_size = body.size();
|
||||
|
||||
body_.set(request_data);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true), FilterDataStatus::Continue);
|
||||
EXPECT_EQ(context_->onRequestBody(body.size() - last_body_size, true),
|
||||
FilterDataStatus::Continue);
|
||||
}
|
||||
|
||||
} // namespace model_router
|
||||
|
||||
@@ -7,21 +7,21 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.6-0.20251103065747-41d65dbb2f9e
|
||||
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/resp v0.1.1
|
||||
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
|
||||
)
|
||||
|
||||
require github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -4,10 +4,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/wasm-go v1.0.6-0.20251103065747-41d65dbb2f9e h1:wYW/DXjyQniQLaB26c+J9NQk3+AhqByzS1r18NShvB4=
|
||||
github.com/higress-group/wasm-go v1.0.6-0.20251103065747-41d65dbb2f9e/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
|
||||
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac h1:tdJzS56Xa6BSHAi9P2omvb98bpI8qFGg6jnCPtPmDgA=
|
||||
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
|
||||
@@ -38,6 +38,7 @@ func init() {
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBodyBy(onHttpResponseBody),
|
||||
wrapper.WithRebuildAfterRequests[config.PluginConfig](1000),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -15,13 +15,18 @@ description: 针对LLM服务的负载均衡策略
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_type` | string | 选填 | endpoint | 负载均衡类型,可选`endpoint`,`cluster` |
|
||||
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
|
||||
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
|
||||
|
||||
目前支持的负载均衡策略包括:
|
||||
`lb_type` 为 `endpoint` 时支持的负载均衡策略包括:
|
||||
|
||||
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
|
||||
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
|
||||
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
|
||||
- `endpoint_metrics`: 基于 llm 服务暴露的 metrics 进行负载均衡
|
||||
|
||||
`lb_type` 为 `cluster` 时支持的负载均衡策略包括:
|
||||
- `cluster_metrics`: 基于网关统计的不同service的指标进行服务之间的负载均衡
|
||||
|
||||
# 全局最小请求数
|
||||
## 功能说明
|
||||
@@ -59,6 +64,7 @@ sequenceDiagram
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -116,11 +122,12 @@ lb_config:
|
||||
| `password` | string | 选填 | 空 | redis 密码 |
|
||||
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
|
||||
| `database` | int | 选填 | 0 | redis 数据库序号 |
|
||||
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
|
||||
| `redisKeyTTL` | int | 选填 | 1800s | prompt 前缀对应的key的ttl |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -161,14 +168,73 @@ sequenceDiagram
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | 选填 | | critical的模型列表 |
|
||||
| `metric_policy` | string | 必填 | | 如何使用llm暴露的metrics做负载均衡,当前支持`[default, least, most]` |
|
||||
| `target_metric` | string | 选填 | | 要使用的metric名称,`metric_policy` 取值为 `least` 或者 `most` 时生效 |
|
||||
| `rate_limit` | string | 选填 | 1 | 单个节点处理请求比例上限,取值范围0~1 |
|
||||
|
||||
|
||||
## 配置示例
|
||||
|
||||
使用 [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 中的算法
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: default
|
||||
rate_limit: 0.6 # 单个节点承载的最大请求比例
|
||||
```
|
||||
|
||||
根据当前排队请求数进行负载均衡
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_waiting
|
||||
rate_limit: 0.6 # 单个节点承载的最大请求比例
|
||||
```
|
||||
|
||||
根据当前GPU中正在处理的请求数进行负载均衡
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_running
|
||||
rate_limit: 0.6 # 单个节点承载的最大请求比例
|
||||
```
|
||||
|
||||
|
||||
# 跨服务负载均衡
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `mode` | string | 必填 | | 如何使用服务级指标做负载均衡,当前支持`[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
|
||||
| `service_list` | []string | 必填 | | 路由后端服务列表 |
|
||||
| `rate_limit` | string | 选填 | 1 | 单个服务处理请求比例上限,取值范围0~1 |
|
||||
| `cluster_header` | string | 选填 | `x-envoy-target-cluster` | 通过取该header的值得知需要路由到哪个后端服务 |
|
||||
| `queue_size` | int | 选填 | 100 | 根据最近的多少个请求进行观测指标的计算 |
|
||||
|
||||
`mode` 各取值含义如下:
|
||||
- `LeastBusy`: 路由到当前并发请求数最少的服务
|
||||
- `LeastTotalLatency`: 路由到当前RT最低的服务
|
||||
- `LeastFirstTokenLatency`: 路由到当前首包RT最低的服务
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_type: cluster
|
||||
lb_policy: cluster_metrics
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
mode: LeastTotalLatency # 策略名称
|
||||
queue_size: 100 # 统计指标时使用的最近请求数
|
||||
rate_limit: 0.6 # 单个服务承载的最大请求比例
|
||||
service_list:
|
||||
- outbound|80||test-1.dns
|
||||
- outbound|80||test-2.static
|
||||
```
|
||||
@@ -15,14 +15,19 @@ The configuration is:
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_policy` | string | required | | load balance type |
|
||||
| `lb_type` | string | optional | endpoint | load balance policy type, `endpoint` or `cluster` |
|
||||
| `lb_policy` | string | required | | load balance policy type |
|
||||
| `lb_config` | object | required | | configuration for the current load balance type |
|
||||
|
||||
Current supported load balance policies are:
|
||||
When `lb_type = endpoint`, current supported load balance policies are:
|
||||
|
||||
- `global_least_request`: global least request based on redis
|
||||
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
|
||||
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
|
||||
- `endpoint_metrics`: Load balancing based on metrics exposed by the llm service
|
||||
|
||||
When `lb_type = cluster`, current supported load balance policies are:
|
||||
- `cluster_metrics`: Load balancing based on metrics of clusters
|
||||
|
||||
|
||||
# Global Least Request
|
||||
## Introduction
|
||||
@@ -60,6 +65,7 @@ sequenceDiagram
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -118,11 +124,12 @@ Then subsequent requests with the same prefix will also be routed to pod 1:
|
||||
| `password` | string | optional | `` | redis password |
|
||||
| `timeout` | int | optional | 3000ms | redis request timeout |
|
||||
| `database` | int | optional | 0 | redis database number |
|
||||
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
|
||||
| `redisKeyTTL` | int | optional | 1800s | prompt prefix key's ttl |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
@@ -164,14 +171,71 @@ sequenceDiagram
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | required | | critical model names |
|
||||
| `metric_policy` | string | required | | How to use the metrics exposed by LLM for load balancing, currently supporting `[default, least, most]` |
|
||||
| `target_metric` | string | optional | | The metric name to use. This is valid only when `metric_policy` is `least` or `most` |
|
||||
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, 0~1 |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
Use the algorithm of [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md):
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: default
|
||||
rate_limit: 0.6
|
||||
```
|
||||
|
||||
Load balancing based on the current number of queued requests:
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_waiting
|
||||
rate_limit: 0.6
|
||||
```
|
||||
|
||||
Load balancing based on the number of requests currently being processed by the GPU:
|
||||
|
||||
```yaml
|
||||
lb_type: endpoint
|
||||
lb_policy: metrics_based
|
||||
lb_config:
|
||||
metric_policy: least
|
||||
target_metric: vllm:num_requests_running
|
||||
rate_limit: 0.6
|
||||
```
|
||||
|
||||
# Cross-service load balancing
|
||||
|
||||
## Configuration
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `mode` | string | required | | how to use cluster metrics, value of `[LeastBusy, LeastTotalLatency, LeastFirstTokenLatency ]` |
|
||||
| `service_list` | []string | required | | service list of current route |
|
||||
| `rate_limit` | string | optional | 1 | The maximum percentage of requests a single node can receive, value of 0~1 |
|
||||
| `cluster_header` | string | optional | `x-envoy-target-cluster` | By retrieving the value of this header, we can determine which backend service to route to |
|
||||
| `queue_size` | int | optional | 100 | The metrics is calculated based on the number of most recent requests. |
|
||||
|
||||
The meanings of the values for `mode` are as follows:
|
||||
|
||||
- `LeastBusy`: Routes to the service with the fewest concurrent requests.
|
||||
- `LeastTotalLatency`: Routes to the service with the lowest response time (RT).
|
||||
- `LeastFirstTokenLatency`: Routes to the service with the lowest RT for the first packet.
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_type: cluster
|
||||
lb_policy: cluster_metrics
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
```
|
||||
mode: LeastTotalLatency
|
||||
rate_limit: 0.6
|
||||
service_list:
|
||||
- outbound|80||test-1.dns
|
||||
- outbound|80||test-2.static
|
||||
```
|
||||
@@ -0,0 +1,218 @@
|
||||
package cluster_metrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultQueueSize = 100
|
||||
DefaultClusterHeader = "x-higress-target-cluster"
|
||||
)
|
||||
|
||||
type ClusterEndpointLoadBalancer struct {
|
||||
// Configurations
|
||||
Mode string
|
||||
ClusterHeader string
|
||||
ServiceList []string
|
||||
RateLimit float64
|
||||
// Statistic
|
||||
ServiceRequestOngoing map[string]int
|
||||
ServiceRequestCount map[string]int
|
||||
FirstTokenLatencyRequests map[string]*utils.FixedQueue[float64]
|
||||
TotalLatencyRequests map[string]*utils.FixedQueue[float64]
|
||||
}
|
||||
|
||||
func NewClusterEndpointLoadBalancer(json gjson.Result) (ClusterEndpointLoadBalancer, error) {
|
||||
lb := ClusterEndpointLoadBalancer{}
|
||||
lb.ServiceRequestOngoing = make(map[string]int)
|
||||
lb.ServiceRequestCount = make(map[string]int)
|
||||
lb.FirstTokenLatencyRequests = make(map[string]*utils.FixedQueue[float64])
|
||||
lb.TotalLatencyRequests = make(map[string]*utils.FixedQueue[float64])
|
||||
|
||||
lb.Mode = json.Get("mode").String()
|
||||
lb.ClusterHeader = json.Get("cluster_header").String()
|
||||
if lb.ClusterHeader == "" {
|
||||
lb.ClusterHeader = DefaultClusterHeader
|
||||
}
|
||||
if json.Get("rate_limit").Exists() {
|
||||
lb.RateLimit = json.Get("rate_limit").Float()
|
||||
} else {
|
||||
lb.RateLimit = 1.0
|
||||
}
|
||||
queueSize := int(json.Get("queue_size").Int())
|
||||
if queueSize == 0 {
|
||||
queueSize = DefaultQueueSize
|
||||
}
|
||||
|
||||
for _, svc := range json.Get("service_list").Array() {
|
||||
serviceName := svc.String()
|
||||
lb.ServiceList = append(lb.ServiceList, serviceName)
|
||||
lb.ServiceRequestOngoing[serviceName] = 0
|
||||
lb.ServiceRequestCount[serviceName] = 0
|
||||
lb.FirstTokenLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
|
||||
lb.TotalLatencyRequests[serviceName] = utils.NewFixedQueue[float64](queueSize)
|
||||
}
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) getRequestRate(serviceName string) float64 {
|
||||
totalRequestCount := 0
|
||||
for _, v := range lb.ServiceRequestCount {
|
||||
totalRequestCount += v
|
||||
}
|
||||
if totalRequestCount != 0 {
|
||||
return float64(lb.ServiceRequestCount[serviceName]) / float64(totalRequestCount)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) getServiceTTFT(serviceName string) float64 {
|
||||
queue, ok := lb.FirstTokenLatencyRequests[serviceName]
|
||||
if !ok || queue.Size() == 0 {
|
||||
return 0
|
||||
}
|
||||
value := 0.0
|
||||
queue.ForEach(func(i int, item float64) {
|
||||
value += float64(item)
|
||||
})
|
||||
return value / float64(queue.Size())
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) getServiceTotalRT(serviceName string) float64 {
|
||||
queue, ok := lb.TotalLatencyRequests[serviceName]
|
||||
if !ok || queue.Size() == 0 {
|
||||
return 0
|
||||
}
|
||||
value := 0.0
|
||||
queue.ForEach(func(i int, item float64) {
|
||||
value += float64(item)
|
||||
})
|
||||
return value / float64(queue.Size())
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.SetContext("request_start", time.Now().UnixMilli())
|
||||
candidate := lb.ServiceList[rand.Int()%len(lb.ServiceList)]
|
||||
var debugInfo string
|
||||
switch lb.Mode {
|
||||
case "LeastBusy":
|
||||
for svc, ongoingNum := range lb.ServiceRequestOngoing {
|
||||
if candidate == svc {
|
||||
continue
|
||||
}
|
||||
if lb.getRequestRate(candidate) >= lb.RateLimit {
|
||||
candidate = svc
|
||||
} else if ongoingNum < lb.ServiceRequestOngoing[candidate] && lb.getRequestRate(svc) < lb.RateLimit {
|
||||
candidate = svc
|
||||
}
|
||||
}
|
||||
for svc := range lb.ServiceRequestOngoing {
|
||||
debugInfo += fmt.Sprintf("[service: %s] {ongoing request: %d, total request: %d, request rate: %.2f}, ",
|
||||
svc, lb.ServiceRequestOngoing[svc], lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
|
||||
}
|
||||
case "LeastFirstTokenLatency":
|
||||
candidateTTFT := lb.getServiceTTFT(candidate)
|
||||
for _, svc := range lb.ServiceList {
|
||||
if candidate == svc {
|
||||
continue
|
||||
}
|
||||
if lb.getRequestRate(candidate) >= lb.RateLimit {
|
||||
candidate = svc
|
||||
candidateTTFT = lb.getServiceTTFT(svc)
|
||||
} else if lb.getServiceTTFT(svc) < candidateTTFT && lb.getRequestRate(svc) < lb.RateLimit {
|
||||
candidate = svc
|
||||
candidateTTFT = lb.getServiceTTFT(svc)
|
||||
}
|
||||
}
|
||||
for _, svc := range lb.ServiceList {
|
||||
debugInfo += fmt.Sprintf("[service: %s] {average ttft: %.2f, total request: %d, request rate: %.2f}, ",
|
||||
svc, lb.getServiceTTFT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
|
||||
}
|
||||
case "LeastTotalLatency":
|
||||
candidateTotalRT := lb.getServiceTotalRT(candidate)
|
||||
for _, svc := range lb.ServiceList {
|
||||
if candidate == svc {
|
||||
continue
|
||||
}
|
||||
if lb.getRequestRate(candidate) >= lb.RateLimit {
|
||||
candidate = svc
|
||||
candidateTotalRT = lb.getServiceTotalRT(svc)
|
||||
} else if lb.getServiceTotalRT(svc) < candidateTotalRT && lb.getRequestRate(svc) < lb.RateLimit {
|
||||
candidate = svc
|
||||
candidateTotalRT = lb.getServiceTotalRT(svc)
|
||||
}
|
||||
}
|
||||
for _, svc := range lb.ServiceList {
|
||||
debugInfo += fmt.Sprintf("[service: %s] {average latency: %.2f, total request: %d, request rate: %.2f}, ",
|
||||
svc, lb.getServiceTotalRT(svc), lb.ServiceRequestCount[svc], lb.getRequestRate(svc))
|
||||
}
|
||||
}
|
||||
debugInfo += fmt.Sprintf("final service: %s", candidate)
|
||||
log.Debug(debugInfo)
|
||||
proxywasm.ReplaceHttpRequestHeader(lb.ClusterHeader, candidate)
|
||||
ctx.SetContext(lb.ClusterHeader, candidate)
|
||||
lb.ServiceRequestOngoing[candidate] += 1
|
||||
lb.ServiceRequestCount[candidate] += 1
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
|
||||
ctx.SetContext("statusCode", statusCode)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
if ctx.GetContext("ttft_recorded") == nil {
|
||||
candidate := ctx.GetContext(lb.ClusterHeader).(string)
|
||||
duration := float64(time.Now().UnixMilli() - ctx.GetContext("request_start").(int64))
|
||||
// punish failed request
|
||||
if ctx.GetContext("statusCode").(string) != "200" {
|
||||
for _, svc := range lb.ServiceList {
|
||||
ttft := lb.getServiceTTFT(svc)
|
||||
if duration < ttft {
|
||||
duration = ttft
|
||||
}
|
||||
}
|
||||
duration *= 2
|
||||
}
|
||||
lb.FirstTokenLatencyRequests[candidate].Enqueue(duration)
|
||||
ctx.SetContext("ttft_recorded", struct{}{})
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb ClusterEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {
|
||||
candidate := ctx.GetContext(lb.ClusterHeader).(string)
|
||||
lb.ServiceRequestOngoing[candidate] -= 1
|
||||
duration := float64(time.Now().UnixMilli() - ctx.GetContext("request_start").(int64))
|
||||
// punish failed request
|
||||
if ctx.GetContext("statusCode").(string) != "200" {
|
||||
for _, svc := range lb.ServiceList {
|
||||
rt := lb.getServiceTotalRT(svc)
|
||||
if duration < rt {
|
||||
duration = rt
|
||||
}
|
||||
}
|
||||
duration *= 2
|
||||
}
|
||||
lb.TotalLatencyRequests[candidate].Enqueue(duration)
|
||||
}
|
||||
@@ -40,13 +40,19 @@ type Metrics struct {
|
||||
KvCacheMaxTokenCapacity int
|
||||
}
|
||||
|
||||
type UserSelectedMetric struct {
|
||||
MetricName string
|
||||
MetricValue float64
|
||||
}
|
||||
|
||||
type PodMetrics struct {
|
||||
Pod
|
||||
Metrics
|
||||
UserSelectedMetric
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) String() string {
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v, UserSelectedMetric: %+v", pm.Pod, pm.Metrics, pm.UserSelectedMetric)
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
@@ -63,6 +69,10 @@ func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
KVCacheUsagePercent: pm.KVCacheUsagePercent,
|
||||
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
|
||||
},
|
||||
UserSelectedMetric: UserSelectedMetric{
|
||||
MetricName: pm.MetricName,
|
||||
MetricValue: pm.MetricValue,
|
||||
},
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"go.uber.org/multierr"
|
||||
@@ -53,6 +53,16 @@ func PromToPodMetrics(
|
||||
) (*backend.PodMetrics, error) {
|
||||
var errs error
|
||||
updated := existing.Clone()
|
||||
// User selected metric
|
||||
if updated.MetricName != "" {
|
||||
metricValue, err := getLatestMetric(metricFamilies, updated.MetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.MetricValue = metricValue.GetGauge().GetValue()
|
||||
}
|
||||
return updated, errs
|
||||
}
|
||||
// Default metric
|
||||
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
@@ -0,0 +1,120 @@
|
||||
package endpoint_metrics
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/scheduling"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
FixedQueueSize = 100
|
||||
)
|
||||
|
||||
type MetricsEndpointLoadBalancer struct {
|
||||
metricPolicy string
|
||||
targetMetric string
|
||||
endpointRequests *utils.FixedQueue[string]
|
||||
maxRate float64
|
||||
}
|
||||
|
||||
func NewMetricsEndpointLoadBalancer(json gjson.Result) (MetricsEndpointLoadBalancer, error) {
|
||||
lb := MetricsEndpointLoadBalancer{}
|
||||
if json.Get("metric_policy").Exists() {
|
||||
lb.metricPolicy = json.Get("metric_policy").String()
|
||||
} else {
|
||||
lb.metricPolicy = scheduling.MetricPolicyDefault
|
||||
}
|
||||
if json.Get("target_metric").Exists() {
|
||||
lb.targetMetric = json.Get("target_metric").String()
|
||||
}
|
||||
if json.Get("rate_limit").Exists() {
|
||||
lb.maxRate = json.Get("rate_limit").Float()
|
||||
} else {
|
||||
lb.maxRate = 1.0
|
||||
}
|
||||
lb.endpointRequests = utils.NewFixedQueue[string](FixedQueueSize)
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: true,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics, lb.metricPolicy, lb.targetMetric)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
finalAddress := targetPod.Address
|
||||
otherHosts := []string{} // 如果当前host超过请求数限制,那么在其中随机挑选一个
|
||||
currentRate := 0.0
|
||||
for k := range hostMetrics {
|
||||
if k != finalAddress {
|
||||
otherHosts = append(otherHosts, k)
|
||||
}
|
||||
}
|
||||
if lb.endpointRequests.Size() != 0 {
|
||||
count := 0.0
|
||||
lb.endpointRequests.ForEach(func(i int, item string) {
|
||||
if item == finalAddress {
|
||||
count += 1
|
||||
}
|
||||
})
|
||||
currentRate = count / float64(lb.endpointRequests.Size())
|
||||
}
|
||||
if currentRate > lb.maxRate && len(otherHosts) > 0 {
|
||||
finalAddress = otherHosts[rand.Intn(len(otherHosts))]
|
||||
}
|
||||
lb.endpointRequests.Enqueue(finalAddress)
|
||||
log.Debugf("pod %s is selected", finalAddress)
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(finalAddress))
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb MetricsEndpointLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
@@ -20,15 +20,22 @@ package scheduling
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics/backend/vllm"
|
||||
|
||||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
const (
|
||||
MetricPolicyDefault = "default"
|
||||
MetricPolicyLeast = "least"
|
||||
MetricPolicyMost = "most"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
kvCacheThreshold = 0.8
|
||||
@@ -107,11 +114,11 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
|
||||
func NewScheduler(pm []*backend.PodMetrics, filter Filter) *Scheduler {
|
||||
|
||||
return &Scheduler{
|
||||
podMetrics: pm,
|
||||
filter: defaultFilter,
|
||||
filter: filter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,7 +137,7 @@ func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error)
|
||||
return pods[i].Pod, nil
|
||||
}
|
||||
|
||||
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
func GetScheduler(hostMetrics map[string]string, metricPolicy string, targetMetric string) (*Scheduler, error) {
|
||||
if len(hostMetrics) == 0 {
|
||||
return nil, errors.New("backend is not support llm scheduling")
|
||||
}
|
||||
@@ -147,6 +154,9 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
Address: addr,
|
||||
},
|
||||
Metrics: backend.Metrics{},
|
||||
UserSelectedMetric: backend.UserSelectedMetric{
|
||||
MetricName: targetMetric,
|
||||
},
|
||||
}
|
||||
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
|
||||
if err != nil {
|
||||
@@ -154,5 +164,60 @@ func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
}
|
||||
pms = append(pms, pm)
|
||||
}
|
||||
return NewScheduler(pms), nil
|
||||
if metricPolicy == MetricPolicyLeast {
|
||||
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
max := 0.0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= min {
|
||||
min = pod.MetricValue
|
||||
}
|
||||
if pod.MetricValue >= max {
|
||||
max = pod.MetricValue
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue >= min && pod.MetricValue <= min+(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
filter := filter{
|
||||
name: "least user selected metric",
|
||||
filter: filterFunc,
|
||||
}
|
||||
return NewScheduler(pms, &filter), nil
|
||||
} else if metricPolicy == MetricPolicyMost {
|
||||
filterFunc := func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
max := 0.0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= min {
|
||||
min = pod.MetricValue
|
||||
}
|
||||
if pod.MetricValue >= max {
|
||||
max = pod.MetricValue
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.MetricValue <= max && pod.MetricValue >= max-(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
filter := filter{
|
||||
name: "most user selected metric",
|
||||
filter: filterFunc,
|
||||
}
|
||||
return NewScheduler(pms, &filter), nil
|
||||
}
|
||||
return NewScheduler(pms, defaultFilter), nil
|
||||
}
|
||||
@@ -16,40 +16,91 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `local seed = KEYS[1]
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLastCleanKeyFormat = "higress:global_least_request_table:last_clean_time:%s:%s"
|
||||
RedisLua = `local seed = tonumber(KEYS[1])
|
||||
local hset_key = KEYS[2]
|
||||
local current_target = KEYS[3]
|
||||
local current_count = 0
|
||||
local last_clean_key = KEYS[3]
|
||||
local clean_interval = tonumber(KEYS[4])
|
||||
local current_target = KEYS[5]
|
||||
local healthy_count = tonumber(KEYS[6])
|
||||
local enable_detail_log = KEYS[7]
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
local function randomBool()
|
||||
return math.random() >= 0.5
|
||||
end
|
||||
-- 1. Selection
|
||||
local current_count = 0
|
||||
local same_count_hits = 0
|
||||
|
||||
if redis.call('HEXISTS', hset_key, current_target) == 1 then
|
||||
current_count = redis.call('HGET', hset_key, current_target)
|
||||
for i = 4, #KEYS do
|
||||
if redis.call('HEXISTS', hset_key, KEYS[i]) == 1 then
|
||||
local count = redis.call('HGET', hset_key, KEYS[i])
|
||||
if tonumber(count) < tonumber(current_count) then
|
||||
current_target = KEYS[i]
|
||||
current_count = count
|
||||
elseif count == current_count and randomBool() then
|
||||
current_target = KEYS[i]
|
||||
end
|
||||
end
|
||||
end
|
||||
for i = 8, 8 + healthy_count - 1 do
|
||||
local host = KEYS[i]
|
||||
local count = 0
|
||||
local val = redis.call('HGET', hset_key, host)
|
||||
if val then
|
||||
count = tonumber(val) or 0
|
||||
end
|
||||
|
||||
if same_count_hits == 0 or count < current_count then
|
||||
current_target = host
|
||||
current_count = count
|
||||
same_count_hits = 1
|
||||
elseif count == current_count then
|
||||
same_count_hits = same_count_hits + 1
|
||||
if math.random(same_count_hits) == 1 then
|
||||
current_target = host
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
local new_count = redis.call("HGET", hset_key, current_target)
|
||||
|
||||
return current_target`
|
||||
-- Collect host counts for logging
|
||||
local host_details = {}
|
||||
if enable_detail_log == "1" then
|
||||
local fields = {}
|
||||
for i = 8, #KEYS do
|
||||
table.insert(fields, KEYS[i])
|
||||
end
|
||||
if #fields > 0 then
|
||||
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
|
||||
for i, val in ipairs(values) do
|
||||
table.insert(host_details, fields[i])
|
||||
table.insert(host_details, tostring(val or 0))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. Cleanup
|
||||
local current_time = math.floor(seed / 1000000)
|
||||
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
|
||||
|
||||
if current_time - last_clean_time >= clean_interval then
|
||||
local all_keys = redis.call('HKEYS', hset_key)
|
||||
if #all_keys > 0 then
|
||||
-- Create a lookup table for current hosts (from index 8 onwards)
|
||||
local current_hosts = {}
|
||||
for i = 8, #KEYS do
|
||||
current_hosts[KEYS[i]] = true
|
||||
end
|
||||
-- Remove keys not in current hosts
|
||||
for _, host in ipairs(all_keys) do
|
||||
if not current_hosts[host] then
|
||||
redis.call('HDEL', hset_key, host)
|
||||
end
|
||||
end
|
||||
end
|
||||
redis.call('SET', last_clean_key, current_time)
|
||||
end
|
||||
|
||||
return {current_target, new_count, host_details}`
|
||||
)
|
||||
|
||||
type GlobalLeastRequestLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
redisClient wrapper.RedisClient
|
||||
maxRequestCount int64
|
||||
cleanInterval int64 // seconds
|
||||
enableDetailLog bool
|
||||
}
|
||||
|
||||
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
|
||||
@@ -72,6 +123,18 @@ func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoa
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
lb.maxRequestCount = json.Get("maxRequestCount").Int()
|
||||
lb.cleanInterval = json.Get("cleanInterval").Int()
|
||||
if lb.cleanInterval == 0 {
|
||||
lb.cleanInterval = 60 * 60 // default 60 minutes
|
||||
} else {
|
||||
lb.cleanInterval = lb.cleanInterval * 60 // convert minutes to seconds
|
||||
}
|
||||
lb.enableDetailLog = true
|
||||
if val := json.Get("enableDetailLog"); val.Exists() {
|
||||
lb.enableDetailLog = val.Bool()
|
||||
}
|
||||
log.Infof("redis client init, serviceFQDN: %s, servicePort: %d, timeout: %d, database: %d, maxRequestCount: %d, cleanInterval: %d minutes, enableDetailLog: %v", serviceFQDN, servicePort, timeout, database, lb.maxRequestCount, lb.cleanInterval/60, lb.enableDetailLog)
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
@@ -100,9 +163,11 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
allHostMap := make(map[string]struct{})
|
||||
// Only healthy host can be selected
|
||||
healthyHostArray := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
allHostMap[hostInfo[0]] = struct{}{}
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHostArray = append(healthyHostArray, hostInfo[0])
|
||||
}
|
||||
@@ -113,10 +178,37 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
}
|
||||
randomIndex := rand.Intn(len(healthyHostArray))
|
||||
hostSelected := healthyHostArray[randomIndex]
|
||||
keys := []interface{}{time.Now().UnixMicro(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
|
||||
|
||||
// KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, ...healthy_hosts, enableDetailLog, ...unhealthy_hosts]
|
||||
keys := []interface{}{
|
||||
time.Now().UnixMicro(),
|
||||
fmt.Sprintf(RedisKeyFormat, routeName, clusterName),
|
||||
fmt.Sprintf(RedisLastCleanKeyFormat, routeName, clusterName),
|
||||
lb.cleanInterval,
|
||||
hostSelected,
|
||||
len(healthyHostArray),
|
||||
"0",
|
||||
}
|
||||
if lb.enableDetailLog {
|
||||
keys[6] = "1"
|
||||
}
|
||||
for _, v := range healthyHostArray {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
// Append unhealthy hosts (those in allHostMap but not in healthyHostArray)
|
||||
for host := range allHostMap {
|
||||
isHealthy := false
|
||||
for _, hh := range healthyHostArray {
|
||||
if host == hh {
|
||||
isHealthy = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isHealthy {
|
||||
keys = append(keys, host)
|
||||
}
|
||||
}
|
||||
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("HGetAll failed: %+v", err)
|
||||
@@ -124,17 +216,54 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpC
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = response.String()
|
||||
valArray := response.Array()
|
||||
if len(valArray) < 2 {
|
||||
log.Errorf("redis eval lua result format error, expect at least [host, count], got: %+v", valArray)
|
||||
ctx.SetContext("error", true)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = valArray[0].String()
|
||||
currentCount := valArray[1].Integer()
|
||||
|
||||
// detail log
|
||||
if lb.enableDetailLog && len(valArray) >= 3 {
|
||||
detailLogStr := "host and count: "
|
||||
details := valArray[2].Array()
|
||||
for i := 0; i+1 < len(details); i += 2 {
|
||||
h := details[i].String()
|
||||
c := details[i+1].String()
|
||||
detailLogStr += fmt.Sprintf("{%s: %s}, ", h, c)
|
||||
}
|
||||
log.Debugf("host_selected: %s + 1, %s", hostSelected, detailLogStr)
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
if !lb.checkRateLimit(hostSelected, int64(currentCount), ctx, routeName, clusterName) {
|
||||
ctx.SetContext("error", true)
|
||||
log.Warnf("host_selected: %s, current_count: %d, exceed max request limit %d", hostSelected, currentCount, lb.maxRequestCount)
|
||||
// return 429
|
||||
proxywasm.SendHttpResponse(429, [][2]string{}, []byte("Exceeded maximum request limit from ai-load-balancer."), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
return
|
||||
}
|
||||
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
|
||||
// finally resume the request
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("redis eval failed, fallback to default lb policy, error informations: %+v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
@@ -161,7 +290,10 @@ func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpCo
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
err := lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
if err != nil {
|
||||
log.Errorf("host_selected: %s - 1, failed to update count from redis: %v", host_selected, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
-- Mocking Redis environment
|
||||
local redis_data = {
|
||||
hset = {},
|
||||
kv = {}
|
||||
}
|
||||
|
||||
local redis = {
|
||||
call = function(cmd, ...)
|
||||
local args = {...}
|
||||
if cmd == "HGET" then
|
||||
local key, field = args[1], args[2]
|
||||
return redis_data.hset[field]
|
||||
elseif cmd == "HSET" then
|
||||
local key, field, val = args[1], args[2], args[3]
|
||||
redis_data.hset[field] = val
|
||||
elseif cmd == "HINCRBY" then
|
||||
local key, field, increment = args[1], args[2], args[3]
|
||||
local val = tonumber(redis_data.hset[field] or 0)
|
||||
redis_data.hset[field] = tostring(val + increment)
|
||||
return redis_data.hset[field]
|
||||
elseif cmd == "HKEYS" then
|
||||
local keys = {}
|
||||
for k, _ in pairs(redis_data.hset) do
|
||||
table.insert(keys, k)
|
||||
end
|
||||
return keys
|
||||
elseif cmd == "HDEL" then
|
||||
local key, field = args[1], args[2]
|
||||
redis_data.hset[field] = nil
|
||||
elseif cmd == "GET" then
|
||||
return redis_data.kv[args[1]]
|
||||
elseif cmd == "HMGET" then
|
||||
local key = args[1]
|
||||
local res = {}
|
||||
for i = 2, #args do
|
||||
table.insert(res, redis_data.hset[args[i]])
|
||||
end
|
||||
return res
|
||||
elseif cmd == "SET" then
|
||||
redis_data.kv[args[1]] = args[2]
|
||||
end
|
||||
end
|
||||
}
|
||||
|
||||
-- The actual logic from lb_policy.go
|
||||
local function run_lb_logic(KEYS)
|
||||
local seed = tonumber(KEYS[1])
|
||||
local hset_key = KEYS[2]
|
||||
local last_clean_key = KEYS[3]
|
||||
local clean_interval = tonumber(KEYS[4])
|
||||
local current_target = KEYS[5]
|
||||
local healthy_count = tonumber(KEYS[6])
|
||||
local enable_detail_log = KEYS[7]
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
-- 1. Selection
|
||||
local current_count = 0
|
||||
local same_count_hits = 0
|
||||
|
||||
for i = 8, 8 + healthy_count - 1 do
|
||||
local host = KEYS[i]
|
||||
local count = 0
|
||||
local val = redis.call('HGET', hset_key, host)
|
||||
if val then
|
||||
count = tonumber(val) or 0
|
||||
end
|
||||
|
||||
if same_count_hits == 0 or count < current_count then
|
||||
current_target = host
|
||||
current_count = count
|
||||
same_count_hits = 1
|
||||
elseif count == current_count then
|
||||
same_count_hits = same_count_hits + 1
|
||||
if math.random(same_count_hits) == 1 then
|
||||
current_target = host
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
local new_count = redis.call("HGET", hset_key, current_target)
|
||||
|
||||
-- Collect host counts for logging
|
||||
local host_details = {}
|
||||
if enable_detail_log == "1" then
|
||||
local fields = {}
|
||||
for i = 8, #KEYS do
|
||||
table.insert(fields, KEYS[i])
|
||||
end
|
||||
if #fields > 0 then
|
||||
local values = redis.call('HMGET', hset_key, (table.unpack or unpack)(fields))
|
||||
for i, val in ipairs(values) do
|
||||
table.insert(host_details, fields[i])
|
||||
table.insert(host_details, tostring(val or 0))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 2. Cleanup
|
||||
local current_time = math.floor(seed / 1000000)
|
||||
local last_clean_time = tonumber(redis.call('GET', last_clean_key) or 0)
|
||||
|
||||
if current_time - last_clean_time >= clean_interval then
|
||||
local all_keys = redis.call('HKEYS', hset_key)
|
||||
if #all_keys > 0 then
|
||||
-- Create a lookup table for current hosts (from index 8 onwards)
|
||||
local current_hosts = {}
|
||||
for i = 8, #KEYS do
|
||||
current_hosts[KEYS[i]] = true
|
||||
end
|
||||
-- Remove keys not in current hosts
|
||||
for _, host in ipairs(all_keys) do
|
||||
if not current_hosts[host] then
|
||||
redis.call('HDEL', hset_key, host)
|
||||
end
|
||||
end
|
||||
end
|
||||
redis.call('SET', last_clean_key, current_time)
|
||||
end
|
||||
|
||||
return {current_target, new_count, host_details}
|
||||
end
|
||||
|
||||
-- --- Test 1: Load Balancing Distribution ---
|
||||
print("--- Test 1: Load Balancing Distribution ---")
|
||||
local hosts = {"host1", "host2", "host3", "host4", "host5"}
|
||||
local iterations = 100000
|
||||
local results = {}
|
||||
for _, h in ipairs(hosts) do results[h] = 0 end
|
||||
|
||||
-- Reset redis
|
||||
redis_data.hset = {}
|
||||
for _, h in ipairs(hosts) do redis_data.hset[h] = "0" end
|
||||
|
||||
print(string.format("Running %d iterations with %d hosts (all counts started at 0)...", iterations, #hosts))
|
||||
|
||||
for i = 1, iterations do
|
||||
local initial_host = hosts[math.random(#hosts)]
|
||||
-- KEYS structure: [seed, hset_key, last_clean_key, clean_interval, host_selected, healthy_count, enable_detail_log, ...healthy_hosts]
|
||||
local keys = {i * 1000000, "table_key", "clean_key", 3600, initial_host, #hosts, "1"}
|
||||
for _, h in ipairs(hosts) do table.insert(keys, h) end
|
||||
|
||||
local res = run_lb_logic(keys)
|
||||
local selected = res[1]
|
||||
results[selected] = results[selected] + 1
|
||||
end
|
||||
|
||||
for _, h in ipairs(hosts) do
|
||||
local percentage = (results[h] / iterations) * 100
|
||||
print(string.format("%s: %6d (%.2f%%)", h, results[h], percentage))
|
||||
end
|
||||
|
||||
-- --- Test 2: IP Cleanup Logic ---
|
||||
print("\n--- Test 2: IP Cleanup Logic ---")
|
||||
|
||||
local function test_cleanup()
|
||||
redis_data.hset = {
|
||||
["host1"] = "10",
|
||||
["host2"] = "5",
|
||||
["old_ip_1"] = "1",
|
||||
["old_ip_2"] = "1",
|
||||
}
|
||||
redis_data.kv["clean_key"] = "1000" -- Last cleaned at 1000s
|
||||
|
||||
local current_hosts = {"host1", "host2"}
|
||||
local current_time_ms = 1000 * 1000000 + 500 * 1000000 -- 1500s (interval is 300s, let's say)
|
||||
local clean_interval = 300
|
||||
|
||||
print("Initial Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end return res end)(), ", "))
|
||||
|
||||
-- Run logic (seed is microtime)
|
||||
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "1"}
|
||||
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
|
||||
|
||||
run_lb_logic(keys)
|
||||
|
||||
print("After Cleanup Redis IPs:", table.concat((function() local res={} for k,_ in pairs(redis_data.hset) do table.insert(res, k) end table.sort(res) return res end)(), ", "))
|
||||
|
||||
local exists_old1 = redis_data.hset["old_ip_1"] ~= nil
|
||||
local exists_old2 = redis_data.hset["old_ip_2"] ~= nil
|
||||
|
||||
if not exists_old1 and not exists_old2 then
|
||||
print("Success: Outdated IPs removed.")
|
||||
else
|
||||
print("Failure: Outdated IPs still exist.")
|
||||
end
|
||||
|
||||
print("New last_clean_time:", redis_data.kv["clean_key"])
|
||||
end
|
||||
|
||||
test_cleanup()
|
||||
|
||||
-- --- Test 3: No Cleanup if Interval Not Reached ---
|
||||
print("\n--- Test 3: No Cleanup if Interval Not Reached ---")
|
||||
|
||||
local function test_no_cleanup()
|
||||
redis_data.hset = {
|
||||
["host1"] = "10",
|
||||
["old_ip_1"] = "1",
|
||||
}
|
||||
redis_data.kv["clean_key"] = "1000"
|
||||
|
||||
local current_hosts = {"host1"}
|
||||
local current_time_ms = 1000 * 1000000 + 100 * 1000000 -- 1100s (interval 300s, not reached)
|
||||
local clean_interval = 300
|
||||
|
||||
local keys = {current_time_ms, "table_key", "clean_key", clean_interval, "host1", #current_hosts, "0"}
|
||||
for _, h in ipairs(current_hosts) do table.insert(keys, h) end
|
||||
|
||||
run_lb_logic(keys)
|
||||
|
||||
if redis_data.hset["old_ip_1"] then
|
||||
print("Success: Cleanup not triggered as expected.")
|
||||
else
|
||||
print("Failure: Cleanup triggered unexpectedly.")
|
||||
end
|
||||
end
|
||||
|
||||
test_no_cleanup()
|
||||
@@ -0,0 +1,24 @@
|
||||
package global_least_request
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) checkRateLimit(hostSelected string, currentCount int64, ctx wrapper.HttpContext, routeName string, clusterName string) bool {
|
||||
// 如果没有配置最大请求数,直接通过
|
||||
if lb.maxRequestCount <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// 如果当前请求数大于最大请求数,则限流
|
||||
// 注意:Lua脚本已经加了1,所以这里比较的是加1后的值
|
||||
if currentCount > lb.maxRequestCount {
|
||||
// 恢复 Redis 计数
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected, -1, nil)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package least_busy
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type LeastBusyLoadBalancer struct {
|
||||
criticalModels map[string]struct{}
|
||||
}
|
||||
|
||||
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
|
||||
lb := LeastBusyLoadBalancer{}
|
||||
lb.criticalModels = make(map[string]struct{})
|
||||
for _, model := range json.Get("criticalModels").Array() {
|
||||
lb.criticalModels[model.String()] = struct{}{}
|
||||
}
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
_, isCritical := lb.criticalModels[requestModel.String()]
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: isCritical,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
|
||||
} else {
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpStreamDone(ctx wrapper.HttpContext) {}
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
|
||||
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
|
||||
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/cluster_metrics"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/endpoint_metrics"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
@@ -37,34 +38,57 @@ type LoadBalancer interface {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
policy string
|
||||
lb LoadBalancer
|
||||
lbType string
|
||||
lbPolicy string
|
||||
lb LoadBalancer
|
||||
}
|
||||
|
||||
const (
|
||||
LeastBusyLoadBalancerPolicy = "least_busy"
|
||||
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
|
||||
PrefixCache = "prefix_cache"
|
||||
ClusterLoadBalancerType = "cluster"
|
||||
EndpointLoadBalancerType = "endpoint"
|
||||
// Cluster load balancer policies
|
||||
MetricsBasedCluster = "cluster_metrics"
|
||||
// Endpoint load balancer policies
|
||||
MetricsBasedEndpoint = "endpoint_metrics"
|
||||
MetricsBasedEndpointDeprecated = "metrics_based" // Compatible with old configurations, equal to `endpoint_metrics`
|
||||
GlobalLeastRequestEndpoint = "global_least_request"
|
||||
PrefixCacheEndpoint = "prefix_cache"
|
||||
)
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.policy = json.Get("lb_policy").String()
|
||||
config.lbType = json.Get("lb_type").String()
|
||||
// Compatible with old configurations
|
||||
if config.lbType == "" {
|
||||
config.lbType = EndpointLoadBalancerType
|
||||
}
|
||||
config.lbPolicy = json.Get("lb_policy").String()
|
||||
var err error
|
||||
switch config.policy {
|
||||
case LeastBusyLoadBalancerPolicy:
|
||||
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
|
||||
case GlobalLeastRequestLoadBalancerPolicy:
|
||||
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
|
||||
case PrefixCache:
|
||||
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
|
||||
switch config.lbType {
|
||||
case ClusterLoadBalancerType:
|
||||
switch config.lbPolicy {
|
||||
case MetricsBasedCluster:
|
||||
config.lb, err = cluster_metrics.NewClusterEndpointLoadBalancer(json.Get("lb_config"))
|
||||
default:
|
||||
err = fmt.Errorf("lb_policy %s is not supported", config.lbPolicy)
|
||||
}
|
||||
case EndpointLoadBalancerType:
|
||||
switch config.lbPolicy {
|
||||
case MetricsBasedEndpoint, MetricsBasedEndpointDeprecated:
|
||||
config.lb, err = endpoint_metrics.NewMetricsEndpointLoadBalancer(json.Get("lb_config"))
|
||||
case GlobalLeastRequestEndpoint:
|
||||
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
|
||||
case PrefixCacheEndpoint:
|
||||
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
|
||||
default:
|
||||
err = fmt.Errorf("lb_psolicy %s is not supported", config.lbPolicy)
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
|
||||
err = fmt.Errorf("lb_type %s is not supported", config.lbType)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
ctx.DisableReroute()
|
||||
return config.lb.HandleHttpRequestHeaders(ctx)
|
||||
}
|
||||
|
||||
|
||||
175
plugins/wasm-go/extensions/ai-load-balancer/utils/queue.go
Normal file
175
plugins/wasm-go/extensions/ai-load-balancer/utils/queue.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// FixedQueue 实现了一个固定容量的环形缓冲区队列
|
||||
// 当队列满时,新元素会覆盖最旧的元素
|
||||
type FixedQueue[T any] struct {
|
||||
data []T
|
||||
head int
|
||||
tail int
|
||||
size int
|
||||
cap int
|
||||
}
|
||||
|
||||
// NewFixed 创建一个指定容量的固定队列
|
||||
func NewFixedQueue[T any](capacity int) *FixedQueue[T] {
|
||||
if capacity <= 0 {
|
||||
capacity = 16
|
||||
}
|
||||
return &FixedQueue[T]{
|
||||
data: make([]T, capacity),
|
||||
head: 0,
|
||||
tail: 0,
|
||||
size: 0,
|
||||
cap: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队操作
|
||||
// 如果队列已满,会覆盖最旧的元素
|
||||
func (q *FixedQueue[T]) Enqueue(item T) {
|
||||
if q.size < q.cap {
|
||||
// 队列未满,正常插入
|
||||
q.data[q.tail] = item
|
||||
q.tail = (q.tail + 1) % q.cap
|
||||
q.size++
|
||||
} else {
|
||||
// 队列已满,覆盖最旧元素
|
||||
q.data[q.tail] = item
|
||||
q.head = (q.head + 1) % q.cap // 移动head,丢弃最旧元素
|
||||
q.tail = (q.tail + 1) % q.cap // tail正常移动
|
||||
// size保持不变(仍然是cap)
|
||||
}
|
||||
}
|
||||
|
||||
// Dequeue 出队操作
|
||||
func (q *FixedQueue[T]) Dequeue() (T, error) {
|
||||
var zero T
|
||||
if q.size == 0 {
|
||||
return zero, errors.New("queue is empty")
|
||||
}
|
||||
|
||||
item := q.data[q.head]
|
||||
// 清除引用,避免内存泄漏
|
||||
var zeroVal T
|
||||
q.data[q.head] = zeroVal
|
||||
|
||||
q.head = (q.head + 1) % q.cap
|
||||
q.size--
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// Peek 查看队头元素但不移除
|
||||
func (q *FixedQueue[T]) Peek() (T, error) {
|
||||
var zero T
|
||||
if q.size == 0 {
|
||||
return zero, errors.New("queue is empty")
|
||||
}
|
||||
return q.data[q.head], nil
|
||||
}
|
||||
|
||||
// Size 返回队列中元素的数量
|
||||
func (q *FixedQueue[T]) Size() int {
|
||||
return q.size
|
||||
}
|
||||
|
||||
// Capacity 返回队列的固定容量
|
||||
func (q *FixedQueue[T]) Capacity() int {
|
||||
return q.cap
|
||||
}
|
||||
|
||||
// IsEmpty 判断队列是否为空
|
||||
func (q *FixedQueue[T]) IsEmpty() bool {
|
||||
return q.size == 0
|
||||
}
|
||||
|
||||
// IsFull 判断队列是否已满
|
||||
func (q *FixedQueue[T]) IsFull() bool {
|
||||
return q.size == q.cap
|
||||
}
|
||||
|
||||
// OverwriteCount 返回被覆盖的元素数量
|
||||
// 注意:这个实现中我们不直接跟踪覆盖次数,
|
||||
// 但可以通过其他方式计算(如果需要的话)
|
||||
func (q *FixedQueue[T]) OverwriteCount() int {
|
||||
// 如果需要跟踪覆盖次数,可以添加一个字段
|
||||
// 目前这个实现不提供此功能
|
||||
return 0
|
||||
}
|
||||
|
||||
// Clear 清空队列
|
||||
func (q *FixedQueue[T]) Clear() {
|
||||
// 清除所有引用
|
||||
for i := 0; i < q.size; i++ {
|
||||
idx := (q.head + i) % q.cap
|
||||
var zero T
|
||||
q.data[idx] = zero
|
||||
}
|
||||
q.head = 0
|
||||
q.tail = 0
|
||||
q.size = 0
|
||||
}
|
||||
|
||||
// ToSlice 返回队列元素的切片副本(按队列顺序,从最旧到最新)
|
||||
func (q *FixedQueue[T]) ToSlice() []T {
|
||||
if q.size == 0 {
|
||||
return []T{}
|
||||
}
|
||||
|
||||
result := make([]T, q.size)
|
||||
if q.head <= q.tail || q.size == q.cap {
|
||||
if q.head < q.tail {
|
||||
// 数据连续且未满
|
||||
copy(result, q.data[q.head:q.tail])
|
||||
} else {
|
||||
// 数据连续但已满(head == tail)
|
||||
// 或者数据跨越边界
|
||||
if q.head == q.tail && q.size == q.cap {
|
||||
// 已满且head == tail的情况
|
||||
copy(result, q.data[q.head:])
|
||||
if len(result) > q.cap-q.head {
|
||||
copy(result[q.cap-q.head:], q.data[:q.tail])
|
||||
}
|
||||
} else {
|
||||
// 跨越边界
|
||||
copy(result, q.data[q.head:])
|
||||
copy(result[q.cap-q.head:], q.data[:q.tail])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 跨越边界的情况
|
||||
copy(result, q.data[q.head:])
|
||||
copy(result[q.cap-q.head:], q.data[:q.tail])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Oldest 返回最旧的元素(队头)
|
||||
func (q *FixedQueue[T]) Oldest() (T, error) {
|
||||
return q.Peek()
|
||||
}
|
||||
|
||||
// Newest 返回最新的元素(队尾的前一个元素)
|
||||
func (q *FixedQueue[T]) Newest() (T, error) {
|
||||
var zero T
|
||||
if q.size == 0 {
|
||||
return zero, errors.New("queue is empty")
|
||||
}
|
||||
|
||||
// tail指向下一个插入位置,所以最新元素在 (tail - 1 + cap) % cap
|
||||
newestIndex := (q.tail - 1 + q.cap) % q.cap
|
||||
return q.data[newestIndex], nil
|
||||
}
|
||||
|
||||
// ForEach 对队列中的每个元素执行回调函数
|
||||
func (q *FixedQueue[T]) ForEach(fn func(index int, item T)) {
|
||||
for i := 0; i < q.size; i++ {
|
||||
idx := (q.head + i) % q.cap
|
||||
fn(i, q.data[idx])
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,7 @@
|
||||
.DEFAULT:
|
||||
build:
|
||||
tinygo build -o ai-proxy.wasm -scheduler=none -target=wasi -gc=custom -tags='custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100' ./main.go
|
||||
mv ai-proxy.wasm ../../../../docker-compose-test/
|
||||
mv ai-proxy.wasm ../../../../docker-compose-test/
|
||||
|
||||
build-go:
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o main.wasm main.go
|
||||
@@ -231,6 +231,18 @@ Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:
|
||||
| `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 |
|
||||
| `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为 11434 |
|
||||
|
||||
#### 通用代理(Generic)
|
||||
|
||||
当只需要借助 AI Proxy 的鉴权、basePath 处理或首包超时能力,且不希望插件改写路径时,可将 `provider.type` 设置为 `generic`。该 Provider 不绑定任何模型厂商,也不会做能力映射。
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------- | -------- | -------- | ------ | -------------------------------------------------------------------- |
|
||||
| `genericHost` | string | 非必填 | - | 指定要转发到的目标 Host;未配置时沿用客户端请求的 Host。 |
|
||||
|
||||
- 配置了 `apiTokens` 时,会自动写入 `Authorization: Bearer <token>` 请求头,复用全局的 Token 轮询能力。
|
||||
- 当配置了 `firstByteTimeout` 时,会自动注入 `x-envoy-upstream-rq-first-byte-timeout-ms`。
|
||||
- `basePath` 与 `basePathHandling` 同样适用,可在通用转发中快捷地移除或添加统一前缀。
|
||||
|
||||
#### 混元
|
||||
|
||||
混元所对应的 `type` 为 `hunyuan`。它特有的配置字段如下:
|
||||
@@ -297,7 +309,9 @@ Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
|
||||
|
||||
#### Google Vertex AI
|
||||
|
||||
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下:
|
||||
Google Vertex AI 所对应的 type 为 vertex。支持两种认证模式:
|
||||
|
||||
**标准模式**(使用 Service Account):
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
@@ -308,25 +322,56 @@ Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
|
||||
|
||||
**Express Mode**(使用 API Key,简化配置):
|
||||
|
||||
Express Mode 是 Vertex AI 推出的简化访问模式,只需 API Key 即可快速开始使用,无需配置 Service Account。详见 [Vertex AI Express Mode 文档](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview)。
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | 必填 | - | Express Mode 使用的 API Key,从 Google Cloud Console 的 API & Services > Credentials 获取 |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
|
||||
**OpenAI 兼容模式**(使用 Vertex AI Chat Completions API):
|
||||
|
||||
Vertex AI 提供了 OpenAI 兼容的 Chat Completions API 端点,可以直接使用 OpenAI 格式的请求和响应,无需进行协议转换。详见 [Vertex AI OpenAI 兼容性文档](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview)。
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `vertexOpenAICompatible` | boolean | 非必填 | false | 启用 OpenAI 兼容模式。启用后将使用 Vertex AI 的 OpenAI-compatible Chat Completions API |
|
||||
| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key |
|
||||
| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等) |
|
||||
| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID |
|
||||
| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称 |
|
||||
|
||||
**注意**:OpenAI 兼容模式与 Express Mode 互斥,不能同时配置 `apiTokens` 和 `vertexOpenAICompatible`。
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
AWS Bedrock 所对应的 type 为 bedrock。它特有的配置字段如下:
|
||||
AWS Bedrock 所对应的 type 为 bedrock。它支持两种认证方式:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|--------|------|-----|------------------------------|
|
||||
| `modelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version |
|
||||
| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain |
|
||||
1. **AWS Signature V4 认证**:使用 `awsAccessKey` 和 `awsSecretKey` 进行 AWS 标准签名认证
|
||||
2. **Bearer Token 认证**:使用 `apiTokens` 配置 AWS Bearer Token(适用于 IAM Identity Center 等场景)
|
||||
|
||||
**注意**:两种认证方式二选一,如果同时配置了 `apiTokens`,将优先使用 Bearer Token 认证方式。
|
||||
|
||||
它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|---------------|-------------------|-------|---------------------------------------------------|
|
||||
| `apiTokens` | array of string | 与 ak/sk 二选一 | - | AWS Bearer Token,用于 Bearer Token 认证方式 |
|
||||
| `awsAccessKey` | string | 与 apiTokens 二选一 | - | AWS Access Key,用于 AWS Signature V4 认证 |
|
||||
| `awsSecretKey` | string | 与 apiTokens 二选一 | - | AWS Secret Access Key,用于 AWS Signature V4 认证 |
|
||||
| `awsRegion` | string | 必填 | - | AWS 区域,例如:us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
|
||||
|
||||
#### NVIDIA Triton Interference Server
|
||||
|
||||
NVIDIA Triton Interference Server 所对应的 type 为 triton。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|--------|------|-----|------------------------------|
|
||||
| `awsAccessKey` | string | 必填 | - | AWS Access Key,用于身份认证 |
|
||||
| `awsSecretKey` | string | 必填 | - | AWS Secret Access Key,用于身份认证 |
|
||||
| `awsRegion` | string | 必填 | - | AWS 区域,例如:us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|----------------------|--------|--------|-------|------------------------------------------|
|
||||
| `tritonModelVersion` | string | 非必填 | - | 用于指定 Triton Server 中 model version |
|
||||
| `tritonDomain` | string | 非必填 | - | Triton Server 部署的指定请求 Domain |
|
||||
|
||||
## 用法示例
|
||||
|
||||
@@ -1935,7 +1980,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(标准模式)
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -1997,8 +2042,134 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(Express Mode)
|
||||
|
||||
Express Mode 是 Vertex AI 的简化访问模式,只需 API Key 即可快速开始使用。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-flash",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是 Gemini,由 Google 开发的人工智能助手。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.5-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务(OpenAI 兼容模式)
|
||||
|
||||
OpenAI 兼容模式使用 Vertex AI 的 OpenAI-compatible Chat Completions API,请求和响应都使用 OpenAI 格式,无需进行协议转换。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexOpenAICompatible: true
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
modelMapping:
|
||||
"gpt-4": "gemini-2.0-flash"
|
||||
"*": "gemini-1.5-flash"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是由 Google 开发的 Gemini 模型。我可以帮助回答问题、提供信息和进行对话。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 35,
|
||||
"total_tokens": 47
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 AWS Bedrock 服务
|
||||
|
||||
AWS Bedrock 支持两种认证方式:
|
||||
|
||||
#### 方式一:使用 AWS Access Key/Secret Key 认证(AWS Signature V4)
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
@@ -2006,7 +2177,21 @@ provider:
|
||||
type: bedrock
|
||||
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
|
||||
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
awsRegion: "YOUR_AWS_REGION"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
|
||||
#### 方式二:使用 Bearer Token 认证(适用于 IAM Identity Center 等场景)
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
apiTokens:
|
||||
- "YOUR_AWS_BEARER_TOKEN"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
@@ -2015,7 +2200,7 @@ provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -197,6 +197,18 @@ For Ollama, the corresponding `type` is `ollama`. Its unique configuration field
|
||||
| `ollamaServerHost` | string | Required | - | The host address of the Ollama server. |
|
||||
| `ollamaServerPort` | number | Required | - | The port number of the Ollama server, defaults to 11434. |
|
||||
|
||||
#### Generic
|
||||
|
||||
For a vendor-agnostic passthrough, set the provider `type` to `generic`. Requests are forwarded without path remapping, while still benefiting from the shared header/basePath utilities.
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|----------------|-----------|-------------|---------|----------------------------------------------------------------------------------------------------------|
|
||||
| `genericHost` | string | Optional | - | Overrides the upstream `Host` header. Use it to route traffic to a specific backend domain for generic proxying. |
|
||||
|
||||
- When `apiTokens` are configured, the Generic provider injects `Authorization: Bearer <token>` automatically.
|
||||
- `firstByteTimeout` applies to any request whose body sets `stream: true`, ensuring consistent streaming behavior even without capability definitions.
|
||||
- `basePath` and `basePathHandling` remain available to strip or prepend prefixes before forwarding.
|
||||
|
||||
#### Hunyuan
|
||||
|
||||
For Hunyuan, the corresponding `type` is `hunyuan`. Its unique configuration fields are:
|
||||
@@ -243,7 +255,9 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
|
||||
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
|
||||
|
||||
#### Google Vertex AI
|
||||
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
|
||||
For Vertex, the corresponding `type` is `vertex`. It supports two authentication modes:
|
||||
|
||||
**Standard Mode** (using Service Account):
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
@@ -254,16 +268,47 @@ For Vertex, the corresponding `type` is `vertex`. Its unique configuration field
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
|
||||
|
||||
**Express Mode** (using API Key, simplified configuration):
|
||||
|
||||
Express Mode is a simplified access mode introduced by Vertex AI. You can quickly get started with just an API Key, without configuring a Service Account. See [Vertex AI Express Mode documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview).
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | Required | - | API Key for Express Mode, obtained from Google Cloud Console under API & Services > Credentials |
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
|
||||
**OpenAI Compatible Mode** (using Vertex AI Chat Completions API):
|
||||
|
||||
Vertex AI provides an OpenAI-compatible Chat Completions API endpoint, allowing you to use OpenAI format requests and responses directly without protocol conversion. See [Vertex AI OpenAI Compatibility documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/openai/overview).
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|------------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `vertexOpenAICompatible` | boolean | Optional | false | Enable OpenAI compatible mode. When enabled, uses Vertex AI's OpenAI-compatible Chat Completions API |
|
||||
| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key for authentication |
|
||||
| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) |
|
||||
| `vertexProjectId` | string | Required | - | Google Cloud Project ID |
|
||||
| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication |
|
||||
|
||||
**Note**: OpenAI Compatible Mode and Express Mode are mutually exclusive. You cannot configure both `apiTokens` and `vertexOpenAICompatible` at the same time.
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
For AWS Bedrock, the corresponding `type` is `bedrock`. Its unique configuration field is:
|
||||
For AWS Bedrock, the corresponding `type` is `bedrock`. It supports two authentication methods:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|---------------------------|-----------|-------------|---------|---------------------------------------------------------|
|
||||
| `awsAccessKey` | string | Required | - | AWS Access Key used for authentication |
|
||||
| `awsSecretKey` | string | Required | - | AWS Secret Access Key used for authentication |
|
||||
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
|
||||
1. **AWS Signature V4 Authentication**: Uses `awsAccessKey` and `awsSecretKey` for standard AWS signature authentication
|
||||
2. **Bearer Token Authentication**: Uses `apiTokens` to configure AWS Bearer Token (suitable for IAM Identity Center and similar scenarios)
|
||||
|
||||
**Note**: Choose one of the two authentication methods. If `apiTokens` is configured, Bearer Token authentication will be used preferentially.
|
||||
|
||||
Its unique configuration fields are:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|---------------------------|-----------------|--------------------------|---------|-------------------------------------------------------------------|
|
||||
| `apiTokens` | array of string | Either this or ak/sk | - | AWS Bearer Token for Bearer Token authentication |
|
||||
| `awsAccessKey` | string | Either this or apiTokens | - | AWS Access Key for AWS Signature V4 authentication |
|
||||
| `awsSecretKey` | string | Either this or apiTokens | - | AWS Secret Access Key for AWS Signature V4 authentication |
|
||||
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -1708,7 +1753,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Standard Mode)
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
@@ -1766,14 +1811,148 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (Express Mode)
|
||||
|
||||
Express Mode is a simplified access mode for Vertex AI. You only need an API Key to get started quickly.
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
apiTokens:
|
||||
- "YOUR_API_KEY"
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-flash",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am Gemini, an AI assistant developed by Google. How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.5-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 25,
|
||||
"total_tokens": 35
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services (OpenAI Compatible Mode)
|
||||
|
||||
OpenAI Compatible Mode uses Vertex AI's OpenAI-compatible Chat Completions API. Both requests and responses use OpenAI format, requiring no protocol conversion.
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexOpenAICompatible: true
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
modelMapping:
|
||||
"gpt-4": "gemini-2.0-flash"
|
||||
"*": "gemini-1.5-flash"
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am Gemini, an AI model developed by Google. I can help answer questions, provide information, and engage in conversations. How can I assist you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 35,
|
||||
"total_tokens": 47
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
|
||||
|
||||
AWS Bedrock supports two authentication methods:
|
||||
|
||||
#### Method 1: Using AWS Access Key/Secret Key Authentication (AWS Signature V4)
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
|
||||
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
awsRegion: "YOUR_AWS_REGION"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
|
||||
#### Method 2: Using Bearer Token Authentication (suitable for IAM Identity Center and similar scenarios)
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
apiTokens:
|
||||
- "YOUR_AWS_BEARER_TOKEN"
|
||||
awsRegion: "us-east-1"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
@@ -1781,7 +1960,7 @@ provider:
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"model": "us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -7,8 +7,8 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251226032831-95da539a1ec7
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
@@ -2,10 +2,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac h1:tdJzS56Xa6BSHAi9P2omvb98bpI8qFGg6jnCPtPmDgA=
|
||||
github.com/higress-group/wasm-go v1.0.3-0.20251011083635-792cb1547bac/go.mod h1:B8C6+OlpnyYyZUBEdUXA7tYZYD+uwZTNjfkE5FywA+A=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c h1:DdVPyaMHSYBqO5jwB9Wl3PqsBGIf4u29BHMI0uIVB1Y=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251209122854-7e766df5675c/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251226032831-95da539a1ec7 h1:ddAkPFIIf6isVQNymS6+X6QO51/WV0Af4Afb9a2z9TE=
|
||||
github.com/higress-group/wasm-go v1.0.9-0.20251226032831-95da539a1ec7/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
|
||||
@@ -102,6 +102,7 @@ func init() {
|
||||
wrapper.ProcessStreamingResponseBody(onStreamingResponseBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
wrapper.WithRebuildAfterRequests[config.PluginConfig](1000),
|
||||
wrapper.WithRebuildMaxMemBytes[config.PluginConfig](200*1024*1024),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -145,7 +146,7 @@ func initContext(ctx wrapper.HttpContext) {
|
||||
ctx.SetContext(ctxKey, value)
|
||||
}
|
||||
for _, originHeader := range headerToOriginalHeaderMapping {
|
||||
proxywasm.RemoveHttpRequestHeader(originHeader)
|
||||
_ = proxywasm.RemoveHttpRequestHeader(originHeader)
|
||||
}
|
||||
originalAuth, _ := proxywasm.GetHttpRequestHeader(util.HeaderOriginalAuth)
|
||||
if originalAuth == "" {
|
||||
@@ -204,23 +205,24 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
|
||||
apiName = handler.GetApiName(path.Path)
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect protocol based on request path and handle conversion if needed
|
||||
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
|
||||
// convert to OpenAI format (/v1/chat/completions)
|
||||
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
|
||||
// Provider doesn't support Claude protocol natively, convert to OpenAI format
|
||||
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
|
||||
// Update apiName to match the new path
|
||||
apiName = provider.ApiNameChatCompletion
|
||||
// Mark that we need to convert response back to Claude format
|
||||
ctx.SetContext("needClaudeResponseConversion", true)
|
||||
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
|
||||
} else if apiName == provider.ApiNameAnthropicMessages {
|
||||
// Provider supports Claude protocol natively, no conversion needed
|
||||
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
|
||||
} else {
|
||||
// Only perform protocol conversion for non-original protocols.
|
||||
// Auto-detect protocol based on request path and handle conversion if needed
|
||||
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
|
||||
// convert to OpenAI format (/v1/chat/completions)
|
||||
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
|
||||
// Provider doesn't support Claude protocol natively, convert to OpenAI format
|
||||
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
|
||||
// Update apiName to match the new path
|
||||
apiName = provider.ApiNameChatCompletion
|
||||
// Mark that we need to convert response back to Claude format
|
||||
ctx.SetContext("needClaudeResponseConversion", true)
|
||||
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
|
||||
} else if apiName == provider.ApiNameAnthropicMessages {
|
||||
// Provider supports Claude protocol natively, no conversion needed
|
||||
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
|
||||
}
|
||||
}
|
||||
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
@@ -247,8 +249,8 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
providerConfig.SetAvailableApiTokens(ctx)
|
||||
|
||||
// save the original request host and path in case they are needed for apiToken health check and retry
|
||||
ctx.SetContext(provider.CtxRequestHost, wrapper.GetRequestHost())
|
||||
ctx.SetContext(provider.CtxRequestPath, wrapper.GetRequestPath())
|
||||
ctx.SetContext(provider.CtxRequestHost, ctx.Host())
|
||||
ctx.SetContext(provider.CtxRequestPath, ctx.Path())
|
||||
|
||||
err := handler.OnRequestHeaders(ctx, apiName)
|
||||
if err != nil {
|
||||
@@ -256,7 +258,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
hasRequestBody := wrapper.HasRequestBody()
|
||||
hasRequestBody := ctx.HasRequestBody()
|
||||
if hasRequestBody {
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
|
||||
@@ -27,6 +27,10 @@ func Test_getApiName(t *testing.T) {
|
||||
{"openai files", "/v1/files", provider.ApiNameFiles},
|
||||
{"openai retrieve file", "/v1/files/fileid", provider.ApiNameRetrieveFile},
|
||||
{"openai retrieve file content", "/v1/files/fileid/content", provider.ApiNameRetrieveFileContent},
|
||||
{"openai videos", "/v1/videos", provider.ApiNameVideos},
|
||||
{"openai retrieve video", "/v1/videos/videoid", provider.ApiNameRetrieveVideo},
|
||||
{"openai retrieve video content", "/v1/videos/videoid/content", provider.ApiNameRetrieveVideoContent},
|
||||
{"openai video remix", "/v1/videos/videoid/remix", provider.ApiNameVideoRemix},
|
||||
{"openai models", "/v1/models", provider.ApiNameModels},
|
||||
{"openai fine tuning jobs", "/v1/fine_tuning/jobs", provider.ApiNameFineTuningJobs},
|
||||
{"openai retrieve fine tuning job", "/v1/fine_tuning/jobs/jobid", provider.ApiNameRetrieveFineTuningJob},
|
||||
@@ -102,6 +106,7 @@ func TestAzure(t *testing.T) {
|
||||
test.RunAzureOnHttpRequestBodyTests(t)
|
||||
test.RunAzureOnHttpResponseHeadersTests(t)
|
||||
test.RunAzureOnHttpResponseBodyTests(t)
|
||||
test.RunAzureBasePathHandlingTests(t)
|
||||
}
|
||||
|
||||
func TestFireworks(t *testing.T) {
|
||||
@@ -109,3 +114,33 @@ func TestFireworks(t *testing.T) {
|
||||
test.RunFireworksOnHttpRequestHeadersTests(t)
|
||||
test.RunFireworksOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestMinimax(t *testing.T) {
|
||||
test.RunMinimaxBasePathHandlingTests(t)
|
||||
}
|
||||
|
||||
func TestUtil(t *testing.T) {
|
||||
test.RunMapRequestPathByCapabilityTests(t)
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
test.RunGenericParseConfigTests(t)
|
||||
test.RunGenericOnHttpRequestHeadersTests(t)
|
||||
test.RunGenericOnHttpRequestBodyTests(t)
|
||||
}
|
||||
|
||||
func TestVertex(t *testing.T) {
|
||||
test.RunVertexParseConfigTests(t)
|
||||
test.RunVertexExpressModeOnHttpRequestHeadersTests(t)
|
||||
test.RunVertexExpressModeOnHttpRequestBodyTests(t)
|
||||
test.RunVertexExpressModeOnHttpResponseBodyTests(t)
|
||||
test.RunVertexExpressModeOnStreamingResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestBedrock(t *testing.T) {
|
||||
test.RunBedrockParseConfigTests(t)
|
||||
test.RunBedrockOnHttpRequestHeadersTests(t)
|
||||
test.RunBedrockOnHttpRequestBodyTests(t)
|
||||
test.RunBedrockOnHttpResponseHeadersTests(t)
|
||||
test.RunBedrockOnHttpResponseBodyTests(t)
|
||||
}
|
||||
|
||||
@@ -174,12 +174,14 @@ func (m *azureProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName Ap
|
||||
}
|
||||
|
||||
func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName ApiName) string {
|
||||
originalPath := util.GetOriginalRequestPath()
|
||||
|
||||
// When using original protocol, don't overwrite the path.
|
||||
// This ensures basePathHandling works correctly even in TransformRequestBody stage.
|
||||
if m.config.IsOriginal() {
|
||||
return originalPath
|
||||
return ""
|
||||
}
|
||||
|
||||
originalPath := util.GetOriginalRequestPath()
|
||||
|
||||
if m.serviceUrlType == azureServiceUrlTypeFull {
|
||||
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
|
||||
return m.serviceUrlFullPath
|
||||
|
||||
@@ -43,8 +43,11 @@ const (
|
||||
type bedrockProviderInitializer struct{}
|
||||
|
||||
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if len(config.awsAccessKey) == 0 || len(config.awsSecretKey) == 0 {
|
||||
return errors.New("missing bedrock access authentication parameters")
|
||||
hasAkSk := len(config.awsAccessKey) > 0 && len(config.awsSecretKey) > 0
|
||||
hasApiToken := len(config.apiTokens) > 0
|
||||
|
||||
if !hasAkSk && !hasApiToken {
|
||||
return errors.New("missing bedrock access authentication parameters: either apiTokens or (awsAccessKey + awsSecretKey) is required")
|
||||
}
|
||||
if len(config.awsRegion) == 0 {
|
||||
return errors.New("missing bedrock region parameters")
|
||||
@@ -107,9 +110,8 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
chatChoice.Delta.Content = nil
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Id: bedrockEvent.Start.ToolUse.ToolUseID,
|
||||
Index: 0,
|
||||
Type: "function",
|
||||
Id: bedrockEvent.Start.ToolUse.ToolUseID,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: bedrockEvent.Start.ToolUse.Name,
|
||||
Arguments: "",
|
||||
@@ -138,8 +140,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
if bedrockEvent.Delta.ToolUse != nil {
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Index: 0,
|
||||
Type: "function",
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Arguments: bedrockEvent.Delta.ToolUse.Input,
|
||||
},
|
||||
@@ -168,7 +169,6 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
openAIFormattedChunkBytes, _ := json.Marshal(openAIFormattedChunk)
|
||||
var openAIChunk strings.Builder
|
||||
openAIChunk.WriteString(ssePrefix)
|
||||
@@ -637,6 +637,13 @@ func (b *bedrockProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiN
|
||||
|
||||
func (b *bedrockProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion))
|
||||
|
||||
// If apiTokens is configured, set Bearer token authentication here
|
||||
// This follows the same pattern as other providers (qwen, zhipuai, etc.)
|
||||
// AWS SigV4 authentication is handled in setAuthHeaders because it requires the request body
|
||||
if len(b.config.apiTokens) > 0 {
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+b.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
@@ -662,18 +669,18 @@ func (b *bedrockProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return b.onImageGenerationResponseBody(ctx, body)
|
||||
return b.onImageGenerationResponseBody(body)
|
||||
}
|
||||
return nil, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
func (b *bedrockProvider) onImageGenerationResponseBody(body []byte) ([]byte, error) {
|
||||
bedrockResponse := &bedrockImageGenerationResponse{}
|
||||
if err := json.Unmarshal(body, bedrockResponse); err != nil {
|
||||
log.Errorf("unable to unmarshal bedrock image gerneration response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal bedrock image generation response: %v", err)
|
||||
}
|
||||
response := b.buildBedrockImageGenerationResponse(ctx, bedrockResponse)
|
||||
response := b.buildBedrockImageGenerationResponse(bedrockResponse)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
@@ -713,7 +720,7 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
||||
return requestBytes, err
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(ctx wrapper.HttpContext, bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
func (b *bedrockProvider) buildBedrockImageGenerationResponse(bedrockResponse *bedrockImageGenerationResponse) *imageGenerationResponse {
|
||||
data := make([]imageGenerationData, len(bedrockResponse.Images))
|
||||
for i, image := range bedrockResponse.Images {
|
||||
data[i] = imageGenerationData{
|
||||
@@ -1062,17 +1069,19 @@ func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
Text: text,
|
||||
},
|
||||
}
|
||||
openaiContent := chatMessage.ParseContent()
|
||||
for _, part := range openaiContent {
|
||||
var content bedrockMessageContent
|
||||
if part.Type == contentTypeText {
|
||||
content.Text = part.Text
|
||||
} else {
|
||||
continue
|
||||
} else if contentList, ok := chatMessage.Content.([]any); ok {
|
||||
for _, contentItem := range contentList {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if ok && contentMap["type"] == contentTypeText {
|
||||
if text, ok := contentMap[contentTypeText].(string); ok {
|
||||
toolResultContent.Content = append(toolResultContent.Content, toolResultContentBlock{
|
||||
Text: text,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Warnf("only text content is supported, current content is %v", chatMessage.Content)
|
||||
log.Warnf("the content type is not supported, current content is %v", chatMessage.Content)
|
||||
}
|
||||
return bedrockMessage{
|
||||
Role: roleUser,
|
||||
@@ -1139,6 +1148,13 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
// Bearer token authentication is already set in TransformRequestHeaders
|
||||
// This function only handles AWS SigV4 authentication which requires the request body
|
||||
if len(b.config.apiTokens) > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Use AWS Signature V4 authentication
|
||||
t := time.Now().UTC()
|
||||
amzDate := t.Format("20060102T150405Z")
|
||||
dateStamp := t.Format("20060102")
|
||||
|
||||
@@ -203,19 +203,20 @@ type claudeThinkingConfig struct {
|
||||
}
|
||||
|
||||
type claudeTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System *claudeSystemPrompt `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System *claudeSystemPrompt `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||
AnthropicVersion string `json:"anthropic_version,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenResponse struct {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
@@ -59,7 +60,18 @@ func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
|
||||
func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if d.config.difyApiUrl != "" {
|
||||
log.Debugf("use local host: %s", d.config.difyApiUrl)
|
||||
util.OverwriteRequestHostHeader(headers, d.config.difyApiUrl)
|
||||
// Extract hostname, including Full URL or Domain
|
||||
host := d.config.difyApiUrl
|
||||
if parsedUrl, err := url.Parse(d.config.difyApiUrl); err == nil && parsedUrl.Host != "" {
|
||||
host = parsedUrl.Host
|
||||
} else {
|
||||
host = strings.TrimPrefix(strings.TrimPrefix(d.config.difyApiUrl, "http://"), "https://")
|
||||
if idx := strings.Index(host, "/"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
}
|
||||
log.Debugf("extracted hostname: %s", host)
|
||||
util.OverwriteRequestHostHeader(headers, host)
|
||||
} else {
|
||||
util.OverwriteRequestHostHeader(headers, difyDomain)
|
||||
}
|
||||
|
||||
@@ -70,7 +70,11 @@ func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
|
||||
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, doubaoDomain)
|
||||
if m.config.doubaoDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, m.config.doubaoDomain)
|
||||
} else {
|
||||
util.OverwriteRequestHostHeader(headers, doubaoDomain)
|
||||
}
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
85
plugins/wasm-go/extensions/ai-proxy/provider/generic.go
Normal file
85
plugins/wasm-go/extensions/ai-proxy/provider/generic.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
// genericProviderInitializer 用于创建一个不做能力映射的通用 Provider。
|
||||
type genericProviderInitializer struct{}
|
||||
|
||||
// ValidateConfig 通用 Provider 不需要额外的配置校验。
|
||||
func (m *genericProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultCapabilities 返回空映射,表示不会做路径或能力重写。
|
||||
func (m *genericProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
// CreateProvider 创建 generic provider,并沿用通用的上下文缓存能力。
|
||||
func (m *genericProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
return &genericProvider{
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// genericProvider 只负责公共的头部、请求体处理逻辑,不绑定任何厂商。
|
||||
type genericProvider struct {
|
||||
config ProviderConfig
|
||||
}
|
||||
|
||||
func (m *genericProvider) GetProviderType() string {
|
||||
return providerTypeGeneric
|
||||
}
|
||||
|
||||
// OnRequestHeaders 复用通用的 handleRequestHeaders,并在配置首包超时时写入相关头部。
|
||||
func (m *genericProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName)
|
||||
if m.config.firstByteTimeout > 0 {
|
||||
ctx.SetContext(ctxKeyIsStreaming, true)
|
||||
m.applyFirstByteTimeout()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *genericProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// TransformRequestHeaders 只处理鉴权与 Host 改写,不做路径重写。
|
||||
func (m *genericProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if len(m.config.apiTokens) > 0 {
|
||||
if token := m.config.GetApiTokenInUse(ctx); token != "" {
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+token)
|
||||
}
|
||||
}
|
||||
if m.config.genericHost != "" {
|
||||
util.OverwriteRequestHostHeader(headers, m.config.genericHost)
|
||||
}
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
// applyFirstByteTimeout 在配置了 firstByteTimeout 时,为所有流式请求写入超时头。
|
||||
func (m *genericProvider) applyFirstByteTimeout() {
|
||||
if m.config.firstByteTimeout == 0 {
|
||||
return
|
||||
}
|
||||
err := proxywasm.ReplaceHttpRequestHeader(
|
||||
"x-envoy-upstream-rq-first-byte-timeout-ms",
|
||||
strconv.FormatUint(uint64(m.config.firstByteTimeout), 10),
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("generic provider: failed to set first byte timeout header: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debugf("[generic][firstByteTimeout] %d", m.config.firstByteTimeout)
|
||||
}
|
||||
@@ -106,8 +106,11 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte) (typ
|
||||
}
|
||||
|
||||
// Map the model and rewrite the request path.
|
||||
// When using original protocol, don't overwrite the path to ensure basePathHandling works correctly.
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping)
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
if !m.config.IsOriginal() {
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
}
|
||||
|
||||
if m.config.context == nil {
|
||||
minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, "")
|
||||
|
||||
@@ -393,7 +393,7 @@ func (m *chatMessage) ParseContent() []chatMessageContent {
|
||||
}
|
||||
|
||||
type toolCall struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Function functionCall `json:"function"`
|
||||
|
||||
@@ -144,6 +144,7 @@ const (
|
||||
providerTypeLongcat = "longcat"
|
||||
providerTypeFireworks = "fireworks"
|
||||
providerTypeVllm = "vllm"
|
||||
providerTypeGeneric = "generic"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
@@ -227,6 +228,7 @@ var (
|
||||
providerTypeLongcat: &longcatProviderInitializer{},
|
||||
providerTypeFireworks: &fireworksProviderInitializer{},
|
||||
providerTypeVllm: &vllmProviderInitializer{},
|
||||
providerTypeGeneric: &genericProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -385,6 +387,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Vertex token刷新提前时间
|
||||
// @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒
|
||||
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
|
||||
// @Title zh-CN Vertex AI OpenAI兼容模式
|
||||
// @Description zh-CN 启用后将使用Vertex AI的OpenAI兼容API,请求和响应均使用OpenAI格式,无需协议转换。与Express Mode(apiTokens)互斥。
|
||||
vertexOpenAICompatible bool `required:"false" yaml:"vertexOpenAICompatible" json:"vertexOpenAICompatible"`
|
||||
// @Title zh-CN 翻译服务需指定的目标语种
|
||||
// @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。
|
||||
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
|
||||
@@ -409,6 +414,9 @@ type ProviderConfig struct {
|
||||
basePath string `required:"false" yaml:"basePath" json:"basePath"`
|
||||
// @Title zh-CN basePathHandling用于指定basePath的处理方式,可选值:removePrefix、prepend
|
||||
basePathHandling basePathHandling `required:"false" yaml:"basePathHandling" json:"basePathHandling"`
|
||||
// @Title zh-CN generic Provider 对应的Host
|
||||
// @Description zh-CN 仅适用于generic provider,用于覆盖请求转发的目标Host
|
||||
genericHost string `required:"false" yaml:"genericHost" json:"genericHost"`
|
||||
// @Title zh-CN 首包超时
|
||||
// @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0,表示不开启首包超时
|
||||
firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"`
|
||||
@@ -424,6 +432,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN vLLM主机地址
|
||||
// @Description zh-CN 仅适用于vLLM服务,指定vLLM服务器的主机地址,例如:vllm-service.cluster.local
|
||||
vllmServerHost string `required:"false" yaml:"vllmServerHost" json:"vllmServerHost"`
|
||||
// @Title zh-CN 豆包服务域名
|
||||
// @Description zh-CN 仅适用于豆包服务,默认转发域名为 ark.cn-beijing.volces.com
|
||||
doubaoDomain string `required:"false" yaml:"doubaoDomain" json:"doubaoDomain"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -532,6 +543,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
if c.vertexTokenRefreshAhead == 0 {
|
||||
c.vertexTokenRefreshAhead = 60
|
||||
}
|
||||
c.vertexOpenAICompatible = json.Get("vertexOpenAICompatible").Bool()
|
||||
c.targetLang = json.Get("targetLang").String()
|
||||
|
||||
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
|
||||
@@ -619,8 +631,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
if c.basePath != "" && c.basePathHandling == "" {
|
||||
c.basePathHandling = basePathHandlingRemovePrefix
|
||||
}
|
||||
c.genericHost = json.Get("genericHost").String()
|
||||
c.vllmServerHost = json.Get("vllmServerHost").String()
|
||||
c.vllmCustomUrl = json.Get("vllmCustomUrl").String()
|
||||
c.doubaoDomain = json.Get("doubaoDomain").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -963,12 +977,33 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
|
||||
headers := util.GetRequestHeaders()
|
||||
originPath := headers.Get(":path")
|
||||
|
||||
// Record the path after removePrefix processing
|
||||
var removePrefixPath string
|
||||
if c.basePath != "" && c.basePathHandling == basePathHandlingRemovePrefix {
|
||||
headers.Set(":path", strings.TrimPrefix(originPath, c.basePath))
|
||||
removePrefixPath = strings.TrimPrefix(originPath, c.basePath)
|
||||
headers.Set(":path", removePrefixPath)
|
||||
}
|
||||
|
||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers)
|
||||
}
|
||||
|
||||
// When using original protocol with removePrefix, restore the basePath-processed path.
|
||||
// This ensures basePathHandling works correctly even when TransformRequestHeaders
|
||||
// overwrites the path (which most providers do).
|
||||
//
|
||||
// TODO: Most providers (OpenAI, vLLM, DeepSeek, Claude, etc.) unconditionally overwrite
|
||||
// the path in TransformRequestHeaders without checking IsOriginal(). Ideally, each provider
|
||||
// should check IsOriginal() before overwriting the path (like Qwen does). Once all providers
|
||||
// are updated to handle protocol correctly, this workaround can be removed.
|
||||
// Affected providers: OpenAI, vLLM, ZhipuAI, Moonshot, Longcat, DeepSeek, Azure, Yi,
|
||||
// TogetherAI, Stepfun, Ollama, Hunyuan, GitHub, Doubao, Cohere, Baichuan, AI360, Claude,
|
||||
// Groq, Grok, Spark, Fireworks, Cloudflare, Baidu, OpenRouter, DeepL (24+ providers)
|
||||
if c.IsOriginal() && removePrefixPath != "" {
|
||||
headers.Set(":path", removePrefixPath)
|
||||
}
|
||||
|
||||
if c.basePath != "" && c.basePathHandling == basePathHandlingPrepend && !strings.HasPrefix(headers.Get(":path"), c.basePath) {
|
||||
headers.Set(":path", path.Join(c.basePath, headers.Get(":path")))
|
||||
}
|
||||
|
||||
@@ -21,21 +21,60 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
vertexAuthDomain = "oauth2.googleapis.com"
|
||||
vertexDomain = "{REGION}-aiplatform.googleapis.com"
|
||||
vertexDomain = "aiplatform.googleapis.com"
|
||||
// /v1/projects/{PROJECT_ID}/locations/{REGION}/publishers/google/models/{MODEL_ID}:{ACTION}
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexEmbeddingAction = "predict"
|
||||
vertexPathTemplate = "/v1/projects/%s/locations/%s/publishers/google/models/%s:%s"
|
||||
vertexPathAnthropicTemplate = "/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s"
|
||||
// Express Mode 路径模板 (不含 project/location)
|
||||
vertexExpressPathTemplate = "/v1/publishers/google/models/%s:%s"
|
||||
vertexExpressPathAnthropicTemplate = "/v1/publishers/anthropic/models/%s:%s"
|
||||
// OpenAI-compatible endpoint 路径模板
|
||||
// /v1beta1/projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/openapi/chat/completions
|
||||
vertexOpenAICompatiblePathTemplate = "/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions"
|
||||
vertexChatCompletionAction = "generateContent"
|
||||
vertexChatCompletionStreamAction = "streamGenerateContent?alt=sse"
|
||||
vertexAnthropicMessageAction = "rawPredict"
|
||||
vertexAnthropicMessageStreamAction = "streamRawPredict"
|
||||
vertexEmbeddingAction = "predict"
|
||||
vertexGlobalRegion = "global"
|
||||
contextClaudeMarker = "isClaudeRequest"
|
||||
contextOpenAICompatibleMarker = "isOpenAICompatibleRequest"
|
||||
vertexAnthropicVersion = "vertex-2023-10-16"
|
||||
)
|
||||
|
||||
type vertexProviderInitializer struct{}
|
||||
|
||||
func (v *vertexProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// Express Mode: 如果配置了 apiTokens,则使用 API Key 认证
|
||||
if len(config.apiTokens) > 0 {
|
||||
// Express Mode 与 OpenAI 兼容模式互斥
|
||||
if config.vertexOpenAICompatible {
|
||||
return errors.New("vertexOpenAICompatible is not compatible with Express Mode (apiTokens)")
|
||||
}
|
||||
// Express Mode 不需要其他配置
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenAI 兼容模式: 需要 OAuth 认证配置
|
||||
if config.vertexOpenAICompatible {
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
if config.vertexRegion == "" || config.vertexProjectId == "" {
|
||||
return errors.New("missing vertexRegion or vertexProjectId in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
if config.vertexAuthServiceName == "" {
|
||||
return errors.New("missing vertexAuthServiceName in vertex provider config for OpenAI compatible mode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标准模式: 保持原有验证逻辑
|
||||
if config.vertexAuthKey == "" {
|
||||
return errors.New("missing vertexAuthKey in vertex provider config")
|
||||
}
|
||||
@@ -57,21 +96,45 @@ func (v *vertexProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
|
||||
func (v *vertexProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(v.DefaultCapabilities())
|
||||
return &vertexProvider{
|
||||
config: config,
|
||||
client: wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
|
||||
provider := &vertexProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
claude: &claudeProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
},
|
||||
}
|
||||
|
||||
// 仅标准模式需要 OAuth 客户端(Express Mode 通过 apiTokens 配置)
|
||||
if !provider.isExpressMode() {
|
||||
provider.client = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
Domain: vertexAuthDomain,
|
||||
ServiceName: config.vertexAuthServiceName,
|
||||
Port: 443,
|
||||
}),
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// isExpressMode 检测是否启用 Express Mode
|
||||
// 如果配置了 apiTokens,则使用 Express Mode(API Key 认证)
|
||||
func (v *vertexProvider) isExpressMode() bool {
|
||||
return len(v.config.apiTokens) > 0
|
||||
}
|
||||
|
||||
// isOpenAICompatibleMode 检测是否启用 OpenAI 兼容模式
|
||||
// 使用 Vertex AI 的 OpenAI-compatible Chat Completions API
|
||||
func (v *vertexProvider) isOpenAICompatibleMode() bool {
|
||||
return v.config.vertexOpenAICompatible
|
||||
}
|
||||
|
||||
type vertexProvider struct {
|
||||
client wrapper.HttpClient
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
claude *claudeProvider
|
||||
}
|
||||
|
||||
func (v *vertexProvider) GetProviderType() string {
|
||||
@@ -94,8 +157,21 @@ func (v *vertexProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
vertexRegionDomain := strings.Replace(vertexDomain, "{REGION}", v.config.vertexRegion, 1)
|
||||
util.OverwriteRequestHostHeader(headers, vertexRegionDomain)
|
||||
var finalVertexDomain string
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 固定域名,不带 region 前缀
|
||||
finalVertexDomain = vertexDomain
|
||||
} else {
|
||||
// 标准模式: 带 region 前缀
|
||||
if v.config.vertexRegion != vertexGlobalRegion {
|
||||
finalVertexDomain = fmt.Sprintf("%s-%s", v.config.vertexRegion, vertexDomain)
|
||||
} else {
|
||||
finalVertexDomain = vertexDomain
|
||||
}
|
||||
}
|
||||
|
||||
util.OverwriteRequestHostHeader(headers, finalVertexDomain)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getToken() (cached bool, err error) {
|
||||
@@ -137,8 +213,42 @@ func (v *vertexProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if v.config.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
headers := util.GetRequestHeaders()
|
||||
|
||||
// OpenAI 兼容模式: 不转换请求体,只设置路径和进行模型映射
|
||||
if v.isOpenAICompatibleMode() {
|
||||
ctx.SetContext(contextOpenAICompatibleMarker, true)
|
||||
body, err := v.onOpenAICompatibleRequestBody(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
// OpenAI 兼容模式需要 OAuth token
|
||||
cached, err := v.getToken()
|
||||
if cached {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
body, err := v.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
headers.Set("Content-Length", fmt.Sprint(len(body)))
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 不需要 Authorization header,API Key 已在 URL 中
|
||||
headers.Del("Authorization")
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 标准模式: 需要获取 OAuth token
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -162,17 +272,58 @@ func (v *vertexProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, ap
|
||||
}
|
||||
}
|
||||
|
||||
// onOpenAICompatibleRequestBody 处理 OpenAI 兼容模式的请求
|
||||
// 不转换请求体格式,只进行模型映射和路径设置
|
||||
func (v *vertexProvider) onOpenAICompatibleRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return nil, fmt.Errorf("OpenAI compatible mode only supports chat completions API")
|
||||
}
|
||||
|
||||
// 解析请求进行模型映射
|
||||
request := &chatCompletionRequest{}
|
||||
if err := v.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置 OpenAI 兼容端点路径
|
||||
path := v.getOpenAICompatibleRequestPath()
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
// 如果模型被映射,需要更新请求体中的模型字段
|
||||
if request.Model != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", request.Model)
|
||||
}
|
||||
|
||||
// 保持 OpenAI 格式,直接返回(可能更新了模型字段)
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := v.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
if strings.HasPrefix(request.Model, "claude") {
|
||||
ctx.SetContext(contextClaudeMarker, true)
|
||||
path := v.getAhthropicRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexChatRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
claudeRequest := v.claude.buildClaudeTextGenRequest(request)
|
||||
claudeRequest.Model = ""
|
||||
claudeRequest.AnthropicVersion = vertexAnthropicVersion
|
||||
claudeBody, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claudeBody, nil
|
||||
} else {
|
||||
path := v.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
|
||||
vertexRequest := v.buildVertexChatRequest(request)
|
||||
return json.Marshal(vertexRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
@@ -188,6 +339,15 @@ func (v *vertexProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
}
|
||||
|
||||
func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
|
||||
return util.DecodeUnicodeEscapesInSSE(chunk), nil
|
||||
}
|
||||
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.OnStreamingResponseBody(ctx, name, chunk, isLastChunk)
|
||||
}
|
||||
log.Infof("[vertexProvider] receive chunk body: %s", string(chunk))
|
||||
if isLastChunk {
|
||||
return []byte(ssePrefix + "[DONE]\n\n"), nil
|
||||
@@ -225,6 +385,15 @@ func (v *vertexProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
|
||||
func (v *vertexProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
// OpenAI 兼容模式: 透传响应,但需要解码 Unicode 转义序列
|
||||
// Vertex AI OpenAI-compatible API 返回 ASCII-safe JSON,将非 ASCII 字符编码为 \uXXXX
|
||||
if ctx.GetContext(contextOpenAICompatibleMarker) != nil && ctx.GetContext(contextOpenAICompatibleMarker).(bool) {
|
||||
return util.DecodeUnicodeEscapes(body), nil
|
||||
}
|
||||
|
||||
if ctx.GetContext(contextClaudeMarker) != nil && ctx.GetContext(contextClaudeMarker).(bool) {
|
||||
return v.claude.TransformResponseBody(ctx, apiName, body)
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return v.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
@@ -252,6 +421,9 @@ func (v *vertexProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
|
||||
PromptTokens: response.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: response.UsageMetadata.TotalTokenCount,
|
||||
CompletionTokensDetails: &completionTokensDetails{
|
||||
ReasoningTokens: response.UsageMetadata.ThoughtsTokenCount,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, candidate := range response.Candidates {
|
||||
@@ -320,6 +492,7 @@ func (v *vertexProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, vertex
|
||||
|
||||
func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, vertexResp *vertexChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
choice.Delta = &chatMessage{}
|
||||
if len(vertexResp.Candidates) > 0 && len(vertexResp.Candidates[0].Content.Parts) > 0 {
|
||||
part := vertexResp.Candidates[0].Content.Parts[0]
|
||||
if part.FunctionCall != nil {
|
||||
@@ -361,6 +534,9 @@ func (v *vertexProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
||||
PromptTokens: vertexResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: vertexResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: vertexResp.UsageMetadata.TotalTokenCount,
|
||||
CompletionTokensDetails: &completionTokensDetails{
|
||||
ReasoningTokens: vertexResp.UsageMetadata.ThoughtsTokenCount,
|
||||
},
|
||||
},
|
||||
}
|
||||
return &streamResponse
|
||||
@@ -370,6 +546,32 @@ func (v *vertexProvider) appendResponse(responseBuilder *strings.Builder, respon
|
||||
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getAhthropicRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if stream {
|
||||
action = vertexAnthropicMessageStreamAction
|
||||
} else {
|
||||
action = vertexAnthropicMessageAction
|
||||
}
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathAnthropicTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?,使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathAnthropicTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
@@ -379,7 +581,28 @@ func (v *vertexProvider) getRequestPath(apiName ApiName, modelId string, stream
|
||||
} else {
|
||||
action = vertexChatCompletionAction
|
||||
}
|
||||
return fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
|
||||
if v.isExpressMode() {
|
||||
// Express Mode: 简化路径 + API Key 参数
|
||||
basePath := fmt.Sprintf(vertexExpressPathTemplate, modelId, action)
|
||||
apiKey := v.config.GetRandomToken()
|
||||
// 如果 action 已经包含 ?(如 streamGenerateContent?alt=sse),使用 & 拼接
|
||||
var fullPath string
|
||||
if strings.Contains(action, "?") {
|
||||
fullPath = basePath + "&key=" + apiKey
|
||||
} else {
|
||||
fullPath = basePath + "?key=" + apiKey
|
||||
}
|
||||
return fullPath
|
||||
}
|
||||
|
||||
path := fmt.Sprintf(vertexPathTemplate, v.config.vertexProjectId, v.config.vertexRegion, modelId, action)
|
||||
return path
|
||||
}
|
||||
|
||||
// getOpenAICompatibleRequestPath 获取 OpenAI 兼容模式的请求路径
|
||||
func (v *vertexProvider) getOpenAICompatibleRequestPath() string {
|
||||
return fmt.Sprintf(vertexOpenAICompatiblePathTemplate, v.config.vertexProjectId, v.config.vertexRegion)
|
||||
}
|
||||
|
||||
func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest) *vertexChatRequest {
|
||||
@@ -400,19 +623,22 @@ func (v *vertexProvider) buildVertexChatRequest(request *chatCompletionRequest)
|
||||
},
|
||||
}
|
||||
if request.ReasoningEffort != "" {
|
||||
thinkingBudget := 1024 // default
|
||||
switch request.ReasoningEffort {
|
||||
case "low":
|
||||
thinkingBudget = 1024
|
||||
case "medium":
|
||||
thinkingBudget = 4096
|
||||
case "high":
|
||||
thinkingBudget = 16384
|
||||
}
|
||||
vertexRequest.GenerationConfig.ThinkingConfig = vertexThinkingConfig{
|
||||
thinkingConfig := vertexThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingBudget: thinkingBudget,
|
||||
ThinkingBudget: 1024,
|
||||
}
|
||||
switch request.ReasoningEffort {
|
||||
case "none":
|
||||
thinkingConfig.IncludeThoughts = false
|
||||
thinkingConfig.ThinkingBudget = 0
|
||||
case "low":
|
||||
thinkingConfig.ThinkingBudget = 1024
|
||||
case "medium":
|
||||
thinkingConfig.ThinkingBudget = 4096
|
||||
case "high":
|
||||
thinkingConfig.ThinkingBudget = 16384
|
||||
}
|
||||
vertexRequest.GenerationConfig.ThinkingConfig = thinkingConfig
|
||||
}
|
||||
if request.Tools != nil {
|
||||
functions := make([]function, 0, len(request.Tools))
|
||||
@@ -633,6 +859,7 @@ type vertexUsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount,omitempty"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
|
||||
TotalTokenCount int `json:"totalTokenCount,omitempty"`
|
||||
ThoughtsTokenCount int `json:"thoughtsTokenCount,omitempty"`
|
||||
}
|
||||
|
||||
type vertexEmbeddingResponse struct {
|
||||
|
||||
@@ -160,6 +160,59 @@ var azureResponseAPIConfig = func() json.RawMessage {
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI basePath移除 + original协议
|
||||
var azureBasePathRemovePrefixOriginalConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-basepath-original",
|
||||
},
|
||||
"azureServiceUrl": "https://basepath-test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15-preview",
|
||||
"basePath": "/azure-gpt4",
|
||||
"basePathHandling": "removePrefix",
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI basePath移除 + openai协议
|
||||
var azureBasePathRemovePrefixOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-basepath-openai",
|
||||
},
|
||||
"azureServiceUrl": "https://basepath-test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15-preview",
|
||||
"basePath": "/azure-gpt4",
|
||||
"basePathHandling": "removePrefix",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI basePath prepend + original协议
|
||||
var azureBasePathPrependOriginalConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-prepend-original",
|
||||
},
|
||||
"azureServiceUrl": "https://prepend-test.openai.azure.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-15-preview",
|
||||
"basePath": "/api/v1",
|
||||
"basePathHandling": "prepend",
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunAzureParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本Azure OpenAI配置解析
|
||||
@@ -682,3 +735,218 @@ func RunAzureOnHttpResponseBodyTests(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// RunAzureBasePathHandlingTests 测试 basePath 处理在不同协议下的行为
|
||||
func RunAzureBasePathHandlingTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 核心用例:测试 basePath removePrefix 在 original 协议下能正常工作
|
||||
// 重要:此测试验证在 TransformRequestBody 阶段后 path 仍然保持正确
|
||||
// 之前的 bug 是 transformRequestPath 在 IsOriginal() 时返回 originalPath,
|
||||
// 导致在 Body 阶段 path 被重新覆盖为包含 basePath 的原始路径
|
||||
t.Run("azure basePath removePrefix with original protocol after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureBasePathRemovePrefixOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟带有 basePath 前缀的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/azure-gpt4/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 在 Headers 阶段后验证 path(此时 handleRequestHeaders 已执行)
|
||||
headersAfterHeaderStage := host.GetRequestHeaders()
|
||||
pathAfterHeaders, _ := test.GetHeaderValue(headersAfterHeaderStage, ":path")
|
||||
// Headers 阶段后,basePath 应该已被移除
|
||||
require.NotContains(t, pathAfterHeaders, "/azure-gpt4",
|
||||
"After headers stage: basePath should be removed")
|
||||
|
||||
// 执行 Body 阶段(此时 TransformRequestBody 会被调用)
|
||||
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 核心验证:在 Body 阶段后验证 path
|
||||
// 这是关键测试点:确保 TransformRequestBody 中的 transformRequestPath
|
||||
// 不会将 path 重新覆盖为包含 basePath 的原始路径
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// basePath "/azure-gpt4" 不应该出现在最终路径中
|
||||
require.NotContains(t, pathValue, "/azure-gpt4",
|
||||
"After body stage: basePath should still be removed (not restored by TransformRequestBody)")
|
||||
// 路径应该是移除 basePath 后的结果
|
||||
require.Equal(t, "/v1/chat/completions", pathValue,
|
||||
"Path should be the original path without basePath after full request processing")
|
||||
|
||||
// 验证 Host 被正确设置
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Equal(t, "basepath-test.openai.azure.com", hostValue)
|
||||
|
||||
// 验证 api-key 被正确设置
|
||||
apiKeyValue, hasApiKey := test.GetHeaderValue(requestHeaders, "api-key")
|
||||
require.True(t, hasApiKey, "api-key header should exist")
|
||||
require.Equal(t, "sk-azure-basepath-original", apiKeyValue)
|
||||
})
|
||||
|
||||
// 测试 basePath removePrefix 在 openai 协议下的行为
|
||||
// 在 openai 协议下,path 会被转换为 Azure 格式,但 basePath 仍然应该被移除
|
||||
t.Run("azure basePath removePrefix with openai protocol after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureBasePathRemovePrefixOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟带有 basePath 前缀的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/azure-gpt4/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 执行 Body 阶段(TransformRequestBody 会被调用)
|
||||
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 在 Body 阶段后验证请求头
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// basePath 应该被移除,路径会被转换为 Azure 路径格式
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// basePath "/azure-gpt4" 不应该出现在最终路径中
|
||||
require.NotContains(t, pathValue, "/azure-gpt4",
|
||||
"After body stage: basePath should be removed from path")
|
||||
// 在 openai 协议下,路径会被转换为 Azure 的路径格式
|
||||
require.Contains(t, pathValue, "/openai/deployments/gpt-4/chat/completions",
|
||||
"Path should be transformed to Azure format")
|
||||
require.Contains(t, pathValue, "api-version=2024-02-15-preview",
|
||||
"Path should contain API version")
|
||||
})
|
||||
|
||||
// 测试 basePath prepend 在 original 协议下能正常工作
|
||||
// 验证在 Body 阶段后 prepend 的 basePath 仍然保持
|
||||
t.Run("azure basePath prepend with original protocol after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureBasePathPrependOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟不带 basePath 的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 执行 Body 阶段
|
||||
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 在 Body 阶段后验证请求头
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证 basePath 被正确添加且在 Body 阶段后保持
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// basePath "/api/v1" 应该被添加到路径前面
|
||||
require.True(t, strings.HasPrefix(pathValue, "/api/v1"),
|
||||
"After body stage: Path should still start with prepended basePath")
|
||||
})
|
||||
|
||||
// 测试 original 协议下请求体不被修改,同时验证 path 处理
|
||||
t.Run("azure original protocol preserves request body and path", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureBasePathRemovePrefixOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/azure-gpt4/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体(包含自定义字段)
|
||||
requestBody := `{
|
||||
"model": "custom-model-name",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"custom_field": "custom_value"
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体被保持原样
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// model 应该保持原样(original 协议不做模型映射)
|
||||
model, exists := bodyMap["model"]
|
||||
require.True(t, exists, "Model should exist")
|
||||
require.Equal(t, "custom-model-name", model, "Model should remain unchanged")
|
||||
|
||||
// 自定义字段应该保持原样
|
||||
customField, exists := bodyMap["custom_field"]
|
||||
require.True(t, exists, "Custom field should exist")
|
||||
require.Equal(t, "custom_value", customField, "Custom field should remain unchanged")
|
||||
|
||||
// 同时验证 path 在 Body 阶段后仍然正确
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.NotContains(t, pathValue, "/azure-gpt4",
|
||||
"After body stage: basePath should be removed")
|
||||
require.Equal(t, "/v1/chat/completions", pathValue,
|
||||
"Path should be correct after body processing")
|
||||
})
|
||||
|
||||
// 测试无 basePath 前缀的请求(removePrefix 配置不影响)
|
||||
// 验证在 Body 阶段后 path 仍然保持正确
|
||||
t.Run("azure request without basePath prefix after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureBasePathRemovePrefixOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟不带 basePath 前缀的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 执行 Body 阶段
|
||||
requestBody := `{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 在 Body 阶段后验证请求头
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// 路径应该保持原样(没有 basePath 前缀时,removePrefix 不会改变 path)
|
||||
// 同时验证 TransformRequestBody 没有覆盖 path
|
||||
require.Equal(t, "/v1/chat/completions", pathValue,
|
||||
"After body stage: Path should remain unchanged when no basePath prefix")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
527
plugins/wasm-go/extensions/ai-proxy/test/bedrock.go
Normal file
527
plugins/wasm-go/extensions/ai-proxy/test/bedrock.go
Normal file
@@ -0,0 +1,527 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test config: Basic Bedrock config with AWS Access Key/Secret Key (AWS Signature V4)
|
||||
var basicBedrockConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with Bearer Token authentication
|
||||
var bedrockApiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-for-unit-test",
|
||||
},
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with multiple Bearer Tokens
|
||||
var bedrockMultiTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-1-for-unit-test",
|
||||
"test-token-2-for-unit-test",
|
||||
},
|
||||
"awsRegion": "us-west-2",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-4": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"*": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Bedrock config with additional fields
|
||||
var bedrockWithAdditionalFieldsConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsSecretKey": "test-sk-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"bedrockAdditionalFields": map[string]interface{}{
|
||||
"top_k": 200,
|
||||
},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Invalid config - missing both apiTokens and ak/sk
|
||||
var bedrockInvalidConfigMissingAuth = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Invalid config - missing region
|
||||
var bedrockInvalidConfigMissingRegion = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"apiTokens": []string{
|
||||
"test-token-for-unit-test",
|
||||
},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// Test config: Invalid config - only has access key without secret key
|
||||
var bedrockInvalidConfigPartialAkSk = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "bedrock",
|
||||
"awsAccessKey": "test-ak-for-unit-test",
|
||||
"awsRegion": "us-east-1",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunBedrockParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// Test basic Bedrock config with AWS Signature V4 authentication
|
||||
t.Run("basic bedrock config with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test Bedrock config with Bearer Token authentication
|
||||
t.Run("bedrock config with api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test Bedrock config with multiple tokens
|
||||
t.Run("bedrock config with multiple tokens", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockMultiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test Bedrock config with additional fields
|
||||
t.Run("bedrock config with additional fields", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockWithAdditionalFieldsConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// Test invalid config - missing authentication
|
||||
t.Run("bedrock invalid config missing auth", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockInvalidConfigMissingAuth)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// Test invalid config - missing region
|
||||
t.Run("bedrock invalid config missing region", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockInvalidConfigMissingRegion)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// Test invalid config - partial ak/sk (only access key, no secret key)
|
||||
t.Run("bedrock invalid config partial ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockInvalidConfigPartialAkSk)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock request headers processing with AWS Signature V4
|
||||
t.Run("bedrock chat completion request headers with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Verify request headers
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Host is changed to Bedrock service domain
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
|
||||
})
|
||||
|
||||
// Test Bedrock request headers processing with Bearer Token
|
||||
t.Run("bedrock chat completion request headers with api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Verify request headers
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Host is changed to Bedrock service domain
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Contains(t, hostValue, "bedrock-runtime.us-east-1.amazonaws.com", "Host should be changed to Bedrock service domain")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock request body processing with Bearer Token authentication
|
||||
t.Run("bedrock chat completion request body with api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify request headers for Bearer Token authentication
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Authorization header uses Bearer token
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "Bearer ", "Authorization should use Bearer token")
|
||||
require.Contains(t, authValue, "test-token-for-unit-test", "Authorization should contain the configured token")
|
||||
|
||||
// Verify path is transformed to Bedrock format
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
|
||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||
})
|
||||
|
||||
// Test Bedrock request body processing with AWS Signature V4 authentication
|
||||
t.Run("bedrock chat completion request body with ak/sk", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicBedrockConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify request headers for AWS Signature V4 authentication
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Verify Authorization header uses AWS Signature
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "AWS4-HMAC-SHA256", "Authorization should use AWS4-HMAC-SHA256 signature")
|
||||
require.Contains(t, authValue, "Credential=", "Authorization should contain Credential")
|
||||
require.Contains(t, authValue, "Signature=", "Authorization should contain Signature")
|
||||
|
||||
// Verify X-Amz-Date header exists
|
||||
dateValue, hasDate := test.GetHeaderValue(requestHeaders, "X-Amz-Date")
|
||||
require.True(t, hasDate, "X-Amz-Date header should exist for AWS Signature V4")
|
||||
require.NotEmpty(t, dateValue, "X-Amz-Date should not be empty")
|
||||
|
||||
// Verify path is transformed to Bedrock format
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
|
||||
require.Contains(t, pathValue, "/converse", "Path should contain converse endpoint")
|
||||
})
|
||||
|
||||
// Test Bedrock streaming request
|
||||
t.Run("bedrock streaming request", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set streaming request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify path is transformed to Bedrock streaming format
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/model/", "Path should contain Bedrock model path")
|
||||
require.Contains(t, pathValue, "/converse-stream", "Path should contain converse-stream endpoint for streaming")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpResponseHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock response headers processing
|
||||
t.Run("bedrock response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Process response headers
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"X-Amzn-Requestid", "test-request-id-12345"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify response headers
|
||||
responseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, responseHeaders)
|
||||
|
||||
// Verify status code
|
||||
statusValue, hasStatus := test.GetHeaderValue(responseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "200", statusValue, "Status should be 200")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunBedrockOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// Test Bedrock response body processing
|
||||
t.Run("bedrock response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(bedrockApiTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// Set request headers
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// Set request body
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Set response property to ensure IsResponseFromUpstream() returns true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// Process response headers (must include :status 200 for body processing)
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Process response body (Bedrock format)
|
||||
responseBody := `{
|
||||
"output": {
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"text": "Hello! How can I help you today?"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"stopReason": "end_turn",
|
||||
"usage": {
|
||||
"inputTokens": 10,
|
||||
"outputTokens": 15,
|
||||
"totalTokens": 25
|
||||
}
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// Verify response body is transformed to OpenAI format
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedResponseBody, &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify choices exist in transformed response
|
||||
choices, exists := responseMap["choices"]
|
||||
require.True(t, exists, "Choices should exist in response body")
|
||||
require.NotNil(t, choices, "Choices should not be nil")
|
||||
|
||||
// Verify usage exists
|
||||
usage, exists := responseMap["usage"]
|
||||
require.True(t, exists, "Usage should exist in response body")
|
||||
require.NotNil(t, usage, "Usage should not be nil")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -213,7 +213,9 @@ func RunFireworksOnHttpRequestHeadersTests(t *testing.T) {
|
||||
{":method", "GET"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
// TODO: Due to the limitations of the test framework, we just treat it as a request with body here.
|
||||
//require.Equal(t, types.ActionContinue, action)
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
|
||||
239
plugins/wasm-go/extensions/ai-proxy/test/generic.go
Normal file
239
plugins/wasm-go/extensions/ai-proxy/test/generic.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 通用测试配置:最简配置,覆盖 host 与 token 注入。
|
||||
var genericBasicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "generic",
|
||||
"apiTokens": []string{"sk-generic-basic"},
|
||||
"genericHost": "generic.backend.internal",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 通用测试配置:开启 basePath removePrefix。
|
||||
var genericBasePathConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "generic",
|
||||
"apiTokens": []string{"sk-generic-basepath"},
|
||||
"genericHost": "basepath.backend.internal",
|
||||
"basePath": "/proxy",
|
||||
"basePathHandling": "removePrefix",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 通用测试配置:开启 basePath prepend。
|
||||
var genericPrependBasePathConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "generic",
|
||||
"apiTokens": []string{"sk-generic-prepend"},
|
||||
"genericHost": "prepend.backend.internal",
|
||||
"basePath": "/custom",
|
||||
"basePathHandling": "prepend",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 通用测试配置:覆盖 firstByteTimeout,用于流式能力验证。
|
||||
var genericStreamingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "generic",
|
||||
"apiTokens": []string{"sk-generic-stream"},
|
||||
"genericHost": "stream.backend.internal",
|
||||
"firstByteTimeout": 1500,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 通用测试配置:无 token,也不设置 host。
|
||||
var genericNoTokenConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "generic",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunGenericParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
t.Run("generic basic config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericBasicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
t.Run("generic config without token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericNoTokenConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
t.Run("generic config with streaming options", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericStreamingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunGenericOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("generic injects token and custom host", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericBasicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "client.local"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "generic.backend.internal"))
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "Authorization", "Bearer sk-generic-basic"))
|
||||
|
||||
_, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length")
|
||||
require.False(t, hasContentLength, "generic provider should remove Content-Length")
|
||||
})
|
||||
|
||||
t.Run("generic removes basePath prefix", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericBasePathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "client.local"},
|
||||
{":path", "/proxy/service/echo"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":path", "/service/echo"))
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "basepath.backend.internal"))
|
||||
})
|
||||
|
||||
t.Run("generic prepends basePath when configured", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericPrependBasePathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "client.local"},
|
||||
{":path", "/v1/echo"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":path", "/custom/v1/echo"))
|
||||
})
|
||||
|
||||
t.Run("generic firstByteTimeout injects timeout header only", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericStreamingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "client.local"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms", "1500"))
|
||||
|
||||
_, hasAccept := test.GetHeaderValue(requestHeaders, "Accept")
|
||||
require.False(t, hasAccept, "Accept header should remain untouched when enabling firstByteTimeout")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunGenericOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("generic body passthrough keeps headers unchanged with timeout", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericStreamingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "client.local"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
body := `{"model":"gpt-any","stream":true}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms", "1500"))
|
||||
_, hasAccept := test.GetHeaderValue(requestHeaders, "Accept")
|
||||
require.False(t, hasAccept, "Accept header should remain untouched even when firstByteTimeout is enabled")
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.JSONEq(t, body, string(processedBody))
|
||||
})
|
||||
|
||||
t.Run("generic without first byte timeout keeps headers untouched", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(genericBasicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "client.local"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
body := `{"model":"gpt-any","stream":true}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
_, hasAccept := test.GetHeaderValue(requestHeaders, "Accept")
|
||||
require.False(t, hasAccept, "Accept header should remain untouched when first byte timeout is disabled")
|
||||
|
||||
_, hasTimeout := test.GetHeaderValue(requestHeaders, "x-envoy-upstream-rq-first-byte-timeout-ms")
|
||||
require.False(t, hasTimeout, "timeout header should not be added when first byte timeout is disabled")
|
||||
|
||||
processedBody := host.GetRequestBody()
|
||||
require.JSONEq(t, body, string(processedBody))
|
||||
})
|
||||
})
|
||||
}
|
||||
251
plugins/wasm-go/extensions/ai-proxy/test/minimax.go
Normal file
251
plugins/wasm-go/extensions/ai-proxy/test/minimax.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:Minimax Pro API + basePath removePrefix + original 协议
|
||||
var minimaxProBasePathRemovePrefixOriginalConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "minimax",
|
||||
"apiTokens": []string{
|
||||
"sk-minimax-test",
|
||||
},
|
||||
"minimaxApiType": "pro",
|
||||
"minimaxGroupId": "test-group-id",
|
||||
"basePath": "/minimax-api",
|
||||
"basePathHandling": "removePrefix",
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Minimax Pro API + basePath removePrefix + 默认协议(openai)
|
||||
var minimaxProBasePathRemovePrefixOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "minimax",
|
||||
"apiTokens": []string{
|
||||
"sk-minimax-openai",
|
||||
},
|
||||
"minimaxApiType": "pro",
|
||||
"minimaxGroupId": "test-group-id",
|
||||
"basePath": "/minimax-api",
|
||||
"basePathHandling": "removePrefix",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Minimax V2 API + basePath removePrefix + original 协议
|
||||
var minimaxV2BasePathRemovePrefixOriginalConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "minimax",
|
||||
"apiTokens": []string{
|
||||
"sk-minimax-v2",
|
||||
},
|
||||
"minimaxApiType": "v2",
|
||||
"basePath": "/minimax-v2",
|
||||
"basePathHandling": "removePrefix",
|
||||
"protocol": "original",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// RunMinimaxBasePathHandlingTests 测试 Minimax basePath 处理在不同协议下的行为
|
||||
func RunMinimaxBasePathHandlingTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 核心用例:测试 Minimax Pro API + basePath removePrefix + original 协议
|
||||
// 重要:此测试验证在 handleRequestBodyByChatCompletionPro 阶段后 path 仍然保持正确
|
||||
// 之前的 bug 是 handleRequestBodyByChatCompletionPro 无条件覆盖 path,
|
||||
// 导致在 Body 阶段 path 被重新覆盖为 minimaxChatCompletionProPath
|
||||
t.Run("minimax pro basePath removePrefix with original protocol after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimaxProBasePathRemovePrefixOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟带有 basePath 前缀的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/minimax-api/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 在 Headers 阶段后验证 path(此时 handleRequestHeaders 已执行)
|
||||
headersAfterHeaderStage := host.GetRequestHeaders()
|
||||
pathAfterHeaders, _ := test.GetHeaderValue(headersAfterHeaderStage, ":path")
|
||||
// Headers 阶段后,basePath 应该已被移除
|
||||
require.NotContains(t, pathAfterHeaders, "/minimax-api",
|
||||
"After headers stage: basePath should be removed")
|
||||
|
||||
// 执行 Body 阶段(此时 handleRequestBodyByChatCompletionPro 会被调用)
|
||||
requestBody := `{"model": "abab5.5-chat", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 核心验证:在 Body 阶段后验证 path
|
||||
// 这是关键测试点:确保 handleRequestBodyByChatCompletionPro
|
||||
// 不会将 path 重新覆盖为 minimaxChatCompletionProPath
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// basePath "/minimax-api" 不应该出现在最终路径中
|
||||
require.NotContains(t, pathValue, "/minimax-api",
|
||||
"After body stage: basePath should still be removed")
|
||||
// original 协议下,path 不应该被覆盖为 minimaxChatCompletionProPath
|
||||
require.NotContains(t, pathValue, "chatcompletion_pro",
|
||||
"With original protocol: path should not be overwritten to minimax pro path")
|
||||
// 路径应该是移除 basePath 后的结果
|
||||
require.Equal(t, "/v1/chat/completions", pathValue,
|
||||
"Path should be the original path without basePath after full request processing")
|
||||
|
||||
// 验证 Host 被正确设置
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Equal(t, "api.minimax.chat", hostValue)
|
||||
|
||||
// 验证 Authorization 被正确设置
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Equal(t, "Bearer sk-minimax-test", authValue)
|
||||
})
|
||||
|
||||
// 测试 Minimax Pro API + basePath removePrefix + 默认协议(openai)
|
||||
// 在 openai 协议下,path 应该被覆盖为 minimaxChatCompletionProPath
|
||||
t.Run("minimax pro basePath removePrefix with openai protocol after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimaxProBasePathRemovePrefixOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟带有 basePath 前缀的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/minimax-api/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 执行 Body 阶段
|
||||
requestBody := `{"model": "abab5.5-chat", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 在 Body 阶段后验证请求头
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// basePath "/minimax-api" 不应该出现在最终路径中
|
||||
require.NotContains(t, pathValue, "/minimax-api",
|
||||
"After body stage: basePath should be removed from path")
|
||||
// 在 openai 协议下,path 应该被覆盖为 minimaxChatCompletionProPath
|
||||
require.True(t, strings.Contains(pathValue, "chatcompletion_pro"),
|
||||
"With openai protocol: path should be overwritten to minimax pro path")
|
||||
require.Contains(t, pathValue, "GroupId=test-group-id",
|
||||
"Path should contain GroupId parameter")
|
||||
})
|
||||
|
||||
// 测试 Minimax V2 API + basePath removePrefix + original 协议
|
||||
// V2 API 使用 handleRequestBody 而不是 handleRequestBodyByChatCompletionPro
|
||||
t.Run("minimax v2 basePath removePrefix with original protocol after body processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimaxV2BasePathRemovePrefixOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 模拟带有 basePath 前缀的请求
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/minimax-v2/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 执行 Body 阶段
|
||||
requestBody := `{"model": "abab5.5-chat", "messages": [{"role": "user", "content": "Hello"}]}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 在 Body 阶段后验证请求头
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// basePath "/minimax-v2" 不应该出现在最终路径中
|
||||
require.NotContains(t, pathValue, "/minimax-v2",
|
||||
"After body stage: basePath should be removed from path")
|
||||
// 路径应该是移除 basePath 后的结果
|
||||
require.Equal(t, "/v1/chat/completions", pathValue,
|
||||
"Path should be the original path without basePath")
|
||||
})
|
||||
|
||||
// 测试 original 协议下请求体保持原样
|
||||
t.Run("minimax pro original protocol preserves request body and path", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimaxProBasePathRemovePrefixOriginalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/minimax-api/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体(包含自定义字段)
|
||||
requestBody := `{
|
||||
"model": "custom-model",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"custom_field": "custom_value"
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体被保持原样
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// model 应该保持原样(original 协议不做模型映射)
|
||||
model, exists := bodyMap["model"]
|
||||
require.True(t, exists, "Model should exist")
|
||||
require.Equal(t, "custom-model", model, "Model should remain unchanged")
|
||||
|
||||
// 自定义字段应该保持原样
|
||||
customField, exists := bodyMap["custom_field"]
|
||||
require.True(t, exists, "Custom field should exist")
|
||||
require.Equal(t, "custom_value", customField, "Custom field should remain unchanged")
|
||||
|
||||
// 同时验证 path 在 Body 阶段后仍然正确
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.NotContains(t, pathValue, "/minimax-api",
|
||||
"After body stage: basePath should be removed")
|
||||
require.Equal(t, "/v1/chat/completions", pathValue,
|
||||
"Path should be correct after body processing")
|
||||
})
|
||||
})
|
||||
}
|
||||
116
plugins/wasm-go/extensions/ai-proxy/test/util.go
Normal file
116
plugins/wasm-go/extensions/ai-proxy/test/util.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
)
|
||||
|
||||
func RunMapRequestPathByCapabilityTests(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
apiName string
|
||||
origin string
|
||||
mapping map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no mapping returns empty",
|
||||
apiName: "openai/v1/chatcompletions",
|
||||
origin: "/v1/chat/completions",
|
||||
mapping: map[string]string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "file placeholder is replaced",
|
||||
apiName: "openai/v1/retrievefile",
|
||||
origin: "/openai/v1/files/file-abc",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/retrievefile": "/v1/files/{file_id}",
|
||||
},
|
||||
expected: "/v1/files/file-abc",
|
||||
},
|
||||
{
|
||||
name: "file content keeps query parameters",
|
||||
apiName: "openai/v1/retrievefilecontent",
|
||||
origin: "/openai/v1/files/file-123/content?variant=thumbnail",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/retrievefilecontent": "/v1/files/{file_id}/content",
|
||||
},
|
||||
expected: "/v1/files/file-123/content?variant=thumbnail",
|
||||
},
|
||||
{
|
||||
name: "file content merges query string with mapped query",
|
||||
apiName: "openai/v1/retrievefilecontent",
|
||||
origin: "/openai/v1/files/file-123/content?variant=thumbnail",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/retrievefilecontent": "/v1/files/{file_id}/content?download=1",
|
||||
},
|
||||
expected: "/v1/files/file-123/content?download=1&variant=thumbnail",
|
||||
},
|
||||
{
|
||||
name: "retrieve batch replaces batch id",
|
||||
apiName: "openai/v1/retrievebatch",
|
||||
origin: "/openai/v1/batches/batch-001",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/retrievebatch": "/v1/batches/{batch_id}",
|
||||
},
|
||||
expected: "/v1/batches/batch-001",
|
||||
},
|
||||
{
|
||||
name: "cancel batch replaces batch id",
|
||||
apiName: "openai/v1/cancelbatch",
|
||||
origin: "/openai/v1/batches/batch-002/cancel",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/cancelbatch": "/v1/batches/{batch_id}/cancel",
|
||||
},
|
||||
expected: "/v1/batches/batch-002/cancel",
|
||||
},
|
||||
{
|
||||
name: "video placeholder is replaced",
|
||||
apiName: "openai/v1/retrievevideo",
|
||||
origin: "/openai/v1/videos/video-xyz",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/retrievevideo": "/v1/videos/{video_id}",
|
||||
},
|
||||
expected: "/v1/videos/video-xyz",
|
||||
},
|
||||
{
|
||||
name: "video content placeholder with query",
|
||||
apiName: "openai/v1/retrievevideocontent",
|
||||
origin: "/openai/v1/videos/video-xyz/content?variant=thumbnail",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/retrievevideocontent": "/v1/videos/{video_id}/content",
|
||||
},
|
||||
expected: "/v1/videos/video-xyz/content?variant=thumbnail",
|
||||
},
|
||||
{
|
||||
name: "video remix placeholder is replaced",
|
||||
apiName: "openai/v1/videoremix",
|
||||
origin: "/openai/v1/videos/video-xyz/remix",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/videoremix": "/v1/videos/{video_id}/remix",
|
||||
},
|
||||
expected: "/v1/videos/video-xyz/remix",
|
||||
},
|
||||
{
|
||||
name: "non placeholder mapping returns mapped path directly",
|
||||
apiName: "openai/v1/videos",
|
||||
origin: "/openai/v1/videos",
|
||||
mapping: map[string]string{
|
||||
"openai/v1/videos": "/v1/videos",
|
||||
},
|
||||
expected: "/v1/videos",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := util.MapRequestPathByCapability(tc.apiName, tc.origin, tc.mapping)
|
||||
if got != tc.expected {
|
||||
t.Fatalf("expected %q, got %q", tc.expected, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
888
plugins/wasm-go/extensions/ai-proxy/test/vertex.go
Normal file
888
plugins/wasm-go/extensions/ai-proxy/test/vertex.go
Normal file
@@ -0,0 +1,888 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:Vertex 标准模式配置
|
||||
var basicVertexConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
|
||||
"vertexRegion": "us-central1",
|
||||
"vertexProjectId": "test-project-id",
|
||||
"vertexAuthServiceName": "test-auth-service",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(使用 apiTokens)
|
||||
var vertexExpressModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-123456789"},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(含模型映射)
|
||||
var vertexExpressModeWithModelMappingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-123456789"},
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-4": "gemini-2.5-flash",
|
||||
"gpt-3.5-turbo": "gemini-2.5-flash-lite",
|
||||
"text-embedding-ada-002": "text-embedding-001",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex Express Mode 配置(含安全设置)
|
||||
var vertexExpressModeWithSafetyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key-123456789"},
|
||||
"geminiSafetySetting": map[string]string{
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_LOW_AND_ABOVE",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_NONE",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:无效 Vertex 标准模式配置(缺少 vertexAuthKey)
|
||||
var invalidVertexStandardModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
// 缺少必需的标准模式配置
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex OpenAI 兼容模式配置
|
||||
var vertexOpenAICompatibleModeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"vertexOpenAICompatible": true,
|
||||
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
|
||||
"vertexRegion": "us-central1",
|
||||
"vertexProjectId": "test-project-id",
|
||||
"vertexAuthServiceName": "test-auth-service",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Vertex OpenAI 兼容模式配置(含模型映射)
|
||||
var vertexOpenAICompatibleModeWithModelMappingConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"vertexOpenAICompatible": true,
|
||||
"vertexAuthKey": `{"type":"service_account","client_email":"test@test.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7k1v5C7y8L4SN\n-----END PRIVATE KEY-----\n","token_uri":"https://oauth2.googleapis.com/token"}`,
|
||||
"vertexRegion": "us-central1",
|
||||
"vertexProjectId": "test-project-id",
|
||||
"vertexAuthServiceName": "test-auth-service",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-4": "gemini-2.0-flash",
|
||||
"gpt-3.5-turbo": "gemini-1.5-flash",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:无效配置 - Express Mode 与 OpenAI 兼容模式互斥
|
||||
var invalidVertexExpressAndOpenAICompatibleConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "vertex",
|
||||
"apiTokens": []string{"test-api-key"},
|
||||
"vertexOpenAICompatible": true,
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunVertexParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex 标准模式配置解析
|
||||
t.Run("vertex standard mode config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicVertexConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 配置解析
|
||||
t.Run("vertex express mode config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 配置(含模型映射)
|
||||
t.Run("vertex express mode with model mapping config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试无效 Vertex 标准模式配置(缺少 vertexAuthKey)
|
||||
t.Run("invalid vertex standard mode config - missing auth key", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(invalidVertexStandardModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 配置(含安全设置)
|
||||
t.Run("vertex express mode with safety setting config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithSafetyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式配置解析
|
||||
t.Run("vertex openai compatible mode config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式配置(含模型映射)
|
||||
t.Run("vertex openai compatible mode with model mapping config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试无效配置 - Express Mode 与 OpenAI 兼容模式互斥
|
||||
t.Run("invalid config - express mode and openai compatible mode conflict", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(invalidVertexExpressAndOpenAICompatibleConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 请求头处理(聊天完成接口)
|
||||
t.Run("vertex express mode chat completion request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要处理请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为 vertex 域名(Express Mode 使用不带 region 前缀的域名)
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain without region prefix")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasVertexLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "vertex") {
|
||||
hasVertexLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasVertexLogs, "Should have vertex processing logs")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求头处理(嵌入接口)
|
||||
t.Run("vertex express mode embeddings request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证嵌入接口的请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host转换
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "aiplatform.googleapis.com"), "Host header should be changed to vertex domain")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 请求体处理(聊天完成接口)
|
||||
t.Run("vertex express mode chat completion request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// Express Mode 不需要暂停等待 OAuth token
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证请求体被转换为 Vertex 格式
|
||||
require.Contains(t, string(processedBody), "contents", "Request should be converted to vertex format")
|
||||
require.Contains(t, string(processedBody), "generationConfig", "Request should contain vertex generation config")
|
||||
|
||||
// 验证路径包含 API Key
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
|
||||
require.Contains(t, pathHeader, "/v1/publishers/google/models/", "Path should use Express Mode format without project/location")
|
||||
|
||||
// 验证没有 Authorization header(Express Mode 使用 URL 参数)
|
||||
hasAuthHeader := false
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == "Authorization" && header[1] != "" {
|
||||
hasAuthHeader = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.False(t, hasAuthHeader, "Authorization header should be removed in Express Mode")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasVertexLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "vertex") {
|
||||
hasVertexLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasVertexLogs, "Should have vertex processing logs")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(嵌入接口)
|
||||
t.Run("vertex express mode embeddings request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-001","input":"test text"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证嵌入接口的请求体处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证请求体被转换为 Vertex 格式
|
||||
require.Contains(t, string(processedBody), "instances", "Request should be converted to vertex format")
|
||||
|
||||
// 验证路径包含 API Key
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key as query parameter")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(流式请求)
|
||||
t.Run("vertex express mode streaming request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证路径包含流式 action
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "streamGenerateContent", "Path should contain streaming action")
|
||||
require.Contains(t, pathHeader, "key=test-api-key-123456789", "Path should contain API key")
|
||||
})
|
||||
|
||||
// 测试 Vertex Express Mode 请求体处理(含模型映射)
|
||||
t.Run("vertex express mode with model mapping request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(使用 OpenAI 模型名)
|
||||
requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证路径包含映射后的模型名
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "gemini-2.5-flash", "Path should contain mapped model name")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 响应体处理(聊天完成接口)
|
||||
t.Run("vertex express mode chat completion response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体(Vertex 格式)
|
||||
responseBody := `{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{
|
||||
"text": "Hello! How can I help you today?"
|
||||
}]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0
|
||||
}],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 9,
|
||||
"candidatesTokenCount": 12,
|
||||
"totalTokenCount": 21
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证响应体内容(转换为OpenAI格式)
|
||||
responseStr := string(processedResponseBody)
|
||||
|
||||
// 检查响应体是否被转换
|
||||
if strings.Contains(responseStr, "chat.completion") {
|
||||
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage information")
|
||||
}
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasResponseBodyLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "vertex") {
|
||||
hasResponseBodyLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexOpenAICompatibleModeOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex OpenAI 兼容模式请求头处理
|
||||
t.Run("vertex openai compatible mode request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要处理请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为 vertex 域名(带 region 前缀)
|
||||
require.True(t, test.HasHeaderWithValue(requestHeaders, ":authority", "us-central1-aiplatform.googleapis.com"), "Host header should be changed to vertex domain with region prefix")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexOpenAICompatibleModeOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex OpenAI 兼容模式请求体处理(不转换格式,保持 OpenAI 格式)
|
||||
t.Run("vertex openai compatible mode request body - no format conversion", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(OpenAI 格式)
|
||||
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// OpenAI 兼容模式需要等待 OAuth token,所以返回 ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 验证请求体保持 OpenAI 格式(不转换为 Vertex 原生格式)
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// OpenAI 兼容模式应该保持 messages 字段,而不是转换为 contents
|
||||
require.Contains(t, string(processedBody), "messages", "Request should keep OpenAI format with messages field")
|
||||
require.NotContains(t, string(processedBody), "contents", "Request should NOT be converted to vertex native format")
|
||||
|
||||
// 验证路径为 OpenAI 兼容端点
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathHeader := ""
|
||||
for _, header := range requestHeaders {
|
||||
if header[0] == ":path" {
|
||||
pathHeader = header[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Contains(t, pathHeader, "/v1beta1/projects/", "Path should use OpenAI compatible endpoint format")
|
||||
require.Contains(t, pathHeader, "/endpoints/openapi/chat/completions", "Path should contain openapi chat completions endpoint")
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式请求体处理(含模型映射)
|
||||
t.Run("vertex openai compatible mode with model mapping", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeWithModelMappingConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体(使用 OpenAI 模型名)
|
||||
requestBody := `{"model":"gpt-4","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 验证请求体中的模型名被映射
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 模型名应该被映射为 gemini-2.0-flash
|
||||
require.Contains(t, string(processedBody), "gemini-2.0-flash", "Model name should be mapped to gemini-2.0-flash")
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式不支持 Embeddings API
|
||||
t.Run("vertex openai compatible mode - embeddings not supported", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-001","input":"test text"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// OpenAI 兼容模式只支持 chat completions,embeddings 应该返回错误
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexExpressModeOnStreamingResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex Express Mode 流式响应处理
|
||||
t.Run("vertex express mode streaming response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexExpressModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟流式响应体
|
||||
chunk1 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}],"role":"model"},"finishReason":"","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":5,"totalTokenCount":14}}`
|
||||
chunk2 := `data: {"candidates":[{"content":{"parts":[{"text":"Hello! How can I help you today?"}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":9,"candidatesTokenCount":12,"totalTokenCount":21}}`
|
||||
|
||||
// 处理流式响应体
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
// 验证流式响应处理
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasStreamingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "vertex") {
|
||||
hasStreamingLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunVertexOpenAICompatibleModeOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试 Vertex OpenAI 兼容模式响应体处理(直接透传,不转换格式)
|
||||
t.Run("vertex openai compatible mode response body - passthrough", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应属性,确保IsResponseFromUpstream()返回true
|
||||
host.SetProperty([]string{"response", "code_details"}, []byte("via_upstream"))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体(OpenAI 格式 - 因为 Vertex AI OpenAI-compatible API 返回的就是 OpenAI 格式)
|
||||
responseBody := `{
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体被直接透传(不进行格式转换)
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 响应应该保持原样
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "chatcmpl-abc123", "Response should be passed through unchanged")
|
||||
require.Contains(t, responseStr, "chat.completion", "Response should contain original object type")
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式流式响应处理(直接透传)
|
||||
t.Run("vertex openai compatible mode streaming response - passthrough", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟 OpenAI 格式的流式响应(Vertex AI OpenAI-compatible API 返回)
|
||||
chunk1 := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}`
|
||||
chunk2 := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":"stop"}]}`
|
||||
|
||||
// 处理流式响应体 - 应该直接透传
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式流式响应处理(Unicode 转义解码)
|
||||
t.Run("vertex openai compatible mode streaming response - unicode escape decoding", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟带有 Unicode 转义的流式响应(Vertex AI OpenAI-compatible API 可能返回的格式)
|
||||
// \u4e2d\u6587 = 中文
|
||||
chunkWithUnicode := `data: {"id":"chatcmpl-abc123","object":"chat.completion.chunk","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"delta":{"role":"assistant","content":"\u4e2d\u6587\u6d4b\u8bd5"},"finish_reason":null}]}`
|
||||
|
||||
// 处理流式响应体 - 应该解码 Unicode 转义
|
||||
action := host.CallOnHttpStreamingResponseBody([]byte(chunkWithUnicode), false)
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体中的 Unicode 转义已被解码
|
||||
responseBody := host.GetResponseBody()
|
||||
require.NotNil(t, responseBody)
|
||||
|
||||
responseStr := string(responseBody)
|
||||
// 应该包含解码后的中文字符,而不是 \uXXXX 转义序列
|
||||
require.Contains(t, responseStr, "中文测试", "Unicode escapes should be decoded to Chinese characters")
|
||||
require.NotContains(t, responseStr, `\u4e2d`, "Should not contain Unicode escape sequences")
|
||||
})
|
||||
|
||||
// 测试 Vertex OpenAI 兼容模式非流式响应处理(Unicode 转义解码)
|
||||
t.Run("vertex openai compatible mode response body - unicode escape decoding", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(vertexOpenAICompatibleModeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gemini-2.0-flash","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟带有 Unicode 转义的响应体
|
||||
// \u76c8\u5229\u80fd\u529b = 盈利能力
|
||||
responseBodyWithUnicode := `{"id":"chatcmpl-abc123","object":"chat.completion","created":1729986750,"model":"gemini-2.0-flash","choices":[{"index":0,"message":{"role":"assistant","content":"\u76c8\u5229\u80fd\u529b\u5206\u6790"},"finish_reason":"stop"}]}`
|
||||
|
||||
// 处理响应体 - 应该解码 Unicode 转义
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBodyWithUnicode))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体中的 Unicode 转义已被解码
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
responseStr := string(processedResponseBody)
|
||||
// 应该包含解码后的中文字符
|
||||
require.Contains(t, responseStr, "盈利能力分析", "Unicode escapes should be decoded to Chinese characters")
|
||||
require.NotContains(t, responseStr, `\u76c8`, "Should not contain Unicode escape sequences")
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -93,6 +93,19 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s
|
||||
if !exist {
|
||||
return ""
|
||||
}
|
||||
mappedPathOnly := mappedPath
|
||||
mappedQuery := ""
|
||||
if queryIndex := strings.Index(mappedPathOnly, "?"); queryIndex >= 0 {
|
||||
mappedPathOnly = mappedPathOnly[:queryIndex]
|
||||
mappedQuery = mappedPath[queryIndex:]
|
||||
}
|
||||
// 将查询字符串从原始路径中剥离,避免干扰正则匹配 video_id 等占位符
|
||||
pathOnly := originPath
|
||||
query := ""
|
||||
if queryIndex := strings.Index(originPath, "?"); queryIndex >= 0 {
|
||||
pathOnly = originPath[:queryIndex]
|
||||
query = originPath[queryIndex:]
|
||||
}
|
||||
if strings.Contains(mappedPath, "{") && strings.Contains(mappedPath, "}") {
|
||||
replacements := []struct {
|
||||
regx *regexp.Regexp
|
||||
@@ -108,8 +121,8 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s
|
||||
}
|
||||
|
||||
for _, r := range replacements {
|
||||
if r.regx.MatchString(originPath) {
|
||||
subMatch := r.regx.FindStringSubmatch(originPath)
|
||||
if r.regx.MatchString(pathOnly) {
|
||||
subMatch := r.regx.FindStringSubmatch(pathOnly)
|
||||
if subMatch == nil {
|
||||
continue
|
||||
}
|
||||
@@ -118,12 +131,25 @@ func MapRequestPathByCapability(apiName string, originPath string, mapping map[s
|
||||
continue
|
||||
}
|
||||
id := subMatch[index]
|
||||
mappedPath = r.regx.ReplaceAllStringFunc(mappedPath, func(s string) string {
|
||||
mappedPathOnly = r.regx.ReplaceAllStringFunc(mappedPathOnly, func(s string) string {
|
||||
return strings.Replace(s, "{"+r.key+"}", id, 1)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if mappedQuery != "" {
|
||||
mappedPath = mappedPathOnly + mappedQuery
|
||||
} else {
|
||||
mappedPath = mappedPathOnly
|
||||
}
|
||||
if query != "" {
|
||||
// 保留原始查询参数,例如 variant=thumbnail
|
||||
if strings.Contains(mappedPath, "?") {
|
||||
mappedPath = mappedPath + "&" + strings.TrimPrefix(query, "?")
|
||||
} else {
|
||||
mappedPath += query
|
||||
}
|
||||
}
|
||||
return mappedPath
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package util
|
||||
|
||||
import "regexp"
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func StripPrefix(s string, prefix string) string {
|
||||
if len(prefix) != 0 && len(s) >= len(prefix) && s[0:len(prefix)] == prefix {
|
||||
@@ -18,3 +22,43 @@ func MatchStatus(status string, patterns []string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// unicodeEscapeRegex matches Unicode escape sequences like \uXXXX
|
||||
var unicodeEscapeRegex = regexp.MustCompile(`\\u([0-9a-fA-F]{4})`)
|
||||
|
||||
// DecodeUnicodeEscapes decodes Unicode escape sequences (\uXXXX) in a string to UTF-8 characters.
|
||||
// This is useful when a JSON response contains ASCII-safe encoded non-ASCII characters.
|
||||
func DecodeUnicodeEscapes(input []byte) []byte {
|
||||
result := unicodeEscapeRegex.ReplaceAllFunc(input, func(match []byte) []byte {
|
||||
// match is like \uXXXX, extract the hex part (XXXX)
|
||||
hexStr := string(match[2:6])
|
||||
codePoint, err := strconv.ParseInt(hexStr, 16, 32)
|
||||
if err != nil {
|
||||
return match // return original if parse fails
|
||||
}
|
||||
return []byte(string(rune(codePoint)))
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
// DecodeUnicodeEscapesInSSE decodes Unicode escape sequences in SSE formatted data.
|
||||
// It processes each line that starts with "data: " and decodes Unicode escapes in the JSON payload.
|
||||
func DecodeUnicodeEscapesInSSE(input []byte) []byte {
|
||||
lines := strings.Split(string(input), "\n")
|
||||
var result strings.Builder
|
||||
for i, line := range lines {
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
// Decode Unicode escapes in the JSON payload
|
||||
jsonData := line[6:]
|
||||
decodedData := DecodeUnicodeEscapes([]byte(jsonData))
|
||||
result.WriteString("data: ")
|
||||
result.Write(decodedData)
|
||||
} else {
|
||||
result.WriteString(line)
|
||||
}
|
||||
if i < len(lines)-1 {
|
||||
result.WriteString("\n")
|
||||
}
|
||||
}
|
||||
return []byte(result.String())
|
||||
}
|
||||
|
||||
108
plugins/wasm-go/extensions/ai-proxy/util/string_test.go
Normal file
108
plugins/wasm-go/extensions/ai-proxy/util/string_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestDecodeUnicodeEscapes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Chinese characters",
|
||||
input: `\u4e2d\u6587\u6d4b\u8bd5`,
|
||||
expected: `中文测试`,
|
||||
},
|
||||
{
|
||||
name: "Mixed content",
|
||||
input: `Hello \u4e16\u754c World`,
|
||||
expected: `Hello 世界 World`,
|
||||
},
|
||||
{
|
||||
name: "No escape sequences",
|
||||
input: `Hello World`,
|
||||
expected: `Hello World`,
|
||||
},
|
||||
{
|
||||
name: "JSON with Unicode escapes",
|
||||
input: `{"content":"\u76c8\u5229\u80fd\u529b"}`,
|
||||
expected: `{"content":"盈利能力"}`,
|
||||
},
|
||||
{
|
||||
name: "Full width parentheses",
|
||||
input: `\uff08\u76c8\u5229\uff09`,
|
||||
expected: `(盈利)`,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: ``,
|
||||
expected: ``,
|
||||
},
|
||||
{
|
||||
name: "Invalid escape sequence (not modified)",
|
||||
input: `\u00GG`,
|
||||
expected: `\u00GG`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DecodeUnicodeEscapes([]byte(tt.input))
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUnicodeEscapesInSSE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "SSE data with Unicode escapes",
|
||||
input: `data: {"choices":[{"delta":{"content":"\u4e2d\u6587"}}]}
|
||||
|
||||
`,
|
||||
expected: `data: {"choices":[{"delta":{"content":"中文"}}]}
|
||||
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Multiple SSE data lines",
|
||||
input: `data: {"content":"\u4e2d\u6587"}
|
||||
data: {"content":"\u82f1\u6587"}
|
||||
data: [DONE]
|
||||
`,
|
||||
expected: `data: {"content":"中文"}
|
||||
data: {"content":"英文"}
|
||||
data: [DONE]
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "Non-data lines unchanged",
|
||||
input: ": comment\nevent: message\ndata: test\n",
|
||||
expected: ": comment\nevent: message\ndata: test\n",
|
||||
},
|
||||
{
|
||||
name: "Real Vertex AI response format",
|
||||
input: `data: {"choices":[{"delta":{"content":"\uff08\u76c8\u5229\u80fd\u529b\uff09","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
|
||||
|
||||
`,
|
||||
expected: `data: {"choices":[{"delta":{"content":"(盈利能力)","role":"assistant"},"index":0}],"created":1768307454,"id":"test","model":"gemini","object":"chat.completion.chunk"}
|
||||
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := DecodeUnicodeEscapesInSSE([]byte(tt.input))
|
||||
assert.Equal(t, tt.expected, string(result))
|
||||
})
|
||||
}
|
||||
}
|
||||
585
plugins/wasm-go/extensions/ai-security-guard/config/config.go
Normal file
585
plugins/wasm-go/extensions/ai-security-guard/config/config.go
Normal file
@@ -0,0 +1,585 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxRisk = "max"
|
||||
HighRisk = "high"
|
||||
MediumRisk = "medium"
|
||||
LowRisk = "low"
|
||||
NoRisk = "none"
|
||||
|
||||
S4Sensitive = "s4"
|
||||
S3Sensitive = "s3"
|
||||
S2Sensitive = "s2"
|
||||
S1Sensitive = "s1"
|
||||
NoSensitive = "s0"
|
||||
|
||||
ContentModerationType = "contentModeration"
|
||||
PromptAttackType = "promptAttack"
|
||||
SensitiveDataType = "sensitiveData"
|
||||
MaliciousUrlDataType = "maliciousUrl"
|
||||
ModelHallucinationDataType = "modelHallucination"
|
||||
|
||||
// Default configurations
|
||||
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
|
||||
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
|
||||
|
||||
DefaultDenyCode = 200
|
||||
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
|
||||
DefaultTimeout = 2000
|
||||
|
||||
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||
LengthLimit = 1800
|
||||
|
||||
DefaultRequestCheckService = "llm_query_moderation"
|
||||
DefaultResponseCheckService = "llm_response_moderation"
|
||||
DefaultRequestJsonPath = "messages.@reverse.0.content"
|
||||
DefaultResponseJsonPath = "choices.0.message.content"
|
||||
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
||||
|
||||
// Actions
|
||||
MultiModalGuard = "MultiModalGuard"
|
||||
MultiModalGuardForBase64 = "MultiModalGuardForBase64"
|
||||
TextModerationPlus = "TextModerationPlus"
|
||||
|
||||
// Services
|
||||
DefaultMultiModalGuardTextInputCheckService = "query_security_check"
|
||||
DefaultMultiModalGuardTextOutputCheckService = "response_security_check"
|
||||
DefaultMultiModalGuardImageInputCheckService = "img_query_security_check"
|
||||
|
||||
DefaultTextModerationPlusTextInputCheckService = "llm_query_moderation"
|
||||
DefaultTextModerationPlusTextOutputCheckService = "llm_response_moderation"
|
||||
)
|
||||
|
||||
// api types
|
||||
|
||||
const (
|
||||
ApiTextGeneration = "text_generation"
|
||||
ApiImageGeneration = "image_generation"
|
||||
)
|
||||
|
||||
// provider types
|
||||
const (
|
||||
ProviderOpenAI = "openai"
|
||||
ProviderQwen = "qwen"
|
||||
ProviderComfyUI = "comfyui"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"Code"`
|
||||
Message string `json:"Message"`
|
||||
RequestId string `json:"RequestId"`
|
||||
Data Data `json:"Data"`
|
||||
}
|
||||
|
||||
type Data struct {
|
||||
RiskLevel string `json:"RiskLevel,omitempty"`
|
||||
AttackLevel string `json:"AttackLevel,omitempty"`
|
||||
Result []Result `json:"Result,omitempty"`
|
||||
Advice []Advice `json:"Advice,omitempty"`
|
||||
Detail []Detail `json:"Detail,omitempty"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
RiskWords string `json:"RiskWords,omitempty"`
|
||||
Description string `json:"Description,omitempty"`
|
||||
Confidence float64 `json:"Confidence,omitempty"`
|
||||
Label string `json:"Label,omitempty"`
|
||||
}
|
||||
|
||||
type Advice struct {
|
||||
Answer string `json:"Answer,omitempty"`
|
||||
HitLabel string `json:"HitLabel,omitempty"`
|
||||
HitLibName string `json:"HitLibName,omitempty"`
|
||||
}
|
||||
|
||||
type Detail struct {
|
||||
Suggestion string `json:"Suggestion,omitempty"`
|
||||
Type string `json:"Type,omitempty"`
|
||||
Level string `json:"Level,omitempty"`
|
||||
}
|
||||
|
||||
type Matcher struct {
|
||||
Exact string
|
||||
Prefix string
|
||||
Re *regexp.Regexp
|
||||
}
|
||||
|
||||
func (m *Matcher) match(consumer string) bool {
|
||||
if m.Exact != "" {
|
||||
return consumer == m.Exact
|
||||
} else if m.Prefix != "" {
|
||||
return strings.HasPrefix(consumer, m.Prefix)
|
||||
} else if m.Re != nil {
|
||||
return m.Re.MatchString(consumer)
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type AISecurityConfig struct {
|
||||
Client wrapper.HttpClient
|
||||
Host string
|
||||
AK string
|
||||
SK string
|
||||
Token string
|
||||
Action string
|
||||
CheckRequest bool
|
||||
CheckRequestImage bool
|
||||
RequestCheckService string
|
||||
RequestImageCheckService string
|
||||
RequestContentJsonPath string
|
||||
CheckResponse bool
|
||||
ResponseCheckService string
|
||||
ResponseImageCheckService string
|
||||
ResponseContentJsonPath string
|
||||
ResponseStreamContentJsonPath string
|
||||
DenyCode int64
|
||||
DenyMessage string
|
||||
ProtocolOriginal bool
|
||||
RiskLevelBar string
|
||||
ContentModerationLevelBar string
|
||||
PromptAttackLevelBar string
|
||||
SensitiveDataLevelBar string
|
||||
MaliciousUrlLevelBar string
|
||||
ModelHallucinationLevelBar string
|
||||
Timeout uint32
|
||||
BufferLimit int
|
||||
Metrics map[string]proxywasm.MetricCounter
|
||||
ConsumerRequestCheckService []map[string]interface{}
|
||||
ConsumerResponseCheckService []map[string]interface{}
|
||||
ConsumerRiskLevel []map[string]interface{}
|
||||
// text_generation, image_generation, etc.
|
||||
ApiType string
|
||||
// openai, qwen, comfyui, etc.
|
||||
ProviderType string
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) Parse(json gjson.Result) error {
|
||||
serviceName := json.Get("serviceName").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
serviceHost := json.Get("serviceHost").String()
|
||||
config.Host = serviceHost
|
||||
if serviceName == "" || servicePort == 0 || serviceHost == "" {
|
||||
return errors.New("invalid service config")
|
||||
}
|
||||
config.AK = json.Get("accessKey").String()
|
||||
config.SK = json.Get("secretKey").String()
|
||||
if config.AK == "" || config.SK == "" {
|
||||
return errors.New("invalid AK/SK config")
|
||||
}
|
||||
config.Token = json.Get("securityToken").String()
|
||||
// set action
|
||||
if obj := json.Get("action"); obj.Exists() {
|
||||
config.Action = json.Get("action").String()
|
||||
} else {
|
||||
config.Action = TextModerationPlus
|
||||
}
|
||||
// set default values
|
||||
config.SetDefaultValues()
|
||||
// set values
|
||||
if obj := json.Get("riskLevelBar"); obj.Exists() {
|
||||
config.RiskLevelBar = obj.String()
|
||||
}
|
||||
if obj := json.Get("requestCheckService"); obj.Exists() {
|
||||
config.RequestCheckService = obj.String()
|
||||
}
|
||||
if obj := json.Get("requestImageCheckService"); obj.Exists() {
|
||||
config.RequestImageCheckService = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseCheckService"); obj.Exists() {
|
||||
config.ResponseCheckService = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseImageCheckService"); obj.Exists() {
|
||||
config.ResponseImageCheckService = obj.String()
|
||||
}
|
||||
config.CheckRequest = json.Get("checkRequest").Bool()
|
||||
config.CheckRequestImage = json.Get("checkRequestImage").Bool()
|
||||
config.CheckResponse = json.Get("checkResponse").Bool()
|
||||
config.ProtocolOriginal = json.Get("protocol").String() == "original"
|
||||
config.DenyMessage = json.Get("denyMessage").String()
|
||||
if obj := json.Get("denyCode"); obj.Exists() {
|
||||
config.DenyCode = obj.Int()
|
||||
}
|
||||
if obj := json.Get("requestContentJsonPath"); obj.Exists() {
|
||||
config.RequestContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseContentJsonPath"); obj.Exists() {
|
||||
config.ResponseContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("responseStreamContentJsonPath"); obj.Exists() {
|
||||
config.ResponseStreamContentJsonPath = obj.String()
|
||||
}
|
||||
if obj := json.Get("contentModerationLevelBar"); obj.Exists() {
|
||||
config.ContentModerationLevelBar = obj.String()
|
||||
if LevelToInt(config.ContentModerationLevelBar) <= 0 {
|
||||
return errors.New("invalid contentModerationLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("promptAttackLevelBar"); obj.Exists() {
|
||||
config.PromptAttackLevelBar = obj.String()
|
||||
if LevelToInt(config.PromptAttackLevelBar) <= 0 {
|
||||
return errors.New("invalid promptAttackLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("sensitiveDataLevelBar"); obj.Exists() {
|
||||
config.SensitiveDataLevelBar = obj.String()
|
||||
if LevelToInt(config.SensitiveDataLevelBar) <= 0 {
|
||||
return errors.New("invalid sensitiveDataLevelBar, value must be one of [S4, S3, S2, S1]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("modelHallucinationLevelBar"); obj.Exists() {
|
||||
config.ModelHallucinationLevelBar = obj.String()
|
||||
if LevelToInt(config.ModelHallucinationLevelBar) <= 0 {
|
||||
return errors.New("invalid modelHallucinationLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("maliciousUrlLevelBar"); obj.Exists() {
|
||||
config.MaliciousUrlLevelBar = obj.String()
|
||||
if LevelToInt(config.MaliciousUrlLevelBar) <= 0 {
|
||||
return errors.New("invalid maliciousUrlLevelBar, value must be one of [max, high, medium, low]")
|
||||
}
|
||||
}
|
||||
if obj := json.Get("timeout"); obj.Exists() {
|
||||
config.Timeout = uint32(obj.Int())
|
||||
}
|
||||
if obj := json.Get("bufferLimit"); obj.Exists() {
|
||||
config.BufferLimit = int(obj.Int())
|
||||
}
|
||||
if obj := json.Get("consumerRequestCheckService"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerRequestCheckService").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.ConsumerRequestCheckService = append(config.ConsumerRequestCheckService, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("consumerResponseCheckService"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerResponseCheckService").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.ConsumerResponseCheckService = append(config.ConsumerResponseCheckService, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("consumerRiskLevel"); obj.Exists() {
|
||||
for _, item := range json.Get("consumerRiskLevel").Array() {
|
||||
m := make(map[string]interface{})
|
||||
for k, v := range item.Map() {
|
||||
m[k] = v.Value()
|
||||
}
|
||||
consumerName, ok1 := m["name"]
|
||||
matchType, ok2 := m["matchType"]
|
||||
if !ok1 || !ok2 {
|
||||
continue
|
||||
}
|
||||
switch fmt.Sprint(matchType) {
|
||||
case "exact":
|
||||
m["matcher"] = Matcher{Exact: fmt.Sprint(consumerName)}
|
||||
case "prefix":
|
||||
m["matcher"] = Matcher{Prefix: fmt.Sprint(consumerName)}
|
||||
case "regexp":
|
||||
m["matcher"] = Matcher{Re: regexp.MustCompile(fmt.Sprint(consumerName))}
|
||||
}
|
||||
config.ConsumerRiskLevel = append(config.ConsumerRiskLevel, m)
|
||||
}
|
||||
}
|
||||
if obj := json.Get("apiType"); obj.Exists() {
|
||||
config.ApiType = obj.String()
|
||||
}
|
||||
if obj := json.Get("providerType"); obj.Exists() {
|
||||
config.ProviderType = obj.String()
|
||||
}
|
||||
config.Client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: servicePort,
|
||||
Host: serviceHost,
|
||||
})
|
||||
config.Metrics = make(map[string]proxywasm.MetricCounter)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) SetDefaultValues() {
|
||||
switch config.Action {
|
||||
case TextModerationPlus:
|
||||
config.RequestCheckService = DefaultTextModerationPlusTextInputCheckService
|
||||
config.ResponseCheckService = DefaultTextModerationPlusTextOutputCheckService
|
||||
case MultiModalGuard:
|
||||
config.RequestCheckService = DefaultMultiModalGuardTextInputCheckService
|
||||
config.RequestImageCheckService = DefaultMultiModalGuardImageInputCheckService
|
||||
config.ResponseCheckService = DefaultMultiModalGuardTextOutputCheckService
|
||||
}
|
||||
config.RiskLevelBar = HighRisk
|
||||
config.DenyCode = DefaultDenyCode
|
||||
config.RequestContentJsonPath = DefaultRequestJsonPath
|
||||
config.ResponseContentJsonPath = DefaultResponseJsonPath
|
||||
config.ResponseStreamContentJsonPath = DefaultStreamingResponseJsonPath
|
||||
config.ContentModerationLevelBar = MaxRisk
|
||||
config.PromptAttackLevelBar = MaxRisk
|
||||
config.SensitiveDataLevelBar = S4Sensitive
|
||||
config.ModelHallucinationLevelBar = MaxRisk
|
||||
config.MaliciousUrlLevelBar = MaxRisk
|
||||
config.Timeout = DefaultTimeout
|
||||
config.BufferLimit = 1000
|
||||
config.ApiType = ApiTextGeneration
|
||||
config.ProviderType = ProviderOpenAI
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) IncrementCounter(metricName string, inc uint64) {
|
||||
counter, ok := config.Metrics[metricName]
|
||||
if !ok {
|
||||
counter = proxywasm.DefineCounterMetric(metricName)
|
||||
config.Metrics[metricName] = counter
|
||||
}
|
||||
counter.Increment(inc)
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetRequestCheckService(consumer string) string {
|
||||
result := config.RequestCheckService
|
||||
for _, obj := range config.ConsumerRequestCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if requestCheckService, ok := obj["requestCheckService"]; ok {
|
||||
result, _ = requestCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetRequestImageCheckService(consumer string) string {
|
||||
result := config.RequestImageCheckService
|
||||
for _, obj := range config.ConsumerRequestCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if requestCheckService, ok := obj["requestImageCheckService"]; ok {
|
||||
result, _ = requestCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetResponseCheckService(consumer string) string {
|
||||
result := config.ResponseCheckService
|
||||
for _, obj := range config.ConsumerResponseCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if responseCheckService, ok := obj["responseCheckService"]; ok {
|
||||
result, _ = responseCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetResponseImageCheckService(consumer string) string {
|
||||
result := config.ResponseImageCheckService
|
||||
for _, obj := range config.ConsumerResponseCheckService {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if responseCheckService, ok := obj["responseImageCheckService"]; ok {
|
||||
result, _ = responseCheckService.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetRiskLevelBar(consumer string) string {
|
||||
result := config.RiskLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if riskLevelBar, ok := obj["riskLevelBar"]; ok {
|
||||
result, _ = riskLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetContentModerationLevelBar(consumer string) string {
|
||||
result := config.ContentModerationLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if contentModerationLevelBar, ok := obj["contentModerationLevelBar"]; ok {
|
||||
result, _ = contentModerationLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetPromptAttackLevelBar(consumer string) string {
|
||||
result := config.PromptAttackLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if promptAttackLevelBar, ok := obj["promptAttackLevelBar"]; ok {
|
||||
result, _ = promptAttackLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetSensitiveDataLevelBar(consumer string) string {
|
||||
result := config.SensitiveDataLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if sensitiveDataLevelBar, ok := obj["sensitiveDataLevelBar"]; ok {
|
||||
result, _ = sensitiveDataLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetMaliciousUrlLevelBar(consumer string) string {
|
||||
result := config.MaliciousUrlLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if maliciousUrlLevelBar, ok := obj["maliciousUrlLevelBar"]; ok {
|
||||
result, _ = maliciousUrlLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (config *AISecurityConfig) GetModelHallucinationLevelBar(consumer string) string {
|
||||
result := config.ModelHallucinationLevelBar
|
||||
for _, obj := range config.ConsumerRiskLevel {
|
||||
if matcher, ok := obj["matcher"].(Matcher); ok {
|
||||
if matcher.match(consumer) {
|
||||
if modelHallucinationLevelBar, ok := obj["modelHallucinationLevelBar"]; ok {
|
||||
result, _ = modelHallucinationLevelBar.(string)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func LevelToInt(riskLevel string) int {
|
||||
// First check against our defined constants
|
||||
switch strings.ToLower(riskLevel) {
|
||||
case MaxRisk, S4Sensitive:
|
||||
return 4
|
||||
case HighRisk, S3Sensitive:
|
||||
return 3
|
||||
case MediumRisk, S2Sensitive:
|
||||
return 2
|
||||
case LowRisk, S1Sensitive:
|
||||
return 1
|
||||
case NoRisk, NoSensitive:
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func IsRiskLevelAcceptable(action string, data Data, config AISecurityConfig, consumer string) bool {
|
||||
if action == MultiModalGuard || action == MultiModalGuardForBase64 {
|
||||
// Check top-level risk levels for MultiModalGuard
|
||||
if LevelToInt(data.RiskLevel) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
// Also check AttackLevel for prompt attack detection
|
||||
if LevelToInt(data.AttackLevel) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check detailed results for backward compatibility
|
||||
for _, detail := range data.Detail {
|
||||
switch detail.Type {
|
||||
case ContentModerationType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetContentModerationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case PromptAttackType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetPromptAttackLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case SensitiveDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetSensitiveDataLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case MaliciousUrlDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetMaliciousUrlLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
case ModelHallucinationDataType:
|
||||
if LevelToInt(detail.Level) >= LevelToInt(config.GetModelHallucinationLevelBar(consumer)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
} else {
|
||||
return LevelToInt(data.RiskLevel) < LevelToInt(config.GetRiskLevelBar(consumer))
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,8 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
@@ -20,5 +20,6 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -4,8 +4,12 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2 h1:NY33OrWCJJ+DFiLc+lsBY4Ywor2Ik61ssk6qkGF8Ypo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20251103120604-77e9cce339d2/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd h1:acTs8sqXf+qP+IypxFg3cu5Cluj7VT5BI+IDRlY5sag=
|
||||
github.com/higress-group/wasm-go v1.0.7-0.20251118110253-ba77116c6ddd/go.mod h1:uKVYICbRaxTlKqdm8E0dpjbysxM8uCPb9LV26hF3Km8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
@@ -24,6 +28,8 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
ALGORITHM = "ACS3-HMAC-SHA256"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
httpMethod string
|
||||
canonicalUri string
|
||||
host string
|
||||
xAcsAction string
|
||||
xAcsVersion string
|
||||
headers map[string]string
|
||||
body []byte
|
||||
queryParam map[string]interface{}
|
||||
}
|
||||
|
||||
func newRequest(httpMethod, canonicalUri, host, xAcsAction, xAcsVersion string) *Request {
|
||||
req := &Request{
|
||||
httpMethod: httpMethod,
|
||||
canonicalUri: canonicalUri,
|
||||
host: host,
|
||||
xAcsAction: xAcsAction,
|
||||
xAcsVersion: xAcsVersion,
|
||||
headers: make(map[string]string),
|
||||
queryParam: make(map[string]interface{}),
|
||||
}
|
||||
req.headers["host"] = host
|
||||
req.headers["x-acs-action"] = xAcsAction
|
||||
req.headers["x-acs-version"] = xAcsVersion
|
||||
req.headers["x-acs-date"] = time.Now().UTC().Format(time.RFC3339)
|
||||
req.headers["x-acs-signature-nonce"] = uuid.New().String()
|
||||
return req
|
||||
}
|
||||
|
||||
func getAuthorization(req *Request, AccessKeyId, AccessKeySecret, SecurityToken string) {
|
||||
newQueryParams := make(map[string]interface{})
|
||||
processObject(newQueryParams, "", req.queryParam)
|
||||
req.queryParam = newQueryParams
|
||||
canonicalQueryString := ""
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
canonicalQueryString += percentCode(url.QueryEscape(k)) + "=" + percentCode(url.QueryEscape(fmt.Sprintf("%v", v))) + "&"
|
||||
}
|
||||
canonicalQueryString = strings.TrimSuffix(canonicalQueryString, "&")
|
||||
|
||||
var bodyContent []byte
|
||||
if req.body == nil {
|
||||
bodyContent = []byte("")
|
||||
} else {
|
||||
bodyContent = req.body
|
||||
}
|
||||
hashedRequestPayload := sha256Hex(bodyContent)
|
||||
req.headers["x-acs-content-sha256"] = hashedRequestPayload
|
||||
|
||||
if SecurityToken != "" {
|
||||
req.headers["x-acs-security-token"] = SecurityToken
|
||||
}
|
||||
|
||||
canonicalHeaders := ""
|
||||
signedHeaders := ""
|
||||
HeadersKeys := maps.Keys(req.headers)
|
||||
sort.Strings(HeadersKeys)
|
||||
for _, k := range HeadersKeys {
|
||||
lowerKey := strings.ToLower(k)
|
||||
if lowerKey == "host" || strings.HasPrefix(lowerKey, "x-acs-") || lowerKey == "content-type" {
|
||||
canonicalHeaders += lowerKey + ":" + req.headers[k] + "\n"
|
||||
signedHeaders += lowerKey + ";"
|
||||
}
|
||||
}
|
||||
signedHeaders = strings.TrimSuffix(signedHeaders, ";")
|
||||
|
||||
canonicalRequest := req.httpMethod + "\n" + req.canonicalUri + "\n" + canonicalQueryString + "\n" + canonicalHeaders + "\n" + signedHeaders + "\n" + hashedRequestPayload
|
||||
|
||||
hashedCanonicalRequest := sha256Hex([]byte(canonicalRequest))
|
||||
stringToSign := ALGORITHM + "\n" + hashedCanonicalRequest
|
||||
|
||||
byteData, err := hmac256([]byte(AccessKeySecret), stringToSign)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
panic(err)
|
||||
}
|
||||
signature := strings.ToLower(hex.EncodeToString(byteData))
|
||||
|
||||
authorization := ALGORITHM + " Credential=" + AccessKeyId + ",SignedHeaders=" + signedHeaders + ",Signature=" + signature
|
||||
req.headers["Authorization"] = authorization
|
||||
}
|
||||
|
||||
func hmac256(key []byte, toSignString string) ([]byte, error) {
|
||||
h := hmac.New(sha256.New, key)
|
||||
_, err := h.Write([]byte(toSignString))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return h.Sum(nil), nil
|
||||
}
|
||||
|
||||
func sha256Hex(byteArray []byte) string {
|
||||
hash := sha256.New()
|
||||
_, _ = hash.Write(byteArray)
|
||||
hexString := hex.EncodeToString(hash.Sum(nil))
|
||||
return hexString
|
||||
}
|
||||
|
||||
func percentCode(str string) string {
|
||||
str = strings.ReplaceAll(str, "+", "%20")
|
||||
str = strings.ReplaceAll(str, "*", "%2A")
|
||||
str = strings.ReplaceAll(str, "%7E", "~")
|
||||
return str
|
||||
}
|
||||
|
||||
func formDataToString(formData map[string]interface{}) *string {
|
||||
tmp := make(map[string]interface{})
|
||||
processObject(tmp, "", formData)
|
||||
res := ""
|
||||
urlEncoder := url.Values{}
|
||||
for key, value := range tmp {
|
||||
v := fmt.Sprintf("%v", value)
|
||||
urlEncoder.Add(key, v)
|
||||
}
|
||||
res = urlEncoder.Encode()
|
||||
return &res
|
||||
}
|
||||
|
||||
// processObject 递归处理对象,将复杂对象(如Map和List)展开为平面的键值对
|
||||
func processObject(mapResult map[string]interface{}, key string, value interface{}) {
|
||||
if value == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case []interface{}:
|
||||
for i, item := range v {
|
||||
processObject(mapResult, fmt.Sprintf("%s.%d", key, i+1), item)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for subKey, subValue := range v {
|
||||
processObject(mapResult, fmt.Sprintf("%s.%s", key, subKey), subValue)
|
||||
}
|
||||
default:
|
||||
if strings.HasPrefix(key, ".") {
|
||||
key = key[1:]
|
||||
}
|
||||
if b, ok := v.([]byte); ok {
|
||||
mapResult[key] = string(b)
|
||||
} else {
|
||||
mapResult[key] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRequestForText(config cfg.AISecurityConfig, checkAction, checkService, text, sessionID string) (path string, headers [][2]string, reqBody []byte) {
|
||||
httpMethod := "POST"
|
||||
canonicalUri := "/"
|
||||
xAcsVersion := "2022-03-02"
|
||||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||
|
||||
req.queryParam["Service"] = checkService
|
||||
|
||||
body := make(map[string]interface{})
|
||||
serviceParameters := make(map[string]interface{})
|
||||
serviceParameters["content"] = text
|
||||
serviceParameters["sessionId"] = sessionID
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
str := formDataToString(body)
|
||||
req.body = []byte(*str)
|
||||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||
|
||||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||
|
||||
q := url.Values{}
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
for k, v := range req.headers {
|
||||
if k != "host" {
|
||||
headers = append(headers, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return "?" + q.Encode(), headers, req.body
|
||||
}
|
||||
|
||||
func GenerateRequestForImage(config cfg.AISecurityConfig, checkAction, checkService, imgUrl, imgBase64 string) (path string, headers [][2]string, reqBody []byte) {
|
||||
httpMethod := "POST"
|
||||
canonicalUri := "/"
|
||||
xAcsVersion := "2022-03-02"
|
||||
req := newRequest(httpMethod, canonicalUri, config.Host, checkAction, xAcsVersion)
|
||||
|
||||
req.queryParam["Service"] = checkService
|
||||
|
||||
body := make(map[string]interface{})
|
||||
serviceParameters := make(map[string]interface{})
|
||||
if imgUrl != "" {
|
||||
serviceParameters["imageUrls"] = []string{imgUrl}
|
||||
}
|
||||
serviceParameters["requestFrom"] = cfg.AliyunUserAgent
|
||||
serviceParametersJSON, _ := json.Marshal(serviceParameters)
|
||||
body["ServiceParameters"] = serviceParametersJSON
|
||||
if imgBase64 != "" {
|
||||
body["ImageBase64Str"] = imgBase64
|
||||
}
|
||||
str := formDataToString(body)
|
||||
req.body = []byte(*str)
|
||||
req.headers["content-type"] = "application/x-www-form-urlencoded"
|
||||
req.headers["User-Agent"] = cfg.AliyunUserAgent
|
||||
|
||||
getAuthorization(req, config.AK, config.SK, config.Token)
|
||||
|
||||
q := url.Values{}
|
||||
keys := maps.Keys(req.queryParam)
|
||||
sort.Strings(keys)
|
||||
for _, k := range keys {
|
||||
v := req.queryParam[k]
|
||||
q.Set(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
for k, v := range req.headers {
|
||||
// host will be added by envoy automatically
|
||||
if k != "host" {
|
||||
headers = append(headers, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
return "?" + q.Encode(), headers, req.body
|
||||
}
|
||||
@@ -0,0 +1,634 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSha256Hex(t *testing.T) {
|
||||
t.Run("empty input", func(t *testing.T) {
|
||||
result := sha256Hex([]byte(""))
|
||||
// SHA256 of empty string
|
||||
expected := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
||||
require.Equal(t, expected, result)
|
||||
})
|
||||
|
||||
t.Run("simple string", func(t *testing.T) {
|
||||
result := sha256Hex([]byte("hello"))
|
||||
expected := "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
require.Equal(t, expected, result)
|
||||
})
|
||||
|
||||
t.Run("unicode string", func(t *testing.T) {
|
||||
result := sha256Hex([]byte("你好"))
|
||||
// Just verify it returns a valid hex string
|
||||
require.Len(t, result, 64)
|
||||
_, err := hex.DecodeString(result)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHmac256(t *testing.T) {
|
||||
t.Run("valid hmac", func(t *testing.T) {
|
||||
key := []byte("test-key")
|
||||
message := "test-message"
|
||||
result, err := hmac256(key, message)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result, 32) // SHA256 produces 32 bytes
|
||||
})
|
||||
|
||||
t.Run("empty key", func(t *testing.T) {
|
||||
key := []byte("")
|
||||
message := "test-message"
|
||||
result, err := hmac256(key, message)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result, 32)
|
||||
})
|
||||
|
||||
t.Run("empty message", func(t *testing.T) {
|
||||
key := []byte("test-key")
|
||||
message := ""
|
||||
result, err := hmac256(key, message)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, result, 32)
|
||||
})
|
||||
|
||||
t.Run("verify hmac consistency", func(t *testing.T) {
|
||||
key := []byte("test-key")
|
||||
message := "test-message"
|
||||
result1, err1 := hmac256(key, message)
|
||||
result2, err2 := hmac256(key, message)
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
require.Equal(t, result1, result2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPercentCode(t *testing.T) {
|
||||
t.Run("replace plus sign", func(t *testing.T) {
|
||||
input := "test+value"
|
||||
result := percentCode(input)
|
||||
require.Equal(t, "test%20value", result)
|
||||
})
|
||||
|
||||
t.Run("replace asterisk", func(t *testing.T) {
|
||||
input := "test*value"
|
||||
result := percentCode(input)
|
||||
require.Equal(t, "test%2Avalue", result)
|
||||
})
|
||||
|
||||
t.Run("replace tilde encoding", func(t *testing.T) {
|
||||
input := "test%7Evalue"
|
||||
result := percentCode(input)
|
||||
require.Equal(t, "test~value", result)
|
||||
})
|
||||
|
||||
t.Run("multiple replacements", func(t *testing.T) {
|
||||
input := "test+value*test%7E"
|
||||
result := percentCode(input)
|
||||
require.Equal(t, "test%20value%2Atest~", result)
|
||||
})
|
||||
|
||||
t.Run("no replacements needed", func(t *testing.T) {
|
||||
input := "test-value"
|
||||
result := percentCode(input)
|
||||
require.Equal(t, "test-value", result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestProcessObject(t *testing.T) {
|
||||
t.Run("simple string value", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
processObject(result, "key", "value")
|
||||
require.Equal(t, "value", result["key"])
|
||||
})
|
||||
|
||||
t.Run("simple int value", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
processObject(result, "key", 123)
|
||||
require.Equal(t, "123", result["key"])
|
||||
})
|
||||
|
||||
t.Run("nil value", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
processObject(result, "key", nil)
|
||||
require.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("map value", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
input := map[string]interface{}{
|
||||
"subkey1": "value1",
|
||||
"subkey2": "value2",
|
||||
}
|
||||
processObject(result, "key", input)
|
||||
require.Equal(t, "value1", result["key.subkey1"])
|
||||
require.Equal(t, "value2", result["key.subkey2"])
|
||||
})
|
||||
|
||||
t.Run("array value", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
input := []interface{}{"item1", "item2", "item3"}
|
||||
processObject(result, "key", input)
|
||||
require.Equal(t, "item1", result["key.1"])
|
||||
require.Equal(t, "item2", result["key.2"])
|
||||
require.Equal(t, "item3", result["key.3"])
|
||||
})
|
||||
|
||||
t.Run("nested map", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
input := map[string]interface{}{
|
||||
"level1": map[string]interface{}{
|
||||
"level2": "value",
|
||||
},
|
||||
}
|
||||
processObject(result, "key", input)
|
||||
require.Equal(t, "value", result["key.level1.level2"])
|
||||
})
|
||||
|
||||
t.Run("nested array", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
input := []interface{}{
|
||||
[]interface{}{"nested1", "nested2"},
|
||||
}
|
||||
processObject(result, "key", input)
|
||||
require.Equal(t, "nested1", result["key.1.1"])
|
||||
require.Equal(t, "nested2", result["key.1.2"])
|
||||
})
|
||||
|
||||
t.Run("key with leading dot", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
processObject(result, ".key", "value")
|
||||
require.Equal(t, "value", result["key"])
|
||||
})
|
||||
|
||||
t.Run("byte array value", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
input := []byte("test")
|
||||
processObject(result, "key", input)
|
||||
require.Equal(t, "test", result["key"])
|
||||
})
|
||||
|
||||
t.Run("complex nested structure", func(t *testing.T) {
|
||||
result := make(map[string]interface{})
|
||||
input := map[string]interface{}{
|
||||
"array": []interface{}{
|
||||
map[string]interface{}{
|
||||
"item": "value",
|
||||
},
|
||||
},
|
||||
}
|
||||
processObject(result, "key", input)
|
||||
require.Equal(t, "value", result["key.array.1.item"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormDataToString(t *testing.T) {
|
||||
t.Run("simple map", func(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
}
|
||||
result := formDataToString(input)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, *result, "key1=value1")
|
||||
require.Contains(t, *result, "key2=value2")
|
||||
})
|
||||
|
||||
t.Run("map with array", func(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"key": []interface{}{"item1", "item2"},
|
||||
}
|
||||
result := formDataToString(input)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, *result, "key.1=item1")
|
||||
require.Contains(t, *result, "key.2=item2")
|
||||
})
|
||||
|
||||
t.Run("map with nested map", func(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"key": map[string]interface{}{
|
||||
"subkey": "value",
|
||||
},
|
||||
}
|
||||
result := formDataToString(input)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, *result, "key.subkey=value")
|
||||
})
|
||||
|
||||
t.Run("empty map", func(t *testing.T) {
|
||||
input := map[string]interface{}{}
|
||||
result := formDataToString(input)
|
||||
require.NotNil(t, result)
|
||||
require.Empty(t, *result)
|
||||
})
|
||||
|
||||
t.Run("map with nil value", func(t *testing.T) {
|
||||
input := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": nil,
|
||||
}
|
||||
result := formDataToString(input)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, *result, "key1=value1")
|
||||
require.NotContains(t, *result, "key2")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateRequestForText(t *testing.T) {
|
||||
config := cfg.AISecurityConfig{
|
||||
Host: "security.example.com",
|
||||
AK: "test-ak",
|
||||
SK: "test-sk",
|
||||
Token: "",
|
||||
}
|
||||
|
||||
t.Run("basic text request", func(t *testing.T) {
|
||||
path, headers, body := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.True(t, strings.HasPrefix(path, "?"))
|
||||
require.Contains(t, path, "Service=llm_query_moderation")
|
||||
|
||||
require.NotEmpty(t, headers)
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
require.Equal(t, "TextModerationPlus", headerMap["x-acs-action"])
|
||||
require.Equal(t, "2022-03-02", headerMap["x-acs-version"])
|
||||
require.Equal(t, "application/x-www-form-urlencoded", headerMap["content-type"])
|
||||
require.Equal(t, cfg.AliyunUserAgent, headerMap["User-Agent"])
|
||||
require.Contains(t, headerMap, "Authorization")
|
||||
require.Contains(t, headerMap, "x-acs-date")
|
||||
require.Contains(t, headerMap, "x-acs-signature-nonce")
|
||||
require.Contains(t, headerMap, "x-acs-content-sha256")
|
||||
|
||||
require.NotEmpty(t, body)
|
||||
bodyStr := string(body)
|
||||
require.Contains(t, bodyStr, "ServiceParameters")
|
||||
// Body is URL encoded, so decode it to check content
|
||||
decodedBody, err := url.QueryUnescape(bodyStr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, decodedBody, "test content")
|
||||
require.Contains(t, decodedBody, "test-session-id")
|
||||
require.Contains(t, decodedBody, cfg.AliyunUserAgent)
|
||||
})
|
||||
|
||||
t.Run("request with security token", func(t *testing.T) {
|
||||
configWithToken := config
|
||||
configWithToken.Token = "test-token"
|
||||
path, headers, body := GenerateRequestForText(
|
||||
configWithToken,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.NotEmpty(t, headers)
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
require.Equal(t, "test-token", headerMap["x-acs-security-token"])
|
||||
require.NotEmpty(t, body)
|
||||
})
|
||||
|
||||
t.Run("empty content", func(t *testing.T) {
|
||||
path, headers, body := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.NotEmpty(t, headers)
|
||||
require.NotEmpty(t, body)
|
||||
bodyStr := string(body)
|
||||
require.Contains(t, bodyStr, "ServiceParameters")
|
||||
decodedBody, err := url.QueryUnescape(bodyStr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, decodedBody, `"content":""`)
|
||||
})
|
||||
|
||||
t.Run("different check service", func(t *testing.T) {
|
||||
path, headers, body := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_response_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
require.Contains(t, path, "Service=llm_response_moderation")
|
||||
require.NotEmpty(t, headers)
|
||||
require.NotEmpty(t, body)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateRequestForImage(t *testing.T) {
|
||||
config := cfg.AISecurityConfig{
|
||||
Host: "security.example.com",
|
||||
AK: "test-ak",
|
||||
SK: "test-sk",
|
||||
Token: "",
|
||||
}
|
||||
|
||||
t.Run("image request with URL", func(t *testing.T) {
|
||||
path, headers, body := GenerateRequestForImage(
|
||||
config,
|
||||
"MultiModalGuard",
|
||||
"llm_image_moderation",
|
||||
"https://example.com/image.jpg",
|
||||
"",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.True(t, strings.HasPrefix(path, "?"))
|
||||
require.Contains(t, path, "Service=llm_image_moderation")
|
||||
|
||||
require.NotEmpty(t, headers)
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
require.Equal(t, "MultiModalGuard", headerMap["x-acs-action"])
|
||||
require.Equal(t, "2022-03-02", headerMap["x-acs-version"])
|
||||
require.Equal(t, "application/x-www-form-urlencoded", headerMap["content-type"])
|
||||
require.Equal(t, cfg.AliyunUserAgent, headerMap["User-Agent"])
|
||||
require.Contains(t, headerMap, "Authorization")
|
||||
require.Contains(t, headerMap, "x-acs-date")
|
||||
require.Contains(t, headerMap, "x-acs-signature-nonce")
|
||||
require.Contains(t, headerMap, "x-acs-content-sha256")
|
||||
|
||||
require.NotEmpty(t, body)
|
||||
bodyStr := string(body)
|
||||
require.Contains(t, bodyStr, "ServiceParameters")
|
||||
decodedBody, err := url.QueryUnescape(bodyStr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, decodedBody, "https://example.com/image.jpg")
|
||||
require.Contains(t, decodedBody, cfg.AliyunUserAgent)
|
||||
})
|
||||
|
||||
t.Run("image request with base64", func(t *testing.T) {
|
||||
base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
path, headers, body := GenerateRequestForImage(
|
||||
config,
|
||||
"MultiModalGuard",
|
||||
"llm_image_moderation",
|
||||
"",
|
||||
base64Data,
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.NotEmpty(t, headers)
|
||||
require.NotEmpty(t, body)
|
||||
bodyStr := string(body)
|
||||
require.Contains(t, bodyStr, "ImageBase64Str")
|
||||
// Base64 data is URL encoded, decode to check
|
||||
decodedBody, err := url.QueryUnescape(bodyStr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, decodedBody, base64Data)
|
||||
})
|
||||
|
||||
t.Run("image request with both URL and base64", func(t *testing.T) {
|
||||
path, headers, body := GenerateRequestForImage(
|
||||
config,
|
||||
"MultiModalGuard",
|
||||
"llm_image_moderation",
|
||||
"https://example.com/image.jpg",
|
||||
"base64data",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.NotEmpty(t, headers)
|
||||
require.NotEmpty(t, body)
|
||||
bodyStr := string(body)
|
||||
require.Contains(t, bodyStr, "ImageBase64Str")
|
||||
decodedBody, err := url.QueryUnescape(bodyStr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, decodedBody, "https://example.com/image.jpg")
|
||||
require.Contains(t, decodedBody, "base64data")
|
||||
})
|
||||
|
||||
t.Run("image request with security token", func(t *testing.T) {
|
||||
configWithToken := config
|
||||
configWithToken.Token = "test-token"
|
||||
path, headers, body := GenerateRequestForImage(
|
||||
configWithToken,
|
||||
"MultiModalGuard",
|
||||
"llm_image_moderation",
|
||||
"https://example.com/image.jpg",
|
||||
"",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.NotEmpty(t, headers)
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
require.Equal(t, "test-token", headerMap["x-acs-security-token"])
|
||||
require.NotEmpty(t, body)
|
||||
})
|
||||
|
||||
t.Run("empty image URL and base64", func(t *testing.T) {
|
||||
path, headers, body := GenerateRequestForImage(
|
||||
config,
|
||||
"MultiModalGuard",
|
||||
"llm_image_moderation",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
require.NotEmpty(t, path)
|
||||
require.NotEmpty(t, headers)
|
||||
require.NotEmpty(t, body)
|
||||
bodyStr := string(body)
|
||||
require.Contains(t, bodyStr, "ServiceParameters")
|
||||
decodedBody, err := url.QueryUnescape(bodyStr)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, decodedBody, cfg.AliyunUserAgent)
|
||||
require.NotContains(t, decodedBody, "imageUrls")
|
||||
require.NotContains(t, decodedBody, "ImageBase64Str")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
// Test newRequest indirectly through GenerateRequestForText
|
||||
// Since it's a private function, we test it through public API
|
||||
t.Run("request structure", func(t *testing.T) {
|
||||
config := cfg.AISecurityConfig{
|
||||
Host: "security.example.com",
|
||||
AK: "test-ak",
|
||||
SK: "test-sk",
|
||||
Token: "",
|
||||
}
|
||||
|
||||
path, headers, _ := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test",
|
||||
"session-id",
|
||||
)
|
||||
|
||||
// Verify that newRequest was called correctly by checking headers
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
// Verify headers set by newRequest
|
||||
require.Equal(t, "TextModerationPlus", headerMap["x-acs-action"])
|
||||
require.Equal(t, "2022-03-02", headerMap["x-acs-version"])
|
||||
require.Contains(t, headerMap, "x-acs-date")
|
||||
require.Contains(t, headerMap, "x-acs-signature-nonce")
|
||||
require.NotEmpty(t, path)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetAuthorization(t *testing.T) {
|
||||
// Test getAuthorization indirectly through GenerateRequestForText
|
||||
// Since it's a private function, we test it through public API
|
||||
t.Run("authorization header format", func(t *testing.T) {
|
||||
config := cfg.AISecurityConfig{
|
||||
Host: "security.example.com",
|
||||
AK: "test-ak",
|
||||
SK: "test-sk",
|
||||
Token: "",
|
||||
}
|
||||
|
||||
_, headers, _ := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
authHeader := headerMap["Authorization"]
|
||||
require.NotEmpty(t, authHeader)
|
||||
require.Contains(t, authHeader, "ACS3-HMAC-SHA256")
|
||||
require.Contains(t, authHeader, "Credential=test-ak")
|
||||
require.Contains(t, authHeader, "SignedHeaders=")
|
||||
require.Contains(t, authHeader, "Signature=")
|
||||
|
||||
// Verify content SHA256 is set
|
||||
require.Contains(t, headerMap, "x-acs-content-sha256")
|
||||
require.Len(t, headerMap["x-acs-content-sha256"], 64) // SHA256 hex string length
|
||||
})
|
||||
|
||||
t.Run("authorization with security token", func(t *testing.T) {
|
||||
config := cfg.AISecurityConfig{
|
||||
Host: "security.example.com",
|
||||
AK: "test-ak",
|
||||
SK: "test-sk",
|
||||
Token: "test-token",
|
||||
}
|
||||
|
||||
_, headers, _ := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h[0]] = h[1]
|
||||
}
|
||||
|
||||
require.Equal(t, "test-token", headerMap["x-acs-security-token"])
|
||||
require.Contains(t, headerMap, "Authorization")
|
||||
})
|
||||
|
||||
t.Run("authorization signature consistency", func(t *testing.T) {
|
||||
config := cfg.AISecurityConfig{
|
||||
Host: "security.example.com",
|
||||
AK: "test-ak",
|
||||
SK: "test-sk",
|
||||
Token: "",
|
||||
}
|
||||
|
||||
// Generate two requests with same content
|
||||
_, headers1, body1 := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
_, headers2, body2 := GenerateRequestForText(
|
||||
config,
|
||||
"TextModerationPlus",
|
||||
"llm_query_moderation",
|
||||
"test content",
|
||||
"test-session-id",
|
||||
)
|
||||
|
||||
// Bodies should be the same (except for sessionId which is random)
|
||||
require.NotEmpty(t, body1)
|
||||
require.NotEmpty(t, body2)
|
||||
|
||||
// Headers should have authorization
|
||||
headerMap1 := make(map[string]string)
|
||||
for _, h := range headers1 {
|
||||
headerMap1[h[0]] = h[1]
|
||||
}
|
||||
|
||||
headerMap2 := make(map[string]string)
|
||||
for _, h := range headers2 {
|
||||
headerMap2[h[0]] = h[1]
|
||||
}
|
||||
|
||||
require.Contains(t, headerMap1, "Authorization")
|
||||
require.Contains(t, headerMap2, "Authorization")
|
||||
// Signatures will be different due to nonce and timestamp, but format should be same
|
||||
require.Contains(t, headerMap1["Authorization"], "ACS3-HMAC-SHA256")
|
||||
require.Contains(t, headerMap2["Authorization"], "ACS3-HMAC-SHA256")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func HandleTextGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
ctx.SetContext("end_of_stream_received", false)
|
||||
ctx.SetContext("during_call", false)
|
||||
ctx.SetContext("risk_detected", false)
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
ctx.SetContext("sessionID", sessionID)
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.NeedPauseStreamingResponse()
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.BufferResponseBody()
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
}
|
||||
|
||||
func HandleTextGenerationStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
var sessionID string
|
||||
if ctx.GetContext("sessionID") == nil {
|
||||
sessionID, _ = utils.GenerateHexID(20)
|
||||
ctx.SetContext("sessionID", sessionID)
|
||||
} else {
|
||||
sessionID, _ = ctx.GetContext("sessionID").(string)
|
||||
}
|
||||
var bufferQueue [][]byte
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
ctx.SetContext("during_call", false)
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
ctx.SetContext("during_call", false)
|
||||
return
|
||||
}
|
||||
if !cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = "\n" + response.Data.Advice[0].Answer
|
||||
} else if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.InjectEncodedDataToFilterChain(jsonData, true)
|
||||
return
|
||||
}
|
||||
endStream := ctx.GetContext("end_of_stream_received").(bool) && ctx.BufferQueueSize() == 0
|
||||
proxywasm.InjectEncodedDataToFilterChain(bytes.Join(bufferQueue, []byte("")), endStream)
|
||||
bufferQueue = [][]byte{}
|
||||
if !endStream {
|
||||
ctx.SetContext("during_call", false)
|
||||
singleCall()
|
||||
}
|
||||
}
|
||||
singleCall = func() {
|
||||
if ctx.GetContext("during_call").(bool) {
|
||||
return
|
||||
}
|
||||
if ctx.BufferQueueSize() >= config.BufferLimit || ctx.GetContext("end_of_stream_received").(bool) {
|
||||
var buffer string
|
||||
for ctx.BufferQueueSize() > 0 {
|
||||
front := ctx.PopBuffer()
|
||||
bufferQueue = append(bufferQueue, front)
|
||||
msg := gjson.GetBytes(front, config.ResponseStreamContentJsonPath).String()
|
||||
buffer += msg
|
||||
if len([]rune(buffer)) >= config.BufferLimit {
|
||||
break
|
||||
}
|
||||
}
|
||||
// case 1: streaming body has reasoning_content, part of buffer maybe empty
|
||||
// case 2: streaming body has toolcall result, part of buffer maybe empty
|
||||
log.Debugf("current content piece: %s", buffer)
|
||||
if len(buffer) == 0 {
|
||||
buffer = "[empty content]"
|
||||
}
|
||||
ctx.SetContext("during_call", true)
|
||||
log.Debugf("current content piece: %s", buffer)
|
||||
checkService := config.GetResponseCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, buffer, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
if ctx.GetContext("end_of_stream_received").(bool) {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !ctx.GetContext("risk_detected").(bool) {
|
||||
unifiedChunk := wrapper.UnifySSEChunk(data)
|
||||
hasTrailingSeparator := bytes.HasSuffix(unifiedChunk, []byte("\n\n"))
|
||||
trimmedChunk := bytes.TrimSpace(unifiedChunk)
|
||||
chunks := bytes.Split(trimmedChunk, []byte("\n\n"))
|
||||
// Filter out empty chunks
|
||||
nonEmptyChunks := make([][]byte, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
if len(chunk) > 0 {
|
||||
nonEmptyChunks = append(nonEmptyChunks, chunk)
|
||||
}
|
||||
}
|
||||
// Restore separators
|
||||
for i := range len(nonEmptyChunks) - 1 {
|
||||
nonEmptyChunks[i] = append(nonEmptyChunks[i], []byte("\n\n")...)
|
||||
}
|
||||
if hasTrailingSeparator && len(nonEmptyChunks) > 0 {
|
||||
nonEmptyChunks[len(nonEmptyChunks)-1] = append(nonEmptyChunks[len(nonEmptyChunks)-1], []byte("\n\n")...)
|
||||
}
|
||||
for _, chunk := range nonEmptyChunks {
|
||||
ctx.PushBuffer(chunk)
|
||||
}
|
||||
// for _, chunk := range bytes.Split(bytes.TrimSpace(wrapper.UnifySSEChunk(data)), []byte("\n\n")) {
|
||||
// ctx.PushBuffer([]byte(string(chunk) + "\n\n"))
|
||||
// }
|
||||
ctx.SetContext("end_of_stream_received", endOfStream)
|
||||
if !ctx.GetContext("during_call").(bool) {
|
||||
singleCall()
|
||||
}
|
||||
} else if endOfStream {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func HandleTextGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
content = utils.ExtractMessageFromStreamingBody(body, config.ResponseStreamContentJsonPath)
|
||||
} else {
|
||||
content = gjson.GetBytes(body, config.ResponseContentJsonPath).String()
|
||||
}
|
||||
log.Debugf("Raw response content is: %s", content)
|
||||
if len(content) == 0 {
|
||||
log.Info("response content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at response phase")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if isStreamingResponse {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
config.IncrementCounter("ai_sec_response_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
checkService := config.GetResponseCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, config.Action, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package multi_modal_guard
|
||||
|
||||
import (
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/image"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/multi_modal_guard/text"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||
case cfg.ApiImageGeneration:
|
||||
switch config.ProviderType {
|
||||
case cfg.ProviderOpenAI:
|
||||
return image.HandleOpenAIImageGenerationRequestBody(ctx, config, body)
|
||||
case cfg.ProviderQwen:
|
||||
return image.HandleQwenImageGenerationRequestBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("[on request body] image generation api don't support provider: %s", config.ProviderType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
default:
|
||||
log.Errorf("[on request body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseHeader(ctx, config)
|
||||
case cfg.ApiImageGeneration:
|
||||
switch config.ProviderType {
|
||||
case cfg.ProviderOpenAI, cfg.ProviderQwen:
|
||||
return image.HandleImageGenerationResponseHeader(ctx, config)
|
||||
default:
|
||||
log.Errorf("[on response header] image generation api don't support provider: %s", config.ProviderType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
default:
|
||||
log.Errorf("[on response header] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||
default:
|
||||
log.Errorf("[on streaming response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
|
||||
case cfg.ApiImageGeneration:
|
||||
switch config.ProviderType {
|
||||
case cfg.ProviderOpenAI:
|
||||
return image.HandleOpenAIImageGenerationResponseBody(ctx, config, body)
|
||||
case cfg.ProviderQwen:
|
||||
return image.HandleQwenImageGenerationResponseBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("[on response body] image generation api don't support provider: %s", config.ProviderType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
default:
|
||||
log.Errorf("[on response body] multi_modal_guard don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type ImageItem struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
}
|
||||
|
||||
func HandleImageGenerationResponseHeader(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
ctx.SetContext("risk_detected", false)
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.BufferResponseBody()
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseOpenAIRequest(body []byte) (text string, images []ImageItem) {
|
||||
text = gjson.GetBytes(body, "prompt").String()
|
||||
return text, images
|
||||
}
|
||||
|
||||
func parseOpenAIResponse(body []byte) []ImageItem {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
result := []ImageItem{}
|
||||
for _, part := range gjson.GetBytes(body, "data").Array() {
|
||||
if url := part.Get("url").String(); url != "" {
|
||||
result = append(result, ImageItem{
|
||||
Content: url,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
if b64 := part.Get("b64_json").String(); b64 != "" {
|
||||
result = append(result, ImageItem{
|
||||
Content: b64,
|
||||
Type: "BASE64",
|
||||
})
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func HandleOpenAIImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
content, images := parseOpenAIRequest(body)
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 && len(images) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(images) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCallForImage = func() {
|
||||
img := images[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
singleCall()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func HandleOpenAIImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
imgResults := parseOpenAIResponse(body)
|
||||
if len(imgResults) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
imageIndex := 0
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(imgResults) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(imgResults) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(imgResults) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte("illegal image"), -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
img := imgResults[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseImage(body []byte, jsonPath string) *ImageItem {
|
||||
if gjson.GetBytes(body, jsonPath).Exists() {
|
||||
imgContent := gjson.GetBytes(body, jsonPath).String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
return &ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
}
|
||||
} else {
|
||||
return &ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseImageArray(body []byte, jsonPath string) []ImageItem {
|
||||
result := []ImageItem{}
|
||||
if gjson.GetBytes(body, jsonPath).Exists() {
|
||||
for _, item := range gjson.GetBytes(body, jsonPath).Array() {
|
||||
imgContent := item.String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
result = append(result, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
result = append(result, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseQwenRequest(body []byte) (text string, images []ImageItem) {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
images = []ImageItem{}
|
||||
// 文生图/文生图v1/文生图v2
|
||||
if gjson.GetBytes(body, "input.prompt").Exists() {
|
||||
text += gjson.GetBytes(body, "input.prompt").String()
|
||||
}
|
||||
// 图像背景生成
|
||||
if gjson.GetBytes(body, "input.ref_prompt").Exists() {
|
||||
text += gjson.GetBytes(body, "input.ref_prompt").String()
|
||||
}
|
||||
if gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Exists() {
|
||||
for _, item := range gjson.GetBytes(body, "input.reference_edge.foreground_edge_prompt").Array() {
|
||||
text += item.String()
|
||||
}
|
||||
}
|
||||
if gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Exists() {
|
||||
for _, item := range gjson.GetBytes(body, "input.reference_edge.background_edge_prompt").Array() {
|
||||
text += item.String()
|
||||
}
|
||||
}
|
||||
// 创意文字
|
||||
if gjson.GetBytes(body, "input.text").Exists() {
|
||||
text += gjson.GetBytes(body, "input.text").String()
|
||||
}
|
||||
if gjson.GetBytes(body, "input.negative_prompt").Exists() {
|
||||
text += gjson.GetBytes(body, "input.negative_prompt").String()
|
||||
}
|
||||
// 图像编辑
|
||||
if gjson.GetBytes(body, "input.messages.0.content").Exists() {
|
||||
for _, item := range gjson.GetBytes(body, "input.messages.0.content").Array() {
|
||||
if item.Get("text").Exists() {
|
||||
text += item.Get("text").String()
|
||||
} else if item.Get("image").Exists() {
|
||||
imgContent := item.Get("image").String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// image json path
|
||||
imageJsonPath := []string{
|
||||
"input.image_url", // 图像翻译/人像风格重绘/图像画面扩展/人物实例分割/图像擦除补全
|
||||
"input.base_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
|
||||
"input.mask_image_url", // 通用图像编辑2.1/图像局部重绘/虚拟模特
|
||||
"input.sketch_image_url", // 涂鸦作画
|
||||
"input.template_image_url", // 鞋靴模特
|
||||
"input.shoe_image_url", // 鞋靴模特
|
||||
"input.base_image_url", // 图像背景生成
|
||||
"input.ref_image_url", // 图像背景生成
|
||||
"input.mask_url", // 图像擦除补全
|
||||
"input.foreground_url", // 图像擦除补全
|
||||
"input.person_image_url", // AI试衣
|
||||
"input.top_garment_url", // AI试衣
|
||||
"input.bottom_garment_url", // AI试衣
|
||||
"input.coarse_image_url", // AI试衣
|
||||
"input.template_url", // 人物写真生成
|
||||
}
|
||||
for _, jsonPath := range imageJsonPath {
|
||||
tmpImage := parseImage(body, jsonPath)
|
||||
if tmpImage != nil {
|
||||
images = append(images, *tmpImage)
|
||||
}
|
||||
}
|
||||
// image array json path
|
||||
imageArrayJsonPath := []string{
|
||||
"input.images", // 通用图像编辑2.5/人物图像检测
|
||||
"input.reference_edge.foreground_edge", // 图像背景生成
|
||||
"input.reference_edge.background_edge", // 图像背景生成
|
||||
"input.user_urls", // 人物写真生成
|
||||
}
|
||||
for _, jsonPath := range imageArrayJsonPath {
|
||||
tmpImageArray := parseImageArray(body, jsonPath)
|
||||
images = append(images, tmpImageArray...)
|
||||
}
|
||||
return text, images
|
||||
}
|
||||
|
||||
func parseQwenResponse(body []byte) []string {
|
||||
// qwen api: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2975126
|
||||
result := []string{}
|
||||
// 文生图/文生图v1/文生图v2/通用图像编辑2.5/通用图像编辑2.1/涂鸦作画/图像局部重绘/人像风格重绘
|
||||
// 虚拟模特/图像背景生成/人物写真FaceChain/文生图StableDiffusion/文生图FLUX/文字纹理生成API
|
||||
for _, part := range gjson.GetBytes(body, "output.results").Array() {
|
||||
if url := part.Get("url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
// 图像编辑
|
||||
for _, part := range gjson.GetBytes(body, "output.choices.0.message.content").Array() {
|
||||
if url := part.Get("image").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
// 图像翻译/AI试衣OutfitAnyone
|
||||
if url := gjson.GetBytes(body, "output.image_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 图像画面扩展/(part of)人物实例分割/图像擦除补全
|
||||
if url := gjson.GetBytes(body, "output.output_image_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 鞋靴模特
|
||||
if url := gjson.GetBytes(body, "output.result_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 创意海报生成
|
||||
for _, part := range gjson.GetBytes(body, "output.render_urls").Array() {
|
||||
if url := part.String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
for _, part := range gjson.GetBytes(body, "output.bg_urls").Array() {
|
||||
if url := part.String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
// 人物实例分割
|
||||
if url := gjson.GetBytes(body, "output.output_vis_image_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
// 文字变形API
|
||||
for _, part := range gjson.GetBytes(body, "output.results").Array() {
|
||||
if url := part.Get("png_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
if url := part.Get("svg_url").String(); url != "" {
|
||||
result = append(result, url)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func HandleQwenImageGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||
content, images := parseQwenRequest(body)
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 && len(images) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(images) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCallForImage = func() {
|
||||
img := images[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
singleCall()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func HandleQwenImageGenerationResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
log.Debugf("checking response body...")
|
||||
checkImageService := config.GetResponseImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
imgUrls := parseQwenResponse(body)
|
||||
if len(imgUrls) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
imageIndex := 0
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(imgUrls) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(imgUrls) {
|
||||
singleCall()
|
||||
} else {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(imgUrls) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
proxywasm.SendHttpResponse(403, [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
imgUrl := imgUrls[imageIndex]
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, "")
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type ImageItem struct {
|
||||
Content string
|
||||
Type string // URL or BASE64
|
||||
}
|
||||
|
||||
func parseContent(json gjson.Result) (text string, images []ImageItem) {
|
||||
images = []ImageItem{}
|
||||
if json.IsArray() {
|
||||
for _, item := range json.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
text += item.Get("text").String()
|
||||
case "image_url":
|
||||
imgContent := item.Get("image_url.url").String()
|
||||
if strings.HasPrefix(imgContent, "data:image") {
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "BASE64",
|
||||
})
|
||||
} else {
|
||||
images = append(images, ImageItem{
|
||||
Content: imgContent,
|
||||
Type: "URL",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
text = json.String()
|
||||
}
|
||||
return text, images
|
||||
}
|
||||
|
||||
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
checkImageService := config.GetRequestImageCheckService(consumer)
|
||||
startTime := time.Now().UnixMilli()
|
||||
// content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||
content, images := parseContent(gjson.GetBytes(body, config.RequestContentJsonPath))
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 && len(images) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
imageIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
var singleCallForImage func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
if len(images) > 0 && config.CheckRequestImage {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
log.Debugf("current content piece: %s", contentPiece)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.MultiModalGuard, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
|
||||
callbackForImage := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
imageIndex += 1
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Errorf("%+v", err)
|
||||
if imageIndex < len(images) {
|
||||
singleCallForImage()
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return
|
||||
}
|
||||
endTime := time.Now().UnixMilli()
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if imageIndex >= len(images) {
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCallForImage = func() {
|
||||
img := images[imageIndex]
|
||||
imgUrl := ""
|
||||
imgBase64 := ""
|
||||
if img.Type == "BASE64" {
|
||||
imgBase64 = img.Content
|
||||
} else {
|
||||
imgUrl = img.Content
|
||||
}
|
||||
path, headers, body := common.GenerateRequestForImage(config, cfg.MultiModalGuardForBase64, checkImageService, imgUrl, imgBase64)
|
||||
err := config.Client.Post(path, headers, body, callbackForImage, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
if len(content) > 0 {
|
||||
singleCall()
|
||||
} else {
|
||||
singleCallForImage()
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package text_moderation_plus
|
||||
|
||||
import (
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
common_text "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common/text"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/text_moderation_plus/text"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
func OnHttpRequestHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func OnHttpRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
return text.HandleTextGenerationRequestBody(ctx, config, body)
|
||||
}
|
||||
|
||||
func OnHttpResponseHeaders(ctx wrapper.HttpContext, config cfg.AISecurityConfig) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseHeader(ctx, config)
|
||||
default:
|
||||
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpStreamingResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, data []byte, endOfStream bool) []byte {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationStreamingResponseBody(ctx, config, data, endOfStream)
|
||||
default:
|
||||
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||
return data
|
||||
}
|
||||
}
|
||||
|
||||
func OnHttpResponseBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
switch config.ApiType {
|
||||
case cfg.ApiTextGeneration:
|
||||
return common_text.HandleTextGenerationResponseBody(ctx, config, body)
|
||||
default:
|
||||
log.Errorf("text_moderation_plus don't support api: %s", config.ApiType)
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package text
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
cfg "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/config"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/lvwang/common"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-security-guard/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func HandleTextGenerationRequestBody(ctx wrapper.HttpContext, config cfg.AISecurityConfig, body []byte) types.Action {
|
||||
consumer, _ := ctx.GetContext("consumer").(string)
|
||||
startTime := time.Now().UnixMilli()
|
||||
content := gjson.GetBytes(body, config.RequestContentJsonPath).String()
|
||||
log.Debugf("Raw request content is: %s", content)
|
||||
if len(content) == 0 {
|
||||
log.Info("request content is empty. skip")
|
||||
return types.ActionContinue
|
||||
}
|
||||
contentIndex := 0
|
||||
sessionID, _ := utils.GenerateHexID(20)
|
||||
var singleCall func()
|
||||
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
log.Info(string(responseBody))
|
||||
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
var response cfg.Response
|
||||
err := json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
log.Error("failed to unmarshal aliyun content security response at request phase")
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
if cfg.IsRiskLevelAcceptable(config.Action, response.Data, config, consumer) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCall()
|
||||
}
|
||||
return
|
||||
}
|
||||
denyMessage := cfg.DefaultDenyMessage
|
||||
if config.DenyMessage != "" {
|
||||
denyMessage = config.DenyMessage
|
||||
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := wrapper.MarshalStr(denyMessage)
|
||||
if config.ProtocolOriginal {
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if gjson.GetBytes(body, "stream").Bool() {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := utils.GenerateRandomChatID()
|
||||
jsonData := []byte(fmt.Sprintf(cfg.OpenAIResponseFormat, randomID, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.DenyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.IncrementCounter("ai_sec_request_deny", 1)
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
var nextContentIndex int
|
||||
if contentIndex+cfg.LengthLimit >= len(content) {
|
||||
nextContentIndex = len(content)
|
||||
} else {
|
||||
nextContentIndex = contentIndex + cfg.LengthLimit
|
||||
}
|
||||
contentPiece := content[contentIndex:nextContentIndex]
|
||||
contentIndex = nextContentIndex
|
||||
checkService := config.GetRequestCheckService(consumer)
|
||||
path, headers, body := common.GenerateRequestForText(config, cfg.TextModerationPlus, checkService, contentPiece, sessionID)
|
||||
err := config.Client.Post(path, headers, body, callback, config.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}
|
||||
singleCall()
|
||||
return types.ActionPause
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user