Compare commits

...

16 Commits

Author SHA1 Message Date
澄潭
85c7b1f501 rel: Release 2.0.4 (#1571) 2024-12-06 13:52:03 +08:00
pepesi
8f660211e3 feat: ai-proxy support custom error handler by cover util.ErrorHandler (#1537) 2024-12-06 11:47:50 +08:00
rinfx
433227323d extension mechanism for custom logs and span attributes (#1451) 2024-12-05 18:39:00 +08:00
pepesi
b36e5ea26b feat: allow cover api-version when use ai-proxy azure provider (#1535) 2024-12-05 13:41:02 +08:00
rinfx
ce66ff68ce solve aliyun lvwang content length limit problem (#1569) 2024-12-05 13:39:20 +08:00
pepesi
d026f0fca5 feat: ai-proxy support dashscope-finance (#1554) 2024-12-05 11:48:09 +08:00
rinfx
22790aa149 fix moonshot usage compatible problem (#1568) 2024-12-05 11:35:25 +08:00
澄潭
7ce6d7aba1 fix xds cache (#1559) 2024-12-04 00:55:29 +08:00
Se7en
e705a0344f fix: qwen stream issue (#1564) 2024-12-03 13:10:47 +08:00
澄潭
d6094974c2 update ai proxy go mod (#1556) 2024-12-02 14:41:55 +08:00
mamba
6187be97e5 fix: 🐛 frontend-grayurl 解析不正确导致路由失败 (#1550) 2024-11-29 13:09:05 +08:00
澄潭
bb64b43f23 set concurrency argument of proxy by cpu limit/request (#1552) 2024-11-28 16:55:57 +08:00
澄潭
ca7458cf1c Optimize the overall log output (#1549) 2024-11-27 20:44:34 +08:00
Se7en
ee2dd76ae1 feat: migrate baidu provider to v2 api (#1527) 2024-11-27 20:12:00 +08:00
pepesi
8154cf95f1 feat: support custom log (#1521) 2024-11-27 20:11:29 +08:00
澄潭
a7593381e1 fix ai fallback (#1541) 2024-11-25 16:48:59 +08:00
52 changed files with 1138 additions and 553 deletions

View File

@@ -144,7 +144,7 @@ docker-buildx-push: clean-env docker.higress-buildx
export PARENT_GIT_TAG:=$(shell cat VERSION)
export PARENT_GIT_REVISION:=$(TAG)
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.0.0/envoy-symbol-ARCH.tar.gz
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.0/envoy-symbol-ARCH.tar.gz
build-envoy: prebuild
./tools/hack/build-envoy.sh
@@ -187,8 +187,8 @@ install: pre-install
cd helm/higress; helm dependency build
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
ENVOY_LATEST_IMAGE_TAG ?= 2.0.1
ISTIO_LATEST_IMAGE_TAG ?= 2.0.1
ENVOY_LATEST_IMAGE_TAG ?= 2.0.3
ISTIO_LATEST_IMAGE_TAG ?= 8be82d2e4c280c29f4952fbeca1e2a79230b7836
install-dev: pre-install
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'

View File

@@ -1 +1 @@
v2.0.3
v2.0.4

View File

@@ -1,5 +1,5 @@
apiVersion: v2
appVersion: 2.0.3
appVersion: 2.0.4
description: Helm chart for deploying higress gateways
icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/
@@ -10,4 +10,4 @@ name: higress-core
sources:
- http://github.com/alibaba/higress
type: application
version: 2.0.3
version: 2.0.4

View File

@@ -7,9 +7,6 @@ Rendering the pod template of gateway component.
template:
metadata:
annotations:
{{- if .Values.global.enableHigressIstio }}
"enableHigressIstio": "true"
{{- end }}
{{- if .Values.gateway.podAnnotations }}
{{- toYaml .Values.gateway.podAnnotations | nindent 6 }}
{{- end }}
@@ -268,11 +265,7 @@ template:
{{- end }}
- name: higress-ca-root-cert
configMap:
{{- if .Values.global.enableHigressIstio }}
name: istio-ca-root-cert
{{- else }}
name: higress-ca-root-cert
{{- end }}
- name: config
configMap:
name: higress-config

View File

@@ -20,11 +20,7 @@
# When processing a leaf namespace Istio will search for declarations in that namespace first
# and if none are found it will search in the root namespace. Any matching declaration found in the root namespace
# is processed as if it were declared in the leaf namespace.
{{- if .Values.global.enableHigressIstio }}
rootNamespace: {{ .Values.meshConfig.rootNamespace | default .Values.global.istioNamespace }}
{{- else }}
rootNamespace: {{ .Release.Namespace }}
{{- end }}
configSources:
- address: "xds://127.0.0.1:15051"
@@ -85,12 +81,8 @@
discoveryAddress: {{ printf "istiod.%s.svc" .Release.Namespace }}:15012
{{- end }}
{{- else }}
{{- if .Values.global.enableHigressIstio }}
discoveryAddress: {{ printf "istiod.%s.svc" .Values.global.istioNamespace }}:15012
{{- else }}
discoveryAddress: {{ include "controller.name" . }}.{{.Release.Namespace}}.svc:15012
{{- end }}
{{- end }}
proxyStatsMatcher:
inclusionRegexps:
- ".*"

View File

@@ -96,7 +96,6 @@ spec:
volumeMounts:
- name: log
mountPath: /var/log
{{- if not .Values.global.enableHigressIstio }}
- name: discovery
image: "{{ .Values.pilot.hub | default .Values.global.hub }}/{{ .Values.pilot.image | default "pilot" }}:{{ .Values.pilot.tag | default .Chart.AppVersion }}"
{{- if .Values.global.imagePullPolicy }}
@@ -229,10 +228,8 @@ spec:
value: "false"
- name: PILOT_ENABLE_GATEWAY_API_DEPLOYMENT_CONTROLLER
value: "false"
{{- if not .Values.global.enableHigressIstio }}
- name: CUSTOM_CA_CERT_NAME
value: "higress-ca-root-cert"
{{- end }}
{{- if not (or .Values.global.local .Values.global.kind) }}
resources:
{{- if .Values.pilot.resources }}
@@ -269,7 +266,6 @@ spec:
- name: extracacerts
mountPath: /cacerts
{{- end }}
{{- end }}
{{- with .Values.controller.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
@@ -285,7 +281,6 @@ spec:
volumes:
- name: log
emptyDir: {}
{{- if not .Values.global.enableHigressIstio }}
- name: config
configMap:
name: higress-config
@@ -317,4 +312,3 @@ spec:
configMap:
name: pilot-jwks-extra-cacerts{{- if not (eq .Values.revision "") }}-{{ .Values.revision }}{{- end }}
{{- end }}
{{- end }}

View File

@@ -9,7 +9,6 @@ spec:
type: {{ .Values.controller.service.type }}
ports:
{{- toYaml .Values.controller.ports | nindent 4 }}
{{- if not .Values.global.enableHigressIstio }}
- port: 15010
name: grpc-xds # plaintext
protocol: TCP
@@ -23,6 +22,5 @@ spec:
- port: 15014
name: http-monitoring # prometheus stats
protocol: TCP
{{- end }}
selector:
{{- include "controller.selectorLabels" . | nindent 4 }}

View File

@@ -40,8 +40,6 @@ global:
enableIstioAPI: true
# -- If true, Higress Controller will monitor Gateway API resources as well
enableGatewayAPI: false
# Deprecated
enableHigressIstio: false
# -- Used to locate istiod.
istioNamespace: istio-system
# -- enable pod disruption budget for the control plane, which is used to

View File

@@ -1,9 +1,9 @@
dependencies:
- name: higress-core
repository: file://../core
version: 2.0.3
version: 2.0.4
- name: higress-console
repository: https://higress.io/helm-charts/
version: 1.4.5
digest: sha256:74b772113264168483961f5d0424459fd7359adc509a4b50400229581d7cddbf
generated: "2024-11-08T14:06:51.871719+08:00"
version: 1.4.6
digest: sha256:ec570ac7ae8a6de976e7ffafaadae4a33beeabfb4b13debe63e0cfa100e2eb8c
generated: "2024-12-06T11:34:04.628976+08:00"

View File

@@ -1,5 +1,5 @@
apiVersion: v2
appVersion: 2.0.3
appVersion: 2.0.4
description: Helm chart for deploying Higress gateways
icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/
@@ -12,9 +12,9 @@ sources:
dependencies:
- name: higress-core
repository: "file://../core"
version: 2.0.3
version: 2.0.4
- name: higress-console
repository: "https://higress.io/helm-charts/"
version: 1.4.5
version: 1.4.6
type: application
version: 2.0.3
version: 2.0.4

View File

@@ -159,7 +159,6 @@ The command removes all the Kubernetes components associated with the chart and
| global.disableAlpnH2 | bool | `false` | Whether to disable HTTP/2 in ALPN |
| global.enableGatewayAPI | bool | `false` | If true, Higress Controller will monitor Gateway API resources as well |
| global.enableH3 | bool | `false` | |
| global.enableHigressIstio | bool | `false` | |
| global.enableIPv6 | bool | `false` | |
| global.enableIstioAPI | bool | `true` | If true, Higress Controller will monitor istio resources as well |
| global.enableProxyProtocol | bool | `false` | |

View File

@@ -41,11 +41,11 @@ import (
"istio.io/istio/pkg/config/schema/kind"
"istio.io/istio/pkg/keepalive"
istiokube "istio.io/istio/pkg/kube"
"istio.io/istio/pkg/log"
"istio.io/istio/pkg/security"
"istio.io/istio/security/pkg/server/ca/authenticate"
"istio.io/istio/security/pkg/server/ca/authenticate/kubeauth"
"istio.io/pkg/ledger"
"istio.io/pkg/log"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"

View File

@@ -173,7 +173,7 @@ func (s *CertMgr) Reconcile(ctx context.Context, oldConfig *Config, newConfig *C
s.cache.Start()
// sync domains
s.configMgr.SetConfig(newConfig)
CertLog.Infof("certMgr start to manageSync domains:+v%", newDomains)
CertLog.Infof("certMgr start to manageSync domains: %+v", newDomains)
s.manageSync(context.Background(), newDomains)
CertLog.Infof("certMgr manageSync domains done")
} else {

View File

@@ -14,6 +14,6 @@
package cert
import "istio.io/pkg/log"
import "istio.io/istio/pkg/log"
var CertLog = log.RegisterScope("cert", "Higress Cert process.", 0)
var CertLog = log.RegisterScope("cert", "Higress Cert process.")

View File

@@ -25,7 +25,7 @@ import (
"istio.io/istio/pkg/config/constants"
"istio.io/istio/pkg/env"
"istio.io/istio/pkg/keepalive"
"istio.io/pkg/log"
"istio.io/istio/pkg/log"
)
var (

View File

@@ -303,21 +303,21 @@ func (m *IngressConfig) listFromIngressControllers(typ config.GroupVersionKind,
common.SortIngressByCreationTime(configs)
wrapperConfigs := m.createWrapperConfigs(configs)
IngressLog.Infof("resource type %s, configs number %d", typ, len(wrapperConfigs))
var result []config.Config
switch typ {
case gvk.Gateway:
return m.convertGateways(wrapperConfigs)
result = m.convertGateways(wrapperConfigs)
case gvk.VirtualService:
return m.convertVirtualService(wrapperConfigs)
result = m.convertVirtualService(wrapperConfigs)
case gvk.DestinationRule:
return m.convertDestinationRule(wrapperConfigs)
result = m.convertDestinationRule(wrapperConfigs)
case gvk.ServiceEntry:
return m.convertServiceEntry(wrapperConfigs)
result = m.convertServiceEntry(wrapperConfigs)
case gvk.WasmPlugin:
return m.convertWasmPlugin(wrapperConfigs)
result = m.convertWasmPlugin(wrapperConfigs)
}
return nil
IngressLog.Infof("resource type %s, ingress number %d, convert configs number %d", typ, len(configs), len(result))
return result
}
func (m *IngressConfig) listFromGatewayControllers(typ config.GroupVersionKind, namespace string) []config.Config {
@@ -712,7 +712,6 @@ func (m *IngressConfig) convertDestinationRule(configs []common.WrapperConfig) [
if m.RegistryReconciler != nil {
drws := m.RegistryReconciler.GetAllDestinationRuleWrapper()
IngressLog.Infof("Found mcp destinationRules: %v", drws)
for _, destinationRuleWrapper := range drws {
serviceName := destinationRuleWrapper.ServiceKey.ServiceFQDN
dr, exist := destinationRules[serviceName]
@@ -906,6 +905,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
StructValue: rule.Config,
}
validRule := false
var matchItems []*_struct.Value
// match ingress
for _, ing := range rule.Ingress {
@@ -916,6 +916,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
})
}
if len(matchItems) > 0 {
validRule = true
v.StructValue.Fields["_match_route_"] = &_struct.Value{
Kind: &_struct.Value_ListValue{
ListValue: &_struct.ListValue{
@@ -923,12 +924,9 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
},
},
}
ruleValues = append(ruleValues, &_struct.Value{
Kind: v,
})
continue
}
// match service
matchItems = nil
for _, service := range rule.Service {
matchItems = append(matchItems, &_struct.Value{
Kind: &_struct.Value_StringValue{
@@ -937,6 +935,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
})
}
if len(matchItems) > 0 {
validRule = true
v.StructValue.Fields["_match_service_"] = &_struct.Value{
Kind: &_struct.Value_ListValue{
ListValue: &_struct.ListValue{
@@ -944,12 +943,9 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
},
},
}
ruleValues = append(ruleValues, &_struct.Value{
Kind: v,
})
continue
}
// match domain
matchItems = nil
for _, domain := range rule.Domain {
matchItems = append(matchItems, &_struct.Value{
Kind: &_struct.Value_StringValue{
@@ -957,19 +953,23 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
},
})
}
if len(matchItems) == 0 {
if len(matchItems) > 0 {
validRule = true
v.StructValue.Fields["_match_domain_"] = &_struct.Value{
Kind: &_struct.Value_ListValue{
ListValue: &_struct.ListValue{
Values: matchItems,
},
},
}
}
if validRule {
ruleValues = append(ruleValues, &_struct.Value{
Kind: v,
})
} else {
return nil, fmt.Errorf("invalid match rule has no match condition, rule:%v", rule)
}
v.StructValue.Fields["_match_domain_"] = &_struct.Value{
Kind: &_struct.Value_ListValue{
ListValue: &_struct.ListValue{
Values: matchItems,
},
},
}
ruleValues = append(ruleValues, &_struct.Value{
Kind: v,
})
}
if len(ruleValues) > 0 {
hasValidRule = true

View File

@@ -493,7 +493,7 @@ func (m *KIngressConfig) HasSynced() bool {
defer m.mutex.RUnlock()
for _, remoteIngressController := range m.remoteIngressControllers {
IngressLog.Info("In Kingress Synced.", remoteIngressController)
IngressLog.Info("In Kingress Synced.")
if !remoteIngressController.HasSynced() {
return false
}

View File

@@ -81,8 +81,6 @@ func (s *statusSyncer) runUpdateStatus() error {
return err
}
IngressLog.Debugf("found number %d of svc", len(svcList))
lbStatusList := common.GetLbStatusListV1Beta1(svcList)
if len(lbStatusList) == 0 {
return nil

View File

@@ -162,6 +162,7 @@ func (c *controller) onEvent(namespacedName types.NamespacedName) error {
delete(c.ingresses, namespacedName.String())
c.mutex.Unlock()
} else {
IngressLog.Warnf("ingressLister Get failed, ingress: %s, err: %v", namespacedName, err)
return err
}
}
@@ -171,7 +172,7 @@ func (c *controller) onEvent(namespacedName types.NamespacedName) error {
return nil
}
IngressLog.Debugf("ingress: %s, event: %s", namespacedName, event)
IngressLog.Infof("ingress: %s, event: %s", namespacedName, event)
// we should check need process only when event is not delete,
// if it is delete event, and previously processed, we need to process too.
@@ -181,7 +182,7 @@ func (c *controller) onEvent(namespacedName types.NamespacedName) error {
return err
}
if !shouldProcess {
IngressLog.Infof("no need process, ingress %s", namespacedName)
IngressLog.Infof("no need process, ingress: %s", namespacedName)
return nil
}
}
@@ -279,10 +280,17 @@ func (c *controller) List() []config.Config {
for _, raw := range c.ingressInformer.Informer.GetStore().List() {
ing, ok := raw.(*ingress.Ingress)
if !ok {
IngressLog.Warnf("get ingress from informer failed: %v", raw)
continue
}
if should, err := c.shouldProcessIngress(ing); !should || err != nil {
should, err := c.shouldProcessIngress(ing)
if err != nil {
IngressLog.Warnf("check should process ingress failed: %v", err)
continue
}
if !should {
IngressLog.Debugf("no need process ingress: %s/%s", ing.Namespace, ing.Name)
continue
}

View File

@@ -81,8 +81,6 @@ func (s *statusSyncer) runUpdateStatus() error {
return err
}
IngressLog.Debugf("found number %d of svc", len(svcList))
lbStatusList := common.GetLbStatusListV1(svcList)
if len(lbStatusList) == 0 {
return nil

View File

@@ -77,7 +77,6 @@ func (s *statusSyncer) runUpdateStatus() error {
return err
}
IngressLog.Debugf("found number %d of svc", len(svcList))
lbStatusList := common2.GetLbStatusList(svcList)
return s.updateStatus(lbStatusList)
}

View File

@@ -14,6 +14,6 @@
package log
import "istio.io/pkg/log"
import "istio.io/istio/pkg/log"
var IngressLog = log.RegisterScope("ingress", "Higress Ingress process.", 0)
var IngressLog = log.RegisterScope("ingress", "Higress Ingress process.")

View File

@@ -0,0 +1,68 @@
static_resources:
listeners:
- name: listener_0
address:
socket_address:
protocol: TCP
address: 0.0.0.0
port_value: 8080
filter_chains:
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: ingress_http
access_log:
- name: envoy.access_loggers.file
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
log_format:
text_format_source:
inline_string: "{\"custom_log\":\"%FILTER_STATE(wasm.custom_log:PLAIN)%\",\"ai_log\":\"%FILTER_STATE(wasm.ai_log:PLAIN)%\"}
"
path: /dev/stdout
route_config:
name: local_route
virtual_hosts:
- name: local_service
domains: ["*"]
routes:
- name: get
match:
prefix: "/get"
route:
cluster: httpbin
http_filters:
- name: test
typed_config:
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
value:
config:
name: test
vm_config:
runtime: envoy.wasm.runtime.v8
code:
local:
filename: main.wasm
configuration:
"@type": "type.googleapis.com/google.protobuf.StringValue"
value: {}
- name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
clusters:
- name: httpbin
connect_timeout: 600s
type: STRICT_DNS
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: httpbin
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: httpbin.org
port_value: 80

View File

@@ -0,0 +1,20 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/custom-logs
go 1.18
replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v0.0.0
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/gjson v1.17.3 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
)

View File

@@ -0,0 +1,20 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,67 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"math/rand"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
func main() {
wrapper.SetCtx(
"custom-log",
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
)
}
type CustomLogConfig struct {
}
// Method 1: write custom log
func writeLog(ctx wrapper.HttpContext) {
ctx.SetUserAttribute("question", "当然可以。在Python中你可以创建一个函数来计算一系列数字的和。下面是一个简单的例子该函数接受一个数字列表作为输入并返回它们的总和。\n\n```python\ndef sum_of_numbers(numbers):\n \"\"\"\n 计算列表中所有数字的和。\n \n 参数:\n numbers (list of int or float): 一个包含数字的列表。\n \n 返回:\n int or float: 列表中所有数字的总和。\n \"\"\"\n total_sum = sum(numbers) # 使用Python内置的sum函数计算总和\n return total_sum\n\n# 示例使用\nnumbers_list = [1, 2, 3, 4, 5]\nprint(\"The sum is:\", sum_of_numbers(numbers_list)) # 输出The sum is: 15\n```\n\n在这段代码中我们定义了一个名为 `sum_of_numbers` 的函数,它接收一个参数 `numbers`这是一个包含整数或浮点数的列表。函数内部使用了Python的内置函数 `sum()` 来计算这些数字的总和,并将结果返回。\n\n你也可以手动实现求和逻辑而不是使用内置的 `sum()` 函数,如下所示:\n\n```python\ndef sum_of_numbers_manual(numbers):\n \"\"\"\n 手动计算列表中所有数字的和。\n \n 参数:\n numbers (list of int or float): 一个包含数字的列表。\n \n 返回:\n int or float: 列表中所有数字的总和。\n \"\"\"\n total_sum = 0\n for number in numbers:\n total_sum += number\n return total_sum\n\n# 示例使用\nnumbers_list = [1, 2, 3, 4, 5]\nprint(\"The sum is:\", sum_of_numbers_manual(numbers_list)) # 输出The sum is: 15\n```\n\n在这个版本中我们初始化 `total_sum` 为0然后遍历列表中的每个元素并将其加到 `total_sum` 上。最后返回这个累加的结果。这两种方法都可以达到相同的目的,但是使用内置函数通常更简洁且效率更高。")
ctx.SetUserAttribute("k2", 2213.22)
ctx.WriteUserAttributeToLog()
}
// Methods 2: write custom log with specific key
func writeLogWithKey(ctx wrapper.HttpContext, key string) {
ctx.SetUserAttribute("k2", 2213.22)
_ = ctx.WriteUserAttributeToLogWithKey(key)
ctx.SetUserAttribute("k2", 212939.22)
ctx.SetUserAttribute("k3", 123)
_ = ctx.WriteUserAttributeToLogWithKey(key)
}
// Methods 2: write custom log with specific key
func writeTraceAttribute(ctx wrapper.HttpContext) {
ctx.SetUserAttribute("question", "当然可以。在Python中你可以创建一个函数来计算一系列数字的和。下面是一个简单的例子该函数接受一个数字列表作为输入并返回它们的总和。\n\n```python\ndef sum_of_numbers(numbers):\n \"\"\"\n 计算列表中所有数字的和。\n \n 参数:\n numbers (list of int or float): 一个包含数字的列表。\n \n 返回:\n int or float: 列表中所有数字的总和。\n \"\"\"\n total_sum = sum(numbers) # 使用Python内置的sum函数计算总和\n return total_sum\n\n# 示例使用\nnumbers_list = [1, 2, 3, 4, 5]\nprint(\"The sum is:\", sum_of_numbers(numbers_list)) # 输出The sum is: 15\n```\n\n在这段代码中我们定义了一个名为 `sum_of_numbers` 的函数,它接收一个参数 `numbers`这是一个包含整数或浮点数的列表。函数内部使用了Python的内置函数 `sum()` 来计算这些数字的总和,并将结果返回。\n\n你也可以手动实现求和逻辑而不是使用内置的 `sum()` 函数,如下所示:\n\n```python\ndef sum_of_numbers_manual(numbers):\n \"\"\"\n 手动计算列表中所有数字的和。\n \n 参数:\n numbers (list of int or float): 一个包含数字的列表。\n \n 返回:\n int or float: 列表中所有数字的总和。\n \"\"\"\n total_sum = 0\n for number in numbers:\n total_sum += number\n return total_sum\n\n# 示例使用\nnumbers_list = [1, 2, 3, 4, 5]\nprint(\"The sum is:\", sum_of_numbers_manual(numbers_list)) # 输出The sum is: 15\n```\n\n在这个版本中我们初始化 `total_sum` 为0然后遍历列表中的每个元素并将其加到 `total_sum` 上。最后返回这个累加的结果。这两种方法都可以达到相同的目的,但是使用内置函数通常更简洁且效率更高。")
ctx.SetUserAttribute("k2", 2213.22)
ctx.WriteUserAttributeToTrace()
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config CustomLogConfig, log wrapper.Log) types.Action {
if rand.Intn(10)%3 == 1 {
writeLog(ctx)
} else if rand.Intn(10)%3 == 2 {
writeLogWithKey(ctx, "ai_log")
} else {
writeTraceAttribute(ctx)
}
return types.ActionContinue
}

View File

@@ -148,7 +148,15 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
#### 文心一言Baidu
文心一言所对应的 `type``baidu`。它并无特有的配置字段
文心一言所对应的 `type``baidu`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|------|-----|-----------------------------------------------------------|
| `baiduAccessKeyAndSecret` | array of string | 必填 | - | Baidu 的 Access Key 和 Secret Key中间用 `:` 分隔,用于申请 apiToken。 |
| `baiduApiTokenServiceName` | string | 必填 | - | 请求刷新百度 apiToken 服务名称。 |
| `baiduApiTokenServiceHost` | string | 非必填 | - | 请求刷新百度 apiToken 服务域名,默认是 iam.bj.baidubce.com。 |
| `baiduApiTokenServicePort` | int64 | 非必填 | - | 请求刷新百度 apiToken 服务端口,默认是 443。 |
#### 360智脑

View File

@@ -86,6 +86,11 @@ func (c *PluginConfig) Complete(log wrapper.Log) error {
providerConfig := c.GetProviderConfig()
err = providerConfig.SetApiTokensFailover(log, c.activeProvider)
if handler, ok := c.activeProvider.(provider.TickFuncHandler); ok {
tickPeriod, tickFunc := handler.GetTickFunc(log)
wrapper.RegisteTickFunc(tickPeriod, tickFunc)
}
return err
}

View File

@@ -8,7 +8,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v0.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.3
)

View File

@@ -4,8 +4,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -114,7 +114,7 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
return types.ActionContinue
}
@@ -136,7 +136,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
if settingErr != nil {
_ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to rewrite request body by custom settings: %v", settingErr))
util.ErrorHandler(
"ai-proxy.proc_req_body_failed",
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
)
return types.ActionContinue
}
@@ -146,7 +149,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if err == nil {
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request body: %v", err))
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
}
return types.ActionContinue
}
@@ -190,14 +193,14 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
if err == nil {
checkStream(&ctx, &log)
checkStream(&ctx, log)
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_resp_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process response headers: %v", err))
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
return types.ActionContinue
}
checkStream(&ctx, &log)
checkStream(&ctx, log)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleBody && !needHandleStreamingBody {
@@ -248,13 +251,13 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
if err == nil {
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_resp_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process response body: %v", err))
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
return types.ActionContinue
}
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
func checkStream(ctx *wrapper.HttpContext, log wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {

View File

@@ -69,7 +69,22 @@ func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
u, e := url.Parse(ctx.Path())
if e == nil {
customApiVersion := u.Query().Get("api-version")
if customApiVersion == "" {
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
} else {
q := m.serviceUrl.Query()
q.Set("api-version", customApiVersion)
newUrl := *m.serviceUrl
newUrl.RawQuery = q.Encode()
util.OverwriteRequestPathHeader(headers, newUrl.RequestURI())
}
} else {
log.Errorf("failed to parse request path: %v", e)
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
}
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")

View File

@@ -1,48 +1,53 @@
package provider
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// baiduProvider is the provider for baidu ernie bot service.
// baiduProvider is the provider for baidu service.
const (
baiduDomain = "aip.baidubce.com"
baiduChatCompletionPath = "/chat"
baiduDomain = "qianfan.baidubce.com"
baiduChatCompletionPath = "/v2/chat/completions"
baiduApiTokenDomain = "iam.bj.baidubce.com"
baiduApiTokenPort = 443
baiduApiTokenPath = "/v1/BCE-BEARER/token"
// refresh apiToken every 1 hour
baiduApiTokenRefreshInterval = 3600
// authorizationString expires in 30 minutes, authorizationString is used to generate apiToken
// the default expiration time of apiToken is 24 hours
baiduAuthorizationStringExpirationSeconds = 1800
bce_prefix = "x-bce-"
)
var baiduModelToPathSuffixMap = map[string]string{
"ERNIE-4.0-8K": "completions_pro",
"ERNIE-3.5-8K": "completions",
"ERNIE-3.5-128K": "ernie-3.5-128k",
"ERNIE-Speed-8K": "ernie_speed",
"ERNIE-Speed-128K": "ernie-speed-128k",
"ERNIE-Tiny-8K": "ernie-tiny-8k",
"ERNIE-Bot-8K": "ernie_bot_8k",
"BLOOMZ-7B": "bloomz_7b1",
}
type baiduProviderInitializer struct{}
type baiduProviderInitializer struct {
}
func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
func (g *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
return errors.New("no baiduAccessKeyAndSecret found in provider config")
}
if config.baiduApiTokenServiceName == "" {
return errors.New("no baiduApiTokenServiceName found in provider config")
}
if !config.failover.enabled {
config.useGlobalApiToken = true
}
return nil
}
func (b *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
func (g *baiduProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &baiduProvider{
config: config,
contextCache: createContextCache(&config),
@@ -54,234 +59,235 @@ type baiduProvider struct {
contextCache *contextCache
}
func (b *baiduProvider) GetProviderType() string {
func (g *baiduProvider) GetProviderType() string {
return providerTypeBaidu
}
func (b *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
b.config.handleRequestHeaders(b, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (b *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, baiduDomain)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
}
func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body, log)
}
func (b *baiduProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
err := b.config.parseRequestAndMapModel(ctx, request, body, log)
if err != nil {
return nil, err
}
path := b.getRequestPath(ctx, request.Model)
util.OverwriteRequestPathHeader(headers, path)
baiduRequest := b.baiduTextGenRequest(request)
return json.Marshal(baiduRequest)
}
func (b *baiduProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// 使用文心一言接口协议,跳过OnStreamingResponseBody()和OnResponseBody()
if b.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil
}
func (b *baiduProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// sample event response:
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089502,"sentence_id":0,"is_end":false,"is_truncated":false,"result":"当然可以,","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}
// sample end event response:
// data: {"id":"as-vb0m37ti8y","object":"chat.completion","created":1709089531,"sentence_id":20,"is_end":true,"is_truncated":false,"result":"","need_clear_history":false,"finish_reason":"normal","usage":{"prompt_tokens":5,"completion_tokens":420,"total_tokens":425}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var baiduResponse baiduTextGenStreamResponse
if err := json.Unmarshal([]byte(data), &baiduResponse); err != nil {
log.Errorf("unable to unmarshal baidu response: %v", err)
continue
}
response := b.streamResponseBaidu2OpenAI(ctx, &baiduResponse)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
b.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
}
func (b *baiduProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
baiduResponse := &baiduTextGenResponse{}
if err := json.Unmarshal(body, baiduResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal baidu response: %v", err)
}
if baiduResponse.ErrorMsg != "" {
return types.ActionContinue, fmt.Errorf("baidu response error, error_code: %d, error_message: %s", baiduResponse.ErrorCode, baiduResponse.ErrorMsg)
}
response := b.responseBaidu2OpenAI(ctx, baiduResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, baiduChatCompletionPath)
util.OverwriteRequestHostHeader(headers, baiduDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
type baiduTextGenRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}
func (b *baiduProvider) getRequestPath(ctx wrapper.HttpContext, baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
suffix = baiduModel
}
return fmt.Sprintf("/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/%s?access_token=%s", suffix, b.config.GetApiTokenInUse(ctx))
}
func (b *baiduProvider) setSystemContent(request *baiduTextGenRequest, content string) {
request.System = content
}
func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *baiduTextGenRequest {
baiduRequest := baiduTextGenRequest{
Messages: make([]chatMessage, 0, len(request.Messages)),
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
MaxOutputTokens: request.MaxTokens,
UserId: request.User,
}
for _, message := range request.Messages {
if message.Role == roleSystem {
baiduRequest.System = message.StringContent()
} else {
baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{
Role: message.Role,
Content: message.Content,
})
}
}
return &baiduRequest
}
type baiduTextGenResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage baiduTextGenResponseUsage `json:"usage"`
baiduTextGenResponseError
}
type baiduTextGenResponseError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type baiduTextGenStreamResponse struct {
baiduTextGenResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type baiduTextGenResponseUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenResponse) *chatCompletionResponse {
choice := chatCompletionChoice{
Index: 0,
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
FinishReason: finishReasonStop,
}
return &chatCompletionResponse{
Id: response.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: []chatCompletionChoice{choice},
Usage: usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
}
func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, response *baiduTextGenStreamResponse) *chatCompletionResponse {
choice := chatCompletionChoice{
Index: 0,
Message: &chatMessage{Role: roleAssistant, Content: response.Result},
}
if response.IsEnd {
choice.FinishReason = finishReasonStop
}
return &chatCompletionResponse{
Id: response.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletionChunk,
Choices: []chatCompletionChoice{choice},
Usage: usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
}
func (b *baiduProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (b *baiduProvider) GetApiName(path string) ApiName {
func (g *baiduProvider) GetApiName(path string) ApiName {
if strings.Contains(path, baiduChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}
func generateAuthorizationString(accessKeyAndSecret string, expirationInSeconds int) string {
c := strings.Split(accessKeyAndSecret, ":")
credentials := BceCredentials{
AccessKeyId: c[0],
SecretAccessKey: c[1],
}
httpMethod := "GET"
path := baiduApiTokenPath
headers := map[string]string{"host": baiduApiTokenDomain}
timestamp := time.Now().Unix()
headersToSign := make([]string, 0, len(headers))
for k := range headers {
headersToSign = append(headersToSign, k)
}
return sign(credentials, httpMethod, path, headers, timestamp, expirationInSeconds, headersToSign)
}
// BceCredentials holds the access key and secret key
type BceCredentials struct {
AccessKeyId string
SecretAccessKey string
}
// normalizeString performs URI encoding according to RFC 3986
func normalizeString(inStr string, encodingSlash bool) string {
if inStr == "" {
return ""
}
var result strings.Builder
for _, ch := range []byte(inStr) {
if (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') ||
(ch >= '0' && ch <= '9') || ch == '.' || ch == '-' ||
ch == '_' || ch == '~' || (!encodingSlash && ch == '/') {
result.WriteByte(ch)
} else {
result.WriteString(fmt.Sprintf("%%%02X", ch))
}
}
return result.String()
}
// getCanonicalTime generates a timestamp in UTC format
func getCanonicalTime(timestamp int64) string {
if timestamp == 0 {
timestamp = time.Now().Unix()
}
t := time.Unix(timestamp, 0).UTC()
return t.Format("2006-01-02T15:04:05Z")
}
// getCanonicalUri generates a canonical URI
func getCanonicalUri(path string) string {
return normalizeString(path, false)
}
// getCanonicalHeaders generates canonical headers
func getCanonicalHeaders(headers map[string]string, headersToSign []string) string {
if len(headers) == 0 {
return ""
}
// If headersToSign is not specified, use default headers
if len(headersToSign) == 0 {
headersToSign = []string{"host", "content-md5", "content-length", "content-type"}
}
// Convert headersToSign to a map for easier lookup
headerMap := make(map[string]bool)
for _, header := range headersToSign {
headerMap[strings.ToLower(strings.TrimSpace(header))] = true
}
// Create a slice to hold the canonical headers
var canonicalHeaders []string
for k, v := range headers {
k = strings.ToLower(strings.TrimSpace(k))
v = strings.TrimSpace(v)
// Add headers that start with x-bce- or are in headersToSign
if strings.HasPrefix(k, bce_prefix) || headerMap[k] {
canonicalHeaders = append(canonicalHeaders,
fmt.Sprintf("%s:%s", normalizeString(k, true), normalizeString(v, true)))
}
}
// Sort the canonical headers
sort.Strings(canonicalHeaders)
return strings.Join(canonicalHeaders, "\n")
}
// sign generates the authorization string
func sign(credentials BceCredentials, httpMethod, path string, headers map[string]string,
timestamp int64, expirationInSeconds int,
headersToSign []string) string {
// Generate sign key
signKeyInfo := fmt.Sprintf("bce-auth-v1/%s/%s/%d",
credentials.AccessKeyId,
getCanonicalTime(timestamp),
expirationInSeconds)
// Generate sign key using HMAC-SHA256
h := hmac.New(sha256.New, []byte(credentials.SecretAccessKey))
h.Write([]byte(signKeyInfo))
signKey := hex.EncodeToString(h.Sum(nil))
// Generate canonical URI
canonicalUri := getCanonicalUri(path)
// Generate canonical headers
canonicalHeaders := getCanonicalHeaders(headers, headersToSign)
// Generate string to sign
stringToSign := strings.Join([]string{
httpMethod,
canonicalUri,
"",
canonicalHeaders,
}, "\n")
// Calculate final signature
h = hmac.New(sha256.New, []byte(signKey))
h.Write([]byte(stringToSign))
signature := hex.EncodeToString(h.Sum(nil))
// Generate final authorization string
if len(headersToSign) > 0 {
return fmt.Sprintf("%s/%s/%s", signKeyInfo, strings.Join(headersToSign, ";"), signature)
}
return fmt.Sprintf("%s//%s", signKeyInfo, signature)
}
// GetTickFunc Refresh apiToken (apiToken) periodically, the maximum apiToken expiration time is 24 hours
func (g *baiduProvider) GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func()) {
vmID := generateVMID()
return baiduApiTokenRefreshInterval * 1000, func() {
// Only the Wasm VM that successfully acquires the lease will refresh the apiToken
if g.config.tryAcquireOrRenewLease(vmID, log) {
log.Debugf("Successfully acquired or renewed lease for baidu apiToken refresh task, vmID: %v", vmID)
// Get the apiToken that is about to expire, will be removed after the new apiToken is obtained
oldApiTokens, _, err := getApiTokens(g.config.failover.ctxApiTokens)
if err != nil {
log.Errorf("Get old apiToken failed: %v", err)
return
}
log.Debugf("Old apiTokens: %v", oldApiTokens)
for _, accessKeyAndSecret := range g.config.baiduAccessKeyAndSecret {
authorizationString := generateAuthorizationString(accessKeyAndSecret, baiduAuthorizationStringExpirationSeconds)
log.Debugf("Generate authorizationString: %v", authorizationString)
g.generateNewApiToken(authorizationString, log)
}
// remove old old apiToken
for _, token := range oldApiTokens {
log.Debugf("Remove old apiToken: %v", token)
removeApiToken(g.config.failover.ctxApiTokens, token, log)
}
}
}
}
func (g *baiduProvider) generateNewApiToken(authorizationString string, log wrapper.Log) {
client := wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: g.config.baiduApiTokenServiceName,
Host: g.config.baiduApiTokenServiceHost,
Port: g.config.baiduApiTokenServicePort,
})
headers := [][2]string{
{"content-type", "application/json"},
{"Authorization", authorizationString},
}
var apiToken string
err := client.Get(baiduApiTokenPath, headers, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 201 {
var response map[string]interface{}
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Errorf("Unmarshal response failed: %v", err)
} else {
apiToken = response["token"].(string)
addApiToken(g.config.failover.ctxApiTokens, apiToken, log)
}
} else {
log.Errorf("Get apiToken failed, status code: %d, response body: %s", statusCode, string(responseBody))
}
}, 30000)
if err != nil {
log.Errorf("Get apiToken failed: %v", err)
}
}

View File

@@ -139,7 +139,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
typ := provider.GetProviderType()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), fmt.Errorf("failed to load context file: %v", err))
}
if inserter, ok := provider.(ContextInserter); ok {
@@ -149,10 +149,10 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
}
if err != nil {
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err))
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
}
if err := replaceHttpJsonRequestBody(body, log); err != nil {
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
}
}

View File

@@ -467,7 +467,7 @@ func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string,
log.Errorf("failed to get failureApiTokenRequestCount: %v", err)
}
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
log.Infof("reset apiToken %s request failure count", apiTokenInUse)
log.Infof("Reset apiToken %s request failure count", apiTokenInUse)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
}
}
@@ -489,7 +489,7 @@ func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log
apiTokenRequestCountByte, err := json.Marshal(apiTokenRequestCount)
if err != nil {
log.Errorf("failed to marshal apiTokenRequestCount: %v", err)
log.Errorf("Failed to marshal apiTokenRequestCount: %v", err)
}
if err := proxywasm.SetSharedData(key, apiTokenRequestCountByte, cas); err == nil {
@@ -551,7 +551,7 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
var apiToken string
if c.isFailoverEnabled() {
if c.isFailoverEnabled() || c.useGlobalApiToken {
// if enable apiToken failover, only use available apiToken
apiToken = c.GetGlobalRandomToken(log)
} else {

View File

@@ -172,7 +172,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.hunyuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
util.ErrorHandler("ai-proxy.hunyuan.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
}
m.insertContextMessageIntoHunyuanRequest(request, content)
@@ -182,7 +182,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.hunyuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
util.ErrorHandler("ai-proxy.hunyuan.insert_ctx_failed", fmt.Errorf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
@@ -244,7 +244,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.hunyuan.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
util.ErrorHandler("ai-proxy.hunyuan.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
return
}
insertContextMessage(request, content)
@@ -256,7 +256,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
if err := replaceJsonRequestBody(hunyuanRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.hunyuan.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
util.ErrorHandler("ai-proxy.hunyuan.insert_ctx_failed", fmt.Errorf("failed to replace request body: %v", err))
}
}, log)
if err == nil {

View File

@@ -141,14 +141,14 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.minimax.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
util.ErrorHandler("ai-proxy.minimax.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
}
// 由于 minimaxChatCompletionV2格式和 OpenAI 一致)和 minimaxChatCompletionPro格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一
// 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息
// minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.minimax.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace Request body: %v", err))
util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err))
}
}, log)
if err == nil {

View File

@@ -3,12 +3,15 @@ package provider
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"net/http"
"github.com/tidwall/sjson"
)
// moonshotProvider is the provider for Moonshot AI service.
@@ -91,12 +94,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.moonshot.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
return
}
err = m.performChatCompletion(ctx, content, request, log)
if err != nil {
_ = util.SendResponse(500, "ai-proxy.moonshot.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to perform chat completion: %v", err))
util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
}
}, log)
if err == nil {
@@ -149,3 +152,99 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
return errors.New("unsupported method: " + method)
}
}
func (m *moonshotProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
receivedBody := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
receivedBody = append(bufferedStreamingBody, chunk...)
}
eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1
defer func() {
if eventStartIndex >= 0 && eventStartIndex < len(receivedBody) {
// Just in case the received chunk is not a complete event.
ctx.SetContext(ctxKeyStreamingBody, receivedBody[eventStartIndex:])
} else {
ctx.SetContext(ctxKeyStreamingBody, nil)
}
}()
var responseBuilder strings.Builder
currentKey := ""
currentEvent := &streamEvent{}
i, length := 0, len(receivedBody)
for i = 0; i < length; i++ {
ch := receivedBody[i]
if ch != '\n' {
if lineStartIndex == -1 {
if eventStartIndex == -1 {
eventStartIndex = i
}
lineStartIndex = i
valueStartIndex = -1
}
if valueStartIndex == -1 {
if ch == ':' {
valueStartIndex = i + 1
currentKey = string(receivedBody[lineStartIndex:valueStartIndex])
}
} else if valueStartIndex == i && ch == ' ' {
// Skip leading spaces in data.
valueStartIndex = i + 1
}
continue
}
if lineStartIndex != -1 {
value := string(receivedBody[valueStartIndex:i])
currentEvent.setValue(currentKey, value)
} else {
// Extra new line. The current event is complete.
log.Debugf("processing event: %v", currentEvent)
m.convertStreamEvent(&responseBuilder, currentEvent, log)
// Reset event parsing state.
eventStartIndex = -1
currentEvent = &streamEvent{}
}
// Reset line parsing state.
lineStartIndex = -1
valueStartIndex = -1
currentKey = ""
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
func (m *moonshotProvider) convertStreamEvent(responseBuilder *strings.Builder, event *streamEvent, log wrapper.Log) error {
if event.Data == streamEndDataValue {
m.appendStreamEvent(responseBuilder, event)
return nil
}
if gjson.Get(event.Data, "choices.0.usage").Exists() {
usageStr := gjson.Get(event.Data, "choices.0.usage").Raw
newData, err := sjson.Delete(event.Data, "choices.0.usage")
if err != nil {
log.Errorf("convert usage event error: %v", err)
return err
}
newData, err = sjson.SetRaw(newData, "usage", usageStr)
if err != nil {
log.Errorf("convert usage event error: %v", err)
return err
}
event.Data = newData
}
m.appendStreamEvent(responseBuilder, event)
return nil
}
func (m *moonshotProvider) appendStreamEvent(responseBuilder *strings.Builder, event *streamEvent) {
responseBuilder.WriteString(streamDataItemKey)
responseBuilder.WriteString(event.Data)
responseBuilder.WriteString("\n\n")
}

View File

@@ -134,7 +134,7 @@ type TransformRequestBodyHandler interface {
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. baidu, gemini) transform request headers (e.g., path) based on the request body (e.g., model).
// Some providers (e.g. gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
}
@@ -151,6 +151,12 @@ type ResponseBodyHandler interface {
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
}
// TickFuncHandler allows the provider to execute a function periodically
// Use case: the maximum expiration time of baidu apiToken is 24 hours, need to refresh periodically
type TickFuncHandler interface {
GetTickFunc(log wrapper.Log) (tickPeriod int64, tickFunc func())
}
type ProviderConfig struct {
// @Title zh-CN ID
// @Description zh-CN AI服务提供商标识
@@ -182,6 +188,9 @@ type ProviderConfig struct {
// @Title zh-CN 启用通义千问搜索服务
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
// @Title zh-CN 通义千问服务域名
// @Description zh-CN 仅适用于通义千问服务,默认转发域名为 dashscope.aliyuncs.com, 当使用金融云服务时,可以设置为 dashscope-finance.aliyuncs.com
qwenDomain string `required:"false" yaml:"qwenDomain" json:"qwenDomain"`
// @Title zh-CN 开启通义千问兼容模式
// @Description zh-CN 启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。
qwenEnableCompatible bool `required:"false" yaml:"qwenEnableCompatible" json:"qwenEnableCompatible"`
@@ -227,6 +236,17 @@ type ProviderConfig struct {
// @Title zh-CN 自定义大模型参数配置
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
customSettings []CustomSetting
// @Title zh-CN Baidu 的 Access Key 和 Secret Key中间用 : 分隔,用于申请 apiToken
baiduAccessKeyAndSecret []string `required:"false" yaml:"baiduAccessKeyAndSecret" json:"baiduAccessKeyAndSecret"`
// @Title zh-CN 请求刷新百度 apiToken 服务名称
baiduApiTokenServiceName string `required:"false" yaml:"baiduApiTokenServiceName" json:"baiduApiTokenServiceName"`
// @Title zh-CN 请求刷新百度 apiToken 服务域名
baiduApiTokenServiceHost string `required:"false" yaml:"baiduApiTokenServiceHost" json:"baiduApiTokenServiceHost"`
// @Title zh-CN 请求刷新百度 apiToken 服务端口
baiduApiTokenServicePort int64 `required:"false" yaml:"baiduApiTokenServicePort" json:"baiduApiTokenServicePort"`
// @Title zh-CN 是否使用全局的 apiToken
// @Description zh-CN 如果没有启用 apiToken failover但是 apiToken 的状态又需要在多个 Wasm VM 中同步时需要将该参数设置为 true例如 Baidu 的 apiToken 需要定时刷新
useGlobalApiToken bool `required:"false" yaml:"useGlobalApiToken" json:"useGlobalApiToken"`
}
func (c *ProviderConfig) GetId() string {
@@ -261,6 +281,10 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
c.qwenEnableCompatible = json.Get("qwenEnableCompatible").Bool()
c.qwenDomain = json.Get("qwenDomain").String()
if c.qwenDomain != "" {
// TODO: validate the domain, if not valid, set to default
}
c.ollamaServerHost = json.Get("ollamaServerHost").String()
c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint())
c.modelMapping = make(map[string]string)
@@ -321,6 +345,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if failoverJson.Exists() {
c.failover.FromJson(failoverJson)
}
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
}
c.baiduApiTokenServiceName = json.Get("baiduApiTokenServiceName").String()
c.baiduApiTokenServiceHost = json.Get("baiduApiTokenServiceHost").String()
if c.baiduApiTokenServiceHost == "" {
c.baiduApiTokenServiceHost = baiduApiTokenDomain
}
c.baiduApiTokenServicePort = json.Get("baiduApiTokenServicePort").Int()
if c.baiduApiTokenServicePort == 0 {
c.baiduApiTokenServicePort = baiduApiTokenPort
}
}
func (c *ProviderConfig) Validate() error {

View File

@@ -23,7 +23,7 @@ import (
const (
qwenResultFormatMessage = "message"
qwenDomain = "dashscope.aliyuncs.com"
qwenDefaultDomain = "dashscope.aliyuncs.com"
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
@@ -64,7 +64,11 @@ type qwenProvider struct {
}
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, qwenDomain)
if m.config.qwenDomain != "" {
util.OverwriteRequestHostHeader(headers, m.config.qwenDomain)
} else {
util.OverwriteRequestHostHeader(headers, qwenDefaultDomain)
}
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
if m.config.qwenEnableCompatible {
@@ -158,11 +162,11 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
streaming := request.Stream
if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
_ = proxywasm.ReplaceHttpRequestHeader("X-DashScope-SSE", "enable")
headers.Set("Accept", "text/event-stream")
headers.Set("X-DashScope-SSE", "enable")
} else {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "*/*")
_ = proxywasm.RemoveHttpRequestHeader("X-DashScope-SSE")
headers.Set("Accept", "*/*")
headers.Del("X-DashScope-SSE")
}
return m.buildQwenTextGenerationRequest(ctx, request, streaming)

View File

@@ -13,8 +13,10 @@ const (
MimeTypeApplicationJson = "application/json"
)
func SendResponse(statusCode uint32, statusCodeDetails string, contentType, body string) error {
return proxywasm.SendHttpResponseWithDetail(statusCode, statusCodeDetails, CreateHeaders(HeaderContentType, contentType), []byte(body), -1)
type ErrorHandlerFunc func(statusCodeDetails string, err error) error
var ErrorHandler ErrorHandlerFunc = func(statusCodeDetails string, err error) error {
return proxywasm.SendHttpResponseWithDetail(500, statusCodeDetails, CreateHeaders(HeaderContentType, MimeTypeTextPlain), []byte(err.Error()), -1)
}
func CreateHeaders(kvs ...string) [][2]string {

View File

@@ -6,7 +6,7 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.3.6-0.20240522012622-fc6a6aad8906
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3
)

View File

@@ -3,8 +3,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -55,6 +55,7 @@ const (
DefaultDenyMessage = "很抱歉,我无法回答您的问题"
AliyunUserAgent = "CIPFrom/AIGateway"
LengthLimit = 1800
)
type Response struct {
@@ -260,73 +261,92 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
log.Info("request content is empty. skip")
return types.ActionContinue
}
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)),
}
if config.token != "" {
params["SecurityToken"] = config.token
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
contentIndex := 0
sessionID, _ := generateHexID(20)
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpRequest()
return
}
var response Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at request phase")
proxywasm.ResumeHttpRequest()
return
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
if contentIndex >= len(content) {
proxywasm.ResumeHttpRequest()
return
}
var response Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at request phase")
proxywasm.ResumeHttpRequest()
return
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
proxywasm.ResumeHttpRequest()
return
}
denyMessage := DefaultDenyMessage
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
singleCall()
}
ctx.DontReadResponseBody()
config.incrementCounter("ai_sec_request_deny", 1)
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
return
}
denyMessage := DefaultDenyMessage
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
if config.protocolOriginal {
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
ctx.DontReadResponseBody()
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.ResumeHttpRequest()
}
singleCall = func() {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
var nextContentIndex int
if contentIndex+LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.requestCheckService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece, log)),
}
if config.token != "" {
params["SecurityToken"] = config.token
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest()
}
}
singleCall()
return types.ActionPause
}
@@ -385,73 +405,94 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
log.Info("response content is empty. skip")
return types.ActionContinue
}
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": fmt.Sprintf(`{"content": "%s"}`, marshalStr(content, log)),
}
if config.token != "" {
params["SecurityToken"] = config.token
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
defer proxywasm.ResumeHttpResponse()
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
return
}
var response Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
return
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
return
}
denyMessage := DefaultDenyMessage
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(marshalledDenyMessage)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
contentIndex := 0
sessionID, _ := generateHexID(20)
var singleCall func()
callback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
log.Info(string(responseBody))
if statusCode != 200 || gjson.GetBytes(responseBody, "Code").Int() != 200 {
proxywasm.ResumeHttpResponse()
return
}
var response Response
err := json.Unmarshal(responseBody, &response)
if err != nil {
log.Error("failed to unmarshal aliyun content security response at response phase")
proxywasm.ResumeHttpResponse()
return
}
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
if contentIndex >= len(content) {
proxywasm.ResumeHttpResponse()
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
singleCall()
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
config.incrementCounter("ai_sec_response_deny", 1)
},
)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
return types.ActionContinue
return
}
denyMessage := DefaultDenyMessage
if config.denyMessage != "" {
denyMessage = config.denyMessage
} else if response.Data.Advice != nil && response.Data.Advice[0].Answer != "" {
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(marshalledDenyMessage)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
config.incrementCounter("ai_sec_response_deny", 1)
proxywasm.ResumeHttpResponse()
}
singleCall = func() {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
randomID, _ := generateHexID(16)
var nextContentIndex int
if contentIndex+LengthLimit >= len(content) {
nextContentIndex = len(content)
} else {
nextContentIndex = contentIndex + LengthLimit
}
contentPiece := content[contentIndex:nextContentIndex]
contentIndex = nextContentIndex
log.Debugf("current content piece: %s", contentPiece)
params := map[string]string{
"Format": "JSON",
"Version": "2022-03-02",
"SignatureMethod": "Hmac-SHA1",
"SignatureNonce": randomID,
"SignatureVersion": "1.0",
"Action": "TextModerationPlus",
"AccessKeyId": config.ak,
"Timestamp": timestamp,
"Service": config.responseCheckService,
"ServiceParameters": fmt.Sprintf(`{"sessionId": "%s","content": "%s"}`, sessionID, marshalStr(contentPiece, log)),
}
if config.token != "" {
params["SecurityToken"] = config.token
}
signature := getSign(params, config.sk+"&")
reqParams := url.Values{}
for k, v := range params {
reqParams.Add(k, v)
}
reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback)
if err != nil {
log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse()
}
}
singleCall()
return types.ActionPause
}

