Compare commits

..

42 Commits

Author SHA1 Message Date
Jun
a2c2d1d521 fix: fallbackForInvalidSecret to return original secret (#1245) 2024-08-25 15:59:12 +08:00
Yang
a5a28aebf6 Add x-forwarded-xxx for ext-auth (#1244) 2024-08-23 14:49:08 +08:00
YeHaitao
1c10f36369 feat: support 360 ai model (#1243)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-08-23 11:13:09 +08:00
韩贤涛
7054f01a36 feat: Adapt to the Qwen multimodal model generation API (#1221) 2024-08-22 18:42:16 +08:00
xingyunyang01
895f17f8d8 update: Add support for post tools, add round limits, per-round token… (#1230)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-08-22 16:33:42 +08:00
Pxl
29fcd330d5 feat: support ai-proxy custom settings (#1219) 2024-08-22 13:59:32 +08:00
Yang Beining
0e58042fa6 Support Openai structure output api (#feat 1214) (#1217)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-08-22 12:33:35 +08:00
brother-戎
bdbfad8a8a fix: fix up kingress controller NPE (#1235) 2024-08-22 09:59:55 +08:00
ran xuxin
4307f88645 extend ai-prompt-decorator plugin with client's geographic message from geo-ip plugin (#1228) 2024-08-20 16:14:21 +08:00
007gzs
25b085cb5e feat: ai敏感词拦截插件 (#1190) 2024-08-16 17:24:32 +08:00
urlyy
dcea483c61 Feat: Add Deepl support for plugins/ai-proxy (#1147) 2024-08-15 18:53:56 +08:00
rinfx
8fa1224cba support qwen compatible mode (#1205) 2024-08-15 18:52:49 +08:00
xingyunyang01
8f7c10ee5f feat: add ai-agent plugin (#1192) 2024-08-15 17:05:25 +08:00
澄潭
5a854b990b Update README.md 2024-08-15 09:53:02 +08:00
Jingze
dd11248e47 Update README.md (#1203) 2024-08-14 19:55:21 +08:00
mamba
ba98f3a7ad feat: 🎸 frontend-gray plugin support cdn type deploy (#1178)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-08-14 15:41:32 +08:00
Jun
d31c978ed3 feat: add AI quota plugin (#1200) 2024-08-14 13:43:31 +08:00
Jingze
daa374d9a4 feat: support wasm-assemblyscript sdk (#1175) 2024-08-13 15:31:36 +08:00
澄潭
6b9dabb489 Update README.md 2024-08-12 19:41:10 +08:00
rinfx
6f04404edd crash bugfix (#1198) 2024-08-12 16:42:10 +08:00
韩贤涛
04a9104062 feat: support gemini ai model (#1173) 2024-08-09 09:55:40 +08:00
Se7en
564f8c770a fix: fix tracing configmap template to handle initial installation (#1191) 2024-08-09 08:29:51 +08:00
Se7en
fec2e9dfc9 feat: improve Skywalking and Zipkin integration (#1131) 2024-08-08 22:40:33 +08:00
Jingze
dc4ddb52ee fix bug of empty config plugin still start (#1189) 2024-08-08 18:04:47 +08:00
Jun
6f221ead53 feat:add service rule match for wasmplugin in control panel (#1166) 2024-08-08 18:04:33 +08:00
韩贤涛
53f8410843 feat: ext auth forward_auth endpoint_mode enhancement (#1180) 2024-08-08 18:01:51 +08:00
rinfx
a17ac9e4c6 Optimize ai-rag plugin (#1170) 2024-08-08 18:00:02 +08:00
澄潭
5e95f6f057 Update README.md 2024-08-08 17:14:18 +08:00
澄潭
94f29e56c0 Update README.md 2024-08-08 17:12:33 +08:00
澄潭
870157c576 Update README.md 2024-08-08 15:53:21 +08:00
urlyy
c78ef7011d Feat: Add Spark llm support for plugins/ai-proxy (#1139) 2024-08-08 15:16:58 +08:00
澄潭
dc0dcaaaee azure-openai support other type api (#1187) 2024-08-08 13:33:12 +08:00
EricaLiu
34f5722d93 fix: add support for nacos triple protocol (#1186) 2024-08-08 10:29:48 +08:00
澄潭
55fdddee2f optimize transformer plugin (#1183) 2024-08-08 09:46:11 +08:00
007gzs
980ffde244 Optimize WASM Rust SDK's body caching logic. (#1181) 2024-08-07 20:06:11 +08:00
澄潭
0a578c2a04 ai-proxy: support custom openai provider (#1176)
Co-authored-by: Kent Dong <ch3cho@qq.com>
2024-08-07 10:33:01 +08:00
澄潭
536a3069a8 Update README.md 2024-08-06 20:15:33 +08:00
韩贤涛
08c64ed467 fix:fix bug in ext-auth wasm plugin (#1152) 2024-08-05 11:04:31 +08:00
澄潭
cc74c0da93 replace regexp (#1169) 2024-07-31 17:48:38 +08:00
Kent Dong
210b97b06b fix: Use the official tinygo package to build Wasm go plugin builder (#1161) 2024-07-29 16:05:23 +08:00
007gzs
bccfbde621 fix PluginHttpWrapper 中 Context的回调未代理 . request-block case_sensitive 逻辑错误 (#1146) 2024-07-27 10:25:14 +08:00
澄潭
f1c6e78047 Update Makefile.core.mk 2024-07-26 14:06:38 +08:00
143 changed files with 74557 additions and 775 deletions

View File

@@ -49,6 +49,11 @@ jobs:
with:
go-version: 1.19
- name: Setup Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
if: matrix.wasmPluginType == 'RUST'
- name: Setup Golang Caches
uses: actions/cache@v4
with:

View File

@@ -177,8 +177,8 @@ install: pre-install
cd helm/higress; helm dependency build
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
ENVOY_LATEST_IMAGE_TAG ?= sha-63539ca
ISTIO_LATEST_IMAGE_TAG ?= sha-63539ca
ENVOY_LATEST_IMAGE_TAG ?= sha-59acb61
ISTIO_LATEST_IMAGE_TAG ?= sha-59acb61
install-dev: pre-install
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'

View File

@@ -1,17 +1,18 @@
<h1 align="center">
<img src="https://img.alicdn.com/imgextra/i2/O1CN01NwxLDd20nxfGBjxmZ_!!6000000006895-2-tps-960-290.png" alt="Higress" width="240" height="72.5">
<br>
Cloud Native API Gateway
AI Gateway
</h1>
<h4 align="center"> AI Native API Gateway </h4>
[![Build Status](https://github.com/alibaba/higress/actions/workflows/build-and-test.yaml/badge.svg?branch=main)](https://github.com/alibaba/higress/actions)
[![license](https://img.shields.io/github/license/alibaba/higress.svg)](https://www.apache.org/licenses/LICENSE-2.0.html)
[**官网**](https://higress.io/) &nbsp; |
&nbsp; [**文档**](https://higress.io/zh-cn/docs/overview/what-is-higress) &nbsp; |
&nbsp; [**博客**](https://higress.io/zh-cn/blog) &nbsp; |
&nbsp; [**开发指引**](https://higress.io/zh-cn/docs/developers/developers_dev) &nbsp; |
&nbsp; [**Higress 企业版**](https://www.aliyun.com/product/aliware/mse?spm=higress-website.topbar.0.0.0) &nbsp;
&nbsp; [**文档**](https://higress.io/docs/latest/user/quickstart/) &nbsp; |
&nbsp; [**博客**](https://higress.io/blog/) &nbsp; |
&nbsp; [**开发指引**](https://higress.io/docs/latest/dev/architecture/) &nbsp; |
&nbsp; [**AI插件**](https://higress.io/plugin/) &nbsp;
<p>
@@ -19,21 +20,54 @@
</p>
Higress 是基于阿里内部两年多的 Envoy Gateway 实践沉淀,以开源 [Istio](https://github.com/istio/istio) 与 [Envoy](https://github.com/envoyproxy/envoy) 为核心构建的云原生 API 网关。Higress 实现了安全防护网关、流量网关、微服务网关三层网关合一,可以显著降低网关的部署和运维成本。
Higress 是基于阿里内部多的 Envoy Gateway 实践沉淀,以开源 [Istio](https://github.com/istio/istio) 与 [Envoy](https://github.com/envoyproxy/envoy) 为核心构建的云原生 API 网关。
Higress 在阿里内部作为 AI 网关,承载了通义千问 APP、百炼大模型 API、机器学习 PAI 平台等 AI 业务的流量。
Higress 能够用统一的协议对接国内外所有 LLM 模型厂商,同时具备丰富的 AI 可观测、多模型负载均衡/fallback、AI token 流控、AI 缓存等能力:
![](https://img.alicdn.com/imgextra/i1/O1CN01fNnhCp1cV8mYPRFeS_!!6000000003605-0-tps-1080-608.jpg)
![arch](https://img.alicdn.com/imgextra/i1/O1CN01iO9ph825juHbOIg75_!!6000000007563-2-tps-2483-2024.png)
## Summary
- [**快速开始**](#快速开始)
- [**功能展示**](#功能展示)
- [**使用场景**](#使用场景)
- [**核心优势**](#核心优势)
- [**Quick Start**](https://higress.io/zh-cn/docs/user/quickstart)
- [**社区**](#社区)
## 快速开始
Higress 只需 Docker 即可启动,方便个人开发者在本地搭建学习,或者用于搭建简易站点:
```bash
# 创建一个工作目录
mkdir higress; cd higress
# 启动 higress配置文件会写到工作目录下
docker run -d --rm --name higress-ai -v ${PWD}:/data \
-p 8001:8001 -p 8080:8080 -p 8443:8443 \
higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/all-in-one:latest
```
监听端口说明如下:
- 8001 端口Higress UI 控制台入口
- 8080 端口:网关 HTTP 协议入口
- 8443 端口:网关 HTTPS 协议入口
**Higress 的所有 Docker 镜像都一直使用自己独享的仓库,不受 Docker Hub 境内不可访问的影响**
K8s 下使用 Helm 部署等其他安装方式可以参考官网 [Quick Start 文档](https://higress.io/docs/latest/user/quickstart/)。
## 使用场景
- **AI 网关**:
Higress 提供了一站式的 AI 插件集,可以增强依赖 AI 能力业务的稳定性、灵活性、可观测性,使得业务与 AI 的集成更加便捷和高效。
- **Kubernetes Ingress 网关**:
Higress 可以作为 K8s 集群的 Ingress 入口网关, 并且兼容了大量 K8s Nginx Ingress 的注解,可以从 K8s Nginx Ingress 快速平滑迁移到 Higress。
@@ -56,27 +90,36 @@ Higress 是基于阿里内部两年多的 Envoy Gateway 实践沉淀,以开源
脱胎于阿里巴巴2年多生产验证的内部产品支持每秒请求量达数十万级的大规模场景。
彻底摆脱 reload 引起的流量抖动,配置变更毫秒级生效且业务无感。
- **平滑演进**
彻底摆脱 Nginx reload 引起的流量抖动,配置变更毫秒级生效且业务无感。对 AI 业务等长连接场景特别友好。
支持 Nacos/Zookeeper/Eureka 等多种注册中心,可以不依赖 K8s Service 进行服务发现,支持非容器架构平滑演进到云原生架构。
- **流式处理**
支持从 Nginx Ingress Controller 平滑迁移,支持平滑过渡到 Gateway API支持业务架构平滑演进到 ServiceMesh
支持真正的完全流式处理请求/响应 BodyWasm 插件很方便地自定义处理 SSE Server-Sent Events等流式协议的报文
- **兼收并蓄**
兼容 Nginx Ingress Annotation 80%+ 的使用场景,且提供功能更丰富的 Higress Annotation 注解。
兼容 Ingress API/Gateway API/Istio API可以组合多种 CRD 实现流量精细化管理。
在 AI 业务等大带宽场景下,可以显著降低内存开销。
- **便于扩展**
提供 Wasm、Lua、进程外三种插件扩展机制支持多语言编写插件生效粒度支持全局级、域名级路由级
提供丰富的官方插件库,涵盖 AI、流量管理、安全防护等常用功能满足90%以上的业务场景需求
主打 Wasm 插件扩展,通过沙箱隔离确保内存安全,支持多种编程语言,允许插件版本独立升级,实现流量无损热更新网关逻辑。
- **安全易用**
基于 Ingress API 和 Gateway API 标准,提供开箱即用的 UI 控制台WAF 防护插件、IP/Cookie CC 防护插件开箱即用。
支持对接 Let's Encrypt 自动签发和续签免费证书,并且可以脱离 K8s 部署,一行 Docker 命令即可启动,方便个人开发者使用。
插件支持热更新,变更插件逻辑和配置都对流量无损。
## 功能展示
### AI 网关 Demo 展示
[从 OpenAI 到其他大模型30 秒完成迁移
](https://www.bilibili.com/video/BV1dT421a7w7/?spm_id_from=333.788.recommend_more_video.14)
### Higress UI 控制台
- **丰富的可观测**

View File

@@ -301,6 +301,7 @@ type MatchRule struct {
Domain []string `protobuf:"bytes,2,rep,name=domain,proto3" json:"domain,omitempty"`
Config *types.Struct `protobuf:"bytes,3,opt,name=config,proto3" json:"config,omitempty"`
ConfigDisable bool `protobuf:"varint,4,opt,name=config_disable,json=configDisable,proto3" json:"config_disable,omitempty"`
Service []string `protobuf:"bytes,5,rep,name=service,proto3" json:"service,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
@@ -367,6 +368,13 @@ func (m *MatchRule) GetConfigDisable() bool {
return false
}
func (m *MatchRule) GetService() []string {
if m != nil {
return m.Service
}
return nil
}
func init() {
proto.RegisterEnum("higress.extensions.v1alpha1.PluginPhase", PluginPhase_name, PluginPhase_value)
proto.RegisterEnum("higress.extensions.v1alpha1.PullPolicy", PullPolicy_name, PullPolicy_value)
@@ -377,46 +385,47 @@ func init() {
func init() { proto.RegisterFile("extensions/v1alpha1/wasm.proto", fileDescriptor_4d60b240916c4e18) }
var fileDescriptor_4d60b240916c4e18 = []byte{
// 619 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x94, 0xdd, 0x4e, 0xdb, 0x4c,
0x10, 0x86, 0x71, 0x02, 0x81, 0x4c, 0x80, 0xcf, 0xac, 0xbe, 0xd2, 0x15, 0x54, 0x69, 0x84, 0xd4,
0xd6, 0xe5, 0xc0, 0x16, 0xa1, 0x3f, 0x27, 0x15, 0x6a, 0x80, 0xb4, 0x44, 0x6d, 0x53, 0xcb, 0x86,
0x56, 0xe5, 0xc4, 0xda, 0x98, 0x8d, 0xb3, 0xea, 0xfa, 0x47, 0xde, 0x35, 0x34, 0x17, 0xd2, 0x7b,
0xea, 0x61, 0x2f, 0xa1, 0xe2, 0x2e, 0x7a, 0x56, 0x65, 0x6d, 0x43, 0x42, 0xab, 0x9c, 0xed, 0xce,
0x3c, 0x33, 0xf3, 0xbe, 0xe3, 0x95, 0xa1, 0x49, 0xbf, 0x49, 0x1a, 0x09, 0x16, 0x47, 0xc2, 0xba,
0xdc, 0x23, 0x3c, 0x19, 0x91, 0x3d, 0xeb, 0x8a, 0x88, 0xd0, 0x4c, 0xd2, 0x58, 0xc6, 0x68, 0x7b,
0xc4, 0x82, 0x94, 0x0a, 0x61, 0xde, 0x72, 0x66, 0xc9, 0x6d, 0x35, 0x83, 0x38, 0x0e, 0x38, 0xb5,
0x14, 0x3a, 0xc8, 0x86, 0xd6, 0x55, 0x4a, 0x92, 0x84, 0xa6, 0x22, 0x2f, 0xde, 0x7a, 0x70, 0x37,
0x2f, 0x64, 0x9a, 0xf9, 0x32, 0xcf, 0xee, 0xfc, 0x5e, 0x04, 0xf8, 0x4c, 0x44, 0x68, 0xf3, 0x2c,
0x60, 0x11, 0xd2, 0xa1, 0x9a, 0xa5, 0x1c, 0x57, 0x5a, 0x9a, 0x51, 0x77, 0x26, 0x47, 0xb4, 0x09,
0x35, 0x31, 0x22, 0xed, 0xe7, 0x2f, 0x70, 0x55, 0x05, 0x8b, 0x1b, 0x72, 0x61, 0x83, 0x85, 0x24,
0xa0, 0x5e, 0x92, 0x71, 0xee, 0x25, 0x31, 0x67, 0xfe, 0x18, 0x2f, 0xb6, 0x34, 0x63, 0xbd, 0xfd,
0xc4, 0x9c, 0xa3, 0xd7, 0xb4, 0x33, 0xce, 0x6d, 0x85, 0x3b, 0xff, 0xa9, 0x0e, 0xb7, 0x01, 0xb4,
0x3b, 0xd3, 0x54, 0x50, 0x3f, 0xa5, 0x12, 0x2f, 0xa9, 0xb9, 0xb7, 0xac, 0xab, 0xc2, 0xe8, 0x29,
0xe8, 0x97, 0x34, 0x65, 0x43, 0xe6, 0x13, 0xc9, 0xe2, 0xc8, 0xfb, 0x4a, 0xc7, 0xb8, 0x96, 0xa3,
0xd3, 0xf1, 0x77, 0x74, 0x8c, 0x5e, 0xc1, 0x5a, 0xa2, 0xfc, 0x79, 0x7e, 0x1c, 0x0d, 0x59, 0x80,
0x97, 0x5b, 0x9a, 0xd1, 0x68, 0xdf, 0x37, 0xf3, 0xd5, 0x98, 0xe5, 0x6a, 0x4c, 0x57, 0xad, 0xc6,
0x59, 0xcd, 0xe9, 0x23, 0x05, 0xa3, 0x87, 0xd0, 0x28, 0xaa, 0x23, 0x12, 0x52, 0xbc, 0xa2, 0x66,
0x40, 0x1e, 0xea, 0x93, 0x90, 0xa2, 0x03, 0x58, 0x4a, 0x46, 0x44, 0x50, 0x5c, 0x57, 0xf6, 0x8d,
0xf9, 0xf6, 0x55, 0x9d, 0x3d, 0xe1, 0x9d, 0xbc, 0x0c, 0xbd, 0x84, 0x95, 0x24, 0x65, 0x71, 0xca,
0xe4, 0x18, 0x83, 0x52, 0xb6, 0xfd, 0x97, 0xb2, 0x5e, 0x24, 0xf7, 0xdb, 0x9f, 0x08, 0xcf, 0xa8,
0x73, 0x03, 0xa3, 0x03, 0x58, 0xbf, 0xa0, 0x43, 0x92, 0x71, 0x59, 0x1a, 0xa3, 0xf3, 0x8d, 0xad,
0x15, 0x78, 0xe1, 0xec, 0x2d, 0x34, 0x42, 0x22, 0xfd, 0x91, 0x97, 0x66, 0x9c, 0x0a, 0x3c, 0x6c,
0x55, 0x8d, 0x46, 0xfb, 0xf1, 0x5c, 0xf9, 0x1f, 0x26, 0xbc, 0x93, 0x71, 0xea, 0x40, 0x58, 0x1e,
0x05, 0x7a, 0x06, 0x9b, 0xb3, 0x42, 0xbc, 0x0b, 0x26, 0xc8, 0x80, 0x53, 0x1c, 0xb4, 0x34, 0x63,
0xc5, 0xf9, 0x7f, 0x66, 0xee, 0x71, 0x9e, 0xdb, 0xf9, 0xae, 0x41, 0xfd, 0xa6, 0x1f, 0xc2, 0xb0,
0xcc, 0x22, 0x35, 0x18, 0x6b, 0xad, 0xaa, 0x51, 0x77, 0xca, 0xeb, 0xe4, 0x09, 0x5e, 0xc4, 0x21,
0x61, 0x11, 0xae, 0xa8, 0x44, 0x71, 0x43, 0x16, 0xd4, 0x0a, 0xdb, 0xd5, 0xf9, 0xb6, 0x0b, 0x0c,
0x3d, 0x82, 0xf5, 0x3b, 0xf2, 0x16, 0x95, 0xbc, 0x35, 0x7f, 0x5a, 0xd7, 0x6e, 0x17, 0x1a, 0x53,
0x5f, 0x09, 0xdd, 0x83, 0x8d, 0xb3, 0xbe, 0x6b, 0x77, 0x8f, 0x7a, 0x6f, 0x7a, 0xdd, 0x63, 0xcf,
0x3e, 0xe9, 0xb8, 0x5d, 0x7d, 0x01, 0xd5, 0x61, 0xa9, 0x73, 0x76, 0x7a, 0xd2, 0xd7, 0xb5, 0xf2,
0x78, 0xae, 0x57, 0x26, 0x47, 0xf7, 0xb4, 0x73, 0xea, 0xea, 0xd5, 0xdd, 0x43, 0x80, 0xa9, 0xa7,
0xbd, 0x09, 0x68, 0xa6, 0xcb, 0xc7, 0xf7, 0xbd, 0xa3, 0x2f, 0xfa, 0x02, 0xd2, 0x61, 0xb5, 0x37,
0xec, 0xc7, 0xd2, 0x4e, 0xa9, 0xa0, 0x91, 0xd4, 0x35, 0x04, 0x50, 0xeb, 0xf0, 0x2b, 0x32, 0x16,
0x7a, 0xe5, 0xf0, 0xf5, 0x8f, 0xeb, 0xa6, 0xf6, 0xf3, 0xba, 0xa9, 0xfd, 0xba, 0x6e, 0x6a, 0xe7,
0xed, 0x80, 0xc9, 0x51, 0x36, 0x30, 0xfd, 0x38, 0xb4, 0x08, 0x67, 0x03, 0x32, 0x20, 0x56, 0xf1,
0xb1, 0x2c, 0x92, 0x30, 0xeb, 0x1f, 0xbf, 0x91, 0x41, 0x4d, 0x2d, 0x63, 0xff, 0x4f, 0x00, 0x00,
0x00, 0xff, 0xff, 0xb9, 0xf2, 0x67, 0xbe, 0x64, 0x04, 0x00, 0x00,
// 631 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x94, 0xdd, 0x6e, 0xd3, 0x4c,
0x10, 0x86, 0xeb, 0xa4, 0x49, 0x9b, 0x49, 0xdb, 0xcf, 0x5d, 0x7d, 0x94, 0x55, 0x8b, 0x42, 0x54,
0x09, 0x30, 0x3d, 0xb0, 0xd5, 0x94, 0x9f, 0x13, 0x54, 0x91, 0xb6, 0x81, 0x46, 0x40, 0xb0, 0xec,
0x16, 0x44, 0x4f, 0xac, 0x8d, 0xbb, 0x71, 0x56, 0xac, 0x7f, 0xe4, 0x5d, 0xb7, 0xe4, 0xaa, 0xb8,
0x0d, 0x0e, 0xb9, 0x04, 0xd4, 0xbb, 0xe0, 0x0c, 0x65, 0xed, 0x34, 0x49, 0x41, 0x39, 0xdb, 0x9d,
0x79, 0x66, 0xe6, 0x7d, 0xc7, 0x2b, 0x43, 0x83, 0x7e, 0x93, 0x34, 0x12, 0x2c, 0x8e, 0x84, 0x75,
0xb5, 0x4f, 0x78, 0x32, 0x24, 0xfb, 0xd6, 0x35, 0x11, 0xa1, 0x99, 0xa4, 0xb1, 0x8c, 0xd1, 0xce,
0x90, 0x05, 0x29, 0x15, 0xc2, 0x9c, 0x72, 0xe6, 0x84, 0xdb, 0x6e, 0x04, 0x71, 0x1c, 0x70, 0x6a,
0x29, 0xb4, 0x9f, 0x0d, 0xac, 0xeb, 0x94, 0x24, 0x09, 0x4d, 0x45, 0x5e, 0xbc, 0xfd, 0xe0, 0x6e,
0x5e, 0xc8, 0x34, 0xf3, 0x65, 0x9e, 0xdd, 0xfd, 0xbd, 0x0c, 0xf0, 0x99, 0x88, 0xd0, 0xe6, 0x59,
0xc0, 0x22, 0xa4, 0x43, 0x39, 0x4b, 0x39, 0x2e, 0x35, 0x35, 0xa3, 0xe6, 0x8c, 0x8f, 0x68, 0x0b,
0xaa, 0x62, 0x48, 0x5a, 0xcf, 0x5f, 0xe0, 0xb2, 0x0a, 0x16, 0x37, 0xe4, 0xc2, 0x26, 0x0b, 0x49,
0x40, 0xbd, 0x24, 0xe3, 0xdc, 0x4b, 0x62, 0xce, 0xfc, 0x11, 0x5e, 0x6e, 0x6a, 0xc6, 0x46, 0xeb,
0x89, 0xb9, 0x40, 0xaf, 0x69, 0x67, 0x9c, 0xdb, 0x0a, 0x77, 0xfe, 0x53, 0x1d, 0xa6, 0x01, 0xb4,
0x37, 0xd7, 0x54, 0x50, 0x3f, 0xa5, 0x12, 0x57, 0xd4, 0xdc, 0x29, 0xeb, 0xaa, 0x30, 0x7a, 0x0a,
0xfa, 0x15, 0x4d, 0xd9, 0x80, 0xf9, 0x44, 0xb2, 0x38, 0xf2, 0xbe, 0xd2, 0x11, 0xae, 0xe6, 0xe8,
0x6c, 0xfc, 0x1d, 0x1d, 0xa1, 0x57, 0xb0, 0x9e, 0x28, 0x7f, 0x9e, 0x1f, 0x47, 0x03, 0x16, 0xe0,
0x95, 0xa6, 0x66, 0xd4, 0x5b, 0xf7, 0xcd, 0x7c, 0x35, 0xe6, 0x64, 0x35, 0xa6, 0xab, 0x56, 0xe3,
0xac, 0xe5, 0xf4, 0xb1, 0x82, 0xd1, 0x43, 0xa8, 0x17, 0xd5, 0x11, 0x09, 0x29, 0x5e, 0x55, 0x33,
0x20, 0x0f, 0xf5, 0x48, 0x48, 0xd1, 0x21, 0x54, 0x92, 0x21, 0x11, 0x14, 0xd7, 0x94, 0x7d, 0x63,
0xb1, 0x7d, 0x55, 0x67, 0x8f, 0x79, 0x27, 0x2f, 0x43, 0x2f, 0x61, 0x35, 0x49, 0x59, 0x9c, 0x32,
0x39, 0xc2, 0xa0, 0x94, 0xed, 0xfc, 0xa5, 0xac, 0x1b, 0xc9, 0x83, 0xd6, 0x27, 0xc2, 0x33, 0xea,
0xdc, 0xc2, 0xe8, 0x10, 0x36, 0x2e, 0xe9, 0x80, 0x64, 0x5c, 0x4e, 0x8c, 0xd1, 0xc5, 0xc6, 0xd6,
0x0b, 0xbc, 0x70, 0xf6, 0x16, 0xea, 0x21, 0x91, 0xfe, 0xd0, 0x4b, 0x33, 0x4e, 0x05, 0x1e, 0x34,
0xcb, 0x46, 0xbd, 0xf5, 0x78, 0xa1, 0xfc, 0x0f, 0x63, 0xde, 0xc9, 0x38, 0x75, 0x20, 0x9c, 0x1c,
0x05, 0x7a, 0x06, 0x5b, 0xf3, 0x42, 0xbc, 0x4b, 0x26, 0x48, 0x9f, 0x53, 0x1c, 0x34, 0x35, 0x63,
0xd5, 0xf9, 0x7f, 0x6e, 0xee, 0x49, 0x9e, 0xdb, 0xfd, 0xae, 0x41, 0xed, 0xb6, 0x1f, 0xc2, 0xb0,
0xc2, 0x22, 0x35, 0x18, 0x6b, 0xcd, 0xb2, 0x51, 0x73, 0x26, 0xd7, 0xf1, 0x13, 0xbc, 0x8c, 0x43,
0xc2, 0x22, 0x5c, 0x52, 0x89, 0xe2, 0x86, 0x2c, 0xa8, 0x16, 0xb6, 0xcb, 0x8b, 0x6d, 0x17, 0x18,
0x7a, 0x04, 0x1b, 0x77, 0xe4, 0x2d, 0x2b, 0x79, 0xeb, 0xfe, 0xac, 0xae, 0xb1, 0x12, 0x41, 0xd3,
0x2b, 0xe6, 0x53, 0x5c, 0xc9, 0x95, 0x14, 0xd7, 0xbd, 0x0e, 0xd4, 0x67, 0xbe, 0x1f, 0xba, 0x07,
0x9b, 0xe7, 0x3d, 0xd7, 0xee, 0x1c, 0x77, 0xdf, 0x74, 0x3b, 0x27, 0x9e, 0x7d, 0xda, 0x76, 0x3b,
0xfa, 0x12, 0xaa, 0x41, 0xa5, 0x7d, 0x7e, 0x76, 0xda, 0xd3, 0xb5, 0xc9, 0xf1, 0x42, 0x2f, 0x8d,
0x8f, 0xee, 0x59, 0xfb, 0xcc, 0xd5, 0xcb, 0x7b, 0x47, 0x00, 0x33, 0x8f, 0x7e, 0x0b, 0xd0, 0x5c,
0x97, 0x8f, 0xef, 0xbb, 0xc7, 0x5f, 0xf4, 0x25, 0xa4, 0xc3, 0x5a, 0x77, 0xd0, 0x8b, 0xa5, 0x9d,
0x52, 0x41, 0x23, 0xa9, 0x6b, 0x08, 0xa0, 0xda, 0xe6, 0xd7, 0x64, 0x24, 0xf4, 0xd2, 0xd1, 0xeb,
0x1f, 0x37, 0x0d, 0xed, 0xe7, 0x4d, 0x43, 0xfb, 0x75, 0xd3, 0xd0, 0x2e, 0x5a, 0x01, 0x93, 0xc3,
0xac, 0x6f, 0xfa, 0x71, 0x68, 0x11, 0xce, 0xfa, 0xa4, 0x4f, 0xac, 0xe2, 0x33, 0x5a, 0x24, 0x61,
0xd6, 0x3f, 0x7e, 0x30, 0xfd, 0xaa, 0x5a, 0xd3, 0xc1, 0x9f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x0b,
0x3c, 0xc3, 0xcf, 0x7e, 0x04, 0x00, 0x00,
}
func (m *WasmPlugin) Marshal() (dAtA []byte, err error) {
@@ -581,6 +590,15 @@ func (m *MatchRule) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i -= len(m.XXX_unrecognized)
copy(dAtA[i:], m.XXX_unrecognized)
}
if len(m.Service) > 0 {
for iNdEx := len(m.Service) - 1; iNdEx >= 0; iNdEx-- {
i -= len(m.Service[iNdEx])
copy(dAtA[i:], m.Service[iNdEx])
i = encodeVarintWasm(dAtA, i, uint64(len(m.Service[iNdEx])))
i--
dAtA[i] = 0x2a
}
}
if m.ConfigDisable {
i--
if m.ConfigDisable {
@@ -719,6 +737,12 @@ func (m *MatchRule) Size() (n int) {
if m.ConfigDisable {
n += 2
}
if len(m.Service) > 0 {
for _, s := range m.Service {
l = len(s)
n += 1 + l + sovWasm(uint64(l))
}
}
if m.XXX_unrecognized != nil {
n += len(m.XXX_unrecognized)
}
@@ -1291,6 +1315,38 @@ func (m *MatchRule) Unmarshal(dAtA []byte) error {
}
}
m.ConfigDisable = bool(v != 0)
case 5:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Service", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowWasm
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthWasm
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthWasm
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Service = append(m.Service, string(dAtA[iNdEx:postIndex]))
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipWasm(dAtA[iNdEx:])

View File

@@ -114,6 +114,7 @@ message MatchRule {
repeated string domain = 2;
google.protobuf.Struct config = 3;
bool config_disable = 4;
repeated string service = 5;
}
// The phase in the filter chain where the plugin will be injected.

View File

@@ -64,6 +64,10 @@ spec:
items:
type: string
type: array
service:
items:
type: string
type: array
type: object
type: array
phase:

View File

@@ -97,7 +97,7 @@ higress: {{ include "controller.name" . }}
{{- end }}
{{- define "skywalking.enabled" -}}
{{- if and .Values.skywalking.enabled .Values.skywalking.service.address }}
{{- if and (hasKey .Values "tracing") .Values.tracing.enable (hasKey .Values.tracing "skywalking") .Values.tracing.skywalking.service }}
true
{{- end }}
{{- end }}

View File

@@ -46,10 +46,6 @@
address: {{ .Values.global.tracer.lightstep.address }}
# Access Token used to communicate with the Satellite pool
accessToken: {{ .Values.global.tracer.lightstep.accessToken }}
{{- else if eq .Values.global.proxy.tracer "zipkin" }}
zipkin:
# Address of the Zipkin collector
address: {{ .Values.global.tracer.zipkin.address | default (print "zipkin." .Release.Namespace ":9411") }}
{{- else if eq .Values.global.proxy.tracer "datadog" }}
datadog:
# Address of the Datadog Agent
@@ -109,7 +105,17 @@ metadata:
labels:
{{- include "gateway.labels" . | nindent 4 }}
data:
higress: |-
{{- $existingConfig := lookup "v1" "ConfigMap" .Release.Namespace "higress-config" }}
{{- $existingData := dict }}
{{- if $existingConfig }}
{{- $existingData = index $existingConfig.data "higress" | default "{}" | fromYaml }}
{{- end }}
{{- $newData := dict }}
{{- if and (hasKey .Values "tracing") .Values.tracing.enable }}
{{- $_ := set $newData "tracing" .Values.tracing }}
{{- end }}
{{- toYaml (merge $existingData $newData) | nindent 4 }}
# Configuration file for the mesh networks to be used by the Split Horizon EDS.
meshNetworks: |-
{{- if .Values.global.meshNetworks }}
@@ -170,8 +176,8 @@ data:
"endpoint": {
"address": {
"socket_address": {
"address": "{{ .Values.skywalking.service.address }}",
"port_value": "{{ .Values.skywalking.service.port }}"
"address": "{{ .Values.tracing.skywalking.service }}",
"port_value": "{{ .Values.tracing.skywalking.port }}"
}
}
}

View File

@@ -178,9 +178,9 @@ global:
# Default port for Pilot agent health checks. A value of 0 will disable health checking.
statusPort: 15020
# Specify which tracer to use. One of: zipkin, lightstep, datadog, stackdriver.
# Specify which tracer to use. One of: lightstep, datadog, stackdriver.
# If using stackdriver tracer outside GCP, set env GOOGLE_APPLICATION_CREDENTIALS to the GCP credential file.
tracer: "zipkin"
tracer: ""
# Controls if sidecar is injected at the front of the container list and blocks the start of the other containers until the proxy is ready
holdApplicationUntilProxyStarts: false
@@ -330,12 +330,8 @@ global:
maxNumberOfAnnotations: 200
# The global default max number of attributes per span.
maxNumberOfAttributes: 200
zipkin:
# Host:Port for reporting trace data in zipkin format. If not specified, will default to
# zipkin service (port 9411) in the same namespace as the other istio components.
address: ""
# Use the Mesh Control Protocol (MCP) for configuring Istiod. Requires an MCP source.
useMCP: false
# Observability (o11y) configurations
@@ -668,9 +664,15 @@ pilot:
podLabels: {}
# Skywalking config settings
skywalking:
enabled: false
service:
address: ~
port: 11800
# Tracing config settings
tracing:
enable: false
sampling: 100
timeout: 500
skywalking:
# access_token: ""
service: ""
port: 11800
# zipkin:
# service: ""
# port: 9411

View File

@@ -32,7 +32,7 @@ func ParseProtocol(s string) Protocol {
return TCP
case "http":
return HTTP
case "grpc":
case "grpc", "triple", "tri":
return GRPC
case "dubbo":
return Dubbo

View File

@@ -841,6 +841,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
StructValue: rule.Config,
}
var matchItems []*types.Value
// match ingress
for _, ing := range rule.Ingress {
matchItems = append(matchItems, &types.Value{
Kind: &types.Value_StringValue{
@@ -861,6 +862,7 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
})
continue
}
// match domain
for _, domain := range rule.Domain {
matchItems = append(matchItems, &types.Value{
Kind: &types.Value_StringValue{
@@ -868,10 +870,31 @@ func (m *IngressConfig) convertIstioWasmPlugin(obj *higressext.WasmPlugin) (*ext
},
})
}
if len(matchItems) > 0 {
v.StructValue.Fields["_match_domain_"] = &types.Value{
Kind: &types.Value_ListValue{
ListValue: &types.ListValue{
Values: matchItems,
},
},
}
ruleValues = append(ruleValues, &types.Value{
Kind: v,
})
continue
}
// match service
for _, service := range rule.Service {
matchItems = append(matchItems, &types.Value{
Kind: &types.Value_StringValue{
StringValue: service,
},
})
}
if len(matchItems) == 0 {
return nil, fmt.Errorf("invalid match rule has no match condition, rule:%v", rule)
}
v.StructValue.Fields["_match_domain_"] = &types.Value{
v.StructValue.Fields["_match_service_"] = &types.Value{
Kind: &types.Value_ListValue{
ListValue: &types.ListValue{
Values: matchItems,

View File

@@ -431,11 +431,14 @@ func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapp
if err != nil {
if k8serrors.IsNotFound(err) {
// If there is no matching secret, try to get it from configmap.
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
secretNamespace = c.options.SystemNamespace
namespace, secret := cert.ParseTLSSecret(secretName)
if namespace != "" {
secretNamespace = namespace
matchSecretName := httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
if matchSecretName != "" {
namespace, secret := cert.ParseTLSSecret(matchSecretName)
if namespace == "" {
secretNamespace = c.options.SystemNamespace
} else {
secretNamespace = namespace
}
secretName = secret
}
}

View File

@@ -417,11 +417,14 @@ func (c *controller) ConvertGateway(convertOptions *common.ConvertOptions, wrapp
if err != nil {
if k8serrors.IsNotFound(err) {
// If there is no matching secret, try to get it from configmap.
secretName = httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
secretNamespace = c.options.SystemNamespace
namespace, secret := cert.ParseTLSSecret(secretName)
if namespace != "" {
secretNamespace = namespace
matchSecretName := httpsCredentialConfig.MatchSecretNameByDomain(rule.Host)
if matchSecretName != "" {
namespace, secret := cert.ParseTLSSecret(matchSecretName)
if namespace == "" {
secretNamespace = c.options.SystemNamespace
} else {
secretNamespace = namespace
}
secretName = secret
}
}

View File

@@ -163,7 +163,6 @@ func (c *controller) processNextWorkItem() bool {
func (c *controller) onEvent(namespacedName types.NamespacedName) error {
event := model.EventUpdate
ing, err := c.ingressLister.Ingresses(namespacedName.Namespace).Get(namespacedName.Name)
ing.Status.InitializeConditions()
if err != nil {
if kerrors.IsNotFound(err) {
event = model.EventDelete
@@ -181,6 +180,8 @@ func (c *controller) onEvent(namespacedName types.NamespacedName) error {
return nil
}
ing.Status.InitializeConditions()
// we should check need process only when event is not delete,
// if it is delete event, and previously processed, we need to process too.
if event != model.EventDelete {

View File

@@ -0,0 +1,53 @@
## 介绍
此 SDK 用于使用 AssemblyScript 语言开发 Higress 的 Wasm 插件。
### 如何使用SDK
创建一个新的 AssemblyScript 项目。
```
npm init
npm install --save-dev assemblyscript
npx asinit .
```
在asconfig.json文件中作为传递给asc编译器的选项之一包含"use": "abort=abort_proc_exit"。
```
{
"options": {
"use": "abort=abort_proc_exit"
}
}
```
`"@higress/wasm-assemblyscript": "^0.0.4"`添加到你的依赖项中,然后运行`npm install`
### 本地构建
```
npm run asbuild
```
构建结果将在`build`文件夹中。其中,`debug.wasm``release.wasm`是已编译的文件,在生产环境中建议使用`release.wasm`
注:如果需要插件带有 name section 信息需要带上`"debug": true`,编译参数解释详见[using-the-compiler](https://www.assemblyscript.org/compiler.html#using-the-compiler)。
```json
"release": {
"outFile": "build/release.wasm",
"textFile": "build/release.wat",
"sourceMap": true,
"optimizeLevel": 3,
"shrinkLevel": 0,
"converge": false,
"noAssert": false,
"debug": true
}
```
### AssemblyScript 限制
此 SDK 使用的 AssemblyScript 版本为`0.27.29`,参考[AssemblyScript Status](https://www.assemblyscript.org/status.html)该版本尚未支持闭包、异常、迭代器等特性并且JSON正则表达式等功能还尚未在标准库中实现暂时需要使用社区提供的实现。

View File

@@ -0,0 +1,23 @@
{
"targets": {
"debug": {
"outFile": "build/debug.wasm",
"textFile": "build/debug.wat",
"sourceMap": true,
"debug": true
},
"release": {
"outFile": "build/release.wasm",
"textFile": "build/release.wat",
"sourceMap": true,
"optimizeLevel": 3,
"shrinkLevel": 0,
"converge": false,
"noAssert": false
}
},
"options": {
"bindings": "esm",
"use": "abort=abort_proc_exit"
}
}

View File

@@ -0,0 +1,214 @@
import {
log,
LogLevelValues,
get_property,
WasmResultValues,
} from "@higress/proxy-wasm-assemblyscript-sdk/assembly";
import { getRequestHost } from "./request_wrapper";
export abstract class Cluster {
abstract clusterName(): string;
abstract hostName(): string;
}
export class RouteCluster extends Cluster {
host: string;
constructor(host: string = "") {
super();
this.host = host;
}
clusterName(): string {
let result = get_property("cluster_name");
if (result.status != WasmResultValues.Ok) {
log(LogLevelValues.error, "get route cluster failed");
return "";
}
return String.UTF8.decode(result.returnValue);
}
hostName(): string {
if (this.host != "") {
return this.host;
}
return getRequestHost();
}
}
export class K8sCluster extends Cluster {
serviceName: string;
namespace: string;
port: i64;
version: string;
host: string;
constructor(
serviceName: string,
namespace: string,
port: i64,
version: string = "",
host: string = ""
) {
super();
this.serviceName = serviceName;
this.namespace = namespace;
this.port = port;
this.version = version;
this.host = host;
}
clusterName(): string {
let namespace = this.namespace != "" ? this.namespace : "default";
return `outbound|${this.port}|${this.version}|${this.serviceName}.${namespace}.svc.cluster.local`;
}
hostName(): string {
if (this.host != "") {
return this.host;
}
return `${this.serviceName}.${this.namespace}.svc.cluster.local`;
}
}
export class NacosCluster extends Cluster {
serviceName: string;
group: string;
namespaceID: string;
port: i64;
isExtRegistry: boolean;
version: string;
host: string;
constructor(
serviceName: string,
namespaceID: string,
port: i64,
// use DEFAULT-GROUP by default
group: string = "DEFAULT-GROUP",
// set true if use edas/sae registry
isExtRegistry: boolean = false,
version: string = "",
host: string = ""
) {
super();
this.serviceName = serviceName;
this.group = group.replace("_", "-");
this.namespaceID = namespaceID;
this.port = port;
this.isExtRegistry = isExtRegistry;
this.version = version;
this.host = host;
}
clusterName(): string {
let tail = "nacos" + (this.isExtRegistry ? "-ext" : "");
return `outbound|${this.port}|${this.version}|${this.serviceName}.${this.group}.${this.namespaceID}.${tail}`;
}
hostName(): string {
if (this.host != "") {
return this.host;
}
return this.serviceName;
}
}
export class StaticIpCluster extends Cluster {
serviceName: string;
port: i64;
host: string;
constructor(serviceName: string, port: i64, host: string = "") {
super()
this.serviceName = serviceName;
this.port = port;
this.host = host;
}
clusterName(): string {
return `outbound|${this.port}||${this.serviceName}.static`;
}
hostName(): string {
if (this.host != "") {
return this.host;
}
return this.serviceName;
}
}
export class DnsCluster extends Cluster {
serviceName: string;
domain: string;
port: i64;
constructor(serviceName: string, domain: string, port: i64) {
super();
this.serviceName = serviceName;
this.domain = domain;
this.port = port;
}
clusterName(): string {
return `outbound|${this.port}||${this.serviceName}.dns`;
}
hostName(): string {
return this.domain;
}
}
export class ConsulCluster extends Cluster {
serviceName: string;
datacenter: string;
port: i64;
host: string;
constructor(
serviceName: string,
datacenter: string,
port: i64,
host: string = ""
) {
super();
this.serviceName = serviceName;
this.datacenter = datacenter;
this.port = port;
this.host = host;
}
clusterName(): string {
return `outbound|${this.port}||${this.serviceName}.${this.datacenter}.consul`;
}
hostName(): string {
if (this.host != "") {
return this.host;
}
return this.serviceName;
}
}
export class FQDNCluster extends Cluster {
fqdn: string;
host: string;
port: i64;
constructor(fqdn: string, port: i64, host: string = "") {
super();
this.fqdn = fqdn;
this.host = host;
this.port = port;
}
clusterName(): string {
return `outbound|${this.port}||${this.fqdn}`;
}
hostName(): string {
if (this.host != "") {
return this.host;
}
return this.fqdn;
}
}

View File

@@ -0,0 +1,120 @@
import {
Cluster
} from "./cluster_wrapper"
import {
log,
LogLevelValues,
Headers,
HeaderPair,
root_context,
BufferTypeValues,
get_buffer_bytes,
BaseContext,
stream_context,
WasmResultValues,
RootContext,
ResponseCallBack
} from "@higress/proxy-wasm-assemblyscript-sdk/assembly";
export interface HttpClient {
get(path: string, headers: Headers, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
head(path: string, headers: Headers, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
options(path: string, headers: Headers, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
post(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
put(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
patch(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
delete(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
connect(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
trace(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32): boolean;
}
const methodArrayBuffer: ArrayBuffer = String.UTF8.encode(":method");
const pathArrayBuffer: ArrayBuffer = String.UTF8.encode(":path");
const authorityArrayBuffer: ArrayBuffer = String.UTF8.encode(":authority");
const StatusBadGateway: i32 = 502;
export class ClusterClient {
cluster: Cluster;
constructor(cluster: Cluster) {
this.cluster = cluster;
}
private httpCall(method: string, path: string, headers: Headers, body: ArrayBuffer, callback: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
if (root_context == null) {
log(LogLevelValues.error, "Root context is null");
return false;
}
for (let i: i32 = headers.length - 1; i >= 0; i--) {
const key = String.UTF8.decode(headers[i].key)
if ((key == ":method") || (key == ":path") || (key == ":authority")) {
headers.splice(i, 1);
}
}
headers.push(new HeaderPair(methodArrayBuffer, String.UTF8.encode(method)));
headers.push(new HeaderPair(pathArrayBuffer, String.UTF8.encode(path)));
headers.push(new HeaderPair(authorityArrayBuffer, String.UTF8.encode(this.cluster.hostName())));
const result = (root_context as RootContext).httpCall(this.cluster.clusterName(), headers, body, [], timeoutMillisecond, root_context as BaseContext, callback,
(_origin_context: BaseContext, _numHeaders: u32, body_size: usize, _trailers: u32, callback: ResponseCallBack): void => {
const respBody = get_buffer_bytes(BufferTypeValues.HttpCallResponseBody, 0, body_size as u32);
const respHeaders = stream_context.headers.http_callback.get_headers()
let code = StatusBadGateway;
let headers = new Array<HeaderPair>();
for (let i = 0; i < respHeaders.length; i++) {
const h = respHeaders[i];
if (String.UTF8.decode(h.key) == ":status") {
code = <i32>parseInt(String.UTF8.decode(h.value))
}
headers.push(new HeaderPair(h.key, h.value));
}
log(LogLevelValues.debug, `http call end, code: ${code}, body: ${String.UTF8.decode(respBody)}`)
callback(code, headers, respBody);
})
log(LogLevelValues.debug, `http call start, cluster: ${this.cluster.clusterName()}, method: ${method}, path: ${path}, body: ${String.UTF8.decode(body)}, timeout: ${timeoutMillisecond}`)
if (result != WasmResultValues.Ok) {
log(LogLevelValues.error, `http call failed, result: ${result}`)
return false
}
return true
}
get(path: string, headers: Headers, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("GET", path, headers, new ArrayBuffer(0), cb, timeoutMillisecond);
}
head(path: string, headers: Headers, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("HEAD", path, headers, new ArrayBuffer(0), cb, timeoutMillisecond);
}
options(path: string, headers: Headers, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("OPTIONS", path, headers, new ArrayBuffer(0), cb, timeoutMillisecond);
}
post(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("POST", path, headers, body, cb, timeoutMillisecond);
}
put(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("PUT", path, headers, body, cb, timeoutMillisecond);
}
patch(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("PATCH", path, headers, body, cb, timeoutMillisecond);
}
delete(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("DELETE", path, headers, body, cb, timeoutMillisecond);
}
connect(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("CONNECT", path, headers, body, cb, timeoutMillisecond);
}
trace(path: string, headers: Headers, body: ArrayBuffer, cb: ResponseCallBack, timeoutMillisecond: u32 = 500): boolean {
return this.httpCall("TRACE", path, headers, body, cb, timeoutMillisecond);
}
}

View File

@@ -0,0 +1,18 @@
export {RouteCluster,
K8sCluster,
NacosCluster,
ConsulCluster,
FQDNCluster,
StaticIpCluster} from "./cluster_wrapper"
export {HttpClient,
ClusterClient} from "./http_wrapper"
export {Log} from "./log_wrapper"
export {SetCtx,
HttpContext,
ParseConfigBy,
ProcessRequestBodyBy,
ProcessRequestHeadersBy,
ProcessResponseBodyBy,
ProcessResponseHeadersBy,
Logger, RegisteTickFunc} from "./plugin_wrapper"
export {ParseResult} from "./rule_matcher"

View File

@@ -0,0 +1,66 @@
import { log, LogLevelValues } from "@higress/proxy-wasm-assemblyscript-sdk/assembly";
enum LogLevel {
Trace = 0,
Debug,
Info,
Warn,
Error,
Critical,
}
export class Log {
private pluginName: string;
constructor(pluginName: string) {
this.pluginName = pluginName;
}
private log(level: LogLevel, msg: string): void {
let formattedMsg = `[${this.pluginName}] ${msg}`;
switch (level) {
case LogLevel.Trace:
log(LogLevelValues.trace, formattedMsg);
break;
case LogLevel.Debug:
log(LogLevelValues.debug, formattedMsg);
break;
case LogLevel.Info:
log(LogLevelValues.info, formattedMsg);
break;
case LogLevel.Warn:
log(LogLevelValues.warn, formattedMsg);
break;
case LogLevel.Error:
log(LogLevelValues.error, formattedMsg);
break;
case LogLevel.Critical:
log(LogLevelValues.critical, formattedMsg);
break;
}
}
public Trace(msg: string): void {
this.log(LogLevel.Trace, msg);
}
public Debug(msg: string): void {
this.log(LogLevel.Debug, msg);
}
public Info(msg: string): void {
this.log(LogLevel.Info, msg);
}
public Warn(msg: string): void {
this.log(LogLevel.Warn, msg);
}
public Error(msg: string): void {
this.log(LogLevel.Error, msg);
}
public Critical(msg: string): void {
this.log(LogLevel.Critical, msg);
}
}

View File

@@ -0,0 +1,445 @@
import { Log } from "./log_wrapper";
import {
Context,
FilterHeadersStatusValues,
RootContext,
setRootContext,
proxy_set_effective_context,
log,
LogLevelValues,
FilterDataStatusValues,
get_buffer_bytes,
BufferTypeValues,
set_tick_period_milliseconds,
get_current_time_nanoseconds
} from "@higress/proxy-wasm-assemblyscript-sdk/assembly";
import {
getRequestHost,
getRequestMethod,
getRequestPath,
getRequestScheme,
isBinaryRequestBody,
} from "./request_wrapper";
import { RuleMatcher, ParseResult } from "./rule_matcher";
import { JSON } from "assemblyscript-json/assembly";
export function SetCtx<PluginConfig>(
pluginName: string,
setFuncs: usize[] = []
): void {
const rootContextId = 1
setRootContext(new CommonRootCtx<PluginConfig>(rootContextId, pluginName, setFuncs));
}
export interface HttpContext {
Scheme(): string;
Host(): string;
Path(): string;
Method(): string;
SetContext(key: string, value: usize): void;
GetContext(key: string): usize;
DontReadRequestBody(): void;
DontReadResponseBody(): void;
}
type ParseConfigFunc<PluginConfig> = (
json: JSON.Obj,
) => ParseResult<PluginConfig>;
type OnHttpHeadersFunc<PluginConfig> = (
context: HttpContext,
config: PluginConfig,
) => FilterHeadersStatusValues;
type OnHttpBodyFunc<PluginConfig> = (
context: HttpContext,
config: PluginConfig,
body: ArrayBuffer,
) => FilterDataStatusValues;
export var Logger: Log = new Log("");
class CommonRootCtx<PluginConfig> extends RootContext {
pluginName: string;
hasCustomConfig: boolean;
ruleMatcher: RuleMatcher<PluginConfig>;
parseConfig: ParseConfigFunc<PluginConfig> | null;
onHttpRequestHeaders: OnHttpHeadersFunc<PluginConfig> | null;
onHttpRequestBody: OnHttpBodyFunc<PluginConfig> | null;
onHttpResponseHeaders: OnHttpHeadersFunc<PluginConfig> | null;
onHttpResponseBody: OnHttpBodyFunc<PluginConfig> | null;
onTickFuncs: Array<TickFuncEntry>;
constructor(context_id: u32, pluginName: string, setFuncs: usize[]) {
super(context_id);
this.pluginName = pluginName;
Logger = new Log(pluginName);
this.hasCustomConfig = true;
this.onHttpRequestHeaders = null;
this.onHttpRequestBody = null;
this.onHttpResponseHeaders = null;
this.onHttpResponseBody = null;
this.parseConfig = null;
this.ruleMatcher = new RuleMatcher<PluginConfig>();
this.onTickFuncs = new Array<TickFuncEntry>();
for (let i = 0; i < setFuncs.length; i++) {
changetype<Closure<PluginConfig>>(setFuncs[i]).lambdaFn(
setFuncs[i],
this
);
}
if (this.parseConfig == null) {
this.hasCustomConfig = false;
this.parseConfig = (json: JSON.Obj): ParseResult<PluginConfig> =>{ return new ParseResult<PluginConfig>(null, true); };
}
}
createContext(context_id: u32): Context {
return new CommonCtx<PluginConfig>(context_id, this);
}
onConfigure(configuration_size: u32): boolean {
super.onConfigure(configuration_size);
const data = this.getConfiguration();
let jsonData: JSON.Obj = new JSON.Obj();
if (data == "{}") {
if (this.hasCustomConfig) {
log(LogLevelValues.warn, "config is empty, but has ParseConfigFunc");
}
} else {
const parseData = JSON.parse(data);
if (parseData.isObj) {
jsonData = changetype<JSON.Obj>(JSON.parse(data));
} else {
log(LogLevelValues.error, "parse json data failed")
return false;
}
}
if (!this.ruleMatcher.parseRuleConfig(jsonData, this.parseConfig as ParseConfigFunc<PluginConfig>)) {
return false;
}
if (globalOnTickFuncs.length > 0) {
this.onTickFuncs = globalOnTickFuncs;
set_tick_period_milliseconds(100);
}
return true;
}
onTick(): void {
for (let i = 0; i < this.onTickFuncs.length; i++) {
const tickFuncEntry = this.onTickFuncs[i];
const now = getCurrentTimeMilliseconds();
if (tickFuncEntry.lastExecuted + tickFuncEntry.tickPeriod <= now) {
tickFuncEntry.tickFunc();
tickFuncEntry.lastExecuted = getCurrentTimeMilliseconds();
}
}
}
}
function getCurrentTimeMilliseconds(): u64 {
return get_current_time_nanoseconds() / 1000000;
}
class TickFuncEntry {
lastExecuted: u64;
tickPeriod: u64;
tickFunc: () => void;
constructor(lastExecuted: u64, tickPeriod: u64, tickFunc: () => void) {
this.lastExecuted = lastExecuted;
this.tickPeriod = tickPeriod;
this.tickFunc = tickFunc;
}
}
var globalOnTickFuncs = new Array<TickFuncEntry>();
export function RegisteTickFunc(tickPeriod: i64, tickFunc: () => void): void {
globalOnTickFuncs.push(new TickFuncEntry(0, tickPeriod, tickFunc));
}
class Closure<PluginConfig> {
lambdaFn: (closure: usize, ctx: CommonRootCtx<PluginConfig>) => void;
parseConfigFunc: ParseConfigFunc<PluginConfig> | null;
onHttpHeadersFunc: OnHttpHeadersFunc<PluginConfig> | null;
OnHttpBodyFunc: OnHttpBodyFunc<PluginConfig> | null;
constructor(
lambdaFn: (closure: usize, ctx: CommonRootCtx<PluginConfig>) => void
) {
this.lambdaFn = lambdaFn;
this.parseConfigFunc = null;
this.onHttpHeadersFunc = null;
this.OnHttpBodyFunc = null;
}
setParseConfigFunc(f: ParseConfigFunc<PluginConfig>): void {
this.parseConfigFunc = f;
}
setHttpHeadersFunc(f: OnHttpHeadersFunc<PluginConfig>): void {
this.onHttpHeadersFunc = f;
}
setHttpBodyFunc(f: OnHttpBodyFunc<PluginConfig>): void {
this.OnHttpBodyFunc = f;
}
}
export function ParseConfigBy<PluginConfig>(
f: ParseConfigFunc<PluginConfig>
): usize {
const lambdaFn = function (
closure: usize,
ctx: CommonRootCtx<PluginConfig>
): void {
const f = changetype<Closure<PluginConfig>>(closure).parseConfigFunc;
if (f != null) {
ctx.parseConfig = f;
}
};
const closure = new Closure<PluginConfig>(lambdaFn);
closure.setParseConfigFunc(f);
return changetype<usize>(closure);
}
export function ProcessRequestHeadersBy<PluginConfig>(
f: OnHttpHeadersFunc<PluginConfig>
): usize {
const lambdaFn = function (
closure: usize,
ctx: CommonRootCtx<PluginConfig>
): void {
const f = changetype<Closure<PluginConfig>>(closure).onHttpHeadersFunc;
if (f != null) {
ctx.onHttpRequestHeaders = f;
}
};
const closure = new Closure<PluginConfig>(lambdaFn);
closure.setHttpHeadersFunc(f);
return changetype<usize>(closure);
}
export function ProcessRequestBodyBy<PluginConfig>(
f: OnHttpBodyFunc<PluginConfig>
): usize {
const lambdaFn = function (
closure: usize,
ctx: CommonRootCtx<PluginConfig>
): void {
const f = changetype<Closure<PluginConfig>>(closure).OnHttpBodyFunc;
if (f != null) {
ctx.onHttpRequestBody = f;
}
};
const closure = new Closure<PluginConfig>(lambdaFn);
closure.setHttpBodyFunc(f);
return changetype<usize>(closure);
}
export function ProcessResponseHeadersBy<PluginConfig>(
f: OnHttpHeadersFunc<PluginConfig>
): usize {
const lambdaFn = function (
closure: usize,
ctx: CommonRootCtx<PluginConfig>
): void {
const f = changetype<Closure<PluginConfig>>(closure).onHttpHeadersFunc;
if (f != null) {
ctx.onHttpResponseHeaders = f;
}
};
const closure = new Closure<PluginConfig>(lambdaFn);
closure.setHttpHeadersFunc(f);
return changetype<usize>(closure);
}
export function ProcessResponseBodyBy<PluginConfig>(
f: OnHttpBodyFunc<PluginConfig>
): usize {
const lambdaFn = function (
closure: usize,
ctx: CommonRootCtx<PluginConfig>
): void {
const f = changetype<Closure<PluginConfig>>(closure).OnHttpBodyFunc;
if (f != null) {
ctx.onHttpResponseBody = f;
}
};
const closure = new Closure<PluginConfig>(lambdaFn);
closure.setHttpBodyFunc(f);
return changetype<usize>(closure);
}
class CommonCtx<PluginConfig> extends Context implements HttpContext {
commonRootCtx: CommonRootCtx<PluginConfig>;
config: PluginConfig |null;
needRequestBody: boolean;
needResponseBody: boolean;
requestBodySize: u32;
responseBodySize: u32;
contextID: u32;
userContext: Map<string, usize>;
constructor(context_id: u32, root_context: CommonRootCtx<PluginConfig>) {
super(context_id, root_context);
this.userContext = new Map<string, usize>();
this.commonRootCtx = root_context;
this.contextID = context_id;
this.requestBodySize = 0;
this.responseBodySize = 0;
this.config = null
if (this.commonRootCtx.onHttpRequestHeaders != null) {
this.needResponseBody = true;
} else {
this.needResponseBody = false;
}
if (this.commonRootCtx.onHttpRequestBody != null) {
this.needRequestBody = true;
} else {
this.needRequestBody = false;
}
}
SetContext(key: string, value: usize): void {
this.userContext.set(key, value);
}
GetContext(key: string): usize {
return this.userContext.get(key);
}
Scheme(): string {
proxy_set_effective_context(this.contextID);
return getRequestScheme();
}
Host(): string {
proxy_set_effective_context(this.contextID);
return getRequestHost();
}
Path(): string {
proxy_set_effective_context(this.contextID);
return getRequestPath();
}
Method(): string {
proxy_set_effective_context(this.contextID);
return getRequestMethod();
}
DontReadRequestBody(): void {
this.needRequestBody = false;
}
DontReadResponseBody(): void {
this.needResponseBody = false;
}
onRequestHeaders(_a: u32, _end_of_stream: boolean): FilterHeadersStatusValues {
const parseResult = this.commonRootCtx.ruleMatcher.getMatchConfig();
if (parseResult.success == false) {
log(LogLevelValues.error, "get match config failed");
return FilterHeadersStatusValues.Continue;
}
this.config = parseResult.pluginConfig;
if (isBinaryRequestBody()) {
this.needRequestBody = false;
}
if (this.commonRootCtx.onHttpRequestHeaders == null) {
return FilterHeadersStatusValues.Continue;
}
return this.commonRootCtx.onHttpRequestHeaders(
this,
this.config as PluginConfig
);
}
onRequestBody(
body_buffer_length: usize,
end_of_stream: boolean
): FilterDataStatusValues {
if (this.config == null || !this.needRequestBody) {
return FilterDataStatusValues.Continue;
}
if (this.commonRootCtx.onHttpRequestBody == null) {
return FilterDataStatusValues.Continue;
}
this.requestBodySize += body_buffer_length as u32;
if (!end_of_stream) {
return FilterDataStatusValues.StopIterationAndBuffer;
}
const body = get_buffer_bytes(
BufferTypeValues.HttpRequestBody,
0,
this.requestBodySize
);
return this.commonRootCtx.onHttpRequestBody(
this,
this.config as PluginConfig,
body
);
}
onResponseHeaders(_a: u32, _end_of_stream: bool): FilterHeadersStatusValues {
if (this.config == null) {
return FilterHeadersStatusValues.Continue;
}
if (isBinaryRequestBody()) {
this.needResponseBody = false;
}
if (this.commonRootCtx.onHttpResponseHeaders == null) {
return FilterHeadersStatusValues.Continue;
}
return this.commonRootCtx.onHttpResponseHeaders(
this,
this.config as PluginConfig
);
}
onResponseBody(
body_buffer_length: usize,
end_of_stream: bool
): FilterDataStatusValues {
if (this.config == null) {
return FilterDataStatusValues.Continue;
}
if (this.commonRootCtx.onHttpResponseBody == null) {
return FilterDataStatusValues.Continue;
}
if (!this.needResponseBody) {
return FilterDataStatusValues.Continue;
}
this.responseBodySize += body_buffer_length as u32;
if (!end_of_stream) {
return FilterDataStatusValues.StopIterationAndBuffer;
}
const body = get_buffer_bytes(
BufferTypeValues.HttpResponseBody,
0,
this.responseBodySize
);
return this.commonRootCtx.onHttpResponseBody(
this,
this.config as PluginConfig,
body
);
}
}

View File

@@ -0,0 +1,65 @@
import {
stream_context,
log,
LogLevelValues
} from "@higress/proxy-wasm-assemblyscript-sdk/assembly";
export function getRequestScheme(): string {
let scheme: string = stream_context.headers.request.get(":scheme");
if (scheme == "") {
log(LogLevelValues.error, "Parse request scheme failed");
}
return scheme;
}
export function getRequestHost(): string {
let host: string = stream_context.headers.request.get(":authority");
if (host == "") {
log(LogLevelValues.error, "Parse request host failed");
}
return host;
}
export function getRequestPath(): string {
let path: string = stream_context.headers.request.get(":path");
if (path == "") {
log(LogLevelValues.error, "Parse request path failed");
}
return path;
}
export function getRequestMethod(): string {
let method: string = stream_context.headers.request.get(":method");
if (method == "") {
log(LogLevelValues.error, "Parse request method failed");
}
return method;
}
export function isBinaryRequestBody(): boolean {
let contentType: string = stream_context.headers.request.get("content-type");
if (contentType != "" && (contentType.includes("octet-stream") || contentType.includes("grpc"))) {
return true;
}
let encoding: string = stream_context.headers.request.get("content-encoding");
if (encoding != "") {
return true;
}
return false;
}
export function isBinaryResponseBody(): boolean {
let contentType: string = stream_context.headers.response.get("content-type");
if (contentType != "" && (contentType.includes("octet-stream") || contentType.includes("grpc"))) {
return true;
}
let encoding: string = stream_context.headers.response.get("content-encoding");
if (encoding != "") {
return true;
}
return false;
}

View File

@@ -0,0 +1,346 @@
import { getRequestHost } from "./request_wrapper";
import {
get_property,
LogLevelValues,
log,
WasmResultValues,
} from "@higress/proxy-wasm-assemblyscript-sdk/assembly";
import { JSON } from "assemblyscript-json/assembly";
enum Category {
Route,
Host,
RoutePrefix,
Service
}
enum MatchType {
Prefix,
Exact,
Suffix,
}
const RULES_KEY: string = "_rules_";
const MATCH_ROUTE_KEY: string = "_match_route_";
const MATCH_DOMAIN_KEY: string = "_match_domain_";
const MATCH_SERVICE_KEY: string = "_match_service_";
const MATCH_ROUTE_PREFIX_KEY: string = "_match_route_prefix_"
class HostMatcher {
matchType: MatchType;
host: string;
constructor(matchType: MatchType, host: string) {
this.matchType = matchType;
this.host = host;
}
}
class RuleConfig<PluginConfig> {
category: Category;
routes!: Map<string, boolean>;
services!: Map<string, boolean>;
routePrefixs!: Map<string, boolean>;
hosts!: Array<HostMatcher>;
config: PluginConfig | null;
constructor() {
this.category = Category.Route;
this.config = null;
}
}
export class ParseResult<PluginConfig> {
pluginConfig: PluginConfig | null;
success: boolean;
constructor(pluginConfig: PluginConfig | null, success: boolean) {
this.pluginConfig = pluginConfig;
this.success = success;
}
}
export class RuleMatcher<PluginConfig> {
ruleConfig: Array<RuleConfig<PluginConfig>>;
globalConfig: PluginConfig | null;
hasGlobalConfig: boolean;
constructor() {
this.ruleConfig = new Array<RuleConfig<PluginConfig>>();
this.globalConfig = null;
this.hasGlobalConfig = false;
}
getMatchConfig(): ParseResult<PluginConfig> {
const host = getRequestHost();
if (host == "") {
return new ParseResult<PluginConfig>(null, false);
}
let result = get_property("route_name");
if (result.status != WasmResultValues.Ok && result.status != WasmResultValues.NotFound) {
return new ParseResult<PluginConfig>(null, false);
}
const routeName = String.UTF8.decode(result.returnValue);
result = get_property("cluster_name");
if (result.status != WasmResultValues.Ok && result.status != WasmResultValues.NotFound) {
return new ParseResult<PluginConfig>(null, false);
}
const serviceName = String.UTF8.decode(result.returnValue);
for (let i = 0; i < this.ruleConfig.length; i++) {
const rule = this.ruleConfig[i];
// category == Host
if (rule.category == Category.Host) {
if (this.hostMatch(rule, host)) {
log(LogLevelValues.debug, "getMatchConfig: match host " + host);
return new ParseResult<PluginConfig>(rule.config, true);
}
}
// category == Route
if (rule.category == Category.Route) {
if (rule.routes.has(routeName)) {
log(LogLevelValues.debug, "getMatchConfig: match route " + routeName);
return new ParseResult<PluginConfig>(rule.config, true);
}
}
// category == RoutePrefix
if (rule.category == Category.RoutePrefix) {
for (let i = 0; i < rule.routePrefixs.keys().length; i++) {
const routePrefix = rule.routePrefixs.keys()[i];
if (routeName.startsWith(routePrefix)) {
return new ParseResult<PluginConfig>(rule.config, true);
}
}
}
// category == Cluster
if (this.serviceMatch(rule, serviceName)) {
return new ParseResult<PluginConfig>(rule.config, true);
}
}
if (this.hasGlobalConfig) {
return new ParseResult<PluginConfig>(this.globalConfig, true);
}
return new ParseResult<PluginConfig>(null, false);
}
parseRuleConfig(
config: JSON.Obj,
parsePluginConfig: (json: JSON.Obj) => ParseResult<PluginConfig>
): boolean {
const obj = config;
let keyCount = obj.keys.length;
if (keyCount == 0) {
this.hasGlobalConfig = true;
const parseResult = parsePluginConfig(config);
if (parseResult.success) {
this.globalConfig = parseResult.pluginConfig;
return true;
} else {
return false;
}
}
let rules: JSON.Arr | null = null;
if (obj.has(RULES_KEY)) {
rules = obj.getArr(RULES_KEY);
keyCount--;
}
if (keyCount > 0) {
const parseResult = parsePluginConfig(config);
if (parseResult.success) {
this.globalConfig = parseResult.pluginConfig;
this.hasGlobalConfig = true;
}
}
if (!rules) {
if (this.hasGlobalConfig) {
return true;
}
log(LogLevelValues.error, "parse config failed, no valid rules; global config parse error");
return false;
}
const rulesArray = rules.valueOf();
for (let i = 0; i < rulesArray.length; i++) {
if (!rulesArray[i].isObj) {
log(LogLevelValues.error, "parse rule failed, rules must be an array of objects");
continue;
}
const ruleJson = changetype<JSON.Obj>(rulesArray[i]);
const rule = new RuleConfig<PluginConfig>();
const parseResult = parsePluginConfig(ruleJson);
if (parseResult.success) {
rule.config = parseResult.pluginConfig;
} else {
return false;
}
rule.routes = this.parseRouteMatchConfig(ruleJson);
rule.hosts = this.parseHostMatchConfig(ruleJson);
rule.services = this.parseServiceMatchConfig(ruleJson);
rule.routePrefixs = this.parseRoutePrefixMatchConfig(ruleJson);
const noRoute = rule.routes.size == 0;
const noHosts = rule.hosts.length == 0;
const noServices = rule.services.size == 0;
const noRoutePrefixs = rule.routePrefixs.size == 0;
if ((boolToInt(noRoute) + boolToInt(noHosts) + boolToInt(noServices) + boolToInt(noRoutePrefixs)) != 3) {
log(LogLevelValues.error, "there is only one of '_match_route_', '_match_domain_', '_match_service_' and '_match_route_prefix_' can present in configuration.");
return false;
}
if (!noRoute) {
rule.category = Category.Route;
} else if (!noHosts) {
rule.category = Category.Host;
} else if (!noServices) {
rule.category = Category.Service;
} else {
rule.category = Category.RoutePrefix;
}
this.ruleConfig.push(rule);
}
return true;
}
parseRouteMatchConfig(config: JSON.Obj): Map<string, boolean> {
const keys = config.getArr(MATCH_ROUTE_KEY);
const routes = new Map<string, boolean>();
if (keys) {
const array = keys.valueOf();
for (let i = 0; i < array.length; i++) {
const key = array[i].toString();
if (key != "") {
routes.set(key, true);
}
}
}
return routes;
}
parseRoutePrefixMatchConfig(config: JSON.Obj): Map<string, boolean> {
const keys = config.getArr(MATCH_ROUTE_PREFIX_KEY);
const routePrefixs = new Map<string, boolean>();
if (keys) {
const array = keys.valueOf();
for (let i = 0; i < array.length; i++) {
const key = array[i].toString();
if (key != "") {
routePrefixs.set(key, true);
}
}
}
return routePrefixs;
}
parseServiceMatchConfig(config: JSON.Obj): Map<string, boolean> {
const keys = config.getArr(MATCH_SERVICE_KEY);
const clusters = new Map<string, boolean>();
if (keys) {
const array = keys.valueOf();
for (let i = 0; i < array.length; i++) {
const key = array[i].toString();
if (key != "") {
clusters.set(key, true);
}
}
}
return clusters;
}
parseHostMatchConfig(config: JSON.Obj): Array<HostMatcher> {
const hostMatchers = new Array<HostMatcher>();
const keys = config.getArr(MATCH_DOMAIN_KEY);
if (keys !== null) {
const array = keys.valueOf();
for (let i = 0; i < array.length; i++) {
const item = array[i].toString(); // Assuming the array has string elements
let hostMatcher: HostMatcher;
if (item.startsWith("*")) {
hostMatcher = new HostMatcher(MatchType.Suffix, item.substr(1));
} else if (item.endsWith("*")) {
hostMatcher = new HostMatcher(
MatchType.Prefix,
item.substr(0, item.length - 1)
);
} else {
hostMatcher = new HostMatcher(MatchType.Exact, item);
}
hostMatchers.push(hostMatcher);
}
}
return hostMatchers;
}
stripPortFromHost(reqHost: string): string {
// Port removing code is inspired by
// https://github.com/envoyproxy/envoy/blob/v1.17.0/source/common/http/header_utility.cc#L219
let portStart: i32 = reqHost.lastIndexOf(":");
if (portStart != -1) {
// According to RFC3986 v6 address is always enclosed in "[]".
// section 3.2.2.
let v6EndIndex: i32 = reqHost.lastIndexOf("]");
if (v6EndIndex == -1 || v6EndIndex < portStart) {
if (portStart + 1 <= reqHost.length) {
return reqHost.substring(0, portStart);
}
}
}
return reqHost;
}
hostMatch(rule: RuleConfig<PluginConfig>, reqHost: string): boolean {
reqHost = this.stripPortFromHost(reqHost);
for (let i = 0; i < rule.hosts.length; i++) {
let hostMatch = rule.hosts[i];
switch (hostMatch.matchType) {
case MatchType.Suffix:
if (reqHost.endsWith(hostMatch.host)) {
return true;
}
break;
case MatchType.Prefix:
if (reqHost.startsWith(hostMatch.host)) {
return true;
}
break;
case MatchType.Exact:
if (reqHost == hostMatch.host) {
return true;
}
break;
default:
return false;
}
}
return false;
}
serviceMatch(rule: RuleConfig<PluginConfig>, serviceName: string): boolean {
const parts = serviceName.split('|');
if (parts.length != 4) {
return false;
}
const port = parts[1];
const fqdn = parts[3];
for (let i = 0; i < rule.services.keys().length; i++) {
let configServiceName = rule.services.keys()[i];
let colonIndex = configServiceName.lastIndexOf(':');
if (colonIndex != -1) {
let configFQDN = configServiceName.slice(0, colonIndex);
let configPort = configServiceName.slice(colonIndex + 1);
if (fqdn == configFQDN && port == configPort) return true;
} else if (fqdn == configServiceName) {
return true;
}
}
return false;
}
}
function boolToInt(value: boolean): i32 {
return value ? 1 : 0;
}

View File

@@ -0,0 +1,6 @@
{
"extends": "assemblyscript/std/assembly.json",
"include": [
"./**/*.ts"
]
}

View File

@@ -0,0 +1,80 @@
# 功能说明
`custom-response`插件支持配置自定义的响应,包括自定义 HTTP 应答状态码、HTTP 应答头,以及 HTTP 应答 Body。可以用于 Mock 响应,也可以用于判断特定状态码后给出自定义应答,例如在触发网关限流策略时实现自定义响应。
# 配置字段
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------- | -------- | -------- | -------- | -------- |
| status_code | number | 选填 | 200 | 自定义 HTTP 应答状态码 |
| headers | array of string | 选填 | - | 自定义 HTTP 应答头key 和 value 用`=`分隔 |
| body | string | 选填 | - | 自定义 HTTP 应答 Body |
| enable_on_status | array of number | 选填 | - | 匹配原始状态码,生成自定义响应,不填写时,不判断原始状态码 |
# 配置示例
## Mock 应答场景
```yaml
status_code: 200
headers:
- Content-Type=application/json
- Hello=World
body: "{\"hello\":\"world\"}"
```
根据该配置,请求将返回自定义应答如下:
```text
HTTP/1.1 200 OK
Content-Type: application/json
Hello: World
Content-Length: 17
{"hello":"world"}
```
## 触发限流时自定义响应
```yaml
enable_on_status:
- 429
status_code: 302
headers:
- Location=https://example.com
```
触发网关限流时一般会返回 `429` 状态码,这时请求将返回自定义应答如下:
```text
HTTP/1.1 302 Found
Location: https://example.com
```
从而实现基于浏览器 302 重定向机制,将限流后的用户引导到其他页面,比如可以是一个 CDN 上的静态页面。
如果希望触发限流时,正常返回其他应答,参考 Mock 应答场景配置相应的字段即可。
## 对特定路由或域名开启
```yaml
# 使用 matchRules 字段进行细粒度规则配置
matchRules:
# 规则一:按 Ingress 名称匹配生效
- ingress:
- default/foo
- default/bar
body: "{\"hello\":\"world\"}"
# 规则二:按域名匹配生效
- domain:
- "*.example.com"
- test.com
enable_on_status:
- 429
status_code: 200
headers:
- Content-Type=application/json
body: "{\"errmsg\": \"rate limited\"}"
```
此例 `ingress` 中指定的 `default/foo``default/bar` 对应 default 命名空间下名为 foo 和 bar 的 Ingress当匹配到这两个 Ingress 时,将使用此段配置;
此例 `domain` 中指定的 `*.example.com``test.com` 用于匹配请求的域名,当发现域名匹配时,将使用此段配置;
配置的匹配生效顺序,将按照 `matchRules` 下规则的排列顺序,匹配第一个规则后生效对应配置,后续规则将被忽略。

View File

@@ -0,0 +1,24 @@
{
"targets": {
"debug": {
"outFile": "build/debug.wasm",
"textFile": "build/debug.wat",
"sourceMap": true,
"debug": true
},
"release": {
"outFile": "build/release.wasm",
"textFile": "build/release.wat",
"sourceMap": true,
"optimizeLevel": 3,
"shrinkLevel": 0,
"converge": false,
"noAssert": false,
"debug": true
}
},
"options": {
"bindings": "esm",
"use": "abort=abort_proc_exit"
}
}

View File

@@ -0,0 +1,96 @@
export * from "@higress/proxy-wasm-assemblyscript-sdk/assembly/proxy";
import { SetCtx, HttpContext, ProcessRequestHeadersBy, Logger, ParseConfigBy, ParseResult, ProcessResponseHeadersBy } from "@higress/wasm-assemblyscript/assembly";
import { FilterHeadersStatusValues, Headers, send_http_response, stream_context, HeaderPair } from "@higress/proxy-wasm-assemblyscript-sdk/assembly"
import { JSON } from "assemblyscript-json/assembly";
class CustomResponseConfig {
statusCode: u32;
headers: Headers;
body: ArrayBuffer;
enableOnStatus: Array<u32>;
contentType: string;
constructor() {
this.statusCode = 200;
this.headers = [];
this.body = new ArrayBuffer(0);
this.enableOnStatus = [];
this.contentType = "text/plain; charset=utf-8";
}
}
SetCtx<CustomResponseConfig>(
"custom-response",
[ParseConfigBy<CustomResponseConfig>(parseConfig),
ProcessRequestHeadersBy<CustomResponseConfig>(onHttpRequestHeaders),
ProcessResponseHeadersBy<CustomResponseConfig>(onHttpResponseHeaders),])
function parseConfig(json: JSON.Obj): ParseResult<CustomResponseConfig> {
let headersArray = json.getArr("headers");
let config = new CustomResponseConfig();
if (headersArray != null) {
for (let i = 0; i < headersArray.valueOf().length; i++) {
let header = headersArray._arr[i];
let jsonString = (<JSON.Str>header).toString()
let kv = jsonString.split("=")
if (kv.length == 2) {
let key = kv[0].trim();
let value = kv[1].trim();
if (key.toLowerCase() == "content-type") {
config.contentType = value;
} else if (key.toLowerCase() == "content-length") {
continue;
} else {
config.headers.push(new HeaderPair(String.UTF8.encode(key), String.UTF8.encode(value)));
}
} else {
Logger.Error("parse header failed");
return new ParseResult<CustomResponseConfig>(null, false);
}
}
}
let body = json.getString("body");
if (body != null) {
config.body = String.UTF8.encode(body.valueOf());
}
config.headers.push(new HeaderPair(String.UTF8.encode("content-type"), String.UTF8.encode(config.contentType)));
let statusCode = json.getInteger("statusCode");
if (statusCode != null) {
config.statusCode = statusCode.valueOf() as u32;
}
let enableOnStatus = json.getArr("enableOnStatus");
if (enableOnStatus != null) {
for (let i = 0; i < enableOnStatus.valueOf().length; i++) {
let status = enableOnStatus._arr[i];
if (status.isInteger) {
config.enableOnStatus.push((<JSON.Integer>status).valueOf() as u32);
}
}
}
return new ParseResult<CustomResponseConfig>(config, true);
}
function onHttpRequestHeaders(context: HttpContext, config: CustomResponseConfig): FilterHeadersStatusValues {
if (config.enableOnStatus.length != 0) {
return FilterHeadersStatusValues.Continue;
}
send_http_response(config.statusCode, "custom-response", config.body, config.headers);
return FilterHeadersStatusValues.StopIteration;
}
function onHttpResponseHeaders(context: HttpContext, config: CustomResponseConfig): FilterHeadersStatusValues {
let statusCodeStr = stream_context.headers.response.get(":status")
if (statusCodeStr == "") {
Logger.Error("get http response status code failed");
return FilterHeadersStatusValues.Continue;
}
let statusCode = parseInt(statusCodeStr);
for (let i = 0; i < config.enableOnStatus.length; i++) {
if (statusCode == config.enableOnStatus[i]) {
send_http_response(config.statusCode, "custom-response", config.body, config.headers);
}
}
return FilterHeadersStatusValues.Continue;
}

View File

@@ -0,0 +1,6 @@
{
"extends": "assemblyscript/std/assembly.json",
"include": [
"./**/*.ts"
]
}

View File

@@ -0,0 +1,68 @@
{
"name": "custom-response",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "custom-response",
"version": "1.0.0",
"license": "ISC",
"devDependencies": {
"@higress/wasm-assemblyscript": "^0.0.4",
"assemblyscript": "^0.27.29",
"assemblyscript-json": "^1.1.0"
}
},
"node_modules/@higress/wasm-assemblyscript": {
"version": "0.0.4",
"resolved": "https://registry.npmjs.org/@higress/wasm-assemblyscript/-/wasm-assemblyscript-0.0.4.tgz",
"integrity": "sha512-F9m3fHBeM0OFWWHekTcmj3dVh7I4pbzf0oIioVdArD2oSUgpCZ8ur8E/9r7JR3WVwn2/l0A3LRSBOJTzQnHtMw==",
"dev": true
},
"node_modules/assemblyscript": {
"version": "0.27.29",
"resolved": "https://registry.npmmirror.com/assemblyscript/-/assemblyscript-0.27.29.tgz",
"integrity": "sha512-pH6udb7aE2F0t6cTh+0uCepmucykhMnAmm7k0kkAU3SY7LvpIngEBZWM6p5VCguu4EpmKGwEuZpZbEXzJ/frHQ==",
"dev": true,
"dependencies": {
"binaryen": "116.0.0-nightly.20240114",
"long": "^5.2.1"
},
"bin": {
"asc": "bin/asc.js",
"asinit": "bin/asinit.js"
},
"engines": {
"node": ">=16",
"npm": ">=7"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/assemblyscript"
}
},
"node_modules/assemblyscript-json": {
"version": "1.1.0",
"resolved": "https://registry.npmmirror.com/assemblyscript-json/-/assemblyscript-json-1.1.0.tgz",
"integrity": "sha512-UbE8ts8csTWQgd5TnSPN7MRV9NveuHv1bVnKmDLoo/tzjqxkmsZb3lu59Uk8H7SGoqdkDSEE049alx/nHnSdFw==",
"dev": true
},
"node_modules/binaryen": {
"version": "116.0.0-nightly.20240114",
"resolved": "https://registry.npmmirror.com/binaryen/-/binaryen-116.0.0-nightly.20240114.tgz",
"integrity": "sha512-0GZrojJnuhoe+hiwji7QFaL3tBlJoA+KFUN7ouYSDGZLSo9CKM8swQX8n/UcbR0d1VuZKU+nhogNzv423JEu5A==",
"dev": true,
"bin": {
"wasm-opt": "bin/wasm-opt",
"wasm2js": "bin/wasm2js"
}
},
"node_modules/long": {
"version": "5.2.3",
"resolved": "https://registry.npmmirror.com/long/-/long-5.2.3.tgz",
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==",
"dev": true
}
}
}

View File

@@ -0,0 +1,27 @@
{
"name": "custom-response",
"version": "1.0.0",
"main": "index.js",
"scripts": {
"test": "node tests",
"asbuild:debug": "asc assembly/index.ts --target debug",
"asbuild:release": "asc assembly/index.ts --target release",
"asbuild": "npm run asbuild:debug && npm run asbuild:release",
"start": "npx serve ."
},
"author": "",
"license": "ISC",
"description": "",
"devDependencies": {
"assemblyscript": "^0.27.29",
"assemblyscript-json": "^1.1.0",
"@higress/wasm-assemblyscript": "^0.0.4"
},
"type": "module",
"exports": {
".": {
"import": "./build/release.js",
"types": "./build/release.d.ts"
}
}
}

View File

@@ -0,0 +1,24 @@
{
"targets": {
"debug": {
"outFile": "build/debug.wasm",
"textFile": "build/debug.wat",
"sourceMap": true,
"debug": true
},
"release": {
"outFile": "build/release.wasm",
"textFile": "build/release.wat",
"sourceMap": true,
"optimizeLevel": 3,
"shrinkLevel": 0,
"converge": false,
"noAssert": false,
"debug": true
}
},
"options": {
"bindings": "esm",
"use": "abort=abort_proc_exit"
}
}

View File

@@ -0,0 +1,42 @@
export * from "@higress/proxy-wasm-assemblyscript-sdk/assembly/proxy";
import { SetCtx, HttpContext, ProcessRequestHeadersBy, Logger, ParseResult, ParseConfigBy, RegisteTickFunc, ProcessResponseHeadersBy } from "@higress/wasm-assemblyscript/assembly";
import { FilterHeadersStatusValues, send_http_response, stream_context } from "@higress/proxy-wasm-assemblyscript-sdk/assembly"
import { JSON } from "assemblyscript-json/assembly";
class HelloWorldConfig {
}
SetCtx<HelloWorldConfig>("hello-world",
[ParseConfigBy<HelloWorldConfig>(parseConfig),
ProcessRequestHeadersBy<HelloWorldConfig>(onHttpRequestHeaders),
ProcessResponseHeadersBy<HelloWorldConfig>(onHttpResponseHeaders)
])
function parseConfig(json: JSON.Obj): ParseResult<HelloWorldConfig> {
RegisteTickFunc(2000, () => {
Logger.Debug("tick 2s");
})
RegisteTickFunc(5000, () => {
Logger.Debug("tick 5s");
})
return new ParseResult<HelloWorldConfig>(new HelloWorldConfig(), true);
}
class TestContext{
value: string
constructor(value: string){
this.value = value
}
}
function onHttpRequestHeaders(context: HttpContext, config: HelloWorldConfig): FilterHeadersStatusValues {
stream_context.headers.request.add("hello", "world");
Logger.Debug("[hello-world] logger test");
context.SetContext("test-set-context", changetype<usize>(new TestContext("value")))
send_http_response(200, "hello-world", String.UTF8.encode("[wasm-assemblyscript]hello world"), []);
return FilterHeadersStatusValues.Continue;
}
function onHttpResponseHeaders(context: HttpContext, config: HelloWorldConfig): FilterHeadersStatusValues {
const str = changetype<TestContext>(context.GetContext("test-set-context")).value;
Logger.Debug("[hello-world] test-set-context: " + str);
return FilterHeadersStatusValues.Continue;
}

View File

@@ -0,0 +1,6 @@
{
"extends": "assemblyscript/std/assembly.json",
"include": [
"./**/*.ts"
]
}

View File

@@ -0,0 +1,68 @@
{
"name": "hello-world",
"version": "1.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "hello-world",
"version": "1.0.0",
"license": "ISC",
"devDependencies": {
"@higress/wasm-assemblyscript": "^0.0.4",
"assemblyscript": "^0.27.29",
"assemblyscript-json": "^1.1.0"
}
},
"node_modules/@higress/wasm-assemblyscript": {
"version": "0.0.4",
"resolved": "https://registry.npmjs.org/@higress/wasm-assemblyscript/-/wasm-assemblyscript-0.0.4.tgz",
"integrity": "sha512-F9m3fHBeM0OFWWHekTcmj3dVh7I4pbzf0oIioVdArD2oSUgpCZ8ur8E/9r7JR3WVwn2/l0A3LRSBOJTzQnHtMw==",
"dev": true
},
"node_modules/assemblyscript": {
"version": "0.27.29",
"resolved": "https://registry.npmmirror.com/assemblyscript/-/assemblyscript-0.27.29.tgz",
"integrity": "sha512-pH6udb7aE2F0t6cTh+0uCepmucykhMnAmm7k0kkAU3SY7LvpIngEBZWM6p5VCguu4EpmKGwEuZpZbEXzJ/frHQ==",
"dev": true,
"dependencies": {
"binaryen": "116.0.0-nightly.20240114",
"long": "^5.2.1"
},
"bin": {
"asc": "bin/asc.js",
"asinit": "bin/asinit.js"
},
"engines": {
"node": ">=16",
"npm": ">=7"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/assemblyscript"
}
},
"node_modules/assemblyscript-json": {
"version": "1.1.0",
"resolved": "https://registry.npmmirror.com/assemblyscript-json/-/assemblyscript-json-1.1.0.tgz",
"integrity": "sha512-UbE8ts8csTWQgd5TnSPN7MRV9NveuHv1bVnKmDLoo/tzjqxkmsZb3lu59Uk8H7SGoqdkDSEE049alx/nHnSdFw==",
"dev": true
},
"node_modules/binaryen": {
"version": "116.0.0-nightly.20240114",
"resolved": "https://registry.npmmirror.com/binaryen/-/binaryen-116.0.0-nightly.20240114.tgz",
"integrity": "sha512-0GZrojJnuhoe+hiwji7QFaL3tBlJoA+KFUN7ouYSDGZLSo9CKM8swQX8n/UcbR0d1VuZKU+nhogNzv423JEu5A==",
"dev": true,
"bin": {
"wasm-opt": "bin/wasm-opt",
"wasm2js": "bin/wasm2js"
}
},
"node_modules/long": {
"version": "5.2.3",
"resolved": "https://registry.npmmirror.com/long/-/long-5.2.3.tgz",
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==",
"dev": true
}
}
}

View File

@@ -0,0 +1,27 @@
{
"name": "hello-world",
"version": "1.0.0",
"main": "index.js",
"scripts": {
"test": "node tests",
"asbuild:debug": "asc assembly/index.ts --target debug",
"asbuild:release": "asc assembly/index.ts --target release",
"asbuild": "npm run asbuild:debug && npm run asbuild:release",
"start": "npx serve ."
},
"author": "",
"license": "ISC",
"description": "",
"devDependencies": {
"assemblyscript": "^0.27.29",
"assemblyscript-json": "^1.1.0",
"@higress/wasm-assemblyscript": "^0.0.4"
},
"type": "module",
"exports": {
".": {
"import": "./build/release.js",
"types": "./build/release.d.ts"
}
}
}

View File

@@ -0,0 +1,75 @@
{
"name": "@higress/wasm-assemblyscript",
"version": "0.0.4",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@higress/wasm-assemblyscript",
"version": "0.0.4",
"license": "Apache-2.0",
"devDependencies": {
"@higress/proxy-wasm-assemblyscript-sdk": "^0.0.2",
"as-uuid": "^0.0.4",
"assemblyscript": "^0.27.29",
"assemblyscript-json": "^1.1.0"
}
},
"node_modules/@higress/proxy-wasm-assemblyscript-sdk": {
"version": "0.0.2",
"resolved": "https://registry.npmmirror.com/@higress/proxy-wasm-assemblyscript-sdk/-/proxy-wasm-assemblyscript-sdk-0.0.2.tgz",
"integrity": "sha512-0J1tFJMTE6o37JpGJBLq0wc5kBC/fpbISrP+KFb4bAEeshu6daXzD2P3bAfJXmW+oZdY0WGptTGXWx8pf9Fk+g==",
"dev": true
},
"node_modules/as-uuid": {
"version": "0.0.4",
"resolved": "https://registry.npmmirror.com/as-uuid/-/as-uuid-0.0.4.tgz",
"integrity": "sha512-ZHNv0ETSzg5ZD0IWWJVyip/73LWtrWeMmvRi+16xbkpU/nZ0O8EegvgS7bgZ5xRqrUbc2NqZqHOWMOtPqbLrhg==",
"dev": true
},
"node_modules/assemblyscript": {
"version": "0.27.29",
"resolved": "https://registry.npmmirror.com/assemblyscript/-/assemblyscript-0.27.29.tgz",
"integrity": "sha512-pH6udb7aE2F0t6cTh+0uCepmucykhMnAmm7k0kkAU3SY7LvpIngEBZWM6p5VCguu4EpmKGwEuZpZbEXzJ/frHQ==",
"dev": true,
"dependencies": {
"binaryen": "116.0.0-nightly.20240114",
"long": "^5.2.1"
},
"bin": {
"asc": "bin/asc.js",
"asinit": "bin/asinit.js"
},
"engines": {
"node": ">=16",
"npm": ">=7"
},
"funding": {
"type": "opencollective",
"url": "https://opencollective.com/assemblyscript"
}
},
"node_modules/assemblyscript-json": {
"version": "1.1.0",
"resolved": "https://registry.npmmirror.com/assemblyscript-json/-/assemblyscript-json-1.1.0.tgz",
"integrity": "sha512-UbE8ts8csTWQgd5TnSPN7MRV9NveuHv1bVnKmDLoo/tzjqxkmsZb3lu59Uk8H7SGoqdkDSEE049alx/nHnSdFw==",
"dev": true
},
"node_modules/binaryen": {
"version": "116.0.0-nightly.20240114",
"resolved": "https://registry.npmmirror.com/binaryen/-/binaryen-116.0.0-nightly.20240114.tgz",
"integrity": "sha512-0GZrojJnuhoe+hiwji7QFaL3tBlJoA+KFUN7ouYSDGZLSo9CKM8swQX8n/UcbR0d1VuZKU+nhogNzv423JEu5A==",
"dev": true,
"bin": {
"wasm-opt": "bin/wasm-opt",
"wasm2js": "bin/wasm2js"
}
},
"node_modules/long": {
"version": "5.2.3",
"resolved": "https://registry.npmmirror.com/long/-/long-5.2.3.tgz",
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==",
"dev": true
}
}
}

View File

@@ -0,0 +1,37 @@
{
"name": "@higress/wasm-assemblyscript",
"version": "0.0.4",
"main": "assembly/index.ts",
"scripts": {
"test": "node tests",
"asbuild:debug": "asc assembly/index.ts --target debug",
"asbuild:release": "asc assembly/index.ts --target release",
"asbuild": "npm run asbuild:debug && npm run asbuild:release",
"start": "npx serve ."
},
"author": "jingze.dai",
"license": "Apache-2.0",
"description": "",
"devDependencies": {
"assemblyscript": "^0.27.29",
"as-uuid": "^0.0.4",
"assemblyscript-json": "^1.1.0",
"@higress/proxy-wasm-assemblyscript-sdk": "^0.0.2"
},
"type": "module",
"exports": {
".": {
"import": "./build/release.js",
"types": "./build/release.d.ts"
}
},
"files": [
"/assembly",
"package-lock.json",
"index.js"
],
"repository": {
"type": "git",
"url": "git+https://github.com/Jing-ze/wasm-assemblyscript.git"
}
}

View File

@@ -5,7 +5,7 @@ GO_VERSION ?= 1.19
TINYGO_VERSION ?= 0.28.1
ORAS_VERSION ?= 1.0.0
HIGRESS_VERSION ?= 1.0.0-rc
USE_HIGRESS_TINYGO ?= true
USE_HIGRESS_TINYGO ?= false
BUILDER ?= ${BUILDER_REGISTRY}wasm-go-builder:go${GO_VERSION}-tinygo${TINYGO_VERSION}-oras${ORAS_VERSION}
BUILD_TIME := $(shell date "+%Y%m%d-%H%M%S")
COMMIT_ID := $(shell git rev-parse --short HEAD 2>/dev/null)

View File

@@ -0,0 +1,350 @@
---
title: AI Agent
keywords: [ AI网关, AI Agent ]
description: AI Agent插件配置参考
---
## 功能说明
一个可定制化的 API AI Agent支持配置 http method 类型为 GET 与 POST 的 API目前只支持非流式模式。
agent流程图如下
![ai-agent](https://github.com/user-attachments/assets/b0761a0c-1afa-496c-a98e-bb9f38b340f8)
## 配置字段
### 基本配置
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|------------------|-----------|---------|--------|----------------------------|
| `llm` | object | 必填 | - | 配置 AI 服务提供商的信息 |
| `apis` | object | 必填 | - | 配置外部 API 服务提供商的信息 |
| `promptTemplate` | object | 非必填 | - | 配置 Agent ReAct 模板的信息 |
`llm`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------|---------|--------|-----------------------------------|
| `apiKey` | string | 必填 | - | 用于在访问大模型服务时进行认证的令牌。|
| `serviceName` | string | 必填 | - | 大模型服务名 |
| `servicePort` | int | 必填 | - | 大模型服务端口 |
| `domain` | string | 必填 | - | 访问大模型服务时域名 |
| `path` | string | 必填 | - | 访问大模型服务时路径 |
| `model` | string | 必填 | - | 访问大模型服务时模型名 |
| `maxIterations` | int | 必填 | 15 | 结束执行循环前的最大步数 |
| `maxExecutionTime` | int | 必填 | 50000 | 每一次请求大模型的超时时间,单位毫秒 |
| `maxTokens` | int | 必填 | 1000 | 每一次请求大模型的输出token限制 |
`apis`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------|-----------|---------|--------|-----------------------------------|
| `apiProvider` | object | 必填 | - | 外部 API 服务信息 |
| `api` | string | 必填 | - | 工具的 OpenAPI 文档 |
`apiProvider`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------|-----------|---------|--------|------------------------------------------|
| `apiKey` | object | 非必填 | - | 用于在访问外部 API 服务时进行认证的令牌。 |
| `serviceName` | string | 必填 | - | 访问外部 API 服务名 |
| `servicePort` | int | 必填 | - | 访问外部 API 服务端口 |
| `domain` | string | 必填 | - | 访访问外部 API 时域名 |
`apiKey`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-------------------|---------|------------|--------|-------------------------------------------------------------------------------|
| `in` | string | 非必填 | header | 在访问外部 API 服务时进行认证的令牌是放在 header 中还是放在 query 中,默认是 header。
| `name` | string | 非必填 | - | 用于在访问外部 API 服务时进行认证的令牌的名称。 |
| `value` | string | 非必填 | - | 用于在访问外部 API 服务时进行认证的令牌的值。 |
`promptTemplate`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------|-----------|-----------|--------|--------------------------------------------|
| `language` | string | 非必填 | EN | Agent ReAct 模板的语言类型,包括 CH 和 EN 两种|
| `chTemplate` | object | 非必填 | - | Agent ReAct 中文模板 |
| `enTemplate` | object | 非必填 | - | Agent ReAct 英文模板 |
`chTemplate``enTemplate`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-----------------|-----------|-----------|--------|---------------------------------------------|
| `question` | string | 非必填 | - | Agent ReAct 模板的 question 部分 |
| `thought1` | string | 非必填 | - | Agent ReAct 模板的 thought1 部分 |
| `actionInput` | string | 非必填 | - | Agent ReAct 模板的 actionInput 部分 |
| `observation` | string | 非必填 | - | Agent ReAct 模板的 observation 部分 |
| `thought2` | string | 非必填 | - | Agent ReAct 模板的 thought2 部分 |
| `finalAnswer` | string | 非必填 | - | Agent ReAct 模板的 finalAnswer 部分 |
| `begin` | string | 非必填 | - | Agent ReAct 模板的 begin 部分 |
## 用法示例
**配置信息**
```yaml
llm:
apiKey: xxxxxxxxxxxxxxxxxx
domain: dashscope.aliyuncs.com
serviceName: dashscope.dns
servicePort: 443
path: /compatible-mode/v1/chat/completions
model: qwen-max-0403
maxIterations: 2
promptTemplate:
language: CH
apis:
- apiProvider:
domain: restapi.amap.com
serviceName: geo.dns
servicePort: 80
apiKey:
in: query
name: key
value: xxxxxxxxxxxxxxx
api: |
openapi: 3.1.0
info:
title: 高德地图
description: 获取 POI 的相关信息
version: v1.0.0
servers:
- url: https://restapi.amap.com
paths:
/v5/place/text:
get:
description: 根据POI名称获得POI的经纬度坐标
operationId: get_location_coordinate
parameters:
- name: keywords
in: query
description: POI名称必须是中文
required: true
schema:
type: string
- name: region
in: query
description: POI所在的区域名必须是中文
required: true
schema:
type: string
deprecated: false
/v5/place/around:
get:
description: 搜索给定坐标附近的POI
operationId: search_nearby_pois
parameters:
- name: keywords
in: query
description: 目标POI的关键字
required: true
schema:
type: string
- name: location
in: query
description: 中心点的经度和纬度,用逗号隔开
required: true
schema:
type: string
deprecated: false
components:
schemas: {}
- apiProvider:
domain: api.seniverse.com
serviceName: seniverse.dns
servicePort: 80
apiKey:
in: query
name: key
value: xxxxxxxxxxxxxxx
api: |
openapi: 3.1.0
info:
title: 心知天气
description: 获取 天气预办相关信息
version: v1.0.0
servers:
- url: https://api.seniverse.com
paths:
/v3/weather/now.json:
get:
description: 获取指定城市的天气实况
operationId: get_weather_now
parameters:
- name: location
in: query
description: 所查询的城市
required: true
schema:
type: string
- name: language
in: query
description: 返回天气查询结果所使用的语言
required: true
schema:
type: string
default: zh-Hans
enum:
- zh-Hans
- en
- ja
- name: unit
in: query
description: 表示温度的的单位,有摄氏度和华氏度两种
required: true
schema:
type: string
default: c
enum:
- c
- f
deprecated: false
components:
schemas: {}
- apiProvider:
apiKey:
in: "header"
name: "DeepL-Auth-Key"
value: "73xxxxxxxxxxxxxxx:fx"
domain: "api-free.deepl.com"
serviceName: "deepl.dns"
servicePort: 443
api: |
openapi: 3.1.0
info:
title: DeepL API Documentation
description: The DeepL API provides programmatic access to DeepLs machine translation technology.
version: v1.0.0
servers:
- url: https://api-free.deepl.com/v2
paths:
/translate:
post:
summary: Request Translation
operationId: translateText
requestBody:
required: true
content:
application/json:
schema:
type: object
required:
- text
- target_lang
properties:
text:
description: |
Text to be translated. Only UTF-8-encoded plain text is supported. The parameter may be specified
up to 50 times in a single request. Translations are returned in the same order as they are requested.
type: array
maxItems: 50
items:
type: string
example: Hello, World!
target_lang:
description: The language into which the text should be translated.
type: string
enum:
- BG
- CS
- DA
- DE
- EL
- EN-GB
- EN-US
- ES
- ET
- FI
- FR
- HU
- ID
- IT
- JA
- KO
- LT
- LV
- NB
- NL
- PL
- PT-BR
- PT-PT
- RO
- RU
- SK
- SL
- SV
- TR
- UK
- ZH
- ZH-HANS
example: DE
components:
schemas: {}
```
本示例配置了三个服务演示了get与post两种类型的工具。其中get类型的工具包括高德地图与心知天气post类型的工具是deepl翻译。三个服务都需要现在Higress的服务中以DNS域名的方式配置好并确保健康。
高德地图提供了两个工具分别是获取指定地点的坐标以及搜索坐标附近的感兴趣的地点。文档https://lbs.amap.com/api/webservice/guide/api-advanced/newpoisearch
心知天气提供了一个工具用于获取指定城市的实时天气情况支持中文英文日语返回以及摄氏度和华氏度的表示。文档https://seniverse.yuque.com/hyper_data/api_v3/nyiu3t
deepl提供了一个工具用于翻译给定的句子支持多语言。。文档https://developers.deepl.com/docs/v/zh/api-reference/translate?fallback=true
以下为测试用例为了效果的稳定性建议保持大模型版本的稳定本例子中使用的qwen-max-0403
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"我想在济南市鑫盛大厦附近喝咖啡,给我推荐几个"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**响应示例**
```json
{"id":"139487e7-96a0-9b13-91b4-290fb79ac992","choices":[{"index":0,"message":{"role":"assistant","content":" 在济南市鑫盛大厦附近,您可以选择以下咖啡店:\n1. luckin coffee 瑞幸咖啡(鑫盛大厦店)位于新泺大街1299号鑫盛大厦2号楼大堂\n2. 三庆齐盛广场挪瓦咖啡(三庆·齐盛广场店)位于新泺大街与颖秀路交叉口西南60米\n3. luckin coffee 瑞幸咖啡(三庆·齐盛广场店)位于颖秀路1267号\n4. 库迪咖啡(齐鲁软件园店)位于新泺大街三庆齐盛广场4号楼底商\n5. 库迪咖啡(美莲广场店)位于高新区新泺大街1166号美莲广场L117号以及其他一些选项。希望这些建议对您有所帮助"},"finish_reason":"stop"}],"created":1723172296,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":886,"completion_tokens":50,"total_tokens":936}}
```
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"济南市现在的天气情况如何?"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**响应示例**
```json
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content":" 济南市现在的天气状况为阴天温度为31℃。此信息最后更新于2024年8月9日15时12分北京时间。"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
```
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"济南市现在的天气情况如何?用华氏度表示,用日语回答"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**响应示例**
```json
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content":" 济南市の現在の天気は雨曇りで、気温は88°Fです。この情報は2024年8月9日15時12分東京時間に更新されました。"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
```
**请求示例**
```shell
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
-H 'Accept: application/json, text/event-stream' \
-H 'Content-Type: application/json' \
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"帮我用德语翻译以下句子:九头蛇万岁!"}],"presence_penalty":0,"temperature":0,"top_p":0}'
```
**响应示例**
```json
{"id":"65dcf12c-61ff-9e68-bffa-44fc9e6070d5","choices":[{"index":0,"message":{"role":"assistant","content":" “九头蛇万岁!”的德语翻译为“Hoch lebe Hydra!”。"},"finish_reason":"stop"}],"created":1724043865,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":908,"completion_tokens":52,"total_tokens":960}}
```

View File

@@ -0,0 +1,424 @@
package main
import (
"encoding/json"
"errors"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"gopkg.in/yaml.v2"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Request struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
FrequencyPenalty float64 `json:"frequency_penalty"`
PresencePenalty float64 `json:"presence_penalty"`
Stream bool `json:"stream"`
Temperature float64 `json:"temperature"`
Topp int32 `json:"top_p"`
}
type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Response struct {
ID string `json:"id"`
Choices []Choice `json:"choices"`
Created int64 `json:"created"`
Model string `json:"model"`
Object string `json:"object"`
Usage Usage `json:"usage"`
}
// 用于存放拆解出来的工具相关信息
type Tool_Param struct {
ToolName string `yaml:"toolName"`
Path string `yaml:"path"`
Method string `yaml:"method"`
ParamName []string `yaml:"paramName"`
Parameter string `yaml:"parameter"`
Description string `yaml:"description"`
}
// 用于存放拆解出来的api相关信息
type APIParam struct {
APIKey APIKey `yaml:"apiKey"`
URL string `yaml:"url"`
Tool_Param []Tool_Param `yaml:"tool_Param"`
}
type Info struct {
Title string `yaml:"title"`
Description string `yaml:"description"`
Version string `yaml:"version"`
}
type Server struct {
URL string `yaml:"url"`
}
// 给OpenAPI的get方法用的
type Parameter struct {
Name string `yaml:"name"`
In string `yaml:"in"`
Description string `yaml:"description"`
Required bool `yaml:"required"`
Schema struct {
Type string `yaml:"type"`
Default string `yaml:"default"`
Enum []string `yaml:"enum"`
} `yaml:"schema"`
}
type Items struct {
Type string `yaml:"type"`
Example string `yaml:"example"`
}
type Property struct {
Description string `yaml:"description"`
Type string `yaml:"type"`
Enum []string `yaml:"enum,omitempty"`
Items *Items `yaml:"items,omitempty"`
MaxItems int `yaml:"maxItems,omitempty"`
Example string `yaml:"example,omitempty"`
}
type Schema struct {
Type string `yaml:"type"`
Required []string `yaml:"required"`
Properties map[string]Property `yaml:"properties"`
}
type MediaType struct {
Schema Schema `yaml:"schema"`
}
// 给OpenAPI的post方法用的
type RequestBody struct {
Required bool `yaml:"required"`
Content map[string]MediaType `yaml:"content"`
}
type PathItem struct {
Description string `yaml:"description"`
Summary string `yaml:"summary"`
OperationID string `yaml:"operationId"`
RequestBody RequestBody `yaml:"requestBody"`
Parameters []Parameter `yaml:"parameters"`
Deprecated bool `yaml:"deprecated"`
}
type Paths map[string]map[string]PathItem
type Components struct {
Schemas map[string]interface{} `yaml:"schemas"`
}
type API struct {
OpenAPI string `yaml:"openapi"`
Info Info `yaml:"info"`
Servers []Server `yaml:"servers"`
Paths Paths `yaml:"paths"`
Components Components `yaml:"components"`
}
type APIKey struct {
In string `yaml:"in" json:"in"`
Name string `yaml:"name" json:"name"`
Value string `yaml:"value" json:"value"`
}
type APIProvider struct {
// @Title zh-CN 服务名称
// @Description zh-CN 带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local
ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"`
// @Title zh-CN 服务端口
// @Description zh-CN 服务端口
ServicePort int64 `required:"true" yaml:"servicePort" json:"servicePort"`
// @Title zh-CN 服务域名
// @Description zh-CN 服务域名,例如 restapi.amap.com
Domin string `required:"true" yaml:"domain" json:"domain"`
// @Title zh-CN 通义千问大模型服务的key
// @Description zh-CN 通义千问大模型服务的key
APIKey APIKey `required:"true" yaml:"apiKey" json:"apiKey"`
}
type APIs struct {
APIProvider APIProvider `required:"true" yaml:"apiProvider" json:"apiProvider"`
API string `required:"true" yaml:"api" json:"api"`
}
type Template struct {
Question string `yaml:"question" json:"question"`
Thought1 string `yaml:"thought1" json:"thought1"`
ActionInput string `yaml:"actionInput" json:"actionInput"`
Observation string `yaml:"observation" json:"observation"`
Thought2 string `yaml:"thought2" json:"thought2"`
FinalAnswer string `yaml:"finalAnswer" json:"finalAnswer"`
Begin string `yaml:"begin" json:"begin"`
}
type PromptTemplate struct {
Language string `required:"true" yaml:"language" json:"language"`
CHTemplate Template `yaml:"chTemplate" json:"chTemplate"`
ENTemplate Template `yaml:"enTemplate" json:"enTemplate"`
}
type LLMInfo struct {
// @Title zh-CN 大模型服务名称
// @Description zh-CN 带服务类型的完整 FQDN 名称
ServiceName string `required:"true" yaml:"serviceName" json:"serviceName"`
// @Title zh-CN 大模型服务端口
// @Description zh-CN 服务端口
ServicePort int64 `required:"true" yaml:"servicePort" json:"servicePort"`
// @Title zh-CN 大模型服务域名
// @Description zh-CN 大模型服务域名,例如 dashscope.aliyuncs.com
Domin string `required:"true" yaml:"domin" json:"domin"`
// @Title zh-CN 大模型服务的key
// @Description zh-CN 大模型服务的key
APIKey string `required:"true" yaml:"apiKey" json:"apiKey"`
// @Title zh-CN 大模型服务的请求路径
// @Description zh-CN 大模型服务的请求路径,如"/compatible-mode/v1/chat/completions"
Path string `required:"true" yaml:"path" json:"path"`
// @Title zh-CN 大模型服务的模型名称
// @Description zh-CN 大模型服务的模型名称,如"qwen-max-0403"
Model string `required:"true" yaml:"model" json:"model"`
// @Title zh-CN 结束执行循环前的最大步数
// @Description zh-CN 结束执行循环前的最大步数比如2设置为0可能会无限循环直到超时退出默认15
MaxIterations int64 `yaml:"maxIterations" json:"maxIterations"`
// @Title zh-CN 每一次请求大模型的超时时间
// @Description zh-CN 每一次请求大模型的超时时间单位毫秒默认50000
MaxExecutionTime int64 `yaml:"maxExecutionTime" json:"maxExecutionTime"`
// @Title zh-CN
// @Description zh-CN 每一次请求大模型的输出token限制默认1000
MaxTokens int64 `yaml:"maxToken" json:"maxTokens"`
}
type PluginConfig struct {
// @Title zh-CN 返回 HTTP 响应的模版
// @Description zh-CN 用 %s 标记需要被 cache value 替换的部分
ReturnResponseTemplate string `required:"true" yaml:"returnResponseTemplate" json:"returnResponseTemplate"`
// @Title zh-CN 工具服务商以及工具信息
// @Description zh-CN 用于存储工具服务商以及工具信息
APIs []APIs `required:"true" yaml:"apis" json:"apis"`
APIClient []wrapper.HttpClient `yaml:"-" json:"-"`
// @Title zh-CN llm信息
// @Description zh-CN 用于存储llm使用信息
LLMInfo LLMInfo `required:"true" yaml:"llm" json:"llm"`
LLMClient wrapper.HttpClient `yaml:"-" json:"-"`
APIParam []APIParam `yaml:"-" json:"-"`
PromptTemplate PromptTemplate `yaml:"promptTemplate" json:"promptTemplate"`
}
func initResponsePromptTpl(gjson gjson.Result, c *PluginConfig) {
//设置回复模板
c.ReturnResponseTemplate = gjson.Get("returnResponseTemplate").String()
if c.ReturnResponseTemplate == "" {
c.ReturnResponseTemplate = `{"id":"error","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
}
}
func initAPIs(gjson gjson.Result, c *PluginConfig) error {
//从插件配置中获取apis信息
apis := gjson.Get("apis")
if !apis.Exists() {
return errors.New("apis is required")
}
if len(apis.Array()) == 0 {
return errors.New("apis cannot be empty")
}
for _, item := range apis.Array() {
serviceName := item.Get("apiProvider.serviceName")
if !serviceName.Exists() || serviceName.String() == "" {
return errors.New("apiProvider serviceName is required")
}
servicePort := item.Get("apiProvider.servicePort")
if !servicePort.Exists() || servicePort.Int() == 0 {
return errors.New("apiProvider servicePort is required")
}
domain := item.Get("apiProvider.domain")
if !domain.Exists() || domain.String() == "" {
return errors.New("apiProvider domain is required")
}
apiKeyIn := item.Get("apiProvider.apiKey.in").String()
if apiKeyIn != "query" {
apiKeyIn = "header"
}
apiKeyName := item.Get("apiProvider.apiKey.name")
apiKeyValue := item.Get("apiProvider.apiKey.value")
//根据多个toolsClientInfo的信息分别初始化toolsClient
apiClient := wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName.String(),
Port: servicePort.Int(),
Host: domain.String(),
})
c.APIClient = append(c.APIClient, apiClient)
api := item.Get("api")
if !api.Exists() || api.String() == "" {
return errors.New("api is required")
}
var apiStruct API
err := yaml.Unmarshal([]byte(api.String()), &apiStruct)
if err != nil {
return err
}
var allTool_param []Tool_Param
//拆除服务下面的每个api的path
for path, pathmap := range apiStruct.Paths {
//拆解出每个api对应的参数
for method, submap := range pathmap {
//把参数列表存起来
var param Tool_Param
param.Path = path
param.ToolName = submap.OperationID
if method == "get" {
param.Method = "GET"
paramName := make([]string, 0)
for _, parammeter := range submap.Parameters {
paramName = append(paramName, parammeter.Name)
}
param.ParamName = paramName
out, _ := json.Marshal(submap.Parameters)
param.Parameter = string(out)
param.Description = submap.Description
} else if method == "post" {
param.Method = "POST"
schema := submap.RequestBody.Content["application/json"].Schema
param.ParamName = schema.Required
param.Description = submap.Summary
out, _ := json.Marshal(schema.Properties)
param.Parameter = string(out)
}
allTool_param = append(allTool_param, param)
}
}
apiParam := APIParam{
APIKey: APIKey{In: apiKeyIn, Name: apiKeyName.String(), Value: apiKeyValue.String()},
URL: apiStruct.Servers[0].URL,
Tool_Param: allTool_param,
}
c.APIParam = append(c.APIParam, apiParam)
}
return nil
}
func initReActPromptTpl(gjson gjson.Result, c *PluginConfig) {
c.PromptTemplate.Language = gjson.Get("promptTemplate.language").String()
if c.PromptTemplate.Language != "EN" && c.PromptTemplate.Language != "CH" {
c.PromptTemplate.Language = "EN"
}
if c.PromptTemplate.Language == "EN" {
c.PromptTemplate.ENTemplate.Question = gjson.Get("promptTemplate.enTemplate.question").String()
if c.PromptTemplate.ENTemplate.Question == "" {
c.PromptTemplate.ENTemplate.Question = "the input question you must answer"
}
c.PromptTemplate.ENTemplate.Thought1 = gjson.Get("promptTemplate.enTemplate.thought1").String()
if c.PromptTemplate.ENTemplate.Thought1 == "" {
c.PromptTemplate.ENTemplate.Thought1 = "you should always think about what to do"
}
c.PromptTemplate.ENTemplate.ActionInput = gjson.Get("promptTemplate.enTemplate.actionInput").String()
if c.PromptTemplate.ENTemplate.ActionInput == "" {
c.PromptTemplate.ENTemplate.ActionInput = "the input to the action"
}
c.PromptTemplate.ENTemplate.Observation = gjson.Get("promptTemplate.enTemplate.observation").String()
if c.PromptTemplate.ENTemplate.Observation == "" {
c.PromptTemplate.ENTemplate.Observation = "the result of the action"
}
c.PromptTemplate.ENTemplate.Thought1 = gjson.Get("promptTemplate.enTemplate.thought2").String()
if c.PromptTemplate.ENTemplate.Thought1 == "" {
c.PromptTemplate.ENTemplate.Thought1 = "I now know the final answer"
}
c.PromptTemplate.ENTemplate.FinalAnswer = gjson.Get("promptTemplate.enTemplate.finalAnswer").String()
if c.PromptTemplate.ENTemplate.FinalAnswer == "" {
c.PromptTemplate.ENTemplate.FinalAnswer = "the final answer to the original input question, please give the most direct answer directly in Chinese, not English, and do not add extra content."
}
c.PromptTemplate.ENTemplate.Begin = gjson.Get("promptTemplate.enTemplate.begin").String()
if c.PromptTemplate.ENTemplate.Begin == "" {
c.PromptTemplate.ENTemplate.Begin = "Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Arg\"s"
}
} else if c.PromptTemplate.Language == "CH" {
c.PromptTemplate.CHTemplate.Question = gjson.Get("promptTemplate.chTemplate.question").String()
if c.PromptTemplate.CHTemplate.Question == "" {
c.PromptTemplate.CHTemplate.Question = "你需要回答的输入问题"
}
c.PromptTemplate.CHTemplate.Thought1 = gjson.Get("promptTemplate.chTemplate.thought1").String()
if c.PromptTemplate.CHTemplate.Thought1 == "" {
c.PromptTemplate.CHTemplate.Thought1 = "你应该总是思考该做什么"
}
c.PromptTemplate.CHTemplate.ActionInput = gjson.Get("promptTemplate.chTemplate.actionInput").String()
if c.PromptTemplate.CHTemplate.ActionInput == "" {
c.PromptTemplate.CHTemplate.ActionInput = "行动的输入必须出现在Action后"
}
c.PromptTemplate.CHTemplate.Observation = gjson.Get("promptTemplate.chTemplate.observation").String()
if c.PromptTemplate.CHTemplate.Observation == "" {
c.PromptTemplate.CHTemplate.Observation = "行动的结果"
}
c.PromptTemplate.CHTemplate.Thought1 = gjson.Get("promptTemplate.chTemplate.thought2").String()
if c.PromptTemplate.CHTemplate.Thought1 == "" {
c.PromptTemplate.CHTemplate.Thought1 = "我现在知道最终答案"
}
c.PromptTemplate.CHTemplate.FinalAnswer = gjson.Get("promptTemplate.chTemplate.finalAnswer").String()
if c.PromptTemplate.CHTemplate.FinalAnswer == "" {
c.PromptTemplate.CHTemplate.FinalAnswer = "对原始输入问题的最终答案"
}
c.PromptTemplate.CHTemplate.Begin = gjson.Get("promptTemplate.chTemplate.begin").String()
if c.PromptTemplate.CHTemplate.Begin == "" {
c.PromptTemplate.CHTemplate.Begin = "再次重申,不要修改以上模板的字段名称,开始吧!"
}
}
}
func initLLMClient(gjson gjson.Result, c *PluginConfig) {
c.LLMInfo.APIKey = gjson.Get("llm.apiKey").String()
c.LLMInfo.ServiceName = gjson.Get("llm.serviceName").String()
c.LLMInfo.ServicePort = gjson.Get("llm.servicePort").Int()
c.LLMInfo.Domin = gjson.Get("llm.domain").String()
c.LLMInfo.Path = gjson.Get("llm.path").String()
c.LLMInfo.Model = gjson.Get("llm.model").String()
c.LLMInfo.MaxIterations = gjson.Get("llm.maxIterations").Int()
if c.LLMInfo.MaxIterations == 0 {
c.LLMInfo.MaxIterations = 15
}
c.LLMInfo.MaxExecutionTime = gjson.Get("llm.maxExecutionTime").Int()
if c.LLMInfo.MaxExecutionTime == 0 {
c.LLMInfo.MaxExecutionTime = 50000
}
c.LLMInfo.MaxTokens = gjson.Get("llm.maxTokens").Int()
if c.LLMInfo.MaxTokens == 0 {
c.LLMInfo.MaxTokens = 1000
}
c.LLMClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: c.LLMInfo.ServiceName,
Port: c.LLMInfo.ServicePort,
Host: c.LLMInfo.Domin,
})
}

View File

@@ -0,0 +1,46 @@
package dashscope
var MessageStore ChatMessages
func init() {
MessageStore = make(ChatMessages, 0)
MessageStore.Clear() //清理和初始化
}
type ChatMessages []Message
// 枚举出角色
const (
RoleUser = "user"
RoleAssistant = "assistant"
RoleSystem = "system"
)
func (cm *ChatMessages) Clear() {
*cm = make([]Message, 0) //重新初始化
}
// 添加角色和对应的prompt
func (cm *ChatMessages) AddFor(msg string, role string) {
*cm = append(*cm, Message{
Role: role,
Content: msg,
})
}
// 添加Assistant角色的prompt
func (cm *ChatMessages) AddForAssistant(msg string) {
cm.AddFor(msg, RoleAssistant)
}
// 添加System角色的prompt
func (cm *ChatMessages) AddForSystem(msg string) {
cm.AddFor(msg, RoleSystem)
}
// 添加User角色的prompt
func (cm *ChatMessages) AddForUser(msg string) {
cm.AddFor(msg, RoleUser)
}

View File

@@ -0,0 +1,70 @@
package dashscope
// DashScope embedding service: Request
type Request struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameter Parameter `json:"parameters"`
}
type Input struct {
Texts []string `json:"texts"`
}
type Parameter struct {
TextType string `json:"text_type"`
}
// DashScope embedding service: Response
type Response struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
RequestID string `json:"request_id"`
}
type Output struct {
Embeddings []Embedding `json:"embeddings"`
}
type Embedding struct {
Embedding []float32 `json:"embedding"`
TextIndex int32 `json:"text_index"`
}
type Usage struct {
TotalTokens int32 `json:"total_tokens"`
}
// completion
type Completion struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int64 `json:"max_tokens"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type CompletionResponse struct {
Choices []Choice `json:"choices"`
Object string `json:"object"`
Usage CompletionUsage `json:"usage"`
Created string `json:"created"`
SystemFingerprint string `json:"system_fingerprint"`
Model string `json:"model"`
ID string `json:"id"`
}
type Choice struct {
Message Message `json:"message"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
}
type CompletionUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@@ -0,0 +1,19 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent
go 1.19
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.2
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.17.3
gopkg.in/yaml.v2 v2.4.0
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
)

View File

@@ -0,0 +1,26 @@
github.com/alibaba/higress/plugins/wasm-go v1.4.2 h1:gH7OIGXm4wtW5Vo7L2deMPqF7OVWNESDHv1CaaTGu6s=
github.com/alibaba/higress/plugins/wasm-go v1.4.2/go.mod h1:359don/ahMxpfeLMzr29Cjwcu8IywTTDUzWlBPRNLHw=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,372 @@
package main
import (
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/dashscope"
prompttpl "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/promptTpl"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)
// 用于统计函数的递归调用次数
const ToolCallsCount = "ToolCallsCount"
// react的正则规则
const ActionPattern = `Action:\s*(.*?)[.\n]`
const ActionInputPattern = `Action Input:\s*(.*)`
const FinalAnswerPattern = `Final Answer:(.*)`
func main() {
wrapper.SetCtx(
"ai-agent",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
)
}
func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
initResponsePromptTpl(gjson, c)
err := initAPIs(gjson, c)
if err != nil {
return err
}
initReActPromptTpl(gjson, c)
initLLMClient(gjson, c)
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
return types.ActionContinue
}
func firstReq(config PluginConfig, prompt string, rawRequest Request, log wrapper.Log) types.Action {
log.Debugf("[onHttpRequestBody] firstreq:%s", prompt)
var userMessage Message
userMessage.Role = "user"
userMessage.Content = prompt
newMessages := []Message{userMessage}
rawRequest.Messages = newMessages
//replace old message and resume request qwen
newbody, err := json.Marshal(rawRequest)
if err != nil {
return types.ActionContinue
} else {
log.Debugf("[onHttpRequestBody] newRequestBody: ", string(newbody))
err := proxywasm.ReplaceHttpRequestBody(newbody)
if err != nil {
log.Debug("替换失败")
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "替换失败"+err.Error())), -1)
}
log.Debug("[onHttpRequestBody] request替换成功")
return types.ActionContinue
}
}
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
log.Debug("onHttpRequestBody start")
defer log.Debug("onHttpRequestBody end")
//拿到请求
var rawRequest Request
err := json.Unmarshal(body, &rawRequest)
if err != nil {
log.Debugf("[onHttpRequestBody] body json umarshal err: ", err.Error())
return types.ActionContinue
}
log.Debugf("onHttpRequestBody rawRequest: %v", rawRequest)
//获取用户query
var query string
messageLength := len(rawRequest.Messages)
log.Debugf("[onHttpRequestBody] messageLength: %s\n", messageLength)
if messageLength > 0 {
query = rawRequest.Messages[messageLength-1].Content
log.Debugf("[onHttpRequestBody] query: %s\n", query)
} else {
return types.ActionContinue
}
if query == "" {
log.Debug("parse query from request body failed")
return types.ActionContinue
}
//拼装agent prompt模板
tool_desc := make([]string, 0)
tool_names := make([]string, 0)
for _, apiParam := range config.APIParam {
for _, tool_param := range apiParam.Tool_Param {
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Description, tool_param.Description, tool_param.Description, tool_param.Parameter), "\n")
tool_names = append(tool_names, tool_param.ToolName)
}
}
var prompt string
if config.PromptTemplate.Language == "CH" {
prompt = fmt.Sprintf(prompttpl.CH_Template,
tool_desc,
config.PromptTemplate.CHTemplate.Question,
config.PromptTemplate.CHTemplate.Thought1,
tool_names,
config.PromptTemplate.CHTemplate.ActionInput,
config.PromptTemplate.CHTemplate.Observation,
config.PromptTemplate.CHTemplate.Thought2,
config.PromptTemplate.CHTemplate.FinalAnswer,
config.PromptTemplate.CHTemplate.Begin,
query)
} else {
prompt = fmt.Sprintf(prompttpl.EN_Template,
tool_desc,
config.PromptTemplate.ENTemplate.Question,
config.PromptTemplate.ENTemplate.Thought1,
tool_names,
config.PromptTemplate.ENTemplate.ActionInput,
config.PromptTemplate.ENTemplate.Observation,
config.PromptTemplate.ENTemplate.Thought2,
config.PromptTemplate.ENTemplate.FinalAnswer,
config.PromptTemplate.ENTemplate.Begin,
query)
}
ctx.SetContext(ToolCallsCount, 0)
//清理历史对话记录
dashscope.MessageStore.Clear()
//将请求加入到历史对话存储器中
dashscope.MessageStore.AddForUser(prompt)
//开始第一次请求
ret := firstReq(config, prompt, rawRequest, log)
return ret
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
log.Debug("onHttpResponseHeaders start")
defer log.Debug("onHttpResponseHeaders end")
return types.ActionContinue
}
func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
if statusCode != http.StatusOK {
log.Debugf("statusCode: %d\n", statusCode)
}
log.Info("========函数返回结果========")
log.Infof(string(responseBody))
observation := "Observation: " + string(responseBody)
dashscope.MessageStore.AddForUser(observation)
completion := dashscope.Completion{
Model: config.LLMInfo.Model,
Messages: dashscope.MessageStore,
MaxTokens: config.LLMInfo.MaxTokens,
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
completionSerialized, _ := json.Marshal(completion)
err := config.LLMClient.Post(
config.LLMInfo.Path,
headers,
completionSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
//得到gpt的返回结果
var responseCompletion dashscope.CompletionResponse
_ = json.Unmarshal(responseBody, &responseCompletion)
log.Infof("[toolsCall] content: %s\n", responseCompletion.Choices[0].Message.Content)
if responseCompletion.Choices[0].Message.Content != "" {
retType := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
if retType == types.ActionContinue {
//得到了Final Answer
var assistantMessage Message
assistantMessage.Role = "assistant"
startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:")
if startIndex != -1 {
startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
assistantMessage.Content = extractedText
}
rawResponse.Choices[0].Message = assistantMessage
newbody, err := json.Marshal(rawResponse)
if err != nil {
proxywasm.ResumeHttpResponse()
return
} else {
log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody))
proxywasm.ReplaceHttpResponseBody(newbody)
log.Debug("[onHttpResponseBody] response替换成功")
proxywasm.ResumeHttpResponse()
}
}
} else {
proxywasm.ResumeHttpRequest()
}
}, uint32(config.LLMInfo.MaxExecutionTime))
if err != nil {
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
proxywasm.ResumeHttpRequest()
}
}
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
dashscope.MessageStore.AddForAssistant(content)
//得到最终答案
regexPattern := regexp.MustCompile(FinalAnswerPattern)
finalAnswer := regexPattern.FindStringSubmatch(content)
if len(finalAnswer) > 1 {
return types.ActionContinue
}
count := ctx.GetContext(ToolCallsCount).(int)
count++
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d\n", count, config.LLMInfo.MaxIterations)
//函数递归调用次数,达到了预设的循环次数,强制结束
if int64(count) > config.LLMInfo.MaxIterations {
ctx.SetContext(ToolCallsCount, 0)
return types.ActionContinue
} else {
ctx.SetContext(ToolCallsCount, count)
}
//没得到最终答案
regexAction := regexp.MustCompile(ActionPattern)
regexActionInput := regexp.MustCompile(ActionInputPattern)
action := regexAction.FindStringSubmatch(content)
actionInput := regexActionInput.FindStringSubmatch(content)
if len(action) > 1 && len(actionInput) > 1 {
var url string
var headers [][2]string
var apiClient wrapper.HttpClient
var method string
var reqBody []byte
var key string
for i, apiParam := range config.APIParam {
for _, tool_param := range apiParam.Tool_Param {
if action[1] == tool_param.ToolName {
log.Infof("calls %s\n", tool_param.ToolName)
log.Infof("actionInput[1]: %s", actionInput[1])
//将大模型需要的参数反序列化
var data map[string]interface{}
if err := json.Unmarshal([]byte(actionInput[1]), &data); err != nil {
log.Debugf("Error: %s\n", err.Error())
return types.ActionContinue
}
method = tool_param.Method
//key or header组装
if apiParam.APIKey.Name != "" {
if apiParam.APIKey.In == "query" { //query类型的key要放到url中
headers = nil
key = "?" + apiParam.APIKey.Name + "=" + apiParam.APIKey.Value
} else if apiParam.APIKey.In == "header" { //header类型的key放在header中
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", apiParam.APIKey.Name + " " + apiParam.APIKey.Value}}
}
}
if method == "GET" {
//query组装
var args string
for i, param := range tool_param.ParamName { //从参数列表中取出参数
if i == 0 && apiParam.APIKey.In != "query" {
args = "?" + param + "=%s"
args = fmt.Sprintf(args, data[param])
} else {
args = args + "&" + param + "=%s"
args = fmt.Sprintf(args, data[param])
}
}
//url组装
url = apiParam.URL + tool_param.Path + key + args
} else if method == "POST" {
reqBody = nil
//json参数组装
jsonData, err := json.Marshal(data)
if err != nil {
log.Debugf("Error: %s\n", err.Error())
return types.ActionContinue
}
reqBody = jsonData
//url组装
url = apiParam.URL + tool_param.Path + key
}
log.Infof("url: %s\n", url)
apiClient = config.APIClient[i]
break
}
}
}
if apiClient != nil {
err := apiClient.Call(
method,
url,
headers,
reqBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
}, 50000)
if err != nil {
log.Debugf("tool calls error: %s\n", err.Error())
proxywasm.ResumeHttpRequest()
}
} else {
return types.ActionContinue
}
}
return types.ActionPause
}
// 从response接收到firstreq的大模型返回
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("onHttpResponseBody start")
defer log.Debugf("onHttpResponseBody end")
//初始化接收gpt返回内容的结构体
var rawResponse Response
err := json.Unmarshal(body, &rawResponse)
if err != nil {
log.Debugf("[onHttpResponseBody] body to json err: %s", err.Error())
return types.ActionContinue
}
log.Infof("first content: %s\n", rawResponse.Choices[0].Message.Content)
//如果gpt返回的内容不是空的
if rawResponse.Choices[0].Message.Content != "" {
//进入agent的循环思考工具调用的过程中
return toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
} else {
return types.ActionContinue
}
}

View File

@@ -0,0 +1,93 @@
package prompttpl
// input param
// {name_for_model}
// {description_for_model}
// {description_for_model}
// {description_for_model}
// {parameters}
const TOOL_DESC = `
%s: Call this tool to interact with the %s API. What is the %s API useful for? %s
Parameters:
%s
Format the arguments as a JSON object.`
/*
Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:
%s
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of %s
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question, please give the most direct answer directly in Chinese, not English, and do not add extra content.
Begin! Remember to speak as a pirate when giving your final answer. Use lots of "Arg"s
Question: %s
*/
const EN_Template = `
Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:
%s
Use the following format:
Question: %s
Thought: %s
Action: the action to take, should be one of %s
Action Input: %s
Observation: %s
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: %s
Final Answer: %s
%s
Question: %s
`
/*
尽你所能回答以下问题。你可以使用以下工具:
%s
请使用以下格式其中Action字段后必须跟着Action Input字段并且不要将Action Input替换成Input或者tool等字段不能出现格式以外的字段名每个字段在每个轮次只出现一次
Question: 你需要回答的输入问题
Thought: 你应该总是思考该做什么
Action: 要采取的动作,动作只能是%s中的一个 ,一定不要加入其它内容
Action Input: 行动的输入必须出现在Action后。
Observation: 行动的结果
...这个Thought/Action/Action Input/Observation可以重复N次
Thought: 我现在知道最终答案
Final Answer: 对原始输入问题的最终答案
再次重申,不要修改以上模板的字段名称,开始吧!
Question: %s
*/
const CH_Template = `
尽你所能回答以下问题。你可以使用以下工具:
%s
请使用以下格式其中Action字段后必须跟着Action Input字段并且不要将Action Input替换成Input或者tool等字段不能出现格式以外的字段名每个字段在每个轮次只出现一次
Question: %s
Thought: %s
Action: 要采取的动作,动作只能是%s中的一个 ,一定不要加入其它内容
Action Input: %s
Observation: %s
...这个Thought/Action/Action Input/Observation可以重复N次
Thought: %s
Final Answer: %s
%s
Question: %s
`

View File

@@ -32,3 +32,15 @@ redis:
serviceName: my-redis.dns
timeout: 2000
```
## 进阶用法
当前默认的缓存 key 是基于 GJSON PATH 的表达式:`messages.@reverse.0.content` 提取,含义是把 messages 数组反转后取第一项的 content
GJSON PATH 支持条件判断语法,例如希望取最后一个 role 为 user 的 content 作为 key可以写成 `messages.@reverse.#(role=="user").content`
如果希望将所有 role 为 user 的 content 拼成一个数组作为 key可以写成`messages.@reverse.#(role=="user")#.content`
还可以支持管道语法,例如希望取到数第二个 role 为 user 的 content 作为 key可以写成`messages.@reverse.#(role=="user")#.content|1`
更多用法可以参考[官方文档](https://github.com/tidwall/gjson/blob/master/SYNTAX.md),可以使用 [GJSON Playground](https://gjson.dev/) 进行语法测试。

View File

@@ -66,4 +66,70 @@ curl http://localhost/test \
}
]
}
```
```
# 基于geo-ip插件的能力扩展AI提示词装饰器插件携带用户地理位置信息
如果需要在LLM的请求前后加入用户地理位置信息请确保同时开启geo-ip插件和AI提示词装饰器插件。并且在相同的请求处理阶段里geo-ip插件的优先级必须高于AI提示词装饰器插件。首先geo-ip插件会根据用户ip计算出用户的地理位置信息然后通过请求属性传递给后续插件。比如在默认阶段里geo-ip插件的priority配置1000ai-prompt-decorator插件的priority配置500。
geo-ip插件配置示例
```yaml
ipProtocal: "ipv4"
```
AI提示词装饰器插件的配置示例如下
```yaml
prepend:
- role: system
content: "提问用户当前的地理位置信息是,国家:${geo-country},省份:${geo-province}, 城市:${geo-city}"
append:
- role: user
content: "每次回答完问题,尝试进行反问"
```
使用以上配置发起请求:
```bash
curl http://localhost/test \
-H "content-type: application/json" \
-H "x-forwarded-for: 87.254.207.100,4.5.6.7" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "今天天气怎么样?"
}
]
}'
```
经过插件处理后,实际请求为:
```bash
curl http://localhost/test \
-H "content-type: application/json" \
-H "x-forwarded-for: 87.254.207.100,4.5.6.7" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "提问用户当前的地理位置信息是,国家:中国,省份:北京, 城市:北京"
},
{
"role": "user",
"content": "今天天气怎么样?"
},
{
"role": "user",
"content": "每次回答完问题,尝试进行反问"
}
]
}'
```

View File

@@ -2,6 +2,8 @@ package main
import (
"encoding/json"
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
@@ -38,10 +40,42 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfi
return types.ActionContinue
}
func replaceVariable(variable string, entry *Message) (*Message, error) {
key := fmt.Sprintf("${%s}", variable)
if strings.Contains(entry.Content, key) {
value, err := proxywasm.GetProperty([]string{variable})
if err != nil {
return nil, err
}
entry.Content = strings.ReplaceAll(entry.Content, key, string(value))
}
return entry, nil
}
func decorateGeographicPrompt(entry *Message) (*Message, error) {
geoArr := []string{"geo-country", "geo-province", "geo-city", "geo-isp"}
var err error
for _, geo := range geoArr {
entry, err = replaceVariable(geo, entry)
if err != nil {
return nil, err
}
}
return entry, nil
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log wrapper.Log) types.Action {
messageJson := `{"messages":[]}`
for _, entry := range config.Prepend {
entry, err := decorateGeographicPrompt(&entry)
if err != nil {
log.Errorf("Failed to decorate geographic prompt in prepend, error: %v", err)
return types.ActionContinue
}
msg, err := json.Marshal(entry)
if err != nil {
log.Errorf("Failed to add prepend message, error: %v", err)
@@ -60,6 +94,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig,
}
for _, entry := range config.Append {
entry, err := decorateGeographicPrompt(&entry)
if err != nil {
log.Errorf("Failed to decorate geographic prompt in append, error: %v", err)
return types.ActionContinue
}
msg, err := json.Marshal(entry)
if err != nil {
log.Errorf("Failed to add prepend message, error: %v", err)

View File

@@ -1,6 +1,6 @@
---
title: AI 代理
keywords: [ higress,ai,proxy,rag ]
keywords: [ AI网关, AI代理 ]
description: AI 代理插件配置参考
---
@@ -9,6 +9,13 @@ description: AI 代理插件配置参考
`AI 代理`插件实现了基于 OpenAI API 契约的 AI 代理功能。目前支持 OpenAI、Azure OpenAI、月之暗面Moonshot和通义千问等 AI
服务提供商。
> **注意:**
> 请求路径后缀匹配 `/v1/chat/completions` 时,对应文生文场景,会用 OpenAI 的文生文协议解析请求 Body再转换为对应 LLM 厂商的文生文协议
> 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body再转换为对应 LLM 厂商的文本向量协议
## 配置字段
### 基本配置
@@ -19,14 +26,15 @@ description: AI 代理插件配置参考
`provider`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------------- | --------------- | -------- | ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000即 2 分钟 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `type` | string | 必填 | - | AI 服务提供商名称 |
| `apiTokens` | array of string | 必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000即 2 分钟 |
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值openai默认值使用 OpenAI 的接口契约、original使用目标服务提供商的原始接口契约 |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值openai默认值使用 OpenAI 的接口契约、original使用目标服务提供商的原始接口契约 |
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
`context`的配置字段说明如下:
@@ -36,11 +44,44 @@ description: AI 代理插件配置参考
| `serviceName` | string | 必填 | - | URL 所对应的 Higress 后端服务完整名称 |
| `servicePort` | number | 必填 | - | URL 所对应的 Higress 后端服务访问端口 |
`customSettings`的配置字段说明如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ----------- | --------------------- | -------- | ------ | ---------------------------------------------------------------------------------------------------------------------------- |
| `name` | string | 必填 | - | 想要设置的参数的名称,例如`max_tokens` |
| `value` | string/int/float/bool | 必填 | - | 想要设置的参数的值例如0 |
| `mode` | string | 非必填 | "auto" | 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会自动根据协议对参数名做改写,如果为"raw"则不会有任何改写和限制检查 |
| `overwrite` | bool | 非必填 | true | 如果为false则只在用户没有设置这个参数时填充参数否则会直接覆盖用户原有的参数设置 |
custom-setting会遵循如下表格根据`name`和协议来替换对应的字段,用户需要填写表格中`settingName`列中存在的值。例如用户将`name`设置为`max_tokens`在openai协议中会替换`max_tokens`在gemini中会替换`maxOutputTokens`
`none`表示该协议不支持此参数。如果`name`不在此表格中或者对应协议不支持此参数同时没有设置raw模式则配置不会生效。
| settingName | openai | baidu | spark | qwen | gemini | hunyuan | claude | minimax |
| ----------- | ----------- | ----------------- | ----------- | ----------- | --------------- | ----------- | ----------- | ------------------ |
| max_tokens | max_tokens | max_output_tokens | max_tokens | max_tokens | maxOutputTokens | none | max_tokens | tokens_to_generate |
| temperature | temperature | temperature | temperature | temperature | temperature | Temperature | temperature | temperature |
| top_p | top_p | top_p | none | top_p | topP | TopP | top_p | top_p |
| top_k | none | none | top_k | none | topK | none | top_k | none |
| seed | seed | none | none | seed | none | none | none | none |
如果启用了raw模式custom-setting会直接用输入的`name``value`去更改请求中的json内容而不对参数名称做任何限制和修改。
对于大多数协议custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
### 提供商特有配置
#### OpenAI
OpenAI 所对应的 `type``openai`。它并无特有的配置字段
OpenAI 所对应的 `type``openai`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-------------------|----------|----------|--------|-------------------------------------------------------------------------------|
| `openaiCustomUrl` | string | 非必填 | - | 基于OpenAI协议的自定义后端URL例如: www.example.com/myai/v1/chat/completions |
| `responseJsonSchema` | object | 非必填 | - | 预先定义OpenAI响应需满足的Json Schema, 注意目前仅特定的几种模型支持该用法|
#### Azure OpenAI
@@ -93,6 +134,10 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
文心一言所对应的 `type``baidu`。它并无特有的配置字段。
#### 360智脑
360智脑所对应的 `type``ai360`。它并无特有的配置字段。
#### MiniMax
MiniMax所对应的 `type``minimax`。它特有的配置字段如下:
@@ -139,6 +184,27 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置
|-------------------|--------|------|-----|----------------------------------------------------------------------------------------------------------------------------|
| `cloudflareAccountId` | string | 必填 | - | [Cloudflare Account ID](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-api-token-and-account-id) |
#### 星火 (Spark)
星火所对应的 `type``spark`。它并无特有的配置字段。
讯飞星火认知大模型的`apiTokens`字段值为`APIKey:APISecret`。即填入自己的APIKey与APISecret并以`:`分隔。
#### Gemini
Gemini 所对应的 `type``gemini`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| --------------------- | -------- | -------- |-----|-------------------------------------------------------------------------------------------------|
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
#### DeepL
DeepL 所对应的 `type``deepl`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ------------ | -------- | -------- | ------ | ---------------------------- |
| `targetLang` | string | 必填 | - | DeepL 翻译服务需要的目标语种 |
## 用法示例
@@ -256,6 +322,7 @@ provider:
'gpt-35-turbo': "qwen-plus"
'gpt-4-turbo': "qwen-max"
'gpt-4-*': "qwen-max"
'gpt-4o': "qwen-vl-plus"
'text-embedding-v1': 'text-embedding-v1'
'*': "qwen-turbo"
```
@@ -264,7 +331,111 @@ provider:
URL: http://your-domain/v1/chat/completions
请求
请求示例
```json
{
"model": "gpt-3",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"temperature": 0.3
}
```
响应示例:
```json
{
"id": "c2518bd3-0f46-97d1-be34-bb5777cb3108",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "我是通义千问由阿里云开发的AI助手。我可以回答各种问题、提供信息和与用户进行对话。有什么我可以帮助你的吗"
},
"finish_reason": "stop"
}
],
"created": 1715175072,
"model": "qwen-turbo",
"object": "chat.completion",
"usage": {
"prompt_tokens": 24,
"completion_tokens": 33,
"total_tokens": 57
}
}
```
**多模态模型 API 请求示例(适用于 `qwen-vl-plus` 和 `qwen-vl-max` 模型)**
URL: http://your-domain/v1/chat/completions
请求示例:
```json
{
"model": "gpt-4o",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg"
}
},
{
"type": "text",
"text": "这个图片是哪里?"
}
]
}
],
"temperature": 0.3
}
```
响应示例:
```json
{
"id": "17c5955d-af9c-9f28-bbde-293a9c9a3515",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": [
{
"text": "这张照片显示的是一位女士和一只狗在海滩上。由于我无法获取具体的地理位置信息,所以不能确定这是哪个地方的海滩。但是从视觉内容来看,它可能是一个位于沿海地区的沙滩海岸线,并且有海浪拍打着岸边。这样的场景在全球许多美丽的海滨地区都可以找到。如果您需要更精确的信息,请提供更多的背景或细节描述。"
}
]
},
"finish_reason": "stop"
}
],
"created": 1723949230,
"model": "qwen-vl-plus",
"object": "chat.completion",
"usage": {
"prompt_tokens": 1279,
"completion_tokens": 78
}
}
```
**文本向量请求示例**
URL: http://your-domain/v1/embeddings
请求示例:
```json
{
@@ -273,7 +444,7 @@ URL: http://your-domain/v1/chat/completions
}
```
响应示例:
响应示例:
```json
{
@@ -305,47 +476,6 @@ URL: http://your-domain/v1/chat/completions
}
```
**请求示例**
URL: http://your-domain/v1/embeddings
示例请求内容:
```json
{
"model": "text-embedding-v1",
"input": [
"Hello world!"
]
}
```
示例响应内容:
```json
{
"id": "c2518bd3-0f46-97d1-be34-bb5777cb3108",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "我是通义千问由阿里云开发的AI助手。我可以回答各种问题、提供信息和与用户进行对话。有什么我可以帮助你的吗"
},
"finish_reason": "stop"
}
],
"created": 1715175072,
"model": "qwen-turbo",
"object": "chat.completion",
"usage": {
"prompt_tokens": 24,
"completion_tokens": 33,
"total_tokens": 57
}
}
```
### 使用通义千问配合纯文本上下文信息
使用通义千问服务,同时配置纯文本上下文信息。
@@ -814,6 +944,77 @@ provider:
}
```
### 使用 OpenAI 协议代理360智脑服务
**配置信息**
```yaml
provider:
type: ai360
apiTokens:
- "YOUR_MINIMAX_API_TOKEN"
modelMapping:
"gpt-4o": "360gpt-turbo-responsibility-8k"
"gpt-4": "360gpt2-pro"
"gpt-3.5": "360gpt-turbo"
"*": "360gpt-pro"
```
**请求示例**
```json
{
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "你是一个专业的开发人员!"
},
{
"role": "user",
"content": "你好,你是谁?"
}
]
}
```
**响应示例**
```json
{
"choices": [
{
"message": {
"role": "assistant",
"content": "你好我是360智脑一个大型语言模型。我可以帮助回答各种问题、提供信息、进行对话等。有什么可以帮助你的吗"
},
"finish_reason": "",
"index": 0
}
],
"created": 1724257207,
"id": "5e5c94a2-d989-40b5-9965-5b971db941fe",
"model": "360gpt-turbo",
"object": "",
"usage": {
"completion_tokens": 33,
"prompt_tokens": 24,
"total_tokens": 57
},
"messages": [
{
"role": "system",
"content": "你是一个专业的开发人员!"
},
{
"role": "user",
"content": "你好,你是谁?"
}
],
"context": null
}
```
### 使用 OpenAI 协议代理 Cloudflare Workers AI 服务
**配置信息**
@@ -865,6 +1066,177 @@ provider:
}
```
### 使用 OpenAI 协议代理Spark服务
**配置信息**
```yaml
provider:
type: spark
apiTokens:
- "APIKey:APISecret"
modelMapping:
"gpt-4o": "generalv3.5"
"gpt-4": "generalv3"
"*": "general"
```
**请求示例**
```json
{
"model": "gpt-4o",
"messages": [
{
"role": "system",
"content": "你是一名专业的开发人员!"
},
{
"role": "user",
"content": "你好,你是谁?"
}
],
"stream": false
}
```
**响应示例**
```json
{
"id": "cha000c23c6@dx190ef0b4b96b8f2532",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好!我是一名专业的开发人员,擅长编程和解决技术问题。有什么我可以帮助你的吗?"
}
}
],
"created": 1721997415,
"model": "generalv3.5",
"object": "chat.completion",
"usage": {
"prompt_tokens": 10,
"completion_tokens": 19,
"total_tokens": 29
}
}
```
### 使用 OpenAI 协议代理 gemini 服务
**配置信息**
```yaml
provider:
type: gemini
apiTokens:
- "YOUR_GEMINI_API_TOKEN"
modelMapping:
"*": "gemini-pro"
geminiSafetySetting:
"HARM_CATEGORY_SEXUALLY_EXPLICIT" :"BLOCK_NONE"
"HARM_CATEGORY_HATE_SPEECH" :"BLOCK_NONE"
"HARM_CATEGORY_HARASSMENT" :"BLOCK_NONE"
"HARM_CATEGORY_DANGEROUS_CONTENT" :"BLOCK_NONE"
```
**请求示例**
```json
{
"model": "gpt-3.5",
"messages": [
{
"role": "user",
"content": "Who are you?"
}
],
"stream": false
}
```
**响应示例**
```json
{
"id": "chatcmpl-b010867c-0d3f-40ba-95fd-4e8030551aeb",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I am a large multi-modal model, trained by Google. I am designed to provide information and answer questions to the best of my abilities."
},
"finish_reason": "stop"
}
],
"created": 1722756984,
"model": "gemini-pro",
"object": "chat.completion",
"usage": {
"prompt_tokens": 5,
"completion_tokens": 29,
"total_tokens": 34
}
}
```
### 使用 OpenAI 协议代理 DeepL 文本翻译服务
**配置信息**
```yaml
provider:
type: deepl
apiTokens:
- "YOUR_DEEPL_API_TOKEN"
targetLang: "ZH"
```
**请求示例**
此处 `model` 表示 DeepL 的服务类型,只能填 `Free``Pro``content` 中设置需要翻译的文本;在 `role: system``content` 中可以包含可能影响翻译但本身不会被翻译的上下文,例如翻译产品名称时,可以将产品描述作为上下文传递,这种额外的上下文可能会提高翻译的质量。
```json
{
"model": "Free",
"messages": [
{
"role": "system",
"content": "money"
},
{
"content": "sit by the bank"
},
{
"content": "a bank in China"
}
]
}
```
**响应示例**
```json
{
"choices": [
{
"index": 0,
"message": { "name": "EN", "role": "assistant", "content": "坐庄" }
},
{
"index": 1,
"message": { "name": "EN", "role": "assistant", "content": "中国银行" }
}
],
"created": 1722747752,
"model": "Free",
"object": "chat.completion",
"usage": {}
}
```
## 完整配置示例
### Kubernetes 示例

View File

@@ -50,3 +50,7 @@ func (c *PluginConfig) Complete() error {
func (c *PluginConfig) GetProvider() provider.Provider {
return c.provider
}
func (c *PluginConfig) GetProviderConfig() provider.ProviderConfig {
return c.providerConfig
}

View File

@@ -15,12 +15,13 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/google/uuid v1.3.0
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -12,6 +12,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -20,6 +21,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -75,15 +75,15 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
action, err := handler.OnRequestHeaders(ctx, apiName, log)
_, err := handler.OnRequestHeaders(ctx, apiName, log)
if err == nil {
if contentType, err := proxywasm.GetHttpRequestHeader("Content-Type"); err == nil && contentType != "" {
if wrapper.HasRequestBody() {
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
// Always return types.HeaderStopIteration to support fallback routing,
// as long as onHttpRequestBody can be called.
return types.HeaderStopIteration
}
return action
return types.ActionContinue
}
_ = util.SendResponse(500, "ai-proxy.proc_req_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request headers: %v", err))
return types.ActionContinue
@@ -104,12 +104,20 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
if settingErr != nil {
_ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to rewrite request body by custom settings: %v", settingErr))
return types.ActionContinue
}
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
body = newBody
action, err := handler.OnRequestBody(ctx, apiName, body, log)
if err == nil {
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_req_body_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process request body: %v", err))
return types.ActionContinue
}
return types.ActionContinue
}
@@ -140,24 +148,18 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
return types.ActionContinue
}
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {
log.Errorf("unable to load content-type header from response: %v", err)
}
ctx.BufferResponseBody()
}
if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok {
apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName)
action, err := handler.OnResponseHeaders(ctx, apiName, log)
if err == nil {
checkStream(&ctx, &log)
return action
}
_ = util.SendResponse(500, "ai-proxy.proc_resp_headers_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to process response headers: %v", err))
return types.ActionContinue
}
checkStream(&ctx, &log)
_, needHandleBody := activeProvider.(provider.ResponseBodyHandler)
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleBody && !needHandleStreamingBody {
@@ -223,3 +225,13 @@ func getOpenAiApiName(path string) provider.ApiName {
}
return ""
}
func checkStream(ctx *wrapper.HttpContext, log *wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
if err != nil {
log.Errorf("unable to load content-type header from response: %v", err)
}
(*ctx).BufferResponseBody()
}
}

View File

@@ -0,0 +1,74 @@
package provider
import (
"errors"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
// ai360Provider is the provider for 360 OpenAI service.
const (
ai360Domain = "api.360.cn"
)
type ai360ProviderInitializer struct {
}
type ai360Provider struct {
config ProviderConfig
contextCache *contextCache
}
func (m *ai360ProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (m *ai360ProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &ai360Provider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
func (m *ai360Provider) GetProviderType() string {
return providerTypeAi360
}
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(ai360Domain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", m.config.GetRandomToken())
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}

View File

@@ -23,6 +23,9 @@ func (m *azureProviderInitializer) ValidateConfig(config ProviderConfig) error {
if _, err := url.Parse(config.azureServiceUrl); err != nil {
return fmt.Errorf("invalid azureServiceUrl: %w", err)
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
@@ -52,27 +55,32 @@ func (m *azureProvider) GetProviderType() string {
}
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(m.serviceUrl.RequestURI())
_ = util.OverwriteRequestHost(m.serviceUrl.Host)
_ = proxywasm.ReplaceHttpRequestHeader("api-key", m.config.apiTokens[0])
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
if apiName == ApiNameChatCompletion {
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
} else {
ctx.DontReadRequestBody()
}
return types.ActionContinue, nil
}
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if m.contextCache == nil {
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if m.contextCache == nil {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()

View File

@@ -1,6 +1,7 @@
package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -20,6 +21,9 @@ type baichuanProviderInitializer struct {
}
func (m *baichuanProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -34,6 +34,9 @@ type baiduProviderInitializer struct {
}
func (b *baiduProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
@@ -80,7 +83,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := b.GetRequestPath(request.Model)
path := b.getRequestPath(request.Model)
_ = util.OverwriteRequestPath(path)
if b.config.context == nil {
@@ -123,7 +126,7 @@ func (b *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := b.GetRequestPath(mappedModel)
path := b.getRequestPath(mappedModel)
_ = util.OverwriteRequestPath(path)
if b.config.context == nil {
@@ -223,7 +226,7 @@ type baiduTextGenRequest struct {
UserId string `json:"user_id,omitempty"`
}
func (b *baiduProvider) GetRequestPath(baiduModel string) string {
func (b *baiduProvider) getRequestPath(baiduModel string) string {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix, ok := baiduModelToPathSuffixMap[baiduModel]
if !ok {
@@ -250,7 +253,7 @@ func (b *baiduProvider) baiduTextGenRequest(request *chatCompletionRequest) *bai
}
for _, message := range request.Messages {
if message.Role == roleSystem {
baiduRequest.System = message.Content
baiduRequest.System = message.StringContent()
} else {
baiduRequest.Messages = append(baiduRequest.Messages, chatMessage{
Role: message.Role,
@@ -323,7 +326,7 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Object: objectChatCompletionChunk,
Choices: []chatCompletionChoice{choice},
Usage: usage{
PromptTokens: response.Usage.PromptTokens,

View File

@@ -4,12 +4,13 @@ import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"strings"
"time"
)
// claudeProvider is the provider for Claude service.
@@ -78,6 +79,9 @@ type claudeTextGenDelta struct {
}
func (c *claudeProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
@@ -270,7 +274,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
for _, message := range origRequest.Messages {
if message.Role == roleSystem {
claudeRequest.System = message.Content
claudeRequest.System = message.StringContent()
continue
}
claudeMessage := chatMessage{

View File

@@ -21,6 +21,9 @@ type cloudflareProviderInitializer struct {
}
func (c *cloudflareProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -0,0 +1,137 @@
package provider
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
nameMaxTokens = "max_tokens"
nameTemperature = "temperature"
nameTopP = "top_p"
nameTopK = "top_k"
nameSeed = "seed"
)
var maxTokensMapping = map[string]string{
"openai": "max_tokens",
"baidu": "max_output_tokens",
"spark": "max_tokens",
"qwen": "max_tokens",
"gemini": "maxOutputTokens",
"claude": "max_tokens",
"minimax": "tokens_to_generate",
}
var temperatureMapping = map[string]string{
"openai": "temperature",
"baidu": "temperature",
"spark": "temperature",
"qwen": "temperature",
"gemini": "temperature",
"hunyuan": "Temperature",
"claude": "temperature",
"minimax": "temperature",
}
var topPMapping = map[string]string{
"openai": "top_p",
"baidu": "top_p",
"qwen": "top_p",
"gemini": "topP",
"hunyuan": "TopP",
"claude": "top_p",
"minimax": "top_p",
}
var topKMapping = map[string]string{
"spark": "top_k",
"gemini": "topK",
"claude": "top_k",
}
var seedMapping = map[string]string{
"openai": "seed",
"qwen": "seed",
}
var settingMapping = map[string]map[string]string{
nameMaxTokens: maxTokensMapping,
nameTemperature: temperatureMapping,
nameTopP: topPMapping,
nameTopK: topKMapping,
nameSeed: seedMapping,
}
type CustomSetting struct {
// @Title zh-CN 参数名称
// @Description zh-CN 想要设置的参数的名称例如max_tokens
name string
// @Title zh-CN 参数值
// @Description zh-CN 想要设置的参数的值例如0
value string
// @Title zh-CN 设置模式
// @Description zh-CN 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会根据 /plugins/wasm-go/extensions/ai-proxy/README.md中关于custom-setting部分的表格自动按照协议对参数名做改写如果为"raw"则不会有任何改写和限制检查
mode string
// @Title zh-CN json edit 模式
// @Description zh-CN 如果为false则只在用户没有设置这个参数时填充参数否则会直接覆盖用户原有的参数设置
overwrite bool
}
func (c *CustomSetting) FromJson(json gjson.Result) {
c.name = json.Get("name").String()
c.value = json.Get("value").Raw
if obj := json.Get("mode"); obj.Exists() {
c.mode = obj.String()
} else {
c.mode = "auto"
}
if obj := json.Get("overwrite"); obj.Exists() {
c.overwrite = obj.Bool()
} else {
c.overwrite = true
}
}
func (c *CustomSetting) Validate() bool {
return c.name != ""
}
func (c *CustomSetting) setInvalid() {
c.name = "" // set empty to represent invalid
}
func (c *CustomSetting) AdjustWithProtocol(protocol string) {
if !(c.mode == "raw") {
mapping, ok := settingMapping[c.name]
if ok {
c.name, ok = mapping[protocol]
}
if !ok {
c.setInvalid()
return
}
}
if protocol == providerTypeQwen {
c.name = "parameters." + c.name
}
if protocol == providerTypeGemini {
c.name = "generation_config." + c.name
}
}
func ReplaceByCustomSettings(body []byte, settings []CustomSetting) ([]byte, error) {
var err error
strBody := string(body)
for _, setting := range settings {
if !setting.overwrite && gjson.Get(strBody, setting.name).Exists() {
continue
}
strBody, err = sjson.SetRaw(strBody, setting.name, setting.value)
if err != nil {
break
}
}
return []byte(strBody), err
}

View File

@@ -0,0 +1,176 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// deeplProvider is the provider for DeepL service.
const (
deeplHostPro = "api.deepl.com"
deeplHostFree = "api-free.deepl.com"
deeplChatCompletionPath = "/v2/translate"
)
type deeplProviderInitializer struct {
}
type deeplProvider struct {
config ProviderConfig
contextCache *contextCache
}
// spec reference: https://developers.deepl.com/docs/v/zh/api-reference/translate/openapi-spec-for-text-translation
type deeplRequest struct {
// "Model" parameter is used to distinguish which service to use
Model string `json:"model,omitempty"`
Text []string `json:"text"`
SourceLang string `json:"source_lang,omitempty"`
TargetLang string `json:"target_lang"`
Context string `json:"context,omitempty"`
SplitSentences string `json:"split_sentences,omitempty"`
PreserveFormatting bool `json:"preserve_formatting,omitempty"`
Formality string `json:"formality,omitempty"`
GlossaryId string `json:"glossary_id,omitempty"`
TagHandling string `json:"tag_handling,omitempty"`
OutlineDetection bool `json:"outline_detection,omitempty"`
NonSplittingTags []string `json:"non_splitting_tags,omitempty"`
SplittingTags []string `json:"splitting_tags,omitempty"`
IgnoreTags []string `json:"ignore_tags,omitempty"`
}
type deeplResponse struct {
Translations []deeplResponseTranslation `json:"translations,omitempty"`
Message string `json:"message,omitempty"`
}
type deeplResponseTranslation struct {
DetectedSourceLanguage string `json:"detected_source_language"`
Text string `json:"text"`
}
func (d *deeplProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.targetLang == "" {
return errors.New("missing targetLang in deepl provider config")
}
return nil
}
func (d *deeplProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &deeplProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
func (d *deeplProvider) GetProviderType() string {
return providerTypeDeepl
}
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(deeplChatCompletionPath)
_ = util.OverwriteRequestAuthorization("DeepL-Auth-Key " + d.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
return types.HeaderStopIteration, nil
}
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
if d.config.protocol == protocolOriginal {
request := &deeplRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if err := d.overwriteRequestHost(request.Model); err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
originRequest := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, originRequest); err != nil {
return types.ActionContinue, err
}
if err := d.overwriteRequestHost(originRequest.Model); err != nil {
return types.ActionContinue, err
}
ctx.SetContext(ctxKeyFinalRequestModel, originRequest.Model)
deeplRequest := &deeplRequest{
Text: make([]string, 0),
TargetLang: d.config.targetLang,
}
for _, msg := range originRequest.Messages {
if msg.Role == roleSystem {
deeplRequest.Context = msg.StringContent()
} else {
deeplRequest.Text = append(deeplRequest.Text, msg.StringContent())
}
}
return types.ActionContinue, replaceJsonRequestBody(deeplRequest, log)
}
}
func (d *deeplProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (d *deeplProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
deeplResponse := &deeplResponse{}
if err := json.Unmarshal(body, deeplResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal deepl response: %v", err)
}
response := d.responseDeepl2OpenAI(ctx, deeplResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (d *deeplProvider) responseDeepl2OpenAI(ctx wrapper.HttpContext, deeplResponse *deeplResponse) *chatCompletionResponse {
var choices []chatCompletionChoice
// Fail
if deeplResponse.Message != "" {
choices = make([]chatCompletionChoice, 1)
choices[0] = chatCompletionChoice{
Message: &chatMessage{Role: roleAssistant, Content: deeplResponse.Message},
Index: 0,
}
} else {
// Success
choices = make([]chatCompletionChoice, len(deeplResponse.Translations))
for idx, t := range deeplResponse.Translations {
choices[idx] = chatCompletionChoice{
Index: idx,
Message: &chatMessage{Role: roleAssistant, Content: t.Text, Name: t.DetectedSourceLanguage},
}
}
}
return &chatCompletionResponse{
Created: time.Now().UnixMilli() / 1000,
Object: objectChatCompletion,
Choices: choices,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
}
}
func (d *deeplProvider) overwriteRequestHost(model string) error {
if model == "Pro" {
_ = util.OverwriteRequestHost(deeplHostPro)
} else if model == "Free" {
_ = util.OverwriteRequestHost(deeplHostFree)
} else {
return errors.New(`deepl model should be "Free" or "Pro"`)
}
return nil
}

View File

@@ -1,6 +1,7 @@
package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -20,6 +21,9 @@ type deepseekProviderInitializer struct {
}
func (m *deepseekProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -0,0 +1,607 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// geminiProvider is the provider for google gemini/gemini flash service.
const (
geminiApiKeyHeader = "x-goog-api-key"
geminiDomain = "generativelanguage.googleapis.com"
)
type geminiProviderInitializer struct {
}
func (g *geminiProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &geminiProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type geminiProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (g *geminiProvider) GetProviderType() string {
return providerTypeGemini
}
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion && apiName != ApiNameEmbeddings {
return types.ActionContinue, errUnsupportedApiName
}
_ = proxywasm.ReplaceHttpRequestHeader(geminiApiKeyHeader, g.config.GetRandomToken())
_ = util.OverwriteRequestHost(geminiDomain)
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
// Delay the header processing to allow changing streaming mode in OnRequestBody
return types.HeaderStopIteration, nil
}
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionRequestBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings {
return g.onEmbeddingsRequestBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
// 使用gemini接口协议
if g.config.protocol == protocolOriginal {
request := &geminiChatRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := g.getRequestPath(ApiNameChatCompletion, request.Model, request.Stream)
_ = util.OverwriteRequestPath(path)
// 移除多余的model和stream字段
request = &geminiChatRequest{
Contents: request.Contents,
SafetySettings: request.SafetySettings,
GenerationConfig: request.GenerationConfig,
Tools: request.Tools,
}
if g.config.context == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := g.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
g.setSystemContent(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, g.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := g.getRequestPath(ApiNameChatCompletion, mappedModel, request.Stream)
_ = util.OverwriteRequestPath(path)
if g.config.context == nil {
geminiRequest := g.buildGeminiChatRequest(request)
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
}
err := g.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, "ai-proxy.gemini.load_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
geminiRequest := g.buildGeminiChatRequest(request)
if err := replaceJsonRequestBody(geminiRequest, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.gemini.insert_ctx_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
// 使用gemini接口协议
if g.config.protocol == protocolOriginal {
request := &geminiBatchEmbeddingRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 根据模型重写requestPath
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
_ = util.OverwriteRequestPath(path)
// 移除多余的model字段
request = &geminiBatchEmbeddingRequest{
Requests: request.Requests,
}
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
request := &embeddingsRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
// 映射模型重写requestPath
model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in embeddings request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model)
mappedModel := getMappedModel(model, g.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
path := g.getRequestPath(ApiNameEmbeddings, mappedModel, false)
_ = util.OverwriteRequestPath(path)
geminiRequest := g.buildBatchEmbeddingRequest(request)
return types.ActionContinue, replaceJsonRequestBody(geminiRequest, log)
}
func (g *geminiProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if g.config.protocol == protocolOriginal {
ctx.DontReadResponseBody()
return types.ActionContinue, nil
}
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
// sample end event response:
// data: {"candidates": [{"content": {"parts": [{"text": "我是 Gemini一个大型多模态模型由 Google 训练。我的职责是尽我所能帮助您,并尽力提供全面且信息丰富的答复。"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 2,"candidatesTokenCount": 35,"totalTokenCount": 37}}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
var geminiResp geminiChatResponse
if err := json.Unmarshal([]byte(data), &geminiResp); err != nil {
log.Errorf("unable to unmarshal gemini response: %v", err)
continue
}
response := g.buildChatCompletionStreamResponse(ctx, &geminiResp)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
g.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
func (g *geminiProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionResponseBody(ctx, body, log)
} else if apiName == ApiNameEmbeddings {
return g.onEmbeddingsResponseBody(ctx, body, log)
}
return types.ActionContinue, errUnsupportedApiName
}
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
}
if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini chat completion response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
}
response := g.buildChatCompletionResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) (types.Action, error) {
geminiResponse := &geminiEmbeddingResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
}
if geminiResponse.Error != nil {
return types.ActionContinue, fmt.Errorf("gemini embeddings response error, error_code: %d, error_status:%s, error_message: %s",
geminiResponse.Error.Code, geminiResponse.Error.Status, geminiResponse.Error.Message)
}
response := g.buildEmbeddingsResponse(ctx, geminiResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
action := ""
if apiName == ApiNameEmbeddings {
action = "batchEmbedContents"
} else if stream {
action = "streamGenerateContent?alt=sse"
} else {
action = "generateContent"
}
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
}
type geminiChatRequest struct {
// Model and Stream are only used when using the gemini original protocol
Model string `json:"model,omitempty"`
Stream bool `json:"stream,omitempty"`
Contents []geminiChatContent `json:"contents"`
SafetySettings []geminiChatSafetySetting `json:"safety_settings,omitempty"`
GenerationConfig geminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []geminiChatTools `json:"tools,omitempty"`
}
type geminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []geminiPart `json:"parts"`
}
type geminiChatSafetySetting struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type geminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type geminiChatTools struct {
FunctionDeclarations any `json:"function_declarations,omitempty"`
}
type geminiPart struct {
Text string `json:"text,omitempty"`
InlineData *geminiInlineData `json:"inlineData,omitempty"`
FunctionCall *geminiFunctionCall `json:"functionCall,omitempty"`
}
type geminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type geminiFunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiChatRequest {
var safetySettings []geminiChatSafetySetting
{
}
for category, threshold := range g.config.geminiSafetySetting {
safetySettings = append(safetySettings, geminiChatSafetySetting{
Category: category,
Threshold: threshold,
})
}
geminiRequest := geminiChatRequest{
Contents: make([]geminiChatContent, 0, len(request.Messages)),
SafetySettings: safetySettings,
GenerationConfig: geminiChatGenerationConfig{
Temperature: request.Temperature,
TopP: request.TopP,
MaxOutputTokens: request.MaxTokens,
},
}
if request.Tools != nil {
functions := make([]function, 0, len(request.Tools))
for _, tool := range request.Tools {
functions = append(functions, tool.Function)
}
geminiRequest.Tools = []geminiChatTools{
{
FunctionDeclarations: functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range request.Messages {
content := geminiChatContent{
Role: message.Role,
Parts: []geminiPart{
{
Text: message.StringContent(),
},
},
}
// there's no assistant role in gemini and API shall vomit if role is not user or model
if content.Role == roleAssistant {
content.Role = "model"
} else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason
content.Role = roleUser
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// if a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, geminiChatContent{
Role: "model",
Parts: []geminiPart{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
func (g *geminiProvider) setSystemContent(request *geminiChatRequest, content string) {
systemContents := []geminiChatContent{{
Role: roleUser,
Parts: []geminiPart{
{
Text: content,
},
},
}}
request.Contents = append(systemContents, request.Contents...)
}
type geminiBatchEmbeddingRequest struct {
// Model are only used when using the gemini original protocol
Model string `json:"model,omitempty"`
Requests []geminiEmbeddingRequest `json:"requests"`
}
type geminiEmbeddingRequest struct {
Model string `json:"model"`
Content geminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
func (g *geminiProvider) buildBatchEmbeddingRequest(request *embeddingsRequest) *geminiBatchEmbeddingRequest {
inputs := request.ParseInput()
requests := make([]geminiEmbeddingRequest, len(inputs))
model := fmt.Sprintf("models/%s", request.Model)
for i, input := range inputs {
requests[i] = geminiEmbeddingRequest{
Model: model,
Content: geminiChatContent{
Parts: []geminiPart{
{
Text: input,
},
},
},
}
}
return &geminiBatchEmbeddingRequest{
Requests: requests,
}
}
type geminiChatResponse struct {
Candidates []geminiChatCandidate `json:"candidates"`
PromptFeedback geminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata geminiUsageMetadata `json:"usageMetadata"`
Error *geminiResponseError `json:"error,omitempty"`
}
type geminiChatCandidate struct {
Content geminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []geminiChatSafetyRating `json:"safetyRatings"`
}
type geminiChatPromptFeedback struct {
SafetyRatings []geminiChatSafetyRating `json:"safetyRatings"`
}
type geminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount,omitempty"`
CandidatesTokenCount int `json:"candidatesTokenCount,omitempty"`
TotalTokenCount int `json:"totalTokenCount,omitempty"`
}
type geminiResponseError struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Status string `json:"status,omitempty"`
}
type geminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
func (g *geminiProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, response *geminiChatResponse) *chatCompletionResponse {
fullTextResponse := chatCompletionResponse{
Id: fmt.Sprintf("chatcmpl-%s", uuid.New().String()),
Object: objectChatCompletion,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: make([]chatCompletionChoice, 0, len(response.Candidates)),
Usage: usage{
PromptTokens: response.UsageMetadata.PromptTokenCount,
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
TotalTokens: response.UsageMetadata.TotalTokenCount,
},
}
for i, candidate := range response.Candidates {
choice := chatCompletionChoice{
Index: i,
Message: &chatMessage{
Role: roleAssistant,
},
FinishReason: finishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = g.buildToolCalls(&candidate)
} else {
choice.Message.Content = candidate.Content.Parts[0].Text
}
} else {
choice.Message.Content = ""
choice.FinishReason = candidate.FinishReason
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCall {
var toolCalls []toolCall
item := candidate.Content.Parts[0]
if item.FunctionCall != nil {
return toolCalls
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
proxywasm.LogErrorf("get toolCalls from gemini response failed: " + err.Error())
return toolCalls
}
toolCall := toolCall{
Id: fmt.Sprintf("call_%s", uuid.New().String()),
Type: "function",
Function: functionCall{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpContext, geminiResp *geminiChatResponse) *chatCompletionResponse {
var choice chatCompletionChoice
if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 {
choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text}
}
streamResponse := chatCompletionResponse{
Id: fmt.Sprintf("chatcmpl-%s", uuid.New().String()),
Object: objectChatCompletionChunk,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: []chatCompletionChoice{choice},
Usage: usage{
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
},
}
return &streamResponse
}
type geminiEmbeddingResponse struct {
Embeddings []geminiEmbeddingData `json:"embeddings"`
Error *geminiResponseError `json:"error,omitempty"`
}
type geminiEmbeddingData struct {
Values []float64 `json:"values"`
}
func (g *geminiProvider) buildEmbeddingsResponse(ctx wrapper.HttpContext, geminiResp *geminiEmbeddingResponse) *embeddingsResponse {
response := embeddingsResponse{
Object: "list",
Data: make([]embedding, 0, len(geminiResp.Embeddings)),
Model: ctx.GetContext(ctxKeyFinalRequestModel).(string),
Usage: usage{
TotalTokens: 0,
},
}
for _, item := range geminiResp.Embeddings {
response.Data = append(response.Data, embedding{
Object: `embedding`,
Index: 0,
Embedding: item.Values,
})
}
return &response
}
func (g *geminiProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}

View File

@@ -1,6 +1,7 @@
package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -18,6 +19,9 @@ const (
type groqProviderInitializer struct{}
func (m *groqProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -447,7 +447,7 @@ func convertMessagesFromOpenAIToHunyuan(openAIMessages []chatMessage) []hunyuanC
for _, msg := range openAIMessages {
hunyuanChatMessages = append(hunyuanChatMessages, hunyuanChatMessage{
Role: msg.Role,
Content: msg.Content,
Content: msg.StringContent(),
})
}

View File

@@ -52,6 +52,9 @@ func (m *minimaxProviderInitializer) ValidateConfig(config ProviderConfig) error
}
}
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
@@ -401,19 +404,19 @@ func (m *minimaxProvider) buildMinimaxChatCompletionV2Request(request *chatCompl
botName = determineName(message.Name, defaultBotName)
botSetting = append(botSetting, minimaxBotSetting{
BotName: botName,
Content: message.Content,
Content: message.StringContent(),
})
case roleAssistant:
messages = append(messages, minimaxMessage{
SenderType: senderTypeBot,
SenderName: determineName(message.Name, defaultBotName),
Text: message.Content,
Text: message.StringContent(),
})
case roleUser:
messages = append(messages, minimaxMessage{
SenderType: senderTypeUser,
SenderName: determineName(message.Name, defaultSenderName),
Text: message.Content,
Text: message.StringContent(),
})
}
}

View File

@@ -13,6 +13,9 @@ const (
eventResult = "result"
httpStatus200 = "200"
contentTypeText = "text"
contentTypeImageUrl = "image_url"
)
type chatCompletionRequest struct {
@@ -31,6 +34,7 @@ type chatCompletionRequest struct {
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
Stop []string `json:"stop,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
}
type streamOptions struct {
@@ -79,12 +83,27 @@ type usage struct {
type chatMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Content any `json:"content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
type messageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageUrl *imageUrl `json:"image_url,omitempty"`
}
type imageUrl struct {
Url string `json:"url,omitempty"`
Detail string `json:"detail,omitempty"`
}
func (m *chatMessage) IsEmpty() bool {
if m.Content != "" {
if m.IsStringContent() && m.Content != "" {
return false
}
anyList, ok := m.Content.([]any)
if ok && len(anyList) > 0 {
return false
}
if len(m.ToolCalls) != 0 {
@@ -102,6 +121,76 @@ func (m *chatMessage) IsEmpty() bool {
return true
}
func (m *chatMessage) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m *chatMessage) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if contentMap["type"] == contentTypeText {
if subStr, ok := contentMap[contentTypeText].(string); ok {
contentStr += subStr + "\n"
}
}
}
return contentStr
}
return ""
}
func (m *chatMessage) ParseContent() []messageContent {
var contentList []messageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, messageContent{
Type: contentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case contentTypeText:
if subStr, ok := contentMap[contentTypeText].(string); ok {
contentList = append(contentList, messageContent{
Type: contentTypeText,
Text: subStr,
})
}
case contentTypeImageUrl:
if subObj, ok := contentMap[contentTypeImageUrl].(map[string]any); ok {
contentList = append(contentList, messageContent{
Type: contentTypeImageUrl,
ImageUrl: &imageUrl{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
type toolCall struct {
Index int `json:"index"`
Id string `json:"id"`
@@ -161,3 +250,22 @@ type embedding struct {
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
func (r embeddingsRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}

View File

@@ -26,6 +26,9 @@ func (m *moonshotProviderInitializer) ValidateConfig(config ProviderConfig) erro
if config.moonshotFileId != "" && config.context != nil {
return errors.New("moonshotFileId and context cannot be configured at the same time")
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -2,6 +2,7 @@ package provider
import (
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
@@ -12,9 +13,9 @@ import (
// openaiProvider is the provider for OpenAI service.
const (
openaiDomain = "api.openai.com"
openaiChatCompletionPath = "/v1/chat/completions"
openaiEmbeddingsPath = "/v1/chat/embeddings"
defaultOpenaiDomain = "api.openai.com"
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
defaultOpenaiEmbeddingsPath = "/v1/chat/embeddings"
)
type openaiProviderInitializer struct {
@@ -25,14 +26,29 @@ func (m *openaiProviderInitializer) ValidateConfig(config ProviderConfig) error
}
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
if config.openaiCustomUrl == "" {
return &openaiProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
customUrl := strings.TrimPrefix(strings.TrimPrefix(config.openaiCustomUrl, "http://"), "https://")
pairs := strings.SplitN(customUrl, "/", 2)
if len(pairs) != 2 {
return nil, fmt.Errorf("invalid openaiCustomUrl:%s", config.openaiCustomUrl)
}
return &openaiProvider{
config: config,
customDomain: pairs[0],
customPath: "/" + pairs[1],
contextCache: createContextCache(&config),
}, nil
}
type openaiProvider struct {
config ProviderConfig
customDomain string
customPath string
contextCache *contextCache
}
@@ -41,15 +57,25 @@ func (m *openaiProvider) GetProviderType() string {
}
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
switch apiName {
case ApiNameChatCompletion:
_ = util.OverwriteRequestPath(openaiChatCompletionPath)
break
case ApiNameEmbeddings:
_ = util.OverwriteRequestPath(openaiEmbeddingsPath)
break
if m.customPath == "" {
switch apiName {
case ApiNameChatCompletion:
_ = util.OverwriteRequestPath(defaultOpenaiChatCompletionPath)
case ApiNameEmbeddings:
ctx.DontReadRequestBody()
_ = util.OverwriteRequestPath(defaultOpenaiEmbeddingsPath)
}
} else {
_ = util.OverwriteRequestPath(m.customPath)
}
if m.customDomain == "" {
_ = util.OverwriteRequestHost(defaultOpenaiDomain)
} else {
_ = util.OverwriteRequestHost(m.customDomain)
}
if len(m.config.apiTokens) > 0 {
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
}
_ = util.OverwriteRequestAuthorization("Bearer " + m.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
}
@@ -63,26 +89,23 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
bodyAltered := false
if m.config.responseJsonSchema != nil {
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
request.ResponseFormat = m.config.responseJsonSchema
}
if request.Stream {
// For stream requests, we need to include usage in the response.
if request.StreamOptions == nil {
request.StreamOptions = &streamOptions{IncludeUsage: true}
bodyAltered = true
} else if !request.StreamOptions.IncludeUsage {
request.StreamOptions.IncludeUsage = true
bodyAltered = true
}
}
if m.contextCache == nil {
if bodyAltered {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, "ai-proxy.openai.set_include_usage_failed", util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
return types.ActionContinue, nil
} else {
// If context cache is configured and body has been altered, the new body will be replaced when inserting the context data.
}
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {

View File

@@ -19,6 +19,7 @@ const (
providerTypeMoonshot = "moonshot"
providerTypeAzure = "azure"
providerTypeAi360 = "ai360"
providerTypeQwen = "qwen"
providerTypeOpenAI = "openai"
providerTypeGroq = "groq"
@@ -33,6 +34,9 @@ const (
providerTypeStepfun = "stepfun"
providerTypeMinimax = "minimax"
providerTypeCloudflare = "cloudflare"
providerTypeSpark = "spark"
providerTypeGemini = "gemini"
providerTypeDeepl = "deepl"
protocolOpenAI = "openai"
protocolOriginal = "original"
@@ -70,6 +74,7 @@ var (
providerInitializers = map[string]providerInitializer{
providerTypeMoonshot: &moonshotProviderInitializer{},
providerTypeAzure: &azureProviderInitializer{},
providerTypeAi360: &ai360ProviderInitializer{},
providerTypeQwen: &qwenProviderInitializer{},
providerTypeOpenAI: &openaiProviderInitializer{},
providerTypeGroq: &groqProviderInitializer{},
@@ -84,6 +89,9 @@ var (
providerTypeStepfun: &stepfunProviderInitializer{},
providerTypeMinimax: &minimaxProviderInitializer{},
providerTypeCloudflare: &cloudflareProviderInitializer{},
providerTypeSpark: &sparkProviderInitializer{},
providerTypeGemini: &geminiProviderInitializer{},
providerTypeDeepl: &deeplProviderInitializer{},
}
)
@@ -121,6 +129,9 @@ type ProviderConfig struct {
// @Title zh-CN 请求超时
// @Description zh-CN 请求AI服务的超时时间单位为毫秒。默认值为120000即2分钟
timeout uint32 `required:"false" yaml:"timeout" json:"timeout"`
// @Title zh-CN 基于OpenAI协议的自定义后端URL
// @Description zh-CN 仅适用于支持 openai 协议的服务。
openaiCustomUrl string `required:"false" yaml:"openaiCustomUrl" json:"openaiCustomUrl"`
// @Title zh-CN Moonshot File ID
// @Description zh-CN 仅适用于Moonshot AI服务。Moonshot AI服务的文件ID其内容用于补充AI请求上下文
moonshotFileId string `required:"false" yaml:"moonshotFileId" json:"moonshotFileId"`
@@ -133,6 +144,9 @@ type ProviderConfig struct {
// @Title zh-CN 启用通义千问搜索服务
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
// @Title zh-CN 开启通义千问兼容模式
// @Description zh-CN 启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。
qwenEnableCompatible bool `required:"false" yaml:"qwenEnableCompatible" json:"qwenEnableCompatible"`
// @Title zh-CN Ollama Server IP/Domain
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的主机地址。
ollamaServerHost string `required:"false" yaml:"ollamaServerHost" json:"ollamaServerHost"`
@@ -163,6 +177,18 @@ type ProviderConfig struct {
// @Title zh-CN Cloudflare Account ID
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
// @Title zh-CN Gemini AI内容过滤和安全级别设定
// @Description zh-CN 仅适用于 Gemini AI 服务。参考https://ai.google.dev/gemini-api/docs/safety-settings
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
// @Title zh-CN 翻译服务需指定的目标语种
// @Description zh-CN 翻译结果的语种目前仅适用于DeepL服务。
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考https://platform.openai.com/docs/guides/structured-outputs
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
// @Title zh-CN 自定义大模型参数配置
// @Description zh-CN 用于填充或者覆盖大模型调用时的参数
customSettings []CustomSetting
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
@@ -175,6 +201,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
if c.timeout == 0 {
c.timeout = defaultTimeout
}
c.openaiCustomUrl = json.Get("openaiCustomUrl").String()
c.moonshotFileId = json.Get("moonshotFileId").String()
c.azureServiceUrl = json.Get("azureServiceUrl").String()
c.qwenFileIds = make([]string, 0)
@@ -182,6 +209,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.qwenFileIds = append(c.qwenFileIds, fileId.String())
}
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
c.qwenEnableCompatible = json.Get("qwenEnableCompatible").Bool()
c.ollamaServerHost = json.Get("ollamaServerHost").String()
c.ollamaServerPort = uint32(json.Get("ollamaServerPort").Uint())
c.modelMapping = make(map[string]string)
@@ -202,12 +230,41 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
c.minimaxGroupId = json.Get("minimaxGroupId").String()
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
if c.typ == providerTypeGemini {
c.geminiSafetySetting = make(map[string]string)
for k, v := range json.Get("geminiSafetySetting").Map() {
c.geminiSafetySetting[k] = v.String()
}
}
c.targetLang = json.Get("targetLang").String()
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
c.responseJsonSchema = schemaValue
} else {
c.responseJsonSchema = nil
}
c.customSettings = make([]CustomSetting, 0)
customSettingsJson := json.Get("customSettings")
if customSettingsJson.Exists() {
protocol := protocolOpenAI
if c.protocol == protocolOriginal {
// use provider name to represent original protocol name
protocol = c.typ
}
for _, settingJson := range customSettingsJson.Array() {
setting := CustomSetting{}
setting.FromJson(settingJson)
// use protocol info to rewrite setting
setting.AdjustWithProtocol(protocol)
if setting.Validate() {
c.customSettings = append(c.customSettings, setting)
}
}
}
}
func (c *ProviderConfig) Validate() error {
if c.apiTokens == nil || len(c.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
if c.timeout < 0 {
return errors.New("invalid timeout in config")
}
@@ -290,3 +347,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func (c ProviderConfig) ReplaceByCustomSettings(body []byte) ([]byte, error) {
return ReplaceByCustomSettings(body, c.customSettings)
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// qwenProvider is the provider for Qwen service.
@@ -20,16 +22,19 @@ import (
const (
qwenResultFormatMessage = "message"
qwenDomain = "dashscope.aliyuncs.com"
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenDomain = "dashscope.aliyuncs.com"
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
qwenCompatiblePath = "/compatible-mode/v1/chat/completions"
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
qwenTopPMin = 0.000001
qwenTopPMax = 0.999999
qwenDummySystemMessageContent = "You are a helpful assistant."
qwenLongModelName = "qwen-long"
qwenLongModelName = "qwen-long"
qwenVlModelPrefixName = "qwen-vl"
)
type qwenProviderInitializer struct {
@@ -39,6 +44,9 @@ func (m *qwenProviderInitializer) ValidateConfig(config ProviderConfig) error {
if len(config.qwenFileIds) != 0 && config.context != nil {
return errors.New("qwenFileIds and context cannot be configured at the same time")
}
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
@@ -60,7 +68,9 @@ func (m *qwenProvider) GetProviderType() string {
}
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName == ApiNameChatCompletion {
if m.config.qwenEnableCompatible {
_ = util.OverwriteRequestPath(qwenCompatiblePath)
} else if apiName == ApiNameChatCompletion {
_ = util.OverwriteRequestPath(qwenChatCompletionPath)
} else if apiName == ApiNameEmbeddings {
_ = util.OverwriteRequestPath(qwenTextEmbeddingPath)
@@ -82,6 +92,23 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
}
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if m.config.qwenEnableCompatible {
if gjson.GetBytes(body, "model").Exists() {
rawModel := gjson.GetBytes(body, "model").String()
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
newBody, err := sjson.SetBytes(body, "model", mappedModel)
if err != nil {
log.Errorf("Replace model error: %v", err)
return types.ActionContinue, err
}
err = proxywasm.ReplaceHttpRequestBody(newBody)
if err != nil {
log.Errorf("Replace request body error: %v", err)
return types.ActionContinue, err
}
}
return types.ActionContinue, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionRequestBody(ctx, body, log)
}
@@ -138,6 +165,10 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
}
request.Model = mappedModel
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
// Use the qwen multimodal model generation API
if strings.HasPrefix(request.Model, qwenVlModelPrefixName) {
_ = util.OverwriteRequestPath(qwenMultimodalGenerationPath)
}
streaming := request.Stream
if streaming {
@@ -217,7 +248,7 @@ func (m *qwenProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiNam
}
func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if name != ApiNameChatCompletion {
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
return chunk, nil
}
@@ -302,6 +333,9 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
}
func (m *qwenProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if m.config.qwenEnableCompatible {
return types.ActionContinue, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionResponseBody(ctx, body, log)
}
@@ -422,8 +456,29 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
if pushedMessage, ok := ctx.GetContext(ctxKeyPushedMessage).(qwenMessage); ok {
if message.Content == "" {
message.Content = pushedMessage.Content
} else if message.IsStringContent() {
deltaContentMessage.Content = util.StripPrefix(deltaContentMessage.StringContent(), pushedMessage.StringContent())
} else if strings.HasPrefix(baseMessage.Model, qwenVlModelPrefixName) {
// Use the Qwen multimodal model generation API
deltaContentList, ok := deltaContentMessage.Content.([]qwenVlMessageContent)
if !ok {
log.Warnf("unexpected deltaContentMessage content type: %T", deltaContentMessage.Content)
} else {
pushedContentList, ok := pushedMessage.Content.([]qwenVlMessageContent)
if !ok {
log.Warnf("unexpected pushedMessage content type: %T", pushedMessage.Content)
} else {
for i, content := range deltaContentList {
if i >= len(pushedContentList) {
break
}
pushedText := pushedContentList[i].Text
content.Text = util.StripPrefix(content.Text, pushedText)
deltaContentList[i] = content
}
}
}
}
deltaContentMessage.Content = util.StripPrefix(deltaContentMessage.Content, pushedMessage.Content)
if len(deltaToolCallsMessage.ToolCalls) > 0 && pushedMessage.ToolCalls != nil {
for i, tc := range deltaToolCallsMessage.ToolCalls {
if i >= len(pushedMessage.ToolCalls) {
@@ -529,7 +584,7 @@ func (m *qwenProvider) insertContextMessage(request *qwenTextGenRequest, content
if builder.Len() != 0 {
builder.WriteString("\n")
}
builder.WriteString(message.Content)
builder.WriteString(message.StringContent())
}
request.Input.Messages = append([]qwenMessage{{Role: roleSystem, Content: builder.String()}, fileMessage}, request.Input.Messages[firstNonSystemMessageIndex:]...)
return 1
@@ -634,10 +689,15 @@ type qwenUsage struct {
type qwenMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role"`
Content string `json:"content"`
Content any `json:"content"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
type qwenVlMessageContent struct {
Image string `json:"image,omitempty"`
Text string `json:"text,omitempty"`
}
type qwenTextEmbeddingRequest struct {
Model string `json:"model"`
Input qwenTextEmbeddingInput `json:"input"`
@@ -677,11 +737,58 @@ func qwenMessageToChatMessage(qwenMessage qwenMessage) chatMessage {
}
}
func (m *qwenMessage) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m *qwenMessage) StringContent() string {
content, ok := m.Content.(string)
if ok {
return content
}
contentList, ok := m.Content.([]any)
if ok {
var contentStr string
for _, contentItem := range contentList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
if text, ok := contentMap["text"].(string); ok {
contentStr += text
}
}
return contentStr
}
return ""
}
func chatMessage2QwenMessage(chatMessage chatMessage) qwenMessage {
return qwenMessage{
Name: chatMessage.Name,
Role: chatMessage.Role,
Content: chatMessage.Content,
ToolCalls: chatMessage.ToolCalls,
if chatMessage.IsStringContent() {
return qwenMessage{
Name: chatMessage.Name,
Role: chatMessage.Role,
Content: chatMessage.StringContent(),
ToolCalls: chatMessage.ToolCalls,
}
} else {
var contents []qwenVlMessageContent
openaiContent := chatMessage.ParseContent()
for _, part := range openaiContent {
var content qwenVlMessageContent
if part.Type == contentTypeText {
content.Text = part.Text
} else if part.Type == contentTypeImageUrl {
content.Image = part.ImageUrl.Url
}
contents = append(contents, content)
}
return qwenMessage{
Name: chatMessage.Name,
Role: chatMessage.Role,
Content: contents,
ToolCalls: chatMessage.ToolCalls,
}
}
}

View File

@@ -2,7 +2,6 @@ package provider
import (
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
@@ -14,7 +13,7 @@ func decodeChatCompletionRequest(body []byte, request *chatCompletionRequest) er
return fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Messages == nil || len(request.Messages) == 0 {
return errors.New("no message found in the request body")
return fmt.Errorf("no message found in the request body: %s", body)
}
return nil
}

View File

@@ -0,0 +1,207 @@
package provider
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// sparkProvider is the provider for SparkLLM AI service.
const (
sparkHost = "spark-api-open.xf-yun.com"
sparkChatCompletionPath = "/v1/chat/completions"
)
type sparkProviderInitializer struct {
}
type sparkProvider struct {
config ProviderConfig
contextCache *contextCache
}
type sparkRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
TopK int `json:"top_k,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
}
type sparkResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Choices []chatCompletionChoice `json:"choices"`
Usage usage `json:"usage,omitempty"`
}
type sparkStreamResponse struct {
sparkResponse
Id string `json:"id"`
Created int64 `json:"created"`
}
func (i *sparkProviderInitializer) ValidateConfig(config ProviderConfig) error {
return nil
}
func (i *sparkProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
return &sparkProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
func (p *sparkProvider) GetProviderType() string {
return providerTypeSpark
}
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestHost(sparkHost)
_ = util.OverwriteRequestPath(sparkChatCompletionPath)
_ = util.OverwriteRequestAuthorization("Bearer " + p.config.GetRandomToken())
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")
return types.ActionContinue, nil
}
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
// 使用Spark协议
if p.config.protocol == protocolOriginal {
request := &sparkRequest{}
if err := json.Unmarshal(body, request); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.Model == "" {
return types.ActionContinue, errors.New("request model is empty")
}
// 目前星火在模型名称错误时也会调用generalv3这里还是按照输入的模型名称设置响应里的模型名称
ctx.SetContext(ctxKeyFinalRequestModel, request.Model)
return types.ActionContinue, replaceJsonRequestBody(request, log)
} else {
// 使用openai协议
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if request.Model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
// 映射模型
mappedModel := getMappedModel(request.Model, p.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
ctx.SetContext(ctxKeyFinalRequestModel, mappedModel)
request.Model = mappedModel
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
}
func (p *sparkProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
_ = proxywasm.RemoveHttpResponseHeader("Content-Length")
return types.ActionContinue, nil
}
func (p *sparkProvider) OnResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
sparkResponse := &sparkResponse{}
if err := json.Unmarshal(body, sparkResponse); err != nil {
return types.ActionContinue, fmt.Errorf("unable to unmarshal spark response: %v", err)
}
if sparkResponse.Code != 0 {
return types.ActionContinue, fmt.Errorf("spark response error, error_code: %d, error_message: %s", sparkResponse.Code, sparkResponse.Message)
}
response := p.responseSpark2OpenAI(ctx, sparkResponse)
return types.ActionContinue, replaceJsonResponseBody(response, log)
}
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
responseBuilder := &strings.Builder{}
lines := strings.Split(string(chunk), "\n")
for _, data := range lines {
if len(data) < 6 {
// ignore blank line or wrong format
continue
}
data = data[6:]
// The final response is `data: [DONE]`
if data == "[DONE]" {
continue
}
var sparkResponse sparkStreamResponse
if err := json.Unmarshal([]byte(data), &sparkResponse); err != nil {
log.Errorf("unable to unmarshal spark response: %v", err)
continue
}
response := p.streamResponseSpark2OpenAI(ctx, &sparkResponse)
responseBody, err := json.Marshal(response)
if err != nil {
log.Errorf("unable to marshal response: %v", err)
return nil, err
}
p.appendResponse(responseBuilder, string(responseBody))
}
modifiedResponseChunk := responseBuilder.String()
log.Debugf("=== modified response chunk: %s", modifiedResponseChunk)
return []byte(modifiedResponseChunk), nil
}
func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkResponse) *chatCompletionResponse {
choices := make([]chatCompletionChoice, len(response.Choices))
for idx, c := range response.Choices {
choices[idx] = chatCompletionChoice{
Index: c.Index,
Message: &chatMessage{Role: c.Message.Role, Content: c.Message.Content},
}
}
return &chatCompletionResponse{
Id: response.Sid,
Created: time.Now().UnixMilli() / 1000,
Object: objectChatCompletion,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Choices: choices,
Usage: response.Usage,
}
}
func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, response *sparkStreamResponse) *chatCompletionResponse {
choices := make([]chatCompletionChoice, len(response.Choices))
for idx, c := range response.Choices {
choices[idx] = chatCompletionChoice{
Index: c.Index,
Delta: &chatMessage{Role: c.Delta.Role, Content: c.Delta.Content},
}
}
return &chatCompletionResponse{
Id: response.Sid,
Created: response.Created,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
Object: objectChatCompletion,
Choices: choices,
Usage: response.Usage,
}
}
func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, responseBody string) {
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}

View File

@@ -1,6 +1,7 @@
package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -18,6 +19,9 @@ type stepfunProviderInitializer struct {
}
func (m *stepfunProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -1,6 +1,7 @@
package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -18,6 +19,9 @@ type yiProviderInitializer struct {
}
func (m *yiProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -1,6 +1,7 @@
package provider
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
@@ -17,6 +18,9 @@ const (
type zhipuAiProviderInitializer struct{}
func (m *zhipuAiProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}

View File

@@ -0,0 +1,58 @@
# 功能说明
`ai-qutoa` 插件实现给特定 consumer 根据分配固定的 quota 进行 quota 策略限流,同时支持 quota 管理能力,包括查询 quota 、刷新 quota、增减 quota。
`ai-quota` 插件需要配合 认证插件比如 `key-auth``jwt-auth` 等插件获取认证身份的 consumer 名称,同时需要配合 `ai-statatistics` 插件获取 AI Token 统计信息。
# 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|--------------------|-----------------|--------------------------------------| ---- |--------------------------------------------|
| `redis_key_prefix` | string | 选填 | chat_quota: | qutoa redis key 前缀 |
| `admin_consumer` | string | 必填 | | 管理 quota 管理身份的 consumer 名称 |
| `admin_path` | string | 选填 | /quota | 管理 quota 请求 path 前缀 |
| `redis` | object | 是 | | redis相关配置 |
`redis`中每一项的配置字段说明
| 配置项 | 类型 | 必填 | 默认值 | 说明 |
| ------------ | ------ | ---- | ---------------------------------------------------------- | --------------------------- |
| service_name | string | 必填 | - | redis 服务名称,带服务类型的完整 FQDN 名称,例如 my-redis.dns、redis.my-ns.svc.cluster.local |
| service_port | int | 否 | 服务类型为固定地址static service默认值为80其他为6379 | 输入redis服务的服务端口 |
| username | string | 否 | - | redis用户名 |
| password | string | 否 | - | redis密码 |
| timeout | int | 否 | 1000 | redis连接超时时间单位毫秒 |
# 配置示例
## 识别请求参数 apikey进行区别限流
```yaml
redis_key_prefix: "chat_quota:"
admin_consumer: consumer3
admin_path: /quota
redis:
service_name: redis-service.default.svc.cluster.local
service_port: 6379
timeout: 2000
```
## 刷新 quota
如果当前请求 url 的后缀符合 admin_path例如插件在 example.com/v1/chat/completions 这个路由上生效,那么更新 quota 可以通过
curl https://example.com/v1/chat/completions/quota/refresh -H "Authorization: Bearer credential3" -d "consumer=consumer1&quota=10000"
Redis 中 key 为 chat_quota:consumer1 的值就会被刷新为 10000
## 查询 quota
查询特定用户的 quota 可以通过 curl https://example.com/v1/chat/completions/quota?consumer=consumer1 -H "Authorization: Bearer credential3"
将返回: {"quota": 10000, "consumer": "consumer1"}
## 增减 quota
增减特定用户的 quota 可以通过 curl https://example.com/v1/chat/completions/quota/delta -d "consumer=consumer1&value=100" -H "Authorization: Bearer credential3"
这样 Redis 中 Key 为 chat_quota:consumer1 的值就会增加100可以支持负数则减去对应值。

View File

@@ -0,0 +1,20 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota
go 1.19
//replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.3-0.20240808022948-34f5722d93de
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/tidwall/gjson v1.17.3
github.com/tidwall/resp v0.1.1
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
)

View File

@@ -0,0 +1,22 @@
github.com/alibaba/higress/plugins/wasm-go v1.4.3-0.20240808022948-34f5722d93de h1:lDLqj7Hw41ox8VdsP7oCTPhjPa3+QJUCKApcLh2a45Y=
github.com/alibaba/higress/plugins/wasm-go v1.4.3-0.20240808022948-34f5722d93de/go.mod h1:359don/ahMxpfeLMzr29Cjwcu8IywTTDUzWlBPRNLHw=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94=
github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,399 @@
package main
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-quota/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
)
const (
pluginName = "ai-quota"
)
type ChatMode string
const (
ChatModeCompletion ChatMode = "completion"
ChatModeAdmin ChatMode = "admin"
ChatModeNone ChatMode = "none"
)
type AdminMode string
const (
AdminModeRefresh AdminMode = "refresh"
AdminModeQuery AdminMode = "query"
AdminModeDelta AdminMode = "delta"
AdminModeNone AdminMode = "none"
)
func main() {
wrapper.SetCtx(
pluginName,
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessStreamingResponseBodyBy(onHttpStreamingResponseBody),
)
}
type QuotaConfig struct {
redisInfo RedisInfo `yaml:"redis"`
RedisKeyPrefix string `yaml:"redis_key_prefix"`
AdminConsumer string `yaml:"admin_consumer"`
AdminPath string `yaml:"admin_path"`
credential2Name map[string]string `yaml:"-"`
redisClient wrapper.RedisClient
}
type Consumer struct {
Name string `yaml:"name"`
Credential string `yaml:"credential"`
}
type RedisInfo struct {
ServiceName string `required:"true" yaml:"service_name" json:"service_name"`
ServicePort int `required:"false" yaml:"service_port" json:"service_port"`
Username string `required:"false" yaml:"username" json:"username"`
Password string `required:"false" yaml:"password" json:"password"`
Timeout int `required:"false" yaml:"timeout" json:"timeout"`
}
func parseConfig(json gjson.Result, config *QuotaConfig, log wrapper.Log) error {
log.Debugf("parse config()")
// admin
config.AdminPath = json.Get("admin_path").String()
config.AdminConsumer = json.Get("admin_consumer").String()
if config.AdminPath == "" {
config.AdminPath = "/quota"
}
if config.AdminConsumer == "" {
return errors.New("missing admin_consumer in config")
}
// Redis
config.RedisKeyPrefix = json.Get("redis_key_prefix").String()
if config.RedisKeyPrefix == "" {
config.RedisKeyPrefix = "chat_quota:"
}
redisConfig := json.Get("redis")
if !redisConfig.Exists() {
return errors.New("missing redis in config")
}
serviceName := redisConfig.Get("service_name").String()
if serviceName == "" {
return errors.New("redis service name must not be empty")
}
servicePort := int(redisConfig.Get("service_port").Int())
if servicePort == 0 {
if strings.HasSuffix(serviceName, ".static") {
// use default logic port which is 80 for static service
servicePort = 80
} else {
servicePort = 6379
}
}
username := redisConfig.Get("username").String()
password := redisConfig.Get("password").String()
timeout := int(redisConfig.Get("timeout").Int())
if timeout == 0 {
timeout = 1000
}
config.redisInfo.ServiceName = serviceName
config.redisInfo.ServicePort = servicePort
config.redisInfo.Username = username
config.redisInfo.Password = password
config.redisInfo.Timeout = timeout
config.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: int64(servicePort),
})
return config.redisClient.Init(username, password, int64(timeout))
}
func onHttpRequestHeaders(context wrapper.HttpContext, config QuotaConfig, log wrapper.Log) types.Action {
log.Debugf("onHttpRequestHeaders()")
// get tokens
consumer, err := proxywasm.GetHttpRequestHeader("x-mse-consumer")
if err != nil {
return deniedNoKeyAuthData()
}
if consumer == "" {
return deniedUnauthorizedConsumer()
}
rawPath := context.Path()
path, _ := url.Parse(rawPath)
chatMode, adminMode := getOperationMode(path.Path, config.AdminPath, log)
context.SetContext("chatMode", chatMode)
context.SetContext("adminMode", adminMode)
context.SetContext("consumer", consumer)
log.Debugf("chatMode:%s, adminMode:%s, consumer:%s", chatMode, adminMode, consumer)
if chatMode == ChatModeNone {
return types.ActionContinue
}
if chatMode == ChatModeAdmin {
// query quota
if adminMode == AdminModeQuery {
return queryQuota(context, config, consumer, path, log)
}
if adminMode == AdminModeRefresh || adminMode == AdminModeDelta {
context.BufferRequestBody()
return types.HeaderStopIteration
}
return types.ActionContinue
}
// there is no need to read request body when it is on chat completion mode
context.DontReadRequestBody()
// check quota here
config.redisClient.Get(config.RedisKeyPrefix+consumer, func(response resp.Value) {
isDenied := false
if err := response.Error(); err != nil {
isDenied = true
}
if response.IsNull() {
isDenied = true
}
if response.Integer() <= 0 {
isDenied = true
}
log.Debugf("get consumer:%s quota:%d isDenied:%t", consumer, response.Integer(), isDenied)
if isDenied {
util.SendResponse(http.StatusForbidden, "ai-quota.noquota", "text/plain", "Request denied by ai quota check, No quota left")
return
}
proxywasm.ResumeHttpRequest()
})
return types.HeaderStopAllIterationAndWatermark
}
func onHttpRequestBody(ctx wrapper.HttpContext, config QuotaConfig, body []byte, log wrapper.Log) types.Action {
log.Debugf("onHttpRequestBody()")
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok {
return types.ActionContinue
}
if chatMode == ChatModeNone || chatMode == ChatModeCompletion {
return types.ActionContinue
}
adminMode, ok := ctx.GetContext("adminMode").(AdminMode)
if !ok {
return types.ActionContinue
}
adminConsumer, ok := ctx.GetContext("consumer").(string)
if !ok {
return types.ActionContinue
}
if adminMode == AdminModeRefresh {
return refreshQuota(ctx, config, adminConsumer, string(body), log)
}
if adminMode == AdminModeDelta {
return deltaQuota(ctx, config, adminConsumer, string(body), log)
}
return types.ActionContinue
}
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, data []byte, endOfStream bool, log wrapper.Log) []byte {
chatMode, ok := ctx.GetContext("chatMode").(ChatMode)
if !ok {
return data
}
if chatMode == ChatModeNone || chatMode == ChatModeAdmin {
return data
}
// chat completion mode
if !endOfStream {
return data
}
inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"})
if err != nil {
return data
}
outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"})
if err != nil {
return data
}
inputToken, err := strconv.Atoi(string(inputTokenStr))
if err != nil {
return data
}
outputToken, err := strconv.Atoi(string(outputTokenStr))
if err != nil {
return data
}
consumer, ok := ctx.GetContext("consumer").(string)
if ok {
totalToken := int(inputToken + outputToken)
log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken)
config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil)
}
return data
}
func deniedNoKeyAuthData() types.Action {
util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.")
return types.ActionContinue
}
func deniedUnauthorizedConsumer() types.Action {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized consumer.")
return types.ActionContinue
}
func getOperationMode(path string, adminPath string, log wrapper.Log) (ChatMode, AdminMode) {
fullAdminPath := "/v1/chat/completions" + adminPath
if strings.HasSuffix(path, fullAdminPath+"/refresh") {
return ChatModeAdmin, AdminModeRefresh
}
if strings.HasSuffix(path, fullAdminPath+"/delta") {
return ChatModeAdmin, AdminModeDelta
}
if strings.HasSuffix(path, fullAdminPath) {
return ChatModeAdmin, AdminModeQuery
}
if strings.HasSuffix(path, "/v1/chat/completions") {
return ChatModeCompletion, AdminModeNone
}
return ChatModeNone, AdminModeNone
}
func refreshQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log wrapper.Log) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
return types.ActionContinue
}
queryValues, _ := url.ParseQuery(body)
values := make(map[string]string, len(queryValues))
for k, v := range queryValues {
values[k] = v[0]
}
queryConsumer := values["consumer"]
quota, err := strconv.Atoi(values["quota"])
if queryConsumer == "" || err != nil {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. consumer can't be empty and quota must be integer.")
return types.ActionContinue
}
err2 := config.redisClient.Set(config.RedisKeyPrefix+queryConsumer, quota, func(response resp.Value) {
log.Debugf("Redis set key = %s quota = %d", config.RedisKeyPrefix+queryConsumer, quota)
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
}
util.SendResponse(http.StatusOK, "ai-quota.refreshquota", "text/plain", "refresh quota successful")
})
if err2 != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
return types.ActionPause
}
func queryQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, url *url.URL, log wrapper.Log) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
return types.ActionContinue
}
// check url
queryValues := url.Query()
values := make(map[string]string, len(queryValues))
for k, v := range queryValues {
values[k] = v[0]
}
if values["consumer"] == "" {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. consumer can't be empty.")
return types.ActionContinue
}
queryConsumer := values["consumer"]
err := config.redisClient.Get(config.RedisKeyPrefix+queryConsumer, func(response resp.Value) {
quota := 0
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
} else if response.IsNull() {
quota = 0
} else {
quota = response.Integer()
}
result := struct {
Consumer string `json:"consumer"`
Quota int `json:"quota"`
}{
Consumer: queryConsumer,
Quota: quota,
}
body, _ := json.Marshal(result)
util.SendResponse(http.StatusOK, "ai-quota.queryquota", "application/json", string(body))
})
if err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
return types.ActionPause
}
func deltaQuota(ctx wrapper.HttpContext, config QuotaConfig, adminConsumer string, body string, log wrapper.Log) types.Action {
// check consumer
if adminConsumer != config.AdminConsumer {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. Unauthorized admin consumer.")
return types.ActionContinue
}
queryValues, _ := url.ParseQuery(body)
values := make(map[string]string, len(queryValues))
for k, v := range queryValues {
values[k] = v[0]
}
queryConsumer := values["consumer"]
value, err := strconv.Atoi(values["value"])
if queryConsumer == "" || err != nil {
util.SendResponse(http.StatusForbidden, "ai-quota.unauthorized", "text/plain", "Request denied by ai quota check. consumer can't be empty and value must be integer.")
return types.ActionContinue
}
if value >= 0 {
err := config.redisClient.IncrBy(config.RedisKeyPrefix+queryConsumer, value, func(response resp.Value) {
log.Debugf("Redis Incr key = %s value = %d", config.RedisKeyPrefix+queryConsumer, value)
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
}
util.SendResponse(http.StatusOK, "ai-quota.deltaquota", "text/plain", "delta quota successful")
})
if err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
} else {
err := config.redisClient.DecrBy(config.RedisKeyPrefix+queryConsumer, 0-value, func(response resp.Value) {
log.Debugf("Redis Decr key = %s value = %d", config.RedisKeyPrefix+queryConsumer, 0-value)
if err := response.Error(); err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return
}
util.SendResponse(http.StatusOK, "ai-quota.deltaquota", "text/plain", "delta quota successful")
})
if err != nil {
util.SendResponse(http.StatusServiceUnavailable, "ai-quota.error", "text/plain", fmt.Sprintf("redis error:%v", err))
return types.ActionContinue
}
}
return types.ActionPause
}

View File

@@ -0,0 +1,61 @@
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: ai-quota
namespace: higress-system
spec:
defaultConfig: {}
defaultConfigDisable: true
matchRules:
- config:
redis_key_prefix: "chat_quota:"
admin_consumer: consumer3
admin_path: /quota
redis:
service_name: redis-service.default.svc.cluster.local
service_port: 6379
timeout: 2000
configDisable: false
ingress:
- qwen
phase: UNSPECIFIED_PHASE
priority: 280
url: oci://registry.cn-hangzhou.aliyuncs.com/2456868764/ai-quota:1.0.8
---
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: ai-statistics
namespace: higress-system
spec:
defaultConfig:
enable: true
defaultConfigDisable: false
phase: UNSPECIFIED_PHASE
priority: 250
url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-statistics:1.0.0
---
apiVersion: extensions.higress.io/v1alpha1
kind: WasmPlugin
metadata:
name: wasm-keyauth
namespace: higress-system
spec:
defaultConfig:
consumers:
- credential: "Bearer credential1"
name: consumer1
- credential: "Bearer credential2"
name: consumer2
- credential: "Bearer credential3"
name: consumer3
global_auth: true
keys:
- authorization
in_header: true
defaultConfigDisable: false
priority: 300
url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/key-auth:1.0.0
imagePullPolicy: Always

View File

@@ -0,0 +1,22 @@
package util
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
const (
HeaderContentType = "Content-Type"
MimeTypeTextPlain = "text/plain"
MimeTypeApplicationJson = "application/json"
)
func SendResponse(statusCode uint32, statusCodeDetails string, contentType, body string) error {
return proxywasm.SendHttpResponseWithDetail(statusCode, statusCodeDetails, CreateHeaders(HeaderContentType, contentType), []byte(body), -1)
}
func CreateHeaders(kvs ...string) [][2]string {
headers := make([][2]string, 0, len(kvs)/2)
for i := 0; i < len(kvs); i += 2 {
headers = append(headers, [2]string{kvs[i], kvs[i+1]})
}
return headers
}

View File

@@ -1,34 +1,42 @@
# 简介
通过对接阿里云向量检索服务实现LLM-RAG流程如图所示
![](https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg)
<img src="https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg" width=600>
# 配置说明
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|----------------|-----------------|------|-----|----------------------------------------------------------------------------------|
| `dashscope.apiKey` | string | 必填 | - | 用于在访问通义千问服务时进行认证的令牌。 |
| `dashscope.serviceName` | string | 必填 | - | 通义千问服务名 |
| `dashscope.serviceFQDN` | string | 必填 | - | 通义千问服务名 |
| `dashscope.servicePort` | int | 必填 | - | 通义千问服务端口 |
| `dashscope.domain` | string | 必填 | - | 访问通义千问服务时域名 |
| `dashscope.serviceHost` | string | 必填 | - | 访问通义千问服务时域名 |
| `dashvector.apiKey` | string | 必填 | - | 用于在访问阿里云向量检索服务时进行认证的令牌。 |
| `dashvector.serviceName` | string | 必填 | - | 阿里云向量检索服务名 |
| `dashvector.serviceFQDN` | string | 必填 | - | 阿里云向量检索服务名 |
| `dashvector.servicePort` | int | 必填 | - | 阿里云向量检索服务端口 |
| `dashvector.domain` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
| `dashvector.serviceHost` | string | 必填 | - | 访问阿里云向量检索服务时域名 |
| `dashvector.topk` | int | 必填 | - | 阿里云向量检索时获取向量数 |
| `dashvector.threshold` | float | 必填 | - | 向量距离阈值,高于该阈值的文档会被过滤掉 |
| `dashvector.field` | string | 必填 | - | 阿里云向量检索存储文档的字段名 |
插件开启后在使用链路追踪功能时会在span的attribute中添加rag检索到的文档id信息供排查问题使用。
# 示例
```yaml
dashscope:
apiKey: xxxxxxxxxxxxxxx
serviceName: dashscope
serviceFQDN: dashscope
servicePort: 443
domain: dashscope.aliyuncs.com
serviceHost: dashscope.aliyuncs.com
dashvector:
apiKey: xxxxxxxxxxxxxxxxxxxx
serviceName: dashvector
serviceFQDN: dashvector
servicePort: 443
domain: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
serviceHost: vrs-cn-xxxxxxxxxxxxxxx.dashvector.cn-hangzhou.aliyuncs.com
collection: xxxxxxxxxxxxxxx
topk: 1
threshold: 0.4
field: raw
```
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。

View File

@@ -1,12 +1,9 @@
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"ai-rag/dashscope"
"ai-rag/dashvector"
@@ -20,6 +21,7 @@ func main() {
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
)
}
@@ -29,6 +31,9 @@ type AIRagConfig struct {
DashVectorClient wrapper.HttpClient
DashVectorAPIKey string
DashVectorCollection string
DashVectorTopK int32
DashVectorThreshold float64
DashVectorField string
}
type Request struct {
@@ -47,29 +52,46 @@ type Message struct {
}
func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
checkList := []string{
"dashscope.apiKey",
"dashscope.serviceFQDN",
"dashscope.servicePort",
"dashscope.serviceHost",
"dashvector.apiKey",
"dashvector.collection",
"dashvector.serviceFQDN",
"dashvector.servicePort",
"dashvector.serviceHost",
"dashvector.topk",
"dashvector.threshold",
"dashvector.field",
}
for _, checkEntry := range checkList {
if !json.Get(checkEntry).Exists() {
return fmt.Errorf("%s not found in plugin config!", checkEntry)
}
}
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()
config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashscope.serviceName").String(),
Port: json.Get("dashscope.servicePort").Int(),
Domain: json.Get("dashscope.domain").String(),
config.DashScopeClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: json.Get("dashscope.serviceFQDN").String(),
Port: json.Get("dashscope.servicePort").Int(),
Host: json.Get("dashscope.serviceHost").String(),
})
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
config.DashVectorCollection = json.Get("dashvector.collection").String()
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashvector.serviceName").String(),
Port: json.Get("dashvector.servicePort").Int(),
Domain: json.Get("dashvector.domain").String(),
config.DashVectorClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: json.Get("dashvector.serviceFQDN").String(),
Port: json.Get("dashvector.servicePort").Int(),
Host: json.Get("dashvector.serviceHost").String(),
})
config.DashVectorTopK = int32(json.Get("dashvector.topk").Int())
config.DashVectorThreshold = json.Get("dashvector.threshold").Float()
config.DashVectorField = json.Get("dashvector.field").String()
return nil
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
p, _ := proxywasm.GetHttpRequestHeader(":path")
if p != "/api/openai/v1/chat/completions" {
ctx.DontReadRequestBody()
return types.ActionContinue
}
proxywasm.RemoveHttpRequestHeader("content-length")
return types.ActionContinue
}
@@ -78,9 +100,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
var rawRequest Request
_ = json.Unmarshal(body, &rawRequest)
messageLength := len(rawRequest.Messages)
if messageLength == 0 {
return types.ActionContinue
}
rawContent := rawRequest.Messages[messageLength-1].Content
requestEmbedding := dashscope.Request{
Model: "text-embedding-v1",
Model: "text-embedding-v2",
Input: dashscope.Input{
Texts: []string{rawContent},
},
@@ -90,7 +115,6 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
// log.Info(string(reqEmbeddingSerialized))
config.DashScopeClient.Post(
"/api/v1/services/embeddings/text-embedding/text-embedding",
headers,
@@ -99,8 +123,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
var responseEmbedding dashscope.Response
_ = json.Unmarshal(responseBody, &responseEmbedding)
requestQuery := dashvector.Request{
TopK: 1,
OutputFileds: []string{"raw"},
TopK: config.DashVectorTopK,
OutputFileds: []string{config.DashVectorField},
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
}
requestQuerySerialized, _ := json.Marshal(requestQuery)
@@ -111,11 +135,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
var response dashvector.Response
_ = json.Unmarshal(responseBody, &response)
doc := response.Output[0].Fields.Raw
rawRequest.Messages[messageLength-1].Content = fmt.Sprintf("%s\n以上是一些可能有帮助的参考信息你可以自行选择是否使用这些参考信息现在请回答以下问题\n%s", doc, rawContent)
newBody, _ := json.Marshal(rawRequest)
// log.Info(string(newBody))
proxywasm.ReplaceHttpRequestBody(newBody)
recallDocIds := []string{}
recallDocs := []string{}
for _, output := range response.Output {
log.Debugf("Score: %f, Doc: %s", output.Score, output.Fields.Raw)
if output.Score <= float32(config.DashVectorThreshold) {
recallDocs = append(recallDocs, output.Fields.Raw)
recallDocIds = append(recallDocIds, output.ID)
}
}
if len(recallDocs) > 0 {
rawRequest.Messages = rawRequest.Messages[:messageLength-1]
traceStr := strings.Join(recallDocIds, ", ")
proxywasm.SetProperty([]string{"trace_span_tag.rag_docs"}, []byte(traceStr))
for _, doc := range recallDocs {
rawRequest.Messages = append(rawRequest.Messages, Message{"user", doc})
}
rawRequest.Messages = append(rawRequest.Messages, Message{"user", fmt.Sprintf("现在,请回答以下问题:\n%s", rawContent)})
newBody, _ := json.Marshal(rawRequest)
proxywasm.ReplaceHttpRequestBody(newBody)
ctx.SetContext("x-envoy-rag-recall", true)
}
proxywasm.ResumeHttpRequest()
},
)
@@ -124,3 +164,13 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte,
)
return types.ActionPause
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
recall, ok := ctx.GetContext("x-envoy-rag-recall").(bool)
if ok && recall {
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "true")
} else {
proxywasm.AddHttpResponseHeader("x-envoy-rag-recall", "false")
}
return types.ActionContinue
}

View File

@@ -3,11 +3,12 @@ package main
import (
"errors"
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
re "github.com/wasilibs/go-re2"
"github.com/zmap/go-iptree/iptree"
"strings"
)
// 限流规则项类型

View File

@@ -5,8 +5,7 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc h1:t2AT8zb6N/59Y78lyRWedVoVWHNRSCBh0oWCC+bluTQ=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=

View File

@@ -16,15 +16,16 @@ package main
import (
"fmt"
"net"
"net/url"
"strconv"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/resp"
"net"
"net/url"
"strconv"
"strings"
)
func main() {
@@ -88,12 +89,10 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
args := []interface{}{configItem.count, configItem.timeWindow}
// 执行限流逻辑
err := config.redisClient.Eval(FixedWindowScript, 1, keys, args, func(response resp.Value) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
resultArray := response.Array()
if len(resultArray) != 3 {
log.Errorf("redis response parse error, response: %v", response)
proxywasm.ResumeHttpRequest()
return
}
context := LimitContext{
@@ -106,6 +105,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config ClusterKeyRateLimitCon
rejected(config, context)
} else {
ctx.SetContext(LimitContextKey, context)
proxywasm.ResumeHttpRequest()
}
})
if err != nil {

View File

@@ -2,9 +2,10 @@ package main
import (
"fmt"
"github.com/zmap/go-iptree/iptree"
"sort"
"strings"
"github.com/zmap/go-iptree/iptree"
)
// parseIPNet 解析Ip段配置

View File

@@ -18,8 +18,9 @@ import (
"errors"
"fmt"
"net/url"
"regexp"
"strings"
regexp "github.com/wasilibs/go-re2"
)
const (

View File

@@ -9,16 +9,19 @@ require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.14.4
github.com/wasilibs/go-re2 v1.6.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
golang.org/x/sys v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -4,17 +4,16 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -23,6 +22,11 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -17,8 +17,9 @@ package config
import (
"errors"
"net/url"
"regexp"
"strings"
regexp "github.com/wasilibs/go-re2"
)
const (

View File

@@ -9,16 +9,19 @@ require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.14.4
github.com/wasilibs/go-re2 v1.6.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
golang.org/x/sys v0.21.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -4,17 +4,16 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a h1:tdPcGgyiH0K+SbsJBBm2oPyEIOTAvLBwD9TuUwVtZho=
github.com/magefile/mage v1.15.1-0.20230912152418-9f54e0f83e2a/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -23,6 +22,11 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/wasilibs/go-re2 v1.6.0 h1:CLlhDebt38wtl/zz4ww+hkXBMcxjrKFvTDXzFW2VOz8=
github.com/wasilibs/go-re2 v1.6.0/go.mod h1:prArCyErsypRBI/jFAFJEbzyHzjABKqkzlidF0SNA04=
github.com/wasilibs/nottinygc v0.4.0 h1:h1TJMihMC4neN6Zq+WKpLxgd9xCFMw7O9ETLwY2exJQ=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1 @@
EXTRA_TAGS=proxy_wasm_version_0_2_100

View File

@@ -1,10 +1,16 @@
# 功能说明
---
title: 外部认证
keywords: [higress, auth]
description: Ext 认证插件实现了调用外部授权服务进行认证鉴权的功能。
---
## 功能说明
`ext-auth` 插件实现了向外部授权服务发送鉴权请求以检查客户端请求是否得到授权。该插件实现时参考了Envoy原生的[ext_authz filter](https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_filters/ext_authz_filter)实现了原生filter中对接HTTP服务的部分能力
# 配置字段
## 配置字段
| 名称 | 数据类型 | 必填 | 默认值 | 描述 |
| ------------------------------- | -------- | ---- | ------ |------------------------------------------------------------------------------------------------------------------------------------------------------------|
@@ -17,29 +23,30 @@
| 名称 | 数据类型 | 必填 | 默认值 | 描述 |
| ------------------------ | -------- | ---- | ------ | ------------------------------------- |
| `endpoint_mode` | string | 否 | envoy | `envoy` , `forward_auth` 中选填一项 |
| `endpoint` | object | 是 | - | 发送鉴权请求的 HTTP 服务信息 |
| `timeout` | int | 否 | 200 | `ext-auth` 服务连接超时时间,单位毫秒 |
| `timeout` | int | 否 | 1000 | `ext-auth` 服务连接超时时间,单位毫秒 |
| `authorization_request` | object | 否 | - | 发送鉴权请求配置 |
| `authorization_response` | object | 否 | - | 处理鉴权响应配置 |
| `authorization_response` | object | 否 | - | 处理鉴权响应配置 |
`endpoint`中每一项的配置字段说明
| 名称 | 数据类型 | 必填 | 默认值 | 描述 |
| ---------------- | -------- | ---- | ------ | --------------------------------------------------- |
| `service_source` | string | 是 | - | 类型为固定 ip 或者 dns输入授权服务的注册来源 |
| `service_name` | string | 是 | - | 输入授权服务的注册名称 |
| `service_port` | string | 是 | - | 输入授权服务的服务端口 |
| `service_domain` | string | 否 | - | 当类型为dns时必须填写输入 `ext-auth` 服务的domain |
| `request_method` | string | 否 | GET | 客户端向授权服务发送请求的HTTP Method |
| `path` | string | 是 | - | 输入授权服务的请求路径 |
| 名称 | 数据类型 | 必填 | 默认值 | 描述 |
| -------- | -------- | -- | ------ |-----------------------------------------------------------------------------------------|
| `service_name` | string | 必填 | - | 输入授权服务名称,带服务类型的完整 FQDN 名称,例如 `ext-auth.dns``ext-auth.my-ns.svc.cluster.local` |
| `service_port` | int | 否 | 80 | 输入授权服务的服务端口 |
| `path_prefix` | string | `endpoint_mode``envoy`时必填 | | `endpoint_mode``envoy` 时,客户端向授权服务发送请求的请求路径前缀 |
| `request_method` | string | 否 | GET | `endpoint_mode``forward_auth`客户端向授权服务发送请求的HTTP Method |
| `path` | string | `endpoint_mode``forward_auth`时必填 | - | `endpoint_mode``forward_auth` 时,客户端向授权服务发送请求的请求路径 |
`authorization_request`中每一项的配置字段说明
| 名称 | 数据类型 | 必填 | 默认值 | 描述 |
| ------------------- | ---------------------- | ---- | ------ | ------------------------------------------------------------ |
| `allowed_headers` | array of StringMatcher | 否 | - | 当设置后,具有相应匹配项的客户端请求头将添加到授权服务请求中的请求头中。除了用户自定义的头部匹配规则外,授权服务请求中会自动包含`Host`, `Method`, `Path`, `Content-Length``Authorization`这几个关键的HTTP头 |
| `headers_to_add` | `map[string]string` | 否 | - | 设置将包含在授权服务请求中的请求头列表。请注意,同名的客户端请求头将被覆盖 |
| `with_request_body` | bool | 否 | false | 缓冲客户端请求体并将其发送至鉴权请求中HTTP Method为GET、OPTIONS、HEAD请求时不生效 |
| 名称 | 数据类型 | 必填 | 默认值 | 描述 |
| ------------------------ | ---------------------- | ---- | ------ |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `allowed_headers` | array of StringMatcher | 否 | - | 当设置后,具有相应匹配项的客户端请求头将添加到授权服务请求中的请求头中。除了用户自定义的头部匹配规则外,授权服务请求中会自动包含 `Authorization` 这个HTTP头 `endpoint_mode` `forward_auth` 时,会把原始请求的请求路径设置到 `X-Original-Uri` 原始请求的HTTP Method设置到 `X-Original-Method` |
| `headers_to_add` | `map[string]string` | 否 | - | 设置将包含在授权服务请求中的请求头列表。请注意,同名的客户端请求头将被覆盖 |
| `with_request_body` | bool | 否 | false | 缓冲客户端请求体并将其发送至鉴权请求中HTTP Method为GET、OPTIONS、HEAD请求时不生效 |
| `max_request_body_bytes` | int | 否 | 10MB | 设置在内存中保存客户端请求体的最大尺寸。当客户端请求体达到在此字段中设置的数值时将会返回HTTP 413状态码并且不会启动授权过程。注意这个设置会优先于 `failure_mode_allow` 的配置 |
`authorization_response`中每一项的配置字段说明
@@ -60,30 +67,35 @@
# 配置示例
## 配置示例
下面假设 `ext-auth` 服务在Kubernetes中serviceName为 `ext-auth`,端口 `8090`,路径为 `/auth`,命名空间为 `backend`
## 示例1
支持两种 `endpoint_mode`
- `endpoint_mode``envoy`鉴权请求会使用原始请求的HTTP Method和配置的 `path_prefix` 作为请求路径前缀拼接上原始的请求路径
- `endpoint_mode``forward_auth` 时,鉴权请求会使用配置的 `request_method` 作为HTTP Method和配置的 `path` 作为请求路径
### endpoint_mode为envoy时
#### 示例1
`ext-auth` 插件的配置:
```yaml
http_service:
endpoint_mode: envoy
endpoint:
service_name: ext-auth
namespace: backend
service_name: ext-auth.backend.svc.cluster.local
service_port: 8090
service_source: k8s
path: /auth
request_method: POST
timeout: 500
path_prefix: /auth
timeout: 1000
```
使用如下请求网关,当开启 `ext-auth` 插件后:
```shell
curl -i http://localhost:8082/users -X GET -H "foo: bar" -H "Authorization: xxx"
curl -X POST http://localhost:8082/users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5 -X GET -H "foo: bar" -H "Authorization: xxx"
```
**请求 `ext-auth` 服务成功:**
@@ -91,7 +103,7 @@ curl -i http://localhost:8082/users -X GET -H "foo: bar" -H "Authorization: xxx"
`ext-auth` 服务将接收到如下的鉴权请求:
```
POST /auth HTTP/1.1
POST /auth/users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5 HTTP/1.1
Host: ext-auth
Authorization: xxx
Content-Length: 0
@@ -116,8 +128,7 @@ content-length: 0
`ext-auth` 服务返回其他 HTTP 状态码时,将以返回的状态码拒绝客户端请求。如果配置了 `allowed_client_headers`,具有相应匹配项的响应头将添加到客户端的响应中
## 示例2
#### 示例2
`ext-auth` 插件的配置:
@@ -132,26 +143,24 @@ http_service:
allowed_upstream_headers:
- exact: x-user-id
- exact: x-auth-version
endpoint_mode: envoy
endpoint:
service_name: ext-auth
namespace: backend
service_name: ext-auth.backend.svc.cluster.local
service_port: 8090
service_source: k8s
path: /auth
request_method: POST
timeout: 500
path_prefix: /auth
timeout: 1000
```
使用如下请求网关,当开启 `ext-auth` 插件后:
```shell
curl -i http://localhost:8082/users -X GET -H "foo: bar" -H "Authorization: xxx" -H "X-Auth-Version: 1.0"
curl -X POST http://localhost:8082/users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5 -X GET -H "foo: bar" -H "Authorization: xxx"
```
`ext-auth` 服务将接收到如下的鉴权请求:
```
POST /auth HTTP/1.1
POST /auth/users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5 HTTP/1.1
Host: ext-auth
Authorization: xxx
X-Auth-Version: 1.0
@@ -160,3 +169,116 @@ Content-Length: 0
```
`ext-auth` 服务返回响应头中如果包含 `x-user-id``x-auth-version`网关调用upstream时的请求中会带上这两个请求头
### endpoint_mode为forward_auth时
#### 示例1
`ext-auth` 插件的配置:
```yaml
http_service:
endpoint_mode: forward_auth
endpoint:
service_name: ext-auth.backend.svc.cluster.local
service_port: 8090
path: /auth
request_method: POST
timeout: 1000
```
使用如下请求网关,当开启 `ext-auth` 插件后:
```shell
curl -i http://localhost:8082/users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5 -X GET -H "foo: bar" -H "Authorization: xxx"
```
**请求 `ext-auth` 服务成功:**
`ext-auth` 服务将接收到如下的鉴权请求:
```
POST /auth HTTP/1.1
Host: ext-auth
Authorization: xxx
X-Original-Uri: /users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5
X-Original-Method: GET
Content-Length: 0
```
**请求 `ext-auth` 服务失败:**
当调用 `ext-auth` 服务响应为 5xx 时客户端将接收到HTTP响应码403和 `ext-auth` 服务返回的全量响应头
假如 `ext-auth` 服务返回了 `x-auth-version: 1.0``x-auth-failed: true` 的响应头,会传递给客户端
```
HTTP/1.1 403 Forbidden
x-auth-version: 1.0
x-auth-failed: true
date: Tue, 16 Jul 2024 00:19:41 GMT
server: istio-envoy
content-length: 0
```
`ext-auth` 无法访问或状态码为 5xx 时,将以 `status_on_error` 配置的状态码拒绝客户端请求
`ext-auth` 服务返回其他 HTTP 状态码时,将以返回的状态码拒绝客户端请求。如果配置了 `allowed_client_headers`,具有相应匹配项的响应头将添加到客户端的响应中
#### 示例2
`ext-auth` 插件的配置:
```yaml
http_service:
authorization_request:
allowed_headers:
- exact: x-auth-version
headers_to_add:
x-envoy-header: true
authorization_response:
allowed_upstream_headers:
- exact: x-user-id
- exact: x-auth-version
endpoint_mode: forward_auth
endpoint:
service_name: ext-auth.backend.svc.cluster.local
service_port: 8090
path: /auth
request_method: POST
timeout: 1000
```
使用如下请求网关,当开启 `ext-auth` 插件后:
```shell
curl -i http://localhost:8082/users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5 -X GET -H "foo: bar" -H "Authorization: xxx" -H "X-Auth-Version: 1.0"
```
`ext-auth` 服务将接收到如下的鉴权请求:
```
POST /auth HTTP/1.1
Host: ext-auth
Authorization: xxx
X-Original-Uri: /users?apikey=9a342114-ba8a-11ec-b1bf-00163e1250b5
X-Original-Method: GET
X-Auth-Version: 1.0
x-envoy-header: true
Content-Length: 0
```
`ext-auth` 服务返回响应头中如果包含 `x-user-id``x-auth-version`网关调用upstream时的请求中会带上这两个请求头
#### x-forwarded-* header
在endpoint_mode为forward_auth时higress会自动生成并发送以下header至鉴权服务。
| Header | 说明 |
|--------------------|-------------------------------------|
| x-forwarded-proto | 原始请求的scheme比如http/https |
| x-forwarded-method | 原始请求的方法比如get/post/delete/patch |
| x-forwarded-host | 原始请求的host |
| x-forwarded-uri | 原始请求的path包含路径参数比如/v1/app?test=true |
| x-forwarded-for | 原始请求的客户端IP地址 |

View File

@@ -2,18 +2,25 @@ package main
import (
"errors"
"ext-auth/expr"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"net/http"
"strings"
"ext-auth/expr"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
)
const (
DefaultStatusOnError uint32 = http.StatusForbidden
DefaultHttpServiceTimeout uint32 = 200
DefaultHttpServiceTimeout uint32 = 1000
DefaultMaxRequestBodyBytes uint32 = 10 * 1024 * 1024
EndpointModeEnvoy = "envoy"
EndpointModeForwardAuth = "forward_auth"
)
type ExtAuthConfig struct {
@@ -24,8 +31,13 @@ type ExtAuthConfig struct {
}
type HttpService struct {
client wrapper.HttpClient
requestMethod string
endpointMode string
client wrapper.HttpClient
// pathPrefix is only used when endpoint_mode is envoy
pathPrefix string
// requestMethod is only used when endpoint_mode is forward_auth
requestMethod string
// path is only used when endpoint_mode is forward_auth
path string
timeout uint32
authorizationRequest AuthorizationRequest
@@ -34,10 +46,14 @@ type HttpService struct {
type AuthorizationRequest struct {
// allowedHeaders In addition to the users supplied matchers,
// Host, Method, Path, Content-Length, and Authorization are automatically included to the list.
allowedHeaders expr.Matcher
headersToAdd map[string]string
withRequestBody bool
// Authorization are automatically included to the list.
// When the endpoint_mode is set to forward_auth,
// the original request's path is set in the X-Original-Uri header,
// and the original request's HTTP method is set in the X-Original-Method header.
allowedHeaders expr.Matcher
headersToAdd map[string]string
withRequestBody bool
maxRequestBodyBytes uint32
}
type AuthorizationResponse struct {
@@ -50,7 +66,7 @@ func parseConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) erro
if !httpServiceConfig.Exists() {
return errors.New("missing http_service in config")
}
err := parseHttpServiceConfig(httpServiceConfig, config)
err := parseHttpServiceConfig(httpServiceConfig, config, log)
if err != nil {
return err
}
@@ -65,20 +81,19 @@ func parseConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) erro
config.failureModeAllowHeaderAdd = failureModeAllowHeaderAdd.Bool()
}
statusOnError := json.Get("status_on_error")
if statusOnError.Exists() {
config.statusOnError = uint32(statusOnError.Uint())
} else {
config.statusOnError = DefaultStatusOnError
statusOnError := uint32(json.Get("status_on_error").Uint())
if statusOnError == 0 {
statusOnError = DefaultStatusOnError
}
config.statusOnError = statusOnError
return nil
}
func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig) error {
func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig, log wrapper.Log) error {
var httpService HttpService
if err := parseEndpointConfig(json, &httpService); err != nil {
if err := parseEndpointConfig(json, &httpService, log); err != nil {
return err
}
@@ -101,64 +116,63 @@ func parseHttpServiceConfig(json gjson.Result, config *ExtAuthConfig) error {
return nil
}
func parseEndpointConfig(json gjson.Result, httpService *HttpService) error {
func parseEndpointConfig(json gjson.Result, httpService *HttpService, log wrapper.Log) error {
endpointMode := json.Get("endpoint_mode").String()
if endpointMode == "" {
endpointMode = EndpointModeEnvoy
} else if endpointMode != EndpointModeEnvoy && endpointMode != EndpointModeForwardAuth {
return errors.New(fmt.Sprintf("endpoint_mode %s is not supported", endpointMode))
}
httpService.endpointMode = endpointMode
endpointConfig := json.Get("endpoint")
if !endpointConfig.Exists() {
return errors.New("missing endpoint in config")
}
serviceSource := endpointConfig.Get("service_source").String()
serviceName := endpointConfig.Get("service_name").String()
if serviceName == "" {
return errors.New("endpoint service name must not be empty")
}
servicePort := endpointConfig.Get("service_port").Int()
if serviceName == "" || servicePort == 0 {
return errors.New("invalid service config")
}
switch serviceSource {
case "k8s":
namespace := json.Get("namespace").String()
httpService.client = wrapper.NewClusterClient(wrapper.K8sCluster{
ServiceName: serviceName,
Namespace: namespace,
Port: servicePort,
})
return nil
case "nacos":
namespace := json.Get("namespace").String()
httpService.client = wrapper.NewClusterClient(wrapper.NacosCluster{
ServiceName: serviceName,
NamespaceID: namespace,
Port: servicePort,
})
return nil
case "ip":
httpService.client = wrapper.NewClusterClient(wrapper.StaticIpCluster{
ServiceName: serviceName,
Port: servicePort,
})
case "dns":
domain := endpointConfig.Get("domain").String()
httpService.client = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: serviceName,
Port: servicePort,
Domain: domain,
})
default:
return errors.New("unknown service source: " + serviceSource)
if servicePort == 0 {
servicePort = 80
}
requestMethodConfig := endpointConfig.Get("request_method")
if !requestMethodConfig.Exists() {
httpService.requestMethod = http.MethodGet
} else {
httpService.requestMethod = strings.ToUpper(requestMethodConfig.String())
}
httpService.client = wrapper.NewClusterClient(wrapper.FQDNCluster{
FQDN: serviceName,
Port: servicePort,
})
pathConfig := endpointConfig.Get("path")
if !pathConfig.Exists() {
return errors.New("missing path in config")
}
httpService.path = pathConfig.String()
switch endpointMode {
case EndpointModeEnvoy:
pathPrefixConfig := endpointConfig.Get("path_prefix")
if !pathPrefixConfig.Exists() {
return errors.New("when endpoint_mode is envoy, endpoint path_prefix must not be empty")
}
httpService.pathPrefix = pathPrefixConfig.String()
if endpointConfig.Get("request_method").Exists() || endpointConfig.Get("path").Exists() {
log.Warn("when endpoint_mode is envoy, endpoint request_method and path will be ignored")
}
case EndpointModeForwardAuth:
requestMethodConfig := endpointConfig.Get("request_method")
if !requestMethodConfig.Exists() {
httpService.requestMethod = http.MethodGet
} else {
httpService.requestMethod = strings.ToUpper(requestMethodConfig.String())
}
pathConfig := endpointConfig.Get("path")
if !pathConfig.Exists() {
return errors.New("when endpoint_mode is forward_auth, endpoint path must not be empty")
}
httpService.path = pathConfig.String()
if endpointConfig.Get("path_prefix").Exists() {
log.Warn("when endpoint_mode is forward_auth, endpoint path_prefix will be ignored")
}
}
return nil
}
@@ -167,6 +181,15 @@ func parseAuthorizationRequestConfig(json gjson.Result, httpService *HttpService
if authorizationRequestConfig.Exists() {
var authorizationRequest AuthorizationRequest
allowedHeaders := authorizationRequestConfig.Get("allowed_headers")
if allowedHeaders.Exists() {
result, err := expr.BuildRepeatedStringMatcherIgnoreCase(allowedHeaders.Array())
if err != nil {
return err
}
authorizationRequest.allowedHeaders = result
}
headersToAdd := map[string]string{}
headersToAddConfig := authorizationRequestConfig.Get("headers_to_add")
if headersToAddConfig.Exists() {
@@ -186,14 +209,11 @@ func parseAuthorizationRequestConfig(json gjson.Result, httpService *HttpService
authorizationRequest.withRequestBody = withRequestBody.Bool()
}
allowedHeaders := authorizationRequestConfig.Get("allowed_headers")
if allowedHeaders.Exists() {
result, err := expr.BuildRepeatedStringMatcherIgnoreCase(allowedHeaders.Array())
if err != nil {
return err
}
authorizationRequest.allowedHeaders = result
maxRequestBodyBytes := uint32(authorizationRequestConfig.Get("max_request_body_bytes").Uint())
if maxRequestBodyBytes == 0 {
maxRequestBodyBytes = DefaultMaxRequestBodyBytes
}
authorizationRequest.maxRequestBodyBytes = maxRequestBodyBytes
httpService.authorizationRequest = authorizationRequest
}

View File

@@ -2,9 +2,10 @@ package expr
import (
"errors"
"github.com/tidwall/gjson"
"regexp"
"strings"
"github.com/tidwall/gjson"
regexp "github.com/wasilibs/go-re2"
)
const (

Some files were not shown because too many files have changed in this diff Show More