[ai-json-resp] Extract JSON from LLM, Validate with Schema, Ensure Valid JSON, Auto-Retry (#1236)

This commit is contained in:
Yang Beining
2024-09-03 04:10:33 +01:00
committed by GitHub
parent 7b2b522160
commit ffc0c0976f
6 changed files with 1035 additions and 0 deletions

View File

@@ -0,0 +1,202 @@
## 简介
**Note**
> 需要数据面的proxy wasm版本大于等于0.2.100
>
> 编译时需要带上版本的tag例如tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./
LLM响应结构化插件用于根据默认或用户配置的Json Schema对AI的响应进行结构化以便后续插件处理。注意目前只支持 `非流式响应`
### 配置说明
| Name | Type | Requirement | Default | **Description** |
| --- | --- | --- | --- | --- |
| serviceName | str | required | - | AI服务或支持AI-Proxy的网关服务名称 |
| serviceDomain | str | optional | - | AI服务或支持AI-Proxy的网关服务域名/IP地址 |
| servicePath | str | optional | '/v1/chat/completions' | AI服务或支持AI-Proxy的网关服务基础路径 |
| serviceUrl | str | optional | - | AI服务或支持 AI-Proxy 的网关服务URL, 插件将自动提取Domain 和 Path, 用于填充未配置的 serviceDomain 或 servicePath |
| servicePort | int | optional | 443 | 网关服务端口 |
| serviceTimeout | int | optional | 50000 | 默认请求超时时间 |
| maxRetry | int | optional | 3 | 若回答无法正确提取格式化时重试次数 |
| contentPath | str | optional | "choices.0.message.content” | 从LLM回答中提取响应结果的gpath路径 |
| jsonSchema | str (json) | optional | - | 验证请求所参照的 jsonSchema, 为空只验证并返回合法Json格式响应 |
| enableSwagger | bool | optional | false | 是否启用 Swagger 协议进行验证 |
| enableOas3 | bool | optional | true | 是否启用 Oas3 协议进行验证 |
| enableContentDisposition | bool | optional | true | 是否启用 Content-Disposition 头部, 若启用则会在响应头中添加 `Content-Disposition: attachment; filename="response.json"` |
> 出于性能考虑,默认支持的最大 Json Schema 深度为 6。超过此深度的 Json Schema 将不用于验证响应,插件只会检查返回的响应是否为合法的 Json 格式。
### 请求和返回参数说明
- **请求参数**: 本插件请求格式为openai请求格式包含`model``messages`字段,其中`model`为AI模型名称`messages`为对话消息列表,每个消息包含`role``content`字段,`role`为消息角色,`content`为消息内容。
```json
{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "give me a api doc for add the variable x to x+5"}
]
}
```
其他请求参数需参考配置的ai服务或网关服务的相应文档。
- **返回参数**:
- 返回满足定义的Json Schema约束的 `Json格式响应`
- 若未定义Json Schema则返回合法的`Json格式响应`
- 若出现内部错误,则返回 `{ "Code": 10XX, "Msg": "错误信息提示" }`。
## 请求示例
```bash
curl -X POST "http://localhost:8001/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "give me a api doc for add the variable x to x+5"}
]
}'
```
## 返回示例
### 正常返回
在正常情况下,系统应返回经过 JSON Schema 验证的 JSON 数据。如果未配置 JSON Schema系统将返回符合 JSON 标准的合法 JSON 数据。
```json
{
"apiVersion": "1.0",
"request": {
"endpoint": "/add_to_five",
"method": "POST",
"port": 8080,
"headers": {
"Content-Type": "application/json"
},
"body": {
"x": 7
}
}
}
```
### 异常返回
在发生错误时,返回状态码为 `500`,返回内容为 JSON 格式的错误信息。包含错误码 `Code` 和错误信息 `Msg` 两个字段。
```json
{
"Code": 1006,
"Msg": "retry count exceed max retry count"
}
```
### 错误码说明
| 错误码 | 说明 |
| --- | --- |
| 1001 | 配置的Json Schema不是合法Json格式|
| 1002 | 配置的Json Schema编译失败不是合法的Json Schema 格式或深度超出 jsonSchemaMaxDepth 且 rejectOnDepthExceeded 为true|
| 1003 | 无法在响应中提取合法的Json|
| 1004 | 响应为空字符串|
| 1005 | 响应不符合Json Schema定义|
| 1006 | 重试次数超过最大限制|
| 1007 | 无法获取响应内容可能是上游服务配置错误或获取内容的ContentPath路径错误|
| 1008 | serciveDomain为空, 请注意serviceDomian或serviceUrl不能同时为空|
## 服务配置说明
本插件需要配置上游服务来支持出现异常时的自动重试机制, 支持的配置主要包括`支持openai接口的AI服务`或`本地网关服务`
### 支持openai接口的AI服务
以qwen为例基本配置如下
Yaml格式配置如下
```yaml
serviceName: qwen
serviceDomain: dashscope.aliyuncs.com
apiKey: [Your API Key]
servicePath: /compatible-mode/v1/chat/completions
jsonSchema:
title: ReasoningSchema
type: object
properties:
reasoning_steps:
type: array
items:
type: string
description: The reasoning steps leading to the final conclusion.
answer:
type: string
description: The final answer, taking into account the reasoning steps.
required:
- reasoning_steps
- answer
additionalProperties: false
```
JSON 格式配置
```json
{
"serviceName": "qwen",
"serviceUrl": "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
"apiKey": "[Your API Key]",
"jsonSchema": {
"title": "ActionItemsSchema",
"type": "object",
"properties": {
"action_items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "Description of the action item."
},
"due_date": {
"type": ["string", "null"],
"description": "Due date for the action item, can be null if not specified."
},
"owner": {
"type": ["string", "null"],
"description": "Owner responsible for the action item, can be null if not specified."
}
},
"required": ["description", "due_date", "owner"],
"additionalProperties": false
},
"description": "List of action items from the meeting."
}
},
"required": ["action_items"],
"additionalProperties": false
}
}
```
### 本地网关服务
为了能复用已经配置好的服务,本插件也支持配置本地网关服务。例如,若网关已经配置好了[AI-proxy服务](../ai-proxy/README.md),则可以直接配置如下:
1. 创建一个固定IP为127.0.0.1的服务例如localservice.static
```yaml
- name: outbound|10000||localservice.static
connect_timeout: 30s
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: outbound|8001||localservice.static
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: 127.0.0.1
port_value: 10000
```
2. 配置文件中添加localservice.static的服务配置
```yaml
serviceName: localservice
serviceDomain: 127.0.0.1
servicePort: 10000
```
3. 自动提取请求的PathHeader等信息
插件会自动提取请求的PathHeader等信息从而避免对AI服务的重复配置。

View File

@@ -0,0 +1,21 @@
module github.com/alibaba/higress/plugins/wasm-go/extensions/hello-world
go 1.18
replace github.com/alibaba/higress/plugins/wasm-go => ../..
require (
github.com/alibaba/higress/plugins/wasm-go v1.4.2
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
)
require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/santhosh-tekuri/jsonschema v1.2.4 // indirect
github.com/tidwall/gjson v1.14.3 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
)

