mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
[ai-json-resp] Extract JSON from LLM, Validate with Schema, Ensure Valid JSON, Auto-Retry (#1236)
This commit is contained in:
202
plugins/wasm-go/extensions/ai-json-resp/README.md
Normal file
202
plugins/wasm-go/extensions/ai-json-resp/README.md
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
## 简介
|
||||||
|
|
||||||
|
**Note**
|
||||||
|
|
||||||
|
> 需要数据面的proxy wasm版本大于等于0.2.100
|
||||||
|
>
|
||||||
|
|
||||||
|
> 编译时,需要带上版本的tag,例如:tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./
|
||||||
|
|
||||||
|
|
||||||
|
LLM响应结构化插件,用于根据默认或用户配置的Json Schema对AI的响应进行结构化,以便后续插件处理。注意目前只支持 `非流式响应`。
|
||||||
|
|
||||||
|
|
||||||
|
### 配置说明
|
||||||
|
|
||||||
|
| Name | Type | Requirement | Default | **Description** |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| serviceName | str | required | - | AI服务或支持AI-Proxy的网关服务名称 |
|
||||||
|
| serviceDomain | str | optional | - | AI服务或支持AI-Proxy的网关服务域名/IP地址 |
|
||||||
|
| servicePath | str | optional | '/v1/chat/completions' | AI服务或支持AI-Proxy的网关服务基础路径 |
|
||||||
|
| serviceUrl | str | optional | - | AI服务或支持 AI-Proxy 的网关服务URL, 插件将自动提取Domain 和 Path, 用于填充未配置的 serviceDomain 或 servicePath |
|
||||||
|
| servicePort | int | optional | 443 | 网关服务端口 |
|
||||||
|
| serviceTimeout | int | optional | 50000 | 默认请求超时时间 |
|
||||||
|
| maxRetry | int | optional | 3 | 若回答无法正确提取格式化时重试次数 |
|
||||||
|
| contentPath | str | optional | "choices.0.message.content” | 从LLM回答中提取响应结果的gpath路径 |
|
||||||
|
| jsonSchema | str (json) | optional | - | 验证请求所参照的 jsonSchema, 为空只验证并返回合法Json格式响应 |
|
||||||
|
| enableSwagger | bool | optional | false | 是否启用 Swagger 协议进行验证 |
|
||||||
|
| enableOas3 | bool | optional | true | 是否启用 Oas3 协议进行验证 |
|
||||||
|
| enableContentDisposition | bool | optional | true | 是否启用 Content-Disposition 头部, 若启用则会在响应头中添加 `Content-Disposition: attachment; filename="response.json"` |
|
||||||
|
|
||||||
|
> 出于性能考虑,默认支持的最大 Json Schema 深度为 6。超过此深度的 Json Schema 将不用于验证响应,插件只会检查返回的响应是否为合法的 Json 格式。
|
||||||
|
|
||||||
|
|
||||||
|
### 请求和返回参数说明
|
||||||
|
|
||||||
|
- **请求参数**: 本插件请求格式为openai请求格式,包含`model`和`messages`字段,其中`model`为AI模型名称,`messages`为对话消息列表,每个消息包含`role`和`content`字段,`role`为消息角色,`content`为消息内容。
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "give me a api doc for add the variable x to x+5"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
其他请求参数需参考配置的ai服务或网关服务的相应文档。
|
||||||
|
- **返回参数**:
|
||||||
|
- 返回满足定义的Json Schema约束的 `Json格式响应`
|
||||||
|
- 若未定义Json Schema,则返回合法的`Json格式响应`
|
||||||
|
- 若出现内部错误,则返回 `{ "Code": 10XX, "Msg": "错误信息提示" }`。
|
||||||
|
|
||||||
|
## 请求示例
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:8001/v1/chat/completions" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "give me a api doc for add the variable x to x+5"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## 返回示例
|
||||||
|
### 正常返回
|
||||||
|
在正常情况下,系统应返回经过 JSON Schema 验证的 JSON 数据。如果未配置 JSON Schema,系统将返回符合 JSON 标准的合法 JSON 数据。
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"apiVersion": "1.0",
|
||||||
|
"request": {
|
||||||
|
"endpoint": "/add_to_five",
|
||||||
|
"method": "POST",
|
||||||
|
"port": 8080,
|
||||||
|
"headers": {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
"body": {
|
||||||
|
"x": 7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 异常返回
|
||||||
|
在发生错误时,返回状态码为 `500`,返回内容为 JSON 格式的错误信息。包含错误码 `Code` 和错误信息 `Msg` 两个字段。
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Code": 1006,
|
||||||
|
"Msg": "retry count exceed max retry count"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 错误码说明
|
||||||
|
| 错误码 | 说明 |
|
||||||
|
| --- | --- |
|
||||||
|
| 1001 | 配置的Json Schema不是合法Json格式|
|
||||||
|
| 1002 | 配置的Json Schema编译失败,不是合法的Json Schema 格式或深度超出 jsonSchemaMaxDepth 且 rejectOnDepthExceeded 为true|
|
||||||
|
| 1003 | 无法在响应中提取合法的Json|
|
||||||
|
| 1004 | 响应为空字符串|
|
||||||
|
| 1005 | 响应不符合Json Schema定义|
|
||||||
|
| 1006 | 重试次数超过最大限制|
|
||||||
|
| 1007 | 无法获取响应内容,可能是上游服务配置错误或获取内容的ContentPath路径错误|
|
||||||
|
| 1008 | serciveDomain为空, 请注意serviceDomian或serviceUrl不能同时为空|
|
||||||
|
|
||||||
|
## 服务配置说明
|
||||||
|
本插件需要配置上游服务来支持出现异常时的自动重试机制, 支持的配置主要包括`支持openai接口的AI服务`或`本地网关服务`
|
||||||
|
|
||||||
|
### 支持openai接口的AI服务
|
||||||
|
以qwen为例,基本配置如下:
|
||||||
|
|
||||||
|
Yaml格式配置如下
|
||||||
|
```yaml
|
||||||
|
serviceName: qwen
|
||||||
|
serviceDomain: dashscope.aliyuncs.com
|
||||||
|
apiKey: [Your API Key]
|
||||||
|
servicePath: /compatible-mode/v1/chat/completions
|
||||||
|
jsonSchema:
|
||||||
|
title: ReasoningSchema
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
reasoning_steps:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: The reasoning steps leading to the final conclusion.
|
||||||
|
answer:
|
||||||
|
type: string
|
||||||
|
description: The final answer, taking into account the reasoning steps.
|
||||||
|
required:
|
||||||
|
- reasoning_steps
|
||||||
|
- answer
|
||||||
|
additionalProperties: false
|
||||||
|
```
|
||||||
|
|
||||||
|
JSON 格式配置
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"serviceName": "qwen",
|
||||||
|
"serviceUrl": "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||||
|
"apiKey": "[Your API Key]",
|
||||||
|
"jsonSchema": {
|
||||||
|
"title": "ActionItemsSchema",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action_items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Description of the action item."
|
||||||
|
},
|
||||||
|
"due_date": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Due date for the action item, can be null if not specified."
|
||||||
|
},
|
||||||
|
"owner": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": "Owner responsible for the action item, can be null if not specified."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["description", "due_date", "owner"],
|
||||||
|
"additionalProperties": false
|
||||||
|
},
|
||||||
|
"description": "List of action items from the meeting."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["action_items"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 本地网关服务
|
||||||
|
为了能复用已经配置好的服务,本插件也支持配置本地网关服务。例如,若网关已经配置好了[AI-proxy服务](../ai-proxy/README.md),则可以直接配置如下:
|
||||||
|
1. 创建一个固定IP为127.0.0.1的服务,例如localservice.static
|
||||||
|
```yaml
|
||||||
|
- name: outbound|10000||localservice.static
|
||||||
|
connect_timeout: 30s
|
||||||
|
type: LOGICAL_DNS
|
||||||
|
dns_lookup_family: V4_ONLY
|
||||||
|
lb_policy: ROUND_ROBIN
|
||||||
|
load_assignment:
|
||||||
|
cluster_name: outbound|8001||localservice.static
|
||||||
|
endpoints:
|
||||||
|
- lb_endpoints:
|
||||||
|
- endpoint:
|
||||||
|
address:
|
||||||
|
socket_address:
|
||||||
|
address: 127.0.0.1
|
||||||
|
port_value: 10000
|
||||||
|
```
|
||||||
|
2. 配置文件中添加localservice.static的服务配置
|
||||||
|
```yaml
|
||||||
|
serviceName: localservice
|
||||||
|
serviceDomain: 127.0.0.1
|
||||||
|
servicePort: 10000
|
||||||
|
```
|
||||||
|
3. 自动提取请求的Path,Header等信息
|
||||||
|
插件会自动提取请求的Path,Header等信息,从而避免对AI服务的重复配置。
|
||||||
21
plugins/wasm-go/extensions/ai-json-resp/go.mod
Normal file
21
plugins/wasm-go/extensions/ai-json-resp/go.mod
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
module github.com/alibaba/higress/plugins/wasm-go/extensions/hello-world
|
||||||
|
|
||||||
|
go 1.18
|
||||||
|
|
||||||
|
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/alibaba/higress/plugins/wasm-go v1.4.2
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||||
|
github.com/magefile/mage v1.14.0 // indirect
|
||||||
|
github.com/santhosh-tekuri/jsonschema v1.2.4 // indirect
|
||||||
|
github.com/tidwall/gjson v1.14.3 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
|
github.com/tidwall/resp v0.1.1 // indirect
|
||||||
|
)
|
||||||
26
plugins/wasm-go/extensions/ai-json-resp/go.sum
Normal file
26
plugins/wasm-go/extensions/ai-json-resp/go.sum
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||||
|
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||||
|
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||||
|
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||||
|
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
|
||||||
|
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
|
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||||
|
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||||
|
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
573
plugins/wasm-go/extensions/ai-json-resp/main.go
Normal file
573
plugins/wasm-go/extensions/ai-json-resp/main.go
Normal file
@@ -0,0 +1,573 @@
|
|||||||
|
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||||
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||||
|
"github.com/santhosh-tekuri/jsonschema"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DEFAULT_SCHEMA = "defaultSchema"
|
||||||
|
HTTP_STATUS_OK = uint32(200)
|
||||||
|
HTTP_STATUS_INTERNAL_SERVER_ERROR = uint32(500)
|
||||||
|
FROM_THIS_PLUGIN_KEY = "fromThisPlugin"
|
||||||
|
EXTEND_HEADER_KEY = "X-HIGRESS-AI-JSON-RESP"
|
||||||
|
|
||||||
|
JSON_SCHEMA_INVALID_CODE = 1001
|
||||||
|
JSON_SCHEMA_COMPILE_FAILED_CODE = 1002
|
||||||
|
CANNOT_FIND_JSON_IN_RESPONSE_CODE = 1003
|
||||||
|
CONTENT_IS_EMPTY_CODE = 1004
|
||||||
|
JSON_MISMATCH_SCHEMA_CODE = 1005
|
||||||
|
REACH_MAX_RETRY_COUNT_CODE = 1006
|
||||||
|
SERVICE_UNAVAILABLE_CODE = 1007
|
||||||
|
SERVICE_CONFIG_INVALID_CODE = 1008
|
||||||
|
)
|
||||||
|
|
||||||
|
type RejectStruct struct {
|
||||||
|
RejectCode uint32 `json:"Code"`
|
||||||
|
RejectMsg string `json:"Msg"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RejectStruct) GetBytes() []byte {
|
||||||
|
jsonData, _ := json.Marshal(r)
|
||||||
|
return jsonData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r RejectStruct) GetShortMsg() string {
|
||||||
|
return "ai-json-resp." + strings.Split(r.RejectMsg, ":")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
type PluginConfig struct {
|
||||||
|
// @Title zh-CN 服务名称
|
||||||
|
// @Description zh-CN 用以请求服务的名称(网关或其他AI服务)
|
||||||
|
serviceName string `required:"true" json:"serviceName" yaml:"serviceName"`
|
||||||
|
// @Title zh-CN 服务域名
|
||||||
|
// @Description zh-CN 用以请求服务的域名
|
||||||
|
serviceDomain string `required:"false" json:"serviceDomain" yaml:"serviceDomain"`
|
||||||
|
// @Title zh-CN 服务端口
|
||||||
|
// @Description zh-CN 用以请求服务的端口
|
||||||
|
servicePort int `required:"false" json:"servicePort" yaml:"servicePort"`
|
||||||
|
// @Title zh-CN 服务URL
|
||||||
|
// @Description zh-CN 用以请求服务的URL,若提供则会覆盖serviceDomain和servicePort
|
||||||
|
serviceUrl string `required:"false" json:"serviceUrl" yaml:"serviceUrl"`
|
||||||
|
// @Title zh-CN API Key
|
||||||
|
// @Description zh-CN 若使用AI服务,需要填写请求服务的API Key
|
||||||
|
apiKey string `required:"false" json: "apiKey" yaml:"apiKey"`
|
||||||
|
// @Title zh-CN 请求端点
|
||||||
|
// @Description zh-CN 用以请求服务的端点, 默认为"/v1/chat/completions"
|
||||||
|
servicePath string `required:"false" json: "servicePath" yaml:"servicePath"`
|
||||||
|
// @Title zh-CN 服务超时时间
|
||||||
|
// @Description zh-CN 用以请求服务的超时时间
|
||||||
|
serviceTimeout int `required:"false" json:"serviceTimeout" yaml:"serviceTimeout"`
|
||||||
|
// @Title zh-CN 最大重试次数
|
||||||
|
// @Description zh-CN 用以请求服务的最大重试次数
|
||||||
|
maxRetry int `required:"false" json:"maxRetry" yaml:"maxRetry"`
|
||||||
|
// @Title zh-CN 内容路径
|
||||||
|
// @Description zh-CN 从AI服务返回的响应中提取json的gpath路径
|
||||||
|
contentPath string `required:"false" json:"contentPath" yaml:"contentPath"`
|
||||||
|
// @Title zh-CN Json Schema
|
||||||
|
// @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json
|
||||||
|
jsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"`
|
||||||
|
// @Title zh-CN 是否启用swagger
|
||||||
|
// @Description zh-CN 是否启用swagger进行Json Schema验证
|
||||||
|
enableSwagger bool `required:"false" json:"enableSwagger" yaml:"enableSwagger"`
|
||||||
|
// @Title zh-CN 是否启用oas3
|
||||||
|
// @Description zh-CN 是否启用oas3进行Json Schema验证
|
||||||
|
enableOas3 bool `required:"false" json:"enableOas3" yaml:"enableOas3"`
|
||||||
|
// @Title zh-CN 是否启用Content-Disposition
|
||||||
|
// @Description zh-CN 是否启用Content-Disposition, 若启用则会在响应头中添加Content-Disposition: attachment; filename="response.json"
|
||||||
|
enableContentDisposition bool `required:"false" json:"enableContentDisposition" yaml:"enableContentDisposition"`
|
||||||
|
|
||||||
|
serviceClient wrapper.HttpClient
|
||||||
|
draft *jsonschema.Draft
|
||||||
|
compiler *jsonschema.Compiler
|
||||||
|
compile *jsonschema.Schema
|
||||||
|
rejectStruct RejectStruct
|
||||||
|
jsonSchemaMaxDepth int
|
||||||
|
enableJsonSchemaValidation bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wrapper.SetCtx(
|
||||||
|
"ai-json-resp",
|
||||||
|
wrapper.ParseConfigBy(parseConfig),
|
||||||
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||||
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestContext struct {
|
||||||
|
Path string
|
||||||
|
ReqHeaders [][2]string
|
||||||
|
ReqBody []byte
|
||||||
|
RespHeader [][2]string
|
||||||
|
RespBody []byte
|
||||||
|
HistoryMessages []chatMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseUrl(url string) (string, string) {
|
||||||
|
if url == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
url = strings.TrimPrefix(url, "http://")
|
||||||
|
url = strings.TrimPrefix(url, "https://")
|
||||||
|
index := strings.Index(url, "/")
|
||||||
|
if index == -1 {
|
||||||
|
return url, ""
|
||||||
|
}
|
||||||
|
return url[:index], url[index:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error {
|
||||||
|
config.serviceName = result.Get("serviceName").String()
|
||||||
|
config.serviceUrl = result.Get("serviceUrl").String()
|
||||||
|
config.serviceDomain = result.Get("serviceDomain").String()
|
||||||
|
config.servicePath = result.Get("servicePath").String()
|
||||||
|
config.servicePort = int(result.Get("servicePort").Int())
|
||||||
|
if config.serviceUrl != "" {
|
||||||
|
domain, url := parseUrl(config.serviceUrl)
|
||||||
|
log.Debugf("serviceUrl: %s, the parsed domain: %s, the parsed url: %s", config.serviceUrl, domain, url)
|
||||||
|
if config.serviceDomain == "" {
|
||||||
|
config.serviceDomain = domain
|
||||||
|
}
|
||||||
|
if config.servicePath == "" {
|
||||||
|
config.servicePath = url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.servicePort == 0 {
|
||||||
|
config.servicePort = 443
|
||||||
|
}
|
||||||
|
config.serviceTimeout = int(result.Get("serviceTimeout").Int())
|
||||||
|
config.apiKey = result.Get("apiKey").String()
|
||||||
|
config.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
|
||||||
|
if config.serviceTimeout == 0 {
|
||||||
|
config.serviceTimeout = 50000
|
||||||
|
}
|
||||||
|
config.maxRetry = int(result.Get("maxRetry").Int())
|
||||||
|
if config.maxRetry == 0 {
|
||||||
|
config.maxRetry = 3
|
||||||
|
}
|
||||||
|
config.contentPath = result.Get("contentPath").String()
|
||||||
|
if config.contentPath == "" {
|
||||||
|
config.contentPath = "choices.0.message.content"
|
||||||
|
}
|
||||||
|
|
||||||
|
if jsonSchemaValue := result.Get("jsonSchema"); jsonSchemaValue.Exists() {
|
||||||
|
if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok {
|
||||||
|
config.jsonSchema = schemaValue
|
||||||
|
|
||||||
|
} else {
|
||||||
|
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema is not valid"}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
config.jsonSchema = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.serviceDomain == "" {
|
||||||
|
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "service domain is empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.serviceClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||||
|
ServiceName: config.serviceName,
|
||||||
|
Port: int64(config.servicePort),
|
||||||
|
Domain: config.serviceDomain,
|
||||||
|
})
|
||||||
|
|
||||||
|
enableSwagger := result.Get("enableSwagger").Bool()
|
||||||
|
enableOas3 := result.Get("enableOas3").Bool()
|
||||||
|
|
||||||
|
// set draft version
|
||||||
|
if enableSwagger {
|
||||||
|
config.draft = jsonschema.Draft4
|
||||||
|
}
|
||||||
|
if enableOas3 {
|
||||||
|
config.draft = jsonschema.Draft7
|
||||||
|
}
|
||||||
|
if !enableSwagger && !enableOas3 {
|
||||||
|
config.draft = jsonschema.Draft7
|
||||||
|
}
|
||||||
|
|
||||||
|
// create compiler
|
||||||
|
compiler := jsonschema.NewCompiler()
|
||||||
|
compiler.Draft = config.draft
|
||||||
|
config.compiler = compiler
|
||||||
|
|
||||||
|
// set max depth of json schema
|
||||||
|
config.jsonSchemaMaxDepth = 6
|
||||||
|
|
||||||
|
enableContentDispositionValue := result.Get("enableContentDisposition")
|
||||||
|
if !enableContentDispositionValue.Exists() {
|
||||||
|
config.enableContentDisposition = true
|
||||||
|
} else {
|
||||||
|
config.enableContentDisposition = enableContentDispositionValue.Bool()
|
||||||
|
}
|
||||||
|
|
||||||
|
config.enableJsonSchemaValidation = true
|
||||||
|
|
||||||
|
jsonSchemaBytes, err := json.Marshal(config.jsonSchema)
|
||||||
|
if err != nil {
|
||||||
|
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema marshal failed"}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
maxDepth := GetMaxDepth(config.jsonSchema)
|
||||||
|
log.Debugf("max depth of json schema: %d", maxDepth)
|
||||||
|
if maxDepth > config.jsonSchemaMaxDepth {
|
||||||
|
config.enableJsonSchemaValidation = false
|
||||||
|
log.Infof("Json Schema depth exceeded: %d from %d , Json Schema validation will not be used.", maxDepth, config.jsonSchemaMaxDepth)
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.enableJsonSchemaValidation {
|
||||||
|
jsonSchemaStr := string(jsonSchemaBytes)
|
||||||
|
config.compiler.AddResource(DEFAULT_SCHEMA, strings.NewReader(jsonSchemaStr))
|
||||||
|
// Test if the Json Schema is valid
|
||||||
|
compile, err := config.compiler.Compile(DEFAULT_SCHEMA)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("Json Schema compile failed: %v", err)
|
||||||
|
config.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
|
||||||
|
config.compile = nil
|
||||||
|
} else {
|
||||||
|
config.compile = compile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RequestContext) assembleReqBody(config PluginConfig) []byte {
|
||||||
|
var reqBodystrut chatCompletionRequest
|
||||||
|
json.Unmarshal(r.ReqBody, &reqBodystrut)
|
||||||
|
content := gjson.ParseBytes(r.RespBody).Get(config.contentPath).String()
|
||||||
|
jsonSchemaBytes, _ := json.Marshal(config.jsonSchema)
|
||||||
|
jsonSchemaStr := string(jsonSchemaBytes)
|
||||||
|
|
||||||
|
askQuestion := "Given the Json Schema: " + jsonSchemaStr + ", please help me convert the following content to a pure json: " + content
|
||||||
|
askQuestion += "\n Do not respond other content except the pure json!!!!"
|
||||||
|
|
||||||
|
reqBodystrut.Messages = append(r.HistoryMessages, []chatMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: askQuestion,
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
reqBody, _ := json.Marshal(reqBodystrut)
|
||||||
|
return reqBody
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) {
|
||||||
|
r.RespBody = respBody
|
||||||
|
lastUserMessage := ""
|
||||||
|
lastSystemMessage := ""
|
||||||
|
|
||||||
|
var reqBodystrut chatCompletionRequest
|
||||||
|
err := json.Unmarshal(reqBody, &reqBodystrut)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("unmarshal reqBody failed: %v", err)
|
||||||
|
} else {
|
||||||
|
if len(reqBodystrut.Messages) != 0 {
|
||||||
|
lastUserMessage = reqBodystrut.Messages[len(reqBodystrut.Messages)-1].Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var respBodystrut chatCompletionResponse
|
||||||
|
err = json.Unmarshal(respBody, &respBodystrut)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("unmarshal respBody failed: %v", err)
|
||||||
|
} else {
|
||||||
|
if len(respBodystrut.Choices) != 0 {
|
||||||
|
lastSystemMessage = respBodystrut.Choices[len(respBodystrut.Choices)-1].Message.Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastUserMessage != "" {
|
||||||
|
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: lastUserMessage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastSystemMessage != "" {
|
||||||
|
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: lastSystemMessage,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) {
|
||||||
|
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||||
|
Role: "system",
|
||||||
|
Content: errMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PluginConfig) ValidateBody(body []byte) error {
|
||||||
|
var respJsonStrct chatCompletionResponse
|
||||||
|
err := json.Unmarshal(body, &respJsonStrct)
|
||||||
|
if err != nil {
|
||||||
|
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "service unavailable: " + string(body)}
|
||||||
|
return errors.New(c.rejectStruct.RejectMsg)
|
||||||
|
}
|
||||||
|
content := gjson.ParseBytes(body).Get(c.contentPath)
|
||||||
|
if !content.Exists() {
|
||||||
|
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "response body does not contain the content: " + string(body)}
|
||||||
|
return errors.New(c.rejectStruct.RejectMsg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) {
|
||||||
|
content := gjson.ParseBytes(body).Get(c.contentPath).String()
|
||||||
|
// first extract json from response body
|
||||||
|
if content == "" {
|
||||||
|
log.Infof("response body does not contain the content")
|
||||||
|
c.rejectStruct = RejectStruct{CONTENT_IS_EMPTY_CODE, "response body does not contain the content"}
|
||||||
|
return "", errors.New(c.rejectStruct.RejectMsg)
|
||||||
|
}
|
||||||
|
jsonStr, err := c.ExtractJson(content)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("response body does not contain the valid json: %v", err.Error())
|
||||||
|
c.rejectStruct = RejectStruct{CANNOT_FIND_JSON_IN_RESPONSE_CODE, "response body does not contain the valid json: " + err.Error()}
|
||||||
|
return "", errors.New(c.rejectStruct.RejectMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.jsonSchema != nil && c.enableJsonSchemaValidation {
|
||||||
|
compile, err := c.compiler.Compile(DEFAULT_SCHEMA)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("Json Schema compile failed: %v", err)
|
||||||
|
c.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
|
||||||
|
c.compile = nil
|
||||||
|
} else {
|
||||||
|
c.compile = compile
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate the json
|
||||||
|
err = c.compile.Validate(strings.NewReader(jsonStr))
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("response body does not match the Json Schema: %v", err)
|
||||||
|
c.rejectStruct = RejectStruct{JSON_MISMATCH_SCHEMA_CODE, "response body does not match the Json Schema: " + err.Error()}
|
||||||
|
return "", errors.New(c.rejectStruct.RejectMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
|
||||||
|
return jsonStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) {
|
||||||
|
// simply extract json from response body string
|
||||||
|
startIndex := strings.Index(bodyStr, "{")
|
||||||
|
endIndex := strings.LastIndex(bodyStr, "}") + 1
|
||||||
|
|
||||||
|
// if not found
|
||||||
|
if startIndex == -1 || endIndex == -1 || startIndex >= endIndex {
|
||||||
|
return "", errors.New("cannot find json in the response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonStr := bodyStr[startIndex:endIndex]
|
||||||
|
|
||||||
|
// attempt to parse the JSON
|
||||||
|
var result map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return jsonStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) {
|
||||||
|
log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body))
|
||||||
|
header := [][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
}
|
||||||
|
if body != nil && config.enableContentDisposition {
|
||||||
|
header = append(header, [2]string{"Content-Disposition", "attachment; filename=\"response.json\""})
|
||||||
|
}
|
||||||
|
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||||
|
proxywasm.SendHttpResponseWithDetail(HTTP_STATUS_INTERNAL_SERVER_ERROR, config.rejectStruct.GetShortMsg(), nil, config.rejectStruct.GetBytes(), -1)
|
||||||
|
} else {
|
||||||
|
proxywasm.SendHttpResponse(HTTP_STATUS_OK, header, body, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) {
|
||||||
|
// if retry count exceeds max retry count, return the response
|
||||||
|
if retryCount >= config.maxRetry {
|
||||||
|
log.Debugf("retry count exceeds max retry count")
|
||||||
|
// report more useful error by appending the last of previous error message
|
||||||
|
config.rejectStruct = RejectStruct{REACH_MAX_RETRY_COUNT_CODE, "retry count exceeds max retry count: " + config.rejectStruct.RejectMsg}
|
||||||
|
sendResponse(ctx, config, log, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// recursively refine json
|
||||||
|
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.assembleReqBody(config),
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
err := config.ValidateBody(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
sendResponse(ctx, config, log, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
retryCount++
|
||||||
|
requestContext.SaveBodyToHistMsg(log, requestContext.assembleReqBody(config), responseBody)
|
||||||
|
log.Debugf("[retry request %d/%d] resp code: %d", retryCount, config.maxRetry, statusCode)
|
||||||
|
validateJson, err := config.ValidateJson(responseBody, log)
|
||||||
|
if err == nil {
|
||||||
|
sendResponse(ctx, config, log, []byte(validateJson))
|
||||||
|
} else {
|
||||||
|
requestContext.SaveStrToHistMsg(log, err.Error())
|
||||||
|
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
|
||||||
|
}
|
||||||
|
}, uint32(config.serviceTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||||
|
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||||
|
sendResponse(ctx, config, log, nil)
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify if the request is from this plugin
|
||||||
|
extendHeaderValue, err := proxywasm.GetHttpRequestHeader(EXTEND_HEADER_KEY)
|
||||||
|
if err == nil {
|
||||||
|
fromThisPlugin, convErr := strconv.ParseBool(extendHeaderValue)
|
||||||
|
if convErr != nil {
|
||||||
|
log.Debugf("failed to parse header value as bool: %v", convErr)
|
||||||
|
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
|
||||||
|
}
|
||||||
|
if fromThisPlugin {
|
||||||
|
ctx.SetContext(FROM_THIS_PLUGIN_KEY, true)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := proxywasm.GetHttpRequestHeader(":path")
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("get request path failed: %v", err)
|
||||||
|
path = ""
|
||||||
|
} else {
|
||||||
|
ctx.SetContext("path", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers, err := proxywasm.GetHttpRequestHeaders()
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("get request header failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey, err := proxywasm.GetHttpRequestHeader("Authorization")
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("get request header failed: %v", err)
|
||||||
|
apiKey = ""
|
||||||
|
}
|
||||||
|
if apiKey != "" {
|
||||||
|
// remove the Authorization header
|
||||||
|
proxywasm.RemoveHttpRequestHeader("Authorization")
|
||||||
|
// remove the Authorization header from the headers
|
||||||
|
for i, header := range headers {
|
||||||
|
if header[0] == "Authorization" {
|
||||||
|
headers = append(headers[:i], headers[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.apiKey != "" {
|
||||||
|
log.Debugf("add Authorization header %s", "Bearer "+config.apiKey)
|
||||||
|
headers = append(headers, [2]string{"Authorization", "Bearer " + config.apiKey})
|
||||||
|
}
|
||||||
|
ctx.SetContext("headers", headers)
|
||||||
|
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||||
|
// if the request is from this plugin, continue the request
|
||||||
|
fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool)
|
||||||
|
if ok && fromThisPlugin {
|
||||||
|
log.Debugf("detected buffer_request, sending request to AI service")
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
var headers [][2]string
|
||||||
|
if h, ok := ctx.GetContext("headers").([][2]string); ok {
|
||||||
|
headers = append(h, [2]string{EXTEND_HEADER_KEY, "true"})
|
||||||
|
} else {
|
||||||
|
log.Debugf("cannot get headers from context, use default headers")
|
||||||
|
headers = [][2]string{
|
||||||
|
{"Content-Type", "application/json"},
|
||||||
|
{EXTEND_HEADER_KEY, "true"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if there is any error in the config, return the response directly
|
||||||
|
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||||
|
sendResponse(ctx, config, log, nil)
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if path, ok := ctx.GetContext("path").(string); ok {
|
||||||
|
log.Debugf("use path: %s", path)
|
||||||
|
} else {
|
||||||
|
log.Debugf("cannot get path from context, use default path")
|
||||||
|
path = "/v1/chat/completions"
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.servicePath != "" {
|
||||||
|
log.Debugf("use base path: %s", config.servicePath)
|
||||||
|
path = config.servicePath
|
||||||
|
}
|
||||||
|
|
||||||
|
requestContext := &RequestContext{
|
||||||
|
Path: path,
|
||||||
|
ReqHeaders: headers,
|
||||||
|
ReqBody: body,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.ReqBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
err := config.ValidateBody(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
sendResponse(ctx, config, log, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
requestContext.SaveBodyToHistMsg(log, body, responseBody)
|
||||||
|
log.Debugf("[first request] resp code: %d", statusCode)
|
||||||
|
validateJson, err := config.ValidateJson(responseBody, log)
|
||||||
|
if err == nil {
|
||||||
|
sendResponse(ctx, config, log, []byte(validateJson))
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
retryCount := 0
|
||||||
|
requestContext.SaveStrToHistMsg(log, err.Error())
|
||||||
|
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
|
||||||
|
}
|
||||||
|
}, uint32(config.serviceTimeout))
|
||||||
|
|
||||||
|
return types.ActionPause
|
||||||
|
}
|
||||||
180
plugins/wasm-go/extensions/ai-json-resp/model.go
Normal file
180
plugins/wasm-go/extensions/ai-json-resp/model.go
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
// adopt from https://github.com/alibaba/higress/blob/main/plugins/wasm-go/extensions/ai-proxy/provider/model.go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
const (
|
||||||
|
streamEventIdItemKey = "id:"
|
||||||
|
streamEventNameItemKey = "event:"
|
||||||
|
streamBuiltInItemKey = ":"
|
||||||
|
streamHttpStatusValuePrefix = "HTTP_STATUS/"
|
||||||
|
streamDataItemKey = "data:"
|
||||||
|
streamEndDataValue = "[DONE]"
|
||||||
|
)
|
||||||
|
|
||||||
|
type chatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []chatMessage `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
Seed int `json:"seed,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
Tools []tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamOptions struct {
|
||||||
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function function `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type function struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Parameters map[string]interface{} `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type toolChoice struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function function `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletionResponse struct {
|
||||||
|
Id string `json:"id,omitempty"`
|
||||||
|
Choices []chatCompletionChoice `json:"choices"`
|
||||||
|
Created int64 `json:"created,omitempty"`
|
||||||
|
Model string `json:"model,omitempty"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||||
|
Object string `json:"object,omitempty"`
|
||||||
|
Usage usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatCompletionChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message *chatMessage `json:"message,omitempty"`
|
||||||
|
Delta *chatMessage `json:"delta,omitempty"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens,omitempty"`
|
||||||
|
CompletionTokens int `json:"completion_tokens,omitempty"`
|
||||||
|
TotalTokens int `json:"total_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatMessage struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *chatMessage) IsEmpty() bool {
|
||||||
|
if m.Content != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(m.ToolCalls) != 0 {
|
||||||
|
nonEmpty := false
|
||||||
|
for _, toolCall := range m.ToolCalls {
|
||||||
|
if !toolCall.Function.IsEmpty() {
|
||||||
|
nonEmpty = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nonEmpty {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type toolCall struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Id string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function functionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type functionCall struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *functionCall) IsEmpty() bool {
|
||||||
|
return m.Name == "" && m.Arguments == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamEvent struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Event string `json:"event"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
HttpStatus string `json:"http_status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *streamEvent) setValue(key, value string) {
|
||||||
|
switch key {
|
||||||
|
case streamEventIdItemKey:
|
||||||
|
e.Id = value
|
||||||
|
case streamEventNameItemKey:
|
||||||
|
e.Event = value
|
||||||
|
case streamDataItemKey:
|
||||||
|
e.Data = value
|
||||||
|
case streamBuiltInItemKey:
|
||||||
|
if strings.HasPrefix(value, streamHttpStatusValuePrefix) {
|
||||||
|
e.HttpStatus = value[len(streamHttpStatusValuePrefix):]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type embeddingsRequest struct {
|
||||||
|
Input interface{} `json:"input"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type embeddingsResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []embedding `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type embedding struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r embeddingsRequest) ParseInput() []string {
|
||||||
|
if r.Input == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var input []string
|
||||||
|
switch r.Input.(type) {
|
||||||
|
case string:
|
||||||
|
input = []string{r.Input.(string)}
|
||||||
|
case []any:
|
||||||
|
input = make([]string, 0, len(r.Input.([]any)))
|
||||||
|
for _, item := range r.Input.([]any) {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
input = append(input, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input
|
||||||
|
}
|
||||||
33
plugins/wasm-go/extensions/ai-json-resp/util.go
Normal file
33
plugins/wasm-go/extensions/ai-json-resp/util.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
func GetMaxDepth(data interface{}) int {
|
||||||
|
type item struct {
|
||||||
|
value interface{}
|
||||||
|
depth int
|
||||||
|
}
|
||||||
|
|
||||||
|
maxDepth := 0
|
||||||
|
stack := []item{{value: data, depth: 1}}
|
||||||
|
|
||||||
|
for len(stack) > 0 {
|
||||||
|
currentItem := stack[len(stack)-1]
|
||||||
|
stack = stack[:len(stack)-1]
|
||||||
|
|
||||||
|
if currentItem.depth > maxDepth {
|
||||||
|
maxDepth = currentItem.depth
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := currentItem.value.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
for _, value := range v {
|
||||||
|
stack = append(stack, item{value: value, depth: currentItem.depth + 1})
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for _, value := range v {
|
||||||
|
stack = append(stack, item{value: value, depth: currentItem.depth + 1})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxDepth
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user