Compare commits

..

51 Commits

Author SHA1 Message Date
澄潭
fd1eb54f25 Release 2.0.6 (#1686) 2025-01-17 15:22:43 +08:00
澄潭
c7550e2d49 Update deploy-to-oss.yaml 2025-01-17 15:10:40 +08:00
Se7en
ba74f4bbb9 fix: baidu api issue (#1685) 2025-01-16 21:42:43 +08:00
澄潭
9e418dafd9 release 2.0.6-rc.3 (#1680) 2025-01-15 20:47:20 +08:00
澄潭
95523a1bc7 Fix istio lds cache (#1679) 2025-01-15 20:44:13 +08:00
澄潭
dcd8466127 Update build-and-test-plugin.yaml 2025-01-15 20:19:58 +08:00
澄潭
cceae6ad2a update cpp wasm plugins (#1675) 2025-01-15 19:15:11 +08:00
zty98751
32f9a5ff32 fix istio commit 2025-01-15 15:29:44 +08:00
澄潭
6f95297b80 Release 2.0.6-rc.2 (#1671) 2025-01-14 20:10:53 +08:00
Kent Dong
95426d5ccf fix: Fix a typo in the README files of ai-statistics plugin (#1670) 2025-01-14 13:39:55 +08:00
澄潭
a05b6b1e9d add ai_log field (#1669) 2025-01-14 10:03:24 +08:00
Jun
d0628344da add higress architecture doc (#1662) 2025-01-14 09:48:32 +08:00
韩贤涛
a1bf315b13 fix: resolve blocking issue with minimax responses in ai-proxy (#1663) 2025-01-14 09:43:19 +08:00
mamba
b3d9123d59 [frontend-gray] 微前端灰度 场景,支持 IncludePathPrefixes字段 (#1666) 2025-01-13 16:24:51 +08:00
rinfx
817061c6cc remove dependency for ai-statistic (#1660) 2025-01-10 13:43:29 +08:00
rinfx
ea0d5e7564 Improve ai plugins (#1657)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2025-01-09 22:04:51 +08:00
澄潭
2a89c3bb70 Optimize wasmplugin proto (#1656) 2025-01-09 13:19:46 +08:00
johnlanni
a570c72504 Update Chart.lock 2025-01-08 17:14:27 +08:00
澄潭
ab1316dfe1 rel: Release 2.0.6-rc.1 (#1653) 2025-01-08 17:08:09 +08:00
澄潭
e97448b71b Update metrics & enable lds cache (#1650) 2025-01-08 16:49:23 +08:00
澄潭
6820a06a99 fix tls version annotation (#1652) 2025-01-08 15:31:39 +08:00
澄潭
4733af849d Update README.md 2025-01-08 11:30:29 +08:00
yunmaoQu
1c2330e33b feat: add TLS version annotation support for per-rule configuration (#1592)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2025-01-06 21:29:09 +08:00
澄潭
61fef0ecf8 Release 2.0.5 (#1646) 2025-01-06 19:42:18 +08:00
澄潭
d29b8d7ca8 fix ai proxy checkStream (#1645) 2025-01-06 15:30:02 +08:00
澄潭
2501895b66 ai-cache update body buffer limit size (#1644) 2025-01-06 14:53:29 +08:00
Kent Dong
187a7b5408 fix: Enlarge the default retry timeout in ai-proxy (#1640) 2025-01-03 11:19:40 +08:00
Jingze
00be491d02 feat: support github provider for oidc wasm plugin (#1639) 2025-01-02 10:01:54 +08:00
ayanami-desu
2d74c48e8a Add cohere embedding for ai-cache (#1572) 2024-12-27 17:48:44 +08:00
澄潭
6dc4d43df5 optimize ai cache (#1626) 2024-12-27 10:10:57 +08:00
rinfx
2a4e55d46f move oidcHandler from global to pluginconfig (#1601) 2024-12-26 19:15:20 +08:00
Se7en
579c986915 feat: retry failed request (#1590) 2024-12-26 18:30:50 +08:00
Kent Dong
380717ae3d fix: Make opa listen to all IPs (#1621) 2024-12-26 17:41:28 +08:00
Kent Dong
8f3723f554 feat: Support setting gateway.unprivilegedPortSupported manually (#1616) 2024-12-23 19:45:47 +08:00
VinciWu557
909cc0f088 feat: AI 代理 Wasm 插件接入 Together AI (#1617) 2024-12-23 15:39:56 +08:00
007gzs
4eaf204737 Enhance the capabilities of the AI Intent plugin (#1605) 2024-12-20 10:25:17 +08:00
澄潭
748bcb083a redis wrapper support lazy init and database options (#1602) 2024-12-19 16:22:56 +08:00
澄潭
39c007d045 optimize ai proxy (#1603) 2024-12-19 16:22:35 +08:00
rinfx
d74d327b68 bugfix: cannot parse content if one streaming body has multi chunks (#1606) 2024-12-19 16:21:57 +08:00
澄潭
be27726721 Update CODEOWNERS 2024-12-19 14:36:11 +08:00
澄潭
34cc1c0632 Update README.md 2024-12-18 17:02:28 +08:00
澄潭
5694475872 Update README.md 2024-12-18 16:59:03 +08:00
rinfx
2f5709a93e qwen bailian compatible bug fix (#1597) 2024-12-17 16:57:31 +08:00
StarryNight
2a200cdd42 AI proxy return unified status in header phase (#1588) 2024-12-16 18:41:38 +08:00
rinfx
ec39d56731 AI observability upgrade (#1587)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-12-16 10:27:49 +08:00
韩贤涛
8544fa604d feat: support choosing chatCompletionV2 or chatCompletionPro API for minimax provider (#1593) 2024-12-15 15:12:00 +08:00
mirror
0ba63e5dd4 fix: default port of static service in ai-cache plugin (#1591) 2024-12-13 19:03:26 +08:00
mirror
441408c593 docs: fix typos in ai-quota document (#1589) 2024-12-13 08:56:43 +08:00
duxin40
be57960c22 Support OpenAI embedding. (#1542) 2024-12-11 11:42:51 +08:00
rinfx
f32020068a bugfix and extend ai log (#1576) 2024-12-09 20:39:13 +08:00
澄潭
1a8fce48f0 Update release-hgctl.yaml 2024-12-06 14:01:18 +08:00
112 changed files with 3305 additions and 1231 deletions

View File

@@ -6,11 +6,15 @@ on:
paths: paths:
- 'plugins/**' - 'plugins/**'
- 'test/**' - 'test/**'
- 'helm/**'
- 'Makefile.core.mk'
pull_request: pull_request:
branches: [ "*" ] branches: [ "*" ]
paths: paths:
- 'plugins/**' - 'plugins/**'
- 'test/**' - 'test/**'
- 'helm/**'
- 'Makefile.core.mk'
workflow_dispatch: ~ workflow_dispatch: ~
jobs: jobs:

View File

@@ -19,7 +19,7 @@ jobs:
- name: Download Helm Charts Index - name: Download Helm Charts Index
uses: doggycool/ossutil-github-action@master uses: doggycool/ossutil-github-action@master
with: with:
ossArgs: 'cp -r -u oss://higress-website-cn-hongkong/helm-charts/index.yaml ./artifact/' ossArgs: 'cp oss://higress-website-cn-hongkong/helm-charts/index.yaml ./artifact/'
accessKey: ${{ secrets.ACCESS_KEYID }} accessKey: ${{ secrets.ACCESS_KEYID }}
accessSecret: ${{ secrets.ACCESS_KEYSECRET }} accessSecret: ${{ secrets.ACCESS_KEYSECRET }}
endpoint: oss-cn-hongkong.aliyuncs.com endpoint: oss-cn-hongkong.aliyuncs.com

View File

@@ -58,7 +58,7 @@ jobs:
hgctl_${{ env.HGCTL_VERSION }}_darwin_arm64.tar.gz hgctl_${{ env.HGCTL_VERSION }}_darwin_arm64.tar.gz
release-hgctl-macos-amd64: release-hgctl-macos-amd64:
runs-on: macos-12 runs-on: macos-14
env: env:
HGCTL_VERSION: ${{github.ref_name}} HGCTL_VERSION: ${{github.ref_name}}
steps: steps:

View File

@@ -12,6 +12,7 @@ header:
- 'LICENSE' - 'LICENSE'
- 'api/**' - 'api/**'
- 'samples/**' - 'samples/**'
- 'docs/**'
- '.github/**' - '.github/**'
- '.licenserc.yaml' - '.licenserc.yaml'
- 'helm/**' - 'helm/**'

View File

@@ -2,7 +2,8 @@
/envoy @gengleilei @johnlanni /envoy @gengleilei @johnlanni
/istio @SpecialYang @johnlanni /istio @SpecialYang @johnlanni
/pkg @SpecialYang @johnlanni @CH3CHO /pkg @SpecialYang @johnlanni @CH3CHO
/plugins @johnlanni @WeixinX @CH3CHO /plugins @johnlanni @CH3CHO @rinfx
/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx
/plugins/wasm-rust @007gzs @jizhuozhi /plugins/wasm-rust @007gzs @jizhuozhi
/registry @NameHaibinZhang @2456868764 @johnlanni /registry @NameHaibinZhang @2456868764 @johnlanni
/test @Xunzhuo @2456868764 @CH3CHO /test @Xunzhuo @2456868764 @CH3CHO

View File

@@ -187,8 +187,8 @@ install: pre-install
cd helm/higress; helm dependency build cd helm/higress; helm dependency build
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true' helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
ENVOY_LATEST_IMAGE_TAG ?= 2.0.3 ENVOY_LATEST_IMAGE_TAG ?= 958467a353d411ae3f06e03b096bfd342cddb2c6
ISTIO_LATEST_IMAGE_TAG ?= 8be82d2e4c280c29f4952fbeca1e2a79230b7836 ISTIO_LATEST_IMAGE_TAG ?= d9c728d3b01f64855e012b08d136e306f1160397
install-dev: pre-install install-dev: pre-install
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true' helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'
@@ -299,7 +299,7 @@ kube-load-image: $(tools/kind) ## Install the Higress image to a kind cluster us
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server 1.3.0 tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server 1.3.0
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server v1.0 tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server v1.0
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-body 1.0.0 tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-body 1.0.0
tools/hack/docker-pull-image.sh openpolicyagent/opa latest tools/hack/docker-pull-image.sh openpolicyagent/opa 0.61.0
tools/hack/docker-pull-image.sh curlimages/curl latest tools/hack/docker-pull-image.sh curlimages/curl latest
tools/hack/docker-pull-image.sh registry.cn-hangzhou.aliyuncs.com/2456868764/httpbin 1.0.2 tools/hack/docker-pull-image.sh registry.cn-hangzhou.aliyuncs.com/2456868764/httpbin 1.0.2
tools/hack/docker-pull-image.sh registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3 1.0.0-RC3 tools/hack/docker-pull-image.sh registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3 1.0.0-RC3
@@ -312,7 +312,7 @@ kube-load-image: $(tools/kind) ## Install the Higress image to a kind cluster us
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server 1.3.0 tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server 1.3.0
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server v1.0 tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-server v1.0
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-body 1.0.0 tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/echo-body 1.0.0
tools/hack/kind-load-image.sh openpolicyagent/opa latest tools/hack/kind-load-image.sh openpolicyagent/opa 0.61.0
tools/hack/kind-load-image.sh curlimages/curl latest tools/hack/kind-load-image.sh curlimages/curl latest
tools/hack/kind-load-image.sh registry.cn-hangzhou.aliyuncs.com/2456868764/httpbin 1.0.2 tools/hack/kind-load-image.sh registry.cn-hangzhou.aliyuncs.com/2456868764/httpbin 1.0.2
tools/hack/kind-load-image.sh registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3 1.0.0-RC3 tools/hack/kind-load-image.sh registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3 1.0.0-RC3

View File

@@ -6,9 +6,14 @@
</h1> </h1>
<h4 align="center"> AI Native API Gateway </h4> <h4 align="center"> AI Native API Gateway </h4>
<div align="center">
[![Build Status](https://github.com/alibaba/higress/actions/workflows/build-and-test.yaml/badge.svg?branch=main)](https://github.com/alibaba/higress/actions) [![Build Status](https://github.com/alibaba/higress/actions/workflows/build-and-test.yaml/badge.svg?branch=main)](https://github.com/alibaba/higress/actions)
[![license](https://img.shields.io/github/license/alibaba/higress.svg)](https://www.apache.org/licenses/LICENSE-2.0.html) [![license](https://img.shields.io/github/license/alibaba/higress.svg)](https://www.apache.org/licenses/LICENSE-2.0.html)
<a href="https://trendshift.io/repositories/10918" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10918" alt="alibaba%2Fhigress | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
</div>
[**官网**](https://higress.cn/) &nbsp; | [**官网**](https://higress.cn/) &nbsp; |
&nbsp; [**文档**](https://higress.cn/docs/latest/overview/what-is-higress/) &nbsp; | &nbsp; [**文档**](https://higress.cn/docs/latest/overview/what-is-higress/) &nbsp; |
&nbsp; [**博客**](https://higress.cn/blog/) &nbsp; | &nbsp; [**博客**](https://higress.cn/blog/) &nbsp; |
@@ -17,6 +22,7 @@
&nbsp; [**AI插件**](https://higress.cn/plugin/) &nbsp; &nbsp; [**AI插件**](https://higress.cn/plugin/) &nbsp;
<p> <p>
<a href="README_EN.md"> English <a/>| 中文 | <a href="README_JP.md"> 日本語 <a/> <a href="README_EN.md"> English <a/>| 中文 | <a href="README_JP.md"> 日本語 <a/>
</p> </p>
@@ -180,7 +186,7 @@ K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start
### 交流群 ### 交流群
![image](https://img.alicdn.com/imgextra/i2/O1CN01BkopaB22ZsvamFftE_!!6000000007135-0-tps-720-405.jpg) ![image](https://img.alicdn.com/imgextra/i2/O1CN01fZefEP1aPWkzG3A19_!!6000000003322-0-tps-720-405.jpg)
### 技术分享 ### 技术分享

View File

@@ -1 +1 @@
v2.0.4 v2.0.6

View File

@@ -341,7 +341,7 @@ type WasmPlugin struct {
// Extended by Higress, matching rules take effect // Extended by Higress, matching rules take effect
MatchRules []*MatchRule `protobuf:"bytes,102,rep,name=match_rules,json=matchRules,proto3" json:"match_rules,omitempty"` MatchRules []*MatchRule `protobuf:"bytes,102,rep,name=match_rules,json=matchRules,proto3" json:"match_rules,omitempty"`
// disable the default config // disable the default config
DefaultConfigDisable bool `protobuf:"varint,103,opt,name=default_config_disable,json=defaultConfigDisable,proto3" json:"default_config_disable,omitempty"` DefaultConfigDisable *wrappers.BoolValue `protobuf:"bytes,103,opt,name=default_config_disable,json=defaultConfigDisable,proto3" json:"default_config_disable,omitempty"`
} }
func (x *WasmPlugin) Reset() { func (x *WasmPlugin) Reset() {
@@ -467,11 +467,11 @@ func (x *WasmPlugin) GetMatchRules() []*MatchRule {
return nil return nil
} }
func (x *WasmPlugin) GetDefaultConfigDisable() bool { func (x *WasmPlugin) GetDefaultConfigDisable() *wrappers.BoolValue {
if x != nil { if x != nil {
return x.DefaultConfigDisable return x.DefaultConfigDisable
} }
return false return nil
} }
// Extended by Higress // Extended by Higress
@@ -480,11 +480,11 @@ type MatchRule struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Ingress []string `protobuf:"bytes,1,rep,name=ingress,proto3" json:"ingress,omitempty"` Ingress []string `protobuf:"bytes,1,rep,name=ingress,proto3" json:"ingress,omitempty"`
Domain []string `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"` Domain []string `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"`
Config *_struct.Struct `protobuf:"bytes,3,opt,name=config,proto3" json:"config,omitempty"` Config *_struct.Struct `protobuf:"bytes,3,opt,name=config,proto3" json:"config,omitempty"`
ConfigDisable bool `protobuf:"varint,4,opt,name=config_disable,json=configDisable,proto3" json:"config_disable,omitempty"` ConfigDisable *wrappers.BoolValue `protobuf:"bytes,4,opt,name=config_disable,json=configDisable,proto3" json:"config_disable,omitempty"`
Service []string `protobuf:"bytes,5,rep,name=service,proto3" json:"service,omitempty"` Service []string `protobuf:"bytes,5,rep,name=service,proto3" json:"service,omitempty"`
} }
func (x *MatchRule) Reset() { func (x *MatchRule) Reset() {
@@ -540,11 +540,11 @@ func (x *MatchRule) GetConfig() *_struct.Struct {
return nil return nil
} }
func (x *MatchRule) GetConfigDisable() bool { func (x *MatchRule) GetConfigDisable() *wrappers.BoolValue {
if x != nil { if x != nil {
return x.ConfigDisable return x.ConfigDisable
} }
return false return nil
} }
func (x *MatchRule) GetService() []string { func (x *MatchRule) GetService() []string {
@@ -686,7 +686,7 @@ var file_extensions_v1alpha1_wasmplugin_proto_rawDesc = []byte{
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x77, 0x72, 0x61, 0x70, 0x70, 0x65, 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x77, 0x72, 0x61, 0x70, 0x70, 0x65, 0x72, 0x73, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x74, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x22, 0x8d, 0x06, 0x0a, 0x0a, 0x57, 0x61, 0x73, 0x6d, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x6f, 0x22, 0xa9, 0x06, 0x0a, 0x0a, 0x57, 0x61, 0x73, 0x6d, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e,
0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75,
0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x68, 0x61, 0x32, 0x35, 0x36, 0x18, 0x03, 0x20, 0x01, 0x72, 0x6c, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x68, 0x61, 0x32, 0x35, 0x36, 0x18, 0x03, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x73, 0x68, 0x61, 0x32, 0x35, 0x36, 0x12, 0x53, 0x0a, 0x11, 0x69, 0x6d, 0x28, 0x09, 0x52, 0x06, 0x73, 0x68, 0x61, 0x32, 0x35, 0x36, 0x12, 0x53, 0x0a, 0x11, 0x69, 0x6d,
@@ -731,52 +731,55 @@ var file_extensions_v1alpha1_wasmplugin_proto_rawDesc = []byte{
0x73, 0x18, 0x66, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x66, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73,
0x73, 0x2e, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x76, 0x31, 0x61, 0x73, 0x2e, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x76, 0x31, 0x61,
0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x52, 0x75, 0x6c, 0x65, 0x52,
0x0a, 0x6d, 0x61, 0x74, 0x63, 0x68, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x34, 0x0a, 0x16, 0x64, 0x0a, 0x6d, 0x61, 0x74, 0x63, 0x68, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x50, 0x0a, 0x16, 0x64,
0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x64, 0x69, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x64, 0x69,
0x73, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x67, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x64, 0x65, 0x66, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x67, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
0x61, 0x75, 0x6c, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f,
0x65, 0x22, 0xaf, 0x01, 0x0a, 0x09, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x14, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74,
0x18, 0x0a, 0x07, 0x69, 0x6e, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x22, 0xcb, 0x01,
0x52, 0x07, 0x69, 0x6e, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x0a, 0x09, 0x4d, 0x61, 0x74, 0x63, 0x68, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x69,
0x61, 0x69, 0x6e, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x67, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x69, 0x6e,
0x6e, 0x12, 0x2f, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18,
0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x2f, 0x0a,
0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e,
0x69, 0x67, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x64, 0x69, 0x73, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
0x61, 0x62, 0x6c, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x63, 0x6f, 0x6e, 0x66, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x41,
0x69, 0x67, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x0a, 0x0e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65,
0x76, 0x69, 0x63, 0x65, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e,
0x69, 0x63, 0x65, 0x22, 0x41, 0x0a, 0x08, 0x56, 0x6d, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c,
0x35, 0x0a, 0x03, 0x65, 0x6e, 0x76, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x68, 0x75, 0x65, 0x52, 0x0d, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c,
0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x18, 0x05, 0x20, 0x03,
0x73, 0x2e, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x45, 0x6e, 0x76, 0x56, 0x61, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x41, 0x0a, 0x08, 0x56,
0x72, 0x52, 0x03, 0x65, 0x6e, 0x76, 0x22, 0x7e, 0x0a, 0x06, 0x45, 0x6e, 0x76, 0x56, 0x61, 0x72, 0x6d, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x35, 0x0a, 0x03, 0x65, 0x6e, 0x76, 0x18, 0x01,
0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x23, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x65,
0x6e, 0x61, 0x6d, 0x65, 0x12, 0x4a, 0x0a, 0x0a, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x5f, 0x66, 0x72, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68,
0x6f, 0x6d, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2b, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x61, 0x31, 0x2e, 0x45, 0x6e, 0x76, 0x56, 0x61, 0x72, 0x52, 0x03, 0x65, 0x6e, 0x76, 0x22, 0x7e,
0x73, 0x73, 0x2e, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x76, 0x31, 0x0a, 0x06, 0x45, 0x6e, 0x76, 0x56, 0x61, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65,
0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x45, 0x6e, 0x76, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x53, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x4a, 0x0a, 0x0a,
0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x09, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x46, 0x72, 0x6f, 0x6d, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x5f, 0x66, 0x72, 0x6f, 0x6d, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e,
0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x32, 0x2b, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x65, 0x78, 0x74, 0x65, 0x6e,
0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x2a, 0x45, 0x0a, 0x0b, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x2e, 0x45,
0x50, 0x68, 0x61, 0x73, 0x65, 0x12, 0x15, 0x0a, 0x11, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x6e, 0x76, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x09, 0x76,
0x46, 0x49, 0x45, 0x44, 0x5f, 0x50, 0x48, 0x41, 0x53, 0x45, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x61, 0x6c, 0x75, 0x65, 0x46, 0x72, 0x6f, 0x6d, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
0x41, 0x55, 0x54, 0x48, 0x4e, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x41, 0x55, 0x54, 0x48, 0x5a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x2a, 0x45,
0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x53, 0x54, 0x41, 0x54, 0x53, 0x10, 0x03, 0x2a, 0x42, 0x0a, 0x0a, 0x0b, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x50, 0x68, 0x61, 0x73, 0x65, 0x12, 0x15, 0x0a,
0x0a, 0x50, 0x75, 0x6c, 0x6c, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x12, 0x16, 0x0a, 0x12, 0x55, 0x11, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x5f, 0x50, 0x48, 0x41,
0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x53, 0x45, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x41, 0x55, 0x54, 0x48, 0x4e, 0x10, 0x01, 0x12,
0x59, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x49, 0x66, 0x4e, 0x6f, 0x74, 0x50, 0x72, 0x65, 0x73, 0x09, 0x0a, 0x05, 0x41, 0x55, 0x54, 0x48, 0x5a, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x53, 0x54,
0x65, 0x6e, 0x74, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x10, 0x41, 0x54, 0x53, 0x10, 0x03, 0x2a, 0x42, 0x0a, 0x0a, 0x50, 0x75, 0x6c, 0x6c, 0x50, 0x6f, 0x6c,
0x02, 0x2a, 0x26, 0x0a, 0x0e, 0x45, 0x6e, 0x76, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x53, 0x6f, 0x75, 0x69, 0x63, 0x79, 0x12, 0x16, 0x0a, 0x12, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49,
0x72, 0x63, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x49, 0x4e, 0x4c, 0x49, 0x4e, 0x45, 0x10, 0x00, 0x12, 0x45, 0x44, 0x5f, 0x50, 0x4f, 0x4c, 0x49, 0x43, 0x59, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x49,
0x08, 0x0a, 0x04, 0x48, 0x4f, 0x53, 0x54, 0x10, 0x01, 0x2a, 0x2d, 0x0a, 0x0c, 0x46, 0x61, 0x69, 0x66, 0x4e, 0x6f, 0x74, 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x74, 0x10, 0x01, 0x12, 0x0a, 0x0a,
0x6c, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67, 0x79, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x41, 0x49, 0x06, 0x41, 0x6c, 0x77, 0x61, 0x79, 0x73, 0x10, 0x02, 0x2a, 0x26, 0x0a, 0x0e, 0x45, 0x6e, 0x76,
0x4c, 0x5f, 0x43, 0x4c, 0x4f, 0x53, 0x45, 0x10, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x41, 0x49, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x53, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x49,
0x4c, 0x5f, 0x4f, 0x50, 0x45, 0x4e, 0x10, 0x01, 0x42, 0x34, 0x5a, 0x32, 0x67, 0x69, 0x74, 0x68, 0x4e, 0x4c, 0x49, 0x4e, 0x45, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x4f, 0x53, 0x54, 0x10,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62, 0x61, 0x2f, 0x68, 0x01, 0x2a, 0x2d, 0x0a, 0x0c, 0x46, 0x61, 0x69, 0x6c, 0x53, 0x74, 0x72, 0x61, 0x74, 0x65, 0x67,
0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x79, 0x12, 0x0e, 0x0a, 0x0a, 0x46, 0x41, 0x49, 0x4c, 0x5f, 0x43, 0x4c, 0x4f, 0x53, 0x45, 0x10,
0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2f, 0x76, 0x31, 0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x62, 0x06, 0x00, 0x12, 0x0d, 0x0a, 0x09, 0x46, 0x41, 0x49, 0x4c, 0x5f, 0x4f, 0x50, 0x45, 0x4e, 0x10, 0x01,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x42, 0x34, 0x5a, 0x32, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61,
0x6c, 0x69, 0x62, 0x61, 0x62, 0x61, 0x2f, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2f, 0x61,
0x70, 0x69, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2f, 0x76, 0x31,
0x61, 0x6c, 0x70, 0x68, 0x61, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@@ -804,6 +807,7 @@ var file_extensions_v1alpha1_wasmplugin_proto_goTypes = []interface{}{
(*EnvVar)(nil), // 7: higress.extensions.v1alpha1.EnvVar (*EnvVar)(nil), // 7: higress.extensions.v1alpha1.EnvVar
(*_struct.Struct)(nil), // 8: google.protobuf.Struct (*_struct.Struct)(nil), // 8: google.protobuf.Struct
(*wrappers.Int32Value)(nil), // 9: google.protobuf.Int32Value (*wrappers.Int32Value)(nil), // 9: google.protobuf.Int32Value
(*wrappers.BoolValue)(nil), // 10: google.protobuf.BoolValue
} }
var file_extensions_v1alpha1_wasmplugin_proto_depIdxs = []int32{ var file_extensions_v1alpha1_wasmplugin_proto_depIdxs = []int32{
1, // 0: higress.extensions.v1alpha1.WasmPlugin.image_pull_policy:type_name -> higress.extensions.v1alpha1.PullPolicy 1, // 0: higress.extensions.v1alpha1.WasmPlugin.image_pull_policy:type_name -> higress.extensions.v1alpha1.PullPolicy
@@ -814,14 +818,16 @@ var file_extensions_v1alpha1_wasmplugin_proto_depIdxs = []int32{
6, // 5: higress.extensions.v1alpha1.WasmPlugin.vm_config:type_name -> higress.extensions.v1alpha1.VmConfig 6, // 5: higress.extensions.v1alpha1.WasmPlugin.vm_config:type_name -> higress.extensions.v1alpha1.VmConfig
8, // 6: higress.extensions.v1alpha1.WasmPlugin.default_config:type_name -> google.protobuf.Struct 8, // 6: higress.extensions.v1alpha1.WasmPlugin.default_config:type_name -> google.protobuf.Struct
5, // 7: higress.extensions.v1alpha1.WasmPlugin.match_rules:type_name -> higress.extensions.v1alpha1.MatchRule 5, // 7: higress.extensions.v1alpha1.WasmPlugin.match_rules:type_name -> higress.extensions.v1alpha1.MatchRule
8, // 8: higress.extensions.v1alpha1.MatchRule.config:type_name -> google.protobuf.Struct 10, // 8: higress.extensions.v1alpha1.WasmPlugin.default_config_disable:type_name -> google.protobuf.BoolValue
7, // 9: higress.extensions.v1alpha1.VmConfig.env:type_name -> higress.extensions.v1alpha1.EnvVar 8, // 9: higress.extensions.v1alpha1.MatchRule.config:type_name -> google.protobuf.Struct
2, // 10: higress.extensions.v1alpha1.EnvVar.value_from:type_name -> higress.extensions.v1alpha1.EnvValueSource 10, // 10: higress.extensions.v1alpha1.MatchRule.config_disable:type_name -> google.protobuf.BoolValue
11, // [11:11] is the sub-list for method output_type 7, // 11: higress.extensions.v1alpha1.VmConfig.env:type_name -> higress.extensions.v1alpha1.EnvVar
11, // [11:11] is the sub-list for method input_type 2, // 12: higress.extensions.v1alpha1.EnvVar.value_from:type_name -> higress.extensions.v1alpha1.EnvValueSource
11, // [11:11] is the sub-list for extension type_name 13, // [13:13] is the sub-list for method output_type
11, // [11:11] is the sub-list for extension extendee 13, // [13:13] is the sub-list for method input_type
0, // [0:11] is the sub-list for field type_name 13, // [13:13] is the sub-list for extension type_name
13, // [13:13] is the sub-list for extension extendee
0, // [0:13] is the sub-list for field type_name
} }
func init() { file_extensions_v1alpha1_wasmplugin_proto_init() } func init() { file_extensions_v1alpha1_wasmplugin_proto_init() }

View File

@@ -112,7 +112,7 @@ message WasmPlugin {
// Extended by Higress, matching rules take effect // Extended by Higress, matching rules take effect
repeated MatchRule match_rules = 102; repeated MatchRule match_rules = 102;
// disable the default config // disable the default config
bool default_config_disable = 103; google.protobuf.BoolValue default_config_disable = 103;
} }
// Extended by Higress // Extended by Higress
@@ -120,7 +120,7 @@ message MatchRule {
repeated string ingress = 1; repeated string ingress = 1;
repeated string domain = 2; repeated string domain = 2;
google.protobuf.Struct config = 3; google.protobuf.Struct config = 3;
bool config_disable = 4; google.protobuf.BoolValue config_disable = 4;
repeated string service = 5; repeated string service = 5;
} }

143
docs/architecture.md Normal file
View File

@@ -0,0 +1,143 @@
# Higress 核心组件和原理
Higress 是基于 Envoy 和 Istio 进行二次定制化开发构建和功能增强,同时利用 Envoy 和 Istio 一些插件机制,实现了一个轻量级的网关服务。其包括 3 个核心组件Higress Controller控制器、Higress Gateway网关和 Higress Console控制台
下图概况了其核心工作流程:
![img](./images/img_02_01.png)
本章将重点介绍 Higress 的两个核心组件Higress Controller 和 Higress Gateway。
## 1 Higress Console
Higress Console 是 Higress 网关的管理控制台,主要功能是管理 Higress 网关的路由配置、插件配置等。
### 1.1 Higress Admin SDK
Higress Admin SDK 脱胎于 Higress Console。起初它作为 Higress Console 的一部分,为前端界面提供实际的功能支持。后来考虑到对接外部系统等需求,将配置管理的部分剥离出来,形成一个独立的逻辑组件,便于和各个系统进行对接。目前支持服务来源管理、服务管理、路由管理、域名管理、证书管理、插件管理等功能。
Higress Admin SDK 现在只提供 Java 版本,且要求 JDK 版本不低于 17。具体如何集成请参考 Higress 官方 BLOG [如何使用 Higress Admin SDK 进行配置管理](https://higress.io/zh-cn/blog/admin-sdk-intro)。
## 2 Higress Controller
Higress Controller控制器 是 Higress 的核心组件,其功能主要是实现 Higress 网关的服务发现、动态配置管理以及动态下发配置给数据面。Higress Controller 内部包含两个子组件Discovery 和 Higress Core。
### 2.1 Discovery 组件
Discovery 组件Istio Pilot-Discovery是 Istio 的核心组件负责服务发现、配置管理、证书签发、控制面和数据面之间的通讯和配置下发等。Discovery 内部结构比较复杂,本文只介绍 Discovery 配置管理和服务发现的基本原理,其核心功能的详细介绍可以参考赵化冰老师的 BLOG [Istio Pilot 组件介绍](https://www.zhaohuabing.com/post/2019-10-21-pilot-discovery-code-analysis/)。
Discovery 将 Kubernetes Service、Gateway API 配置等转换成 Istio 配置,然后将所有 Istio 配置合并转成符合 xDS 接口规范的数据结构,通过 GRPC 下发到数据面的 Envoy。其工作原理如下图
![img](./images/img_02_02.png)
#### 2.1.1 Config Controller
Discovery 为了更好管理 Istio 配置来源,提供 `Config Controller` 用于管理各种配置来源,目前支持 4 种类型的 `Config Controller`
- Kubernetes使用 Kubernetes 作为配置信息来源,该方式的直接依赖 Kubernetes 强大的 CRD 机制来存储配置信息,简单方便,是 Istio 最开始使用的配置信息存储方案, 其中包括 `Kubernetes Controller``Gateway API Controller` 两个实现。
- MCPMesh Configuration Protocol使用 Kubernetes 存储配置数据导致了 Istio 和 Kubernetes 的耦合,限制了 Istio 在非 Kubernetes 环境下的运用。为了解决该耦合Istio 社区提出了 MCP。
- Memory一个基于内存的 Config Controller 实现,主要用于测试。
- File一个基于文件的 Config Controller 实现,主要用于测试。
1. Istio 配置
Istio 配置包括:`Gateway``VirtualService``DestinationRule``ServiceEntry``EnvoyFilter``WasmPlugin``WorkloadEntry``WorkloadGroup` 等,可以参考 Istio 官方文档[流量管理](https://istio.io/latest/zh/docs/reference/config/networking/)了解更多配置信息。
2. Gateway API 配置
Gateway API 配置包括:`GatewayClass``Gateway``HttpRoute``TCPRoute``GRPCRoute` 等, 可以参考 Gateway API 官方文档 [Gateway API](https://gateway-api.sigs.k8s.io/api-types/gateway/) 了解更多配置信息。
3. MCP over xDS
Discovery 作为 MCP Client任何实现了 MCP 协议的 Server 都可以通过 MCP 协议向 Discovery 下发配置信息,从而消除了 Istio 和 Kubernetes 之间的耦合, 同时使 Istio 的配置信息处理更加灵活和可扩展。
同时 MCP 是一种基于 xDS 协议的配置管理协议Higress Core 通过实现 MCP 协议,使 Higress Core 成为 Discovery 的 Istio 配置来源。
4. Config Controller 来源配置
`higress-system` 命名空间中,名为 `higress-config` 的 Configmap 中,`mesh` 配置项包含一个 `configSources` 属性用于配置来源。其 Configmap 部分配置项如下:
```yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: higress-config
namespace: higress-system
data:
mesh: |-
accessLogEncoding: TEXT
...
configSources:
- address: xds://127.0.0.1:15051
- address: k8s://
...
meshNetworks: "networks: {}"
```
#### 2.1.2 Service Controller
`Service Controller` 用于管理各种 `Service Registry`,提供服务发现数据,目前 Istio 支持的 `Service Registry` 主要包括:
- Kubernetes对接 Kubernetes Registry可以将 Kubernetes 中定义的 Service 和 Endpoint 采集到 Istio 中。
- Memory一个基于内存的 Service Controller 实现,主要用于测试。
### 2.2 Higress Core 组件
Higress Core 核心逻辑如下图:
![img](./images/img_02_03.png)
Higress Core 内部包含两个核心子组件: Ingress Config 和 Cert Server。
#### 2.2.1 Ingress Config
Ingress Config 包含 6 个控制器,各自负责不同的功能:
- Ingress Controller监听 Ingress 资源,将 Ingress 转换为 Istio 的 Gateway、VirtualService、DestinationRule 等资源。
- Gateway Controller监听 Gateway、VirtualService、DestinationRule 等资源。
- McpBridge Controller根据 McpBridge 的配置,将来自 Nacos、Eureka、Consul、Zookeeper 等外部注册中心或 DNS 的服务信息转换成 Istio ServiceEntry 资源。
- Http2Rpc Controller监听 Http2Rpc 资源,实现 HTTP 协议到 RPC 协议的转换。用户可以通过配置协议转换,将 RPC 服务以 HTTP 接口的形式暴露,从而使用 HTTP 请求调用 RPC 接口。
- WasmPlugin Controller监听 WasmPlugin 资源,将 Higress WasmPlugin 转化为 Istio WasmPlugin。Higress WasmPlugin 在 Istio WasmPlugin 的基础上进行了扩展,支持全局、路由、域名、服务级别的配置。
- ConfigmapMgr监听 Higress 的全局配置 `higress-config` ConfigMap可以根据 tracing、gzip 等配置构造 EnvoyFilter。
#### 2.2.2 Cert Server
Cert Server 管理 Secret 资源和证书自动签发。
## 3 Higress Gateway
Higress Gateway 内部包含两个子组件Pilot Agent 和 Envoy。Pilot Agent 主要负责 Envoy 的启动和配置,同时代理 Envoy xDS 请求到 Discovery。 Envoy 作为数据面,负责接收控制面的配置下发,并代理请求到业务服务。 Pilot Agent 和 Envoy 之间通讯协议是使用 xDS 协议, 通过 Unix Domain SocketUDS进行通信。
Envoy 核心架构如下图:
![img](./images/img_02_04.png)
### 1 Envoy 核心组件
- 下游Downstream:
下游是 Envoy 的客户端,它们负责发起请求并接收 Envoy 的响应。下游通常是最终用户的设备或服务,它们通过 Envoy 代理与后端服务进行通信。
- 上游Upstream:
上游是 Envoy 的后端服务器,它们接收 Envoy 代理的连接和请求。上游提供服务或数据,对来自下游客户端的请求进行处理并返回响应。
- 监听器Listener:
监听器是可以接受来自下游客户端连接的网络地址(如 IP 地址和端口Unix Domain Socket 等。Envoy 支持在单个进程中配置任意数量的监听器。监听器可以通过 `Listener Discovery ServiceLDS`来动态发现和更新。
- 路由Router:
路由器是 Envoy 中连接下游和上游的桥梁。它负责决定如何将监听器接收到的请求路由到适当的集群。路由器根据配置的路由规则如路径、HTTP 标头 等,来确定请求的目标集群,从而实现精确的流量控制和路由。路由器可以通过 `Route Discovery ServiceRDS`来动态发现和更新。
- 集群Cluster:
集群是一组逻辑上相似的服务提供者的集合。集群成员的选择由负载均衡策略决定,确保请求能够均匀或按需分配到不同的服务实例。集群可以通过 `Cluster Discovery ServiceCDS`来动态发现和更新。
- 端点Endpoint:
端点是上游集群中的具体服务实例,可以是 IP 地址和端口号的组合。端点可以通过 `Endpoint Discovery ServiceEDS`来动态发现和更新。
- SSL/TLS:
Envoy 可以通过 `Secret Discovery Service (SDS)` 动态获取监听器和集群所需的 TLS 证书、私钥以及信任的根证书和撤销机制等配置信息。
通过这些组件的协同工作Envoy 能够高效地处理网络请求,提供流量管理、负载均衡、服务发现和动态路由等关键功能。
要详细了解 Envoy 的工作原理,可以参考[Envoy 官方文档](https://www.envoyproxy.io/docs/envoy/latest/intro/intro),最佳的方式可以通过一个请求通过 [Envoy 代理的生命周期](https://www.envoyproxy.io/docs/envoy/latest/intro/life_of_a_request)事件的过程来理解 Envoy 的工作原理。
## 参考
- [1] [Istio Pilot 组件介绍](https://www.zhaohuabing.com/post/2019-10-21-pilot-discovery-code-analysis/)
- [2] [Istio 服务注册插件机制代码解析](https://www.zhaohuabing.com/post/2019-02-18-pilot-service-registry-code-analysis/)
- [3] [Istio Pilot代码深度解析](https://www.zhaohuabing.com/post/2019-10-21-pilot-discovery-code-analysis/)
- [4] [Envoy 官方文档](https://www.envoyproxy.io/docs/envoy/latest/intro/intro)

BIN
docs/images/img_02_01.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 119 KiB

BIN
docs/images/img_02_02.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

BIN
docs/images/img_02_03.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 167 KiB

BIN
docs/images/img_02_04.png Normal file
View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

View File

@@ -1,5 +1,5 @@
apiVersion: v2 apiVersion: v2
appVersion: 2.0.4 appVersion: 2.0.6
description: Helm chart for deploying higress gateways description: Helm chart for deploying higress gateways
icon: https://higress.io/img/higress_logo_small.png icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/ home: http://higress.io/
@@ -10,4 +10,4 @@ name: higress-core
sources: sources:
- http://github.com/alibaba/higress - http://github.com/alibaba/higress
type: application type: application
version: 2.0.4 version: 2.0.6

View File

@@ -9,7 +9,7 @@
accessLogFile: "/dev/stdout" accessLogFile: "/dev/stdout"
{{- end }} {{- end }}
ingressControllerMode: "OFF" ingressControllerMode: "OFF"
accessLogFormat: '{"authority":"%REQ(X-ENVOY-ORIGINAL-HOST?:AUTHORITY)%","bytes_received":"%BYTES_RECEIVED%","bytes_sent":"%BYTES_SENT%","downstream_local_address":"%DOWNSTREAM_LOCAL_ADDRESS%","downstream_remote_address":"%DOWNSTREAM_REMOTE_ADDRESS%","duration":"%DURATION%","istio_policy_status":"%DYNAMIC_METADATA(istio.mixer:status)%","method":"%REQ(:METHOD)%","path":"%REQ(X-ENVOY-ORIGINAL-PATH?:PATH)%","protocol":"%PROTOCOL%","request_id":"%REQ(X-REQUEST-ID)%","requested_server_name":"%REQUESTED_SERVER_NAME%","response_code":"%RESPONSE_CODE%","response_flags":"%RESPONSE_FLAGS%","route_name":"%ROUTE_NAME%","start_time":"%START_TIME%","trace_id":"%REQ(X-B3-TRACEID)%","upstream_cluster":"%UPSTREAM_CLUSTER%","upstream_host":"%UPSTREAM_HOST%","upstream_local_address":"%UPSTREAM_LOCAL_ADDRESS%","upstream_service_time":"%RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)%","upstream_transport_failure_reason":"%UPSTREAM_TRANSPORT_FAILURE_REASON%","user_agent":"%REQ(USER-AGENT)%","x_forwarded_for":"%REQ(X-FORWARDED-FOR)%","response_code_details":"%RESPONSE_CODE_DETAILS%"} accessLogFormat: '{"ai_log":"%FILTER_STATE(wasm.ai_log:PLAIN)%","authority":"%REQ(X-ENVOY-ORIGINAL-HOST?:AUTHORITY)%","bytes_received":"%BYTES_RECEIVED%","bytes_sent":"%BYTES_SENT%","downstream_local_address":"%DOWNSTREAM_LOCAL_ADDRESS%","downstream_remote_address":"%DOWNSTREAM_REMOTE_ADDRESS%","duration":"%DURATION%","istio_policy_status":"%DYNAMIC_METADATA(istio.mixer:status)%","method":"%REQ(:METHOD)%","path":"%REQ(X-ENVOY-ORIGINAL-PATH?:PATH)%","protocol":"%PROTOCOL%","request_id":"%REQ(X-REQUEST-ID)%","requested_server_name":"%REQUESTED_SERVER_NAME%","response_code":"%RESPONSE_CODE%","response_flags":"%RESPONSE_FLAGS%","route_name":"%ROUTE_NAME%","start_time":"%START_TIME%","trace_id":"%REQ(X-B3-TRACEID)%","upstream_cluster":"%UPSTREAM_CLUSTER%","upstream_host":"%UPSTREAM_HOST%","upstream_local_address":"%UPSTREAM_LOCAL_ADDRESS%","upstream_service_time":"%RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)%","upstream_transport_failure_reason":"%UPSTREAM_TRANSPORT_FAILURE_REASON%","user_agent":"%REQ(USER-AGENT)%","x_forwarded_for":"%REQ(X-FORWARDED-FOR)%","response_code_details":"%RESPONSE_CODE_DETAILS%"}
' '
dnsRefreshRate: 200s dnsRefreshRate: 200s

View File

@@ -136,6 +136,10 @@ spec:
periodSeconds: 3 periodSeconds: 3
timeoutSeconds: 5 timeoutSeconds: 5
env: env:
- name: ENABLE_PUSH_ALL_MCP_CLUSTERS
value: "{{ .Values.global.enablePushAllMCPClusters }}"
- name: PILOT_ENABLE_LDS_CACHE
value: "{{ .Values.global.enableLDSCache }}"
- name: PILOT_ENABLE_QUIC_LISTENERS - name: PILOT_ENABLE_QUIC_LISTENERS
value: "true" value: "true"
- name: VALIDATION_WEBHOOK_CONFIG_NAME - name: VALIDATION_WEBHOOK_CONFIG_NAME

View File

@@ -1,7 +1,8 @@
{{- if eq .Values.gateway.kind "DaemonSet" -}} {{- if eq .Values.gateway.kind "DaemonSet" -}}
{{- $o11y := .Values.global.o11y }} {{- $o11y := .Values.global.o11y }}
{{- $unprivilegedPortSupported := true }} {{- if eq .Values.gateway.unprivilegedPortSupported nil -}}
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }} {{- $unprivilegedPortSupported := true }}
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }}
{{- $kernelVersion := $node.status.nodeInfo.kernelVersion }} {{- $kernelVersion := $node.status.nodeInfo.kernelVersion }}
{{- if $kernelVersion }} {{- if $kernelVersion }}
{{- $kernelVersion = regexFind "^(\\d+\\.\\d+\\.\\d+)" $kernelVersion }} {{- $kernelVersion = regexFind "^(\\d+\\.\\d+\\.\\d+)" $kernelVersion }}
@@ -9,8 +10,9 @@
{{- $unprivilegedPortSupported = false }} {{- $unprivilegedPortSupported = false }}
{{- end }} {{- end }}
{{- end }} {{- end }}
{{- end -}}
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
{{- end -}} {{- end -}}
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
apiVersion: apps/v1 apiVersion: apps/v1
kind: DaemonSet kind: DaemonSet

View File

@@ -1,6 +1,7 @@
{{- if eq .Values.gateway.kind "Deployment" -}} {{- if eq .Values.gateway.kind "Deployment" -}}
{{- $unprivilegedPortSupported := true }} {{- if eq .Values.gateway.unprivilegedPortSupported nil -}}
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }} {{- $unprivilegedPortSupported := true }}
{{- range $index, $node := (lookup "v1" "Node" "default" "").items }}
{{- $kernelVersion := $node.status.nodeInfo.kernelVersion }} {{- $kernelVersion := $node.status.nodeInfo.kernelVersion }}
{{- if $kernelVersion }} {{- if $kernelVersion }}
{{- $kernelVersion = regexFind "^(\\d+\\.\\d+\\.\\d+)" $kernelVersion }} {{- $kernelVersion = regexFind "^(\\d+\\.\\d+\\.\\d+)" $kernelVersion }}
@@ -8,8 +9,9 @@
{{- $unprivilegedPortSupported = false }} {{- $unprivilegedPortSupported = false }}
{{- end }} {{- end }}
{{- end }} {{- end }}
{{- end -}}
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
{{- end -}} {{- end -}}
{{- $_ := set .Values.gateway "unprivilegedPortSupported" $unprivilegedPortSupported -}}
apiVersion: apps/v1 apiVersion: apps/v1
kind: Deployment kind: Deployment

View File

@@ -3,7 +3,9 @@ global:
enableH3: false enableH3: false
enableIPv6: false enableIPv6: false
enableProxyProtocol: false enableProxyProtocol: false
liteMetrics: true enableLDSCache: true
enablePushAllMCPClusters: true
liteMetrics: false
xdsMaxRecvMsgSize: "104857600" xdsMaxRecvMsgSize: "104857600"
defaultUpstreamConcurrencyThreshold: 10000 defaultUpstreamConcurrencyThreshold: 10000
enableSRDS: true enableSRDS: true
@@ -465,6 +467,7 @@ gateway:
# On Kubernetes 1.22+, this only requires the `net.ipv4.ip_unprivileged_port_start` sysctl. # On Kubernetes 1.22+, this only requires the `net.ipv4.ip_unprivileged_port_start` sysctl.
securityContext: ~ securityContext: ~
containerSecurityContext: ~ containerSecurityContext: ~
unprivilegedPortSupported: ~
service: service:
# -- Type of service. Set to "None" to disable the service entirely # -- Type of service. Set to "None" to disable the service entirely

View File

@@ -1,9 +1,9 @@
dependencies: dependencies:
- name: higress-core - name: higress-core
repository: file://../core repository: file://../core
version: 2.0.4 version: 2.0.6
- name: higress-console - name: higress-console
repository: https://higress.io/helm-charts/ repository: https://higress.io/helm-charts/
version: 1.4.6 version: 2.0.2
digest: sha256:ec570ac7ae8a6de976e7ffafaadae4a33beeabfb4b13debe63e0cfa100e2eb8c digest: sha256:9c84a628df434c4bf23ec10d62ad7ddf4b15957f797b01bbaa492ede33d87003
generated: "2024-12-06T11:34:04.628976+08:00" generated: "2025-01-17T15:10:43.589701962+08:00"

View File

@@ -1,5 +1,5 @@
apiVersion: v2 apiVersion: v2
appVersion: 2.0.4 appVersion: 2.0.6
description: Helm chart for deploying Higress gateways description: Helm chart for deploying Higress gateways
icon: https://higress.io/img/higress_logo_small.png icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/ home: http://higress.io/
@@ -12,9 +12,9 @@ sources:
dependencies: dependencies:
- name: higress-core - name: higress-core
repository: "file://../core" repository: "file://../core"
version: 2.0.4 version: 2.0.6
- name: higress-console - name: higress-console
repository: "https://higress.io/helm-charts/" repository: "https://higress.io/helm-charts/"
version: 1.4.6 version: 2.0.2
type: application type: application
version: 2.0.4 version: 2.0.6

View File

@@ -149,6 +149,7 @@ The command removes all the Kubernetes components associated with the chart and
| gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used | | gateway.serviceAccount.name | string | `""` | The name of the service account to use. If not set, the release name is used |
| gateway.tag | string | `""` | | | gateway.tag | string | `""` | |
| gateway.tolerations | list | `[]` | | | gateway.tolerations | list | `[]` | |
| gateway.unprivilegedPortSupported | string | `nil` | |
| global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. | | global.autoscalingv2API | bool | `true` | whether to use autoscaling/v2 template for HPA settings for internal usage only, not to be configured by users. |
| global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. | | global.caAddress | string | `""` | The customized CA address to retrieve certificates for the pods in the cluster. CSR clients such as the Istio Agent and ingress gateways can use this to specify the CA endpoint. If not set explicitly, default to the Istio discovery address. |
| global.caName | string | `""` | The name of the CA for workload certificates. For example, when caName=GkeWorkloadCertificate, GKE workload certificates will be used as the certificates for workloads. The default value is "" and when caName="", the CA will be configured by other mechanisms (e.g., environmental variable CA_PROVIDER). | | global.caName | string | `""` | The name of the CA for workload certificates. For example, when caName=GkeWorkloadCertificate, GKE workload certificates will be used as the certificates for workloads. The default value is "" and when caName="", the CA will be configured by other mechanisms (e.g., environmental variable CA_PROVIDER). |
@@ -161,7 +162,9 @@ The command removes all the Kubernetes components associated with the chart and
| global.enableH3 | bool | `false` | | | global.enableH3 | bool | `false` | |
| global.enableIPv6 | bool | `false` | | | global.enableIPv6 | bool | `false` | |
| global.enableIstioAPI | bool | `true` | If true, Higress Controller will monitor istio resources as well | | global.enableIstioAPI | bool | `true` | If true, Higress Controller will monitor istio resources as well |
| global.enableLDSCache | bool | `true` | |
| global.enableProxyProtocol | bool | `false` | | | global.enableProxyProtocol | bool | `false` | |
| global.enablePushAllMCPClusters | bool | `true` | |
| global.enableSRDS | bool | `true` | | | global.enableSRDS | bool | `true` | |
| global.enableStatus | bool | `true` | If true, Higress Controller will update the status field of Ingress resources. When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the status field of the corresponding Ingress object. | | global.enableStatus | bool | `true` | If true, Higress Controller will update the status field of Ingress resources. When migrating from Nginx Ingress, in order to avoid status field of Ingress objects being overwritten, this parameter needs to be set to false, so Higress won't write the entry IP to the status field of the corresponding Ingress object. |
| global.externalIstiod | bool | `false` | Configure a remote cluster data plane controlled by an external istiod. When set to true, istiod is not deployed locally and only a subset of the other discovery charts are enabled. | | global.externalIstiod | bool | `false` | Configure a remote cluster data plane controlled by an external istiod. When set to true, istiod is not deployed locally and only a subset of the other discovery charts are enabled. |
@@ -174,7 +177,7 @@ The command removes all the Kubernetes components associated with the chart and
| global.istiod | object | `{"enableAnalysis":false}` | Enabled by default in master for maximising testing. | | global.istiod | object | `{"enableAnalysis":false}` | Enabled by default in master for maximising testing. |
| global.jwtPolicy | string | `"third-party-jwt"` | Configure the policy for validating JWT. Currently, two options are supported: "third-party-jwt" and "first-party-jwt". | | global.jwtPolicy | string | `"third-party-jwt"` | Configure the policy for validating JWT. Currently, two options are supported: "third-party-jwt" and "first-party-jwt". |
| global.kind | bool | `false` | | | global.kind | bool | `false` | |
| global.liteMetrics | bool | `true` | | | global.liteMetrics | bool | `false` | |
| global.local | bool | `false` | When deploying to a local cluster (e.g.: kind cluster), set this to true. | | global.local | bool | `false` | When deploying to a local cluster (e.g.: kind cluster), set this to true. |
| global.logAsJson | bool | `false` | | | global.logAsJson | bool | `false` | |
| global.logging | object | `{"level":"default:info"}` | Comma-separated minimum per-scope logging level of messages to output, in the form of <scope>:<level>,<scope>:<level> The control plane has different scopes depending on component, but can configure default log level across all components If empty, default scope and level will be used as configured in code | | global.logging | object | `{"level":"default:info"}` | Comma-separated minimum per-scope logging level of messages to output, in the form of <scope>:<level>,<scope>:<level> The control plane has different scopes depending on component, but can configure default log level across all components If empty, default scope and level will be used as configured in code |

View File

@@ -881,7 +881,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
if result.PluginConfig != nil { if result.PluginConfig != nil {
return result, nil return result, nil
} }
if !obj.DefaultConfigDisable { if !isBoolValueTrue(obj.DefaultConfigDisable) {
result.PluginConfig = obj.DefaultConfig result.PluginConfig = obj.DefaultConfig
} }
hasValidRule := false hasValidRule := false
@@ -893,7 +893,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
} }
var ruleValues []*_struct.Value var ruleValues []*_struct.Value
for _, rule := range obj.MatchRules { for _, rule := range obj.MatchRules {
if rule.ConfigDisable { if isBoolValueTrue(rule.ConfigDisable) {
continue continue
} }
if rule.Config == nil { if rule.Config == nil {
@@ -982,13 +982,17 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
} }
} }
} }
if !hasValidRule && obj.DefaultConfigDisable { if !hasValidRule && isBoolValueTrue(obj.DefaultConfigDisable) {
return nil, nil return nil, nil
} }
return result, nil return result, nil
} }
func isBoolValueTrue(b *wrappers.BoolValue) bool {
return b != nil && b.Value
}
func (m *IngressConfig) AddOrUpdateWasmPlugin(clusterNamespacedName util.ClusterNamespacedName) { func (m *IngressConfig) AddOrUpdateWasmPlugin(clusterNamespacedName util.ClusterNamespacedName) {
if clusterNamespacedName.Namespace != m.namespace { if clusterNamespacedName.Namespace != m.namespace {
return return

View File

@@ -15,6 +15,7 @@
package annotations package annotations
import ( import (
"fmt"
"strings" "strings"
networking "istio.io/api/networking/v1alpha3" networking "istio.io/api/networking/v1alpha3"
@@ -27,9 +28,11 @@ import (
) )
const ( const (
authTLSSecret = "auth-tls-secret" authTLSSecret = "auth-tls-secret"
sslCipher = "ssl-cipher" sslCipher = "ssl-cipher"
gatewaySdsCaSuffix = "-cacert" gatewaySdsCaSuffix = "-cacert"
annotationMinTLSVersion = "tls-min-protocol-version"
annotationMaxTLSVersion = "tls-max-protocol-version"
) )
var ( var (
@@ -41,6 +44,8 @@ type DownstreamTLSConfig struct {
CipherSuites []string CipherSuites []string
Mode networking.ServerTLSSettings_TLSmode Mode networking.ServerTLSSettings_TLSmode
CASecretName types.NamespacedName CASecretName types.NamespacedName
MinVersion string
MaxVersion string
} }
type downstreamTLS struct{} type downstreamTLS struct{}
@@ -82,6 +87,14 @@ func (d downstreamTLS) Parse(annotations Annotations, config *Ingress, _ *Global
downstreamTLSConfig.CipherSuites = validCipherSuite downstreamTLSConfig.CipherSuites = validCipherSuite
} }
if minVersion, err := annotations.ParseStringASAP(annotationMinTLSVersion); err == nil {
downstreamTLSConfig.MinVersion = minVersion
}
if maxVersion, err := annotations.ParseStringASAP(annotationMaxTLSVersion); err == nil {
downstreamTLSConfig.MaxVersion = maxVersion
}
return nil return nil
} }
@@ -107,11 +120,44 @@ func (d downstreamTLS) ApplyGateway(gateway *networking.Gateway, config *Ingress
if len(downstreamTLSConfig.CipherSuites) != 0 { if len(downstreamTLSConfig.CipherSuites) != 0 {
server.Tls.CipherSuites = downstreamTLSConfig.CipherSuites server.Tls.CipherSuites = downstreamTLSConfig.CipherSuites
} }
if downstreamTLSConfig.MinVersion != "" {
if version, err := convertTLSVersion(downstreamTLSConfig.MinVersion); err != nil {
IngressLog.Errorf("Invalid minimum TLS version: %v", err)
} else {
server.Tls.MinProtocolVersion = version
}
}
if downstreamTLSConfig.MaxVersion != "" {
if version, err := convertTLSVersion(downstreamTLSConfig.MaxVersion); err != nil {
IngressLog.Errorf("Invalid maximum TLS version: %v", err)
} else {
server.Tls.MaxProtocolVersion = version
}
}
} }
} }
} }
func needDownstreamTLS(annotations Annotations) bool { func needDownstreamTLS(annotations Annotations) bool {
return annotations.HasASAP(sslCipher) || return annotations.HasASAP(sslCipher) ||
annotations.HasASAP(authTLSSecret) annotations.HasASAP(authTLSSecret) ||
annotations.HasASAP(annotationMinTLSVersion) ||
annotations.HasASAP(annotationMaxTLSVersion)
}
func convertTLSVersion(version string) (networking.ServerTLSSettings_TLSProtocol, error) {
switch version {
case "TLSv1.0":
return networking.ServerTLSSettings_TLSV1_0, nil
case "TLSv1.1":
return networking.ServerTLSSettings_TLSV1_1, nil
case "TLSv1.2":
return networking.ServerTLSSettings_TLSV1_2, nil
case "TLSv1.3":
return networking.ServerTLSSettings_TLSV1_3, nil
}
return networking.ServerTLSSettings_TLS_AUTO, fmt.Errorf("invalid TLS version: %s. Valid values are: TLSv1.0, TLSv1.1, TLSv1.2, TLSv1.3", version)
} }

View File

@@ -26,11 +26,15 @@ var parser = downstreamTLS{}
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
testCases := []struct { testCases := []struct {
name string
input map[string]string input map[string]string
expect *DownstreamTLSConfig expect *DownstreamTLSConfig
}{ }{
{},
{ {
name: "empty config",
},
{
name: "ssl cipher only",
input: map[string]string{ input: map[string]string{
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA", buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
}, },
@@ -40,9 +44,24 @@ func TestParse(t *testing.T) {
}, },
}, },
{ {
name: "with TLS version config",
input: map[string]string{ input: map[string]string{
buildNginxAnnotationKey(authTLSSecret): "test", buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA", buildNginxAnnotationKey(annotationMaxTLSVersion): "TLSv1.3",
},
expect: &DownstreamTLSConfig{
Mode: networking.ServerTLSSettings_SIMPLE,
MinVersion: "TLSv1.2",
MaxVersion: "TLSv1.3",
},
},
{
name: "complete config",
input: map[string]string{
buildNginxAnnotationKey(authTLSSecret): "test",
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
buildNginxAnnotationKey(annotationMaxTLSVersion): "TLSv1.3",
}, },
expect: &DownstreamTLSConfig{ expect: &DownstreamTLSConfig{
CASecretName: types.NamespacedName{ CASecretName: types.NamespacedName{
@@ -51,34 +70,79 @@ func TestParse(t *testing.T) {
}, },
Mode: networking.ServerTLSSettings_MUTUAL, Mode: networking.ServerTLSSettings_MUTUAL,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384", "AES128-SHA"}, CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384", "AES128-SHA"},
}, MinVersion: "TLSv1.2",
}, MaxVersion: "TLSv1.3",
{
input: map[string]string{
buildHigressAnnotationKey(authTLSSecret): "test/foo",
DefaultAnnotationsPrefix + "/" + sslCipher: "ECDHE-RSA-AES256-GCM-SHA384:AES128-SHA",
},
expect: &DownstreamTLSConfig{
CASecretName: types.NamespacedName{
Namespace: "test",
Name: "foo",
},
Mode: networking.ServerTLSSettings_MUTUAL,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384", "AES128-SHA"},
}, },
}, },
} }
for _, testCase := range testCases { for _, tc := range testCases {
t.Run("", func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
config := &Ingress{ config := &Ingress{
Meta: Meta{ Meta: Meta{
Namespace: "foo", Namespace: "foo",
}, },
} }
_ = parser.Parse(testCase.input, config, nil) err := parser.Parse(tc.input, config, nil)
if !reflect.DeepEqual(testCase.expect, config.DownstreamTLS) { if err != nil {
t.Fatalf("Should be equal") t.Fatalf("Parse failed: %v", err)
}
if !reflect.DeepEqual(tc.expect, config.DownstreamTLS) {
t.Fatalf("Parse result mismatch:\nExpect: %+v\nGot: %+v", tc.expect, config.DownstreamTLS)
}
})
}
}
func TestConvertTLSVersion(t *testing.T) {
testCases := []struct {
name string
version string
expect networking.ServerTLSSettings_TLSProtocol
wantErr bool
}{
{
name: "TLS 1.0",
version: "TLSv1.0",
expect: networking.ServerTLSSettings_TLSV1_0,
},
{
name: "TLS 1.1",
version: "TLSv1.1",
expect: networking.ServerTLSSettings_TLSV1_1,
},
{
name: "TLS 1.2",
version: "TLSv1.2",
expect: networking.ServerTLSSettings_TLSV1_2,
},
{
name: "TLS 1.3",
version: "TLSv1.3",
expect: networking.ServerTLSSettings_TLSV1_3,
},
{
name: "invalid version",
version: "invalid",
expect: networking.ServerTLSSettings_TLS_AUTO,
wantErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := convertTLSVersion(tc.version)
if tc.wantErr {
if err == nil {
t.Error("Expected error but got none")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if result != tc.expect {
t.Errorf("Expected %v but got %v", tc.expect, result)
}
} }
}) })
} }
@@ -86,11 +150,13 @@ func TestParse(t *testing.T) {
func TestApplyGateway(t *testing.T) { func TestApplyGateway(t *testing.T) {
testCases := []struct { testCases := []struct {
name string
input *networking.Gateway input *networking.Gateway
config *Ingress config *Ingress
expect *networking.Gateway expect *networking.Gateway
}{ }{
{ {
name: "apply TLS version",
input: &networking.Gateway{ input: &networking.Gateway{
Servers: []*networking.Server{ Servers: []*networking.Server{
{ {
@@ -105,7 +171,8 @@ func TestApplyGateway(t *testing.T) {
}, },
config: &Ingress{ config: &Ingress{
DownstreamTLS: &DownstreamTLSConfig{ DownstreamTLS: &DownstreamTLSConfig{
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"}, MinVersion: "TLSv1.2",
MaxVersion: "TLSv1.3",
}, },
}, },
expect: &networking.Gateway{ expect: &networking.Gateway{
@@ -115,14 +182,16 @@ func TestApplyGateway(t *testing.T) {
Protocol: "HTTPS", Protocol: "HTTPS",
}, },
Tls: &networking.ServerTLSSettings{ Tls: &networking.ServerTLSSettings{
Mode: networking.ServerTLSSettings_SIMPLE, Mode: networking.ServerTLSSettings_SIMPLE,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"}, MinProtocolVersion: networking.ServerTLSSettings_TLSV1_2,
MaxProtocolVersion: networking.ServerTLSSettings_TLSV1_3,
}, },
}, },
}, },
}, },
}, },
{ {
name: "complete config",
input: &networking.Gateway{ input: &networking.Gateway{
Servers: []*networking.Server{ Servers: []*networking.Server{
{ {
@@ -144,24 +213,28 @@ func TestApplyGateway(t *testing.T) {
}, },
Mode: networking.ServerTLSSettings_MUTUAL, Mode: networking.ServerTLSSettings_MUTUAL,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"}, CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
MinVersion: "TLSv1.2",
MaxVersion: "TLSv1.3",
}, },
}, },
expect: &networking.Gateway{ expect: &networking.Gateway{
Servers: []*networking.Server{ Servers: []*networking.Server{
{ {Port: &networking.Port{
Port: &networking.Port{ Protocol: "HTTPS",
Protocol: "HTTPS", },
},
Tls: &networking.ServerTLSSettings{ Tls: &networking.ServerTLSSettings{
CredentialName: "kubernetes-ingress://cluster/foo/bar", CredentialName: "kubernetes-ingress://cluster/foo/bar",
Mode: networking.ServerTLSSettings_MUTUAL, Mode: networking.ServerTLSSettings_MUTUAL,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"}, CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
MinProtocolVersion: networking.ServerTLSSettings_TLSV1_2,
MaxProtocolVersion: networking.ServerTLSSettings_TLSV1_3,
}, },
}, },
}, },
}, },
}, },
{ {
name: "invalid TLS version",
input: &networking.Gateway{ input: &networking.Gateway{
Servers: []*networking.Server{ Servers: []*networking.Server{
{ {
@@ -169,20 +242,15 @@ func TestApplyGateway(t *testing.T) {
Protocol: "HTTPS", Protocol: "HTTPS",
}, },
Tls: &networking.ServerTLSSettings{ Tls: &networking.ServerTLSSettings{
Mode: networking.ServerTLSSettings_SIMPLE, Mode: networking.ServerTLSSettings_SIMPLE,
CredentialName: "kubernetes-ingress://cluster/foo/bar",
}, },
}, },
}, },
}, },
config: &Ingress{ config: &Ingress{
DownstreamTLS: &DownstreamTLSConfig{ DownstreamTLS: &DownstreamTLSConfig{
CASecretName: types.NamespacedName{ MinVersion: "invalid",
Namespace: "foo", MaxVersion: "invalid",
Name: "bar-cacert",
},
Mode: networking.ServerTLSSettings_MUTUAL,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
}, },
}, },
expect: &networking.Gateway{ expect: &networking.Gateway{
@@ -192,48 +260,10 @@ func TestApplyGateway(t *testing.T) {
Protocol: "HTTPS", Protocol: "HTTPS",
}, },
Tls: &networking.ServerTLSSettings{ Tls: &networking.ServerTLSSettings{
CredentialName: "kubernetes-ingress://cluster/foo/bar", Mode: networking.ServerTLSSettings_SIMPLE,
Mode: networking.ServerTLSSettings_MUTUAL, // Invalid versions should default to TLS_AUTO
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"}, MinProtocolVersion: networking.ServerTLSSettings_TLS_AUTO,
}, MaxProtocolVersion: networking.ServerTLSSettings_TLS_AUTO,
},
},
},
},
{
input: &networking.Gateway{
Servers: []*networking.Server{
{
Port: &networking.Port{
Protocol: "HTTPS",
},
Tls: &networking.ServerTLSSettings{
Mode: networking.ServerTLSSettings_SIMPLE,
CredentialName: "kubernetes-ingress://cluster/foo/bar",
},
},
},
},
config: &Ingress{
DownstreamTLS: &DownstreamTLSConfig{
CASecretName: types.NamespacedName{
Namespace: "bar",
Name: "foo",
},
Mode: networking.ServerTLSSettings_MUTUAL,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
},
},
expect: &networking.Gateway{
Servers: []*networking.Server{
{
Port: &networking.Port{
Protocol: "HTTPS",
},
Tls: &networking.ServerTLSSettings{
CredentialName: "kubernetes-ingress://cluster/foo/bar",
Mode: networking.ServerTLSSettings_SIMPLE,
CipherSuites: []string{"ECDHE-RSA-AES256-GCM-SHA384"},
}, },
}, },
}, },
@@ -241,11 +271,59 @@ func TestApplyGateway(t *testing.T) {
}, },
} }
for _, testCase := range testCases { for _, tc := range testCases {
t.Run("", func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
parser.ApplyGateway(testCase.input, testCase.config) parser.ApplyGateway(tc.input, tc.config)
if !reflect.DeepEqual(testCase.input, testCase.expect) { if !reflect.DeepEqual(tc.input, tc.expect) {
t.Fatalf("Should be equal") t.Fatalf("ApplyGateway result mismatch for %s:\nExpect: %+v\nGot: %+v",
tc.name, tc.expect, tc.input)
}
})
}
}
func TestNeedDownstreamTLS(t *testing.T) {
testCases := []struct {
name string
annotations map[string]string
expect bool
}{
{
name: "empty annotations",
annotations: map[string]string{},
expect: false,
},
{
name: "with ssl cipher",
annotations: map[string]string{
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384",
},
expect: true,
},
{
name: "with TLS version",
annotations: map[string]string{
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
},
expect: true,
},
{
name: "with multiple TLS configs",
annotations: map[string]string{
buildNginxAnnotationKey(sslCipher): "ECDHE-RSA-AES256-GCM-SHA384",
buildNginxAnnotationKey(annotationMinTLSVersion): "TLSv1.2",
buildNginxAnnotationKey(annotationMaxTLSVersion): "TLSv1.3",
},
expect: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := needDownstreamTLS(tc.annotations)
if result != tc.expect {
t.Errorf("needDownstreamTLS() for %s = %v, want %v",
tc.name, result, tc.expect)
} }
}) })
} }

View File

@@ -1,17 +1,15 @@
## 功能说明 # 功能说明
`model-mapper`插件实现了基于LLM协议中的model参数路由的功能 `model-mapper`插件实现了基于LLM协议中的model参数路由的功能
## 配置字段 # 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | | ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 | | `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 | | `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 ## 运行属性 | `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 |
插件执行阶段:认证阶段
插件执行优先级800
|
## 效果说明 ## 效果说明
如下配置 如下配置

View File

@@ -1,7 +1,7 @@
## 功能说明 # 功能说明
`model-router`插件实现了基于LLM协议中的model参数路由的功能 `model-router`插件实现了基于LLM协议中的model参数路由的功能
## 配置字段 # 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- | | ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |

View File

@@ -26,6 +26,7 @@ proxy_wasm_cc_binary(
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
"//common:json_util", "//common:json_util",
"//common:http_util", "//common:http_util",
"//common:regex_util",
"//common:rule_util", "//common:rule_util",
], ],
) )
@@ -44,6 +45,7 @@ cc_library(
"//common:json_util", "//common:json_util",
"@proxy_wasm_cpp_host//:lib", "@proxy_wasm_cpp_host//:lib",
"//common:http_util_nullvm", "//common:http_util_nullvm",
"//common:regex_util",
"//common:rule_util_nullvm", "//common:rule_util_nullvm",
], ],
) )

View File

@@ -1,31 +1,22 @@
--- # 功能说明
title: 请求屏蔽
keywords: [higress,request block]
description: 请求屏蔽插件配置参考
---
## 功能说明
`request-block`插件实现了基于 URL、请求头等特征屏蔽 HTTP 请求,可以用于防护部分站点资源不对外部暴露 `request-block`插件实现了基于 URL、请求头等特征屏蔽 HTTP 请求,可以用于防护部分站点资源不对外部暴露
## 运行属性 # 配置字段
插件执行阶段:`鉴权阶段` | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
插件执行优先级:`320` | -------- | -------- | -------- | -------- | -------- |
| block_urls | array of string | 选填,`block_urls`,`block_exact_urls`,`block_regexp_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽 URL 的字符串 |
| block_exact_urls | array of string | 选填,`block_urls`,`block_exact_urls`,`block_regexp_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要精确屏蔽 URL 的字符串 |
| block_regexp_urls | array of string | 选填,`block_urls`,`block_exact_urls`,`block_regexp_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽 URL 的正则表达式 |
| block_headers | array of string | 选填,`block_urls`,`block_exact_urls`,`block_regexp_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽请求 Header 的字符串 |
| block_bodies | array of string | 选填,`block_urls`,`block_exact_urls`,`block_regexp_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽请求 Body 的字符串 |
| blocked_code | number | 选填 | 403 | 配置请求被屏蔽时返回的 HTTP 状态码 |
| blocked_message | string | 选填 | - | 配置请求被屏蔽时返回的 HTTP 应答 Body |
| case_sensitive | bool | 选填 | true | 配置匹配时是否区分大小写,默认区分 |
## 配置字段 # 配置示例
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | ## 屏蔽请求 url 路径
| -------- | -------- | -------- | -------- | -------- |
| block_urls | array of string | 选填,`block_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽 URL 的字符串 |
| block_headers | array of string | 选填,`block_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽请求 Header 的字符串 |
| block_bodies | array of string | 选填,`block_urls`,`block_headers`,`block_bodies` 中至少必填一项 | - | 配置用于匹配需要屏蔽请求 Body 的字符串 |
| blocked_code | number | 选填 | 403 | 配置请求被屏蔽时返回的 HTTP 状态码 |
| blocked_message | string | 选填 | - | 配置请求被屏蔽时返回的 HTTP 应答 Body |
| case_sensitive | bool | 选填 | true | 配置匹配时是否区分大小写,默认区分 |
## 配置示例
### 屏蔽请求 url 路径
```yaml ```yaml
block_urls: block_urls:
- swagger.html - swagger.html
@@ -40,7 +31,36 @@ curl http://example.com?foo=Bar
curl http://exmaple.com/Swagger.html curl http://exmaple.com/Swagger.html
``` ```
### 屏蔽请求 header ## 屏蔽精确匹配的请求 url 路径
```yaml
block_exact_urls:
- /swagger.html?foo=bar
case_sensitive: false
```
根据该配置,下列请求将被禁止访问:
```bash
curl http://exmaple.com/Swagger.html?foo=Bar
```
## 屏蔽正则匹配的请求 url 路径
```yaml
block_exact_urls:
- .*swagger.*
case_sensitive: false
```
根据该配置,下列请求将被禁止访问:
```bash
curl http://exmaple.com/Swagger.html?foo=Bar
```
## 屏蔽请求 header
```yaml ```yaml
block_headers: block_headers:
- example-key - example-key
@@ -54,9 +74,9 @@ curl http://example.com -H 'example-key: 123'
curl http://exmaple.com -H 'my-header: example-value' curl http://exmaple.com -H 'my-header: example-value'
``` ```
### 屏蔽请求 body ## 屏蔽请求 body
```yaml ```yaml
block_bodies: block_bodys:
- "hello world" - "hello world"
case_sensitive: false case_sensitive: false
``` ```
@@ -68,8 +88,30 @@ curl http://example.com -d 'Hello World'
curl http://exmaple.com -d 'hello world' curl http://exmaple.com -d 'hello world'
``` ```
## 对特定路由或域名开启
```yaml
# 使用 _rules_ 字段进行细粒度规则配置
_rules_:
# 规则一:按路由名称匹配生效
- _match_route_:
- route-a
- route-b
block_bodys:
- "hello world"
# 规则二:按域名匹配生效
- _match_domain_:
- "*.example.com"
- test.com
block_urls:
- "swagger.html"
block_bodys:
- "hello world"
```
此例 `_match_route_` 中指定的 `route-a``route-b` 即在创建网关路由时填写的路由名称,当匹配到这两个路由时,将使用此段配置;
此例 `_match_domain_` 中指定的 `*.example.com``test.com` 用于匹配请求的域名,当发现域名匹配时,将使用此段配置;
配置的匹配生效顺序,将按照 `_rules_` 下规则的排列顺序,匹配第一个规则后生效对应配置,后续规则将被忽略。
## 请求 Body 大小限制 # 请求 Body 大小限制
当配置了 `block_bodies` 时,仅支持小于 32 MB 的请求 Body 进行匹配。若请求 Body 大于此限制,并且不存在匹配到的 `block_urls``block_headers` 项时,不会对该请求执行屏蔽操作 当配置了 `block_bodys` 时,仅支持小于 32 MB 的请求 Body 进行匹配。若请求 Body 大于此限制,并且不存在匹配到的 `block_urls``block_headers` 项时,不会对该请求执行屏蔽操作
当配置了 `block_bodies` 时,若请求 Body 超过全局配置 DownstreamConnectionBufferLimits将返回 `413 Payload Too Large` 当配置了 `block_bodys` 时,若请求 Body 超过全局配置 DownstreamConnectionBufferLimits将返回 `413 Payload Too Large`, 请在参数配置页调高此项。注意调高此参数配置后,网关内存使用将有显著增加。

View File

@@ -15,6 +15,7 @@
#include "extensions/request_block/plugin.h" #include "extensions/request_block/plugin.h"
#include <array> #include <array>
#include <memory>
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
@@ -89,6 +90,48 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
LOG_WARN("failed to parse configuration for block_urls."); LOG_WARN("failed to parse configuration for block_urls.");
return false; return false;
} }
if (!JsonArrayIterate(
configuration, "block_exact_urls", [&](const json& item) -> bool {
auto url = JsonValueAs<std::string>(item);
if (url.second != Wasm::Common::JsonParserResultDetail::OK) {
LOG_WARN("cannot parse block_exact_urls");
return false;
}
if (rule.case_sensitive) {
rule.block_exact_urls.push_back(std::move(url.first.value()));
} else {
rule.block_exact_urls.push_back(
absl::AsciiStrToLower(url.first.value()));
}
return true;
})) {
LOG_WARN("failed to parse configuration for block_exact_urls.");
return false;
}
if (!JsonArrayIterate(
configuration, "block_regexp_urls", [&](const json& item) -> bool {
auto url = JsonValueAs<std::string>(item);
if (url.second != Wasm::Common::JsonParserResultDetail::OK) {
LOG_WARN("cannot parse block_regexp_urls");
return false;
}
std::string regex;
if (rule.case_sensitive) {
regex = url.first.value();
} else {
regex = absl::AsciiStrToLower(url.first.value());
}
auto re = std::make_unique<ReMatcher>(regex);
if (!re->error().empty()) {
LOG_WARN(re->error());
return false;
}
rule.block_regexp_urls.push_back(std::move(re));
return true;
})) {
LOG_WARN("failed to parse configuration for block_regexp_urls.");
return false;
}
if (!JsonArrayIterate( if (!JsonArrayIterate(
configuration, "block_headers", [&](const json& item) -> bool { configuration, "block_headers", [&](const json& item) -> bool {
auto header = JsonValueAs<std::string>(item); auto header = JsonValueAs<std::string>(item);
@@ -125,8 +168,28 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
LOG_WARN("failed to parse configuration for block_bodys."); LOG_WARN("failed to parse configuration for block_bodys.");
return false; return false;
} }
// compatiable
if (!JsonArrayIterate(
configuration, "block_bodies", [&](const json& item) -> bool {
auto body = JsonValueAs<std::string>(item);
if (body.second != Wasm::Common::JsonParserResultDetail::OK) {
LOG_WARN("cannot parse block_bodies");
return false;
}
if (rule.case_sensitive) {
rule.block_bodys.push_back(std::move(body.first.value()));
} else {
rule.block_bodys.push_back(
absl::AsciiStrToLower(body.first.value()));
}
return true;
})) {
LOG_WARN("failed to parse configuration for block_bodies.");
return false;
}
if (rule.block_bodys.empty() && rule.block_headers.empty() && if (rule.block_bodys.empty() && rule.block_headers.empty() &&
rule.block_urls.empty()) { rule.block_urls.empty() && rule.block_exact_urls.empty() &&
rule.block_regexp_urls.empty()) {
LOG_WARN("there is no block rules"); LOG_WARN("there is no block rules");
return false; return false;
} }
@@ -172,6 +235,18 @@ bool PluginRootContext::checkHeader(const RequestBlockConfigRule& rule,
urlstr = absl::AsciiStrToLower(request_url); urlstr = absl::AsciiStrToLower(request_url);
url = urlstr; url = urlstr;
} }
for (const auto& block_url : rule.block_exact_urls) {
if (url == block_url) {
sendLocalResponse(rule.blocked_code, "", rule.blocked_message, {});
return false;
}
}
for (const auto& block_url : rule.block_regexp_urls) {
if (block_url->match(url)) {
sendLocalResponse(rule.blocked_code, "", rule.blocked_message, {});
return false;
}
}
for (const auto& block_url : rule.block_urls) { for (const auto& block_url : rule.block_urls) {
if (absl::StrContains(url, block_url)) { if (absl::StrContains(url, block_url)) {
sendLocalResponse(rule.blocked_code, "", rule.blocked_message, {}); sendLocalResponse(rule.blocked_code, "", rule.blocked_message, {});

View File

@@ -22,6 +22,7 @@
#include <unordered_map> #include <unordered_map>
#include "common/http_util.h" #include "common/http_util.h"
#include "common/regex.h"
#include "common/route_rule_matcher.h" #include "common/route_rule_matcher.h"
#define ASSERT(_X) assert(_X) #define ASSERT(_X) assert(_X)
@@ -39,11 +40,16 @@ namespace request_block {
#endif #endif
using ReMatcher = Wasm::Common::Regex::CompiledGoogleReMatcher;
using ReMatcherPtr = std::unique_ptr<ReMatcher>;
struct RequestBlockConfigRule { struct RequestBlockConfigRule {
int blocked_code = 403; int blocked_code = 403;
std::string blocked_message; std::string blocked_message;
bool case_sensitive = true; bool case_sensitive = true;
std::vector<std::string> block_urls; std::vector<std::string> block_urls;
std::vector<std::string> block_exact_urls;
std::vector<ReMatcherPtr> block_regexp_urls;
std::vector<std::string> block_headers; std::vector<std::string> block_headers;
std::vector<std::string> block_bodys; std::vector<std::string> block_bodys;
}; };

View File

@@ -127,6 +127,8 @@ TEST_F(RequestBlockTest, CaseSensitive) {
std::string configuration = R"( std::string configuration = R"(
{ {
"block_urls": ["?foo=bar", "swagger.html"], "block_urls": ["?foo=bar", "swagger.html"],
"block_exact_urls": ["/hello.html?abc=123"],
"block_regexp_urls": [".*monkey.*"],
"block_headers": ["headerKey", "headerValue"], "block_headers": ["headerKey", "headerValue"],
"block_bodys": ["Hello World"] "block_bodys": ["Hello World"]
})"; })";
@@ -150,6 +152,22 @@ TEST_F(RequestBlockTest, CaseSensitive) {
EXPECT_EQ(context_->onRequestHeaders(0, false), EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration); FilterHeadersStatus::StopIteration);
path_ = "/hello.html?abc=123";
EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_,
testing::_, testing::_));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
path_ = "/black/Monkey";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
path_ = "/black/monkey";
EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_,
testing::_, testing::_));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
path_ = ""; path_ = "";
headers_ = {{"headerKey", "123"}}; headers_ = {{"headerKey", "123"}};
EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_, EXPECT_CALL(*mock_context_, sendLocalResponse(403, testing::_, testing::_,
@@ -188,6 +206,8 @@ TEST_F(RequestBlockTest, CaseInsensitive) {
"blocked_code": 404, "blocked_code": 404,
"block_urls": ["?foo=bar", "swagger.html"], "block_urls": ["?foo=bar", "swagger.html"],
"block_headers": ["headerKey", "headerValue"], "block_headers": ["headerKey", "headerValue"],
"block_exact_urls": ["/hello.html?abc=123"],
"block_regexp_urls": [".*monkey.*"],
"block_bodys": ["Hello World"] "block_bodys": ["Hello World"]
})"; })";
@@ -206,6 +226,24 @@ TEST_F(RequestBlockTest, CaseInsensitive) {
EXPECT_EQ(context_->onRequestHeaders(0, false), EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration); FilterHeadersStatus::StopIteration);
path_ = "/Hello.html?abc=123";
EXPECT_CALL(*mock_context_, sendLocalResponse(404, testing::_, testing::_,
testing::_, testing::_));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
path_ = "/black/Monkey";
EXPECT_CALL(*mock_context_, sendLocalResponse(404, testing::_, testing::_,
testing::_, testing::_));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
path_ = "/black/monkey";
EXPECT_CALL(*mock_context_, sendLocalResponse(404, testing::_, testing::_,
testing::_, testing::_));
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
path_ = ""; path_ = "";
headers_ = {{"headerkey", "123"}}; headers_ = {{"headerkey", "123"}};
EXPECT_CALL(*mock_context_, sendLocalResponse(404, testing::_, testing::_, EXPECT_CALL(*mock_context_, sendLocalResponse(404, testing::_, testing::_,
@@ -232,6 +270,26 @@ TEST_F(RequestBlockTest, CaseInsensitive) {
FilterDataStatus::StopIterationNoBuffer); FilterDataStatus::StopIterationNoBuffer);
} }
TEST_F(RequestBlockTest, Bodies) {
std::string configuration = R"(
{
"case_sensitive": false,
"blocked_code": 404,
"block_bodies": ["Hello World"]
})";
config_.set({configuration.data(), configuration.size()});
EXPECT_TRUE(root_context_->configure(configuration.size()));
body_.set("hello world");
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
EXPECT_CALL(*mock_context_, sendLocalResponse(404, testing::_, testing::_,
testing::_, testing::_));
EXPECT_EQ(context_->onRequestBody(11, true),
FilterDataStatus::StopIterationNoBuffer);
}
} // namespace request_block } // namespace request_block
} // namespace null_plugin } // namespace null_plugin
} // namespace proxy_wasm } // namespace proxy_wasm

View File

@@ -2,6 +2,7 @@ package cache
import ( import (
"errors" "errors"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -62,7 +63,12 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.serviceName = json.Get("serviceName").String() c.serviceName = json.Get("serviceName").String()
c.servicePort = int(json.Get("servicePort").Int()) c.servicePort = int(json.Get("servicePort").Int())
if !json.Get("servicePort").Exists() { if !json.Get("servicePort").Exists() {
c.servicePort = 6379 if strings.HasSuffix(c.serviceName, ".static") {
// use default logic port which is 80 for static service
c.servicePort = 80
} else {
c.servicePort = 6379
}
} }
c.serviceHost = json.Get("serviceHost").String() c.serviceHost = json.Get("serviceHost").String()
c.username = json.Get("username").String() c.username = json.Get("username").String()

View File

@@ -79,11 +79,11 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
c.StreamResponseTemplate = json.Get("streamResponseTemplate").String() c.StreamResponseTemplate = json.Get("streamResponseTemplate").String()
if c.StreamResponseTemplate == "" { if c.StreamResponseTemplate == "" {
c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n" c.StreamResponseTemplate = `data:{"id":"from-cache","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + "\n\ndata:[DONE]\n\n"
} }
c.ResponseTemplate = json.Get("responseTemplate").String() c.ResponseTemplate = json.Get("responseTemplate").String()
if c.ResponseTemplate == "" { if c.ResponseTemplate == "" {
c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` c.ResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"from-cache","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
} }
if json.Get("enableSemanticCache").Exists() { if json.Get("enableSemanticCache").Exists() {

View File

@@ -74,6 +74,9 @@ func processCacheHit(key string, response string, stream bool, ctx wrapper.HttpC
ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil) ctx.SetContext(CACHE_KEY_CONTEXT_KEY, nil)
ctx.SetUserAttribute("cache_status", "hit")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
if stream { if stream {
proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1) proxywasm.SendHttpResponseWithDetail(200, "ai-cache.hit", [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}, []byte(fmt.Sprintf(c.StreamResponseTemplate, escapedResponse)), -1)
} else { } else {

View File

@@ -0,0 +1,158 @@
package embedding
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
COHERE_DOMAIN = "api.cohere.com"
COHERE_PORT = 443
COHERE_DEFAULT_MODEL_NAME = "embed-english-v2.0"
COHERE_ENDPOINT = "/v2/embed"
)
type cohereProviderInitializer struct {
}
var cohereConfig cohereProviderConfig
type cohereProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}
func (c *cohereProviderInitializer) InitConfig(json gjson.Result) {
cohereConfig.apiKey = json.Get("apiKey").String()
}
func (c *cohereProviderInitializer) ValidateConfig() error {
if cohereConfig.apiKey == "" {
return errors.New("[Cohere] apiKey is required")
}
return nil
}
func (t *cohereProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = COHERE_PORT
}
if c.serviceHost == "" {
c.serviceHost = COHERE_DOMAIN
}
return &CohereProvider{
config: c,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.serviceName,
Host: c.serviceHost,
Port: int64(c.servicePort),
}),
}, nil
}
type cohereResponse struct {
Embeddings cohereEmbeddings `json:"embeddings"`
}
type cohereEmbeddings struct {
FloatTypeEebedding [][]float64 `json:"float"`
}
type cohereEmbeddingRequest struct {
Texts []string `json:"texts"`
Model string `json:"model"`
InputType string `json:"input_type"`
EmbeddingTypes []string `json:"embedding_types"`
}
type CohereProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (t *CohereProvider) GetProviderType() string {
return PROVIDER_TYPE_COHERE
}
func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) {
model := t.config.model
if model == "" {
model = COHERE_DEFAULT_MODEL_NAME
}
data := cohereEmbeddingRequest{
Texts: texts,
Model: model,
InputType: "search_document",
EmbeddingTypes: []string{"float"},
}
requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("failed to marshal request data: %v", err)
return "", nil, nil, err
}
headers := [][2]string{
{"Authorization", fmt.Sprintf("BEARER %s", cohereConfig.apiKey)},
{"Content-Type", "application/json"},
}
return COHERE_ENDPOINT, headers, requestBody, nil
}
func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereResponse, error) {
var resp cohereResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (t *CohereProvider) GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error {
embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log)
if err != nil {
log.Errorf("failed to construct parameters: %v", err)
return err
}
var resp *cohereResponse
err = t.client.Post(embUrl, embHeaders, embRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode))
callback(nil, err)
return
}
log.Debugf("get embedding response: %d, %s", statusCode, responseBody)
resp, err = t.parseTextEmbedding(responseBody)
if err != nil {
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}
if len(resp.Embeddings.FloatTypeEebedding) == 0 {
err = errors.New("no embedding found in response")
callback(nil, err)
return
}
callback(resp.Embeddings.FloatTypeEebedding[0], nil)
}, t.config.timeout)
return err
}

View File

@@ -8,6 +8,7 @@ import (
"strconv" "strconv"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
) )
const ( const (
@@ -17,11 +18,22 @@ const (
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
) )
var dashScopeConfig dashScopeProviderConfig
type dashScopeProviderInitializer struct { type dashScopeProviderInitializer struct {
} }
type dashScopeProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}
func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { func (c *dashScopeProviderInitializer) InitConfig(json gjson.Result) {
if config.apiKey == "" { dashScopeConfig.apiKey = json.Get("apiKey").String()
}
func (c *dashScopeProviderInitializer) ValidateConfig() error {
if dashScopeConfig.apiKey == "" {
return errors.New("[DashScope] apiKey is required") return errors.New("[DashScope] apiKey is required")
} }
return nil return nil
@@ -114,14 +126,14 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin
return "", nil, nil, err return "", nil, nil, err
} }
if d.config.apiKey == "" { if dashScopeConfig.apiKey == "" {
err := errors.New("dashScopeKey is empty") err := errors.New("dashScopeKey is empty")
log.Errorf("failed to construct headers: %v", err) log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err return "", nil, nil, err
} }
headers := [][2]string{ headers := [][2]string{
{"Authorization", "Bearer " + d.config.apiKey}, {"Authorization", "Bearer " + dashScopeConfig.apiKey},
{"Content-Type", "application/json"}, {"Content-Type", "application/json"},
} }

View File

@@ -0,0 +1,170 @@
package embedding
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
OPENAI_DOMAIN = "api.openai.com"
OPENAI_PORT = 443
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
OPENAI_ENDPOINT = "/v1/embeddings"
)
type openAIProviderInitializer struct {
}
var openAIConfig openAIProviderConfig
type openAIProviderConfig struct {
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
}
func (c *openAIProviderInitializer) InitConfig(json gjson.Result) {
openAIConfig.apiKey = json.Get("apiKey").String()
}
func (c *openAIProviderInitializer) ValidateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openAI] apiKey is required")
}
return nil
}
func (t *openAIProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
if c.servicePort == 0 {
c.servicePort = OPENAI_PORT
}
if c.serviceHost == "" {
c.serviceHost = OPENAI_DOMAIN
}
if c.model == "" {
c.model = OPENAI_DEFAULT_MODEL_NAME
}
return &OpenAIProvider{
config: c,
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.serviceName,
Host: c.serviceHost,
Port: c.servicePort,
}),
}, nil
}
func (t *OpenAIProvider) GetProviderType() string {
return PROVIDER_TYPE_OPENAI
}
type OpenAIResponse struct {
Object string `json:"object"`
Data []OpenAIResult `json:"data"`
Model string `json:"model"`
Error *OpenAIError `json:"error"`
}
type OpenAIResult struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type OpenAIError struct {
Message string `json:"prompt_tokens"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
}
type OpenAIEmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type OpenAIProvider struct {
config ProviderConfig
client wrapper.HttpClient
}
func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (string, [][2]string, []byte, error) {
if text == "" {
err := errors.New("queryString text cannot be empty")
return "", nil, nil, err
}
data := OpenAIEmbeddingRequest{
Input: text,
Model: t.config.model,
}
requestBody, err := json.Marshal(data)
if err != nil {
log.Errorf("failed to marshal request data: %v", err)
return "", nil, nil, err
}
headers := [][2]string{
{"Authorization", fmt.Sprintf("Bearer %s", openAIConfig.apiKey)},
{"Content-Type", "application/json"},
}
return OPENAI_ENDPOINT, headers, requestBody, err
}
func (t *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
var resp OpenAIResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (t *OpenAIProvider) GetEmbedding(
queryString string,
ctx wrapper.HttpContext,
log wrapper.Log,
callback func(emb []float64, err error)) error {
embUrl, embHeaders, embRequestBody, err := t.constructParameters(queryString, log)
if err != nil {
log.Errorf("failed to construct parameters: %v", err)
return err
}
var resp *OpenAIResponse
err = t.client.Post(embUrl, embHeaders, embRequestBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
err = fmt.Errorf("failed to get embedding due to status code: %d, resp: %s", statusCode, responseBody)
callback(nil, err)
return
}
resp, err = t.parseTextEmbedding(responseBody)
if err != nil {
err = fmt.Errorf("failed to parse response: %v", err)
callback(nil, err)
return
}
log.Debugf("get embedding response: %d, %s", statusCode, responseBody)
if len(resp.Data) == 0 {
err = errors.New("no embedding found in response")
callback(nil, err)
return
}
callback(resp.Data[0].Embedding, nil)
}, t.config.timeout)
return err
}

View File

@@ -10,10 +10,13 @@ import (
const ( const (
PROVIDER_TYPE_DASHSCOPE = "dashscope" PROVIDER_TYPE_DASHSCOPE = "dashscope"
PROVIDER_TYPE_TEXTIN = "textin" PROVIDER_TYPE_TEXTIN = "textin"
PROVIDER_TYPE_COHERE = "cohere"
PROVIDER_TYPE_OPENAI = "openai"
) )
type providerInitializer interface { type providerInitializer interface {
ValidateConfig(ProviderConfig) error InitConfig(json gjson.Result)
ValidateConfig() error
CreateProvider(ProviderConfig) (Provider, error) CreateProvider(ProviderConfig) (Provider, error)
} }
@@ -21,6 +24,8 @@ var (
providerInitializers = map[string]providerInitializer{ providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{}, PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{},
PROVIDER_TYPE_COHERE: &cohereProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
} }
) )
@@ -37,35 +42,26 @@ type ProviderConfig struct {
// @Title zh-CN 文本特征提取服务端口 // @Title zh-CN 文本特征提取服务端口
// @Description zh-CN 文本特征提取服务端口 // @Description zh-CN 文本特征提取服务端口
servicePort int64 servicePort int64
// @Title zh-CN 文本特征提取服务 API Key
// @Description zh-CN 文本特征提取服务 API Key
apiKey string
//@Title zh-CN TextIn x-ti-app-id
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinAppId string
//@Title zh-CN TextIn x-ti-secret-code
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinSecretCode string
//@Title zh-CN TextIn request matryoshka_dim
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
textinMatryoshkaDim int
// @Title zh-CN 文本特征提取服务超时时间 // @Title zh-CN 文本特征提取服务超时时间
// @Description zh-CN 文本特征提取服务超时时间 // @Description zh-CN 文本特征提取服务超时时间
timeout uint32 timeout uint32
// @Title zh-CN 文本特征提取服务使用的模型 // @Title zh-CN 文本特征提取服务使用的模型
// @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1" // @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1"
model string model string
initializer providerInitializer
} }
func (c *ProviderConfig) FromJson(json gjson.Result) { func (c *ProviderConfig) FromJson(json gjson.Result) {
c.typ = json.Get("type").String() c.typ = json.Get("type").String()
i, has := providerInitializers[c.typ]
if has {
i.InitConfig(json)
c.initializer = i
}
c.serviceName = json.Get("serviceName").String() c.serviceName = json.Get("serviceName").String()
c.serviceHost = json.Get("serviceHost").String() c.serviceHost = json.Get("serviceHost").String()
c.servicePort = json.Get("servicePort").Int() c.servicePort = json.Get("servicePort").Int()
c.apiKey = json.Get("apiKey").String()
c.textinAppId = json.Get("textinAppId").String()
c.textinSecretCode = json.Get("textinSecretCode").String()
c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
c.timeout = uint32(json.Get("timeout").Int()) c.timeout = uint32(json.Get("timeout").Int())
c.model = json.Get("model").String() c.model = json.Get("model").String()
if c.timeout == 0 { if c.timeout == 0 {
@@ -80,11 +76,10 @@ func (c *ProviderConfig) Validate() error {
if c.typ == "" { if c.typ == "" {
return errors.New("embedding service type is required") return errors.New("embedding service type is required")
} }
initializer, has := providerInitializers[c.typ] if c.initializer == nil {
if !has {
return errors.New("unknown embedding service provider type: " + c.typ) return errors.New("unknown embedding service provider type: " + c.typ)
} }
if err := initializer.ValidateConfig(*c); err != nil { if err := c.initializer.ValidateConfig(); err != nil {
return err return err
} }
return nil return nil

View File

@@ -8,6 +8,7 @@ import (
"strconv" "strconv"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
) )
const ( const (
@@ -20,14 +21,34 @@ const (
type textInProviderInitializer struct { type textInProviderInitializer struct {
} }
func (t *textInProviderInitializer) ValidateConfig(config ProviderConfig) error { var textInConfig textInProviderConfig
if config.textinAppId == "" {
return errors.New("embedding service TextIn App ID is required") type textInProviderConfig struct {
//@Title zh-CN TextIn x-ti-app-id
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinAppId string
//@Title zh-CN TextIn x-ti-secret-code
// @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding
textinSecretCode string
//@Title zh-CN TextIn request matryoshka_dim
// @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding
textinMatryoshkaDim int
}
func (c *textInProviderInitializer) InitConfig(json gjson.Result) {
textInConfig.textinAppId = json.Get("textinAppId").String()
textInConfig.textinSecretCode = json.Get("textinSecretCode").String()
textInConfig.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int())
}
func (c *textInProviderInitializer) ValidateConfig() error {
if textInConfig.textinAppId == "" {
return errors.New("textinAppId is required")
} }
if config.textinSecretCode == "" { if textInConfig.textinSecretCode == "" {
return errors.New("embedding service TextIn Secret Code is required") return errors.New("textinSecretCode is required")
} }
if config.textinMatryoshkaDim == 0 { if textInConfig.textinMatryoshkaDim == 0 {
return errors.New("embedding service TextIn Matryoshka Dim is required") return errors.New("embedding service TextIn Matryoshka Dim is required")
} }
return nil return nil
@@ -62,7 +83,7 @@ type TextInResponse struct {
} }
type TextInResult struct { type TextInResult struct {
Embeddings [][]float64 `json:"embedding"` Embeddings [][]float64 `json:"embedding"`
MatryoshkaDim int `json:"matryoshka_dim"` MatryoshkaDim int `json:"matryoshka_dim"`
} }
@@ -80,7 +101,7 @@ func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (strin
data := TextInEmbeddingRequest{ data := TextInEmbeddingRequest{
Input: texts, Input: texts,
MatryoshkaDim: t.config.textinMatryoshkaDim, MatryoshkaDim: textInConfig.textinMatryoshkaDim,
} }
requestBody, err := json.Marshal(data) requestBody, err := json.Marshal(data)
@@ -89,20 +110,20 @@ func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (strin
return "", nil, nil, err return "", nil, nil, err
} }
if t.config.textinAppId == "" { if textInConfig.textinAppId == "" {
err := errors.New("textinAppId is empty") err := errors.New("textinAppId is empty")
log.Errorf("failed to construct headers: %v", err) log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err return "", nil, nil, err
} }
if t.config.textinSecretCode == "" { if textInConfig.textinSecretCode == "" {
err := errors.New("textinSecretCode is empty") err := errors.New("textinSecretCode is empty")
log.Errorf("failed to construct headers: %v", err) log.Errorf("failed to construct headers: %v", err)
return "", nil, nil, err return "", nil, nil, err
} }
headers := [][2]string{ headers := [][2]string{
{"x-ti-app-id", t.config.textinAppId}, {"x-ti-app-id", textInConfig.textinAppId},
{"x-ti-secret-code", t.config.textinSecretCode}, {"x-ti-secret-code", textInConfig.textinSecretCode},
{"Content-Type", "application/json"}, {"Content-Type", "application/json"},
} }

View File

@@ -8,14 +8,14 @@ replace github.com/alibaba/higress/plugins/wasm-go => ../..
require ( require (
github.com/alibaba/higress/plugins/wasm-go v1.4.2 github.com/alibaba/higress/plugins/wasm-go v1.4.2
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f github.com/google/uuid v1.6.0
github.com/higress-group/proxy-wasm-go-sdk v1.0.0
github.com/tidwall/gjson v1.17.3 github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1 github.com/tidwall/resp v0.1.1
// github.com/weaviate/weaviate-go-client/v4 v4.15.1 // github.com/weaviate/weaviate-go-client/v4 v4.15.1
) )
require ( require (
github.com/google/uuid v1.6.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect github.com/magefile/mage v1.14.0 // indirect
github.com/stretchr/testify v1.9.0 // indirect github.com/stretchr/testify v1.9.0 // indirect

View File

@@ -3,8 +3,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -22,6 +22,8 @@ const (
STREAM_CONTEXT_KEY = "stream" STREAM_CONTEXT_KEY = "stream"
SKIP_CACHE_HEADER = "x-higress-skip-ai-cache" SKIP_CACHE_HEADER = "x-higress-skip-ai-cache"
ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage" ERROR_PARTIAL_MESSAGE_KEY = "errorPartialMessage"
DEFAULT_MAX_BODY_BYTES uint32 = 10 * 1024 * 1024
) )
func main() { func main() {
@@ -69,6 +71,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wr
ctx.DontReadRequestBody() ctx.DontReadRequestBody()
return types.ActionContinue return types.ActionContinue
} }
ctx.SetRequestBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding") _ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
// The request has a body and requires delaying the header transmission until a cache miss occurs, // The request has a body and requires delaying the header transmission until a cache miss occurs,
// at which point the header should be sent. // at which point the header should be sent.
@@ -128,12 +131,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, c config.PluginConfig, body []by
func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, c config.PluginConfig, log wrapper.Log) types.Action {
skipCache := ctx.GetContext(SKIP_CACHE_HEADER) skipCache := ctx.GetContext(SKIP_CACHE_HEADER)
if skipCache != nil { if skipCache != nil {
ctx.SetUserAttribute("cache_status", "skip")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
return types.ActionContinue return types.ActionContinue
} }
if ctx.GetContext(CACHE_KEY_CONTEXT_KEY) != nil {
ctx.SetUserAttribute("cache_status", "miss")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
contentType, _ := proxywasm.GetHttpResponseHeader("content-type") contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
if strings.Contains(contentType, "text/event-stream") { if strings.Contains(contentType, "text/event-stream") {
ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{}) ctx.SetContext(STREAM_CONTEXT_KEY, struct{}{})
} else {
ctx.SetResponseBodyBufferLimit(DEFAULT_MAX_BODY_BYTES)
} }
if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil { if ctx.GetContext(ERROR_PARTIAL_MESSAGE_KEY) != nil {
@@ -158,22 +169,26 @@ func onHttpResponseBody(ctx wrapper.HttpContext, c config.PluginConfig, chunk []
return chunk return chunk
} }
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
var err error
if !isLastChunk { if !isLastChunk {
if err := handleNonLastChunk(ctx, c, chunk, log); err != nil { if stream == nil {
err = handleNonStreamChunk(ctx, c, chunk, log)
} else {
err = handleStreamChunk(ctx, c, unifySSEChunk(chunk), log)
}
if err != nil {
log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err) log.Errorf("[onHttpResponseBody] handle non last chunk failed, error: %v", err)
// Set an empty struct in the context to indicate an error in processing the partial message // Set an empty struct in the context to indicate an error in processing the partial message
ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{}) ctx.SetContext(ERROR_PARTIAL_MESSAGE_KEY, struct{}{})
} }
return chunk return chunk
} }
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
var value string var value string
var err error
if stream == nil { if stream == nil {
value, err = processNonStreamLastChunk(ctx, c, chunk, log) value, err = processNonStreamLastChunk(ctx, c, chunk, log)
} else { } else {
value, err = processStreamLastChunk(ctx, c, chunk, log) value, err = processStreamLastChunk(ctx, c, unifySSEChunk(chunk), log)
} }
if err != nil { if err != nil {

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"strings" "strings"
@@ -9,17 +10,6 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
func handleNonLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
stream := ctx.GetContext(STREAM_CONTEXT_KEY)
err := error(nil)
if stream == nil {
err = handleNonStreamChunk(ctx, c, chunk, log)
} else {
err = handleStreamChunk(ctx, c, chunk, log)
}
return err
}
func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
if tempContentI == nil { if tempContentI == nil {
@@ -32,6 +22,12 @@ func handleNonStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk
return nil return nil
} }
func unifySSEChunk(data []byte) []byte {
data = bytes.ReplaceAll(data, []byte("\r\n"), []byte("\n"))
data = bytes.ReplaceAll(data, []byte("\r"), []byte("\n"))
return data
}
func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error { func handleStreamChunk(ctx wrapper.HttpContext, c config.PluginConfig, chunk []byte, log wrapper.Log) error {
var partialMessage []byte var partialMessage []byte
partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY) partialMessageI := ctx.GetContext(PARTIAL_MESSAGE_CONTEXT_KEY)
@@ -101,55 +97,54 @@ func processStreamLastChunk(ctx wrapper.HttpContext, c config.PluginConfig, chun
} }
func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) { func processSSEMessage(ctx wrapper.HttpContext, c config.PluginConfig, sseMessage string, log wrapper.Log) (string, error) {
subMessages := strings.Split(sseMessage, "\n") content := ""
var message string for _, chunk := range strings.Split(sseMessage, "\n\n") {
for _, msg := range subMessages { log.Debugf("single sse message: %s", chunk)
if strings.HasPrefix(msg, "data:") { subMessages := strings.Split(chunk, "\n")
message = msg var message string
break for _, msg := range subMessages {
if strings.HasPrefix(msg, "data:") {
message = msg
break
}
} }
} if len(message) < 6 {
if len(message) < 6 { return content, fmt.Errorf("[processSSEMessage] invalid message: %s", message)
return "", fmt.Errorf("[processSSEMessage] invalid message: %s", message)
}
// skip the prefix "data:"
bodyJson := message[5:]
if strings.TrimSpace(bodyJson) == "[DONE]" {
return "", nil
}
// Extract values from JSON fields
responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
if toolCalls.Exists() {
// TODO: Temporarily store the tool_calls value in the context for processing
ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
}
// Check if the ResponseBody field exists
if !responseBody.Exists() {
if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
return "", nil
} }
return "", fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
} else {
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
// If there is no content in the cache, initialize and set the content // skip the prefix "data:"
if tempContentI == nil { bodyJson := message[5:]
content := responseBody.String()
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content) if strings.TrimSpace(bodyJson) == "[DONE]" {
return content, nil return content, nil
} }
// Update the content in the cache // Extract values from JSON fields
appendMsg := responseBody.String() responseBody := gjson.Get(bodyJson, c.CacheStreamValueFrom)
content := tempContentI.(string) + appendMsg toolCalls := gjson.Get(bodyJson, c.CacheToolCallsFrom)
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
return content, nil if toolCalls.Exists() {
// TODO: Temporarily store the tool_calls value in the context for processing
ctx.SetContext(TOOL_CALLS_CONTEXT_KEY, toolCalls.String())
}
// Check if the ResponseBody field exists
if !responseBody.Exists() {
if ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY) != nil {
log.Debugf("[processSSEMessage] unable to extract content from message; cache content is not nil: %s", message)
return content, nil
}
return content, fmt.Errorf("[processSSEMessage] unable to extract content from message; cache content is nil: %s", message)
} else {
content += responseBody.String()
}
} }
tempContentI := ctx.GetContext(CACHE_CONTENT_CONTEXT_KEY)
// If there is no content in the cache, initialize and set the content
if tempContentI == nil {
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, content)
} else {
ctx.SetContext(CACHE_CONTENT_CONTEXT_KEY, tempContentI.(string)+content)
}
return content, nil
} }

View File

@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=

View File

@@ -194,6 +194,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
ctx.SetContext(StreamContextKey, struct{}{}) ctx.SetContext(StreamContextKey, struct{}{})
} }
identityKey := ctx.GetStringContext(IdentityKey, "") identityKey := ctx.GetStringContext(IdentityKey, "")
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
if question == "" {
log.Debug("parse question from request body failed")
return types.ActionContinue
}
ctx.SetContext(QuestionContextKey, question)
err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) { err := config.redisClient.Get(config.CacheKeyPrefix+identityKey, func(response resp.Value) {
if err := response.Error(); err != nil { if err := response.Error(); err != nil {
log.Errorf("redis get failed, err:%v", err) log.Errorf("redis get failed, err:%v", err)
@@ -230,13 +236,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
_ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1) _ = proxywasm.SendHttpResponseWithDetail(200, "OK", [][2]string{{"content-type", "application/json; charset=utf-8"}}, res, -1)
return return
} }
question := TrimQuote(bodyJson.Get(config.QuestionFrom.RequestBody).String())
if question == "" {
log.Debug("parse question from request body failed")
_ = proxywasm.ResumeHttpRequest()
return
}
ctx.SetContext(QuestionContextKey, question)
fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2 fillHistoryCnt := getIntQueryParameter("fill_history_cnt", path, config.FillHistoryCnt) * 2
currJson := bodyJson.Get("messages").String() currJson := bodyJson.Get("messages").String()
var currMessage []ChatHistory var currMessage []ChatHistory
@@ -317,38 +316,39 @@ func getIntQueryParameter(name string, path string, defaultValue int) int {
} }
func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string { func processSSEMessage(ctx wrapper.HttpContext, config PluginConfig, sseMessage string, log wrapper.Log) string {
subMessages := strings.Split(sseMessage, "\n") content := ""
var message string for _, chunk := range strings.Split(sseMessage, "\n\n") {
for _, msg := range subMessages { subMessages := strings.Split(chunk, "\n")
if strings.HasPrefix(msg, "data:") { var message string
message = msg for _, msg := range subMessages {
break if strings.HasPrefix(msg, "data:") {
message = msg
break
}
} }
} if len(message) < 6 {
if len(message) < 6 { log.Errorf("invalid message:%s", message)
log.Errorf("invalid message:%s", message)
return ""
}
// skip the prefix "data:"
bodyJson := message[5:]
if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
tempContentI := ctx.GetContext(AnswerContentContextKey)
if tempContentI == nil {
content := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
ctx.SetContext(AnswerContentContextKey, content)
return content return content
} }
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw) // skip the prefix "data:"
content := tempContentI.(string) + append bodyJson := message[5:]
ctx.SetContext(AnswerContentContextKey, content) if gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Exists() {
return content tempContentI := ctx.GetContext(AnswerContentContextKey)
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() { if tempContentI == nil {
// TODO: compatible with other providers content = TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
ctx.SetContext(ToolCallsContextKey, struct{}{}) ctx.SetContext(AnswerContentContextKey, content)
return "" } else {
append := TrimQuote(gjson.Get(bodyJson, config.AnswerStreamValueFrom.ResponseBody).Raw)
content = tempContentI.(string) + append
ctx.SetContext(AnswerContentContextKey, content)
}
} else if gjson.Get(bodyJson, "choices.0.delta.content.tool_calls").Exists() {
// TODO: compatible with other providers
ctx.SetContext(ToolCallsContextKey, struct{}{})
}
log.Debugf("unknown message:%s", bodyJson)
} }
log.Debugf("unknown message:%s", bodyJson) return content
return ""
} }
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {

View File

@@ -34,7 +34,7 @@ func parseConfig(json gjson.Result, config *AIPromptTemplateConfig, log wrapper.
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptTemplateConfig, log wrapper.Log) types.Action {
templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable") templateEnable, _ := proxywasm.GetHttpRequestHeader("template-enable")
if templateEnable != "true" { if templateEnable == "false" {
ctx.DontReadRequestBody() ctx.DontReadRequestBody()
return types.ActionContinue return types.ActionContinue
} }

View File

@@ -41,6 +41,7 @@ description: AI 代理插件配置参考
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 | | `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 | | `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 | | `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
`context`的配置字段说明如下: `context`的配置字段说明如下:
@@ -78,14 +79,22 @@ custom-setting会遵循如下表格根据`name`和协议来替换对应的字
`failover` 的配置字段说明如下: `failover` 的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|------|-------|-----------------------------| |------------------|--------|-----------------|-------|-----------------------------|
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 | | enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) | | failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) | | successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 | | healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 | | healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
| healthCheckModel | string | 必填 | | 健康检测使用的模型 | | healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
`retryOnFailure` 的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|--------|-----------------|-------|-------------|
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
| retryTimeout | int | 非必填 | 30000 | 重试超时时间,单位毫秒 |
### 提供商特有配置 ### 提供商特有配置
@@ -174,9 +183,10 @@ Mistral 所对应的 `type` 为 `mistral`。它并无特有的配置字段。
MiniMax所对应的 `type``minimax`。它特有的配置字段如下: MiniMax所对应的 `type``minimax`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | | 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ---------------- | -------- | ------------------------------------------------------------ | ------ | ------------------------------------------------------------ | | ---------------- | -------- | ------------------------------ | ------ |----------------------------------------------------------------|
| `minimaxGroupId` | string | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时必填 | - | 当使用`abab6.5-chat`, `abab6.5s-chat`, `abab5.5s-chat`, `abab5.5-chat`四种模型时会使用ChatCompletion Pro需要设置groupID | | `minimaxApiType` | string | v2 和 pro 中选填一项 | v2 | v2 代表 ChatCompletion v2 APIpro 代表 ChatCompletion Pro API |
| `minimaxGroupId` | string | `minimaxApiType` 为 pro 时必填 | - | `minimaxApiType` 为 pro 时使用 ChatCompletion Pro API需要设置 groupID |
#### Anthropic Claude #### Anthropic Claude
@@ -242,6 +252,9 @@ DeepL 所对应的 `type` 为 `deepl`。它特有的配置字段如下:
Cohere 所对应的 `type``cohere`。它并无特有的配置字段。 Cohere 所对应的 `type``cohere`。它并无特有的配置字段。
#### Together-AI
Together-AI 所对应的 `type``together-ai`。它并无特有的配置字段。
## 用法示例 ## 用法示例
### 使用 OpenAI 协议代理 Azure OpenAI 服务 ### 使用 OpenAI 协议代理 Azure OpenAI 服务
@@ -1000,17 +1013,16 @@ provider:
apiTokens: apiTokens:
- "YOUR_MINIMAX_API_TOKEN" - "YOUR_MINIMAX_API_TOKEN"
modelMapping: modelMapping:
"gpt-3": "abab6.5g-chat" "gpt-3": "abab6.5s-chat"
"gpt-4": "abab6.5-chat" "gpt-4": "abab6.5g-chat"
"*": "abab6.5g-chat" "*": "abab6.5t-chat"
minimaxGroupId: "YOUR_MINIMAX_GROUP_ID"
``` ```
**请求示例** **请求示例**
```json ```json
{ {
"model": "gpt-4-turbo", "model": "gpt-3",
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@@ -1025,27 +1037,33 @@ provider:
```json ```json
{ {
"id": "02b2251f8c6c09d68c1743f07c72afd7", "id": "03ac4fcfe1c6cc9c6a60f9d12046e2b4",
"choices": [ "choices": [
{ {
"finish_reason": "stop", "finish_reason": "stop",
"index": 0, "index": 0,
"message": { "message": {
"content": "你好我是MM智能助理一款由MiniMax自研的大型语言模型。我可以帮助你解答问题提供信息进行对话等。有什么可以帮助你的吗?", "content": "你好我是一个由MiniMax公司研发的大型语言模型名为MM智能助理。我可以帮助答问题提供信息进行对话和执行多种语言处理任务。如果你有任何问题或需要帮助,请随时告诉我!",
"role": "assistant" "role": "assistant",
"name": "MM智能助理",
"audio_content": ""
} }
} }
], ],
"created": 1717760544, "created": 1734155471,
"model": "abab6.5s-chat", "model": "abab6.5s-chat",
"object": "chat.completion", "object": "chat.completion",
"usage": { "usage": {
"total_tokens": 106 "total_tokens": 116,
"total_characters": 0,
"prompt_tokens": 70,
"completion_tokens": 46
}, },
"input_sensitive": false, "input_sensitive": false,
"output_sensitive": false, "output_sensitive": false,
"input_sensitive_type": 0, "input_sensitive_type": 0,
"output_sensitive_type": 0, "output_sensitive_type": 0,
"output_sensitive_int": 0,
"base_resp": { "base_resp": {
"status_code": 0, "status_code": 0,
"status_msg": "" "status_msg": ""
@@ -1490,6 +1508,61 @@ provider:
} }
``` ```
### 使用 OpenAI 协议代理 Together-AI 服务
**配置信息**
```yaml
provider:
type: together-ai
apiTokens:
- "YOUR_TOGETHER_AI_API_TOKEN"
modelMapping:
"*": "Qwen/Qwen2.5-72B-Instruct-Turbo"
```
**请求示例**
```json
{
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"messages": [
{
"role": "user",
"content": "Who are you?"
}
]
}
```
**响应示例**
```json
{
"id": "8f5809d54b73efac",
"object": "chat.completion",
"created": 1734785851,
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"prompt": [],
"choices": [
{
"finish_reason": "eos",
"seed": 12830868308626506000,
"logprobs": null,
"index": 0,
"message": {
"role": "assistant",
"content": "I am Qwen, a large language model created by Alibaba Cloud. I am designed to assist users in generating various types of text, such as articles, stories, poems, and more, as well as answering questions and providing information on a wide range of topics. How can I assist you today?",
"tool_calls": []
}
}
],
"usage": {
"prompt_tokens": 33,
"completion_tokens": 61,
"total_tokens": 94
}
}
```
## 完整配置示例 ## 完整配置示例
### Kubernetes 示例 ### Kubernetes 示例

View File

@@ -1356,6 +1356,60 @@ Here, `model` denotes the service tier of DeepL and can only be either `Free` or
} }
``` ```
### Utilizing OpenAI Protocol Proxy for Together-AI Services
**Configuration Information**
```yaml
provider:
type: together-ai
apiTokens:
- "YOUR_TOGETHER_AI_API_TOKEN"
modelMapping:
"*": "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
```
**Request Example**
```json
{
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"messages": [
{
"role": "user",
"content": "Who are you?"
}
]
}
```
**Response Example**
```json
{
"id": "8f5809d54b73efac",
"object": "chat.completion",
"created": 1734785851,
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
"prompt": [],
"choices": [
{
"finish_reason": "eos",
"seed": 12830868308626506000,
"logprobs": null,
"index": 0,
"message": {
"role": "assistant",
"content": "I am Qwen, a large language model created by Alibaba Cloud. I am designed to assist users in generating various types of text, such as articles, stories, poems, and more, as well as answering questions and providing information on a wide range of topics. How can I assist you today?",
"tool_calls": []
}
}
],
"usage": {
"prompt_tokens": 33,
"completion_tokens": 61,
"total_tokens": 94
}
}
```
## Full Configuration Example ## Full Configuration Example
### Kubernetes Example ### Kubernetes Example

View File

@@ -20,8 +20,6 @@ import (
const ( const (
pluginName = "ai-proxy" pluginName = "ai-proxy"
ctxKeyApiName = "apiName"
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024 defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
) )
@@ -89,29 +87,34 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
} }
if apiName == "" { if apiName == "" {
log.Debugf("[onHttpRequestHeader] unsupported path: %s", path.Path) log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
// _ = util.SendResponse(404, "ai-proxy.unknown_api", util.MimeTypeTextPlain, "API not found: "+path.Path)
log.Debugf("[onHttpRequestHeader] no send response")
return types.ActionContinue return types.ActionContinue
} }
ctx.SetContext(ctxKeyApiName, apiName)
ctx.SetContext(provider.CtxKeyApiName, apiName)
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if needHandleStreamingBody {
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
}
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok { if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Set the apiToken for the current request. // Set the apiToken for the current request.
providerConfig.SetApiTokenInUse(ctx, log) providerConfig.SetApiTokenInUse(ctx, log)
hasRequestBody := wrapper.HasRequestBody() hasRequestBody := wrapper.HasRequestBody()
action, err := handler.OnRequestHeaders(ctx, apiName, log) err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil { if err == nil {
if hasRequestBody { if hasRequestBody {
proxywasm.RemoveHttpRequestHeader("Content-Length")
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes) ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Always return types.HeaderStopIteration to support fallback routing, // Delay the header processing to allow changing in OnRequestBody
// as long as onHttpRequestBody can be called.
return types.HeaderStopIteration return types.HeaderStopIteration
} }
return action ctx.DontReadRequestBody()
return types.ActionContinue
} }
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err)) util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
@@ -132,7 +135,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType()) log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body) newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
if settingErr != nil { if settingErr != nil {
@@ -180,32 +183,25 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
log.Errorf("unable to load :status header from response: %v", err) log.Errorf("unable to load :status header from response: %v", err)
} }
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
providerConfig.OnRequestFailed(ctx, apiTokenInUse, log) return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
return types.ActionContinue
} }
// Reset ctxApiTokenRequestFailureCount if the request is successful, // Reset ctxApiTokenRequestFailureCount if the request is successful,
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold. // the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log) providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { headers := util.GetOriginalResponseHeaders()
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
action, err := handler.OnResponseHeaders(ctx, apiName, log) apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
if err == nil { handler.TransformResponseHeaders(ctx, apiName, headers, log)
checkStream(&ctx, log) } else {
return action providerConfig.DefaultTransformResponseHeaders(ctx, headers)
}
util.ErrorHandler("ai-proxy.proc_resp_headers_failed", fmt.Errorf("failed to process response headers: %v", err))
return types.ActionContinue
} }
util.ReplaceResponseHeaders(headers)
checkStream(&ctx, log) checkStream(ctx, log)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler) _, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleBody && !needHandleStreamingBody { if !needHandleStreamingBody {
ctx.DontReadResponseBody()
} else if !needHandleStreamingBody {
ctx.BufferResponseBody() ctx.BufferResponseBody()
} }
@@ -224,7 +220,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok { if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log) modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
if err == nil && modifiedChunk != nil { if err == nil && modifiedChunk != nil {
return modifiedChunk return modifiedChunk
@@ -243,27 +239,29 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
} }
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType()) log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
//log.Debugf("response body: %s", string(body))
if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok { if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseBody(ctx, apiName, body, log) body, err := handler.TransformResponseBody(ctx, apiName, body, log)
if err == nil { if err != nil {
return action util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
if err = provider.ReplaceResponseBody(body, log); err != nil {
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
} }
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
} }
return types.ActionContinue return types.ActionContinue
} }
func checkStream(ctx *wrapper.HttpContext, log wrapper.Log) { func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type") contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") { if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil { if err != nil {
log.Errorf("unable to load content-type header from response: %v", err) log.Errorf("unable to load content-type header from response: %v", err)
} }
(*ctx).BufferResponseBody() ctx.BufferResponseBody()
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
} }
} }

View File

@@ -22,7 +22,7 @@ type ai360Provider struct {
contextCache *contextCache contextCache *contextCache
} }
func (m *ai360ProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *ai360ProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -40,13 +40,13 @@ func (m *ai360Provider) GetProviderType() string {
return providerTypeAi360 return providerTypeAi360
} }
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody // Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil return nil
} }
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -58,7 +58,5 @@ func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, ai360Domain) util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestAuthorizationHeader(headers, "Authorization "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }

View File

@@ -15,7 +15,7 @@ import (
type azureProviderInitializer struct { type azureProviderInitializer struct {
} }
func (m *azureProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *azureProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.azureServiceUrl == "" { if config.azureServiceUrl == "" {
return errors.New("missing azureServiceUrl in provider config") return errors.New("missing azureServiceUrl in provider config")
} }
@@ -53,12 +53,12 @@ func (m *azureProvider) GetProviderType() string {
return providerTypeAzure return providerTypeAzure
} }
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -86,6 +86,6 @@ func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI()) util.OverwriteRequestPathHeader(headers, m.serviceUrl.RequestURI())
} }
util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host) util.OverwriteRequestHostHeader(headers, m.serviceUrl.Host)
util.OverwriteRequestAuthorizationHeader(headers, "api-key "+m.config.GetApiTokenInUse(ctx)) headers.Set("api-key", m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length") headers.Del("Content-Length")
} }

View File

@@ -19,7 +19,7 @@ const (
type baichuanProviderInitializer struct { type baichuanProviderInitializer struct {
} }
func (m *baichuanProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -42,12 +42,12 @@ func (m *baichuanProvider) GetProviderType() string {
return providerTypeBaichuan return providerTypeBaichuan
} }
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -34,16 +34,15 @@ const (
type baiduProviderInitializer struct{} type baiduProviderInitializer struct{}
func (g *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error { func (g *baiduProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 { if config.baiduAccessKeyAndSecret == nil || len(config.baiduAccessKeyAndSecret) == 0 {
return errors.New("no baiduAccessKeyAndSecret found in provider config") return errors.New("no baiduAccessKeyAndSecret found in provider config")
} }
if config.baiduApiTokenServiceName == "" { if config.baiduApiTokenServiceName == "" {
return errors.New("no baiduApiTokenServiceName found in provider config") return errors.New("no baiduApiTokenServiceName found in provider config")
} }
if !config.failover.enabled { // baidu use access key and access secret to refresh apiToken regularly, the apiToken should be accessed globally (via all Wasm VMs)
config.useGlobalApiToken = true config.useGlobalApiToken = true
}
return nil return nil
} }
@@ -63,12 +62,12 @@ func (g *baiduProvider) GetProviderType() string {
return providerTypeBaidu return providerTypeBaidu
} }
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
g.config.handleRequestHeaders(g, ctx, apiName, log) g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -10,7 +10,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
) )
@@ -79,7 +78,7 @@ type claudeTextGenDelta struct {
StopSequence *string `json:"stop_sequence"` StopSequence *string `json:"stop_sequence"`
} }
func (c *claudeProviderInitializer) ValidateConfig(config ProviderConfig) error { func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -102,27 +101,25 @@ func (c *claudeProvider) GetProviderType() string {
return providerTypeClaude return providerTypeClaude
} }
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
c.config.handleRequestHeaders(c, ctx, apiName, log) c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath) util.OverwriteRequestPathHeader(headers, claudeChatCompletionPath)
util.OverwriteRequestHostHeader(headers, claudeDomain) util.OverwriteRequestHostHeader(headers, claudeDomain)
headers.Add("x-api-key", c.config.GetApiTokenInUse(ctx)) headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
if c.config.claudeVersion == "" { if c.config.claudeVersion == "" {
c.config.claudeVersion = defaultVersion c.config.claudeVersion = defaultVersion
} }
headers.Add("anthropic-version", c.config.claudeVersion) headers.Set("anthropic-version", c.config.claudeVersion)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -141,27 +138,16 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
return json.Marshal(claudeRequest) return json.Marshal(claudeRequest)
} }
func (c *claudeProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
claudeResponse := &claudeTextGenResponse{} claudeResponse := &claudeTextGenResponse{}
if err := json.Unmarshal(body, claudeResponse); err != nil { if err := json.Unmarshal(body, claudeResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal claude response: %v", err) return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)
} }
if claudeResponse.Error != nil { if claudeResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message) return nil, fmt.Errorf("claude response error, error_type: %s, error_message: %s", claudeResponse.Error.Type, claudeResponse.Error.Message)
} }
response := c.responseClaude2OpenAI(ctx, claudeResponse) response := c.responseClaude2OpenAI(ctx, claudeResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
}
func (c *claudeProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
// use original protocol, skip OnStreamingResponseBody() and OnResponseBody()
if c.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
} }
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {

View File

@@ -19,7 +19,7 @@ const (
type cloudflareProviderInitializer struct { type cloudflareProviderInitializer struct {
} }
func (c *cloudflareProviderInitializer) ValidateConfig(config ProviderConfig) error { func (c *cloudflareProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -42,12 +42,12 @@ func (c *cloudflareProvider) GetProviderType() string {
return providerTypeCloudflare return providerTypeCloudflare
} }
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
c.config.handleRequestHeaders(c, ctx, apiName, log) c.config.handleRequestHeaders(c, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -61,6 +61,4 @@ func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1)) util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
util.OverwriteRequestHostHeader(headers, cloudflareDomain) util.OverwriteRequestHostHeader(headers, cloudflareDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }

View File

@@ -3,11 +3,12 @@ package provider
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
) )
const ( const (
@@ -17,7 +18,7 @@ const (
type cohereProviderInitializer struct{} type cohereProviderInitializer struct{}
func (m *cohereProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *cohereProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -54,12 +55,12 @@ func (m *cohereProvider) GetProviderType() string {
return providerTypeCohere return providerTypeCohere
} }
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -151,7 +151,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
if err != nil { if err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err)) util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
} }
if err := replaceHttpJsonRequestBody(body, log); err != nil { if err := replaceRequestBody(body, log); err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err)) util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
} }
} }

View File

@@ -6,7 +6,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
) )
const ( const (
@@ -15,7 +14,7 @@ const (
type cozeProviderInitializer struct{} type cozeProviderInitializer struct{}
func (m *cozeProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -38,9 +37,9 @@ func (m *cozeProvider) GetProviderType() string {
return providerTypeCoze return providerTypeCoze
} }
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {

View File

@@ -10,7 +10,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
) )
@@ -58,7 +57,7 @@ type deeplResponseTranslation struct {
Text string `json:"text"` Text string `json:"text"`
} }
func (d *deeplProviderInitializer) ValidateConfig(config ProviderConfig) error { func (d *deeplProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.targetLang == "" { if config.targetLang == "" {
return errors.New("missing targetLang in deepl provider config") return errors.New("missing targetLang in deepl provider config")
} }
@@ -76,19 +75,17 @@ func (d *deeplProvider) GetProviderType() string {
return providerTypeDeepl return providerTypeDeepl
} }
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
d.config.handleRequestHeaders(d, ctx, apiName, log) d.config.handleRequestHeaders(d, ctx, apiName, log)
return types.HeaderStopIteration, nil return nil
} }
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath) util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
headers.Del("Accept-Encoding")
} }
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -114,18 +111,13 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
return json.Marshal(baiduRequest) return json.Marshal(baiduRequest)
} }
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
deeplResponse := &deeplResponse{} deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil { if err := json.Unmarshal(body, deeplResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err) return nil, fmt.Errorf("unable to unmarshal deepl response: %v", err)
} }
response := d.responseDeepl2OpenAI(ctx, deeplResponse) response := d.responseDeepl2OpenAI(ctx, deeplResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse { func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {

View File

@@ -2,10 +2,11 @@ package provider
import ( import (
"errors" "errors"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
) )
// deepseekProvider is the provider for deepseek Ai service. // deepseekProvider is the provider for deepseek Ai service.
@@ -18,7 +19,7 @@ const (
type deepseekProviderInitializer struct { type deepseekProviderInitializer struct {
} }
func (m *deepseekProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -41,12 +42,12 @@ func (m *deepseekProvider) GetProviderType() string {
return providerTypeDeepSeek return providerTypeDeepSeek
} }
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -2,11 +2,12 @@ package provider
import ( import (
"errors" "errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
) )
const ( const (
@@ -16,7 +17,7 @@ const (
type doubaoProviderInitializer struct{} type doubaoProviderInitializer struct{}
func (m *doubaoProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -39,12 +40,12 @@ func (m *doubaoProvider) GetProviderType() string {
return providerTypeDoubao return providerTypeDoubao
} }
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -4,14 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -19,7 +19,7 @@ import (
type failover struct { type failover struct {
// @Title zh-CN 是否启用 apiToken 的 failover 机制 // @Title zh-CN 是否启用 apiToken 的 failover 机制
enabled bool `required:"true" yaml:"enabled" json:"enabled"` enabled bool `required:"false" yaml:"enabled" json:"enabled"`
// @Title zh-CN 触发 failover 连续请求失败的阈值 // @Title zh-CN 触发 failover 连续请求失败的阈值
failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"` failureThreshold int64 `required:"false" yaml:"failureThreshold" json:"failureThreshold"`
// @Title zh-CN 健康检测的成功阈值 // @Title zh-CN 健康检测的成功阈值
@@ -29,7 +29,7 @@ type failover struct {
// @Title zh-CN 健康检测的超时时间,单位毫秒 // @Title zh-CN 健康检测的超时时间,单位毫秒
healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"` healthCheckTimeout int64 `required:"false" yaml:"healthCheckTimeout" json:"healthCheckTimeout"`
// @Title zh-CN 健康检测使用的模型 // @Title zh-CN 健康检测使用的模型
healthCheckModel string `required:"true" yaml:"healthCheckModel" json:"healthCheckModel"` healthCheckModel string `required:"false" yaml:"healthCheckModel" json:"healthCheckModel"`
// @Title zh-CN 本次请求使用的 apiToken // @Title zh-CN 本次请求使用的 apiToken
ctxApiTokenInUse string ctxApiTokenInUse string
// @Title zh-CN 记录 apiToken 请求失败的次数key 为 apiTokenvalue 为失败次数 // @Title zh-CN 记录 apiToken 请求失败的次数key 为 apiTokenvalue 为失败次数
@@ -184,9 +184,9 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok { if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log) body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok { } else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders() headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log) body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, originalHeaders, log)
util.ReplaceOriginalHttpHeaders(headers) util.ReplaceRequestHeaders(headers)
} else { } else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log) body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
} }
@@ -539,25 +539,32 @@ func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0) _ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
} }
func (c *ProviderConfig) OnRequestFailed(ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) { func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, log wrapper.Log) types.Action {
if c.isFailoverEnabled() { if c.isFailoverEnabled() {
c.handleUnavailableApiToken(ctx, apiTokenInUse, log) c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
} }
if c.isRetryOnFailureEnabled() && ctx.GetContext(ctxKeyIsStreaming) != nil && !ctx.GetContext(ctxKeyIsStreaming).(bool) {
c.retryFailedRequest(activeProvider, ctx, log)
return types.HeaderStopAllIterationAndWatermark
}
return types.ActionContinue
} }
func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string { func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
return ctx.GetContext(c.failover.ctxApiTokenInUse).(string) token, _ := ctx.GetContext(c.failover.ctxApiTokenInUse).(string)
return token
} }
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) { func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
var apiToken string var apiToken string
if c.isFailoverEnabled() || c.useGlobalApiToken { if c.isFailoverEnabled() || c.useGlobalApiToken {
// if enable apiToken failover, only use available apiToken // if enable apiToken failover, only use available apiToken from global apiTokens list
// or the apiToken need to be accessed globally (via all Wasm VMs, e.g. baidu),
apiToken = c.GetGlobalRandomToken(log) apiToken = c.GetGlobalRandomToken(log)
} else { } else {
apiToken = c.GetRandomToken() apiToken = c.GetRandomToken()
} }
log.Debugf("[onHttpRequestHeader] use apiToken %s to send request", apiToken) log.Debugf("Use apiToken %s to send request", apiToken)
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken) ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
} }

View File

@@ -28,7 +28,7 @@ const (
type geminiProviderInitializer struct { type geminiProviderInitializer struct {
} }
func (g *geminiProviderInitializer) ValidateConfig(config ProviderConfig) error { func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -51,20 +51,18 @@ func (g *geminiProvider) GetProviderType() string {
return providerTypeGemini return providerTypeGemini
} }
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
g.config.handleRequestHeaders(g, ctx, apiName, log) g.config.handleRequestHeaders(g, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody // Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil return nil
} }
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestHostHeader(headers, geminiDomain) util.OverwriteRequestHostHeader(headers, geminiDomain)
headers.Add(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx)) headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -107,16 +105,6 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(geminiRequest) return json.Marshal(geminiRequest)
} }
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if g.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk)) log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
@@ -150,39 +138,38 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
return []byte(modifiedResponseChunk), nil return []byte(modifiedResponseChunk), nil
} }
func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if apiName == ApiNameChatCompletion { if apiName == ApiNameChatCompletion {
return g.onChatCompletionResponseBody(ctx, body, log) return g.onChatCompletionResponseBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings { } else {
return g.onEmbeddingsResponseBody(ctx, body, log) return g.onEmbeddingsResponseBody(ctx, body, log)
} }
return types.ActionContinue, errUnsupportedApiName
} }
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
geminiResponse := &geminiChatResponse{} geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil { if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err) return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
} }
if geminiResponse.Error != nil { if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s", return nil, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message) geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
} }
response := g.buildChatCompletionResponse(ctx, geminiResponse) response := g.buildChatCompletionResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
geminiResponse := &geminiEmbeddingResponse{} geminiResponse := &geminiEmbeddingResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil { if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err) return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
} }
if geminiResponse.Error != nil { if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s", return nil, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message) geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
} }
response := g.buildEmbeddingsResponse(ctx, geminiResponse) response := g.buildEmbeddingsResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string { func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {

View File

@@ -2,11 +2,12 @@ package provider
import ( import (
"errors" "errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
) )
// githubProvider is the provider for GitHub OpenAI service. // githubProvider is the provider for GitHub OpenAI service.
@@ -24,7 +25,7 @@ type githubProvider struct {
contextCache *contextCache contextCache *contextCache
} }
func (m *githubProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *githubProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -42,13 +43,13 @@ func (m *githubProvider) GetProviderType() string {
return providerTypeGithub return providerTypeGithub
} }
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody // Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil return nil
} }
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -67,8 +68,6 @@ func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
util.OverwriteRequestPathHeader(headers, githubEmbeddingPath) util.OverwriteRequestPathHeader(headers, githubEmbeddingPath)
} }
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }
func (m *githubProvider) GetApiName(path string) ApiName { func (m *githubProvider) GetApiName(path string) ApiName {

View File

@@ -18,7 +18,7 @@ const (
type groqProviderInitializer struct{} type groqProviderInitializer struct{}
func (g *groqProviderInitializer) ValidateConfig(config ProviderConfig) error { func (g *groqProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -41,12 +41,12 @@ func (g *groqProvider) GetProviderType() string {
return providerTypeGroq return providerTypeGroq
} }
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
g.config.handleRequestHeaders(g, ctx, apiName, log) g.config.handleRequestHeaders(g, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -85,7 +85,7 @@ type hunyuanChatMessage struct {
Content string `json:"Content,omitempty"` Content string `json:"Content,omitempty"`
} }
func (m *hunyuanProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// 校验hunyuan id 和 key的合法性 // 校验hunyuan id 和 key的合法性
if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen { if len(config.hunyuanAuthId) != hunyuanAuthIdLen || len(config.hunyuanAuthKey) != hunyuanAuthKeyLen {
return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file") return errors.New("hunyuanAuthId / hunyuanAuthKey is illegal in config file")
@@ -114,13 +114,13 @@ func (m *hunyuanProvider) GetProviderType() string {
return providerTypeHunyuan return providerTypeHunyuan
} }
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody // Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil return nil
} }
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
@@ -128,11 +128,8 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
util.OverwriteRequestPathHeader(headers, hunyuanRequestPath) util.OverwriteRequestPathHeader(headers, hunyuanRequestPath)
// 添加 hunyuan 需要的自定义字段 // 添加 hunyuan 需要的自定义字段
headers.Add(actionKey, hunyuanChatCompletionTCAction) headers.Set(actionKey, hunyuanChatCompletionTCAction)
headers.Add(versionKey, versionValue) headers.Set(versionKey, versionValue)
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法 // hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
@@ -291,11 +288,6 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
return json.Marshal(hunyuanRequest) return json.Marshal(hunyuanRequest)
} }
func (m *hunyuanProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.protocol == protocolOriginal { if m.config.protocol == protocolOriginal {
return chunk, nil return chunk, nil
@@ -412,21 +404,14 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
return []byte(openAIChunk.String()), nil return []byte(openAIChunk.String()), nil
} }
func (m *hunyuanProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body)) log.Debugf("#debug nash5# onRespBody's resp is: %s", string(body))
hunyuanResponse := &hunyuanTextGenResponseNonStreaming{} hunyuanResponse := &hunyuanTextGenResponseNonStreaming{}
if err := json.Unmarshal(body, hunyuanResponse); err != nil { if err := json.Unmarshal(body, hunyuanResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal hunyuan response: %v", err) return nil, fmt.Errorf("unable to unmarshal hunyuan response: %v", err)
} }
if m.config.protocol == protocolOriginal {
return types.ActionContinue, replaceJsonResponseBody(hunyuanResponse, log)
}
response := m.buildChatCompletionResponse(ctx, hunyuanResponse) response := m.buildChatCompletionResponse(ctx, hunyuanResponse)
return json.Marshal(response)
return types.ActionContinue, replaceJsonResponseBody(response, log)
} }
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) { func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {

View File

@@ -11,47 +11,37 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
) )
// minimaxProvider is the provider for minimax service. // minimaxProvider is the provider for minimax service.
const ( const (
minimaxDomain = "api.minimax.chat" minimaxApiTypeV2 = "v2" // minimaxApiTypeV2 represents chat completion V2 API.
// minimaxChatCompletionV2Path 接口请求响应格式与OpenAI相同 minimaxApiTypePro = "pro" // minimaxApiTypePro represents chat completion Pro API.
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd minimaxDomain = "api.minimax.chat"
// minimaxChatCompletionV2Path represents the API path for chat completion V2 API which has a response format similar to OpenAI's.
minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2" minimaxChatCompletionV2Path = "/v1/text/chatcompletion_v2"
// minimaxChatCompletionProPath 接口请求响应格式与OpenAI不同 // minimaxChatCompletionProPath represents the API path for chat completion Pro API which has a different response format from OpenAI's.
// 接口文档: https://platform.minimaxi.com/document/guides/chat-model/pro/api?id=6569c85948bc7b684b30377e
minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro" minimaxChatCompletionProPath = "/v1/text/chatcompletion_pro"
senderTypeUser string = "USER" // 用户发送的内容 senderTypeUser string = "USER" // Content sent by the user.
senderTypeBot string = "BOT" // 模型生成的内容 senderTypeBot string = "BOT" // Content generated by the model.
// 默认机器人设置 // Default bot settings.
defaultBotName string = "MM智能助理" defaultBotName string = "MM智能助理"
defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司一直致力于进行大模型相关的研究。" defaultBotSettingContent string = "MM智能助理是一款由MiniMax自研的没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司一直致力于进行大模型相关的研究。"
defaultSenderName string = "小明" defaultSenderName string = "小明"
) )
// chatCompletionProModels 这些模型对应接口为ChatCompletion Pro
var chatCompletionProModels = map[string]struct{}{
"abab6.5-chat": {},
"abab6.5s-chat": {},
"abab5.5s-chat": {},
"abab5.5-chat": {},
}
type minimaxProviderInitializer struct { type minimaxProviderInitializer struct {
} }
func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error {
// 如果存在模型对应接口为ChatCompletion Pro必须配置minimaxGroupId // If using the chat completion Pro API, a group ID must be set.
if len(config.modelMapping) > 0 && config.minimaxGroupId == "" { if minimaxApiTypePro == config.minimaxApiType && config.minimaxGroupId == "" {
for _, minimaxModel := range config.modelMapping { return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when minimaxApiType is %s", minimaxApiTypePro))
if _, exists := chatCompletionProModels[minimaxModel]; exists {
return errors.New(fmt.Sprintf("missing minimaxGroupId in provider config when %s model is provided", minimaxModel))
}
}
} }
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
@@ -75,13 +65,13 @@ func (m *minimaxProvider) GetProviderType() string {
return providerTypeMinimax return providerTypeMinimax
} }
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
// Delay the header processing to allow changing streaming mode in OnRequestBody // Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil return nil
} }
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
@@ -94,44 +84,28 @@ func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return types.ActionContinue, errUnsupportedApiName
} }
// 解析并映射模型,设置上下文 if minimaxApiTypePro == m.config.minimaxApiType {
model, err := m.parseModel(body) // Use chat completion Pro API.
if err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
_, ok := chatCompletionProModels[mappedModel]
if ok {
// 使用ChatCompletion Pro接口
return m.handleRequestBodyByChatCompletionPro(body, log) return m.handleRequestBodyByChatCompletionPro(body, log)
} else { } else {
// 使用ChatCompletion v2接口 // Use chat completion V2 API.
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log) return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
} }
} }
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { // handleRequestBodyByChatCompletionPro processes the request body using the chat completion Pro API.
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
}
// handleRequestBodyByChatCompletionPro 使用ChatCompletion Pro接口处理请求体
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) { func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
request := &chatCompletionRequest{} request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil { if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err return types.ActionContinue, err
} }
// 映射模型重写requestPath // Map the model and rewrite the request path.
request.Model = getMappedModel(request.Model, m.config.modelMapping, log) request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId)) _ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if m.config.context == nil { if m.config.context == nil {
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, "") minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, "")
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest, log) return types.ActionContinue, replaceJsonRequestBody(minimaxRequest, log)
} }
@@ -143,10 +117,10 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
log.Errorf("failed to load context file: %v", err) log.Errorf("failed to load context file: %v", err)
util.ErrorHandler("ai-proxy.minimax.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err)) util.ErrorHandler("ai-proxy.minimax.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
} }
// 由于 minimaxChatCompletionV2(格式和 OpenAI 一致)和 minimaxChatCompletionPro(格式和 OpenAI 不一致)中 insertHttpContextMessage 的逻辑不同,无法做到同一个 provider 统一 // Since minimaxChatCompletionV2 (format consistent with OpenAI) and minimaxChatCompletionPro (different format from OpenAI) have different logic for insertHttpContextMessage, we cannot unify them within one provider.
// 因此对于 minimaxChatCompletionPro 需要手动处理 context 消息 // For minimaxChatCompletionPro, we need to manually handle context messages.
// minimaxChatCompletionV2 交给默认的 defaultInsertHttpContextMessage 方法插入 context 消息 // minimaxChatCompletionV2 uses the default defaultInsertHttpContextMessage method to insert context messages.
minimaxRequest := m.buildMinimaxChatCompletionV2Request(request, content) minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil { if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err)) util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err))
} }
@@ -157,63 +131,53 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
return types.ActionContinue, err return types.ActionContinue, err
} }
// handleRequestBodyByChatCompletionV2 使用ChatCompletion v2接口处理请求体 func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { return m.handleRequestBodyByChatCompletionV2(body, headers, log)
request := &chatCompletionRequest{} }
if err := decodeChatCompletionRequest(body, request); err != nil {
return nil, err
}
// 映射模型重写requestPath // handleRequestBodyByChatCompletionV2 processes the request body using the chat completion V2 API.
request.Model = getMappedModel(request.Model, m.config.modelMapping, log) func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path) util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
return body, nil rawModel := gjson.GetBytes(body, "model").String()
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
return sjson.SetBytes(body, "model", mappedModel)
} }
func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
// 使用minimax接口协议,跳过OnStreamingResponseBody()OnResponseBody() // Skip OnStreamingResponseBody() and OnResponseBody() when using the original protocol
if m.config.protocol == protocolOriginal { // or when the model corresponds to the chat completion V2 interface.
if m.config.protocol == protocolOriginal || minimaxApiTypePro != m.config.minimaxApiType {
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
return types.ActionContinue, nil } else {
headers.Del("Content-Length")
} }
// 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody()
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
if model != "" {
_, ok := chatCompletionProModels[model]
if !ok {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
} }
// OnStreamingResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应 // OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 { if isLastChunk || len(chunk) == 0 {
return nil, nil return nil, nil
} }
// sample event response: // Sample event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false} // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"","choices":[{"messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"am from China."}]}],"output_sensitive":false}
// sample end event response: // Sample end event response:
// data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}} // data: {"created":1689747645,"model":"abab6.5s-chat","reply":"I am from China.","choices":[{"finish_reason":"stop","messages":[{"sender_type":"BOT","sender_name":"MM智能助理","text":"I am from China."}]}],"usage":{"total_tokens":187},"input_sensitive":false,"output_sensitive":false,"id":"0106b3bc9fd844a9f3de1aa06004e2ab","base_resp":{"status_code":0,"status_msg":""}}
responseBuilder := &strings.Builder{} responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n") lines := strings.Split(string(chunk), "\n")
for _, data := range lines { for _, data := range lines {
if len(data) < 6 { if len(data) < 6 {
// ignore blank line or wrong format // Ignore blank line or improperly formatted lines.
continue continue
} }
data = data[6:] data = data[6:]
var minimaxResp minimaxChatCompletionV2Resp var minimaxResp minimaxChatCompletionProResp
if err := json.Unmarshal([]byte(data), &minimaxResp); err != nil { if err := json.Unmarshal([]byte(data), &minimaxResp); err != nil {
log.Errorf("unable to unmarshal minimax response: %v", err) log.Errorf("unable to unmarshal minimax response: %v", err)
continue continue
} }
response := m.responseV2ToOpenAI(&minimaxResp) response := m.responseProToOpenAI(&minimaxResp)
responseBody, err := json.Marshal(response) responseBody, err := json.Marshal(response)
if err != nil { if err != nil {
log.Errorf("unable to marshal response: %v", err) log.Errorf("unable to marshal response: %v", err)
@@ -226,82 +190,82 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
return []byte(modifiedResponseChunk), nil return []byte(modifiedResponseChunk), nil
} }
// OnResponseBody 只处理使用OpenAI协议 且 模型对应接口为ChatCompletion Pro的流式响应 // TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
minimaxResp := &minimaxChatCompletionV2Resp{} minimaxResp := &minimaxChatCompletionProResp{}
if err := json.Unmarshal(body, minimaxResp); err != nil { if err := json.Unmarshal(body, minimaxResp); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal minimax response: %v", err) return nil, fmt.Errorf("unable to unmarshal minimax response: %v", err)
} }
if minimaxResp.BaseResp.StatusCode != 0 { if minimaxResp.BaseResp.StatusCode != 0 {
return types.ActionContinue, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg) return nil, fmt.Errorf("minimax response error, error_code: %d, error_message: %s", minimaxResp.BaseResp.StatusCode, minimaxResp.BaseResp.StatusMsg)
} }
response := m.responseV2ToOpenAI(minimaxResp) response := m.responseProToOpenAI(minimaxResp)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
// minimaxChatCompletionV2Request 表示ChatCompletion V2请求的结构体 // minimaxChatCompletionProRequest represents the structure of a chat completion Pro request.
type minimaxChatCompletionV2Request struct { type minimaxChatCompletionProRequest struct {
Model string `json:"model"` Model string `json:"model"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
TokensToGenerate int64 `json:"tokens_to_generate,omitempty"` TokensToGenerate int64 `json:"tokens_to_generate,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
MaskSensitiveInfo bool `json:"mask_sensitive_info"` // 是否开启隐私信息打码,默认true MaskSensitiveInfo bool `json:"mask_sensitive_info"` // Whether to mask sensitive information, defaults to true.
Messages []minimaxMessage `json:"messages"` Messages []minimaxMessage `json:"messages"`
BotSettings []minimaxBotSetting `json:"bot_setting"` BotSettings []minimaxBotSetting `json:"bot_setting"`
ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"` ReplyConstraints minimaxReplyConstraints `json:"reply_constraints"`
} }
// minimaxMessage 表示对话中的消息 // minimaxMessage represents a message in the conversation.
type minimaxMessage struct { type minimaxMessage struct {
SenderType string `json:"sender_type"` SenderType string `json:"sender_type"`
SenderName string `json:"sender_name"` SenderName string `json:"sender_name"`
Text string `json:"text"` Text string `json:"text"`
} }
// minimaxBotSetting 表示机器人的设置 // minimaxBotSetting represents the bot's settings.
type minimaxBotSetting struct { type minimaxBotSetting struct {
BotName string `json:"bot_name"` BotName string `json:"bot_name"`
Content string `json:"content"` Content string `json:"content"`
} }
// minimaxReplyConstraints 表示模型回复要求 // minimaxReplyConstraints represents requirements for model replies.
type minimaxReplyConstraints struct { type minimaxReplyConstraints struct {
SenderType string `json:"sender_type"` SenderType string `json:"sender_type"`
SenderName string `json:"sender_name"` SenderName string `json:"sender_name"`
} }
// minimaxChatCompletionV2Resp Minimax Chat Completion V2响应结构体 // minimaxChatCompletionProResp represents the structure of a Minimax Chat Completion Pro response.
type minimaxChatCompletionV2Resp struct { type minimaxChatCompletionProResp struct {
Created int64 `json:"created"` Created int64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
Reply string `json:"reply"` Reply string `json:"reply"`
InputSensitive bool `json:"input_sensitive,omitempty"` InputSensitive bool `json:"input_sensitive,omitempty"`
InputSensitiveType int64 `json:"input_sensitive_type,omitempty"` OutputSensitive bool `json:"output_sensitive,omitempty"`
OutputSensitive bool `json:"output_sensitive,omitempty"` Choices []minimaxChoice `json:"choices,omitempty"`
OutputSensitiveType int64 `json:"output_sensitive_type,omitempty"` Usage minimaxUsage `json:"usage,omitempty"`
Choices []minimaxChoice `json:"choices,omitempty"` Id string `json:"id"`
Usage minimaxUsage `json:"usage,omitempty"` BaseResp minimaxBaseResp `json:"base_resp"`
Id string `json:"id"`
BaseResp minimaxBaseResp `json:"base_resp"`
} }
// minimaxBaseResp 包含错误状态码和详情 // minimaxBaseResp contains error status code and details.
type minimaxBaseResp struct { type minimaxBaseResp struct {
StatusCode int64 `json:"status_code"` StatusCode int64 `json:"status_code"`
StatusMsg string `json:"status_msg"` StatusMsg string `json:"status_msg"`
} }
// minimaxChoice 结果选项 // minimaxChoice represents a result option.
type minimaxChoice struct { type minimaxChoice struct {
Messages []minimaxMessage `json:"messages"` Messages []minimaxMessage `json:"messages"`
Index int64 `json:"index"` Index int64 `json:"index"`
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
} }
// minimaxUsage 令牌使用情况 // minimaxUsage represents token usage statistics.
type minimaxUsage struct { type minimaxUsage struct {
TotalTokens int64 `json:"total_tokens"` TotalTokens int64 `json:"total_tokens"`
PromptTokens int64 `json:"prompt_tokens"`
CompletionTokens int64 `json:"completion_tokens"`
} }
func (m *minimaxProvider) parseModel(body []byte) (string, error) { func (m *minimaxProvider) parseModel(body []byte) (string, error) {
@@ -316,7 +280,7 @@ func (m *minimaxProvider) parseModel(body []byte) (string, error) {
return model, nil return model, nil
} }
func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionV2Request, botSettingContent string) { func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionProRequest, botSettingContent string) {
if len(request.BotSettings) == 0 { if len(request.BotSettings) == 0 {
request.BotSettings = []minimaxBotSetting{ request.BotSettings = []minimaxBotSetting{
{ {
@@ -338,7 +302,7 @@ func (m *minimaxProvider) setBotSettings(request *minimaxChatCompletionV2Request
} }
} }
func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompletionRequest, botSettingContent string) *minimaxChatCompletionV2Request { func (m *minimaxProvider) buildMinimaxChatCompletionProRequest(request *chatCompletionRequest, botSettingContent string) *minimaxChatCompletionProRequest {
var messages []minimaxMessage var messages []minimaxMessage
var botSetting []minimaxBotSetting var botSetting []minimaxBotSetting
var botName string var botName string
@@ -377,7 +341,7 @@ func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompl
SenderType: senderTypeBot, SenderType: senderTypeBot,
SenderName: determineName(botName, defaultBotName), SenderName: determineName(botName, defaultBotName),
} }
result := &minimaxChatCompletionV2Request{ result := &minimaxChatCompletionProRequest{
Model: request.Model, Model: request.Model,
Stream: request.Stream, Stream: request.Stream,
TokensToGenerate: int64(request.MaxTokens), TokensToGenerate: int64(request.MaxTokens),
@@ -393,7 +357,7 @@ func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompl
return result return result
} }
func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Resp) *chatCompletionResponse { func (m *minimaxProvider) responseProToOpenAI(response *minimaxChatCompletionProResp) *chatCompletionResponse {
var choices []chatCompletionChoice var choices []chatCompletionChoice
messageIndex := 0 messageIndex := 0
for _, choice := range response.Choices { for _, choice := range response.Choices {
@@ -418,7 +382,9 @@ func (m *minimaxProvider) responseV2ToOpenAI(response *minimaxChatCompletionV2Re
Model: response.Model, Model: response.Model,
Choices: choices, Choices: choices,
Usage: usage{ Usage: usage{
TotalTokens: int(response.Usage.TotalTokens), TotalTokens: int(response.Usage.TotalTokens),
PromptTokens: int(response.Usage.PromptTokens),
CompletionTokens: int(response.Usage.CompletionTokens),
}, },
} }
} }

View File

@@ -2,10 +2,11 @@ package provider
import ( import (
"errors" "errors"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
) )
const ( const (
@@ -14,7 +15,7 @@ const (
type mistralProviderInitializer struct{} type mistralProviderInitializer struct{}
func (m *mistralProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -37,12 +38,12 @@ func (m *mistralProvider) GetProviderType() string {
return providerTypeMistral return providerTypeMistral
} }
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -24,7 +24,7 @@ const (
type moonshotProviderInitializer struct { type moonshotProviderInitializer struct {
} }
func (m *moonshotProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.moonshotFileId != "" && config.context != nil { if config.moonshotFileId != "" && config.context != nil {
return errors.New("moonshotFileId and context cannot be configured at the same time") return errors.New("moonshotFileId and context cannot be configured at the same time")
} }
@@ -56,12 +56,12 @@ func (m *moonshotProvider) GetProviderType() string {
return providerTypeMoonshot return providerTypeMoonshot
} }
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {

View File

@@ -3,10 +3,11 @@ package provider
import ( import (
"errors" "errors"
"fmt" "fmt"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
) )
// ollamaProvider is the provider for Ollama service. // ollamaProvider is the provider for Ollama service.
@@ -18,7 +19,7 @@ const (
type ollamaProviderInitializer struct { type ollamaProviderInitializer struct {
} }
func (m *ollamaProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *ollamaProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.ollamaServerHost == "" { if config.ollamaServerHost == "" {
return errors.New("missing ollamaServerHost in provider config") return errors.New("missing ollamaServerHost in provider config")
} }
@@ -48,12 +49,12 @@ func (m *ollamaProvider) GetProviderType() string {
return providerTypeOllama return providerTypeOllama
} }
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -22,7 +22,7 @@ const (
type openaiProviderInitializer struct { type openaiProviderInitializer struct {
} }
func (m *openaiProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
} }
@@ -57,9 +57,9 @@ func (m *openaiProvider) GetProviderType() string {
return providerTypeOpenAI return providerTypeOpenAI
} }
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) { func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {

View File

@@ -46,6 +46,7 @@ const (
providerTypeCohere = "cohere" providerTypeCohere = "cohere"
providerTypeDoubao = "doubao" providerTypeDoubao = "doubao"
providerTypeCoze = "coze" providerTypeCoze = "coze"
providerTypeTogetherAI = "together-ai"
protocolOpenAI = "openai" protocolOpenAI = "openai"
protocolOriginal = "original" protocolOriginal = "original"
@@ -58,7 +59,9 @@ const (
finishReasonLength = "length" finishReasonLength = "length"
ctxKeyIncrementalStreaming = "incrementalStreaming" ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiName = "apiKey" ctxKeyApiKey = "apiKey"
CtxKeyApiName = "apiName"
ctxKeyIsStreaming = "isStreaming"
ctxKeyStreamingBody = "streamingBody" ctxKeyStreamingBody = "streamingBody"
ctxKeyOriginalRequestModel = "originalRequestModel" ctxKeyOriginalRequestModel = "originalRequestModel"
ctxKeyFinalRequestModel = "finalRequestModel" ctxKeyFinalRequestModel = "finalRequestModel"
@@ -73,7 +76,7 @@ const (
) )
type providerInitializer interface { type providerInitializer interface {
ValidateConfig(ProviderConfig) error ValidateConfig(*ProviderConfig) error
CreateProvider(ProviderConfig) (Provider, error) CreateProvider(ProviderConfig) (Provider, error)
} }
@@ -106,6 +109,7 @@ var (
providerTypeCohere: &cohereProviderInitializer{}, providerTypeCohere: &cohereProviderInitializer{},
providerTypeDoubao: &doubaoProviderInitializer{}, providerTypeDoubao: &doubaoProviderInitializer{},
providerTypeCoze: &cozeProviderInitializer{}, providerTypeCoze: &cozeProviderInitializer{},
providerTypeTogetherAI: &togetherAIProviderInitializer{},
} }
) )
@@ -113,22 +117,26 @@ type Provider interface {
GetProviderType() string GetProviderType() string
} }
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type RequestHeadersHandler interface { type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
} }
type RequestBodyHandler interface { type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
} }
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
}
type ApiNameHandler interface {
GetApiName(path string) ApiName
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
}
type TransformRequestBodyHandler interface { type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
} }
@@ -139,16 +147,12 @@ type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
} }
type ResponseHeadersHandler interface { type TransformResponseHeadersHandler interface {
OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
} }
type StreamingResponseBodyHandler interface { type TransformResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
}
type ResponseBodyHandler interface {
OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
} }
// TickFuncHandler allows the provider to execute a function periodically // TickFuncHandler allows the provider to execute a function periodically
@@ -173,6 +177,9 @@ type ProviderConfig struct {
// @Title zh-CN apiToken 故障切换 // @Title zh-CN apiToken 故障切换
// @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表 // @Description zh-CN 当 apiToken 不可用时移出 apiTokens 列表,对移除的 apiToken 进行健康检查,当重新可用后加回 apiTokens 列表
failover *failover `required:"false" yaml:"failover" json:"failover"` failover *failover `required:"false" yaml:"failover" json:"failover"`
// @Title zh-CN 失败请求重试
// @Description zh-CN 对失败的请求立即进行重试
retryOnFailure *retryOnFailure `required:"false" yaml:"retryOnFailure" json:"retryOnFailure"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL // @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。 // @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"` openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
@@ -206,8 +213,11 @@ type ProviderConfig struct {
// @Title zh-CN hunyuan api id for authorization // @Title zh-CN hunyuan api id for authorization
// @Description zh-CN 仅适用于Hun Yuan AI服务鉴权 // @Description zh-CN 仅适用于Hun Yuan AI服务鉴权
hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"` hunyuanAuthId string `required:"false" yaml:"hunyuanAuthId" json:"hunyuanAuthId"`
// @Title zh-CN minimax API type
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型v2 和 pro 中选填一项,默认值为 v2
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
// @Title zh-CN minimax group id // @Title zh-CN minimax group id
// @Description zh-CN 仅适用于minimax使用ChatCompletion Pro接口的模型 // @Description zh-CN 仅适用于 minimax 服务。minimax API 类型为 pro 时必填
minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"` minimaxGroupId string `required:"false" yaml:"minimaxGroupId" json:"minimaxGroupId"`
// @Title zh-CN 模型名称映射表 // @Title zh-CN 模型名称映射表
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射 // @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
@@ -303,6 +313,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.claudeVersion = json.Get("claudeVersion").String() c.claudeVersion = json.Get("claudeVersion").String()
c.hunyuanAuthId = json.Get("hunyuanAuthId").String() c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String() c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
c.minimaxApiType = json.Get("minimaxApiType").String()
c.minimaxGroupId = json.Get("minimaxGroupId").String() c.minimaxGroupId = json.Get("minimaxGroupId").String()
c.cloudflareAccountId = json.Get("cloudflareAccountId").String() c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
if c.typ == providerTypeGemini { if c.typ == providerTypeGemini {
@@ -346,6 +357,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.failover.FromJson(failoverJson) c.failover.FromJson(failoverJson)
} }
retryOnFailureJson := json.Get("retryOnFailure")
c.retryOnFailure = &retryOnFailure{
enabled: false,
}
if retryOnFailureJson.Exists() {
c.retryOnFailure.FromJson(retryOnFailureJson)
}
for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() { for _, accessKeyAndSecret := range json.Get("baiduAccessKeyAndSecret").Array() {
c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String()) c.baiduAccessKeyAndSecret = append(c.baiduAccessKeyAndSecret, accessKeyAndSecret.String())
} }
@@ -386,17 +405,17 @@ func (c *ProviderConfig) Validate() error {
if !has { if !has {
return errors.New("unknown provider type: " + c.typ) return errors.New("unknown provider type: " + c.typ)
} }
if err := initializer.ValidateConfig(*c); err != nil { if err := initializer.ValidateConfig(c); err != nil {
return err return err
} }
return nil return nil
} }
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string { func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
ctxApiKey := ctx.GetContext(ctxKeyApiName) ctxApiKey := ctx.GetContext(ctxKeyApiKey)
if ctxApiKey == nil { if ctxApiKey == nil {
ctxApiKey = c.GetRandomToken() ctxApiKey = c.GetRandomToken()
ctx.SetContext(ctxKeyApiName, ctxApiKey) ctx.SetContext(ctxKeyApiKey, ctxApiKey)
} }
return ctxApiKey.(string) return ctxApiKey.(string)
} }
@@ -440,6 +459,9 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
streaming := req.Stream streaming := req.Stream
if streaming { if streaming {
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream") _ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
ctx.SetContext(ctxKeyIsStreaming, true)
} else {
ctx.SetContext(ctxKeyIsStreaming, false)
} }
return c.setRequestModel(ctx, req, log) return c.setRequestModel(ctx, req, log)
@@ -534,9 +556,9 @@ func (c *ProviderConfig) handleRequestBody(
if handler, ok := provider.(TransformRequestBodyHandler); ok { if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log) body, err = handler.TransformRequestBody(ctx, apiName, body, log)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok { } else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalHttpHeaders() headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log) body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
util.ReplaceOriginalHttpHeaders(headers) util.ReplaceRequestHeaders(headers)
} else { } else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log) body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
} }
@@ -545,9 +567,14 @@ func (c *ProviderConfig) handleRequestBody(
return types.ActionContinue, err return types.ActionContinue, err
} }
// If retryOnFailure is enabled, save the transformed body to the context in case of retry
if c.isRetryOnFailureEnabled() {
ctx.SetContext(ctxRequestBody, body)
}
if apiName == ApiNameChatCompletion { if apiName == ApiNameChatCompletion {
if c.context == nil { if c.context == nil {
return types.ActionContinue, replaceHttpJsonRequestBody(body, log) return types.ActionContinue, replaceRequestBody(body, log)
} }
err = contextCache.GetContextFromFile(ctx, provider, body, log) err = contextCache.GetContextFromFile(ctx, provider, body, log)
@@ -556,14 +583,14 @@ func (c *ProviderConfig) handleRequestBody(
} }
return types.ActionContinue, err return types.ActionContinue, err
} }
return types.ActionContinue, replaceHttpJsonRequestBody(body, log) return types.ActionContinue, replaceRequestBody(body, log)
} }
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) { func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
headers := util.GetOriginalRequestHeaders()
if handler, ok := provider.(TransformRequestHeadersHandler); ok { if handler, ok := provider.(TransformRequestHeadersHandler); ok {
originalHeaders := util.GetOriginalHttpHeaders() handler.TransformRequestHeaders(ctx, apiName, headers, log)
handler.TransformRequestHeaders(ctx, apiName, originalHeaders, log) util.ReplaceRequestHeaders(headers)
util.ReplaceOriginalHttpHeaders(originalHeaders)
} }
} }
@@ -579,3 +606,11 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
} }
return json.Marshal(request) return json.Marshal(request)
} }
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {
if c.protocol == protocolOriginal {
ctx.DontReadResponseBody()
} else {
headers.Del("Content-Length")
}
}