View File

@@ -0,0 +1,26 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,573 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/santhosh-tekuri/jsonschema"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
const (
DEFAULT_SCHEMA = "defaultSchema"
HTTP_STATUS_OK = uint32(200)
HTTP_STATUS_INTERNAL_SERVER_ERROR = uint32(500)
FROM_THIS_PLUGIN_KEY = "fromThisPlugin"
EXTEND_HEADER_KEY = "X-HIGRESS-AI-JSON-RESP"
JSON_SCHEMA_INVALID_CODE = 1001
JSON_SCHEMA_COMPILE_FAILED_CODE = 1002
CANNOT_FIND_JSON_IN_RESPONSE_CODE = 1003
CONTENT_IS_EMPTY_CODE = 1004
JSON_MISMATCH_SCHEMA_CODE = 1005
REACH_MAX_RETRY_COUNT_CODE = 1006
SERVICE_UNAVAILABLE_CODE = 1007
SERVICE_CONFIG_INVALID_CODE = 1008
)
type RejectStruct struct {
RejectCode uint32 `json:"Code"`
RejectMsg string `json:"Msg"`
}
func (r RejectStruct) GetBytes() []byte {
jsonData, _ := json.Marshal(r)
return jsonData
}
func (r RejectStruct) GetShortMsg() string {
return "ai-json-resp." + strings.Split(r.RejectMsg, ":")[0]
}
type PluginConfig struct {
// @Title zh-CN 服务名称
// @Description zh-CN 用以请求服务的名称(网关或其他AI服务)
serviceName string `required:"true" json:"serviceName" yaml:"serviceName"`
// @Title zh-CN 服务域名
// @Description zh-CN 用以请求服务的域名
serviceDomain string `required:"false" json:"serviceDomain" yaml:"serviceDomain"`
// @Title zh-CN 服务端口
// @Description zh-CN 用以请求服务的端口
servicePort int `required:"false" json:"servicePort" yaml:"servicePort"`
// @Title zh-CN 服务URL
// @Description zh-CN 用以请求服务的URL若提供则会覆盖serviceDomain和servicePort
serviceUrl string `required:"false" json:"serviceUrl" yaml:"serviceUrl"`
// @Title zh-CN API Key
// @Description zh-CN 若使用AI服务需要填写请求服务的API Key
apiKey string `required:"false" json: "apiKey" yaml:"apiKey"`
// @Title zh-CN 请求端点
// @Description zh-CN 用以请求服务的端点, 默认为"/v1/chat/completions"
servicePath string `required:"false" json: "servicePath" yaml:"servicePath"`
// @Title zh-CN 服务超时时间
// @Description zh-CN 用以请求服务的超时时间
serviceTimeout int `required:"false" json:"serviceTimeout" yaml:"serviceTimeout"`
// @Title zh-CN 最大重试次数
// @Description zh-CN 用以请求服务的最大重试次数
maxRetry int `required:"false" json:"maxRetry" yaml:"maxRetry"`
// @Title zh-CN 内容路径
// @Description zh-CN 从AI服务返回的响应中提取json的gpath路径
contentPath string `required:"false" json:"contentPath" yaml:"contentPath"`
// @Title zh-CN Json Schema
// @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json
jsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"`
// @Title zh-CN 是否启用swagger
// @Description zh-CN 是否启用swagger进行Json Schema验证
enableSwagger bool `required:"false" json:"enableSwagger" yaml:"enableSwagger"`
// @Title zh-CN 是否启用oas3
// @Description zh-CN 是否启用oas3进行Json Schema验证
enableOas3 bool `required:"false" json:"enableOas3" yaml:"enableOas3"`
// @Title zh-CN 是否启用Content-Disposition
// @Description zh-CN 是否启用Content-Disposition, 若启用则会在响应头中添加Content-Disposition: attachment; filename="response.json"
enableContentDisposition bool `required:"false" json:"enableContentDisposition" yaml:"enableContentDisposition"`
serviceClient wrapper.HttpClient
draft *jsonschema.Draft
compiler *jsonschema.Compiler
compile *jsonschema.Schema
rejectStruct RejectStruct
jsonSchemaMaxDepth int
enableJsonSchemaValidation bool
}
func main() {
wrapper.SetCtx(
"ai-json-resp",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
)
}
type RequestContext struct {
Path string
ReqHeaders [][2]string
ReqBody []byte
RespHeader [][2]string
RespBody []byte
HistoryMessages []chatMessage
}
func parseUrl(url string) (string, string) {
if url == "" {
return "", ""
}
url = strings.TrimPrefix(url, "http://")
url = strings.TrimPrefix(url, "https://")
index := strings.Index(url, "/")
if index == -1 {
return url, ""
}
return url[:index], url[index:]
}
func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error {
config.serviceName = result.Get("serviceName").String()
config.serviceUrl = result.Get("serviceUrl").String()
config.serviceDomain = result.Get("serviceDomain").String()
config.servicePath = result.Get("servicePath").String()
config.servicePort = int(result.Get("servicePort").Int())
if config.serviceUrl != "" {
domain, url := parseUrl(config.serviceUrl)
log.Debugf("serviceUrl: %s, the parsed domain: %s, the parsed url: %s", config.serviceUrl, domain, url)
if config.serviceDomain == "" {
config.serviceDomain = domain
}
if config.servicePath == "" {
config.servicePath = url
}
}
if config.servicePort == 0 {
config.servicePort = 443
}
config.serviceTimeout = int(result.Get("serviceTimeout").Int())
config.apiKey = result.Get("apiKey").String()
config.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
if config.serviceTimeout == 0 {
config.serviceTimeout = 50000
}
config.maxRetry = int(result.Get("maxRetry").Int())
if config.maxRetry == 0 {
config.maxRetry = 3
}
config.contentPath = result.Get("contentPath").String()
if config.contentPath == "" {
config.contentPath = "choices.0.message.content"
}
if jsonSchemaValue := result.Get("jsonSchema"); jsonSchemaValue.Exists() {
if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok {
config.jsonSchema = schemaValue
} else {
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema is not valid"}
}
} else {
config.jsonSchema = nil
}
if config.serviceDomain == "" {
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "service domain is empty"}
}
config.serviceClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: config.serviceName,
Port: int64(config.servicePort),
Domain: config.serviceDomain,
})
enableSwagger := result.Get("enableSwagger").Bool()
enableOas3 := result.Get("enableOas3").Bool()
// set draft version
if enableSwagger {
config.draft = jsonschema.Draft4
}
if enableOas3 {
config.draft = jsonschema.Draft7
}
if !enableSwagger && !enableOas3 {
config.draft = jsonschema.Draft7
}
// create compiler
compiler := jsonschema.NewCompiler()
compiler.Draft = config.draft
config.compiler = compiler
// set max depth of json schema
config.jsonSchemaMaxDepth = 6
enableContentDispositionValue := result.Get("enableContentDisposition")
if !enableContentDispositionValue.Exists() {
config.enableContentDisposition = true
} else {
config.enableContentDisposition = enableContentDispositionValue.Bool()
}
config.enableJsonSchemaValidation = true
jsonSchemaBytes, err := json.Marshal(config.jsonSchema)
if err != nil {
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema marshal failed"}
return err
}
maxDepth := GetMaxDepth(config.jsonSchema)
log.Debugf("max depth of json schema: %d", maxDepth)
if maxDepth > config.jsonSchemaMaxDepth {
config.enableJsonSchemaValidation = false
log.Infof("Json Schema depth exceeded: %d from %d , Json Schema validation will not be used.", maxDepth, config.jsonSchemaMaxDepth)
}
if config.enableJsonSchemaValidation {
jsonSchemaStr := string(jsonSchemaBytes)
config.compiler.AddResource(DEFAULT_SCHEMA, strings.NewReader(jsonSchemaStr))
// Test if the Json Schema is valid
compile, err := config.compiler.Compile(DEFAULT_SCHEMA)
if err != nil {
log.Infof("Json Schema compile failed: %v", err)
config.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
config.compile = nil
} else {
config.compile = compile
}
}
return nil
}
func (r *RequestContext) assembleReqBody(config PluginConfig) []byte {
var reqBodystrut chatCompletionRequest
json.Unmarshal(r.ReqBody, &reqBodystrut)
content := gjson.ParseBytes(r.RespBody).Get(config.contentPath).String()
jsonSchemaBytes, _ := json.Marshal(config.jsonSchema)
jsonSchemaStr := string(jsonSchemaBytes)
askQuestion := "Given the Json Schema: " + jsonSchemaStr + ", please help me convert the following content to a pure json: " + content
askQuestion += "\n Do not respond other content except the pure json!!!!"
reqBodystrut.Messages = append(r.HistoryMessages, []chatMessage{
{
Role: "user",
Content: askQuestion,
},
}...)
reqBody, _ := json.Marshal(reqBodystrut)
return reqBody
}
func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) {
r.RespBody = respBody
lastUserMessage := ""
lastSystemMessage := ""
var reqBodystrut chatCompletionRequest
err := json.Unmarshal(reqBody, &reqBodystrut)
if err != nil {
log.Debugf("unmarshal reqBody failed: %v", err)
} else {
if len(reqBodystrut.Messages) != 0 {
lastUserMessage = reqBodystrut.Messages[len(reqBodystrut.Messages)-1].Content
}
}
var respBodystrut chatCompletionResponse
err = json.Unmarshal(respBody, &respBodystrut)
if err != nil {
log.Debugf("unmarshal respBody failed: %v", err)
} else {
if len(respBodystrut.Choices) != 0 {
lastSystemMessage = respBodystrut.Choices[len(respBodystrut.Choices)-1].Message.Content
}
}
if lastUserMessage != "" {
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
Role: "user",
Content: lastUserMessage,
})
}
if lastSystemMessage != "" {
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
Role: "system",
Content: lastSystemMessage,
})
}
}
func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) {
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
Role: "system",
Content: errMsg,
})
}
func (c *PluginConfig) ValidateBody(body []byte) error {
var respJsonStrct chatCompletionResponse
err := json.Unmarshal(body, &respJsonStrct)
if err != nil {
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "service unavailable: " + string(body)}
return errors.New(c.rejectStruct.RejectMsg)
}
content := gjson.ParseBytes(body).Get(c.contentPath)
if !content.Exists() {
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "response body does not contain the content: " + string(body)}
return errors.New(c.rejectStruct.RejectMsg)
}
return nil
}
func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) {
content := gjson.ParseBytes(body).Get(c.contentPath).String()
// first extract json from response body
if content == "" {
log.Infof("response body does not contain the content")
c.rejectStruct = RejectStruct{CONTENT_IS_EMPTY_CODE, "response body does not contain the content"}
return "", errors.New(c.rejectStruct.RejectMsg)
}
jsonStr, err := c.ExtractJson(content)
if err != nil {
log.Infof("response body does not contain the valid json: %v", err.Error())
c.rejectStruct = RejectStruct{CANNOT_FIND_JSON_IN_RESPONSE_CODE, "response body does not contain the valid json: " + err.Error()}
return "", errors.New(c.rejectStruct.RejectMsg)
}
if c.jsonSchema != nil && c.enableJsonSchemaValidation {
compile, err := c.compiler.Compile(DEFAULT_SCHEMA)
if err != nil {
log.Infof("Json Schema compile failed: %v", err)
c.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
c.compile = nil
} else {
c.compile = compile
}
// validate the json
err = c.compile.Validate(strings.NewReader(jsonStr))
if err != nil {
log.Infof("response body does not match the Json Schema: %v", err)
c.rejectStruct = RejectStruct{JSON_MISMATCH_SCHEMA_CODE, "response body does not match the Json Schema: " + err.Error()}
return "", errors.New(c.rejectStruct.RejectMsg)
}
}
c.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
return jsonStr, nil
}
func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) {
// simply extract json from response body string
startIndex := strings.Index(bodyStr, "{")
endIndex := strings.LastIndex(bodyStr, "}") + 1
// if not found
if startIndex == -1 || endIndex == -1 || startIndex >= endIndex {
return "", errors.New("cannot find json in the response body")
}
jsonStr := bodyStr[startIndex:endIndex]
// attempt to parse the JSON
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
if err != nil {
return "", err
}
return jsonStr, nil
}
func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) {
log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body))
header := [][2]string{
{"Content-Type", "application/json"},
}
if body != nil && config.enableContentDisposition {
header = append(header, [2]string{"Content-Disposition", "attachment; filename=\"response.json\""})
}
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
proxywasm.SendHttpResponseWithDetail(HTTP_STATUS_INTERNAL_SERVER_ERROR, config.rejectStruct.GetShortMsg(), nil, config.rejectStruct.GetBytes(), -1)
} else {
proxywasm.SendHttpResponse(HTTP_STATUS_OK, header, body, -1)
}
}
func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) {
// if retry count exceeds max retry count, return the response
if retryCount >= config.maxRetry {
log.Debugf("retry count exceeds max retry count")
// report more useful error by appending the last of previous error message
config.rejectStruct = RejectStruct{REACH_MAX_RETRY_COUNT_CODE, "retry count exceeds max retry count: " + config.rejectStruct.RejectMsg}
sendResponse(ctx, config, log, nil)
return
}
// recursively refine json
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.assembleReqBody(config),
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
err := config.ValidateBody(responseBody)
if err != nil {
sendResponse(ctx, config, log, nil)
return
}
retryCount++
requestContext.SaveBodyToHistMsg(log, requestContext.assembleReqBody(config), responseBody)
log.Debugf("[retry request %d/%d] resp code: %d", retryCount, config.maxRetry, statusCode)
validateJson, err := config.ValidateJson(responseBody, log)
if err == nil {
sendResponse(ctx, config, log, []byte(validateJson))
} else {
requestContext.SaveStrToHistMsg(log, err.Error())
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
}
}, uint32(config.serviceTimeout))
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
sendResponse(ctx, config, log, nil)
return types.ActionPause
}
// verify if the request is from this plugin
extendHeaderValue, err := proxywasm.GetHttpRequestHeader(EXTEND_HEADER_KEY)
if err == nil {
fromThisPlugin, convErr := strconv.ParseBool(extendHeaderValue)
if convErr != nil {
log.Debugf("failed to parse header value as bool: %v", convErr)
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
}
if fromThisPlugin {
ctx.SetContext(FROM_THIS_PLUGIN_KEY, true)
return types.ActionContinue
}
} else {
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
}
path, err := proxywasm.GetHttpRequestHeader(":path")
if err != nil {
log.Infof("get request path failed: %v", err)
path = ""
} else {
ctx.SetContext("path", path)
}
headers, err := proxywasm.GetHttpRequestHeaders()
if err != nil {
log.Infof("get request header failed: %v", err)
}
apiKey, err := proxywasm.GetHttpRequestHeader("Authorization")
if err != nil {
log.Infof("get request header failed: %v", err)
apiKey = ""
}
if apiKey != "" {
// remove the Authorization header
proxywasm.RemoveHttpRequestHeader("Authorization")
// remove the Authorization header from the headers
for i, header := range headers {
if header[0] == "Authorization" {
headers = append(headers[:i], headers[i+1:]...)
break
}
}
}
if config.apiKey != "" {
log.Debugf("add Authorization header %s", "Bearer "+config.apiKey)
headers = append(headers, [2]string{"Authorization", "Bearer " + config.apiKey})
}
ctx.SetContext("headers", headers)
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
// if the request is from this plugin, continue the request
fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool)
if ok && fromThisPlugin {
log.Debugf("detected buffer_request, sending request to AI service")
return types.ActionContinue
}
var headers [][2]string
if h, ok := ctx.GetContext("headers").([][2]string); ok {
headers = append(h, [2]string{EXTEND_HEADER_KEY, "true"})
} else {
log.Debugf("cannot get headers from context, use default headers")
headers = [][2]string{
{"Content-Type", "application/json"},
{EXTEND_HEADER_KEY, "true"},
}
}
// if there is any error in the config, return the response directly
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
sendResponse(ctx, config, log, nil)
return types.ActionContinue
}
var path string
if path, ok := ctx.GetContext("path").(string); ok {
log.Debugf("use path: %s", path)
} else {
log.Debugf("cannot get path from context, use default path")
path = "/v1/chat/completions"
}
if config.servicePath != "" {
log.Debugf("use base path: %s", config.servicePath)
path = config.servicePath
}
requestContext := &RequestContext{
Path: path,
ReqHeaders: headers,
ReqBody: body,
}
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.ReqBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
err := config.ValidateBody(responseBody)
if err != nil {
sendResponse(ctx, config, log, nil)
return
}
requestContext.SaveBodyToHistMsg(log, body, responseBody)
log.Debugf("[first request] resp code: %d", statusCode)
validateJson, err := config.ValidateJson(responseBody, log)
if err == nil {
sendResponse(ctx, config, log, []byte(validateJson))
return
} else {
retryCount := 0
requestContext.SaveStrToHistMsg(log, err.Error())
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
}
}, uint32(config.serviceTimeout))
return types.ActionPause
}

