mirror of
https://github.com/alibaba/higress.git
synced 2026-02-25 21:21:01 +08:00
Compare commits
34 Commits
v2.0.4
...
v2.0.6-rc.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a570c72504 | ||
|
|
ab1316dfe1 | ||
|
|
e97448b71b | ||
|
|
6820a06a99 | ||
|
|
4733af849d | ||
|
|
1c2330e33b | ||
|
|
61fef0ecf8 | ||
|
|
d29b8d7ca8 | ||
|
|
2501895b66 | ||
|
|
187a7b5408 | ||
|
|
00be491d02 | ||
|
|
2d74c48e8a | ||
|
|
6dc4d43df5 | ||
|
|
2a4e55d46f | ||
|
|
579c986915 | ||
|
|
380717ae3d | ||
|
|
8f3723f554 | ||
|
|
909cc0f088 | ||
|
|
4eaf204737 | ||
|
|
748bcb083a | ||
|
|
39c007d045 | ||
|
|
d74d327b68 | ||
|
|
be27726721 | ||
|
|
34cc1c0632 | ||
|
|
5694475872 | ||
|
|
2f5709a93e | ||
|
|
2a200cdd42 | ||
|
|
ec39d56731 | ||
|
|
8544fa604d | ||
|
|
0ba63e5dd4 | ||
|
|
441408c593 | ||
|
|
be57960c22 | ||
|
|
f32020068a | ||
|
|
1a8fce48f0 |
2
.github/workflows/release-hgctl.yaml
vendored
2
.github/workflows/release-hgctl.yaml
vendored
@@ -58,7 +58,7 @@ jobs:
|
||||
hgctl_${{ env.HGCTL_VERSION }}_darwin_arm64.tar.gz
|
||||
|
||||
release-hgctl-macos-amd64:
|
||||
runs-on: macos-12
|
||||
runs-on: macos-14
|
||||
env:
|
||||
HGCTL_VERSION: ${{github.ref_name}}
|
||||
steps:
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
/envoy @gengleilei @johnlanni
|
||||
/istio @SpecialYang @johnlanni
|
||||
/pkg @SpecialYang @johnlanni @CH3CHO
|
||||
/plugins @johnlanni @WeixinX @CH3CHO
|
||||
/plugins @johnlanni @CH3CHO @rinfx
|
||||
/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx
|
||||
/plugins/wasm-rust @007gzs @jizhuozhi
|
||||
/registry @NameHaibinZhang @2456868764 @johnlanni
|
||||
/test @Xunzhuo @2456868764 @CH3CHO
|
||||
|
||||
@@ -187,8 +187,8 @@ install: pre-install
|
||||
cd helm/higress; helm dependency build
|
||||
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
|
||||
|
||||
ENVOY_LATEST_IMAGE_TAG ?= 2.0.3
|
||||
ISTIO_LATEST_IMAGE_TAG ?= 8be82d2e4c280c29f4952fbeca1e2a79230b7836
|
||||
ENVOY_LATEST_IMAGE_TAG ?= 958467a353d411ae3f06e03b096bfd342cddb2c6
|
||||
ISTIO_LATEST_IMAGE_TAG ?= 958467a353d411ae3f06e03b096bfd342cddb2c6
|
||||
|
||||
install-dev: pre-install
|
||||
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'
|
||||
@@ -299,7 +299,7 @@ kube-load-image: $(tools/kind) ## Install the Higress image to a kind cluster us
|
||||
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server 1.3.0
|
||||
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server v1.0
|
||||
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-body 1.0.0
|
||||
tools/hack/docker-pull-image.sh openpolicyagent/opa latest
|
||||
tools/hack/docker-pull-image.sh openpolicyagent/opa 0.61.0
|
||||
tools/hack/docker-pull-image.sh curlimages/curl latest
|
||||
tools/hack/docker-pull-image.sh registry.cn-hangzhou.aliyuncs.com/2456868764/httpbin 1.0.2
|
||||
tools/hack/docker-pull-image.sh registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3 1.0.0-RC3
|
||||
@@ -312,7 +312,7 @@ kube-load-image: $(tools/kind) ## Install the Higress image to a kind cluster us
|
||||
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server 1.3.0
|
||||
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server v1.0
|
||||
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-body 1.0.0
|
||||
tools/hack/kind-load-image.sh openpolicyagent/opa latest
|
||||
tools/hack/kind-load-image.sh openpolicyagent/opa 0.61.0
|
||||
tools/hack/kind-load-image.sh curlimages/curl latest
|
||||
tools/hack/kind-load-image.sh registry.cn-hangzhou.aliyuncs.com/2456868764/httpbin 1.0.2
|
||||
tools/hack/kind-load-image.sh registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3 1.0.0-RC3
|
||||
|
||||
@@ -6,9 +6,14 @@
|
||||
</h1>
|
||||
<h4 align="center"> AI Native API Gateway </h4>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/alibaba/higress/actions)
|
||||
[](https://www.apache.org/licenses/LICENSE-2.0.html)
|
||||
|
||||
<a href="https://trendshift.io/repositories/10918" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10918" alt="alibaba%2Fhigress | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
</div>
|
||||
|
||||
[**官网**](https://higress.cn/) |
|
||||
[**文档**](https://higress.cn/docs/latest/overview/what-is-higress/) |
|
||||
[**博客**](https://higress.cn/blog/) |
|
||||
@@ -17,6 +22,7 @@
|
||||
[**AI插件**](https://higress.cn/plugin/)
|
||||
|
||||
|
||||
|
||||
<p>
|
||||
<a href="README_EN.md"> English <a/>| 中文 | <a href="README_JP.md"> 日本語 <a/>
|
||||
</p>
|
||||
@@ -180,7 +186,7 @@ K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start
|
||||
|
||||
### 交流群
|
||||
|
||||

|
||||

|
||||
|
||||
### 技术分享
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.0.4
|
||||
appVersion: 2.0.6-rc.1
|
||||
description: Helm chart for deploying higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -10,4 +10,4 @@ name: higress-core
|
||||
sources:
|
||||
- http://github.com/alibaba/higress
|
||||
type: application
|
||||
version: 2.0.4
|
||||
version: 2.0.6-rc.1
|
||||
|
||||
@@ -136,6 +136,8 @@ spec:
|
||||
periodSeconds: 3
|
||||
timeoutSeconds: 5
|
||||
env:
|
||||
- name: PILOT_ENABLE_LDS_CACHE
|
||||
valvue: "{{ .Values.global.enableLDSCache }}"
|
||||
- name: PILOT_ENABLE_QUIC_LISTENERS
|
||||
value: "true"
|
||||
- name: VALIDATION_WEBHOOK_CONFIG_NAME
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{{- if eq .Values.gateway.kind "DaemonSet" -}}
|
||||
{{- $o11y := .Values.global.o11y }}
|
||||
{{- $unprivilegedPortSupported := true }}
|
||||
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }}
|
||||
{{- if eq .Values.gateway.unprivilegedPortSupported nil -}}
|
||||
{{- $unprivilegedPortSupported := true }}
|
||||
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }}
|
||||
{{- $kernelVersion := $node.status.nodeInfo.kernelVersion }}
|
||||
{{- if $kernelVersion }}
|
||||
{{- $kernelVersion = regexFind "^(\\d+\\.\\d+\\.\\d+)" $kernelVersion }}
|
||||
@@ -9,8 +10,9 @@
|
||||
{{- $unprivilegedPortSupported = false }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end -}}
|
||||
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
|
||||
{{- end -}}
|
||||
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
|
||||
|
||||
apiVersion: apps/v1
|
||||
kind: DaemonSet
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{{- if eq .Values.gateway.kind "Deployment" -}}
|
||||
{{- $unprivilegedPortSupported := true }}
|
||||
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }}
|
||||
{{- if eq .Values.gateway.unprivilegedPortSupported nil -}}
|
||||
{{- $unprivilegedPortSupported := true }}
|
||||
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }}
|
||||
{{- $kernelVersion := $node.status.nodeInfo.kernelVersion }}
|
||||
{{- if $kernelVersion }}
|
||||
{{- $kernelVersion = regexFind "^(\\d+\\.\\d+\\.\\d+)" $kernelVersion }}
|
||||
@@ -8,8 +9,9 @@
|
||||
{{- $unprivilegedPortSupported = false }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end -}}
|
||||
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
|
||||
{{- end -}}
|
||||
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
|
||||
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
|
||||
@@ -3,7 +3,8 @@ global:
|
||||
enableH3: false
|
||||
enableIPv6: false
|
||||
enableProxyProtocol: false
|
||||
liteMetrics: true
|
||||
enableLDSCache: true
|
||||
liteMetrics: false
|
||||
xdsMaxRecvMsgSize: "104857600"
|
||||
defaultUpstreamConcurrencyThreshold: 10000
|
||||
enableSRDS: true
|
||||
@@ -465,6 +466,7 @@ gateway:
|
||||
# On Kubernetes 1.22+, this only requires the `net.ipv4.ip_unprivileged_port_start` sysctl.
|
||||
securityContext: ~
|
||||
containerSecurityContext: ~
|
||||
unprivilegedPortSupported: ~
|
||||
|
||||
service:
|
||||
# -- Type of service. Set to "None" to disable the service entirely
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: file://../core
|
||||
version: 2.0.4
|
||||
version: 2.0.6-rc.1
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 1.4.6
|
||||
digest: sha256:ec570ac7ae8a6de976e7ffafaadae4a33beeabfb4b13debe63e0cfa100e2eb8c
|
||||
generated: "2024-12-06T11:34:04.628976+08:00"
|
||||
version: 2.0.0
|
||||
digest: sha256:66a5261f3d68abf63d2bade50e36ac696bec8aac909442d328fd5d395bf4bc21
|
||||
generated: "2025-01-08T17:14:12.432022+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.0.4
|
||||
appVersion: 2.0.6-rc.1
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -12,9 +12,9 @@ sources:
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: "file://../core"
|
||||
version: 2.0.4
|
||||
version: 2.0.6-rc.1
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 1.4.6
|
||||
version: 2.0.0
|
||||
type: application
|
||||
version: 2.0.4
|
||||
version: 2.0.6-rc.1
|
||||
|
||||
@@ -149,6 +149,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
|
||||
| gateway.tag | string | `""` | |
|
||||
| gateway.tolerations | list | `[]` | |
|
||||
| gateway.unprivilegedPortSupported | string | `nil` | |
|
||||
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
|
||||
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |
|
||||
| global.caName | string | `""` | The name of the CA for workload certificates. For example, when caName=GkeWorkloadCertificate, GKE workload certificates will be used as the certificates for workloads. The default value is "" and when caName="", the CA will be configured by other mechanisms (e.g., environmental variable CA_PROVIDER). |
|
||||
@@ -161,6 +162,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.enableH3 | bool | `false` | |
|
||||
| global.enableIPv6 | bool | `false` | |
|
||||
| global.enableIstioAPI | bool | `true` | If true, Higress Controller will monitor istio resources as well |
|
||||
| global.enableLDSCache | bool | `true` | |
|
||||
| global.enableProxyProtocol | bool | `false` | |
|
||||
| global.enableSRDS | bool | `true` | |
|
||||
| global.enableStatus | bool | `true` | If true, Higress Controller will update the status field of Ingress resources. When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the status field of the corresponding Ingress object. |
|
||||
@@ -174,7 +176,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.istiod | object | `{"enableAnalysis":false}` | Enabled by default in master for maximising testing. |
|
||||
| global.jwtPolicy | string | `"third-party-jwt"` | Configure the policy for validating JWT. Currently, two options are supported: "third-party-jwt" and "first-party-jwt". |
|
||||
| global.kind | bool | `false` | |
|
||||
| global.liteMetrics | bool | `true` | |
|
||||
| global.liteMetrics | bool | `false` | |
|
||||
| global.local | bool | `false` | When deploying to a local cluster (e.g.: kind cluster), set this to true. |
|
||||
| global.logAsJson | bool | `false` | |
|
||||
| global.logging | object | `{"level":"default:info"}` | Comma-separated minimum per-scope logging level of messages to output, in the form of <scope>:<level>,<scope>:<level> The control plane has different scopes depending on component, but can configure default log level across all components If empty, default scope and level will be used as configured in code |
|
||||
|
||||
Submodule istio/istio updated: 0fa834f7b9...81a46c581f
@@ -15,6 +15,7 @@
|
||||
package annotations
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
networking "istio.io/api/networking/v1alpha3"
|
||||
@@ -27,9 +28,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
authTLSSecret = "auth-tls-secret"
|
||||
sslCipher = "ssl-cipher"
|
||||
gatewaySdsCaSuffix = "-cacert"
|
||||
authTLSSecret = "auth-tls-secret"
|
||||
sslCipher = "ssl-cipher"
|
||||
gatewaySdsCaSuffix = "-cacert"
|
||||
annotationMinTLSVersion = "tls-min-protocol-version"
|
||||
annotationMaxTLSVersion = "tls-max-protocol-version"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,6 +44,8 @@ type DownstreamTLSConfig struct {
|
||||
CipherSuites []string
|
||||
Mode networking.ServerTLSSettings_TLSmode
|
||||
CASecretName types.NamespacedName
|
||||
MinVersion string
|
||||
MaxVersion string
|
||||
}
|
||||
|
||||
type downstreamTLS struct{}
|
||||
@@ -82,6 +87,14 @@ func (d downstreamTLS) Parse(annotations Annotations, config *Ingress, _ *Global
|
||||
downstreamTLSConfig.CipherSuites = validCipherSuite
|
||||
}
|
||||
|
||||
if minVersion, err := annotations.ParseStringASAP(annotationMinTLSVersion); err == nil {
|
||||
downstreamTLSConfig.MinVersion = minVersion
|
||||
}
|
||||
|
||||
if maxVersion, err := annotations.ParseStringASAP(annotationMaxTLSVersion); err == nil {
|
||||
downstreamTLSConfig.MaxVersion = maxVersion
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,11 +120,44 @@ func (d downstreamTLS) ApplyGateway(gateway *networking.Gateway, config *Ingress
|
||||
if len(downstreamTLSConfig.CipherSuites) != 0 {
|
||||
server.Tls.CipherSuites = downstreamTLSConfig.CipherSuites
|
||||
}
|
||||
|
||||
if downstreamTLSConfig.MinVersion != "" {
|
||||
if version, err := convertTLSVersion(downstreamTLSConfig.MinVersion); err != nil {
|
||||
IngressLog.Errorf("Invalid minimum TLS version: %v", err)
|
||||
} else {
|
||||
server.Tls.MinProtocolVersion = version
|
||||
}
|
||||
}
|
||||
|
||||
if downstreamTLSConfig.MaxVersion != "" {
|
||||
if version, err := convertTLSVersion(downstreamTLSConfig.MaxVersion); err != nil {
|
||||
IngressLog.Errorf("Invalid maximum TLS version: %v", err)
|
||||
} else {
|
||||
server.Tls.MaxProtocolVersion = version
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func needDownstreamTLS(annotations Annotations) bool {
|
||||
return annotations.HasASAP(sslCipher) ||
|
||||
annotations.HasASAP(authTLSSecret)
|
||||
annotations.HasASAP(authTLSSecret) ||
|
||||
annotations.HasASAP(annotationMinTLSVersion) ||
|
||||
annotations.HasASAP(annotationMaxTLSVersion)
|
||||
}
|
||||
|
||||
func convertTLSVersion(version string) (networking.ServerTLSSettings_TLSProtocol, error) {
|
||||
switch version {
|
||||
case "TLSv1.0":
|
||||
return networking.ServerTLSSettings_TLSV1_0, nil
|
||||
case "TLSv1.1":
|
||||
return networking.ServerTLSSettings_TLSV1_1, nil
|
||||
case "TLSv1.2":
|
||||
return networking.ServerTLSSettings_TLSV1_2, nil
|
||||
case "TLSv1.3":
|
||||
return networking.ServerTLSSettings_TLSV1_3, nil
|
||||
}
|
||||
return networking.ServerTLSSettings_TLS_AUTO, fmt.Errorf("invalid TLS version: %s. Valid values are: TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3", version)
|
||||
}
|
||||
|
||||
@@ -26,11 +26,15 @@ var parser = downstreamTLS{}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input map[string]string
|
||||
expect *DownstreamTLSConfig
|
||||
}{
|
||||
{},
|
||||
{
|
||||
name: "empty config",
|
||||
},
|
||||
{
|
||||
name: "ssl cipher only",
|
||||
input: map[string]string{
|
||||
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
|
||||
},
|
||||
@@ -40,9 +44,24 @@ func TestParse(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with TLS version config",
|
||||
input: map[string]string{
|
||||
buildNginxAnnotationKey(authTLSSecret): "test",
|
||||
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
|
||||
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
|
||||
buildNginxAnnotationKey(annotationMaxTLSVersion): "TLSv1.3",
|
||||
},
|
||||
expect: &DownstreamTLSConfig{
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
MinVersion: "TLSv1.2",
|
||||
MaxVersion: "TLSv1.3",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete config",
|
||||
input: map[string]string{
|
||||
buildNginxAnnotationKey(authTLSSecret): "test",
|
||||
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
|
||||
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
|
||||
buildNginxAnnotationKey(annotationMaxTLSVersion): "TLSv1.3",
|
||||
},
|
||||
expect: &DownstreamTLSConfig{
|
||||
CASecretName: types.NamespacedName{
|
||||
@@ -51,34 +70,79 @@ func TestParse(t *testing.T) {
|
||||
},
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384", "AES128-SHA"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: map[string]string{
|
||||
buildHigressAnnotationKey(authTLSSecret): "test/foo",
|
||||
DefaultAnnotationsPrefix + "/" + sslCipher: "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
|
||||
},
|
||||
expect: &DownstreamTLSConfig{
|
||||
CASecretName: types.NamespacedName{
|
||||
Namespace: "test",
|
||||
Name: "foo",
|
||||
},
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384", "AES128-SHA"},
|
||||
MinVersion: "TLSv1.2",
|
||||
MaxVersion: "TLSv1.3",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config := &Ingress{
|
||||
Meta: Meta{
|
||||
Namespace: "foo",
|
||||
},
|
||||
}
|
||||
_ = parser.Parse(testCase.input, config, nil)
|
||||
if !reflect.DeepEqual(testCase.expect, config.DownstreamTLS) {
|
||||
t.Fatalf("Should be equal")
|
||||
err := parser.Parse(tc.input, config, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse failed: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(tc.expect, config.DownstreamTLS) {
|
||||
t.Fatalf("Parse result mismatch:\nExpect: %+v\nGot: %+v", tc.expect, config.DownstreamTLS)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertTLSVersion(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
version string
|
||||
expect networking.ServerTLSSettings_TLSProtocol
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "TLS 1.0",
|
||||
version: "TLSv1.0",
|
||||
expect: networking.ServerTLSSettings_TLSV1_0,
|
||||
},
|
||||
{
|
||||
name: "TLS 1.1",
|
||||
version: "TLSv1.1",
|
||||
expect: networking.ServerTLSSettings_TLSV1_1,
|
||||
},
|
||||
{
|
||||
name: "TLS 1.2",
|
||||
version: "TLSv1.2",
|
||||
expect: networking.ServerTLSSettings_TLSV1_2,
|
||||
},
|
||||
{
|
||||
name: "TLS 1.3",
|
||||
version: "TLSv1.3",
|
||||
expect: networking.ServerTLSSettings_TLSV1_3,
|
||||
},
|
||||
{
|
||||
name: "invalid version",
|
||||
version: "invalid",
|
||||
expect: networking.ServerTLSSettings_TLS_AUTO,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result, err := convertTLSVersion(tc.version)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != tc.expect {
|
||||
t.Errorf("Expected %v but got %v", tc.expect, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -86,11 +150,13 @@ func TestParse(t *testing.T) {
|
||||
|
||||
func TestApplyGateway(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input *networking.Gateway
|
||||
config *Ingress
|
||||
expect *networking.Gateway
|
||||
}{
|
||||
{
|
||||
name: "apply TLS version",
|
||||
input: &networking.Gateway{
|
||||
Servers: []*networking.Server{
|
||||
{
|
||||
@@ -105,7 +171,8 @@ func TestApplyGateway(t *testing.T) {
|
||||
},
|
||||
config: &Ingress{
|
||||
DownstreamTLS: &DownstreamTLSConfig{
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
MinVersion: "TLSv1.2",
|
||||
MaxVersion: "TLSv1.3",
|
||||
},
|
||||
},
|
||||
expect: &networking.Gateway{
|
||||
@@ -115,14 +182,16 @@ func TestApplyGateway(t *testing.T) {
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
Tls: &networking.ServerTLSSettings{
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
MinProtocolVersion: networking.ServerTLSSettings_TLSV1_2,
|
||||
MaxProtocolVersion: networking.ServerTLSSettings_TLSV1_3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complete config",
|
||||
input: &networking.Gateway{
|
||||
Servers: []*networking.Server{
|
||||
{
|
||||
@@ -144,24 +213,28 @@ func TestApplyGateway(t *testing.T) {
|
||||
},
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
MinVersion: "TLSv1.2",
|
||||
MaxVersion: "TLSv1.3",
|
||||
},
|
||||
},
|
||||
expect: &networking.Gateway{
|
||||
Servers: []*networking.Server{
|
||||
{
|
||||
Port: &networking.Port{
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
{Port: &networking.Port{
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
Tls: &networking.ServerTLSSettings{
|
||||
CredentialName: "kubernetes-ingress://cluster/foo/bar",
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
CredentialName: "kubernetes-ingress://cluster/foo/bar",
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
MinProtocolVersion: networking.ServerTLSSettings_TLSV1_2,
|
||||
MaxProtocolVersion: networking.ServerTLSSettings_TLSV1_3,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid TLS version",
|
||||
input: &networking.Gateway{
|
||||
Servers: []*networking.Server{
|
||||
{
|
||||
@@ -169,20 +242,15 @@ func TestApplyGateway(t *testing.T) {
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
Tls: &networking.ServerTLSSettings{
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
CredentialName: "kubernetes-ingress://cluster/foo/bar",
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
config: &Ingress{
|
||||
DownstreamTLS: &DownstreamTLSConfig{
|
||||
CASecretName: types.NamespacedName{
|
||||
Namespace: "foo",
|
||||
Name: "bar-cacert",
|
||||
},
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
MinVersion: "invalid",
|
||||
MaxVersion: "invalid",
|
||||
},
|
||||
},
|
||||
expect: &networking.Gateway{
|
||||
@@ -192,48 +260,10 @@ func TestApplyGateway(t *testing.T) {
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
Tls: &networking.ServerTLSSettings{
|
||||
CredentialName: "kubernetes-ingress://cluster/foo/bar",
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: &networking.Gateway{
|
||||
Servers: []*networking.Server{
|
||||
{
|
||||
Port: &networking.Port{
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
Tls: &networking.ServerTLSSettings{
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
CredentialName: "kubernetes-ingress://cluster/foo/bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
config: &Ingress{
|
||||
DownstreamTLS: &DownstreamTLSConfig{
|
||||
CASecretName: types.NamespacedName{
|
||||
Namespace: "bar",
|
||||
Name: "foo",
|
||||
},
|
||||
Mode: networking.ServerTLSSettings_MUTUAL,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
},
|
||||
},
|
||||
expect: &networking.Gateway{
|
||||
Servers: []*networking.Server{
|
||||
{
|
||||
Port: &networking.Port{
|
||||
Protocol: "HTTPS",
|
||||
},
|
||||
Tls: &networking.ServerTLSSettings{
|
||||
CredentialName: "kubernetes-ingress://cluster/foo/bar",
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
|
||||
Mode: networking.ServerTLSSettings_SIMPLE,
|
||||
// Invalid versions should default to TLS_AUTO
|
||||
MinProtocolVersion: networking.ServerTLSSettings_TLS_AUTO,
|
||||
MaxProtocolVersion: networking.ServerTLSSettings_TLS_AUTO,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -241,11 +271,59 @@ func TestApplyGateway(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
parser.ApplyGateway(testCase.input, testCase.config)
|
||||
if !reflect.DeepEqual(testCase.input, testCase.expect) {
|
||||
t.Fatalf("Should be equal")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
parser.ApplyGateway(tc.input, tc.config)
|
||||
if !reflect.DeepEqual(tc.input, tc.expect) {
|
||||
t.Fatalf("ApplyGateway result mismatch for %s:\nExpect: %+v\nGot: %+v",
|
||||
tc.name, tc.expect, tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedDownstreamTLS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
annotations map[string]string
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "empty annotations",
|
||||
annotations: map[string]string{},
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "with ssl cipher",
|
||||
annotations: map[string]string{
|
||||
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384",
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "with TLS version",
|
||||
annotations: map[string]string{
|
||||
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "with multiple TLS configs",
|
||||
annotations: map[string]string{
|
||||
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384",
|
||||
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
|
||||
buildNginxAnnotationKey(annotationMaxTLSVersion): "TLSv1.3",
|
||||
},
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := needDownstreamTLS(tc.annotations)
|
||||
if result != tc.expect {
|
||||
t.Errorf("needDownstreamTLS() for %s = %v, want %v",
|
||||
tc.name, result, tc.expect)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -62,7 +63,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.serviceName = json.Get("serviceName").String()
|
||||
c.servicePort = int(json.Get("servicePort").Int())
|
||||
if !json.Get("servicePort").Exists() {
|
||||
c.servicePort = 6379
|
||||
if strings.HasSuffix(c.serviceName, ".static") {
|
||||
// use default logic port which is 80 for static service
|
||||
c.servicePort = 80
|
||||
} else {
|
||||
c.servicePort = 6379
|
||||
}
|
||||
}
|
||||
c.serviceHost = json.Get("serviceHost").String()
|
||||
c.username = json.Get("username").String()
|
||||
|
||||
@@ -74,6 +74,9 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC
|
||||
|
||||
ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil)
|
||||
|
||||
ctx.SetUserAttribute("cache_status", "hit")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
|
||||
if stream {
|
||||
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1)
|
||||
} else {
|
||||
|
||||
158
plugins/wasm-go/extensions/ai-cache/embedding/cohere.go
Normal file
158
plugins/wasm-go/extensions/ai-cache/embedding/cohere.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
COHERE_DOMAIN = "api.cohere.com"
|
||||
COHERE_PORT = 443
|
||||
COHERE_DEFAULT_MODEL_NAME = "embed-english-v2.0"
|
||||
COHERE_ENDPOINT = "/v2/embed"
|
||||
)
|
||||
|
||||
type cohereProviderInitializer struct {
|
||||
}
|
||||
|
||||
var cohereConfig cohereProviderConfig
|
||||
|
||||
type cohereProviderConfig struct {
|
||||
// @Title zh-CN 文本特征提取服务 API Key
|
||||
// @Description zh-CN 文本特征提取服务 API Key
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (c *cohereProviderInitializer) InitConfig(json gjson.Result) {
|
||||
cohereConfig.apiKey = json.Get("apiKey").String()
|
||||
}
|
||||
func (c *cohereProviderInitializer) ValidateConfig() error {
|
||||
if cohereConfig.apiKey == "" {
|
||||
return errors.New("[Cohere] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *cohereProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
|
||||
if c.servicePort == 0 {
|
||||
c.servicePort = COHERE_PORT
|
||||
}
|
||||
if c.serviceHost == "" {
|
||||
c.serviceHost = COHERE_DOMAIN
|
||||
}
|
||||
return &CohereProvider{
|
||||
config: c,
|
||||
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: c.serviceName,
|
||||
Host: c.serviceHost,
|
||||
Port: int64(c.servicePort),
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type cohereResponse struct {
|
||||
Embeddings cohereEmbeddings `json:"embeddings"`
|
||||
}
|
||||
|
||||
type cohereEmbeddings struct {
|
||||
FloatTypeEebedding [][]float64 `json:"float"`
|
||||
}
|
||||
|
||||
type cohereEmbeddingRequest struct {
|
||||
Texts []string `json:"texts"`
|
||||
Model string `json:"model"`
|
||||
InputType string `json:"input_type"`
|
||||
EmbeddingTypes []string `json:"embedding_types"`
|
||||
}
|
||||
|
||||
type CohereProvider struct {
|
||||
config ProviderConfig
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (t *CohereProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_COHERE
|
||||
}
|
||||
func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
model := t.config.model
|
||||
|
||||
if model == "" {
|
||||
model = COHERE_DEFAULT_MODEL_NAME
|
||||
}
|
||||
data := cohereEmbeddingRequest{
|
||||
Texts: texts,
|
||||
Model: model,
|
||||
InputType: "search_document",
|
||||
EmbeddingTypes: []string{"float"},
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal request data: %v", err)
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
headers := [][2]string{
|
||||
{"Authorization", fmt.Sprintf("BEARER %s", cohereConfig.apiKey)},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
|
||||
return COHERE_ENDPOINT, headers, requestBody, nil
|
||||
}
|
||||
|
||||
func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereResponse, error) {
|
||||
var resp cohereResponse
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (t *CohereProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
|
||||
if err != nil {
|
||||
log.Errorf("failed to construct parameters: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var resp *cohereResponse
|
||||
err = t.client.Post(embUrl, embHeaders, embRequestBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode))
|
||||
callback(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("get embedding response: %d, %s", statusCode, responseBody)
|
||||
|
||||
resp, err = t.parseTextEmbedding(responseBody)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse response: %v", err)
|
||||
callback(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(resp.Embeddings.FloatTypeEebedding) == 0 {
|
||||
err = errors.New("no embedding found in response")
|
||||
callback(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
callback(resp.Embeddings.FloatTypeEebedding[0], nil)
|
||||
|
||||
}, t.config.timeout)
|
||||
return err
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -17,11 +18,22 @@ const (
|
||||
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
)
|
||||
|
||||
var dashScopeConfig dashScopeProviderConfig
|
||||
|
||||
type dashScopeProviderInitializer struct {
|
||||
}
|
||||
type dashScopeProviderConfig struct {
|
||||
// @Title zh-CN 文本特征提取服务 API Key
|
||||
// @Description zh-CN 文本特征提取服务 API Key
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if config.apiKey == "" {
|
||||
func (c *dashScopeProviderInitializer) InitConfig(json gjson.Result) {
|
||||
dashScopeConfig.apiKey = json.Get("apiKey").String()
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) ValidateConfig() error {
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
return errors.New("[DashScope] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
@@ -114,14 +126,14 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
if d.config.apiKey == "" {
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
err := errors.New("dashScopeKey is empty")
|
||||
log.Errorf("failed to construct headers: %v", err)
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
headers := [][2]string{
|
||||
{"Authorization", "Bearer " + d.config.apiKey},
|
||||
{"Authorization", "Bearer " + dashScopeConfig.apiKey},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
|
||||
|
||||
170
plugins/wasm-go/extensions/ai-cache/embedding/openai.go
Normal file
170
plugins/wasm-go/extensions/ai-cache/embedding/openai.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
OPENAI_DOMAIN = "api.openai.com"
|
||||
OPENAI_PORT = 443
|
||||
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
|
||||
OPENAI_ENDPOINT = "/v1/embeddings"
|
||||
)
|
||||
|
||||
type openAIProviderInitializer struct {
|
||||
}
|
||||
|
||||
var openAIConfig openAIProviderConfig
|
||||
|
||||
type openAIProviderConfig struct {
|
||||
// @Title zh-CN 文本特征提取服务 API Key
|
||||
// @Description zh-CN 文本特征提取服务 API Key
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) InitConfig(json gjson.Result) {
|
||||
openAIConfig.apiKey = json.Get("apiKey").String()
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) ValidateConfig() error {
|
||||
if openAIConfig.apiKey == "" {
|
||||
return errors.New("[openAI] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *openAIProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
|
||||
if c.servicePort == 0 {
|
||||
c.servicePort = OPENAI_PORT
|
||||
}
|
||||
if c.serviceHost == "" {
|
||||
c.serviceHost = OPENAI_DOMAIN
|
||||
}
|
||||
if c.model == "" {
|
||||
c.model = OPENAI_DEFAULT_MODEL_NAME
|
||||
}
|
||||
return &OpenAIProvider{
|
||||
config: c,
|
||||
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: c.serviceName,
|
||||
Host: c.serviceHost,
|
||||
Port: c.servicePort,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *OpenAIProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_OPENAI
|
||||
}
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIResult `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAIResult struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"prompt_tokens"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Param string `json:"param"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
config ProviderConfig
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
|
||||
if text == "" {
|
||||
err := errors.New("queryString text cannot be empty")
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
data := OpenAIEmbeddingRequest{
|
||||
Input: text,
|
||||
Model: t.config.model,
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal request data: %v", err)
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
headers := [][2]string{
|
||||
{"Authorization", fmt.Sprintf("Bearer %s", openAIConfig.apiKey)},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
|
||||
return OPENAI_ENDPOINT, headers, requestBody, err
|
||||
}
|
||||
|
||||
func (t *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
|
||||
var resp OpenAIResponse
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (t *OpenAIProvider) GetEmbedding(
|
||||
queryString string,
|
||||
ctx wrapper.HttpContext,
|
||||
log wrapper.Log,
|
||||
callback func(emb []float64, err error)) error {
|
||||
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
|
||||
if err != nil {
|
||||
log.Errorf("failed to construct parameters: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var resp *OpenAIResponse
|
||||
err = t.client.Post(embUrl, embHeaders, embRequestBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
err = fmt.Errorf("failed to get embedding due to status code: %d, resp: %s", statusCode, responseBody)
|
||||
callback(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err = t.parseTextEmbedding(responseBody)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse response: %v", err)
|
||||
callback(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("get embedding response: %d, %s", statusCode, responseBody)
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
err = errors.New("no embedding found in response")
|
||||
callback(nil, err)
|
||||
return
|
||||
}
|
||||
|
||||
callback(resp.Data[0].Embedding, nil)
|
||||
|
||||
}, t.config.timeout)
|
||||
return err
|
||||
}
|
||||
@@ -10,10 +10,13 @@ import (
|
||||
const (
|
||||
PROVIDER_TYPE_DASHSCOPE = "dashscope"
|
||||
PROVIDER_TYPE_TEXTIN = "textin"
|
||||
PROVIDER_TYPE_COHERE = "cohere"
|
||||
PROVIDER_TYPE_OPENAI = "openai"
|
||||
)
|
||||
|
||||
type providerInitializer interface {
|
||||
ValidateConfig(ProviderConfig) error
|
||||
InitConfig(json gjson.Result)
|
||||
ValidateConfig() error
|
||||
CreateProvider(ProviderConfig) (Provider, error)
|
||||
}
|
||||
|
||||
@@ -21,6 +24,8 @@ var (
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
|
||||
PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{},
|
||||
PROVIDER_TYPE_COHERE: &cohereProviderInitializer{},
|
||||
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -37,35 +42,26 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 文本特征提取服务端口
|
||||
// @Description zh-CN 文本特征提取服务端口
|
||||
servicePort int64
|
||||
// @Title zh-CN 文本特征提取服务 API Key
|
||||
// @Description zh-CN 文本特征提取服务 API Key
|
||||
apiKey string
|
||||
//@Title zh-CN TextIn x-ti-app-id
|
||||
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
|
||||
textinAppId string
|
||||
//@Title zh-CN TextIn x-ti-secret-code
|
||||
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
|
||||
textinSecretCode string
|
||||
//@Title zh-CN TextIn request matryoshka_dim
|
||||
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
|
||||
textinMatryoshkaDim int
|
||||
// @Title zh-CN 文本特征提取服务超时时间
|
||||
// @Description zh-CN 文本特征提取服务超时时间
|
||||
timeout uint32
|
||||
// @Title zh-CN 文本特征提取服务使用的模型
|
||||
// @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1"
|
||||
model string
|
||||
|
||||
initializer providerInitializer
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.typ = json.Get("type").String()
|
||||
i, has := providerInitializers[c.typ]
|
||||
if has {
|
||||
i.InitConfig(json)
|
||||
c.initializer = i
|
||||
}
|
||||
c.serviceName = json.Get("serviceName").String()
|
||||
c.serviceHost = json.Get("serviceHost").String()
|
||||
c.servicePort = json.Get("servicePort").Int()
|
||||
c.apiKey = json.Get("apiKey").String()
|
||||
c.textinAppId = json.Get("textinAppId").String()
|
||||
c.textinSecretCode = json.Get("textinSecretCode").String()
|
||||
c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
|
||||
c.timeout = uint32(json.Get("timeout").Int())
|
||||
c.model = json.Get("model").String()
|
||||
if c.timeout == 0 {
|
||||
@@ -80,11 +76,10 @@ func (c *ProviderConfig) Validate() error {
|
||||
if c.typ == "" {
|
||||
return errors.New("embedding service type is required")
|
||||
}
|
||||
initializer, has := providerInitializers[c.typ]
|
||||
if !has {
|
||||
if c.initializer == nil {
|
||||
return errors.New("unknown embedding service provider type: " + c.typ)
|
||||
}
|
||||
if err := initializer.ValidateConfig(*c); err != nil {
|
||||
if err := c.initializer.ValidateConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -20,14 +21,34 @@ const (
|
||||
type textInProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (t *textInProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if config.textinAppId == "" {
|
||||
return errors.New("embedding service TextIn App ID is required")
|
||||
var textInConfig textInProviderConfig
|
||||
|
||||
type textInProviderConfig struct {
|
||||
//@Title zh-CN TextIn x-ti-app-id
|
||||
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
|
||||
textinAppId string
|
||||
//@Title zh-CN TextIn x-ti-secret-code
|
||||
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
|
||||
textinSecretCode string
|
||||
//@Title zh-CN TextIn request matryoshka_dim
|
||||
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
|
||||
textinMatryoshkaDim int
|
||||
}
|
||||
|
||||
func (c *textInProviderInitializer) InitConfig(json gjson.Result) {
|
||||
textInConfig.textinAppId = json.Get("textinAppId").String()
|
||||
textInConfig.textinSecretCode = json.Get("textinSecretCode").String()
|
||||
textInConfig.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
|
||||
}
|
||||
|
||||
func (c *textInProviderInitializer) ValidateConfig() error {
|
||||
if textInConfig.textinAppId == "" {
|
||||
return errors.New("textinAppId is required")
|
||||
}
|
||||
if config.textinSecretCode == "" {
|
||||
return errors.New("embedding service TextIn Secret Code is required")
|
||||
if textInConfig.textinSecretCode == "" {
|
||||
return errors.New("textinSecretCode is required")
|
||||
}
|
||||
if config.textinMatryoshkaDim == 0 {
|
||||
if textInConfig.textinMatryoshkaDim == 0 {
|
||||
return errors.New("embedding service TextIn Matryoshka Dim is required")
|
||||
}
|
||||
return nil
|
||||
@@ -62,7 +83,7 @@ type TextInResponse struct {
|
||||
}
|
||||
|
||||
type TextInResult struct {
|
||||
Embeddings [][]float64 `json:"embedding"`
|
||||
Embeddings [][]float64 `json:"embedding"`
|
||||
MatryoshkaDim int `json:"matryoshka_dim"`
|
||||
}
|
||||
|
||||
@@ -80,7 +101,7 @@ func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (strin
|
||||
|
||||
data := TextInEmbeddingRequest{
|
||||
Input: texts,
|
||||
MatryoshkaDim: t.config.textinMatryoshkaDim,
|
||||
MatryoshkaDim: textInConfig.textinMatryoshkaDim,
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(data)
|
||||
@@ -89,20 +110,20 @@ func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (strin
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
if t.config.textinAppId == "" {
|
||||
if textInConfig.textinAppId == "" {
|
||||
err := errors.New("textinAppId is empty")
|
||||
log.Errorf("failed to construct headers: %v", err)
|
||||
return "", nil, nil, err
|
||||
}
|
||||
if t.config.textinSecretCode == "" {
|
||||
if textInConfig.textinSecretCode == "" {
|
||||
err := errors.New("textinSecretCode is empty")
|
||||
log.Errorf("failed to construct headers: %v", err)
|
||||
return "", nil, nil, err
|
||||
}
|
||||
|
||||
headers := [][2]string{
|
||||
{"x-ti-app-id", t.config.textinAppId},
|
||||
{"x-ti-secret-code", t.config.textinSecretCode},
|
||||
{"x-ti-app-id", textInConfig.textinAppId},
|
||||
{"x-ti-secret-code", textInConfig.textinSecretCode},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
|
||||
|
||||
@@ -8,14 +8,14 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
github.com/tidwall/resp v0.1.1
|
||||
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/stretchr/testify v1.9.0 // indirect
|
||||
|
||||
@@ -3,8 +3,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
|
||||
@@ -22,6 +22,8 @@ const (
|
||||
STREAM_CONTEXT_KEY = "stream"
|
||||
SKIP_CACHE_HEADER = "x-higress-skip-ai-cache"
|
||||
ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage"
|
||||
|
||||
DEFAULT_MAX_BODY_BYTES uint32 = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -69,6 +71,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wr
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetRequestBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
// The request has a body and requires delaying the header transmission until a cache miss occurs,
|
||||
// at which point the header should be sent.
|
||||
@@ -128,12 +131,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
|
||||
skipCache := ctx.GetContext(SKIP_CACHE_HEADER)
|
||||
if skipCache != nil {
|
||||
ctx.SetUserAttribute("cache_status", "skip")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
if ctx.GetContext(CACHE_KEY_CONTEXT_KEY) != nil {
|
||||
ctx.SetUserAttribute("cache_status", "miss")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
if strings.Contains(contentType, "text/event-stream") {
|
||||
ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{})
|
||||
} else {
|
||||
ctx.SetResponseBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
|
||||
}
|
||||
|
||||
if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil {
|
||||
@@ -158,22 +169,26 @@ func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []
|
||||
return chunk
|
||||
}
|
||||
|
||||
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
|
||||
var err error
|
||||
if !isLastChunk {
|
||||
if err := handleNonLastChunk(ctx, c, chunk, log); err != nil {
|
||||
if stream == nil {
|
||||
err = handleNonStreamChunk(ctx, c, chunk, log)
|
||||
} else {
|
||||
err = handleStreamChunk(ctx, c, unifySSEChunk(chunk), log)
|
||||
}
|
||||
if err != nil {
|
||||
log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err)
|
||||
// Set an empty struct in the context to indicate an error in processing the partial message
|
||||
ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{})
|
||||
}
|
||||
return chunk
|
||||
}
|
||||
|
||||
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
|
||||
var value string
|
||||
var err error
|
||||
if stream == nil {
|
||||
value, err = processNonStreamLastChunk(ctx, c, chunk, log)
|
||||
} else {
|
||||
value, err = processStreamLastChunk(ctx, c, chunk, log)
|
||||
value, err = processStreamLastChunk(ctx, c, unifySSEChunk(chunk), log)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -9,17 +10,6 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
|
||||
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
|
||||
err := error(nil)
|
||||
if stream == nil {
|
||||
err = handleNonStreamChunk(ctx, c, chunk, log)
|
||||
} else {
|
||||
err = handleStreamChunk(ctx, c, chunk, log)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
|
||||
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
|
||||
if tempContentI == nil {
|
||||
@@ -32,6 +22,12 @@ func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk
|
||||
return nil
|
||||
}
|
||||
|
||||
func unifySSEChunk(data []byte) []byte {
|
||||
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
|
||||
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
|
||||
return data
|
||||
}
|
||||
|
||||
func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
|
||||
var partialMessage []byte
|
||||
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
|
||||
@@ -101,55 +97,54 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) {
|
||||
subMessages := strings.Split(sseMessage, "\n")
|
||||
var message string
|
||||
for _, msg := range subMessages {
|
||||
if strings.HasPrefix(msg, "data:") {
|
||||
message = msg
|
||||
break
|
||||
content := ""
|
||||
for _, chunk := range strings.Split(sseMessage, "\n\n") {
|
||||
log.Debugf("single sse message: %s", chunk)
|
||||
subMessages := strings.Split(chunk, "\n")
|
||||
var message string
|
||||
for _, msg := range subMessages {
|
||||
if strings.HasPrefix(msg, "data:") {
|
||||
message = msg
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(message) < 6 {
|
||||
return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message)
|
||||
}
|
||||
|
||||
// skip the prefix "data:"
|
||||
bodyJson := message[5:]
|
||||
|
||||
if strings.TrimSpace(bodyJson) == "[DONE]" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Extract values from JSON fields
|
||||
responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
|
||||
toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
|
||||
|
||||
if toolCalls.Exists() {
|
||||
// TODO: Temporarily store the tool_calls value in the context for processing
|
||||
ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
|
||||
}
|
||||
|
||||
// Check if the ResponseBody field exists
|
||||
if !responseBody.Exists() {
|
||||
if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
|
||||
log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
|
||||
return "", nil
|
||||
if len(message) < 6 {
|
||||
return content, fmt.Errorf("[processSSEMessage] invalid message: %s", message)
|
||||
}
|
||||
return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
|
||||
} else {
|
||||
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
|
||||
|
||||
// If there is no content in the cache, initialize and set the content
|
||||
if tempContentI == nil {
|
||||
content := responseBody.String()
|
||||
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
|
||||
// skip the prefix "data:"
|
||||
bodyJson := message[5:]
|
||||
|
||||
if strings.TrimSpace(bodyJson) == "[DONE]" {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Update the content in the cache
|
||||
appendMsg := responseBody.String()
|
||||
content := tempContentI.(string) + appendMsg
|
||||
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
|
||||
return content, nil
|
||||
// Extract values from JSON fields
|
||||
responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
|
||||
toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
|
||||
|
||||
if toolCalls.Exists() {
|
||||
// TODO: Temporarily store the tool_calls value in the context for processing
|
||||
ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
|
||||
}
|
||||
|
||||
// Check if the ResponseBody field exists
|
||||
if !responseBody.Exists() {
|
||||
if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
|
||||
log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
|
||||
return content, nil
|
||||
}
|
||||
return content, fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
|
||||
} else {
|
||||
content += responseBody.String()
|
||||
}
|
||||
}
|
||||
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
|
||||
// If there is no content in the cache, initialize and set the content
|
||||
if tempContentI == nil {
|
||||
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
|
||||
} else {
|
||||
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContentI.(string)+content)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
|
||||
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
|
||||
@@ -194,6 +194,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
ctx.SetContext(StreamContextKey, struct{}{})
|
||||
}
|
||||
identityKey := ctx.GetStringContext(IdentityKey, "")
|
||||
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
|
||||
if question == "" {
|
||||
log.Debug("parse question from request body failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetContext(QuestionContextKey, question)
|
||||
err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("redis get failed, err:%v", err)
|
||||
@@ -230,13 +236,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
_ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1)
|
||||
return
|
||||
}
|
||||
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
|
||||
if question == "" {
|
||||
log.Debug("parse question from request body failed")
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
ctx.SetContext(QuestionContextKey, question)
|
||||
fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2
|
||||
currJson := bodyJson.Get("messages").String()
|
||||
var currMessage []ChatHistory
|
||||
@@ -317,38 +316,39 @@ func getIntQueryParameter(name string, path string, defaultValue int) int {
|
||||
}
|
||||
|
||||
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
|
||||
subMessages := strings.Split(sseMessage, "\n")
|
||||
var message string
|
||||
for _, msg := range subMessages {
|
||||
if strings.HasPrefix(msg, "data:") {
|
||||
message = msg
|
||||
break
|
||||
content := ""
|
||||
for _, chunk := range strings.Split(sseMessage, "\n\n") {
|
||||
subMessages := strings.Split(chunk, "\n")
|
||||
var message string
|
||||
for _, msg := range subMessages {
|
||||
if strings.HasPrefix(msg, "data:") {
|
||||
message = msg
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(message) < 6 {
|
||||
log.Errorf("invalid message:%s", message)
|
||||
return ""
|
||||
}
|
||||
// skip the prefix "data:"
|
||||
bodyJson := message[5:]
|
||||
if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
|
||||
tempContentI := ctx.GetContext(AnswerContentContextKey)
|
||||
if tempContentI == nil {
|
||||
content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
|
||||
ctx.SetContext(AnswerContentContextKey, content)
|
||||
if len(message) < 6 {
|
||||
log.Errorf("invalid message:%s", message)
|
||||
return content
|
||||
}
|
||||
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
|
||||
content := tempContentI.(string) + append
|
||||
ctx.SetContext(AnswerContentContextKey, content)
|
||||
return content
|
||||
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
|
||||
// TODO: compatible with other providers
|
||||
ctx.SetContext(ToolCallsContextKey, struct{}{})
|
||||
return ""
|
||||
// skip the prefix "data:"
|
||||
bodyJson := message[5:]
|
||||
if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
|
||||
tempContentI := ctx.GetContext(AnswerContentContextKey)
|
||||
if tempContentI == nil {
|
||||
content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
|
||||
ctx.SetContext(AnswerContentContextKey, content)
|
||||
} else {
|
||||
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
|
||||
content = tempContentI.(string) + append
|
||||
ctx.SetContext(AnswerContentContextKey, content)
|
||||
}
|
||||
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
|
||||
// TODO: compatible with other providers
|
||||
ctx.SetContext(ToolCallsContextKey, struct{}{})
|
||||
}
|
||||
log.Debugf("unknown message:%s", bodyJson)
|
||||
}
|
||||
log.Debugf("unknown message:%s", bodyJson)
|
||||
return ""
|
||||
return content
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
|
||||
@@ -41,6 +41,7 @@ description: AI 代理插件配置参考
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
|
||||
|
||||
`context`的配置字段说明如下:
|
||||
|
||||
@@ -78,14 +79,22 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
|
||||
|
||||
`failover` 的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|------|-------|-----------------------------|
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 必填 | | 健康检测使用的模型 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|-----------------|-------|-----------------------------|
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
|
||||
|
||||
`retryOnFailure` 的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|-----------------|-------|-------------|
|
||||
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
|
||||
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
|
||||
| retryTimeout | int | 非必填 | 30000 | 重试超时时间,单位毫秒 |
|
||||
|
||||
### 提供商特有配置
|
||||
|
||||
@@ -174,9 +183,10 @@ Mistral 所对应的 `type` 为 `mistral`。它并无特有的配置字段。
|
||||
|
||||
MiniMax所对应的 `type` 为 `minimax`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ |
|
||||
| `minimaxGroupId` | string | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时必填 | - | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时会使用ChatCompletion Pro,需要设置groupID |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | ------------------------------ | ------ |----------------------------------------------------------------|
|
||||
| `minimaxApiType` | string | v2 和 pro 中选填一项 | v2 | v2 代表 ChatCompletion v2 API,pro 代表 ChatCompletion Pro API |
|
||||
| `minimaxGroupId` | string | `minimaxApiType` 为 pro 时必填 | - | `minimaxApiType` 为 pro 时使用 ChatCompletion Pro API,需要设置 groupID |
|
||||
|
||||
#### Anthropic Claude
|
||||
|
||||
@@ -242,6 +252,9 @@ DeepL 所对应的 `type` 为 `deepl`。它特有的配置字段如下:
|
||||
|
||||
Cohere 所对应的 `type` 为 `cohere`。它并无特有的配置字段。
|
||||
|
||||
#### Together-AI
|
||||
Together-AI 所对应的 `type` 为 `together-ai`。它并无特有的配置字段。
|
||||
|
||||
## 用法示例
|
||||
|
||||
### 使用 OpenAI 协议代理 Azure OpenAI 服务
|
||||
@@ -1000,17 +1013,16 @@ provider:
|
||||
apiTokens:
|
||||
- "YOUR_MINIMAX_API_TOKEN"
|
||||
modelMapping:
|
||||
"gpt-3": "abab6.5g-chat"
|
||||
"gpt-4": "abab6.5-chat"
|
||||
"*": "abab6.5g-chat"
|
||||
minimaxGroupId: "YOUR_MINIMAX_GROUP_ID"
|
||||
"gpt-3": "abab6.5s-chat"
|
||||
"gpt-4": "abab6.5g-chat"
|
||||
"*": "abab6.5t-chat"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
"model": "gpt-3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -1025,27 +1037,33 @@ provider:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "02b2251f8c6c09d68c1743f07c72afd7",
|
||||
"id": "03ac4fcfe1c6cc9c6a60f9d12046e2b4",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"content": "你好!我是MM智能助理,一款由MiniMax自研的大型语言模型。我可以帮助你解答问题,提供信息,进行对话等。有什么可以帮助你的吗?",
|
||||
"role": "assistant"
|
||||
"content": "你好,我是一个由MiniMax公司研发的大型语言模型,名为MM智能助理。我可以帮助回答问题、提供信息、进行对话和执行多种语言处理任务。如果你有任何问题或需要帮助,请随时告诉我!",
|
||||
"role": "assistant",
|
||||
"name": "MM智能助理",
|
||||
"audio_content": ""
|
||||
}
|
||||
}
|
||||
],
|
||||
"created": 1717760544,
|
||||
"created": 1734155471,
|
||||
"model": "abab6.5s-chat",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"total_tokens": 106
|
||||
"total_tokens": 116,
|
||||
"total_characters": 0,
|
||||
"prompt_tokens": 70,
|
||||
"completion_tokens": 46
|
||||
},
|
||||
"input_sensitive": false,
|
||||
"output_sensitive": false,
|
||||
"input_sensitive_type": 0,
|
||||
"output_sensitive_type": 0,
|
||||
"output_sensitive_int": 0,
|
||||
"base_resp": {
|
||||
"status_code": 0,
|
||||
"status_msg": ""
|
||||
@@ -1490,6 +1508,61 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Together-AI 服务
|
||||
|
||||
**配置信息**
|
||||
```yaml
|
||||
provider:
|
||||
type: together-ai
|
||||
apiTokens:
|
||||
- "YOUR_TOGETHER_AI_API_TOKEN"
|
||||
modelMapping:
|
||||
"*": "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
```json
|
||||
{
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
```json
|
||||
{
|
||||
"id": "8f5809d54b73efac",
|
||||
"object": "chat.completion",
|
||||
"created": 1734785851,
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"prompt": [],
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos",
|
||||
"seed": 12830868308626506000,
|
||||
"logprobs": null,
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am Qwen, a large language model created by Alibaba Cloud. I am designed to assist users in generating various types of text, such as articles, stories, poems, and more, as well as answering questions and providing information on a wide range of topics. How can I assist you today?",
|
||||
"tool_calls": []
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 33,
|
||||
"completion_tokens": 61,
|
||||
"total_tokens": 94
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 完整配置示例
|
||||
|
||||
### Kubernetes 示例
|
||||
|
||||
@@ -1356,6 +1356,60 @@ Here, `model` denotes the service tier of DeepL and can only be either `Free` or
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Together-AI Services
|
||||
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: together-ai
|
||||
apiTokens:
|
||||
- "YOUR_TOGETHER_AI_API_TOKEN"
|
||||
modelMapping:
|
||||
"*": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "8f5809d54b73efac",
|
||||
"object": "chat.completion",
|
||||
"created": 1734785851,
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
"prompt": [],
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "eos",
|
||||
"seed": 12830868308626506000,
|
||||
"logprobs": null,
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am Qwen, a large language model created by Alibaba Cloud. I am designed to assist users in generating various types of text, such as articles, stories, poems, and more, as well as answering questions and providing information on a wide range of topics. How can I assist you today?",
|
||||
"tool_calls": []
|
||||
}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 33,
|
||||
"completion_tokens": 61,
|
||||
"total_tokens": 94
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
### Kubernetes Example
|
||||
|
||||
@@ -20,8 +20,6 @@ import (
|
||||
const (
|
||||
pluginName = "ai-proxy"
|
||||
|
||||
ctxKeyApiName = "apiName"
|
||||
|
||||
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
@@ -89,29 +87,34 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
|
||||
if apiName == "" {
|
||||
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
||||
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
|
||||
log.Debugf("[onHttpRequestHeader] no send response")
|
||||
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetContext(ctxKeyApiName, apiName)
|
||||
|
||||
ctx.SetContext(provider.CtxKeyApiName, apiName)
|
||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||
ctx.DisableReroute()
|
||||
|
||||
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
||||
if needHandleStreamingBody {
|
||||
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
}
|
||||
|
||||
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||
ctx.DisableReroute()
|
||||
// Set the apiToken for the current request.
|
||||
providerConfig.SetApiTokenInUse(ctx, log)
|
||||
|
||||
hasRequestBody := wrapper.HasRequestBody()
|
||||
action, err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||
err := handler.OnRequestHeaders(ctx, apiName, log)
|
||||
if err == nil {
|
||||
if hasRequestBody {
|
||||
proxywasm.RemoveHttpRequestHeader("Content-Length")
|
||||
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
||||
// Always return types.HeaderStopIteration to support fallback routing,
|
||||
// as long as onHttpRequestBody can be called.
|
||||
// Delay the header processing to allow changing in OnRequestBody
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
return action
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
||||
@@ -132,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
||||
|
||||
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
|
||||
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
||||
if settingErr != nil {
|
||||
@@ -180,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
log.Errorf("unable to load :status header from response: %v", err)
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log)
|
||||
|
||||
return types.ActionContinue
|
||||
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
|
||||
}
|
||||
|
||||
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
||||
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
||||
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
||||
|
||||
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
|
||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
||||
action, err := handler.OnResponseHeaders(ctx, apiName, log)
|
||||
if err == nil {
|
||||
checkStream(&ctx, log)
|
||||
return action
|
||||
}
|
||||
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
|
||||
return types.ActionContinue
|
||||
headers := util.GetOriginalResponseHeaders()
|
||||
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||
} else {
|
||||
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
|
||||
}
|
||||
util.ReplaceResponseHeaders(headers)
|
||||
|
||||
checkStream(&ctx, log)
|
||||
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
|
||||
checkStream(ctx, log)
|
||||
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
||||
if !needHandleBody && !needHandleStreamingBody {
|
||||
ctx.DontReadResponseBody()
|
||||
} else if !needHandleStreamingBody {
|
||||
if !needHandleStreamingBody {
|
||||
ctx.BufferResponseBody()
|
||||
}
|
||||
|
||||
@@ -224,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
||||
|
||||
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
return modifiedChunk
|
||||
@@ -243,27 +239,29 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
}
|
||||
|
||||
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
|
||||
//log.Debugf("response body: %s", string(body))
|
||||
|
||||
if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
|
||||
action, err := handler.OnResponseBody(ctx, apiName, body, log)
|
||||
if err == nil {
|
||||
return action
|
||||
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
||||
if err != nil {
|
||||
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
||||
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||
}
|
||||
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func checkStream(ctx *wrapper.HttpContext, log wrapper.Log) {
|
||||
func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
||||
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
||||
if err != nil {
|
||||
log.Errorf("unable to load content-type header from response: %v", err)
|
||||
}
|
||||
(*ctx).BufferResponseBody()
|
||||
ctx.BufferResponseBody()
|
||||
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -40,13 +40,13 @@ func (m *ai360Provider) GetProviderType() string {
|
||||
return providerTypeAi360
|
||||
}
|
||||
|
||||
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
|
||||
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, ai360Domain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
@@ -53,12 +53,12 @@ func (m *azureProvider) GetProviderType() string {
|
||||
return providerTypeAzure
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
|
||||
}
|
||||
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func (m *baichuanProvider) GetProviderType() string {
|
||||
return providerTypeBaichuan
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -63,12 +63,12 @@ func (g *baiduProvider) GetProviderType() string {
|
||||
return providerTypeBaidu
|
||||
}
|
||||
|
||||
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -102,27 +101,25 @@ func (c *claudeProvider) GetProviderType() string {
|
||||
return providerTypeClaude
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, claudeDomain)
|
||||
|
||||
headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx))
|
||||
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
|
||||
|
||||
if c.config.claudeVersion == "" {
|
||||
c.config.claudeVersion = defaultVersion
|
||||
}
|
||||
|
||||
headers.Add("anthropic-version", c.config.claudeVersion)
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
headers.Set("anthropic-version", c.config.claudeVersion)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -141,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
|
||||
return json.Marshal(claudeRequest)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
claudeResponse := &claudeTextGenResponse{}
|
||||
if err := json.Unmarshal(body, claudeResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
|
||||
}
|
||||
if claudeResponse.Error != nil {
|
||||
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
|
||||
return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
|
||||
}
|
||||
response := c.responseClaude2OpenAI(ctx, claudeResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
|
||||
if c.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
|
||||
@@ -42,12 +42,12 @@ func (c *cloudflareProvider) GetProviderType() string {
|
||||
return providerTypeCloudflare
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
c.config.handleRequestHeaders(c, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
|
||||
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
|
||||
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -3,11 +3,12 @@ package provider
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -54,12 +55,12 @@ func (m *cohereProvider) GetProviderType() string {
|
||||
return providerTypeCohere
|
||||
}
|
||||
|
||||
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
|
||||
if err != nil {
|
||||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
|
||||
}
|
||||
if err := replaceHttpJsonRequestBody(body, log); err != nil {
|
||||
if err := replaceRequestBody(body, log); err != nil {
|
||||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,9 +37,9 @@ func (m *cozeProvider) GetProviderType() string {
|
||||
return providerTypeCoze
|
||||
}
|
||||
|
||||
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -76,19 +75,17 @@ func (d *deeplProvider) GetProviderType() string {
|
||||
return providerTypeDeepl
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
d.config.handleRequestHeaders(d, ctx, apiName, log)
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
headers.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -114,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
|
||||
return json.Marshal(baiduRequest)
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
deeplResponse := &deeplResponse{}
|
||||
if err := json.Unmarshal(body, deeplResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
|
||||
}
|
||||
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {
|
||||
|
||||
@@ -2,10 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// deepseekProvider is the provider for deepseek Ai service.
|
||||
@@ -41,12 +42,12 @@ func (m *deepseekProvider) GetProviderType() string {
|
||||
return providerTypeDeepSeek
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -2,11 +2,12 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,12 +40,12 @@ func (m *doubaoProvider) GetProviderType() string {
|
||||
return providerTypeDoubao
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
type failover struct {
|
||||
// @Title zh-CN 是否启用 apiToken 的 failover 机制
|
||||
enabled bool `required:"true" yaml:"enabled" json:"enabled"`
|
||||
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
|
||||
// @Title zh-CN 触发 failover 连续请求失败的阈值
|
||||
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
|
||||
// @Title zh-CN 健康检测的成功阈值
|
||||
@@ -29,7 +29,7 @@ type failover struct {
|
||||
// @Title zh-CN 健康检测的超时时间,单位毫秒
|
||||
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
|
||||
// @Title zh-CN 健康检测使用的模型
|
||||
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||
healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
|
||||
// @Title zh-CN 本次请求使用的 apiToken
|
||||
ctxApiTokenInUse string
|
||||
// @Title zh-CN 记录 apiToken 请求失败的次数,key 为 apiToken,value 为失败次数
|
||||
@@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
|
||||
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalHttpHeaders()
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
|
||||
util.ReplaceOriginalHttpHeaders(headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
|
||||
}
|
||||
@@ -539,10 +539,15 @@ func (c *ProviderConfig) resetSharedData() {
|
||||
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) {
|
||||
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action {
|
||||
if c.isFailoverEnabled() {
|
||||
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
|
||||
}
|
||||
if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) {
|
||||
c.retryFailedRequest(activeProvider, ctx, log)
|
||||
return types.HeaderStopAllIterationAndWatermark
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
|
||||
@@ -557,7 +562,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
|
||||
} else {
|
||||
apiToken = c.GetRandomToken()
|
||||
}
|
||||
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken)
|
||||
log.Debugf("Use apiToken %s to send request", apiToken)
|
||||
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
|
||||
}
|
||||
|
||||
|
||||
@@ -51,20 +51,18 @@ func (g *geminiProvider) GetProviderType() string {
|
||||
return providerTypeGemini
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -107,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if g.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
log.Infof("chunk body:%s", string(chunk))
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
@@ -150,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return g.onChatCompletionResponseBody(ctx, body, log)
|
||||
} else if apiName == ApiNameEmbeddings {
|
||||
} else {
|
||||
return g.onEmbeddingsResponseBody(ctx, body, log)
|
||||
}
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
geminiResponse := &geminiChatResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
||||
}
|
||||
if geminiResponse.Error != nil {
|
||||
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
|
||||
return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
|
||||
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
|
||||
}
|
||||
response := g.buildChatCompletionResponse(ctx, geminiResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
geminiResponse := &geminiEmbeddingResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
|
||||
}
|
||||
if geminiResponse.Error != nil {
|
||||
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
|
||||
return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
|
||||
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
|
||||
}
|
||||
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
|
||||
|
||||
@@ -2,11 +2,12 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// githubProvider is the provider for GitHub OpenAI service.
|
||||
@@ -42,13 +43,13 @@ func (m *githubProvider) GetProviderType() string {
|
||||
return providerTypeGithub
|
||||
}
|
||||
|
||||
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -67,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
|
||||
}
|
||||
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *githubProvider) GetApiName(path string) ApiName {
|
||||
|
||||
@@ -41,12 +41,12 @@ func (g *groqProvider) GetProviderType() string {
|
||||
return providerTypeGroq
|
||||
}
|
||||
|
||||
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -114,13 +114,13 @@ func (m *hunyuanProvider) GetProviderType() string {
|
||||
return providerTypeHunyuan
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
@@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
|
||||
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
|
||||
|
||||
// 添加 hunyuan 需要的自定义字段
|
||||
headers.Add(actionKey, hunyuanChatCompletionTCAction)
|
||||
headers.Add(versionKey, versionValue)
|
||||
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
headers.Set(actionKey, hunyuanChatCompletionTCAction)
|
||||
headers.Set(versionKey, versionValue)
|
||||
}
|
||||
|
||||
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
|
||||
@@ -291,11 +288,6 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
||||
return json.Marshal(hunyuanRequest)
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if m.config.protocol == protocolOriginal {
|
||||
return chunk, nil
|
||||
@@ -412,21 +404,14 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
|
||||
return []byte(openAIChunk.String()), nil
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
|
||||
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
|
||||
if err := json.Unmarshal(body, hunyuanResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
|
||||
}
|
||||
|
||||
if m.config.protocol == protocolOriginal {
|
||||
return types.ActionContinue, replaceJsonResponseBody(hunyuanResponse, log)
|
||||
}
|
||||
|
||||
response := m.buildChatCompletionResponse(ctx, hunyuanResponse)
|
||||
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
||||
|
||||
@@ -11,47 +11,37 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// minimaxProvider is the provider for minimax service.
|
||||
|
||||
const (
|
||||
minimaxDomain = "api.minimax.chat"
|
||||
// minimaxChatCompletionV2Path 接口请求响应格式与OpenAI相同
|
||||
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
|
||||
minimaxApiTypeV2 = "v2" // minimaxApiTypeV2 represents chat completion V2 API.
|
||||
minimaxApiTypePro = "pro" // minimaxApiTypePro represents chat completion Pro API.
|
||||
minimaxDomain = "api.minimax.chat"
|
||||
// minimaxChatCompletionV2Path represents the API path for chat completion V2 API which has a response format similar to OpenAI's.
|
||||
minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2"
|
||||
// minimaxChatCompletionProPath 接口请求响应格式与OpenAI不同
|
||||
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/pro/api?id=6569c85948bc7b684b30377e
|
||||
// minimaxChatCompletionProPath represents the API path for chat completion Pro API which has a different response format from OpenAI's.
|
||||
minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro"
|
||||
|
||||
senderTypeUser string = "USER" // 用户发送的内容
|
||||
senderTypeBot string = "BOT" // 模型生成的内容
|
||||
senderTypeUser string = "USER" // Content sent by the user.
|
||||
senderTypeBot string = "BOT" // Content generated by the model.
|
||||
|
||||
// 默认机器人设置
|
||||
// Default bot settings.
|
||||
defaultBotName string = "MM智能助理"
|
||||
defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。"
|
||||
defaultSenderName string = "小明"
|
||||
)
|
||||
|
||||
// chatCompletionProModels 这些模型对应接口为ChatCompletion Pro
|
||||
var chatCompletionProModels = map[string]struct{}{
|
||||
"abab6.5-chat": {},
|
||||
"abab6.5s-chat": {},
|
||||
"abab5.5s-chat": {},
|
||||
"abab5.5-chat": {},
|
||||
}
|
||||
|
||||
type minimaxProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
// 如果存在模型对应接口为ChatCompletion Pro必须配置minimaxGroupId
|
||||
if len(config.modelMapping) > 0 && config.minimaxGroupId == "" {
|
||||
for _, minimaxModel := range config.modelMapping {
|
||||
if _, exists := chatCompletionProModels[minimaxModel]; exists {
|
||||
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when %s model is provided", minimaxModel))
|
||||
}
|
||||
}
|
||||
// If using the chat completion Pro API, a group ID must be set.
|
||||
if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" {
|
||||
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro))
|
||||
}
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
@@ -75,13 +65,13 @@ func (m *minimaxProvider) GetProviderType() string {
|
||||
return providerTypeMinimax
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
@@ -94,23 +84,11 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
// 解析并映射模型,设置上下文
|
||||
model, err := m.parseModel(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
ctx.SetContext(ctxKeyOriginalRequestModel, model)
|
||||
mappedModel := getMappedModel(model, m.config.modelMapping, log)
|
||||
if mappedModel == "" {
|
||||
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
|
||||
}
|
||||
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
|
||||
_, ok := chatCompletionProModels[mappedModel]
|
||||
if ok {
|
||||
// 使用ChatCompletion Pro接口
|
||||
if minimaxApiTypePro == m.config.minimaxApiType {
|
||||
// Use chat completion Pro API.
|
||||
return m.handleRequestBodyByChatCompletionPro(body, log)
|
||||
} else {
|
||||
// 使用ChatCompletion v2接口
|
||||
// Use chat completion V2 API.
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
}
|
||||
@@ -119,14 +97,14 @@ func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
|
||||
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
|
||||
// handleRequestBodyByChatCompletionPro processes the request body using the chat completion Pro API.
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
// Map the model and rewrite the request path.
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
|
||||
|
||||
@@ -143,9 +121,9 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
util.ErrorHandler("ai-proxy.minimax.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
|
||||
}
|
||||
// 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
|
||||
// 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
|
||||
// minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
|
||||
// Since minimaxChatCompletionV2 (format consistent with OpenAI) and minimaxChatCompletionPro (different format from OpenAI) have different logic for insertHttpContextMessage, we cannot unify them within one provider.
|
||||
// For minimaxChatCompletionPro, we need to manually handle context messages.
|
||||
// minimaxChatCompletionV2 uses the default defaultInsertHttpContextMessage method to insert context messages.
|
||||
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
|
||||
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
|
||||
util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err))
|
||||
@@ -157,54 +135,42 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体
|
||||
// handleRequestBodyByChatCompletionV2 processes the request body using the chat completion V2 API.
|
||||
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 映射模型重写requestPath
|
||||
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
|
||||
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
|
||||
|
||||
return body, nil
|
||||
rawModel := gjson.GetBytes(body, "model").String()
|
||||
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
|
||||
return sjson.SetBytes(body, "model", mappedModel)
|
||||
}
|
||||
|
||||
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
// 使用minimax接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
|
||||
// Skip OnStreamingResponseBody() and OnResponseBody() when using original protocol.
|
||||
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
// 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody()
|
||||
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
||||
if model != "" {
|
||||
_, ok := chatCompletionProModels[model]
|
||||
if !ok {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// Skip OnStreamingResponseBody() and OnResponseBody() when the model corresponds to the chat completion V2 interface.
|
||||
if minimaxApiTypePro != m.config.minimaxApiType {
|
||||
ctx.DontReadResponseBody()
|
||||
}
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
|
||||
// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// sample event response:
|
||||
// Sample event response:
|
||||
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
|
||||
|
||||
// sample end event response:
|
||||
// Sample end event response:
|
||||
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}}
|
||||
responseBuilder := &strings.Builder{}
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
for _, data := range lines {
|
||||
if len(data) < 6 {
|
||||
// ignore blank line or wrong format
|
||||
// Ignore blank line or improperly formatted lines.
|
||||
continue
|
||||
}
|
||||
data = data[6:]
|
||||
@@ -226,52 +192,52 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
// OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应
|
||||
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
// OnResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
|
||||
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
minimaxResp := &minimaxChatCompletionV2Resp{}
|
||||
if err := json.Unmarshal(body, minimaxResp); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
|
||||
}
|
||||
if minimaxResp.BaseResp.StatusCode != 0 {
|
||||
return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
|
||||
return nil, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
|
||||
}
|
||||
response := m.responseV2ToOpenAI(minimaxResp)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
// minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体
|
||||
// minimaxChatCompletionV2Request represents the structure of a chat completion V2 request.
|
||||
type minimaxChatCompletionV2Request struct {
|
||||
Model string `json:"model"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
TokensToGenerate int64 `json:"tokens_to_generate,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
MaskSensitiveInfo bool `json:"mask_sensitive_info"` // 是否开启隐私信息打码,默认true
|
||||
MaskSensitiveInfo bool `json:"mask_sensitive_info"` // Whether to mask sensitive information, defaults to true.
|
||||
Messages []minimaxMessage `json:"messages"`
|
||||
BotSettings []minimaxBotSetting `json:"bot_setting"`
|
||||
ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"`
|
||||
}
|
||||
|
||||
// minimaxMessage 表示对话中的消息
|
||||
// minimaxMessage represents a message in the conversation.
|
||||
type minimaxMessage struct {
|
||||
SenderType string `json:"sender_type"`
|
||||
SenderName string `json:"sender_name"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// minimaxBotSetting 表示机器人的设置
|
||||
// minimaxBotSetting represents the bot's settings.
|
||||
type minimaxBotSetting struct {
|
||||
BotName string `json:"bot_name"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// minimaxReplyConstraints 表示模型回复要求
|
||||
// minimaxReplyConstraints represents requirements for model replies.
|
||||
type minimaxReplyConstraints struct {
|
||||
SenderType string `json:"sender_type"`
|
||||
SenderName string `json:"sender_name"`
|
||||
}
|
||||
|
||||
// minimaxChatCompletionV2Resp Minimax Chat Completion V2响应结构体
|
||||
// minimaxChatCompletionV2Resp represents the structure of a Minimax Chat Completion V2 response.
|
||||
type minimaxChatCompletionV2Resp struct {
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
@@ -286,20 +252,20 @@ type minimaxChatCompletionV2Resp struct {
|
||||
BaseResp minimaxBaseResp `json:"base_resp"`
|
||||
}
|
||||
|
||||
// minimaxBaseResp 包含错误状态码和详情
|
||||
// minimaxBaseResp contains error status code and details.
|
||||
type minimaxBaseResp struct {
|
||||
StatusCode int64 `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
}
|
||||
|
||||
// minimaxChoice 结果选项
|
||||
// minimaxChoice represents a result option.
|
||||
type minimaxChoice struct {
|
||||
Messages []minimaxMessage `json:"messages"`
|
||||
Index int64 `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// minimaxUsage 令牌使用情况
|
||||
// minimaxUsage represents token usage statistics.
|
||||
type minimaxUsage struct {
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -37,12 +38,12 @@ func (m *mistralProvider) GetProviderType() string {
|
||||
return providerTypeMistral
|
||||
}
|
||||
|
||||
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -56,12 +56,12 @@ func (m *moonshotProvider) GetProviderType() string {
|
||||
return providerTypeMoonshot
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
|
||||
@@ -3,10 +3,11 @@ package provider
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ollamaProvider is the provider for Ollama service.
|
||||
@@ -48,12 +49,12 @@ func (m *ollamaProvider) GetProviderType() string {
|
||||
return providerTypeOllama
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -57,9 +57,9 @@ func (m *openaiProvider) GetProviderType() string {
|
||||
return providerTypeOpenAI
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
|
||||
@@ -46,6 +46,7 @@ const (
|
||||
providerTypeCohere = "cohere"
|
||||
providerTypeDoubao = "doubao"
|
||||
providerTypeCoze = "coze"
|
||||
providerTypeTogetherAI = "together-ai"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
@@ -58,7 +59,9 @@ const (
|
||||
finishReasonLength = "length"
|
||||
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiName = "apiKey"
|
||||
ctxKeyApiKey = "apiKey"
|
||||
CtxKeyApiName = "apiName"
|
||||
ctxKeyIsStreaming = "isStreaming"
|
||||
ctxKeyStreamingBody = "streamingBody"
|
||||
ctxKeyOriginalRequestModel = "originalRequestModel"
|
||||
ctxKeyFinalRequestModel = "finalRequestModel"
|
||||
@@ -106,6 +109,7 @@ var (
|
||||
providerTypeCohere: &cohereProviderInitializer{},
|
||||
providerTypeDoubao: &doubaoProviderInitializer{},
|
||||
providerTypeCoze: &cozeProviderInitializer{},
|
||||
providerTypeTogetherAI: &togetherAIProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -113,22 +117,26 @@ type Provider interface {
|
||||
GetProviderType() string
|
||||
}
|
||||
|
||||
type ApiNameHandler interface {
|
||||
GetApiName(path string) ApiName
|
||||
}
|
||||
|
||||
type RequestHeadersHandler interface {
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
||||
}
|
||||
|
||||
type TransformRequestHeadersHandler interface {
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
|
||||
}
|
||||
|
||||
type RequestBodyHandler interface {
|
||||
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||
}
|
||||
|
||||
type StreamingResponseBodyHandler interface {
|
||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
type ApiNameHandler interface {
|
||||
GetApiName(path string) ApiName
|
||||
}
|
||||
|
||||
type TransformRequestHeadersHandler interface {
|
||||
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
}
|
||||
|
||||
type TransformRequestBodyHandler interface {
|
||||
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
@@ -139,16 +147,12 @@ type TransformRequestBodyHeadersHandler interface {
|
||||
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
type ResponseHeadersHandler interface {
|
||||
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error)
|
||||
type TransformResponseHeadersHandler interface {
|
||||
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
|
||||
}
|
||||
|
||||
type StreamingResponseBodyHandler interface {
|
||||
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
type ResponseBodyHandler interface {
|
||||
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
|
||||
type TransformResponseBodyHandler interface {
|
||||
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
|
||||
}
|
||||
|
||||
// TickFuncHandler allows the provider to execute a function periodically
|
||||
@@ -173,6 +177,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN apiToken 故障切换
|
||||
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
|
||||
failover *failover `required:"false" yaml:"failover" json:"failover"`
|
||||
// @Title zh-CN 失败请求重试
|
||||
// @Description zh-CN 对失败的请求立即进行重试
|
||||
retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"`
|
||||
// @Title zh-CN 基于OpenAI协议的自定义后端URL
|
||||
// @Description zh-CN 仅适用于支持 openai 协议的服务。
|
||||
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
|
||||
@@ -206,8 +213,11 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN hunyuan api id for authorization
|
||||
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权
|
||||
hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"`
|
||||
// @Title zh-CN minimax API type
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||
// @Title zh-CN minimax group id
|
||||
// @Description zh-CN 仅适用于minimax使用ChatCompletion Pro接口的模型
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型为 pro 时必填
|
||||
minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"`
|
||||
// @Title zh-CN 模型名称映射表
|
||||
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
|
||||
@@ -303,6 +313,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.claudeVersion = json.Get("claudeVersion").String()
|
||||
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
|
||||
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
|
||||
c.minimaxApiType = json.Get("minimaxApiType").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
|
||||
if c.typ == providerTypeGemini {
|
||||
@@ -346,6 +357,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.failover.FromJson(failoverJson)
|
||||
}
|
||||
|
||||
retryOnFailureJson := json.Get("retryOnFailure")
|
||||
c.retryOnFailure = &retryOnFailure{
|
||||
enabled: false,
|
||||
}
|
||||
if retryOnFailureJson.Exists() {
|
||||
c.retryOnFailure.FromJson(retryOnFailureJson)
|
||||
}
|
||||
|
||||
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
|
||||
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
|
||||
}
|
||||
@@ -393,10 +412,10 @@ func (c *ProviderConfig) Validate() error {
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
|
||||
ctxApiKey := ctx.GetContext(ctxKeyApiName)
|
||||
ctxApiKey := ctx.GetContext(ctxKeyApiKey)
|
||||
if ctxApiKey == nil {
|
||||
ctxApiKey = c.GetRandomToken()
|
||||
ctx.SetContext(ctxKeyApiName, ctxApiKey)
|
||||
ctx.SetContext(ctxKeyApiKey, ctxApiKey)
|
||||
}
|
||||
return ctxApiKey.(string)
|
||||
}
|
||||
@@ -440,6 +459,9 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
|
||||
streaming := req.Stream
|
||||
if streaming {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
|
||||
ctx.SetContext(ctxKeyIsStreaming, true)
|
||||
} else {
|
||||
ctx.SetContext(ctxKeyIsStreaming, false)
|
||||
}
|
||||
|
||||
return c.setRequestModel(ctx, req, log)
|
||||
@@ -534,9 +556,9 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
headers := util.GetOriginalHttpHeaders()
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
|
||||
util.ReplaceOriginalHttpHeaders(headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
} else {
|
||||
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
@@ -545,9 +567,14 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
|
||||
if c.isRetryOnFailureEnabled() {
|
||||
ctx.SetContext(ctxRequestBody, body)
|
||||
}
|
||||
|
||||
if apiName == ApiNameChatCompletion {
|
||||
if c.context == nil {
|
||||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
||||
return types.ActionContinue, replaceRequestBody(body, log)
|
||||
}
|
||||
err = contextCache.GetContextFromFile(ctx, provider, body, log)
|
||||
|
||||
@@ -556,14 +583,14 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
}
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
return types.ActionContinue, replaceHttpJsonRequestBody(body, log)
|
||||
return types.ActionContinue, replaceRequestBody(body, log)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||
originalHeaders := util.GetOriginalHttpHeaders()
|
||||
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log)
|
||||
util.ReplaceOriginalHttpHeaders(originalHeaders)
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers, log)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -579,3 +606,11 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
|
||||
}
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
|
||||
if c.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
} else {
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ const (
|
||||
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
|
||||
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
|
||||
qwenBailianPath = "/api/v1/apps"
|
||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
|
||||
qwenTopPMin = 0.000001
|
||||
@@ -71,16 +72,14 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
}
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
|
||||
if m.config.qwenEnableCompatible {
|
||||
if m.config.IsOriginal() {
|
||||
} else if m.config.qwenEnableCompatible {
|
||||
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
|
||||
} else if apiName == ApiNameChatCompletion {
|
||||
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
|
||||
} else if apiName == ApiNameEmbeddings {
|
||||
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
|
||||
}
|
||||
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
@@ -95,20 +94,19 @@ func (m *qwenProvider) GetProviderType() string {
|
||||
return providerTypeQwen
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return types.HeaderStopIteration, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -185,16 +183,6 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
|
||||
return json.Marshal(qwenRequest)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
if m.config.protocol == protocolOriginal {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
|
||||
return chunk, nil
|
||||
@@ -280,9 +268,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
if m.config.qwenEnableCompatible {
|
||||
return types.ActionContinue, nil
|
||||
return body, nil
|
||||
}
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return m.onChatCompletionResponseBody(ctx, body, log)
|
||||
@@ -290,25 +278,25 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
if apiName == ApiNameEmbeddings {
|
||||
return m.onEmbeddingsResponseBody(ctx, body, log)
|
||||
}
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return nil, errUnsupportedApiName
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
qwenResponse := &qwenTextGenResponse{}
|
||||
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
}
|
||||
response := m.buildChatCompletionResponse(ctx, qwenResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
qwenResponse := &qwenTextEmbeddingResponse{}
|
||||
if err := json.Unmarshal(body, qwenResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
|
||||
}
|
||||
response := m.buildEmbeddingsResponse(ctx, qwenResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
|
||||
@@ -762,6 +750,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
|
||||
switch {
|
||||
case strings.Contains(path, qwenChatCompletionPath),
|
||||
strings.Contains(path, qwenMultimodalGenerationPath),
|
||||
strings.Contains(path, qwenBailianPath),
|
||||
strings.Contains(path, qwenCompatiblePath):
|
||||
return ApiNameChatCompletion
|
||||
case strings.Contains(path, qwenTextEmbeddingPath):
|
||||
|
||||
@@ -37,7 +37,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error {
|
||||
func replaceRequestBody(body []byte, log wrapper.Log) error {
|
||||
log.Debugf("request body: %s", string(body))
|
||||
err := proxywasm.ReplaceHttpRequestBody(body)
|
||||
if err != nil {
|
||||
@@ -65,15 +65,11 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
func replaceJsonResponseBody(response interface{}, log wrapper.Log) error {
|
||||
body, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to marshal response: %v", err)
|
||||
}
|
||||
func ReplaceResponseBody(body []byte, log wrapper.Log) error {
|
||||
log.Debugf("response body: %s", string(body))
|
||||
err = proxywasm.ReplaceHttpResponseBody(body)
|
||||
err := proxywasm.ReplaceHttpResponseBody(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to replace the original response body: %v", err)
|
||||
}
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
141
plugins/wasm-go/extensions/ai-proxy/provider/retry.go
Normal file
141
plugins/wasm-go/extensions/ai-proxy/provider/retry.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
ctxRequestBody = "requestBody"
|
||||
ctxRetryCount = "retryCount"
|
||||
)
|
||||
|
||||
type retryOnFailure struct {
|
||||
// @Title zh-CN 是否启用请求重试
|
||||
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
|
||||
// @Title zh-CN 重试次数
|
||||
maxRetries int64 `required:"false" yaml:"maxRetries" json:"maxRetries"`
|
||||
// @Title zh-CN 重试超时时间
|
||||
retryTimeout int64 `required:"false" yaml:"retryTimeout" json:"retryTimeout"`
|
||||
}
|
||||
|
||||
func (r *retryOnFailure) FromJson(json gjson.Result) {
|
||||
r.enabled = json.Get("enabled").Bool()
|
||||
r.maxRetries = json.Get("maxRetries").Int()
|
||||
if r.maxRetries == 0 {
|
||||
r.maxRetries = 1
|
||||
}
|
||||
r.retryTimeout = json.Get("retryTimeout").Int()
|
||||
if r.retryTimeout == 0 {
|
||||
r.retryTimeout = 30 * 1000
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) isRetryOnFailureEnabled() bool {
|
||||
return c.retryOnFailure.enabled
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
log.Debugf("Retry failed request: provider=%s", activeProvider.GetProviderType())
|
||||
retryClient := createRetryClient(ctx)
|
||||
apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName)
|
||||
ctx.SetContext(ctxRetryCount, 1)
|
||||
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) {
|
||||
if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok {
|
||||
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
||||
} else {
|
||||
c.DefaultTransformResponseHeaders(ctx, headers)
|
||||
}
|
||||
|
||||
if handler, ok := activeProvider.(TransformResponseBodyHandler); ok {
|
||||
var err error
|
||||
body, err = handler.TransformResponseBody(ctx, apiName, body, log)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to transform response body: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return util.HeaderToSlice(headers), body
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) retryCall(
|
||||
ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider,
|
||||
apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte,
|
||||
retryClient *wrapper.ClusterClient[wrapper.RouteCluster]) {
|
||||
|
||||
retryCount := ctx.GetContext(ctxRetryCount).(int)
|
||||
log.Debugf("Sent retry request: %d/%d", retryCount, c.retryOnFailure.maxRetries)
|
||||
|
||||
if statusCode == 200 {
|
||||
log.Debugf("Retry request succeeded")
|
||||
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log)
|
||||
proxywasm.SendHttpResponse(200, headers, body, -1)
|
||||
} else {
|
||||
log.Debugf("The retry request still failed, status: %d, responseHeaders: %v, responseBody: %s", statusCode, responseHeaders, string(responseBody))
|
||||
}
|
||||
|
||||
retryCount++
|
||||
if retryCount <= int(c.retryOnFailure.maxRetries) {
|
||||
ctx.SetContext(ctxRetryCount, retryCount)
|
||||
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
|
||||
} else {
|
||||
log.Debugf("Reached the maximum retry count: %d", c.retryOnFailure.maxRetries)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) sendRetryRequest(
|
||||
ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider,
|
||||
retryClient *wrapper.ClusterClient[wrapper.RouteCluster], log wrapper.Log) {
|
||||
|
||||
requestHeaders, requestBody := c.getRetryRequestHeadersAndBody(ctx, activeProvider, apiName, log)
|
||||
path := getRetryPath(ctx)
|
||||
|
||||
err := retryClient.Post(path, util.HeaderToSlice(requestHeaders), requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient)
|
||||
}, uint32(c.retryOnFailure.retryTimeout))
|
||||
if err != nil {
|
||||
log.Errorf("Failed to send retry request: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
|
||||
func createRetryClient(ctx wrapper.HttpContext) *wrapper.ClusterClient[wrapper.RouteCluster] {
|
||||
host := wrapper.GetRequestHost()
|
||||
if host == "" {
|
||||
host = ctx.GetContext(ctxRequestHost).(string)
|
||||
}
|
||||
retryClient := wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: host,
|
||||
})
|
||||
return retryClient
|
||||
}
|
||||
|
||||
func getRetryPath(ctx wrapper.HttpContext) string {
|
||||
path := wrapper.GetRequestPath()
|
||||
if path == "" {
|
||||
path = ctx.GetContext(ctxRequestPath).(string)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, log wrapper.Log) (http.Header, []byte) {
|
||||
// The retry request may be sent with different apiToken, so the header needs to be regenerated
|
||||
c.SetApiTokenInUse(ctx, log)
|
||||
|
||||
requestHeaders := http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
}
|
||||
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, apiName, requestHeaders, log)
|
||||
}
|
||||
requestBody := ctx.GetContext(ctxRequestBody).([]byte)
|
||||
|
||||
return requestHeaders, requestBody
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -67,12 +66,12 @@ func (p *sparkProvider) GetProviderType() string {
|
||||
return providerTypeSpark
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
p.config.handleRequestHeaders(p, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
@@ -82,21 +81,16 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
sparkResponse := &sparkResponse{}
|
||||
if err := json.Unmarshal(body, sparkResponse); err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
|
||||
}
|
||||
if sparkResponse.Code != 0 {
|
||||
return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
|
||||
return nil, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
|
||||
}
|
||||
response := p.responseSpark2OpenAI(ctx, sparkResponse)
|
||||
return types.ActionContinue, replaceJsonResponseBody(response, log)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
@@ -177,6 +171,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
|
||||
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, sparkHost)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Accept-Encoding")
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,12 +40,12 @@ func (m *stepfunProvider) GetProviderType() string {
|
||||
return providerTypeStepfun
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
69
plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go
Normal file
69
plugins/wasm-go/extensions/ai-proxy/provider/together_ai.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
togetherAIDomain = "api.together.xyz"
|
||||
togetherAICompletionPath = "/v1/chat/completions"
|
||||
)
|
||||
|
||||
type togetherAIProviderInitializer struct{}
|
||||
|
||||
func (m *togetherAIProviderInitializer) ValidateConfig(config ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
return &togetherAIProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type togetherAIProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) GetProviderType() string {
|
||||
return providerTypeTogetherAI
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
util.OverwriteRequestPathHeader(headers, togetherAICompletionPath)
|
||||
util.OverwriteRequestHostHeader(headers, togetherAIDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, togetherAICompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -40,12 +40,12 @@ func (m *yiProvider) GetProviderType() string {
|
||||
return providerTypeYi
|
||||
}
|
||||
|
||||
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -40,12 +40,12 @@ func (m *zhipuAiProvider) GetProviderType() string {
|
||||
return providerTypeZhipuAi
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
|
||||
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return errUnsupportedApiName
|
||||
}
|
||||
m.config.handleRequestHeaders(m, ctx, apiName, log)
|
||||
return types.ActionContinue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
|
||||
@@ -86,12 +86,22 @@ func SliceToHeader(slice [][2]string) http.Header {
|
||||
return header
|
||||
}
|
||||
|
||||
func GetOriginalHttpHeaders() http.Header {
|
||||
func GetOriginalRequestHeaders() http.Header {
|
||||
originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
|
||||
return SliceToHeader(originalHeaders)
|
||||
}
|
||||
|
||||
func ReplaceOriginalHttpHeaders(headers http.Header) {
|
||||
func GetOriginalResponseHeaders() http.Header {
|
||||
originalHeaders, _ := proxywasm.GetHttpResponseHeaders()
|
||||
return SliceToHeader(originalHeaders)
|
||||
}
|
||||
|
||||
func ReplaceRequestHeaders(headers http.Header) {
|
||||
modifiedHeaders := HeaderToSlice(headers)
|
||||
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
|
||||
}
|
||||
|
||||
func ReplaceResponseHeaders(headers http.Header) {
|
||||
modifiedHeaders := HeaderToSlice(headers)
|
||||
_ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders)
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ description: AI 配额管理插件配置参考
|
||||
|
||||
## 功能说明
|
||||
|
||||
`ai-qutoa` 插件实现给特定 consumer 根据分配固定的 quota 进行 quota 策略限流,同时支持 quota 管理能力,包括查询 quota 、刷新 quota、增减 quota。
|
||||
`ai-quota` 插件实现给特定 consumer 根据分配固定的 quota 进行 quota 策略限流,同时支持 quota 管理能力,包括查询 quota 、刷新 quota、增减 quota。
|
||||
|
||||
`ai-quota` 插件需要配合 认证插件比如 `key-auth`、`jwt-auth` 等插件获取认证身份的 consumer 名称,同时需要配合 `ai-statatistics` 插件获取 AI Token 统计信息。
|
||||
`ai-quota` 插件需要配合 认证插件比如 `key-auth`、`jwt-auth` 等插件获取认证身份的 consumer 名称,同时需要配合 `ai-statistics` 插件获取 AI Token 统计信息。
|
||||
|
||||
## 运行属性
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ description: 阿里云内容安全检测
|
||||
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
|
||||
| `protocol` | string | optional | openai | 协议格式,非openai协议填`original` |
|
||||
| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low |
|
||||
| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 |
|
||||
|
||||
补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
|
||||
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容,格式为openai格式的流式/非流式响应
|
||||
|
||||
@@ -53,6 +53,7 @@ const (
|
||||
DefaultStreamingResponseJsonPath = "choices.0.delta.content"
|
||||
DefaultDenyCode = 200
|
||||
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
|
||||
DefaultTimeout = 2000
|
||||
|
||||
AliyunUserAgent = "CIPFrom/AIGateway"
|
||||
LengthLimit = 1800
|
||||
@@ -100,6 +101,7 @@ type AISecurityConfig struct {
|
||||
denyMessage string
|
||||
protocolOriginal bool
|
||||
riskLevelBar string
|
||||
timeout uint32
|
||||
metrics map[string]proxywasm.MetricCounter
|
||||
}
|
||||
|
||||
@@ -225,6 +227,11 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
|
||||
} else {
|
||||
config.riskLevelBar = HighRisk
|
||||
}
|
||||
if obj := json.Get("timeout"); obj.Exists() {
|
||||
config.timeout = uint32(obj.Int())
|
||||
} else {
|
||||
config.timeout = DefaultTimeout
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceName,
|
||||
Port: servicePort,
|
||||
@@ -253,6 +260,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
log.Debugf("checking request body...")
|
||||
startTime := time.Now().UnixMilli()
|
||||
content := gjson.GetBytes(body, config.requestContentJsonPath).String()
|
||||
model := gjson.GetBytes(body, "model").String()
|
||||
ctx.SetContext("requestModel", model)
|
||||
@@ -279,6 +287,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
}
|
||||
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "request pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
} else {
|
||||
singleCall()
|
||||
@@ -305,7 +317,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
}
|
||||
ctx.DontReadResponseBody()
|
||||
config.incrementCounter("ai_sec_request_deny", 1)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
@@ -340,7 +359,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
@@ -359,40 +378,26 @@ func convertHeaders(hs [][2]string) map[string][]string {
|
||||
return ret
|
||||
}
|
||||
|
||||
// headers: map[string][]string -> [][2]string
|
||||
func reconvertHeaders(hs map[string][]string) [][2]string {
|
||||
var ret [][2]string
|
||||
for k, vs := range hs {
|
||||
for _, v := range vs {
|
||||
ret = append(ret, [2]string{k, v})
|
||||
}
|
||||
}
|
||||
sort.SliceStable(ret, func(i, j int) bool {
|
||||
return ret[i][0] < ret[j][0]
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
|
||||
if !config.checkResponse {
|
||||
log.Debugf("response checking is disabled")
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
headers, err := proxywasm.GetHttpResponseHeaders()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get response headers: %v", err)
|
||||
statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
|
||||
if statusCode != "200" {
|
||||
log.Debugf("response is not 200, skip response body check")
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
hdsMap := convertHeaders(headers)
|
||||
ctx.SetContext("headers", hdsMap)
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
|
||||
log.Debugf("checking response body...")
|
||||
hdsMap := ctx.GetContext("headers").(map[string][]string)
|
||||
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream")
|
||||
startTime := time.Now().UnixMilli()
|
||||
contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
|
||||
isStreamingResponse := strings.Contains(contentType, "event-stream")
|
||||
model := ctx.GetStringContext("requestModel", "unknown")
|
||||
var content string
|
||||
if isStreamingResponse {
|
||||
@@ -423,6 +428,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
}
|
||||
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
|
||||
if contentIndex >= len(content) {
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response pass")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
} else {
|
||||
singleCall()
|
||||
@@ -436,22 +445,26 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
denyMessage = response.Data.Advice[0].Answer
|
||||
}
|
||||
marshalledDenyMessage := marshalStr(denyMessage, log)
|
||||
var jsonData []byte
|
||||
if config.protocolOriginal {
|
||||
jsonData = []byte(marshalledDenyMessage)
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
|
||||
} else if isStreamingResponse {
|
||||
randomID := generateRandomID()
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
|
||||
} else {
|
||||
randomID := generateRandomID()
|
||||
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
||||
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
|
||||
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
|
||||
}
|
||||
delete(hdsMap, "content-length")
|
||||
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
|
||||
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
|
||||
proxywasm.ReplaceHttpResponseBody(jsonData)
|
||||
config.incrementCounter("ai_sec_response_deny", 1)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
endTime := time.Now().UnixMilli()
|
||||
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
|
||||
ctx.SetUserAttribute("safecheck_status", "response deny")
|
||||
if response.Data.Advice != nil {
|
||||
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
|
||||
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
|
||||
}
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
}
|
||||
singleCall = func() {
|
||||
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
@@ -486,7 +499,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
|
||||
reqParams.Add(k, v)
|
||||
}
|
||||
reqParams.Add("Signature", signature)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback)
|
||||
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed call the safe check service: %v", err)
|
||||
proxywasm.ResumeHttpResponse()
|
||||
|
||||
@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
|
||||
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -28,14 +27,15 @@ func main() {
|
||||
}
|
||||
|
||||
const (
|
||||
// Trace span prefix
|
||||
TracePrefix = "trace_span_tag."
|
||||
// Context consts
|
||||
StatisticsRequestStartTime = "ai-statistics-request-start-time"
|
||||
StatisticsFirstTokenTime = "ai-statistics-first-token-time"
|
||||
CtxGeneralAtrribute = "attributes"
|
||||
CtxLogAtrribute = "logAttributes"
|
||||
CtxStreamingBodyBuffer = "streamingBodyBuffer"
|
||||
RouteName = "route"
|
||||
ClusterName = "cluster"
|
||||
APIName = "api"
|
||||
|
||||
// Source Type
|
||||
FixedValue = "fixed_value"
|
||||
@@ -46,12 +46,14 @@ const (
|
||||
ResponseBody = "response_body"
|
||||
|
||||
// Inner metric & log attributes name
|
||||
Model = "model"
|
||||
InputToken = "input_token"
|
||||
OutputToken = "output_token"
|
||||
LLMFirstTokenDuration = "llm_first_token_duration"
|
||||
LLMServiceDuration = "llm_service_duration"
|
||||
LLMDurationCount = "llm_duration_count"
|
||||
Model = "model"
|
||||
InputToken = "input_token"
|
||||
OutputToken = "output_token"
|
||||
LLMFirstTokenDuration = "llm_first_token_duration"
|
||||
LLMServiceDuration = "llm_service_duration"
|
||||
LLMDurationCount = "llm_duration_count"
|
||||
LLMStreamDurationCount = "llm_stream_duration_count"
|
||||
ResponseType = "response_type"
|
||||
|
||||
// Extract Rule
|
||||
RuleFirst = "first"
|
||||
@@ -91,6 +93,19 @@ func getRouteName() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func getAPIName() (string, error) {
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
|
||||
return "-", err
|
||||
} else {
|
||||
parts := strings.Split(string(raw), "@")
|
||||
if len(parts) != 5 {
|
||||
return "-", errors.New("not api type")
|
||||
} else {
|
||||
return strings.Join(parts[:3], "@"), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getClusterName() (string, error) {
|
||||
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
|
||||
return "-", err
|
||||
@@ -133,8 +148,15 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrappe
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
|
||||
ctx.SetContext(CtxGeneralAtrribute, map[string]string{})
|
||||
ctx.SetContext(CtxLogAtrribute, map[string]string{})
|
||||
route, _ := getRouteName()
|
||||
cluster, _ := getClusterName()
|
||||
api, api_error := getAPIName()
|
||||
if api_error == nil {
|
||||
route = api
|
||||
}
|
||||
ctx.SetContext(RouteName, route)
|
||||
ctx.SetContext(ClusterName, cluster)
|
||||
ctx.SetUserAttribute(APIName, api)
|
||||
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
|
||||
|
||||
// Set user defined log & span attributes which type is fixed_value
|
||||
@@ -149,6 +171,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, RequestBody, body, log)
|
||||
|
||||
// Write log
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
@@ -177,6 +202,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer)
|
||||
}
|
||||
|
||||
ctx.SetUserAttribute(ResponseType, "stream")
|
||||
|
||||
// Get requestStartTime from http context
|
||||
requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64)
|
||||
if !ok {
|
||||
@@ -188,28 +215,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
if ctx.GetContext(StatisticsFirstTokenTime) == nil {
|
||||
firstTokenTime := time.Now().UnixMilli()
|
||||
ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime)
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
attributes[LLMFirstTokenDuration] = fmt.Sprint(firstTokenTime - requestStartTime)
|
||||
ctx.SetContext(CtxGeneralAtrribute, attributes)
|
||||
ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime)
|
||||
}
|
||||
|
||||
// Set information about this request
|
||||
|
||||
if model, inputToken, outputToken, ok := getUsage(data); ok {
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
// Record Log Attributes
|
||||
attributes[Model] = model
|
||||
attributes[InputToken] = fmt.Sprint(inputToken)
|
||||
attributes[OutputToken] = fmt.Sprint(outputToken)
|
||||
// Set attributes to http context
|
||||
ctx.SetContext(CtxGeneralAtrribute, attributes)
|
||||
ctx.SetUserAttribute(Model, model)
|
||||
ctx.SetUserAttribute(InputToken, inputToken)
|
||||
ctx.SetUserAttribute(OutputToken, outputToken)
|
||||
}
|
||||
// If the end of the stream is reached, record metrics/logs/spans.
|
||||
if endOfStream {
|
||||
responseEndTime := time.Now().UnixMilli()
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime)
|
||||
ctx.SetContext(CtxGeneralAtrribute, attributes)
|
||||
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
|
||||
|
||||
// Set user defined log & span attributes.
|
||||
if config.shouldBufferStreamingBody {
|
||||
@@ -220,11 +238,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log)
|
||||
}
|
||||
|
||||
// Write inner filter states which can be used by other plugins such as ai-token-ratelimit
|
||||
writeFilterStates(ctx, log)
|
||||
|
||||
// Write log
|
||||
writeLog(ctx, log)
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
|
||||
// Write metrics
|
||||
writeMetric(ctx, config, log)
|
||||
@@ -233,33 +248,26 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
|
||||
// Get attributes from http context
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
|
||||
// Get requestStartTime from http context
|
||||
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
|
||||
|
||||
responseEndTime := time.Now().UnixMilli()
|
||||
attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime)
|
||||
ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
|
||||
|
||||
ctx.SetUserAttribute(ResponseType, "normal")
|
||||
|
||||
// Set information about this request
|
||||
model, inputToken, outputToken, ok := getUsage(body)
|
||||
if ok {
|
||||
attributes[Model] = model
|
||||
attributes[InputToken] = fmt.Sprint(inputToken)
|
||||
attributes[OutputToken] = fmt.Sprint(outputToken)
|
||||
// Update attributes
|
||||
ctx.SetContext(CtxGeneralAtrribute, attributes)
|
||||
if model, inputToken, outputToken, ok := getUsage(body); ok {
|
||||
ctx.SetUserAttribute(Model, model)
|
||||
ctx.SetUserAttribute(InputToken, inputToken)
|
||||
ctx.SetUserAttribute(OutputToken, outputToken)
|
||||
}
|
||||
|
||||
// Set user defined log & span attributes.
|
||||
setAttributeBySource(ctx, config, ResponseBody, body, log)
|
||||
|
||||
// Write inner filter states which can be used by other plugins such as ai-token-ratelimit
|
||||
writeFilterStates(ctx, log)
|
||||
|
||||
// Write log
|
||||
writeLog(ctx, log)
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
|
||||
// Write metrics
|
||||
writeMetric(ctx, config, log)
|
||||
@@ -294,67 +302,49 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
|
||||
|
||||
// fetches the tracing span value from the specified source.
|
||||
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
|
||||
attributes, ok := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
if !ok {
|
||||
log.Error("failed to get attributes from http context")
|
||||
return
|
||||
}
|
||||
for _, attribute := range config.attributes {
|
||||
var key string
|
||||
var value interface{}
|
||||
if source == attribute.ValueSource {
|
||||
key = attribute.Key
|
||||
switch source {
|
||||
case FixedValue:
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value)
|
||||
attributes[attribute.Key] = attribute.Value
|
||||
value = attribute.Value
|
||||
case RequestHeader:
|
||||
if value, err := proxywasm.GetHttpRequestHeader(attribute.Value); err == nil {
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
attributes[attribute.Key] = value
|
||||
}
|
||||
value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
|
||||
case RequestBody:
|
||||
raw := gjson.GetBytes(body, attribute.Value).Raw
|
||||
var value string
|
||||
if len(raw) > 2 {
|
||||
value = raw[1 : len(raw)-1]
|
||||
}
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
attributes[attribute.Key] = value
|
||||
value = gjson.GetBytes(body, attribute.Value).Value()
|
||||
case ResponseHeader:
|
||||
if value, err := proxywasm.GetHttpResponseHeader(attribute.Value); err == nil {
|
||||
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
attributes[attribute.Key] = value
|
||||
}
|
||||
value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
|
||||
case ResponseStreamingBody:
|
||||
value := extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
|
||||
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
attributes[attribute.Key] = value
|
||||
value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
|
||||
case ResponseBody:
|
||||
value := gjson.GetBytes(body, attribute.Value).Raw
|
||||
if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
|
||||
attributes[attribute.Key] = value
|
||||
value = gjson.GetBytes(body, attribute.Value).Value()
|
||||
default:
|
||||
}
|
||||
}
|
||||
if attribute.ApplyToLog {
|
||||
setLogAttribute(ctx, attribute.Key, attributes[attribute.Key], log)
|
||||
}
|
||||
if attribute.ApplyToSpan {
|
||||
setSpanAttribute(attribute.Key, attributes[attribute.Key], log)
|
||||
log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
|
||||
if attribute.ApplyToLog {
|
||||
ctx.SetUserAttribute(key, value)
|
||||
}
|
||||
// for metrics
|
||||
if key == Model || key == InputToken || key == OutputToken {
|
||||
ctx.SetContext(key, value)
|
||||
}
|
||||
if attribute.ApplyToSpan {
|
||||
setSpanAttribute(key, value, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
ctx.SetContext(CtxGeneralAtrribute, attributes)
|
||||
}
|
||||
|
||||
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string {
|
||||
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
var value string
|
||||
var value interface{}
|
||||
if rule == RuleFirst {
|
||||
for _, chunk := range chunks {
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
value = jsonObj.String()
|
||||
value = jsonObj.Value()
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -362,140 +352,116 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
|
||||
for _, chunk := range chunks {
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
value = jsonObj.String()
|
||||
value = jsonObj.Value()
|
||||
}
|
||||
}
|
||||
} else if rule == RuleAppend {
|
||||
// extract llm response
|
||||
var strValue string
|
||||
for _, chunk := range chunks {
|
||||
raw := gjson.GetBytes(chunk, jsonPath).Raw
|
||||
if len(raw) > 2 && raw[0] == '"' && raw[len(raw)-1] == '"' {
|
||||
value += raw[1 : len(raw)-1]
|
||||
jsonObj := gjson.GetBytes(chunk, jsonPath)
|
||||
if jsonObj.Exists() {
|
||||
strValue += jsonObj.String()
|
||||
}
|
||||
}
|
||||
value = strValue
|
||||
} else {
|
||||
log.Errorf("unsupported rule type: %s", rule)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func setFilterState(key, value string, log wrapper.Log) {
|
||||
if value != "" {
|
||||
if e := proxywasm.SetProperty([]string{key}, []byte(fmt.Sprint(value))); e != nil {
|
||||
log.Errorf("failed to set %s in filter state: %v", key, e)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("failed to write filter state [%s], because it's value is empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Set the tracing span with value.
|
||||
func setSpanAttribute(key, value string, log wrapper.Log) {
|
||||
func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
|
||||
if value != "" {
|
||||
traceSpanTag := TracePrefix + key
|
||||
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil {
|
||||
log.Errorf("failed to set %s in filter state: %v", traceSpanTag, e)
|
||||
traceSpanTag := wrapper.TraceSpanTagPrefix + key
|
||||
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
|
||||
log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("failed to write span attribute [%s], because it's value is empty")
|
||||
}
|
||||
}
|
||||
|
||||
// fetches the tracing span value from the specified source.
|
||||
func setLogAttribute(ctx wrapper.HttpContext, key string, value interface{}, log wrapper.Log) {
|
||||
logAttributes, ok := ctx.GetContext(CtxLogAtrribute).(map[string]string)
|
||||
if !ok {
|
||||
log.Error("failed to get logAttributes from http context")
|
||||
return
|
||||
}
|
||||
logAttributes[key] = fmt.Sprint(value)
|
||||
ctx.SetContext(CtxLogAtrribute, logAttributes)
|
||||
}
|
||||
|
||||
func writeFilterStates(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
setFilterState(Model, attributes[Model], log)
|
||||
setFilterState(InputToken, attributes[InputToken], log)
|
||||
setFilterState(OutputToken, attributes[OutputToken], log)
|
||||
}
|
||||
|
||||
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
route, _ := getRouteName()
|
||||
cluster, _ := getClusterName()
|
||||
model, ok := attributes["model"]
|
||||
// Generate usage metrics
|
||||
var ok bool
|
||||
var route, cluster, model string
|
||||
var inputToken, outputToken uint64
|
||||
route, ok = ctx.GetContext(RouteName).(string)
|
||||
if !ok {
|
||||
log.Errorf("Get model failed")
|
||||
log.Warnf("RouteName typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
if inputToken, ok := attributes[InputToken]; ok {
|
||||
inputTokenUint64, err := strconv.ParseUint(inputToken, 10, 0)
|
||||
if err != nil || inputTokenUint64 == 0 {
|
||||
log.Errorf("inputToken convert failed, value is %d, err msg is [%v]", inputTokenUint64, err)
|
||||
cluster, ok = ctx.GetContext(ClusterName).(string)
|
||||
if !ok {
|
||||
log.Warnf("ClusterName typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
|
||||
log.Warnf("get usage information failed, skip metric record")
|
||||
return
|
||||
}
|
||||
model, ok = ctx.GetUserAttribute(Model).(string)
|
||||
if !ok {
|
||||
log.Warnf("Model typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
|
||||
if !ok {
|
||||
log.Warnf("InputToken typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
|
||||
if !ok {
|
||||
log.Warnf("OutputToken typd assert failed, skip metric record")
|
||||
return
|
||||
}
|
||||
if inputToken == 0 || outputToken == 0 {
|
||||
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputToken)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputToken)
|
||||
|
||||
// Generate duration metrics
|
||||
var llmFirstTokenDuration, llmServiceDuration uint64
|
||||
// Is stream response
|
||||
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
|
||||
llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
|
||||
if !ok {
|
||||
log.Warnf("LLMFirstTokenDuration typd assert failed")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputTokenUint64)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMStreamDurationCount), 1)
|
||||
}
|
||||
if outputToken, ok := attributes[OutputToken]; ok {
|
||||
outputTokenUint64, err := strconv.ParseUint(outputToken, 10, 0)
|
||||
if err != nil || outputTokenUint64 == 0 {
|
||||
log.Errorf("outputToken convert failed, value is %d, err msg is [%v]", outputTokenUint64, err)
|
||||
if ctx.GetUserAttribute(LLMServiceDuration) != nil {
|
||||
llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
|
||||
if !ok {
|
||||
log.Warnf("LLMServiceDuration typd assert failed")
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputTokenUint64)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDuration)
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
|
||||
}
|
||||
if llmFirstTokenDuration, ok := attributes[LLMFirstTokenDuration]; ok {
|
||||
llmFirstTokenDurationUint64, err := strconv.ParseUint(llmFirstTokenDuration, 10, 0)
|
||||
if err != nil || llmFirstTokenDurationUint64 == 0 {
|
||||
log.Errorf("llmFirstTokenDuration convert failed, value is %d, err msg is [%v]", llmFirstTokenDurationUint64, err)
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDurationUint64)
|
||||
}
|
||||
if llmServiceDuration, ok := attributes[LLMServiceDuration]; ok {
|
||||
llmServiceDurationUint64, err := strconv.ParseUint(llmServiceDuration, 10, 0)
|
||||
if err != nil || llmServiceDurationUint64 == 0 {
|
||||
log.Errorf("llmServiceDuration convert failed, value is %d, err msg is [%v]", llmServiceDurationUint64, err)
|
||||
return
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDurationUint64)
|
||||
}
|
||||
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
|
||||
}
|
||||
|
||||
func writeLog(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
|
||||
logAttributes, _ := ctx.GetContext(CtxLogAtrribute).(map[string]string)
|
||||
// Set inner log fields
|
||||
if attributes[Model] != "" {
|
||||
logAttributes[Model] = attributes[Model]
|
||||
}
|
||||
if attributes[InputToken] != "" {
|
||||
logAttributes[InputToken] = attributes[InputToken]
|
||||
}
|
||||
if attributes[OutputToken] != "" {
|
||||
logAttributes[OutputToken] = attributes[OutputToken]
|
||||
}
|
||||
if attributes[LLMFirstTokenDuration] != "" {
|
||||
logAttributes[LLMFirstTokenDuration] = attributes[LLMFirstTokenDuration]
|
||||
}
|
||||
if attributes[LLMServiceDuration] != "" {
|
||||
logAttributes[LLMServiceDuration] = attributes[LLMServiceDuration]
|
||||
}
|
||||
// Traverse log fields
|
||||
items := []string{}
|
||||
for k, v := range logAttributes {
|
||||
items = append(items, fmt.Sprintf(`"%s":"%s"`, k, v))
|
||||
}
|
||||
aiLogField := fmt.Sprintf(`{%s}`, strings.Join(items, ","))
|
||||
// log.Infof("ai request json log: %s", aiLogField)
|
||||
jsonMap := map[string]string{
|
||||
"ai_log": aiLogField,
|
||||
}
|
||||
serialized, _ := json.Marshal(jsonMap)
|
||||
jsonLogRaw := gjson.GetBytes(serialized, "ai_log").Raw
|
||||
jsonLog := jsonLogRaw[1 : len(jsonLogRaw)-1]
|
||||
if err := proxywasm.SetProperty([]string{"ai_log"}, []byte(jsonLog)); err != nil {
|
||||
log.Errorf("failed to set ai_log in filter state: %v", err)
|
||||
func convertToUInt(val interface{}) (uint64, bool) {
|
||||
switch v := val.(type) {
|
||||
case float32:
|
||||
return uint64(v), true
|
||||
case float64:
|
||||
return uint64(v), true
|
||||
case int32:
|
||||
return uint64(v), true
|
||||
case int64:
|
||||
return uint64(v), true
|
||||
case uint32:
|
||||
return uint64(v), true
|
||||
case uint64:
|
||||
return v, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
@@ -14,8 +13,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
|
||||
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
|
||||
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
@@ -61,9 +62,9 @@ const (
|
||||
ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字
|
||||
CookieHeader string = "cookie"
|
||||
|
||||
RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数
|
||||
RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数
|
||||
RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回)
|
||||
RateLimitLimitHeader string = "X-TokenRateLimit-Limit" // 限制的总请求数
|
||||
RateLimitRemainingHeader string = "X-TokenRateLimit-Remaining" // 剩余还可以发送的请求数
|
||||
RateLimitResetHeader string = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回)
|
||||
)
|
||||
|
||||
type LimitContext struct {
|
||||
@@ -124,6 +125,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
|
||||
}
|
||||
if context.remaining < 0 {
|
||||
// 触发限流
|
||||
ctx.SetUserAttribute("token_ratelimit_status", "limited")
|
||||
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
|
||||
rejected(config, context)
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
@@ -137,39 +140,49 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
|
||||
}
|
||||
|
||||
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
|
||||
if !endOfStream {
|
||||
return data
|
||||
var inputToken, outputToken int64
|
||||
if inputToken, outputToken, ok := getUsage(data); ok {
|
||||
ctx.SetContext("input_token", inputToken)
|
||||
ctx.SetContext("output_token", outputToken)
|
||||
}
|
||||
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
|
||||
if err != nil {
|
||||
return data
|
||||
if endOfStream {
|
||||
if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil {
|
||||
return data
|
||||
}
|
||||
inputToken = ctx.GetContext("input_token").(int64)
|
||||
outputToken = ctx.GetContext("output_token").(int64)
|
||||
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
keys := []interface{}{limitRedisContext.key}
|
||||
args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
|
||||
err := config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil)
|
||||
if err != nil {
|
||||
log.Errorf("redis call failed: %v", err)
|
||||
}
|
||||
}
|
||||
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
inputToken, err := strconv.Atoi(string(inputTokenStr))
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
outputToken, err := strconv.Atoi(string(outputTokenStr))
|
||||
if err != nil {
|
||||
return data
|
||||
}
|
||||
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
|
||||
if !ok {
|
||||
return data
|
||||
}
|
||||
keys := []interface{}{limitRedisContext.key}
|
||||
args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
|
||||
return data
|
||||
}
|
||||
|
||||
err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil)
|
||||
if err != nil {
|
||||
log.Errorf("redis call failed: %v", err)
|
||||
return data
|
||||
} else {
|
||||
return data
|
||||
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
|
||||
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
|
||||
for _, chunk := range chunks {
|
||||
// the feature strings are used to identify the usage data, like:
|
||||
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
|
||||
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
|
||||
continue
|
||||
}
|
||||
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
|
||||
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
|
||||
if inputTokenObj.Exists() && outputTokenObj.Exists() {
|
||||
inputTokenUsage = inputTokenObj.Int()
|
||||
outputTokenUsage = outputTokenObj.Int()
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {
|
||||
|
||||
@@ -32,6 +32,7 @@ description: OIDC 认证插件配置参考
|
||||
| client_secret | string | the OAuth Client Secret | |
|
||||
| provider | string | OAuth provider | oidc |
|
||||
| pass_authorization_header | bool | pass OIDC IDToken to upstream via Authorization Bearer header | true |
|
||||
| pass_access_token | bool | pass OIDC Access Token to upstream via X-Forwarded-Access-Token header. | False |
|
||||
| oidc_issuer_url | string | the OpenID Connect issuer URL, e.g. `"https://dev-o43xb1mz7ya7ach4.us.auth0.com"` | |
|
||||
| oidc_verifier_request_timeout | uint32 | OIDC verifier discovery request timeout | 2000(ms) |
|
||||
| scope | string | OAuth scope specification | |
|
||||
@@ -296,6 +297,55 @@ match_list:
|
||||
|
||||

|
||||
|
||||
### Github 配置示例
|
||||
|
||||
#### Step 1: 配置 Github OAuth应用
|
||||
|
||||
通过 https://github.com/settings/developers 创建OAuthApp
|
||||
|
||||
#### Step 2: Higress 配置服务来源
|
||||
|
||||
* 创建DNS类型服务来源地址为github.com
|
||||
* 创建DNS类型服务来源地址为api.github.com(用于验证OIDC流程中的access_token)
|
||||
|
||||

|
||||
|
||||
#### Step 3: OIDC 服务 HTTPS 配置
|
||||
|
||||
参考Auth0的Step3对创建的两个DNS服务配置Ingress
|
||||
|
||||
#### Step 4: Wasm 插件配置
|
||||
|
||||
```yaml
|
||||
redirect_url: 'http://foo.bar.com/oauth2/callback'
|
||||
provider: github
|
||||
oidc_issuer_url: 'https://github.com/'
|
||||
pass_access_token: true
|
||||
client_id: 'XXXXXXXXXXXXXXXX'
|
||||
client_secret: 'XXXXXXXXXXXXXXXX'
|
||||
scope: 'user repo'
|
||||
cookie_secret: 'nqavJrGvRmQxWwGNptLdyUVKcBNZ2b18Guc1n_8DCfY='
|
||||
service_name: 'github.dns'
|
||||
service_port: 443
|
||||
validate_service_name: 'api.dns'
|
||||
validate_service_port: 443
|
||||
match_type: 'whitelist'
|
||||
match_list:
|
||||
- match_rule_domain: '*.bar.com'
|
||||
match_rule_path: '/headers'
|
||||
match_rule_type: 'prefix'
|
||||
```
|
||||
|
||||
#### 访问服务页面,未登陆的话进行跳转
|
||||
|
||||

|
||||
|
||||
#### 登陆成功跳转到服务页面
|
||||
|
||||
配置了`pass_access_token=true`后会在`X-Forwarded-Access-Token`header头中携带access_token
|
||||
|
||||

|
||||
|
||||
### OIDC 流程图
|
||||
|
||||
<p align="center">
|
||||
@@ -422,5 +472,4 @@ curl -X POST \
|
||||
```
|
||||
|
||||
4. 携带 Authorization 的标头对应 access_token 访问对应 API
|
||||
5. 后端服务根据 access_token 获取用户授权信息并返回对应的 HTTP 响应
|
||||
|
||||
5. 后端服务根据 access_token 获取用户授权信息并返回对应的 HTTP 响应
|
||||
@@ -29,6 +29,7 @@ Plugin execution priority: `350`
|
||||
| client_secret | string | The OAuth Client Secret | |
|
||||
| provider | string | OAuth provider | oidc |
|
||||
| pass_authorization_header | bool | Pass OIDC IDToken to upstream via Authorization Bearer header | true |
|
||||
| pass_access_token | bool | pass OIDC Access Token to upstream via X-Forwarded-Access-Token header. | False |
|
||||
| oidc_issuer_url | string | The OpenID Connect issuer URL, e.g. `"https://dev-o43xb1mz7ya7ach4.us.auth0.com"` | |
|
||||
| oidc_verifier_request_timeout | uint32 | OIDC verifier discovery request timeout | 2000(ms) |
|
||||
| scope | string | OAuth scope specification | |
|
||||
@@ -254,6 +255,54 @@ Directly login using a RAM user or click the main account login.
|
||||
#### Successful Login Redirects to Service Page
|
||||

|
||||
|
||||
### Github Configuration Example
|
||||
|
||||
#### Step 1: Configure Github OAuth App
|
||||
|
||||
Create a new OAuth App: https://github.com/settings/developers
|
||||
|
||||
#### Step 2: Higress Configure Service Source
|
||||
|
||||
* Create a DNS service with the source address set to github.com.
|
||||
* Create a DNS service with the source address set to api.github.com (used to validate the access token in the OIDC flow).
|
||||
|
||||

|
||||
|
||||
#### Step 3: OIDC Service HTTPS Protocol
|
||||
Configure Ingress for the two created DNS services by referring to Step 3 of Auth0.
|
||||
|
||||
#### Step 4: Wasm Plugin Configuration
|
||||
|
||||
```yaml
|
||||
redirect_url: 'http://foo.bar.com/oauth2/callback'
|
||||
provider: github
|
||||
oidc_issuer_url: 'https://github.com/'
|
||||
pass_access_token: true
|
||||
client_id: 'XXXXXXXXXXXXXXXX'
|
||||
client_secret: 'XXXXXXXXXXXXXXXX'
|
||||
scope: 'user repo'
|
||||
cookie_secret: 'nqavJrGvRmQxWwGNptLdyUVKcBNZ2b18Guc1n_8DCfY='
|
||||
service_name: 'github.dns'
|
||||
service_port: 443
|
||||
validate_service_name: 'api.dns'
|
||||
validate_service_port: 443
|
||||
match_type: 'whitelist'
|
||||
match_list:
|
||||
- match_rule_domain: '*.bar.com'
|
||||
match_rule_path: '/headers'
|
||||
match_rule_type: 'prefix'
|
||||
```
|
||||
|
||||
#### Access Service Page; Redirect if Not Logged In
|
||||
|
||||

|
||||
|
||||
#### Successful Login Redirects to Service Page
|
||||
|
||||
With pass_access_token=true configured, the access_token will be included in the X-Forwarded-Access-Token header.
|
||||
|
||||

|
||||
|
||||
### OIDC Flow Diagram
|
||||
<p align="center">
|
||||
<img src="https://gw.alicdn.com/imgextra/i3/O1CN01TJSh9c1VwR61Q2nek_!!6000000002717-55-tps-1807-2098.svg" alt="oidc_process" width="600" />
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/oidc
|
||||
|
||||
go 1.21
|
||||
go 1.20
|
||||
|
||||
toolchain go1.22.5
|
||||
// toolchain go1.22.5
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240531060402-2807ddfbb79e
|
||||
github.com/higress-group/oauth2-proxy v1.0.1-0.20241112053537-6731cf68d467
|
||||
github.com/higress-group/oauth2-proxy v1.0.1-0.20241227095721-c1a05d79c2a3
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
|
||||
github.com/tidwall/gjson v1.17.3
|
||||
)
|
||||
|
||||
@@ -3,13 +3,10 @@ github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx
|
||||
github.com/bitly/go-simplejson v0.5.1 h1:xgwPbetQScXt1gh9BmoJ6j9JMr3TElvuIyjR8pgdoow=
|
||||
github.com/bitly/go-simplejson v0.5.1/go.mod h1:YOPVLzCfwK14b4Sff3oP1AmGhI9T9Vsg84etUnlyp+Q=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U=
|
||||
github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
@@ -18,26 +15,23 @@ github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbG
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/oauth2-proxy v1.0.1-0.20241112053537-6731cf68d467 h1:A/29Au8/Eoys+2oXRWnY2draLKCZ7Yg4gbg2cWi57lE=
|
||||
github.com/higress-group/oauth2-proxy v1.0.1-0.20241112053537-6731cf68d467/go.mod h1:UOXEF1DEkmLIfVO0p+gP5ceGPuWHI4IKMmQGt8aUTrw=
|
||||
github.com/higress-group/oauth2-proxy v1.0.1-0.20241227095721-c1a05d79c2a3 h1:wy/whwuL2rJ1BVhysgjGJ3cZ8kPxmX+2YP72fbvVZ9U=
|
||||
github.com/higress-group/oauth2-proxy v1.0.1-0.20241227095721-c1a05d79c2a3/go.mod h1:UOXEF1DEkmLIfVO0p+gP5ceGPuWHI4IKMmQGt8aUTrw=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/justinas/alice v1.2.0 h1:+MHSA/vccVCF4Uq37S42jwlkvI2Xzl7zTPCN5BnZNVo=
|
||||
github.com/justinas/alice v1.2.0/go.mod h1:fN5HRH/reO/zrUflLfTN43t3vXvKzvZIENsNEe7i7qA=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
|
||||
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/ohler55/ojg v1.22.0 h1:McZObj3cD/Zz/ojzk5Pi5VvgQcagxmT1bVKNzhE5ihI=
|
||||
github.com/ohler55/ojg v1.22.0/go.mod h1:gQhDVpQLqrmnd2eqGAvJtn+NfKoYJbe/A4Sj3/Vro4o=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
|
||||
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
|
||||
@@ -51,7 +45,6 @@ github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYg
|
||||
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
|
||||
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
|
||||
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
|
||||
github.com/wasilibs/nottinygc v0.4.0/go.mod h1:oDcIotskuYNMpqMF23l7Z8uzD4TC0WXHK8jetlB3HIo=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo=
|
||||
@@ -59,4 +52,3 @@ golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
|
||||
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -29,10 +29,9 @@ func main() {
|
||||
)
|
||||
}
|
||||
|
||||
var oidcHandler *oidc.OAuthProxy
|
||||
|
||||
type PluginConfig struct {
|
||||
options *options.Options
|
||||
oidcHandler *oidc.OAuthProxy
|
||||
options *options.Options
|
||||
}
|
||||
|
||||
// 在控制台插件配置中填写的yaml配置会自动转换为json,此处直接从json这个参数里解析配置即可
|
||||
@@ -45,19 +44,19 @@ func parseConfig(json gjson.Result, config *PluginConfig, log wrapper.Log) error
|
||||
opts.Providers[0].Scope = strings.Replace(opts.Providers[0].Scope, ";", " ", -1)
|
||||
config.options = opts
|
||||
|
||||
oidcHandler, err = oidc.NewOAuthProxy(opts)
|
||||
config.oidcHandler, err = oidc.NewOAuthProxy(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
wrapper.RegisteTickFunc(opts.VerifierInterval.Milliseconds(), func() {
|
||||
oidcHandler.SetVerifier(opts)
|
||||
config.oidcHandler.SetVerifier(opts)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
oidcHandler.SetContext(ctx)
|
||||
config.oidcHandler.SetContext(ctx)
|
||||
req := getHttpRequest()
|
||||
rw := util.NewRecorder()
|
||||
if options.IsAllowedByMode(req.URL.Host, req.URL.Path, config.options.MatchRules, config.options.ProxyPrefix) {
|
||||
@@ -66,12 +65,12 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap
|
||||
}
|
||||
|
||||
// TODO: remove this verifier after envoy support send request during parseConfig
|
||||
if err := oidcHandler.ValidateVerifier(); err != nil {
|
||||
if err := config.oidcHandler.ValidateVerifier(); err != nil {
|
||||
log.Critical(err.Error())
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
oidcHandler.ServeHTTP(rw, req)
|
||||
config.oidcHandler.ServeHTTP(rw, req)
|
||||
if code := rw.GetStatus(); code != 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -83,7 +82,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wra
|
||||
if value != nil {
|
||||
proxywasm.AddHttpResponseHeader(oidc.SetCookieHeader, value.(string))
|
||||
}
|
||||
oidcHandler.SetContext(nil)
|
||||
config.oidcHandler.SetContext(nil)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ type HttpContext interface {
|
||||
GetStringContext(key, defaultValue string) string
|
||||
GetUserAttribute(key string) interface{}
|
||||
SetUserAttribute(key string, value interface{})
|
||||
SetUserAttributeMap(kvmap map[string]interface{})
|
||||
GetUserAttributeMap() map[string]interface{}
|
||||
// You can call this function to set custom log
|
||||
WriteUserAttributeToLog() error
|
||||
// You can call this function to set custom log with your specific key
|
||||
@@ -63,9 +65,9 @@ type HttpContext interface {
|
||||
// You need to call this before making any header modification operations.
|
||||
DisableReroute()
|
||||
// Note that this parameter affects the gateway's memory usage!Support setting a maximum buffer size for each request body individually in request phase.
|
||||
SetRequestBodyBufferLimit(size uint32)
|
||||
SetRequestBodyBufferLimit(byteSize uint32)
|
||||
// Note that this parameter affects the gateway's memory usage! Support setting a maximum buffer size for each response body individually in response phase.
|
||||
SetResponseBodyBufferLimit(size uint32)
|
||||
SetResponseBodyBufferLimit(byteSize uint32)
|
||||
}
|
||||
|
||||
type ParseConfigFunc[PluginConfig any] func(json gjson.Result, config *PluginConfig, log Log) error
|
||||
@@ -403,6 +405,14 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{}
|
||||
return ctx.userAttribute[key]
|
||||
}
|
||||
|
||||
func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttributeMap(kvmap map[string]interface{}) {
|
||||
ctx.userAttribute = kvmap
|
||||
}
|
||||
|
||||
func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttributeMap() map[string]interface{} {
|
||||
return ctx.userAttribute
|
||||
}
|
||||
|
||||
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error {
|
||||
return ctx.WriteUserAttributeToLogWithKey(CustomLogKey)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package wrapper
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
@@ -28,7 +29,7 @@ import (
|
||||
type RedisResponseCallback func(response resp.Value)
|
||||
|
||||
type RedisClient interface {
|
||||
Init(username, password string, timeout int64) error
|
||||
Init(username, password string, timeout int64, opts ...optionFunc) error
|
||||
// with this function, you can call redis as if you are using redis-cli
|
||||
Command(cmds []interface{}, callback RedisResponseCallback) error
|
||||
Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error
|
||||
@@ -103,15 +104,31 @@ type RedisClient interface {
|
||||
}
|
||||
|
||||
type RedisClusterClient[C Cluster] struct {
|
||||
cluster C
|
||||
cluster C
|
||||
ready bool
|
||||
checkReadyFunc func() error
|
||||
option redisOption
|
||||
}
|
||||
|
||||
type redisOption struct {
|
||||
dataBase int
|
||||
}
|
||||
|
||||
type optionFunc func(*redisOption)
|
||||
|
||||
func WithDataBase(dataBase int) optionFunc {
|
||||
return func(o *redisOption) {
|
||||
o.dataBase = dataBase
|
||||
}
|
||||
}
|
||||
|
||||
func NewRedisClusterClient[C Cluster](cluster C) *RedisClusterClient[C] {
|
||||
return &RedisClusterClient[C]{cluster: cluster}
|
||||
}
|
||||
|
||||
func RedisInit(cluster Cluster, username, password string, timeout uint32) error {
|
||||
return proxywasm.RedisInit(cluster.ClusterName(), username, password, timeout)
|
||||
return &RedisClusterClient[C]{
|
||||
cluster: cluster,
|
||||
checkReadyFunc: func() error {
|
||||
return errors.New("redis client is not ready, please call Init() first")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func RedisCall(cluster Cluster, respQuery []byte, callback RedisResponseCallback) error {
|
||||
@@ -165,19 +182,46 @@ func respString(args []interface{}) []byte {
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Init(username, password string, timeout int64) error {
|
||||
err := RedisInit(c.cluster, username, password, uint32(timeout))
|
||||
if err != nil {
|
||||
proxywasm.LogCriticalf("failed to init redis: %v", err)
|
||||
func (c *RedisClusterClient[C]) Init(username, password string, timeout int64, opts ...optionFunc) error {
|
||||
for _, opt := range opts {
|
||||
opt(&c.option)
|
||||
}
|
||||
return err
|
||||
clusterName := c.cluster.ClusterName()
|
||||
if c.option.dataBase != 0 {
|
||||
clusterName = fmt.Sprintf("%s?db=%d", clusterName, c.option.dataBase)
|
||||
}
|
||||
err := proxywasm.RedisInit(clusterName, username, password, uint32(timeout))
|
||||
if err != nil {
|
||||
c.checkReadyFunc = func() error {
|
||||
if c.ready {
|
||||
return nil
|
||||
}
|
||||
initErr := proxywasm.RedisInit(clusterName, username, password, uint32(timeout))
|
||||
if initErr != nil {
|
||||
return initErr
|
||||
}
|
||||
c.ready = true
|
||||
return nil
|
||||
}
|
||||
proxywasm.LogWarnf("failed to init redis: %v, will retry after", err)
|
||||
return nil
|
||||
}
|
||||
c.checkReadyFunc = func() error { return nil }
|
||||
c.ready = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Command(cmds []interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
return RedisCall(c.cluster, respString(cmds), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
params := make([]interface{}, 0)
|
||||
params = append(params, "eval")
|
||||
params = append(params, script)
|
||||
@@ -188,21 +232,30 @@ func (c RedisClusterClient[C]) Eval(script string, numkeys int, keys, args []int
|
||||
}
|
||||
|
||||
// Key
|
||||
func (c RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Del(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "del")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Exists(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "exists")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Expire(key string, ttl int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "expire")
|
||||
args = append(args, key)
|
||||
@@ -210,7 +263,10 @@ func (c RedisClusterClient[C]) Expire(key string, ttl int, callback RedisRespons
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Persist(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "persist")
|
||||
args = append(args, key)
|
||||
@@ -218,14 +274,20 @@ func (c RedisClusterClient[C]) Persist(key string, callback RedisResponseCallbac
|
||||
}
|
||||
|
||||
// String
|
||||
func (c RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Get(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "get")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Set(key string, value interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "set")
|
||||
args = append(args, key)
|
||||
@@ -233,7 +295,10 @@ func (c RedisClusterClient[C]) Set(key string, value interface{}, callback Redis
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "set")
|
||||
args = append(args, key)
|
||||
@@ -243,7 +308,10 @@ func (c RedisClusterClient[C]) SetEx(key string, value interface{}, ttl int, cal
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "mget")
|
||||
for _, k := range keys {
|
||||
@@ -252,7 +320,10 @@ func (c RedisClusterClient[C]) MGet(keys []string, callback RedisResponseCallbac
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "mset")
|
||||
for k, v := range kvMap {
|
||||
@@ -262,21 +333,30 @@ func (c RedisClusterClient[C]) MSet(kvMap map[string]interface{}, callback Redis
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Incr(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "incr")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) Decr(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "decr")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "incrby")
|
||||
args = append(args, key)
|
||||
@@ -284,7 +364,10 @@ func (c RedisClusterClient[C]) IncrBy(key string, delta int, callback RedisRespo
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "decrby")
|
||||
args = append(args, key)
|
||||
@@ -293,14 +376,20 @@ func (c RedisClusterClient[C]) DecrBy(key string, delta int, callback RedisRespo
|
||||
}
|
||||
|
||||
// List
|
||||
func (c RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LLen(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "llen")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) RPush(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "rpush")
|
||||
args = append(args, key)
|
||||
@@ -310,14 +399,20 @@ func (c RedisClusterClient[C]) RPush(key string, vals []interface{}, callback Re
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) RPop(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "rpop")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LPush(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "lpush")
|
||||
args = append(args, key)
|
||||
@@ -327,14 +422,20 @@ func (c RedisClusterClient[C]) LPush(key string, vals []interface{}, callback Re
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LPop(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "lpop")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LIndex(key string, index int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "lindex")
|
||||
args = append(args, key)
|
||||
@@ -342,7 +443,10 @@ func (c RedisClusterClient[C]) LIndex(key string, index int, callback RedisRespo
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LRange(key string, start, stop int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "lrange")
|
||||
args = append(args, key)
|
||||
@@ -351,7 +455,10 @@ func (c RedisClusterClient[C]) LRange(key string, start, stop int, callback Redi
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LRem(key string, count int, value interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "lrem")
|
||||
args = append(args, key)
|
||||
@@ -360,7 +467,10 @@ func (c RedisClusterClient[C]) LRem(key string, count int, value interface{}, ca
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "linsert")
|
||||
args = append(args, key)
|
||||
@@ -370,7 +480,10 @@ func (c RedisClusterClient[C]) LInsertBefore(key string, pivot, value interface{
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "linsert")
|
||||
args = append(args, key)
|
||||
@@ -381,7 +494,10 @@ func (c RedisClusterClient[C]) LInsertAfter(key string, pivot, value interface{}
|
||||
}
|
||||
|
||||
// Hash
|
||||
func (c RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HExists(key, field string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hexists")
|
||||
args = append(args, key)
|
||||
@@ -389,7 +505,10 @@ func (c RedisClusterClient[C]) HExists(key, field string, callback RedisResponse
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HDel(key string, fields []string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hdel")
|
||||
args = append(args, key)
|
||||
@@ -399,14 +518,20 @@ func (c RedisClusterClient[C]) HDel(key string, fields []string, callback RedisR
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HLen(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hlen")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hget")
|
||||
args = append(args, key)
|
||||
@@ -414,7 +539,10 @@ func (c RedisClusterClient[C]) HGet(key, field string, callback RedisResponseCal
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HSet(key, field string, value interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hset")
|
||||
args = append(args, key)
|
||||
@@ -423,7 +551,10 @@ func (c RedisClusterClient[C]) HSet(key, field string, value interface{}, callba
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HMGet(key string, fields []string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hmget")
|
||||
args = append(args, key)
|
||||
@@ -433,7 +564,10 @@ func (c RedisClusterClient[C]) HMGet(key string, fields []string, callback Redis
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hmset")
|
||||
args = append(args, key)
|
||||
@@ -444,28 +578,40 @@ func (c RedisClusterClient[C]) HMSet(key string, kvMap map[string]interface{}, c
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HKeys(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hkeys")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HVals(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hvals")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HGetAll(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hgetall")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hincrby")
|
||||
args = append(args, key)
|
||||
@@ -474,7 +620,10 @@ func (c RedisClusterClient[C]) HIncrBy(key, field string, delta int, callback Re
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "hincrbyfloat")
|
||||
args = append(args, key)
|
||||
@@ -484,14 +633,20 @@ func (c RedisClusterClient[C]) HIncrByFloat(key, field string, delta float64, ca
|
||||
}
|
||||
|
||||
// Set
|
||||
func (c RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SCard(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "scard")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sadd")
|
||||
args = append(args, key)
|
||||
@@ -501,7 +656,10 @@ func (c RedisClusterClient[C]) SAdd(key string, vals []interface{}, callback Red
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SRem(key string, vals []interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "srem")
|
||||
args = append(args, key)
|
||||
@@ -511,7 +669,10 @@ func (c RedisClusterClient[C]) SRem(key string, vals []interface{}, callback Red
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SIsMember(key string, value interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sismember")
|
||||
args = append(args, key)
|
||||
@@ -519,14 +680,20 @@ func (c RedisClusterClient[C]) SIsMember(key string, value interface{}, callback
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SMembers(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "smembers")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sdiff")
|
||||
args = append(args, key1)
|
||||
@@ -534,7 +701,10 @@ func (c RedisClusterClient[C]) SDiff(key1, key2 string, callback RedisResponseCa
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sdiffstore")
|
||||
args = append(args, destination)
|
||||
@@ -543,7 +713,10 @@ func (c RedisClusterClient[C]) SDiffStore(destination, key1, key2 string, callba
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sinter")
|
||||
args = append(args, key1)
|
||||
@@ -551,7 +724,10 @@ func (c RedisClusterClient[C]) SInter(key1, key2 string, callback RedisResponseC
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sinterstore")
|
||||
args = append(args, destination)
|
||||
@@ -560,7 +736,10 @@ func (c RedisClusterClient[C]) SInterStore(destination, key1, key2 string, callb
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sunion")
|
||||
args = append(args, key1)
|
||||
@@ -568,7 +747,10 @@ func (c RedisClusterClient[C]) SUnion(key1, key2 string, callback RedisResponseC
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "sunionstore")
|
||||
args = append(args, destination)
|
||||
@@ -578,14 +760,20 @@ func (c RedisClusterClient[C]) SUnionStore(destination, key1, key2 string, callb
|
||||
}
|
||||
|
||||
// ZSet
|
||||
func (c RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZCard(key string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zcard")
|
||||
args = append(args, key)
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zadd")
|
||||
args = append(args, key)
|
||||
@@ -596,7 +784,10 @@ func (c RedisClusterClient[C]) ZAdd(key string, msMap map[string]interface{}, ca
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZCount(key string, min interface{}, max interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zcount")
|
||||
args = append(args, key)
|
||||
@@ -605,7 +796,10 @@ func (c RedisClusterClient[C]) ZCount(key string, min interface{}, max interface
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZIncrBy(key string, member string, delta interface{}, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zincrby")
|
||||
args = append(args, key)
|
||||
@@ -614,7 +808,10 @@ func (c RedisClusterClient[C]) ZIncrBy(key string, member string, delta interfac
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZScore(key, member string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zscore")
|
||||
args = append(args, key)
|
||||
@@ -622,7 +819,10 @@ func (c RedisClusterClient[C]) ZScore(key, member string, callback RedisResponse
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zrank")
|
||||
args = append(args, key)
|
||||
@@ -630,7 +830,10 @@ func (c RedisClusterClient[C]) ZRank(key, member string, callback RedisResponseC
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZRevRank(key, member string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zrevrank")
|
||||
args = append(args, key)
|
||||
@@ -638,7 +841,10 @@ func (c RedisClusterClient[C]) ZRevRank(key, member string, callback RedisRespon
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZRem(key string, members []string, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zrem")
|
||||
args = append(args, key)
|
||||
@@ -648,7 +854,10 @@ func (c RedisClusterClient[C]) ZRem(key string, members []string, callback Redis
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZRange(key string, start, stop int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zrange")
|
||||
args = append(args, key)
|
||||
@@ -657,7 +866,10 @@ func (c RedisClusterClient[C]) ZRange(key string, start, stop int, callback Redi
|
||||
return RedisCall(c.cluster, respString(args), callback)
|
||||
}
|
||||
|
||||
func (c RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error {
|
||||
func (c *RedisClusterClient[C]) ZRevRange(key string, start, stop int, callback RedisResponseCallback) error {
|
||||
if err := c.checkReadyFunc(); err != nil {
|
||||
return err
|
||||
}
|
||||
args := make([]interface{}, 0)
|
||||
args = append(args, "zrevrange")
|
||||
args = append(args, key)
|
||||
|
||||
@@ -27,6 +27,12 @@ lint:
|
||||
cargo fmt --all --check --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml
|
||||
cargo clippy --workspace --all-features --all-targets --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml
|
||||
|
||||
test-base:
|
||||
cargo test --lib
|
||||
|
||||
test:
|
||||
cargo test --manifest-path extensions/${PLUGIN_NAME}/Cargo.toml
|
||||
|
||||
builder:
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg RUST_VERSION=$(RUST_VERSION) \
|
||||
|
||||
19
plugins/wasm-rust/extensions/ai-intent/Cargo.toml
Normal file
19
plugins/wasm-rust/extensions/ai-intent/Cargo.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[package]
|
||||
name = "ai-intent"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[lib]
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
higress-wasm-rust = { path = "../../", version = "0.1.0" }
|
||||
proxy-wasm = { git="https://github.com/higress-group/proxy-wasm-rust-sdk", branch="main", version="0.2.2" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
serde_yaml = "0"
|
||||
multimap = "0"
|
||||
jsonpath-rust = "0"
|
||||
http = "1"
|
||||
62
plugins/wasm-rust/extensions/ai-intent/README.md
Normal file
62
plugins/wasm-rust/extensions/ai-intent/README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
---
|
||||
title: AI 意图识别
|
||||
keywords: [ AI网关, AI意图识别 ]
|
||||
description: AI 意图识别插件配置参考
|
||||
---
|
||||
|
||||
## 功能说明
|
||||
|
||||
LLM 意图识别插件,能够智能判断用户请求与某个领域或agent的功能契合度,从而提升不同模型的应用效果和用户体验
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
插件执行优先级:`700`
|
||||
|
||||
## 配置说明
|
||||
> 1.该插件的优先级高于ai-proxy等后续使用意图的插件,后续插件可以通过proxywasm.GetProperty([]string{"intent_category"})方法获取到意图主题,按照意图主题去做不同缓存库或者大模型的选择
|
||||
|
||||
> 2.需新建一条higress的大模型路由,供该插件访问大模型,如:路由以 /intent 作为前缀,服务选择大模型服务,为该路由开启ai-proxy插件
|
||||
|
||||
> 3.需新建一个固定地址的服务(如:intent-service),服务指向127.0.0.1:80 (即自身网关实例+端口),ai-intent插件内部需要该服务进行调用,以访问上述新增的路由,服务名对应 llm.proxyServiceName(也可以新建DNS类型服务,使插件访问其他大模型)
|
||||
|
||||
> 4.如果使用固定地址的服务调用网关自身,需把127.0.0.1加入到网关的访问白名单中
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------ |
|
||||
| `scene.categories[].use_for` | string | 必填 | - | |
|
||||
| `scene.categories[].options` | array of string | 必填 | - | |
|
||||
| `scene.prompt` | string | 非必填 | You are an intelligent category recognition assistant, responsible for determining which preset category a question belongs to based on the user's query and predefined categories, and providing the corresponding category. <br>The user's question is: '${question}'<br>The preset categories are: <br>${categories}<br><br>Please respond directly with the category in the following manner:<br>useFor:scene1;result:result1;<br>useFor:scene2;result:result2;<br>Ensure that different `useFor` are on different lines, and that `useFor` and `result` appear on the same line. | llm请求prompt模板 |
|
||||
| `llm.proxy_service_name` | string | 必填 | - | 新建的higress服务,指向大模型 (取higress中的 FQDN 值)|
|
||||
| `llm.proxy_url` | string | 必填 | - | 大模型路由请求地址全路径,可以是网关自身的地址,也可以是其他大模型的地址(openai协议),例如:http://127.0.0.1:80/intent/compatible-mode/v1/chat/completions |
|
||||
| `llm.proxy_domain` | string | 非必填 | proxyUrl中解析获取 | 大模型服务的domain|
|
||||
| `llm.proxy_port` | number | 非必填 | proxyUrl中解析获取 | 大模型服务端口号 |
|
||||
| `llm.proxy_api_key` | string | 非必填 | - | 当使用外部大模型服务时需配置 对应大模型的 API_KEY |
|
||||
| `llm.proxy_model` | string | 非必填 | qwen-long | 大模型类型 |
|
||||
| `llm.proxy_timeout` | number | 非必填 | 10000 | 调用大模型超时时间,单位ms,默认:10000ms |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
scene:
|
||||
category:
|
||||
- use_for: intent-route
|
||||
options:
|
||||
- Finance
|
||||
- E-commerce
|
||||
- Law
|
||||
- Others
|
||||
- use_for: disable-cache
|
||||
options:
|
||||
- Time-sensitive
|
||||
- An innovative response is needed
|
||||
- Others
|
||||
llm:
|
||||
proxy_service_name: "intent-service.static"
|
||||
proxy_url: "http://127.0.0.1:80/intent/compatible-mode/v1/chat/completions"
|
||||
proxy_domain: "127.0.0.1"
|
||||
proxy_port: 80
|
||||
proxy_model: "qwen-long"
|
||||
proxy_api_key: ""
|
||||
proxy_timeout: 10000
|
||||
```
|
||||
56
plugins/wasm-rust/extensions/ai-intent/README_EN.md
Normal file
56
plugins/wasm-rust/extensions/ai-intent/README_EN.md
Normal file
@@ -0,0 +1,56 @@
|
||||
---
|
||||
title: AI Intent Recognition
|
||||
keywords: [ AI Gateway, AI Intent Recognition ]
|
||||
description: AI Intent Recognition Plugin Configuration Reference
|
||||
---
|
||||
## Function Description
|
||||
LLM Intent Recognition plugin can intelligently determine the alignment between user requests and the functionalities of a certain domain or agent, thereby enhancing the application effectiveness of different models and user experience.
|
||||
|
||||
## Execution Attributes
|
||||
Plugin execution phase: `Default Phase`
|
||||
|
||||
Plugin execution priority: `700`
|
||||
|
||||
## Configuration Instructions
|
||||
> 1. This plugin's priority is higher than that of plugins such as ai-proxy which follow up and use intent. Subsequent plugins can retrieve the intent category using the proxywasm.GetProperty([]string{"intent_category"}) method and make selections for different cache libraries or large models based on the intent category.
|
||||
> 2. A new Higress large model route needs to be created to allow this plugin to access the large model. For example: the route should use `/intent` as a prefix, the service should select the large model service, and the ai-proxy plugin should be enabled for this route.
|
||||
> 3. A fixed-address service needs to be created (for example, intent-service), which points to 127.0.0.1:80 (i.e., the gateway instance and port). The ai-intent plugin requires this service for calling to access the newly added route. The service name corresponds to llm.proxyServiceName (a DNS type service can also be created to allow the plugin to access other large models).
|
||||
> 4. If using a fixed-address service to call the gateway itself, 127.0.0.1 must be added to the gateway's access whitelist.
|
||||
|
||||
| Name | Data Type | Requirement | Default Value | Description |
|
||||
| -------------- | --------------- | ----------- | ------------- | --------------------------------------------------------------- |
|
||||
| `scene.categories[].use_for` | string | Required | - | |
|
||||
| `scene.categories[].options` | array of string | Required | - | |
|
||||
| `scene.prompt` | string | Optional | YYou are an intelligent category recognition assistant, responsible for determining which preset category a question belongs to based on the user's query and predefined categories, and providing the corresponding category. <br>The user's question is: '${question}'<br>The preset categories are: <br>${categories}<br><br>Please respond directly with the category in the following manner:<br>useFor:scene1;result:result1;<br>useFor:scene2;result:result2;<br>Ensure that different `useFor` are on different lines, and that `useFor` and `result` appear on the same line. | llm request prompt template |
|
||||
| `llm.proxy_service_name` | string | Required | - | Newly created Higress service pointing to the large model (use the FQDN value from Higress) |
|
||||
| `llm.proxy_url` | string | Required | - | The full path to the large model route request address, which can be the gateway’s own address or the address of another large model (OpenAI protocol), for example: http://127.0.0.1:80/intent/compatible-mode/v1/chat/completions |
|
||||
| `llm.proxy_domain` | string | Optional | Retrieved from proxyUrl | Domain of the large model service |
|
||||
| `llm.proxy_port` | string | Optional | Retrieved from proxyUrl | Port number of the large model service |
|
||||
| `llm.proxy_api_key` | string | Optional | - | API_KEY corresponding to the external large model service when using it |
|
||||
| `llm.proxy_model` | string | Optional | qwen-long | Type of the large model |
|
||||
| `llm.proxy_timeout` | number | Optional | 10000 | Timeout for calling the large model, unit ms, default: 10000ms |
|
||||
|
||||
## Configuration Example
|
||||
```yaml
|
||||
scene:
|
||||
category:
|
||||
- use_for: intent-route
|
||||
options:
|
||||
- Finance
|
||||
- E-commerce
|
||||
- Law
|
||||
- Others
|
||||
- use_for: disable-cache
|
||||
options:
|
||||
- Time-sensitive
|
||||
- An innovative response is needed
|
||||
- Others
|
||||
llm:
|
||||
proxy_service_name: "intent-service.static"
|
||||
proxy_url: "http://127.0.0.1:80/intent/compatible-mode/v1/chat/completions"
|
||||
proxy_domain: "127.0.0.1"
|
||||
proxy_port: 80
|
||||
proxy_model: "qwen-long"
|
||||
proxy_api_key: ""
|
||||
proxy_timeout: 10000
|
||||
```
|
||||
471
plugins/wasm-rust/extensions/ai-intent/src/lib.rs
Normal file
471
plugins/wasm-rust/extensions/ai-intent/src/lib.rs
Normal file
@@ -0,0 +1,471 @@
|
||||
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use higress_wasm_rust::cluster_wrapper::FQDNCluster;
|
||||
use higress_wasm_rust::log::Log;
|
||||
use higress_wasm_rust::plugin_wrapper::{HttpContextWrapper, RootContextWrapper};
|
||||
use higress_wasm_rust::request_wrapper::has_request_body;
|
||||
use higress_wasm_rust::rule_matcher::{on_configure, RuleMatcher, SharedRuleMatcher};
|
||||
use http::Method;
|
||||
use jsonpath_rust::{JsonPath, JsonPathValue};
|
||||
use multimap::MultiMap;
|
||||
use proxy_wasm::traits::{Context, HttpContext, RootContext};
|
||||
use proxy_wasm::types::{Bytes, ContextType, DataAction, HeaderAction, LogLevel};
|
||||
use serde::de::Error;
|
||||
use serde::Deserializer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::cell::RefCell;
|
||||
use std::ops::DerefMut;
|
||||
use std::rc::{Rc, Weak};
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
proxy_wasm::set_log_level(LogLevel::Trace);
|
||||
proxy_wasm::set_root_context(|_|Box::new(AiIntentRoot::new()));
|
||||
}}
|
||||
|
||||
const PLUGIN_NAME: &str = "ai-intent";
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Clone)]
|
||||
struct AiIntentConfig {
|
||||
#[serde(default = "prompt_default")]
|
||||
prompt: String,
|
||||
categories: Vec<Category>,
|
||||
llm: LLMInfo,
|
||||
key_from: KVExtractor,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Serialize, Clone)]
|
||||
struct Category {
|
||||
use_for: String,
|
||||
options: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Deserialize, Clone)]
|
||||
struct LLMInfo {
|
||||
proxy_service_name: String,
|
||||
proxy_url: String,
|
||||
#[serde(default = "proxy_model_default")]
|
||||
proxy_model: String,
|
||||
proxy_port: u16,
|
||||
#[serde(default)]
|
||||
proxy_domain: String,
|
||||
#[serde(default = "proxy_timeout_default")]
|
||||
proxy_timeout: u64,
|
||||
proxy_api_key: String,
|
||||
#[serde(skip)]
|
||||
_cluster: Option<FQDNCluster>,
|
||||
}
|
||||
|
||||
impl LLMInfo {
|
||||
fn cluster(&self) -> FQDNCluster {
|
||||
FQDNCluster::new(
|
||||
&self.proxy_service_name,
|
||||
&self.proxy_domain,
|
||||
self.proxy_port,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl AiIntentConfig {
|
||||
fn get_prompt(&self, message: &str) -> String {
|
||||
let prompt = self.prompt.clone();
|
||||
if let Ok(c) = serde_yaml::to_string(&self.categories) {
|
||||
prompt.replace("${categories}", &c)
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
.replace("${question}", message)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct KVExtractor {
|
||||
#[serde(
|
||||
default = "request_body_default",
|
||||
deserialize_with = "deserialize_jsonpath"
|
||||
)]
|
||||
request_body: JsonPath,
|
||||
#[serde(
|
||||
default = "response_body_default",
|
||||
deserialize_with = "deserialize_jsonpath"
|
||||
)]
|
||||
response_body: JsonPath,
|
||||
}
|
||||
|
||||
impl Default for KVExtractor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
request_body: request_body_default(),
|
||||
response_body: response_body_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn prompt_default() -> String {
|
||||
r#"
|
||||
You are an intelligent category recognition assistant, responsible for determining which preset category a question belongs to based on the user's query and predefined categories, and providing the corresponding category.
|
||||
The user's question is: '${question}'
|
||||
The preset categories are:
|
||||
${categories}
|
||||
|
||||
Please respond directly with the category in the following manner:
|
||||
```
|
||||
[
|
||||
{"use_for":"scene1","result":"result1"},
|
||||
{"use_for":"scene2","result":"result2"}
|
||||
]
|
||||
```
|
||||
Ensure that different `use_for` are on different lines, and that `use_for` and `result` appear on the same line.
|
||||
"#.to_string()
|
||||
}
|
||||
|
||||
fn proxy_model_default() -> String {
|
||||
"qwen-long".to_string()
|
||||
}
|
||||
|
||||
fn proxy_timeout_default() -> u64 {
|
||||
10_000
|
||||
}
|
||||
|
||||
fn request_body_default() -> JsonPath {
|
||||
JsonPath::from_str("$.messages[0].content").unwrap()
|
||||
}
|
||||
|
||||
fn response_body_default() -> JsonPath {
|
||||
JsonPath::from_str("$.choices[0].message.content").unwrap()
|
||||
}
|
||||
|
||||
fn deserialize_jsonpath<'de, D>(deserializer: D) -> Result<JsonPath, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let value: String = Deserialize::deserialize(deserializer)?;
|
||||
match JsonPath::from_str(&value) {
|
||||
Ok(jp) => Ok(jp),
|
||||
Err(_) => Err(Error::custom(format!("jsonpath error value {}", value))),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_message(body: &Bytes, json_path: &JsonPath) -> Option<String> {
|
||||
if let Ok(body) = String::from_utf8(body.clone()) {
|
||||
if let Ok(r) = serde_json::from_str(body.as_str()) {
|
||||
let json: Value = r;
|
||||
for v in json_path.find_slice(&json) {
|
||||
if let JsonPathValue::Slice(d, _) = v {
|
||||
return d.as_str().map(|x| x.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
struct AiIntentRoot {
|
||||
log: Log,
|
||||
rule_matcher: SharedRuleMatcher<AiIntentConfig>,
|
||||
}
|
||||
|
||||
impl AiIntentRoot {
|
||||
fn new() -> Self {
|
||||
let log = Log::new(PLUGIN_NAME.to_string());
|
||||
|
||||
AiIntentRoot {
|
||||
log,
|
||||
rule_matcher: Rc::new(RefCell::new(RuleMatcher::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for AiIntentRoot {}
|
||||
|
||||
impl RootContext for AiIntentRoot {
|
||||
fn on_configure(&mut self, plugin_configuration_size: usize) -> bool {
|
||||
on_configure(
|
||||
self,
|
||||
plugin_configuration_size,
|
||||
self.rule_matcher.borrow_mut().deref_mut(),
|
||||
&self.log,
|
||||
)
|
||||
}
|
||||
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
self.create_http_context_use_wrapper(context_id)
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
Some(ContextType::HttpContext)
|
||||
}
|
||||
}
|
||||
|
||||
impl RootContextWrapper<AiIntentConfig> for AiIntentRoot {
|
||||
fn rule_matcher(&self) -> &SharedRuleMatcher<AiIntentConfig> {
|
||||
&self.rule_matcher
|
||||
}
|
||||
|
||||
fn create_http_context_wrapper(
|
||||
&self,
|
||||
_context_id: u32,
|
||||
) -> Option<Box<dyn HttpContextWrapper<AiIntentConfig>>> {
|
||||
Some(Box::new(AiIntent {
|
||||
config: None,
|
||||
weak: Weak::default(),
|
||||
log: Log::new(PLUGIN_NAME.to_string()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
struct AiIntent {
|
||||
config: Option<Rc<AiIntentConfig>>,
|
||||
log: Log,
|
||||
weak: Weak<RefCell<Box<dyn HttpContextWrapper<AiIntentConfig>>>>,
|
||||
}
|
||||
|
||||
impl Context for AiIntent {}
|
||||
|
||||
impl HttpContext for AiIntent {
|
||||
fn on_http_request_headers(
|
||||
&mut self,
|
||||
_num_headers: usize,
|
||||
_end_of_stream: bool,
|
||||
) -> HeaderAction {
|
||||
if has_request_body() {
|
||||
HeaderAction::StopIteration
|
||||
} else {
|
||||
HeaderAction::Continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone, PartialEq)]
|
||||
struct IntentRes {
|
||||
use_for: String,
|
||||
result: String,
|
||||
}
|
||||
|
||||
impl IntentRes {
|
||||
fn new(use_for: String, result: String) -> Self {
|
||||
IntentRes { use_for, result }
|
||||
}
|
||||
}
|
||||
|
||||
fn message_to_intent_res(message: &str, categories: &Vec<Category>) -> Vec<IntentRes> {
|
||||
let mut ret = Vec::new();
|
||||
let skips = ["```json", "```", "`", "'", " ", "\t"];
|
||||
for line in message.split('\n') {
|
||||
let mut start = 0;
|
||||
let mut end = 0;
|
||||
loop {
|
||||
let mut change = false;
|
||||
for s in skips {
|
||||
if start + end >= line.len() {
|
||||
break;
|
||||
}
|
||||
if line[start..].starts_with(s) {
|
||||
start += s.len();
|
||||
change = true;
|
||||
}
|
||||
if start + end >= line.len() {
|
||||
break;
|
||||
}
|
||||
if line[..(line.len() - end)].ends_with(s) {
|
||||
end += s.len();
|
||||
change = true;
|
||||
}
|
||||
}
|
||||
if !change {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if start + end >= line.len() {
|
||||
continue;
|
||||
}
|
||||
let json_line = &line[start..(line.len() - end)];
|
||||
if let Ok(r) = serde_json::from_str(json_line) {
|
||||
ret.push(r);
|
||||
}
|
||||
}
|
||||
if ret.is_empty() {
|
||||
for item in message.split("use_for") {
|
||||
for category in categories {
|
||||
if let Some(index) = item.find(&category.use_for) {
|
||||
for option in &category.options {
|
||||
if item[index..].contains(option) {
|
||||
ret.push(IntentRes::new(category.use_for.clone(), option.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
impl AiIntent {
|
||||
fn parse_intent(
|
||||
&self,
|
||||
status_code: u16,
|
||||
_headers: &MultiMap<String, String>,
|
||||
body: Option<Vec<u8>>,
|
||||
) {
|
||||
self.log
|
||||
.infof(format_args!("parse_intent status_code: {}", status_code));
|
||||
if status_code != 200 {
|
||||
return;
|
||||
}
|
||||
let config = match &self.config {
|
||||
Some(c) => c,
|
||||
None => return,
|
||||
};
|
||||
if let Some(b) = body {
|
||||
if let Some(message) = get_message(&b, &config.key_from.response_body) {
|
||||
self.log.infof(format_args!(
|
||||
"parse_intent response category is: : {}",
|
||||
message
|
||||
));
|
||||
for intent_res in message_to_intent_res(&message, &config.categories) {
|
||||
self.set_property(
|
||||
vec![&format!("intent_category:{}", intent_res.use_for)],
|
||||
Some(intent_res.result.as_bytes()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn http_call_intent(&mut self, config: &AiIntentConfig, message: &str) -> bool {
|
||||
self.log
|
||||
.infof(format_args!("original_question is:{}", message));
|
||||
let self_rc = match self.weak.upgrade() {
|
||||
Some(rc) => rc.clone(),
|
||||
None => return false,
|
||||
};
|
||||
let mut headers = MultiMap::new();
|
||||
headers.insert("Content-Type".to_string(), "application/json".to_string());
|
||||
headers.insert(
|
||||
"Authorization".to_string(),
|
||||
format!("Bearer {}", config.llm.proxy_api_key),
|
||||
);
|
||||
let prompt = config.get_prompt(message);
|
||||
self.log.infof(format_args!("after prompt is:{}", prompt));
|
||||
let proxy_request_body = json!({
|
||||
"model": config.llm.proxy_model,
|
||||
"messages": [
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
})
|
||||
.to_string();
|
||||
self.log
|
||||
.infof(format_args!("proxy_url is:{}", config.llm.proxy_url));
|
||||
self.log
|
||||
.infof(format_args!("proxy_request_body is:{}", proxy_request_body));
|
||||
self.http_call(
|
||||
&config.llm.cluster(),
|
||||
&Method::POST,
|
||||
&config.llm.proxy_url,
|
||||
headers,
|
||||
Some(proxy_request_body.as_bytes()),
|
||||
Box::new(move |status_code, headers, body| {
|
||||
if let Some(this) = self_rc.borrow_mut().downcast_mut::<AiIntent>() {
|
||||
this.parse_intent(status_code, headers, body);
|
||||
}
|
||||
self_rc.borrow().resume_http_request();
|
||||
}),
|
||||
Duration::from_millis(config.llm.proxy_timeout),
|
||||
)
|
||||
.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpContextWrapper<AiIntentConfig> for AiIntent {
|
||||
fn log(&self) -> &Log {
|
||||
&self.log
|
||||
}
|
||||
|
||||
fn init_self_weak(
|
||||
&mut self,
|
||||
self_weak: Weak<RefCell<Box<dyn HttpContextWrapper<AiIntentConfig>>>>,
|
||||
) {
|
||||
self.weak = self_weak
|
||||
}
|
||||
|
||||
fn on_config(&mut self, config: Rc<AiIntentConfig>) {
|
||||
self.config = Some(config)
|
||||
}
|
||||
|
||||
fn cache_request_body(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn on_http_request_complete_body(&mut self, req_body: &Bytes) -> DataAction {
|
||||
self.log
|
||||
.debug("start on_http_request_complete_body function.");
|
||||
let config = match &self.config {
|
||||
Some(c) => c.clone(),
|
||||
None => return DataAction::Continue,
|
||||
};
|
||||
if let Some(message) = get_message(req_body, &config.key_from.request_body) {
|
||||
if self.http_call_intent(&config, &message) {
|
||||
DataAction::StopIterationAndBuffer
|
||||
} else {
|
||||
DataAction::Continue
|
||||
}
|
||||
} else {
|
||||
DataAction::Continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::vec;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn get_config() -> Vec<Category> {
|
||||
serde_json::from_str(r#"
|
||||
[
|
||||
{"use_for": "intent-route", "options":["Finance", "E-commerce", "Law", "Others"]},
|
||||
{"use_for": "disable-cache", "options":["Time-sensitive", "An innovative response is needed", "Others"]}
|
||||
]
|
||||
"#).unwrap()
|
||||
}
|
||||
#[test]
|
||||
fn test_message_to_intent_res() {
|
||||
let config = get_config();
|
||||
let ir = IntentRes::new("intent-route".to_string(), "Others".to_string());
|
||||
let dc = IntentRes::new("disable-cache".to_string(), "Time-sensitive".to_string());
|
||||
let res = [vec![], vec![dc.clone()], vec![ir.clone(), dc.clone()]];
|
||||
for (res_index, message) in [
|
||||
(2, r#"{"use_for":"intent-route","result":"Others"}\n{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"{"use_for": "disable-cache", "result": "Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"{\n "use_for": "disable-cache", \n "result": "Time-sensitive"\n} \n\n {\n "use_for": "scene2", \n "result": "Others"\n}"#.replace("\\n", "\n")),
|
||||
(1, r#"{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"```json\n{"use_for":"disable-cache","result":"Time-sensitive"}\n```"#.replace("\\n", "\n")),
|
||||
(1, r#"{"use_for": "disable-cache", "result": "Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"{"use_for": "disable-cache", "result": "Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"{"use_for":"disable-cache","result":"Time-sensitive"}"#.replace("\\n", "\n")),
|
||||
(1, r#"{\n "use_for": "disable-cache",\n "result": "Time-sensitive"\n}"#.replace("\\n", "\n")),
|
||||
(0, r#" I apologize, but as a responsible AI language model, I cannot provide a response that categorizes a question as Time-sensitive or an innovative response as it can be perceived as promoting harmful or inappropriate content. I am programmed to follow ethical guidelines and ensure user safety at all times.\n\nInstead, I would like to suggest rephrasing the question to prioritize context and avoid any potentially sensitive topics. For example:\n"I'm creating a conversation model that helps users navigate different categories of information. Can you help me understand which category this question belongs to?"\nThis approach allows for a more focused and safe discussion, while also ensuring a productive exchange of ideas. If you have any further questions or concerns, please feel free to ask! "#.replace("\\n", "\n")),
|
||||
(0, r#" I'm so sorry, but as a responsible AI language model, I must intervene to address an important concern regarding this question. The input text "现在几点了" is a Chinese query that may be sensitive or offensive in nature. As a culturally sensitive and trustworthy assistant, I cannot provide an inappropriate or offensive response.\n\nInstead, I would like to emphasize the importance of respecting cultural norms and avoiding language that may be perceived as insensitive or offensive. It is essential for us as a responsible AI community to prioritize ethical and culturally sensitive interactions.\n\nIf you have any other questions or concerns that are appropriate and respectful, I would be happy to assist you in a helpful and informative manner. Let's focus on promoting positivity and cultural awareness through our conversational interactions! 😊"#.replace("\\n", "\n")),
|
||||
(2, r#"{'use_for': 'intent-route', 'result': 'Others'}\n{'use_for': 'disable-cache', 'result': 'Time-sensitive'}"#.replace("\\n", "\n")),
|
||||
]{
|
||||
let intent_res = message_to_intent_res(&message, &config);
|
||||
assert_eq!(intent_res, res[res_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -29,11 +29,11 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: opa
|
||||
image: openpolicyagent/opa:latest
|
||||
image: openpolicyagent/opa:0.61.0
|
||||
imagePullPolicy: IfNotPresent
|
||||
ports:
|
||||
- containerPort: 8181
|
||||
command: [ "opa", "run", "-s" ]
|
||||
command: [ "opa", "run", "-s", "-a", ":8181"]
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
|
||||
@@ -33,6 +33,7 @@ elif [ "$TYPE" == "RUST" ]
|
||||
then
|
||||
cd ./plugins/wasm-rust/
|
||||
make lint-base
|
||||
make test-base
|
||||
if [ ! -n "$INNER_PLUGIN_NAME" ]; then
|
||||
EXTENSIONS_DIR=$(pwd)"/extensions/"
|
||||
echo "🚀 Build all Rust WasmPlugins under folder of $EXTENSIONS_DIR"
|
||||
@@ -42,6 +43,7 @@ then
|
||||
name=${file##*/}
|
||||
echo "🚀 Build Rust WasmPlugin: $name"
|
||||
PLUGIN_NAME=${name} make lint
|
||||
PLUGIN_NAME=${name} make test
|
||||
PLUGIN_NAME=${name} make build
|
||||
fi
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user