View File

@@ -27,6 +27,7 @@ const (
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation" qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding" qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatiblePath = "/compatible-mode/v1/chat/completions" qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
qwenBailianPath = "/api/v1/apps"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation" qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenTopPMin = 0.000001 qwenTopPMin = 0.000001
@@ -41,7 +42,7 @@ const (
type qwenProviderInitializer struct { type qwenProviderInitializer struct {
} }
func (m *qwenProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if len(config.qwenFileIds) != 0 && config.context != nil { if len(config.qwenFileIds) != 0 && config.context != nil {
return errors.New("qwenFileIds and context cannot be configured at the same time") return errors.New("qwenFileIds and context cannot be configured at the same time")
} }
@@ -71,16 +72,14 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
} }
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
if m.config.qwenEnableCompatible { if m.config.IsOriginal() {
} else if m.config.qwenEnableCompatible {
util.OverwriteRequestPathHeader(headers, qwenCompatiblePath) util.OverwriteRequestPathHeader(headers, qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion { } else if apiName == ApiNameChatCompletion {
util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath) util.OverwriteRequestPathHeader(headers, qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings { } else if apiName == ApiNameEmbeddings {
util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath) util.OverwriteRequestPathHeader(headers, qwenTextEmbeddingPath)
} }
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) { func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
@@ -95,20 +94,19 @@ func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen return providerTypeQwen
} }
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings { if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
if m.config.protocol == protocolOriginal { if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody() ctx.DontReadRequestBody()
return types.ActionContinue, nil return nil
} }
// Delay the header processing to allow changing streaming mode in OnRequestBody return nil
return types.HeaderStopIteration, nil
} }
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -185,16 +183,6 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
return json.Marshal(qwenRequest) return json.Marshal(qwenRequest)
} }
func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if m.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion { if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
return chunk, nil return chunk, nil
@@ -280,9 +268,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
return []byte(modifiedResponseChunk), nil return []byte(modifiedResponseChunk), nil
} }
func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
if m.config.qwenEnableCompatible { if m.config.qwenEnableCompatible {
return types.ActionContinue, nil return body, nil
} }
if apiName == ApiNameChatCompletion { if apiName == ApiNameChatCompletion {
return m.onChatCompletionResponseBody(ctx, body, log) return m.onChatCompletionResponseBody(ctx, body, log)
@@ -290,25 +278,25 @@ func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName,
if apiName == ApiNameEmbeddings { if apiName == ApiNameEmbeddings {
return m.onEmbeddingsResponseBody(ctx, body, log) return m.onEmbeddingsResponseBody(ctx, body, log)
} }
return types.ActionContinue, errUnsupportedApiName return nil, errUnsupportedApiName
} }
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
qwenResponse := &qwenTextGenResponse{} qwenResponse := &qwenTextGenResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil { if err := json.Unmarshal(body, qwenResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err) return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
} }
response := m.buildChatCompletionResponse(ctx, qwenResponse) response := m.buildChatCompletionResponse(ctx, qwenResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) { func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
qwenResponse := &qwenTextEmbeddingResponse{} qwenResponse := &qwenTextEmbeddingResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil { if err := json.Unmarshal(body, qwenResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal Qwen response: %v", err) return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
} }
response := m.buildEmbeddingsResponse(ctx, qwenResponse) response := m.buildEmbeddingsResponse(ctx, qwenResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) { func (m *qwenProvider) buildQwenTextGenerationRequest(ctx wrapper.HttpContext, origRequest *chatCompletionRequest, streaming bool) ([]byte, error) {
@@ -762,6 +750,7 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
switch { switch {
case strings.Contains(path, qwenChatCompletionPath), case strings.Contains(path, qwenChatCompletionPath),
strings.Contains(path, qwenMultimodalGenerationPath), strings.Contains(path, qwenMultimodalGenerationPath),
strings.Contains(path, qwenBailianPath),
strings.Contains(path, qwenCompatiblePath): strings.Contains(path, qwenCompatiblePath):
return ApiNameChatCompletion return ApiNameChatCompletion
case strings.Contains(path, qwenTextEmbeddingPath): case strings.Contains(path, qwenTextEmbeddingPath):

View File

@@ -37,7 +37,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
return err return err
} }
func replaceHttpJsonRequestBody(body []byte, log wrapper.Log) error { func replaceRequestBody(body []byte, log wrapper.Log) error {
log.Debugf("request body: %s", string(body)) log.Debugf("request body: %s", string(body))
err := proxywasm.ReplaceHttpRequestBody(body) err := proxywasm.ReplaceHttpRequestBody(body)
if err != nil { if err != nil {
@@ -65,15 +65,11 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
} }
} }
func replaceJsonResponseBody(response interface{}, log wrapper.Log) error { func ReplaceResponseBody(body []byte, log wrapper.Log) error {
body, err := json.Marshal(response)
if err != nil {
return fmt.Errorf("unable to marshal response: %v", err)
}
log.Debugf("response body: %s", string(body)) log.Debugf("response body: %s", string(body))
err = proxywasm.ReplaceHttpResponseBody(body) err := proxywasm.ReplaceHttpResponseBody(body)
if err != nil { if err != nil {
return fmt.Errorf("unable to replace the original response body: %v", err) return fmt.Errorf("unable to replace the original response body: %v", err)
} }
return err return nil
} }

View File

@@ -0,0 +1,141 @@
package provider
import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
"net/http"
)
const (
ctxRequestBody = "requestBody"
ctxRetryCount = "retryCount"
)
type retryOnFailure struct {
// @Title zh-CN 是否启用请求重试
enabled bool `required:"false" yaml:"enabled" json:"enabled"`
// @Title zh-CN 重试次数
maxRetries int64 `required:"false" yaml:"maxRetries" json:"maxRetries"`
// @Title zh-CN 重试超时时间
retryTimeout int64 `required:"false" yaml:"retryTimeout" json:"retryTimeout"`
}
func (r *retryOnFailure) FromJson(json gjson.Result) {
r.enabled = json.Get("enabled").Bool()
r.maxRetries = json.Get("maxRetries").Int()
if r.maxRetries == 0 {
r.maxRetries = 1
}
r.retryTimeout = json.Get("retryTimeout").Int()
if r.retryTimeout == 0 {
r.retryTimeout = 30 * 1000
}
}
func (c *ProviderConfig) isRetryOnFailureEnabled() bool {
return c.retryOnFailure.enabled
}
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, log wrapper.Log) {
log.Debugf("Retry failed request: provider=%s", activeProvider.GetProviderType())
retryClient := createRetryClient(ctx)
apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName)
ctx.SetContext(ctxRetryCount, 1)
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
}
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) {
if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok {
handler.TransformResponseHeaders(ctx, apiName, headers, log)
} else {
c.DefaultTransformResponseHeaders(ctx, headers)
}
if handler, ok := activeProvider.(TransformResponseBodyHandler); ok {
var err error
body, err = handler.TransformResponseBody(ctx, apiName, body, log)
if err != nil {
log.Errorf("Failed to transform response body: %v", err)
}
}
return util.HeaderToSlice(headers), body
}
func (c *ProviderConfig) retryCall(
ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider,
apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte,
retryClient *wrapper.ClusterClient[wrapper.RouteCluster]) {
retryCount := ctx.GetContext(ctxRetryCount).(int)
log.Debugf("Sent retry request: %d/%d", retryCount, c.retryOnFailure.maxRetries)
if statusCode == 200 {
log.Debugf("Retry request succeeded")
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log)
proxywasm.SendHttpResponse(200, headers, body, -1)
} else {
log.Debugf("The retry request still failed, status: %d, responseHeaders: %v, responseBody: %s", statusCode, responseHeaders, string(responseBody))
}
retryCount++
if retryCount <= int(c.retryOnFailure.maxRetries) {
ctx.SetContext(ctxRetryCount, retryCount)
c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, log)
} else {
log.Debugf("Reached the maximum retry count: %d", c.retryOnFailure.maxRetries)
proxywasm.ResumeHttpResponse()
}
}
func (c *ProviderConfig) sendRetryRequest(
ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider,
retryClient *wrapper.ClusterClient[wrapper.RouteCluster], log wrapper.Log) {
requestHeaders, requestBody := c.getRetryRequestHeadersAndBody(ctx, activeProvider, apiName, log)
path := getRetryPath(ctx)
err := retryClient.Post(path, util.HeaderToSlice(requestHeaders), requestBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient)
}, uint32(c.retryOnFailure.retryTimeout))
if err != nil {
log.Errorf("Failed to send retry request: %v", err)
proxywasm.ResumeHttpResponse()
}
}
func createRetryClient(ctx wrapper.HttpContext) *wrapper.ClusterClient[wrapper.RouteCluster] {
host := wrapper.GetRequestHost()
if host == "" {
host = ctx.GetContext(ctxRequestHost).(string)
}
retryClient := wrapper.NewClusterClient(wrapper.RouteCluster{
Host: host,
})
return retryClient
}
func getRetryPath(ctx wrapper.HttpContext) string {
path := wrapper.GetRequestPath()
if path == "" {
path = ctx.GetContext(ctxRequestPath).(string)
}
return path
}
func (c *ProviderConfig) getRetryRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, log wrapper.Log) (http.Header, []byte) {
// The retry request may be sent with different apiToken, so the header needs to be regenerated
c.SetApiTokenInUse(ctx, log)
requestHeaders := http.Header{
"Content-Type": []string{"application/json"},
}
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, apiName, requestHeaders, log)
}
requestBody := ctx.GetContext(ctxRequestBody).([]byte)
return requestHeaders, requestBody
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
) )
@@ -52,7 +51,7 @@ type sparkStreamResponse struct {
Created int64 `json:"created"` Created int64 `json:"created"`
} }
func (i *sparkProviderInitializer) ValidateConfig(config ProviderConfig) error { func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error {
return nil return nil
} }
@@ -67,12 +66,12 @@ func (p *sparkProvider) GetProviderType() string {
return providerTypeSpark return providerTypeSpark
} }
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
p.config.handleRequestHeaders(p, ctx, apiName, log) p.config.handleRequestHeaders(p, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
@@ -82,21 +81,16 @@ func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log) return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
} }
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
sparkResponse := &sparkResponse{} sparkResponse := &sparkResponse{}
if err := json.Unmarshal(body, sparkResponse); err != nil { if err := json.Unmarshal(body, sparkResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err) return nil, fmt.Errorf("unable to unmarshal spark response: %v", err)
} }
if sparkResponse.Code != 0 { if sparkResponse.Code != 0 {
return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message) return nil, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
} }
response := p.responseSpark2OpenAI(ctx, sparkResponse) response := p.responseSpark2OpenAI(ctx, sparkResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log) return json.Marshal(response)
} }
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) { func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
@@ -177,6 +171,4 @@ func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath) util.OverwriteRequestPathHeader(headers, sparkChatCompletionPath)
util.OverwriteRequestHostHeader(headers, sparkHost) util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx)) util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))
headers.Del("Accept-Encoding")
headers.Del("Content-Length")
} }