View File

@@ -0,0 +1,180 @@
// adopt from https://github.com/alibaba/higress/blob/main/plugins/wasm-go/extensions/ai-proxy/provider/model.go
package main
import "strings"
const (
streamEventIdItemKey = "id:"
streamEventNameItemKey = "event:"
streamBuiltInItemKey = ":"
streamHttpStatusValuePrefix = "HTTP_STATUS/"
streamDataItemKey = "data:"
streamEndDataValue = "[DONE]"
)
type chatCompletionRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Seed int `json:"seed,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *streamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Tools []tool `json:"tools,omitempty"`
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
Stop []string `json:"stop,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
}
type streamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"`
}
type tool struct {
Type string `json:"type"`
Function function `json:"function"`
}
type function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters map[string]interface{} `json:"parameters,omitempty"`
}
type toolChoice struct {
Type string `json:"type"`
Function function `json:"function"`
}
type chatCompletionResponse struct {
Id string `json:"id,omitempty"`
Choices []chatCompletionChoice `json:"choices"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
SystemFingerprint string `json:"system_fingerprint,omitempty"`
Object string `json:"object,omitempty"`
Usage usage `json:"usage,omitempty"`
}
type chatCompletionChoice struct {
Index int `json:"index"`
Message *chatMessage `json:"message,omitempty"`
Delta *chatMessage `json:"delta,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
type usage struct {
PromptTokens int `json:"prompt_tokens,omitempty"`
CompletionTokens int `json:"completion_tokens,omitempty"`
TotalTokens int `json:"total_tokens,omitempty"`
}
type chatMessage struct {
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ToolCalls []toolCall `json:"tool_calls,omitempty"`
}
func (m *chatMessage) IsEmpty() bool {
if m.Content != "" {
return false
}
if len(m.ToolCalls) != 0 {
nonEmpty := false
for _, toolCall := range m.ToolCalls {
if !toolCall.Function.IsEmpty() {
nonEmpty = true
break
}
}
if nonEmpty {
return false
}
}
return true
}
type toolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Type string `json:"type"`
Function functionCall `json:"function"`
}
type functionCall struct {
Id string `json:"id"`
Name string `json:"name"`
Arguments string `json:"arguments"`
}
func (m *functionCall) IsEmpty() bool {
return m.Name == "" && m.Arguments == ""
}
type streamEvent struct {
Id string `json:"id"`
Event string `json:"event"`
Data string `json:"data"`
HttpStatus string `json:"http_status"`
}
func (e *streamEvent) setValue(key, value string) {
switch key {
case streamEventIdItemKey:
e.Id = value
case streamEventNameItemKey:
e.Event = value
case streamDataItemKey:
e.Data = value
case streamBuiltInItemKey:
if strings.HasPrefix(value, streamHttpStatusValuePrefix) {
e.HttpStatus = value[len(streamHttpStatusValuePrefix):]
}
}
}
type embeddingsRequest struct {
Input interface{} `json:"input"`
Model string `json:"model"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
User string `json:"user,omitempty"`
}
type embeddingsResponse struct {
Object string `json:"object"`
Data []embedding `json:"data"`
Model string `json:"model"`
Usage usage `json:"usage"`
}
type embedding struct {
Object string `json:"object"`
Index int `json:"index"`
Embedding []float64 `json:"embedding"`
}
func (r embeddingsRequest) ParseInput() []string {
if r.Input == nil {
return nil
}
var input []string
switch r.Input.(type) {
case string:
input = []string{r.Input.(string)}
case []any:
input = make([]string, 0, len(r.Input.([]any)))
for _, item := range r.Input.([]any) {
if str, ok := item.(string); ok {
input = append(input, str)
}
}
}
return input
}

View File

@@ -0,0 +1,33 @@
package main
func GetMaxDepth(data interface{}) int {
type item struct {
value interface{}
depth int
}
maxDepth := 0
stack := []item{{value: data, depth: 1}}
for len(stack) > 0 {
currentItem := stack[len(stack)-1]
stack = stack[:len(stack)-1]
if currentItem.depth > maxDepth {
maxDepth = currentItem.depth
}
switch v := currentItem.value.(type) {
case map[string]interface{}:
for _, value := range v {
stack = append(stack, item{value: value, depth: currentItem.depth + 1})
}
case []interface{}:
for _, value := range v {
stack = append(stack, item{value: value, depth: currentItem.depth + 1})
}
}
}
return maxDepth
}