View File

@@ -1,23 +1,25 @@
github.com/davecgh/go-spew v1.1.1+incompatible/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1+incompatible/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0+incompatible/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.0+incompatible/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM=
github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=

View File

@@ -3,6 +3,7 @@ package main
import (
"fmt"
"net/http"
"net/url"
"path"
"strings"
@@ -35,6 +36,12 @@ func parseConfig(json gjson.Result, grayConfig *config.GrayConfig, log wrapper.L
func onHttpRequestHeaders(ctx wrapper.HttpContext, grayConfig config.GrayConfig, log wrapper.Log) types.Action {
requestPath, _ := proxywasm.GetHttpRequestHeader(":path")
requestPath = path.Clean(requestPath)
parsedURL, err := url.Parse(requestPath)
if err == nil {
requestPath = parsedURL.Path
} else {
log.Errorf("parse request path %s failed: %v", requestPath, err)
}
enabledGray := util.IsGrayEnabled(grayConfig, requestPath)
ctx.SetContext(config.EnabledGray, enabledGray)

View File

@@ -31,11 +31,26 @@ const (
LogLevelCritical
)
type Log struct {
type Log interface {
Trace(msg string)
Tracef(format string, args ...interface{})
Debug(msg string)
Debugf(format string, args ...interface{})
Info(msg string)
Infof(format string, args ...interface{})
Warn(msg string)
Warnf(format string, args ...interface{})
Error(msg string)
Errorf(format string, args ...interface{})
Critical(msg string)
Criticalf(format string, args ...interface{})
}
type DefaultLog struct {
pluginName string
}
func (l Log) log(level LogLevel, msg string) {
func (l *DefaultLog) log(level LogLevel, msg string) {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
@@ -58,7 +73,7 @@ func (l Log) log(level LogLevel, msg string) {
}
}
func (l Log) logFormat(level LogLevel, format string, args ...interface{}) {
func (l *DefaultLog) logFormat(level LogLevel, format string, args ...interface{}) {
requestIDRaw, _ := proxywasm.GetProperty([]string{"x_request_id"})
requestID := string(requestIDRaw)
if requestID == "" {
@@ -81,50 +96,50 @@ func (l Log) logFormat(level LogLevel, format string, args ...interface{}) {
}
}
func (l Log) Trace(msg string) {
func (l *DefaultLog) Trace(msg string) {
l.log(LogLevelTrace, msg)
}
func (l Log) Tracef(format string, args ...interface{}) {
func (l *DefaultLog) Tracef(format string, args ...interface{}) {
l.logFormat(LogLevelTrace, format, args...)
}
func (l Log) Debug(msg string) {
func (l *DefaultLog) Debug(msg string) {
l.log(LogLevelDebug, msg)
}
func (l Log) Debugf(format string, args ...interface{}) {
func (l *DefaultLog) Debugf(format string, args ...interface{}) {
l.logFormat(LogLevelDebug, format, args...)
}
func (l Log) Info(msg string) {
func (l *DefaultLog) Info(msg string) {
l.log(LogLevelInfo, msg)
}
func (l Log) Infof(format string, args ...interface{}) {
func (l *DefaultLog) Infof(format string, args ...interface{}) {
l.logFormat(LogLevelInfo, format, args...)
}
func (l Log) Warn(msg string) {
func (l *DefaultLog) Warn(msg string) {
l.log(LogLevelWarn, msg)
}
func (l Log) Warnf(format string, args ...interface{}) {
func (l *DefaultLog) Warnf(format string, args ...interface{}) {
l.logFormat(LogLevelWarn, format, args...)
}
func (l Log) Error(msg string) {
func (l *DefaultLog) Error(msg string) {
l.log(LogLevelError, msg)
}
func (l Log) Errorf(format string, args ...interface{}) {
func (l *DefaultLog) Errorf(format string, args ...interface{}) {
l.logFormat(LogLevelError, format, args...)
}
func (l Log) Critical(msg string) {
func (l *DefaultLog) Critical(msg string) {
l.log(LogLevelCritical, msg)
}
func (l Log) Criticalf(format string, args ...interface{}) {
func (l *DefaultLog) Criticalf(format string, args ...interface{}) {
l.logFormat(LogLevelCritical, format, args...)
}

View File

@@ -15,6 +15,8 @@
package wrapper
import (
"encoding/json"
"fmt"
"strconv"
"time"
"unsafe"
@@ -26,6 +28,12 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/pkg/matcher"
)
const (
CustomLogKey = "custom_log"
AILogKey = "ai_log"
TraceSpanTagPrefix = "trace_span_tag."
)
type HttpContext interface {
Scheme() string
Host() string
@@ -35,6 +43,14 @@ type HttpContext interface {
GetContext(key string) interface{}
GetBoolContext(key string, defaultValue bool) bool
GetStringContext(key, defaultValue string) string
GetUserAttribute(key string) interface{}
SetUserAttribute(key string, value interface{})
// You can call this function to set custom log
WriteUserAttributeToLog() error
// You can call this function to set custom log with your specific key
WriteUserAttributeToLogWithKey(key string) error
// You can call this function to set custom trace span attribute
WriteUserAttributeToTrace() error
// If the onHttpRequestBody handle is not set, the request body will not be read by default
DontReadRequestBody()
// If the onHttpResponseBody handle is not set, the request body will not be read by default
@@ -98,79 +114,157 @@ func RegisteTickFunc(tickPeriod int64, tickFunc func()) {
globalOnTickFuncs = append(globalOnTickFuncs, TickFuncEntry{0, tickPeriod, tickFunc})
}
func SetCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtx(pluginName, setFuncs...))
func SetCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtx(pluginName, options...))
}
type SetPluginFunc[PluginConfig any] func(*CommonVmCtx[PluginConfig])
func ParseConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = f
}
func SetCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) {
proxywasm.SetVMContext(NewCommonVmCtxWithOptions(pluginName, options...))
}
func ParseOverrideConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = f
ctx.parseRuleConfig = g
}
type CtxOption[PluginConfig any] interface {
Apply(*CommonVmCtx[PluginConfig])
}
func ProcessRequestHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestHeaders = f
}
type parseConfigOption[PluginConfig any] struct {
f ParseConfigFunc[PluginConfig]
}
func ProcessRequestBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestBody = f
}
func (o parseConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = o.f
}
func ProcessStreamingRequestBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingRequestBody = f
}
func ParseConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return parseConfigOption[PluginConfig]{f}
}
func ProcessResponseHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseHeaders = f
}
type parseOverrideConfigOption[PluginConfig any] struct {
parseConfigF ParseConfigFunc[PluginConfig]
parseRuleConfigF ParseRuleConfigFunc[PluginConfig]
}
func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseBody = f
}
func (o *parseOverrideConfigOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.parseConfig = o.parseConfigF
ctx.parseRuleConfig = o.parseRuleConfigF
}
func ProcessStreamingResponseBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingResponseBody = f
}
func ParseOverrideConfigBy[PluginConfig any](f ParseConfigFunc[PluginConfig], g ParseRuleConfigFunc[PluginConfig]) CtxOption[PluginConfig] {
return &parseOverrideConfigOption[PluginConfig]{f, g}
}
func ProcessStreamDoneBy[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) SetPluginFunc[PluginConfig] {
return func(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamDone = f
}
type onProcessRequestHeadersOption[PluginConfig any] struct {
f onHttpHeadersFunc[PluginConfig]
}
func (o *onProcessRequestHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestHeaders = o.f
}
func ProcessRequestHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestHeadersOption[PluginConfig]{f}
}
type onProcessRequestBodyOption[PluginConfig any] struct {
f onHttpBodyFunc[PluginConfig]
}
func (o *onProcessRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpRequestBody = o.f
}
func ProcessRequestBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessRequestBodyOption[PluginConfig]{f}
}
type onProcessStreamingRequestBodyOption[PluginConfig any] struct {
f onHttpStreamingBodyFunc[PluginConfig]
}
func (o *onProcessStreamingRequestBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingRequestBody = o.f
}
func ProcessStreamingRequestBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingRequestBodyOption[PluginConfig]{f}
}
type onProcessResponseHeadersOption[PluginConfig any] struct {
f onHttpHeadersFunc[PluginConfig]
}
func (o *onProcessResponseHeadersOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseHeaders = o.f
}
func ProcessResponseHeadersBy[PluginConfig any](f onHttpHeadersFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseHeadersOption[PluginConfig]{f}
}
type onProcessResponseBodyOption[PluginConfig any] struct {
f onHttpBodyFunc[PluginConfig]
}
func (o *onProcessResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpResponseBody = o.f
}
func ProcessResponseBodyBy[PluginConfig any](f onHttpBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessResponseBodyOption[PluginConfig]{f}
}
type onProcessStreamingResponseBodyOption[PluginConfig any] struct {
f onHttpStreamingBodyFunc[PluginConfig]
}
func (o *onProcessStreamingResponseBodyOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamingResponseBody = o.f
}
func ProcessStreamingResponseBodyBy[PluginConfig any](f onHttpStreamingBodyFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamingResponseBodyOption[PluginConfig]{f}
}
type onProcessStreamDoneOption[PluginConfig any] struct {
f onHttpStreamDoneFunc[PluginConfig]
}
func (o *onProcessStreamDoneOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.onHttpStreamDone = o.f
}
func ProcessStreamDoneBy[PluginConfig any](f onHttpStreamDoneFunc[PluginConfig]) CtxOption[PluginConfig] {
return &onProcessStreamDoneOption[PluginConfig]{f}
}
type logOption[PluginConfig any] struct {
logger Log
}
func (o *logOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
ctx.log = o.logger
}
func WithLogger[PluginConfig any](logger Log) CtxOption[PluginConfig] {
return &logOption[PluginConfig]{logger}
}
func parseEmptyPluginConfig[PluginConfig any](gjson.Result, *PluginConfig, Log) error {
return nil
}
func NewCommonVmCtx[PluginConfig any](pluginName string, setFuncs ...SetPluginFunc[PluginConfig]) *CommonVmCtx[PluginConfig] {
func NewCommonVmCtx[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] {
logger := &DefaultLog{pluginName}
opts := append([]CtxOption[PluginConfig]{WithLogger[PluginConfig](logger)}, options...)
return NewCommonVmCtxWithOptions(pluginName, opts...)
}
func NewCommonVmCtxWithOptions[PluginConfig any](pluginName string, options ...CtxOption[PluginConfig]) *CommonVmCtx[PluginConfig] {
ctx := &CommonVmCtx[PluginConfig]{
pluginName: pluginName,
log: Log{pluginName},
hasCustomConfig: true,
}
for _, set := range setFuncs {
set(ctx)
for _, opt := range options {
opt.Apply(ctx)
}
if ctx.parseConfig == nil {
var config PluginConfig
@@ -257,9 +351,10 @@ func (ctx *CommonPluginCtx[PluginConfig]) OnTick() {
func (ctx *CommonPluginCtx[PluginConfig]) NewHttpContext(contextID uint32) types.HttpContext {
httpCtx := &CommonHttpCtx[PluginConfig]{
plugin: ctx,
contextID: contextID,
userContext: map[string]interface{}{},
plugin: ctx,
contextID: contextID,
userContext: map[string]interface{}{},
userAttribute: map[string]interface{}{},
}
if ctx.vm.onHttpRequestBody != nil || ctx.vm.onHttpStreamingRequestBody != nil {
httpCtx.needRequestBody = true
@@ -289,6 +384,7 @@ type CommonHttpCtx[PluginConfig any] struct {
responseBodySize int
contextID uint32
userContext map[string]interface{}
userAttribute map[string]interface{}
}
func (ctx *CommonHttpCtx[PluginConfig]) SetContext(key string, value interface{}) {
@@ -299,6 +395,63 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetContext(key string) interface{} {
return ctx.userContext[key]
}
func (ctx *CommonHttpCtx[PluginConfig]) SetUserAttribute(key string, value interface{}) {
ctx.userAttribute[key] = value
}
func (ctx *CommonHttpCtx[PluginConfig]) GetUserAttribute(key string) interface{} {
return ctx.userAttribute[key]
}
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLog() error {
return ctx.WriteUserAttributeToLogWithKey(CustomLogKey)
}
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToLogWithKey(key string) error {
// e.g. {\"field1\":\"value1\",\"field2\":\"value2\"}
preMarshalledJsonLogStr, _ := proxywasm.GetProperty([]string{key})
newAttributeMap := map[string]interface{}{}
if string(preMarshalledJsonLogStr) != "" {
// e.g. {"field1":"value1","field2":"value2"}
preJsonLogStr := unmarshalStr(fmt.Sprintf(`"%s"`, string(preMarshalledJsonLogStr)))
err := json.Unmarshal([]byte(preJsonLogStr), &newAttributeMap)
if err != nil {
ctx.plugin.vm.log.Warnf("Unmarshal failed, will overwrite %s, pre value is: %s", key, string(preMarshalledJsonLogStr))
return err
}
}
// update customLog
for k, v := range ctx.userAttribute {
newAttributeMap[k] = v
}
// e.g. {"field1":"value1","field2":2,"field3":"value3"}
jsonStr, _ := json.Marshal(newAttributeMap)
// e.g. {\"field1\":\"value1\",\"field2\":2,\"field3\":\"value3\"}
marshalledJsonStr := marshalStr(string(jsonStr))
if err := proxywasm.SetProperty([]string{key}, []byte(marshalledJsonStr)); err != nil {
ctx.plugin.vm.log.Warnf("failed to set %s in filter state, raw is %s, err is %v", key, marshalledJsonStr, err)
return err
}
return nil
}
func (ctx *CommonHttpCtx[PluginConfig]) WriteUserAttributeToTrace() error {
for k, v := range ctx.userAttribute {
traceSpanTag := TraceSpanTagPrefix + k
traceSpanValue := fmt.Sprint(v)
var err error
if traceSpanValue != "" {
err = proxywasm.SetProperty([]string{traceSpanTag}, []byte(traceSpanValue))
} else {
err = fmt.Errorf("value of %s is empty", traceSpanTag)
}
if err != nil {
ctx.plugin.vm.log.Warnf("Failed to set trace attribute - %s: %s, error message: %v", traceSpanTag, traceSpanValue, err)
}
}
return nil
}
func (ctx *CommonHttpCtx[PluginConfig]) GetBoolContext(key string, defaultValue bool) bool {
if b, ok := ctx.userContext[key].(bool); ok {
return b

View File

@@ -0,0 +1,36 @@
package wrapper
import (
"encoding/json"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
)
func unmarshalStr(marshalledJsonStr string) string {
// e.g. "{\"field1\":\"value1\",\"field2\":\"value2\"}"
var jsonStr string
err := json.Unmarshal([]byte(marshalledJsonStr), &jsonStr)
if err != nil {
proxywasm.LogErrorf("failed to unmarshal json string, raw string is: %s, err is: %v", marshalledJsonStr, err)
return ""
}
// e.g. {"field1":"value1","field2":"value2"}
return jsonStr
}
func marshalStr(raw string) string {
// e.g. {"field1":"value1","field2":"value2"}
helper := map[string]string{
"placeholder": raw,
}
marshalledHelper, _ := json.Marshal(helper)
marshalledRaw := gjson.GetBytes(marshalledHelper, "placeholder").Raw
if len(marshalledRaw) >= 2 {
// e.g. {\"field1\":\"value1\",\"field2\":\"value2\"}
return marshalledRaw[1 : len(marshalledRaw)-1]
} else {
proxywasm.LogErrorf("failed to marshal json string, raw string is: %s", raw)
return ""
}
}