View File

@@ -2,10 +2,11 @@ package provider
import ( import (
"errors" "errors"
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types" "github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
) )
const ( const (
@@ -16,7 +17,7 @@ const (
type stepfunProviderInitializer struct { type stepfunProviderInitializer struct {
} }
func (m *stepfunProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -39,12 +40,12 @@ func (m *stepfunProvider) GetProviderType() string {
return providerTypeStepfun return providerTypeStepfun
} }
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -0,0 +1,69 @@
package provider
import (
"errors"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"net/http"
"strings"
)
const (
togetherAIDomain = "api.together.xyz"
togetherAICompletionPath = "/v1/chat/completions"
)
type togetherAIProviderInitializer struct{}
func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (m *togetherAIProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &togetherAIProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type togetherAIProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (m *togetherAIProvider) GetProviderType() string {
return providerTypeTogetherAI
}
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion {
return errUnsupportedApiName
}
m.config.handleRequestHeaders(m, ctx, apiName, log)
return nil
}
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
}
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, togetherAICompletionPath)
util.OverwriteRequestHostHeader(headers, togetherAIDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *togetherAIProvider) GetApiName(path string) ApiName {
if strings.Contains(path, togetherAICompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -17,7 +17,7 @@ const (
type yiProviderInitializer struct { type yiProviderInitializer struct {
} }
func (m *yiProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *yiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -40,12 +40,12 @@ func (m *yiProvider) GetProviderType() string {
return providerTypeYi return providerTypeYi
} }
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -17,7 +17,7 @@ const (
type zhipuAiProviderInitializer struct{} type zhipuAiProviderInitializer struct{}
func (m *zhipuAiProviderInitializer) ValidateConfig(config ProviderConfig) error { func (m *zhipuAiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 { if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config") return errors.New("no apiToken found in provider config")
} }
@@ -40,12 +40,12 @@ func (m *zhipuAiProvider) GetProviderType() string {
return providerTypeZhipuAi return providerTypeZhipuAi
} }
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) { func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
if apiName != ApiNameChatCompletion { if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName return errUnsupportedApiName
} }
m.config.handleRequestHeaders(m, ctx, apiName, log) m.config.handleRequestHeaders(m, ctx, apiName, log)
return types.ActionContinue, nil return nil
} }
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) { func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {

View File

@@ -86,12 +86,22 @@ func SliceToHeader(slice [][2]string) http.Header {
return header return header
} }
func GetOriginalHttpHeaders() http.Header { func GetOriginalRequestHeaders() http.Header {
originalHeaders, _ := proxywasm.GetHttpRequestHeaders() originalHeaders, _ := proxywasm.GetHttpRequestHeaders()
return SliceToHeader(originalHeaders) return SliceToHeader(originalHeaders)
} }
func ReplaceOriginalHttpHeaders(headers http.Header) { func GetOriginalResponseHeaders() http.Header {
originalHeaders, _ := proxywasm.GetHttpResponseHeaders()
return SliceToHeader(originalHeaders)
}
func ReplaceRequestHeaders(headers http.Header) {
modifiedHeaders := HeaderToSlice(headers) modifiedHeaders := HeaderToSlice(headers)
_ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders) _ = proxywasm.ReplaceHttpRequestHeaders(modifiedHeaders)
} }
func ReplaceResponseHeaders(headers http.Header) {
modifiedHeaders := HeaderToSlice(headers)
_ = proxywasm.ReplaceHttpResponseHeaders(modifiedHeaders)
}

View File

@@ -6,9 +6,9 @@ description: AI 配额管理插件配置参考
## 功能说明 ## 功能说明
`ai-qutoa` 插件实现给特定 consumer 根据分配固定的 quota 进行 quota 策略限流,同时支持 quota 管理能力,包括查询 quota 、刷新 quota、增减 quota。 `ai-quota` 插件实现给特定 consumer 根据分配固定的 quota 进行 quota 策略限流,同时支持 quota 管理能力,包括查询 quota 、刷新 quota、增减 quota。
`ai-quota` 插件需要配合 认证插件比如 `key-auth``jwt-auth` 等插件获取认证身份的 consumer 名称,同时需要配合 `ai-statatistics` 插件获取 AI Token 统计信息。 `ai-quota` 插件需要配合 认证插件比如 `key-auth``jwt-auth` 等插件获取认证身份的 consumer 名称,同时需要配合 `ai-statistics` 插件获取 AI Token 统计信息。
## 运行属性 ## 运行属性

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -215,35 +216,51 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da
if chatMode == ChatModeNone || chatMode == ChatModeAdmin { if chatMode == ChatModeNone || chatMode == ChatModeAdmin {
return data return data
} }
var inputToken, outputToken int64
var consumer string
if inputToken, outputToken, ok := getUsage(data); ok {
ctx.SetContext("input_token", inputToken)
ctx.SetContext("output_token", outputToken)
}
// chat completion mode // chat completion mode
if !endOfStream { if !endOfStream {
return data return data
} }
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
if err != nil { if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil || ctx.GetContext("consumer") == nil {
return data return data
} }
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
if err != nil { inputToken = ctx.GetContext("input_token").(int64)
return data outputToken = ctx.GetContext("output_token").(int64)
} consumer = ctx.GetContext("consumer").(string)
inputToken, err := strconv.Atoi(string(inputTokenStr)) totalToken := int(inputToken + outputToken)
if err != nil { log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken)
return data config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil)
}
outputToken, err := strconv.Atoi(string(outputTokenStr))
if err != nil {
return data
}
consumer, ok := ctx.GetContext("consumer").(string)
if ok {
totalToken := int(inputToken + outputToken)
log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken)
config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil)
}
return data return data
} }
func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
for _, chunk := range chunks {
// the feature strings are used to identify the usage data, like:
// {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
}
return
}
func deniedNoKeyAuthData() types.Action { func deniedNoKeyAuthData() types.Action {
util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.") util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.")
return types.ActionContinue return types.ActionContinue

View File

@@ -31,6 +31,7 @@ description: 阿里云内容安全检测
| `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 | | `denyMessage` | string | optional | openai格式的流式/非流式响应 | 指定内容非法时的响应内容 |
| `protocol` | string | optional | openai | 协议格式非openai协议填`original` | | `protocol` | string | optional | openai | 协议格式非openai协议填`original` |
| `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low | | `riskLevelBar` | string | optional | high | 拦截风险等级,取值为 max, high, medium, low |
| `timeout` | int | optional | 2000 | 调用内容安全服务时的超时时间 |
补充说明一下 `denyMessage`,对非法请求的处理逻辑为: 补充说明一下 `denyMessage`,对非法请求的处理逻辑为:
- 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容格式为openai格式的流式/非流式响应 - 如果配置了 `denyMessage`,返回内容为 `denyMessage` 配置内容格式为openai格式的流式/非流式响应

View File

@@ -41,9 +41,9 @@ const (
LowRisk = "low" LowRisk = "low"
NoRisk = "none" NoRisk = "none"
OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"%s","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}]}` OpenAIResponseFormat = `{"id": "%s","object":"chat.completion","model":"from-security-guard","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}` OpenAIStreamResponseChunk = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"logprobs":null,"finish_reason":null}]}`
OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"%s","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}` OpenAIStreamResponseEnd = `data:{"id":"%s","object":"chat.completion.chunk","model":"from-security-guard","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]` OpenAIStreamResponseFormat = OpenAIStreamResponseChunk + "\n\n" + OpenAIStreamResponseEnd + "\n\n" + `data: [DONE]`
DefaultRequestCheckService = "llm_query_moderation" DefaultRequestCheckService = "llm_query_moderation"
@@ -53,6 +53,7 @@ const (
DefaultStreamingResponseJsonPath = "choices.0.delta.content" DefaultStreamingResponseJsonPath = "choices.0.delta.content"
DefaultDenyCode = 200 DefaultDenyCode = 200
DefaultDenyMessage = "很抱歉,我无法回答您的问题" DefaultDenyMessage = "很抱歉,我无法回答您的问题"
DefaultTimeout = 2000
AliyunUserAgent = "CIPFrom/AIGateway" AliyunUserAgent = "CIPFrom/AIGateway"
LengthLimit = 1800 LengthLimit = 1800
@@ -100,6 +101,7 @@ type AISecurityConfig struct {
denyMessage string denyMessage string
protocolOriginal bool protocolOriginal bool
riskLevelBar string riskLevelBar string
timeout uint32
metrics map[string]proxywasm.MetricCounter metrics map[string]proxywasm.MetricCounter
} }
@@ -225,6 +227,11 @@ func parseConfig(json gjson.Result, config *AISecurityConfig, log wrapper.Log) e
} else { } else {
config.riskLevelBar = HighRisk config.riskLevelBar = HighRisk
} }
if obj := json.Get("timeout"); obj.Exists() {
config.timeout = uint32(obj.Int())
} else {
config.timeout = DefaultTimeout
}
config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{ config.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName, FQDN: serviceName,
Port: servicePort, Port: servicePort,
@@ -253,9 +260,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking request body...") log.Debugf("checking request body...")
startTime := time.Now().UnixMilli()
content := gjson.GetBytes(body, config.requestContentJsonPath).String() content := gjson.GetBytes(body, config.requestContentJsonPath).String()
model := gjson.GetBytes(body, "model").String()
ctx.SetContext("requestModel", model)
log.Debugf("Raw request content is: %s", content) log.Debugf("Raw request content is: %s", content)
if len(content) == 0 { if len(content) == 0 {
log.Info("request content is empty. skip") log.Info("request content is empty. skip")
@@ -279,6 +285,10 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
} }
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
if contentIndex >= len(content) { if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "request pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpRequest() proxywasm.ResumeHttpRequest()
} else { } else {
singleCall() singleCall()
@@ -296,16 +306,23 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if gjson.GetBytes(body, "stream").Bool() { } else if gjson.GetBytes(body, "stream").Bool() {
randomID := generateRandomID() randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else { } else {
randomID := generateRandomID() randomID := generateRandomID()
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} }
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
config.incrementCounter("ai_sec_request_deny", 1) config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.ResumeHttpRequest() endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_request_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "reqeust deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
} }
singleCall = func() { singleCall = func() {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
@@ -340,7 +357,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
reqParams.Add(k, v) reqParams.Add(k, v)
} }
reqParams.Add("Signature", signature) reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
if err != nil { if err != nil {
log.Errorf("failed call the safe check service: %v", err) log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpRequest() proxywasm.ResumeHttpRequest()
@@ -350,50 +367,26 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
return types.ActionPause return types.ActionPause
} }
func convertHeaders(hs [][2]string) map[string][]string {
ret := make(map[string][]string)
for _, h := range hs {
k, v := strings.ToLower(h[0]), h[1]
ret[k] = append(ret[k], v)
}
return ret
}
// headers: map[string][]string -> [][2]string
func reconvertHeaders(hs map[string][]string) [][2]string {
var ret [][2]string
for k, vs := range hs {
for _, v := range vs {
ret = append(ret, [2]string{k, v})
}
}
sort.SliceStable(ret, func(i, j int) bool {
return ret[i][0] < ret[j][0]
})
return ret
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action { func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log wrapper.Log) types.Action {
if !config.checkResponse { if !config.checkResponse {
log.Debugf("response checking is disabled") log.Debugf("response checking is disabled")
ctx.DontReadResponseBody() ctx.DontReadResponseBody()
return types.ActionContinue return types.ActionContinue
} }
headers, err := proxywasm.GetHttpResponseHeaders() statusCode, _ := proxywasm.GetHttpResponseHeader(":status")
if err != nil { if statusCode != "200" {
log.Warnf("failed to get response headers: %v", err) log.Debugf("response is not 200, skip response body check")
ctx.DontReadResponseBody()
return types.ActionContinue return types.ActionContinue
} }
hdsMap := convertHeaders(headers)
ctx.SetContext("headers", hdsMap)
return types.HeaderStopIteration return types.HeaderStopIteration
} }
func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action { func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("checking response body...") log.Debugf("checking response body...")
hdsMap := ctx.GetContext("headers").(map[string][]string) startTime := time.Now().UnixMilli()
isStreamingResponse := strings.Contains(strings.Join(hdsMap["content-type"], ";"), "event-stream") contentType, _ := proxywasm.GetHttpResponseHeader("content-type")
model := ctx.GetStringContext("requestModel", "unknown") isStreamingResponse := strings.Contains(contentType, "event-stream")
var content string var content string
if isStreamingResponse { if isStreamingResponse {
content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath) content = extractMessageFromStreamingBody(body, config.responseStreamContentJsonPath)
@@ -423,6 +416,10 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
} }
if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) { if riskLevelToInt(response.Data.RiskLevel) < riskLevelToInt(config.riskLevelBar) {
if contentIndex >= len(content) { if contentIndex >= len(content) {
endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response pass")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
proxywasm.ResumeHttpResponse() proxywasm.ResumeHttpResponse()
} else { } else {
singleCall() singleCall()
@@ -436,22 +433,26 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
denyMessage = response.Data.Advice[0].Answer denyMessage = response.Data.Advice[0].Answer
} }
marshalledDenyMessage := marshalStr(denyMessage, log) marshalledDenyMessage := marshalStr(denyMessage, log)
var jsonData []byte
if config.protocolOriginal { if config.protocolOriginal {
jsonData = []byte(marshalledDenyMessage) proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse { } else if isStreamingResponse {
randomID := generateRandomID() randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model)) jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, marshalledDenyMessage, randomID))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else { } else {
randomID := generateRandomID() randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage)) jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
} }
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
config.incrementCounter("ai_sec_response_deny", 1) config.incrementCounter("ai_sec_response_deny", 1)
proxywasm.ResumeHttpResponse() endTime := time.Now().UnixMilli()
ctx.SetUserAttribute("safecheck_response_rt", endTime-startTime)
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
} }
singleCall = func() { singleCall = func() {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z") timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
@@ -486,7 +487,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
reqParams.Add(k, v) reqParams.Add(k, v)
} }
reqParams.Add("Signature", signature) reqParams.Add("Signature", signature)
err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback) err := config.client.Post(fmt.Sprintf("/?%s", reqParams.Encode()), [][2]string{{"User-Agent", AliyunUserAgent}}, nil, callback, config.timeout)
if err != nil { if err != nil {
log.Errorf("failed call the safe check service: %v", err) log.Errorf("failed call the safe check service: %v", err)
proxywasm.ResumeHttpResponse() proxywasm.ResumeHttpResponse()

View File

@@ -38,7 +38,7 @@ Attribute 配置说明:
`value_source` 的各种取值含义如下: `value_source` 的各种取值含义如下:
- `fixed_value`:固定值 - `fixed_value`:固定值
- `requeset_header` attrribute 值通过 http 请求头获取value 配置为 header key - `request_header` attrribute 值通过 http 请求头获取value 配置为 header key
- `request_body` attrribute 值通过请求 body 获取value 配置格式为 gjson 的 jsonpath - `request_body` attrribute 值通过请求 body 获取value 配置格式为 gjson 的 jsonpath
- `response_header` attrribute 值通过 http 响应头获取value 配置为header key - `response_header` attrribute 值通过 http 响应头获取value 配置为header key
- `response_body` attrribute 值通过响应 body 获取value 配置格式为 gjson 的 jsonpath - `response_body` attrribute 值通过响应 body 获取value 配置格式为 gjson 的 jsonpath

View File

@@ -38,7 +38,7 @@ Attribute Configuration instructions:
The meanings of various values for `value_source` are as follows: The meanings of various values for `value_source` are as follows:
- `fixed_value`: fixed value - `fixed_value`: fixed value
- `requeset_header`: The attrribute is obtained through the http request header - `request_header`: The attrribute is obtained through the http request header
- `request_body`: The attrribute is obtained through the http request body - `request_body`: The attrribute is obtained through the http request body
- `response_header`: The attrribute is obtained through the http response header - `response_header`: The attrribute is obtained through the http response header
- `response_body`: The attrribute is obtained through the http response body - `response_body`: The attrribute is obtained through the http response body

View File

@@ -3,15 +3,13 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"time" "time"
@@ -28,14 +27,16 @@ func main() {
} }
const ( const (
// Trace span prefix
TracePrefix = "trace_span_tag."
// Context consts // Context consts
StatisticsRequestStartTime = "ai-statistics-request-start-time" StatisticsRequestStartTime = "ai-statistics-request-start-time"
StatisticsFirstTokenTime = "ai-statistics-first-token-time" StatisticsFirstTokenTime = "ai-statistics-first-token-time"
CtxGeneralAtrribute = "attributes" CtxGeneralAtrribute = "attributes"
CtxLogAtrribute = "logAttributes" CtxLogAtrribute = "logAttributes"
CtxStreamingBodyBuffer = "streamingBodyBuffer" CtxStreamingBodyBuffer = "streamingBodyBuffer"
RouteName = "route"
ClusterName = "cluster"
APIName = "api"
ConsumerKey = "x-mse-consumer"
// Source Type // Source Type
FixedValue = "fixed_value" FixedValue = "fixed_value"
@@ -46,12 +47,14 @@ const (
ResponseBody = "response_body" ResponseBody = "response_body"
// Inner metric & log attributes name // Inner metric & log attributes name
Model = "model" Model = "model"
InputToken = "input_token" InputToken = "input_token"
OutputToken = "output_token" OutputToken = "output_token"
LLMFirstTokenDuration = "llm_first_token_duration" LLMFirstTokenDuration = "llm_first_token_duration"
LLMServiceDuration = "llm_service_duration" LLMServiceDuration = "llm_service_duration"
LLMDurationCount = "llm_duration_count" LLMDurationCount = "llm_duration_count"
LLMStreamDurationCount = "llm_stream_duration_count"
ResponseType = "response_type"
// Extract Rule // Extract Rule
RuleFirst = "first" RuleFirst = "first"
@@ -79,8 +82,8 @@ type AIStatisticsConfig struct {
shouldBufferStreamingBody bool shouldBufferStreamingBody bool
} }
func generateMetricName(route, cluster, model, metricName string) string { func generateMetricName(route, cluster, model, consumer, metricName string) string {
return fmt.Sprintf("route.%s.upstream.%s.model.%s.metric.%s", route, cluster, model, metricName) return fmt.Sprintf("route.%s.upstream.%s.model.%s.consumer.%s.metric.%s", route, cluster, model, consumer, metricName)
} }
func getRouteName() (string, error) { func getRouteName() (string, error) {
@@ -91,6 +94,19 @@ func getRouteName() (string, error) {
} }
} }
func getAPIName() (string, error) {
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
return "-", err
} else {
parts := strings.Split(string(raw), "@")
if len(parts) != 5 {
return "-", errors.New("not api type")
} else {
return strings.Join(parts[:3], "@"), nil
}
}
}
func getClusterName() (string, error) { func getClusterName() (string, error) {
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil { if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
return "-", err return "-", err
@@ -100,6 +116,9 @@ func getClusterName() (string, error) {
} }
func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) { func (config *AIStatisticsConfig) incrementCounter(metricName string, inc uint64) {
if inc == 0 {
return
}
counter, ok := config.counterMetrics[metricName] counter, ok := config.counterMetrics[metricName]
if !ok { if !ok {
counter = proxywasm.DefineCounterMetric(metricName) counter = proxywasm.DefineCounterMetric(metricName)
@@ -133,9 +152,19 @@ func parseConfig(configJson gjson.Result, config *AIStatisticsConfig, log wrappe
} }
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action { func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) types.Action {
ctx.SetContext(CtxGeneralAtrribute, map[string]string{}) route, _ := getRouteName()
ctx.SetContext(CtxLogAtrribute, map[string]string{}) cluster, _ := getClusterName()
api, api_error := getAPIName()
if api_error == nil {
route = api
}
ctx.SetContext(RouteName, route)
ctx.SetContext(ClusterName, cluster)
ctx.SetUserAttribute(APIName, api)
ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli()) ctx.SetContext(StatisticsRequestStartTime, time.Now().UnixMilli())
if consumer, _ := proxywasm.GetHttpRequestHeader(ConsumerKey); consumer != "" {
ctx.SetContext(ConsumerKey, consumer)
}
// Set user defined log & span attributes which type is fixed_value // Set user defined log & span attributes which type is fixed_value
setAttributeBySource(ctx, config, FixedValue, nil, log) setAttributeBySource(ctx, config, FixedValue, nil, log)
@@ -149,6 +178,9 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIStatisticsConfig, lo
func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { func onHttpRequestBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
// Set user defined log & span attributes. // Set user defined log & span attributes.
setAttributeBySource(ctx, config, RequestBody, body, log) setAttributeBySource(ctx, config, RequestBody, body, log)
// Write log
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
return types.ActionContinue return types.ActionContinue
} }
@@ -177,6 +209,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer) ctx.SetContext(CtxStreamingBodyBuffer, streamingBodyBuffer)
} }
ctx.SetUserAttribute(ResponseType, "stream")
// Get requestStartTime from http context // Get requestStartTime from http context
requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64) requestStartTime, ok := ctx.GetContext(StatisticsRequestStartTime).(int64)
if !ok { if !ok {
@@ -188,28 +222,19 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
if ctx.GetContext(StatisticsFirstTokenTime) == nil { if ctx.GetContext(StatisticsFirstTokenTime) == nil {
firstTokenTime := time.Now().UnixMilli() firstTokenTime := time.Now().UnixMilli()
ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime) ctx.SetContext(StatisticsFirstTokenTime, firstTokenTime)
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) ctx.SetUserAttribute(LLMFirstTokenDuration, firstTokenTime-requestStartTime)
attributes[LLMFirstTokenDuration] = fmt.Sprint(firstTokenTime - requestStartTime)
ctx.SetContext(CtxGeneralAtrribute, attributes)
} }
// Set information about this request // Set information about this request
if model, inputToken, outputToken, ok := getUsage(data); ok { if model, inputToken, outputToken, ok := getUsage(data); ok {
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) ctx.SetUserAttribute(Model, model)
// Record Log Attributes ctx.SetUserAttribute(InputToken, inputToken)
attributes[Model] = model ctx.SetUserAttribute(OutputToken, outputToken)
attributes[InputToken] = fmt.Sprint(inputToken)
attributes[OutputToken] = fmt.Sprint(outputToken)
// Set attributes to http context
ctx.SetContext(CtxGeneralAtrribute, attributes)
} }
// If the end of the stream is reached, record metrics/logs/spans. // If the end of the stream is reached, record metrics/logs/spans.
if endOfStream { if endOfStream {
responseEndTime := time.Now().UnixMilli() responseEndTime := time.Now().UnixMilli()
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime)
ctx.SetContext(CtxGeneralAtrribute, attributes)
// Set user defined log & span attributes. // Set user defined log & span attributes.
if config.shouldBufferStreamingBody { if config.shouldBufferStreamingBody {
@@ -220,11 +245,8 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log) setAttributeBySource(ctx, config, ResponseStreamingBody, streamingBodyBuffer, log)
} }
// Write inner filter states which can be used by other plugins such as ai-token-ratelimit
writeFilterStates(ctx, log)
// Write log // Write log
writeLog(ctx, log) ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics // Write metrics
writeMetric(ctx, config, log) writeMetric(ctx, config, log)
@@ -233,33 +255,26 @@ func onHttpStreamingBody(ctx wrapper.HttpContext, config AIStatisticsConfig, dat
} }
func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action { func onHttpResponseBody(ctx wrapper.HttpContext, config AIStatisticsConfig, body []byte, log wrapper.Log) types.Action {
// Get attributes from http context
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
// Get requestStartTime from http context // Get requestStartTime from http context
requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64) requestStartTime, _ := ctx.GetContext(StatisticsRequestStartTime).(int64)
responseEndTime := time.Now().UnixMilli() responseEndTime := time.Now().UnixMilli()
attributes[LLMServiceDuration] = fmt.Sprint(responseEndTime - requestStartTime) ctx.SetUserAttribute(LLMServiceDuration, responseEndTime-requestStartTime)
ctx.SetUserAttribute(ResponseType, "normal")
// Set information about this request // Set information about this request
model, inputToken, outputToken, ok := getUsage(body) if model, inputToken, outputToken, ok := getUsage(body); ok {
if ok { ctx.SetUserAttribute(Model, model)
attributes[Model] = model ctx.SetUserAttribute(InputToken, inputToken)
attributes[InputToken] = fmt.Sprint(inputToken) ctx.SetUserAttribute(OutputToken, outputToken)
attributes[OutputToken] = fmt.Sprint(outputToken)
// Update attributes
ctx.SetContext(CtxGeneralAtrribute, attributes)
} }
// Set user defined log & span attributes. // Set user defined log & span attributes.
setAttributeBySource(ctx, config, ResponseBody, body, log) setAttributeBySource(ctx, config, ResponseBody, body, log)
// Write inner filter states which can be used by other plugins such as ai-token-ratelimit
writeFilterStates(ctx, log)
// Write log // Write log
writeLog(ctx, log) ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
// Write metrics // Write metrics
writeMetric(ctx, config, log) writeMetric(ctx, config, log)
@@ -294,67 +309,49 @@ func getUsage(data []byte) (model string, inputTokenUsage int64, outputTokenUsag
// fetches the tracing span value from the specified source. // fetches the tracing span value from the specified source.
func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) { func setAttributeBySource(ctx wrapper.HttpContext, config AIStatisticsConfig, source string, body []byte, log wrapper.Log) {
attributes, ok := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
if !ok {
log.Error("failed to get attributes from http context")
return
}
for _, attribute := range config.attributes { for _, attribute := range config.attributes {
var key string
var value interface{}
if source == attribute.ValueSource { if source == attribute.ValueSource {
key = attribute.Key
switch source { switch source {
case FixedValue: case FixedValue:
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, attribute.Value) value = attribute.Value
attributes[attribute.Key] = attribute.Value
case RequestHeader: case RequestHeader:
if value, err := proxywasm.GetHttpRequestHeader(attribute.Value); err == nil { value, _ = proxywasm.GetHttpRequestHeader(attribute.Value)
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
attributes[attribute.Key] = value
}
case RequestBody: case RequestBody:
raw := gjson.GetBytes(body, attribute.Value).Raw value = gjson.GetBytes(body, attribute.Value).Value()
var value string
if len(raw) > 2 {
value = raw[1 : len(raw)-1]
}
log.Debugf("[attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
attributes[attribute.Key] = value
case ResponseHeader: case ResponseHeader:
if value, err := proxywasm.GetHttpResponseHeader(attribute.Value); err == nil { value, _ = proxywasm.GetHttpResponseHeader(attribute.Value)
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
attributes[attribute.Key] = value
}
case ResponseStreamingBody: case ResponseStreamingBody:
value := extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log) value = extractStreamingBodyByJsonPath(body, attribute.Value, attribute.Rule, log)
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
attributes[attribute.Key] = value
case ResponseBody: case ResponseBody:
value := gjson.GetBytes(body, attribute.Value).Raw value = gjson.GetBytes(body, attribute.Value).Value()
if len(value) > 2 && value[0] == '"' && value[len(value)-1] == '"' {
value = value[1 : len(value)-1]
}
log.Debugf("[log attribute] source type: %s, key: %s, value: %s", source, attribute.Key, value)
attributes[attribute.Key] = value
default: default:
} }
} log.Debugf("[attribute] source type: %s, key: %s, value: %+v", source, key, value)
if attribute.ApplyToLog { if attribute.ApplyToLog {
setLogAttribute(ctx, attribute.Key, attributes[attribute.Key], log) ctx.SetUserAttribute(key, value)
} }
if attribute.ApplyToSpan { // for metrics
setSpanAttribute(attribute.Key, attributes[attribute.Key], log) if key == Model || key == InputToken || key == OutputToken {
ctx.SetContext(key, value)
}
if attribute.ApplyToSpan {
setSpanAttribute(key, value, log)
}
} }
} }
ctx.SetContext(CtxGeneralAtrribute, attributes)
} }
func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) string { func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, log wrapper.Log) interface{} {
chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
var value string var value interface{}
if rule == RuleFirst { if rule == RuleFirst {
for _, chunk := range chunks { for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath) jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() { if jsonObj.Exists() {
value = jsonObj.String() value = jsonObj.Value()
break break
} }
} }
@@ -362,140 +359,117 @@ func extractStreamingBodyByJsonPath(data []byte, jsonPath string, rule string, l
for _, chunk := range chunks { for _, chunk := range chunks {
jsonObj := gjson.GetBytes(chunk, jsonPath) jsonObj := gjson.GetBytes(chunk, jsonPath)
if jsonObj.Exists() { if jsonObj.Exists() {
value = jsonObj.String() value = jsonObj.Value()
} }
} }
} else if rule == RuleAppend { } else if rule == RuleAppend {
// extract llm response // extract llm response
var strValue string
for _, chunk := range chunks { for _, chunk := range chunks {
raw := gjson.GetBytes(chunk, jsonPath).Raw jsonObj := gjson.GetBytes(chunk, jsonPath)
if len(raw) > 2 && raw[0] == '"' && raw[len(raw)-1] == '"' { if jsonObj.Exists() {
value += raw[1 : len(raw)-1] strValue += jsonObj.String()
} }
} }
value = strValue
} else { } else {
log.Errorf("unsupported rule type: %s", rule) log.Errorf("unsupported rule type: %s", rule)
} }
return value return value
} }
func setFilterState(key, value string, log wrapper.Log) {
if value != "" {
if e := proxywasm.SetProperty([]string{key}, []byte(fmt.Sprint(value))); e != nil {
log.Errorf("failed to set %s in filter state: %v", key, e)
}
} else {
log.Debugf("failed to write filter state [%s], because it's value is empty")
}
}
// Set the tracing span with value. // Set the tracing span with value.
func setSpanAttribute(key, value string, log wrapper.Log) { func setSpanAttribute(key string, value interface{}, log wrapper.Log) {
if value != "" { if value != "" {
traceSpanTag := TracePrefix + key traceSpanTag := wrapper.TraceSpanTagPrefix + key
if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(value)); e != nil { if e := proxywasm.SetProperty([]string{traceSpanTag}, []byte(fmt.Sprint(value))); e != nil {
log.Errorf("failed to set %s in filter state: %v", traceSpanTag, e) log.Warnf("failed to set %s in filter state: %v", traceSpanTag, e)
} }
} else { } else {
log.Debugf("failed to write span attribute [%s], because it's value is empty") log.Debugf("failed to write span attribute [%s], because it's value is empty")
} }
} }
// fetches the tracing span value from the specified source.
func setLogAttribute(ctx wrapper.HttpContext, key string, value interface{}, log wrapper.Log) {
logAttributes, ok := ctx.GetContext(CtxLogAtrribute).(map[string]string)
if !ok {
log.Error("failed to get logAttributes from http context")
return
}
logAttributes[key] = fmt.Sprint(value)
ctx.SetContext(CtxLogAtrribute, logAttributes)
}
func writeFilterStates(ctx wrapper.HttpContext, log wrapper.Log) {
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string)
setFilterState(Model, attributes[Model], log)
setFilterState(InputToken, attributes[InputToken], log)
setFilterState(OutputToken, attributes[OutputToken], log)
}
func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) { func writeMetric(ctx wrapper.HttpContext, config AIStatisticsConfig, log wrapper.Log) {
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) // Generate usage metrics
route, _ := getRouteName() var ok bool
cluster, _ := getClusterName() var route, cluster, model string
model, ok := attributes["model"] var inputToken, outputToken uint64
consumer := ctx.GetStringContext(ConsumerKey, "none")
route, ok = ctx.GetContext(RouteName).(string)
if !ok { if !ok {
log.Errorf("Get model failed") log.Warnf("RouteName typd assert failed, skip metric record")
return return
} }
if inputToken, ok := attributes[InputToken]; ok { cluster, ok = ctx.GetContext(ClusterName).(string)
inputTokenUint64, err := strconv.ParseUint(inputToken, 10, 0) if !ok {
if err != nil || inputTokenUint64 == 0 { log.Warnf("ClusterName typd assert failed, skip metric record")
log.Errorf("inputToken convert failed, value is %d, err msg is [%v]", inputTokenUint64, err) return
}
if ctx.GetUserAttribute(Model) == nil || ctx.GetUserAttribute(InputToken) == nil || ctx.GetUserAttribute(OutputToken) == nil {
log.Warnf("get usage information failed, skip metric record")
return
}
model, ok = ctx.GetUserAttribute(Model).(string)
if !ok {
log.Warnf("Model typd assert failed, skip metric record")
return
}
inputToken, ok = convertToUInt(ctx.GetUserAttribute(InputToken))
if !ok {
log.Warnf("InputToken typd assert failed, skip metric record")
return
}
outputToken, ok = convertToUInt(ctx.GetUserAttribute(OutputToken))
if !ok {
log.Warnf("OutputToken typd assert failed, skip metric record")
return
}
if inputToken == 0 || outputToken == 0 {
log.Warnf("inputToken and outputToken cannot equal to 0, skip metric record")
return
}
config.incrementCounter(generateMetricName(route, cluster, model, consumer, InputToken), inputToken)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, OutputToken), outputToken)
// Generate duration metrics
var llmFirstTokenDuration, llmServiceDuration uint64
// Is stream response
if ctx.GetUserAttribute(LLMFirstTokenDuration) != nil {
llmFirstTokenDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMFirstTokenDuration))
if !ok {
log.Warnf("LLMFirstTokenDuration typd assert failed")
return return
} }
config.incrementCounter(generateMetricName(route, cluster, model, InputToken), inputTokenUint64) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMFirstTokenDuration), llmFirstTokenDuration)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMStreamDurationCount), 1)
} }
if outputToken, ok := attributes[OutputToken]; ok { if ctx.GetUserAttribute(LLMServiceDuration) != nil {
outputTokenUint64, err := strconv.ParseUint(outputToken, 10, 0) llmServiceDuration, ok = convertToUInt(ctx.GetUserAttribute(LLMServiceDuration))
if err != nil || outputTokenUint64 == 0 { if !ok {
log.Errorf("outputToken convert failed, value is %d, err msg is [%v]", outputTokenUint64, err) log.Warnf("LLMServiceDuration typd assert failed")
return return
} }
config.incrementCounter(generateMetricName(route, cluster, model, OutputToken), outputTokenUint64) config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMServiceDuration), llmServiceDuration)
config.incrementCounter(generateMetricName(route, cluster, model, consumer, LLMDurationCount), 1)
} }
if llmFirstTokenDuration, ok := attributes[LLMFirstTokenDuration]; ok {
llmFirstTokenDurationUint64, err := strconv.ParseUint(llmFirstTokenDuration, 10, 0)
if err != nil || llmFirstTokenDurationUint64 == 0 {
log.Errorf("llmFirstTokenDuration convert failed, value is %d, err msg is [%v]", llmFirstTokenDurationUint64, err)
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMFirstTokenDuration), llmFirstTokenDurationUint64)
}
if llmServiceDuration, ok := attributes[LLMServiceDuration]; ok {
llmServiceDurationUint64, err := strconv.ParseUint(llmServiceDuration, 10, 0)
if err != nil || llmServiceDurationUint64 == 0 {
log.Errorf("llmServiceDuration convert failed, value is %d, err msg is [%v]", llmServiceDurationUint64, err)
return
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMServiceDuration), llmServiceDurationUint64)
}
config.incrementCounter(generateMetricName(route, cluster, model, LLMDurationCount), 1)
} }
func writeLog(ctx wrapper.HttpContext, log wrapper.Log) { func convertToUInt(val interface{}) (uint64, bool) {
attributes, _ := ctx.GetContext(CtxGeneralAtrribute).(map[string]string) switch v := val.(type) {
logAttributes, _ := ctx.GetContext(CtxLogAtrribute).(map[string]string) case float32:
// Set inner log fields return uint64(v), true
if attributes[Model] != "" { case float64:
logAttributes[Model] = attributes[Model] return uint64(v), true
} case int32:
if attributes[InputToken] != "" { return uint64(v), true
logAttributes[InputToken] = attributes[InputToken] case int64:
} return uint64(v), true
if attributes[OutputToken] != "" { case uint32:
logAttributes[OutputToken] = attributes[OutputToken] return uint64(v), true
} case uint64:
if attributes[LLMFirstTokenDuration] != "" { return v, true
logAttributes[LLMFirstTokenDuration] = attributes[LLMFirstTokenDuration] default:
} return 0, false
if attributes[LLMServiceDuration] != "" {
logAttributes[LLMServiceDuration] = attributes[LLMServiceDuration]
}
// Traverse log fields
items := []string{}
for k, v := range logAttributes {
items = append(items, fmt.Sprintf(`"%s":"%s"`, k, v))
}
aiLogField := fmt.Sprintf(`{%s}`, strings.Join(items, ","))
// log.Infof("ai request json log: %s", aiLogField)
jsonMap := map[string]string{
"ai_log": aiLogField,
}
serialized, _ := json.Marshal(jsonMap)
jsonLogRaw := gjson.GetBytes(serialized, "ai_log").Raw
jsonLog := jsonLogRaw[1 : len(jsonLogRaw)-1]
if err := proxywasm.SetProperty([]string{"ai_log"}, []byte(jsonLog)); err != nil {
log.Errorf("failed to set ai_log in filter state: %v", err)
} }
} }

View File

@@ -5,8 +5,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew= github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg= github.com/higress-group/proxy-wasm-go-sdk v1.0.0 h1:BZRNf4R7jr9hwRivg/E29nkVaKEak5MWjBDhWjuHijU=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0= github.com/higress-group/proxy-wasm-go-sdk v1.0.0/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo= github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A= github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
@@ -14,8 +13,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8= github.com/tetratelabs/wazero v1.7.1 h1:QtSfd6KLc41DIMpDYlJdoMc6k7QTN246DM2+n2Y/Dx8=
github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y= github.com/tetratelabs/wazero v1.7.1/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=

View File

@@ -15,6 +15,7 @@
package main package main
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
@@ -61,9 +62,9 @@ const (
ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字 ConsumerHeader string = "x-mse-consumer" // LimitByConsumer从该request header获取consumer的名字
CookieHeader string = "cookie" CookieHeader string = "cookie"
RateLimitLimitHeader string = "X-RateLimit-Limit" // 限制的总请求数 RateLimitLimitHeader string = "X-TokenRateLimit-Limit" // 限制的总请求数
RateLimitRemainingHeader string = "X-RateLimit-Remaining" // 剩余还可以发送的请求数 RateLimitRemainingHeader string = "X-TokenRateLimit-Remaining" // 剩余还可以发送的请求数
RateLimitResetHeader string = "X-RateLimit-Reset" // 限流重置时间(触发限流时返回) RateLimitResetHeader string = "X-TokenRateLimit-Reset" // 限流重置时间(触发限流时返回)
) )
type LimitContext struct { type LimitContext struct {
@@ -124,6 +125,8 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
} }
if context.remaining < 0 { if context.remaining < 0 {
// 触发限流 // 触发限流
ctx.SetUserAttribute("token_ratelimit_status", "limited")
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
rejected(config, context) rejected(config, context)
} else { } else {
proxywasm.ResumeHttpRequest() proxywasm.ResumeHttpRequest()
@@ -137,39 +140,49 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
} }
func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte { func onHttpStreamingBody(ctx wrapper.HttpContext, config ClusterKeyRateLimitConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
if !endOfStream { var inputToken, outputToken int64
return data if inputToken, outputToken, ok := getUsage(data); ok {
ctx.SetContext("input_token", inputToken)
ctx.SetContext("output_token", outputToken)
} }
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"}) if endOfStream {
if err != nil { if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil {
return data return data
}
inputToken = ctx.GetContext("input_token").(int64)
outputToken = ctx.GetContext("output_token").(int64)
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
if !ok {
return data
}
keys := []interface{}{limitRedisContext.key}
args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
err := config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil)
if err != nil {
log.Errorf("redis call failed: %v", err)
}
} }
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"}) return data
if err != nil { }
return data
}
inputToken, err := strconv.Atoi(string(inputTokenStr))
if err != nil {
return data
}
outputToken, err := strconv.Atoi(string(outputTokenStr))
if err != nil {
return data
}
limitRedisContext, ok := ctx.GetContext(LimitRedisContextKey).(LimitRedisContext)
if !ok {
return data
}
keys := []interface{}{limitRedisContext.key}
args := []interface{}{limitRedisContext.count, limitRedisContext.window, inputToken + outputToken}
err = config.redisClient.Eval(ResponsePhaseFixedWindowScript, 1, keys, args, nil) func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) {
if err != nil { chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n"))
log.Errorf("redis call failed: %v", err) for _, chunk := range chunks {
return data // the feature strings are used to identify the usage data, like:
} else { // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}}
return data if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) {
continue
}
inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens")
outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens")
if inputTokenObj.Exists() && outputTokenObj.Exists() {
inputTokenUsage = inputTokenObj.Int()
outputTokenUsage = outputTokenObj.Int()
ok = true
return
}
} }
return
} }
func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) { func checkRequestAgainstLimitRule(ctx wrapper.HttpContext, ruleItems []LimitRuleItem, log wrapper.Log) (string, *LimitRuleItem, *LimitConfigItem) {

View File

@@ -20,6 +20,7 @@ description: 前端灰度插件配置参考
| `localStorageGrayKey` | string | 非必填 | - | 使用JWT鉴权方式用户ID的唯一标识来自`localStorage`中,如果配置了当前参数,则`grayKey`失效 | | `localStorageGrayKey` | string | 非必填 | - | 使用JWT鉴权方式用户ID的唯一标识来自`localStorage`中,如果配置了当前参数,则`grayKey`失效 |
| `graySubKey` | string | 非必填 | - | 用户身份信息可能以JSON形式透出比如`userInfo:{ userCode:"001" }`,当前例子`graySubKey`取值为`userCode` | | `graySubKey` | string | 非必填 | - | 用户身份信息可能以JSON形式透出比如`userInfo:{ userCode:"001" }`,当前例子`graySubKey`取值为`userCode` |
| `userStickyMaxAge` | int | 非必填 | 172800 | 用户粘滞的时长:单位为秒,默认为`172800`2天时间 | | `userStickyMaxAge` | int | 非必填 | 172800 | 用户粘滞的时长:单位为秒,默认为`172800`2天时间 |
| `includePathPrefixes` | array of strings | 非必填 | - | 强制处理的路径。例如,在 微前端 场景下XHR 接口如: `/resource/xxx`本质是一个资源请求,需要走插件转发逻辑。 |
| `skippedPathPrefixes` | array of strings | 非必填 | - | 用于排除特定路径,避免当前插件处理这些请求。例如,在 rewrite 场景下XHR 接口请求 `/api/xxx` 如果经过插件转发逻辑,可能会导致非预期的结果。 | | `skippedPathPrefixes` | array of strings | 非必填 | - | 用于排除特定路径,避免当前插件处理这些请求。例如,在 rewrite 场景下XHR 接口请求 `/api/xxx` 如果经过插件转发逻辑,可能会导致非预期的结果。 |
| `skippedByHeaders` | map of string to string | 非必填 | - | 用于通过请求头过滤,指定哪些请求不被当前插件 | `skippedByHeaders` | map of string to string | 非必填 | - | 用于通过请求头过滤,指定哪些请求不被当前插件
处理。`skippedPathPrefixes` 的优先级高于当前配置且页面HTML请求不受本配置的影响。若本配置为空默认会判断`sec-fetch-mode=cors`以及`upgrade=websocket`两个header头进行过滤 | 处理。`skippedPathPrefixes` 的优先级高于当前配置且页面HTML请求不受本配置的影响。若本配置为空默认会判断`sec-fetch-mode=cors`以及`upgrade=websocket`两个header头进行过滤 |

View File

@@ -64,6 +64,7 @@ type GrayConfig struct {
BackendGrayTag string BackendGrayTag string
Injection *Injection Injection *Injection
SkippedPathPrefixes []string SkippedPathPrefixes []string
IncludePathPrefixes []string
SkippedByHeaders map[string]string SkippedByHeaders map[string]string
} }
@@ -97,6 +98,7 @@ func JsonToGrayConfig(json gjson.Result, grayConfig *GrayConfig) {
grayConfig.Html = json.Get("html").String() grayConfig.Html = json.Get("html").String()
grayConfig.SkippedPathPrefixes = convertToStringList(json.Get("skippedPathPrefixes").Array()) grayConfig.SkippedPathPrefixes = convertToStringList(json.Get("skippedPathPrefixes").Array())
grayConfig.SkippedByHeaders = convertToStringMap(json.Get("skippedByHeaders")) grayConfig.SkippedByHeaders = convertToStringMap(json.Get("skippedByHeaders"))
grayConfig.IncludePathPrefixes = convertToStringList(json.Get("includePathPrefixes").Array())
if grayConfig.UserStickyMaxAge == "" { if grayConfig.UserStickyMaxAge == "" {
// 默认值2天 // 默认值2天

View File

@@ -64,7 +64,13 @@ func IsRequestSkippedByHeaders(grayConfig config.GrayConfig) bool {
} }
func IsGrayEnabled(grayConfig config.GrayConfig, requestPath string) bool { func IsGrayEnabled(grayConfig config.GrayConfig, requestPath string) bool {
// 当前路径中前缀为 SkipedRoute则不走插件逻辑 for _, prefix := range grayConfig.IncludePathPrefixes {
if strings.HasPrefix(requestPath, prefix) {
return true
}
}
// 当前路径中前缀为 SkippedPathPrefixes则不走插件逻辑
for _, prefix := range grayConfig.SkippedPathPrefixes { for _, prefix := range grayConfig.SkippedPathPrefixes {
if strings.HasPrefix(requestPath, prefix) { if strings.HasPrefix(requestPath, prefix) {
return false return false

View File

@@ -32,6 +32,7 @@ description: OIDC 认证插件配置参考
| client_secret | string | the OAuth Client Secret | | | client_secret | string | the OAuth Client Secret | |
| provider | string | OAuth provider | oidc | | provider | string | OAuth provider | oidc |
| pass_authorization_header | bool | pass OIDC IDToken to upstream via Authorization Bearer header | true | | pass_authorization_header | bool | pass OIDC IDToken to upstream via Authorization Bearer header | true |
| pass_access_token | bool | pass OIDC Access Token to upstream via X-Forwarded-Access-Token header. | False |
| oidc_issuer_url | string | the OpenID Connect issuer URL, e.g. `"https://dev-o43xb1mz7ya7ach4.us.auth0.com"` | | | oidc_issuer_url | string | the OpenID Connect issuer URL, e.g. `"https://dev-o43xb1mz7ya7ach4.us.auth0.com"` | |
| oidc_verifier_request_timeout | uint32 | OIDC verifier discovery request timeout | 2000(ms) | | oidc_verifier_request_timeout | uint32 | OIDC verifier discovery request timeout | 2000(ms) |
| scope | string | OAuth scope specification | | | scope | string | OAuth scope specification | |
@@ -296,6 +297,55 @@ match_list:
![aliyun_result](https://gw.alicdn.com/imgextra/i3/O1CN015pGvi51eakt3pFS8Y_!!6000000003888-0-tps-3840-2160.jpg) ![aliyun_result](https://gw.alicdn.com/imgextra/i3/O1CN015pGvi51eakt3pFS8Y_!!6000000003888-0-tps-3840-2160.jpg)
### Github 配置示例
#### Step 1: 配置 Github OAuth应用
通过 https://github.com/settings/developers 创建OAuthApp
#### Step 2: Higress 配置服务来源
* 创建DNS类型服务来源地址为github.com
* 创建DNS类型服务来源地址为api.github.com用于验证OIDC流程中的access_token
![github_service](https://www.helloimg.com/i/2024/12/31/677398a2b34be.png)
#### Step 3: OIDC 服务 HTTPS 配置
参考Auth0的Step3对创建的两个DNS服务配置Ingress
#### Step 4: Wasm 插件配置
```yaml
redirect_url: 'http://foo.bar.com/oauth2/callback'
provider: github
oidc_issuer_url: 'https://github.com/'
pass_access_token: true
client_id: 'XXXXXXXXXXXXXXXX'
client_secret: 'XXXXXXXXXXXXXXXX'
scope: 'user repo'
cookie_secret: 'nqavJrGvRmQxWwGNptLdyUVKcBNZ2b18Guc1n_8DCfY='
service_name: 'github.dns'
service_port: 443
validate_service_name: 'api.dns'
validate_service_port: 443
match_type: 'whitelist'
match_list:
- match_rule_domain: '*.bar.com'
match_rule_path: '/headers'
match_rule_type: 'prefix'
```
#### 访问服务页面,未登陆的话进行跳转
![github_login](https://www.helloimg.com/i/2024/12/31/6773983f64b3c.png)
#### 登陆成功跳转到服务页面
配置了`pass_access_token=true`后会在`X-Forwarded-Access-Token`header头中携带access_token
![github_result](https://www.helloimg.com/i/2024/12/31/677398de64872.png)
### OIDC 流程图 ### OIDC 流程图
<p align="center"> <p align="center">
@@ -422,5 +472,4 @@ curl -X POST \
``` ```
4. 携带 Authorization 的标头对应 access_token 访问对应 API 4. 携带 Authorization 的标头对应 access_token 访问对应 API
5. 后端服务根据 access_token 获取用户授权信息并返回对应的 HTTP 响应 5. 后端服务根据 access_token 获取用户授权信息并返回对应的 HTTP 响应

View File

@@ -29,6 +29,7 @@ Plugin execution priority: `350`
| client_secret | string | The OAuth Client Secret | | | client_secret | string | The OAuth Client Secret | |
| provider | string | OAuth provider | oidc | | provider | string | OAuth provider | oidc |
| pass_authorization_header | bool | Pass OIDC IDToken to upstream via Authorization Bearer header | true | | pass_authorization_header | bool | Pass OIDC IDToken to upstream via Authorization Bearer header | true |
| pass_access_token | bool | pass OIDC Access Token to upstream via X-Forwarded-Access-Token header. | False |
| oidc_issuer_url | string | The OpenID Connect issuer URL, e.g. `"https://dev-o43xb1mz7ya7ach4.us.auth0.com"` | | | oidc_issuer_url | string | The OpenID Connect issuer URL, e.g. `"https://dev-o43xb1mz7ya7ach4.us.auth0.com"` | |
| oidc_verifier_request_timeout | uint32 | OIDC verifier discovery request timeout | 2000(ms) | | oidc_verifier_request_timeout | uint32 | OIDC verifier discovery request timeout | 2000(ms) |
| scope | string | OAuth scope specification | | | scope | string | OAuth scope specification | |
@@ -254,6 +255,54 @@ Directly login using a RAM user or click the main account login.
#### Successful Login Redirects to Service Page #### Successful Login Redirects to Service Page
![aliyun_result](https://gw.alicdn.com/imgextra/i3/O1CN015pGvi51eakt3pFS8Y_!!6000000003888-0-tps-3840-2160.jpg) ![aliyun_result](https://gw.alicdn.com/imgextra/i3/O1CN015pGvi51eakt3pFS8Y_!!6000000003888-0-tps-3840-2160.jpg)
### Github Configuration Example
#### Step 1: Configure Github OAuth App
Create a new OAuth App: https://github.com/settings/developers
#### Step 2: Higress Configure Service Source
* Create a DNS service with the source address set to github.com.
* Create a DNS service with the source address set to api.github.com (used to validate the access token in the OIDC flow).
![github_service](https://www.helloimg.com/i/2024/12/31/677398a2b34be.png)
#### Step 3: OIDC Service HTTPS Protocol
Configure Ingress for the two created DNS services by referring to Step 3 of Auth0.
#### Step 4: Wasm Plugin Configuration
```yaml
redirect_url: 'http://foo.bar.com/oauth2/callback'
provider: github
oidc_issuer_url: 'https://github.com/'
pass_access_token: true
client_id: 'XXXXXXXXXXXXXXXX'
client_secret: 'XXXXXXXXXXXXXXXX'
scope: 'user repo'
cookie_secret: 'nqavJrGvRmQxWwGNptLdyUVKcBNZ2b18Guc1n_8DCfY='
service_name: 'github.dns'
service_port: 443
validate_service_name: 'api.dns'
validate_service_port: 443
match_type: 'whitelist'
match_list:
- match_rule_domain: '*.bar.com'
match_rule_path: '/headers'
match_rule_type: 'prefix'
```
#### Access Service Page; Redirect if Not Logged In
![github_login](https://www.helloimg.com/i/2024/12/31/6773983f64b3c.png)
#### Successful Login Redirects to Service Page
With pass_access_token=true configured, the access_token will be included in the X-Forwarded-Access-Token header.
![github_result](https://www.helloimg.com/i/2024/12/31/677398de64872.png)
### OIDC Flow Diagram ### OIDC Flow Diagram
<p align="center"> <p align="center">
<img src="https://gw.alicdn.com/imgextra/i3/O1CN01TJSh9c1VwR61Q2nek_!!6000000002717-55-tps-1807-2098.svg" alt="oidc_process" width="600" /> <img src="https://gw.alicdn.com/imgextra/i3/O1CN01TJSh9c1VwR61Q2nek_!!6000000002717-55-tps-1807-2098.svg" alt="oidc_process" width="600" />

Some files were not shown because too many files have changed in this diff Show More