mirror of
https://github.com/alibaba/higress.git
synced 2026-03-01 23:20:52 +08:00
[ai-json-resp] Extract JSON from LLM, Validate with Schema, Ensure Valid JSON, Auto-Retry (#1236)
This commit is contained in:
202
plugins/wasm-go/extensions/ai-json-resp/README.md
Normal file
202
plugins/wasm-go/extensions/ai-json-resp/README.md
Normal file
@@ -0,0 +1,202 @@
|
||||
## 简介
|
||||
|
||||
**Note**
|
||||
|
||||
> 需要数据面的proxy wasm版本大于等于0.2.100
|
||||
>
|
||||
|
||||
> 编译时,需要带上版本的tag,例如:tinygo build -o main.wasm -scheduler=none -target=wasi -gc=custom -tags="custommalloc nottinygc_finalizer proxy_wasm_version_0_2_100" ./
|
||||
|
||||
|
||||
LLM响应结构化插件,用于根据默认或用户配置的Json Schema对AI的响应进行结构化,以便后续插件处理。注意目前只支持 `非流式响应`。
|
||||
|
||||
|
||||
### 配置说明
|
||||
|
||||
| Name | Type | Requirement | Default | **Description** |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| serviceName | str | required | - | AI服务或支持AI-Proxy的网关服务名称 |
|
||||
| serviceDomain | str | optional | - | AI服务或支持AI-Proxy的网关服务域名/IP地址 |
|
||||
| servicePath | str | optional | '/v1/chat/completions' | AI服务或支持AI-Proxy的网关服务基础路径 |
|
||||
| serviceUrl | str | optional | - | AI服务或支持 AI-Proxy 的网关服务URL, 插件将自动提取Domain 和 Path, 用于填充未配置的 serviceDomain 或 servicePath |
|
||||
| servicePort | int | optional | 443 | 网关服务端口 |
|
||||
| serviceTimeout | int | optional | 50000 | 默认请求超时时间 |
|
||||
| maxRetry | int | optional | 3 | 若回答无法正确提取格式化时重试次数 |
|
||||
| contentPath | str | optional | "choices.0.message.content” | 从LLM回答中提取响应结果的gpath路径 |
|
||||
| jsonSchema | str (json) | optional | - | 验证请求所参照的 jsonSchema, 为空只验证并返回合法Json格式响应 |
|
||||
| enableSwagger | bool | optional | false | 是否启用 Swagger 协议进行验证 |
|
||||
| enableOas3 | bool | optional | true | 是否启用 Oas3 协议进行验证 |
|
||||
| enableContentDisposition | bool | optional | true | 是否启用 Content-Disposition 头部, 若启用则会在响应头中添加 `Content-Disposition: attachment; filename="response.json"` |
|
||||
|
||||
> 出于性能考虑,默认支持的最大 Json Schema 深度为 6。超过此深度的 Json Schema 将不用于验证响应,插件只会检查返回的响应是否为合法的 Json 格式。
|
||||
|
||||
|
||||
### 请求和返回参数说明
|
||||
|
||||
- **请求参数**: 本插件请求格式为openai请求格式,包含`model`和`messages`字段,其中`model`为AI模型名称,`messages`为对话消息列表,每个消息包含`role`和`content`字段,`role`为消息角色,`content`为消息内容。
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "give me a api doc for add the variable x to x+5"}
|
||||
]
|
||||
}
|
||||
```
|
||||
其他请求参数需参考配置的ai服务或网关服务的相应文档。
|
||||
- **返回参数**:
|
||||
- 返回满足定义的Json Schema约束的 `Json格式响应`
|
||||
- 若未定义Json Schema,则返回合法的`Json格式响应`
|
||||
- 若出现内部错误,则返回 `{ "Code": 10XX, "Msg": "错误信息提示" }`。
|
||||
|
||||
## 请求示例
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8001/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{"role": "user", "content": "give me a api doc for add the variable x to x+5"}
|
||||
]
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
## 返回示例
|
||||
### 正常返回
|
||||
在正常情况下,系统应返回经过 JSON Schema 验证的 JSON 数据。如果未配置 JSON Schema,系统将返回符合 JSON 标准的合法 JSON 数据。
|
||||
```json
|
||||
{
|
||||
"apiVersion": "1.0",
|
||||
"request": {
|
||||
"endpoint": "/add_to_five",
|
||||
"method": "POST",
|
||||
"port": 8080,
|
||||
"headers": {
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
"body": {
|
||||
"x": 7
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 异常返回
|
||||
在发生错误时,返回状态码为 `500`,返回内容为 JSON 格式的错误信息。包含错误码 `Code` 和错误信息 `Msg` 两个字段。
|
||||
```json
|
||||
{
|
||||
"Code": 1006,
|
||||
"Msg": "retry count exceed max retry count"
|
||||
}
|
||||
```
|
||||
|
||||
### 错误码说明
|
||||
| 错误码 | 说明 |
|
||||
| --- | --- |
|
||||
| 1001 | 配置的Json Schema不是合法Json格式|
|
||||
| 1002 | 配置的Json Schema编译失败,不是合法的Json Schema 格式或深度超出 jsonSchemaMaxDepth 且 rejectOnDepthExceeded 为true|
|
||||
| 1003 | 无法在响应中提取合法的Json|
|
||||
| 1004 | 响应为空字符串|
|
||||
| 1005 | 响应不符合Json Schema定义|
|
||||
| 1006 | 重试次数超过最大限制|
|
||||
| 1007 | 无法获取响应内容,可能是上游服务配置错误或获取内容的ContentPath路径错误|
|
||||
| 1008 | serciveDomain为空, 请注意serviceDomian或serviceUrl不能同时为空|
|
||||
|
||||
## 服务配置说明
|
||||
本插件需要配置上游服务来支持出现异常时的自动重试机制, 支持的配置主要包括`支持openai接口的AI服务`或`本地网关服务`
|
||||
|
||||
### 支持openai接口的AI服务
|
||||
以qwen为例,基本配置如下:
|
||||
|
||||
Yaml格式配置如下
|
||||
```yaml
|
||||
serviceName: qwen
|
||||
serviceDomain: dashscope.aliyuncs.com
|
||||
apiKey: [Your API Key]
|
||||
servicePath: /compatible-mode/v1/chat/completions
|
||||
jsonSchema:
|
||||
title: ReasoningSchema
|
||||
type: object
|
||||
properties:
|
||||
reasoning_steps:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: The reasoning steps leading to the final conclusion.
|
||||
answer:
|
||||
type: string
|
||||
description: The final answer, taking into account the reasoning steps.
|
||||
required:
|
||||
- reasoning_steps
|
||||
- answer
|
||||
additionalProperties: false
|
||||
```
|
||||
|
||||
JSON 格式配置
|
||||
```json
|
||||
{
|
||||
"serviceName": "qwen",
|
||||
"serviceUrl": "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
"apiKey": "[Your API Key]",
|
||||
"jsonSchema": {
|
||||
"title": "ActionItemsSchema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action_items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Description of the action item."
|
||||
},
|
||||
"due_date": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Due date for the action item, can be null if not specified."
|
||||
},
|
||||
"owner": {
|
||||
"type": ["string", "null"],
|
||||
"description": "Owner responsible for the action item, can be null if not specified."
|
||||
}
|
||||
},
|
||||
"required": ["description", "due_date", "owner"],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"description": "List of action items from the meeting."
|
||||
}
|
||||
},
|
||||
"required": ["action_items"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 本地网关服务
|
||||
为了能复用已经配置好的服务,本插件也支持配置本地网关服务。例如,若网关已经配置好了[AI-proxy服务](../ai-proxy/README.md),则可以直接配置如下:
|
||||
1. 创建一个固定IP为127.0.0.1的服务,例如localservice.static
|
||||
```yaml
|
||||
- name: outbound|10000||localservice.static
|
||||
connect_timeout: 30s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: outbound|8001||localservice.static
|
||||
endpoints:
|
||||
- lb_endpoints:
|
||||
- endpoint:
|
||||
address:
|
||||
socket_address:
|
||||
address: 127.0.0.1
|
||||
port_value: 10000
|
||||
```
|
||||
2. 配置文件中添加localservice.static的服务配置
|
||||
```yaml
|
||||
serviceName: localservice
|
||||
serviceDomain: 127.0.0.1
|
||||
servicePort: 10000
|
||||
```
|
||||
3. 自动提取请求的Path,Header等信息
|
||||
插件会自动提取请求的Path,Header等信息,从而避免对AI服务的重复配置。
|
||||
21
plugins/wasm-go/extensions/ai-json-resp/go.mod
Normal file
21
plugins/wasm-go/extensions/ai-json-resp/go.mod
Normal file
@@ -0,0 +1,21 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/hello-world
|
||||
|
||||
go 1.18
|
||||
|
||||
replace github.com/alibaba/higress/plugins/wasm-go => ../..
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.2
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4 // indirect
|
||||
github.com/tidwall/gjson v1.14.3 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
)
|
||||
26
plugins/wasm-go/extensions/ai-json-resp/go.sum
Normal file
26
plugins/wasm-go/extensions/ai-json-resp/go.sum
Normal file
@@ -0,0 +1,26 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240318034951-d5306e367c43/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240327114451-d6b7174a84fc/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f h1:ZIiIBRvIw62gA5MJhuwp1+2wWbqL9IGElQ499rUsYYg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240711023527-ba358c48772f/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
|
||||
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
573
plugins/wasm-go/extensions/ai-json-resp/main.go
Normal file
573
plugins/wasm-go/extensions/ai-json-resp/main.go
Normal file
@@ -0,0 +1,573 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/santhosh-tekuri/jsonschema"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
const (
|
||||
DEFAULT_SCHEMA = "defaultSchema"
|
||||
HTTP_STATUS_OK = uint32(200)
|
||||
HTTP_STATUS_INTERNAL_SERVER_ERROR = uint32(500)
|
||||
FROM_THIS_PLUGIN_KEY = "fromThisPlugin"
|
||||
EXTEND_HEADER_KEY = "X-HIGRESS-AI-JSON-RESP"
|
||||
|
||||
JSON_SCHEMA_INVALID_CODE = 1001
|
||||
JSON_SCHEMA_COMPILE_FAILED_CODE = 1002
|
||||
CANNOT_FIND_JSON_IN_RESPONSE_CODE = 1003
|
||||
CONTENT_IS_EMPTY_CODE = 1004
|
||||
JSON_MISMATCH_SCHEMA_CODE = 1005
|
||||
REACH_MAX_RETRY_COUNT_CODE = 1006
|
||||
SERVICE_UNAVAILABLE_CODE = 1007
|
||||
SERVICE_CONFIG_INVALID_CODE = 1008
|
||||
)
|
||||
|
||||
type RejectStruct struct {
|
||||
RejectCode uint32 `json:"Code"`
|
||||
RejectMsg string `json:"Msg"`
|
||||
}
|
||||
|
||||
func (r RejectStruct) GetBytes() []byte {
|
||||
jsonData, _ := json.Marshal(r)
|
||||
return jsonData
|
||||
}
|
||||
|
||||
func (r RejectStruct) GetShortMsg() string {
|
||||
return "ai-json-resp." + strings.Split(r.RejectMsg, ":")[0]
|
||||
}
|
||||
|
||||
type PluginConfig struct {
|
||||
// @Title zh-CN 服务名称
|
||||
// @Description zh-CN 用以请求服务的名称(网关或其他AI服务)
|
||||
serviceName string `required:"true" json:"serviceName" yaml:"serviceName"`
|
||||
// @Title zh-CN 服务域名
|
||||
// @Description zh-CN 用以请求服务的域名
|
||||
serviceDomain string `required:"false" json:"serviceDomain" yaml:"serviceDomain"`
|
||||
// @Title zh-CN 服务端口
|
||||
// @Description zh-CN 用以请求服务的端口
|
||||
servicePort int `required:"false" json:"servicePort" yaml:"servicePort"`
|
||||
// @Title zh-CN 服务URL
|
||||
// @Description zh-CN 用以请求服务的URL,若提供则会覆盖serviceDomain和servicePort
|
||||
serviceUrl string `required:"false" json:"serviceUrl" yaml:"serviceUrl"`
|
||||
// @Title zh-CN API Key
|
||||
// @Description zh-CN 若使用AI服务,需要填写请求服务的API Key
|
||||
apiKey string `required:"false" json: "apiKey" yaml:"apiKey"`
|
||||
// @Title zh-CN 请求端点
|
||||
// @Description zh-CN 用以请求服务的端点, 默认为"/v1/chat/completions"
|
||||
servicePath string `required:"false" json: "servicePath" yaml:"servicePath"`
|
||||
// @Title zh-CN 服务超时时间
|
||||
// @Description zh-CN 用以请求服务的超时时间
|
||||
serviceTimeout int `required:"false" json:"serviceTimeout" yaml:"serviceTimeout"`
|
||||
// @Title zh-CN 最大重试次数
|
||||
// @Description zh-CN 用以请求服务的最大重试次数
|
||||
maxRetry int `required:"false" json:"maxRetry" yaml:"maxRetry"`
|
||||
// @Title zh-CN 内容路径
|
||||
// @Description zh-CN 从AI服务返回的响应中提取json的gpath路径
|
||||
contentPath string `required:"false" json:"contentPath" yaml:"contentPath"`
|
||||
// @Title zh-CN Json Schema
|
||||
// @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json
|
||||
jsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"`
|
||||
// @Title zh-CN 是否启用swagger
|
||||
// @Description zh-CN 是否启用swagger进行Json Schema验证
|
||||
enableSwagger bool `required:"false" json:"enableSwagger" yaml:"enableSwagger"`
|
||||
// @Title zh-CN 是否启用oas3
|
||||
// @Description zh-CN 是否启用oas3进行Json Schema验证
|
||||
enableOas3 bool `required:"false" json:"enableOas3" yaml:"enableOas3"`
|
||||
// @Title zh-CN 是否启用Content-Disposition
|
||||
// @Description zh-CN 是否启用Content-Disposition, 若启用则会在响应头中添加Content-Disposition: attachment; filename="response.json"
|
||||
enableContentDisposition bool `required:"false" json:"enableContentDisposition" yaml:"enableContentDisposition"`
|
||||
|
||||
serviceClient wrapper.HttpClient
|
||||
draft *jsonschema.Draft
|
||||
compiler *jsonschema.Compiler
|
||||
compile *jsonschema.Schema
|
||||
rejectStruct RejectStruct
|
||||
jsonSchemaMaxDepth int
|
||||
enableJsonSchemaValidation bool
|
||||
}
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-json-resp",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
type RequestContext struct {
|
||||
Path string
|
||||
ReqHeaders [][2]string
|
||||
ReqBody []byte
|
||||
RespHeader [][2]string
|
||||
RespBody []byte
|
||||
HistoryMessages []chatMessage
|
||||
}
|
||||
|
||||
func parseUrl(url string) (string, string) {
|
||||
if url == "" {
|
||||
return "", ""
|
||||
}
|
||||
url = strings.TrimPrefix(url, "http://")
|
||||
url = strings.TrimPrefix(url, "https://")
|
||||
index := strings.Index(url, "/")
|
||||
if index == -1 {
|
||||
return url, ""
|
||||
}
|
||||
return url[:index], url[index:]
|
||||
}
|
||||
|
||||
func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error {
|
||||
config.serviceName = result.Get("serviceName").String()
|
||||
config.serviceUrl = result.Get("serviceUrl").String()
|
||||
config.serviceDomain = result.Get("serviceDomain").String()
|
||||
config.servicePath = result.Get("servicePath").String()
|
||||
config.servicePort = int(result.Get("servicePort").Int())
|
||||
if config.serviceUrl != "" {
|
||||
domain, url := parseUrl(config.serviceUrl)
|
||||
log.Debugf("serviceUrl: %s, the parsed domain: %s, the parsed url: %s", config.serviceUrl, domain, url)
|
||||
if config.serviceDomain == "" {
|
||||
config.serviceDomain = domain
|
||||
}
|
||||
if config.servicePath == "" {
|
||||
config.servicePath = url
|
||||
}
|
||||
}
|
||||
if config.servicePort == 0 {
|
||||
config.servicePort = 443
|
||||
}
|
||||
config.serviceTimeout = int(result.Get("serviceTimeout").Int())
|
||||
config.apiKey = result.Get("apiKey").String()
|
||||
config.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
|
||||
if config.serviceTimeout == 0 {
|
||||
config.serviceTimeout = 50000
|
||||
}
|
||||
config.maxRetry = int(result.Get("maxRetry").Int())
|
||||
if config.maxRetry == 0 {
|
||||
config.maxRetry = 3
|
||||
}
|
||||
config.contentPath = result.Get("contentPath").String()
|
||||
if config.contentPath == "" {
|
||||
config.contentPath = "choices.0.message.content"
|
||||
}
|
||||
|
||||
if jsonSchemaValue := result.Get("jsonSchema"); jsonSchemaValue.Exists() {
|
||||
if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok {
|
||||
config.jsonSchema = schemaValue
|
||||
|
||||
} else {
|
||||
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema is not valid"}
|
||||
}
|
||||
} else {
|
||||
config.jsonSchema = nil
|
||||
}
|
||||
|
||||
if config.serviceDomain == "" {
|
||||
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "service domain is empty"}
|
||||
}
|
||||
|
||||
config.serviceClient = wrapper.NewClusterClient(wrapper.DnsCluster{
|
||||
ServiceName: config.serviceName,
|
||||
Port: int64(config.servicePort),
|
||||
Domain: config.serviceDomain,
|
||||
})
|
||||
|
||||
enableSwagger := result.Get("enableSwagger").Bool()
|
||||
enableOas3 := result.Get("enableOas3").Bool()
|
||||
|
||||
// set draft version
|
||||
if enableSwagger {
|
||||
config.draft = jsonschema.Draft4
|
||||
}
|
||||
if enableOas3 {
|
||||
config.draft = jsonschema.Draft7
|
||||
}
|
||||
if !enableSwagger && !enableOas3 {
|
||||
config.draft = jsonschema.Draft7
|
||||
}
|
||||
|
||||
// create compiler
|
||||
compiler := jsonschema.NewCompiler()
|
||||
compiler.Draft = config.draft
|
||||
config.compiler = compiler
|
||||
|
||||
// set max depth of json schema
|
||||
config.jsonSchemaMaxDepth = 6
|
||||
|
||||
enableContentDispositionValue := result.Get("enableContentDisposition")
|
||||
if !enableContentDispositionValue.Exists() {
|
||||
config.enableContentDisposition = true
|
||||
} else {
|
||||
config.enableContentDisposition = enableContentDispositionValue.Bool()
|
||||
}
|
||||
|
||||
config.enableJsonSchemaValidation = true
|
||||
|
||||
jsonSchemaBytes, err := json.Marshal(config.jsonSchema)
|
||||
if err != nil {
|
||||
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema marshal failed"}
|
||||
return err
|
||||
}
|
||||
|
||||
maxDepth := GetMaxDepth(config.jsonSchema)
|
||||
log.Debugf("max depth of json schema: %d", maxDepth)
|
||||
if maxDepth > config.jsonSchemaMaxDepth {
|
||||
config.enableJsonSchemaValidation = false
|
||||
log.Infof("Json Schema depth exceeded: %d from %d , Json Schema validation will not be used.", maxDepth, config.jsonSchemaMaxDepth)
|
||||
}
|
||||
|
||||
if config.enableJsonSchemaValidation {
|
||||
jsonSchemaStr := string(jsonSchemaBytes)
|
||||
config.compiler.AddResource(DEFAULT_SCHEMA, strings.NewReader(jsonSchemaStr))
|
||||
// Test if the Json Schema is valid
|
||||
compile, err := config.compiler.Compile(DEFAULT_SCHEMA)
|
||||
if err != nil {
|
||||
log.Infof("Json Schema compile failed: %v", err)
|
||||
config.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
|
||||
config.compile = nil
|
||||
} else {
|
||||
config.compile = compile
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RequestContext) assembleReqBody(config PluginConfig) []byte {
|
||||
var reqBodystrut chatCompletionRequest
|
||||
json.Unmarshal(r.ReqBody, &reqBodystrut)
|
||||
content := gjson.ParseBytes(r.RespBody).Get(config.contentPath).String()
|
||||
jsonSchemaBytes, _ := json.Marshal(config.jsonSchema)
|
||||
jsonSchemaStr := string(jsonSchemaBytes)
|
||||
|
||||
askQuestion := "Given the Json Schema: " + jsonSchemaStr + ", please help me convert the following content to a pure json: " + content
|
||||
askQuestion += "\n Do not respond other content except the pure json!!!!"
|
||||
|
||||
reqBodystrut.Messages = append(r.HistoryMessages, []chatMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: askQuestion,
|
||||
},
|
||||
}...)
|
||||
|
||||
reqBody, _ := json.Marshal(reqBodystrut)
|
||||
return reqBody
|
||||
}
|
||||
|
||||
func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) {
|
||||
r.RespBody = respBody
|
||||
lastUserMessage := ""
|
||||
lastSystemMessage := ""
|
||||
|
||||
var reqBodystrut chatCompletionRequest
|
||||
err := json.Unmarshal(reqBody, &reqBodystrut)
|
||||
if err != nil {
|
||||
log.Debugf("unmarshal reqBody failed: %v", err)
|
||||
} else {
|
||||
if len(reqBodystrut.Messages) != 0 {
|
||||
lastUserMessage = reqBodystrut.Messages[len(reqBodystrut.Messages)-1].Content
|
||||
}
|
||||
}
|
||||
|
||||
var respBodystrut chatCompletionResponse
|
||||
err = json.Unmarshal(respBody, &respBodystrut)
|
||||
if err != nil {
|
||||
log.Debugf("unmarshal respBody failed: %v", err)
|
||||
} else {
|
||||
if len(respBodystrut.Choices) != 0 {
|
||||
lastSystemMessage = respBodystrut.Choices[len(respBodystrut.Choices)-1].Message.Content
|
||||
}
|
||||
}
|
||||
|
||||
if lastUserMessage != "" {
|
||||
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||
Role: "user",
|
||||
Content: lastUserMessage,
|
||||
})
|
||||
}
|
||||
|
||||
if lastSystemMessage != "" {
|
||||
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||
Role: "system",
|
||||
Content: lastSystemMessage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) {
|
||||
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
|
||||
Role: "system",
|
||||
Content: errMsg,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *PluginConfig) ValidateBody(body []byte) error {
|
||||
var respJsonStrct chatCompletionResponse
|
||||
err := json.Unmarshal(body, &respJsonStrct)
|
||||
if err != nil {
|
||||
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "service unavailable: " + string(body)}
|
||||
return errors.New(c.rejectStruct.RejectMsg)
|
||||
}
|
||||
content := gjson.ParseBytes(body).Get(c.contentPath)
|
||||
if !content.Exists() {
|
||||
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "response body does not contain the content: " + string(body)}
|
||||
return errors.New(c.rejectStruct.RejectMsg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) {
|
||||
content := gjson.ParseBytes(body).Get(c.contentPath).String()
|
||||
// first extract json from response body
|
||||
if content == "" {
|
||||
log.Infof("response body does not contain the content")
|
||||
c.rejectStruct = RejectStruct{CONTENT_IS_EMPTY_CODE, "response body does not contain the content"}
|
||||
return "", errors.New(c.rejectStruct.RejectMsg)
|
||||
}
|
||||
jsonStr, err := c.ExtractJson(content)
|
||||
|
||||
if err != nil {
|
||||
log.Infof("response body does not contain the valid json: %v", err.Error())
|
||||
c.rejectStruct = RejectStruct{CANNOT_FIND_JSON_IN_RESPONSE_CODE, "response body does not contain the valid json: " + err.Error()}
|
||||
return "", errors.New(c.rejectStruct.RejectMsg)
|
||||
}
|
||||
|
||||
if c.jsonSchema != nil && c.enableJsonSchemaValidation {
|
||||
compile, err := c.compiler.Compile(DEFAULT_SCHEMA)
|
||||
if err != nil {
|
||||
log.Infof("Json Schema compile failed: %v", err)
|
||||
c.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
|
||||
c.compile = nil
|
||||
} else {
|
||||
c.compile = compile
|
||||
}
|
||||
|
||||
// validate the json
|
||||
err = c.compile.Validate(strings.NewReader(jsonStr))
|
||||
if err != nil {
|
||||
log.Infof("response body does not match the Json Schema: %v", err)
|
||||
c.rejectStruct = RejectStruct{JSON_MISMATCH_SCHEMA_CODE, "response body does not match the Json Schema: " + err.Error()}
|
||||
return "", errors.New(c.rejectStruct.RejectMsg)
|
||||
}
|
||||
}
|
||||
c.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
|
||||
return jsonStr, nil
|
||||
}
|
||||
|
||||
func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) {
|
||||
// simply extract json from response body string
|
||||
startIndex := strings.Index(bodyStr, "{")
|
||||
endIndex := strings.LastIndex(bodyStr, "}") + 1
|
||||
|
||||
// if not found
|
||||
if startIndex == -1 || endIndex == -1 || startIndex >= endIndex {
|
||||
return "", errors.New("cannot find json in the response body")
|
||||
}
|
||||
|
||||
jsonStr := bodyStr[startIndex:endIndex]
|
||||
|
||||
// attempt to parse the JSON
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return jsonStr, nil
|
||||
}
|
||||
|
||||
func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) {
|
||||
log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body))
|
||||
header := [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
if body != nil && config.enableContentDisposition {
|
||||
header = append(header, [2]string{"Content-Disposition", "attachment; filename=\"response.json\""})
|
||||
}
|
||||
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||
proxywasm.SendHttpResponseWithDetail(HTTP_STATUS_INTERNAL_SERVER_ERROR, config.rejectStruct.GetShortMsg(), nil, config.rejectStruct.GetBytes(), -1)
|
||||
} else {
|
||||
proxywasm.SendHttpResponse(HTTP_STATUS_OK, header, body, -1)
|
||||
}
|
||||
}
|
||||
|
||||
func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) {
|
||||
// if retry count exceeds max retry count, return the response
|
||||
if retryCount >= config.maxRetry {
|
||||
log.Debugf("retry count exceeds max retry count")
|
||||
// report more useful error by appending the last of previous error message
|
||||
config.rejectStruct = RejectStruct{REACH_MAX_RETRY_COUNT_CODE, "retry count exceeds max retry count: " + config.rejectStruct.RejectMsg}
|
||||
sendResponse(ctx, config, log, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// recursively refine json
|
||||
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.assembleReqBody(config),
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
err := config.ValidateBody(responseBody)
|
||||
if err != nil {
|
||||
sendResponse(ctx, config, log, nil)
|
||||
return
|
||||
}
|
||||
retryCount++
|
||||
requestContext.SaveBodyToHistMsg(log, requestContext.assembleReqBody(config), responseBody)
|
||||
log.Debugf("[retry request %d/%d] resp code: %d", retryCount, config.maxRetry, statusCode)
|
||||
validateJson, err := config.ValidateJson(responseBody, log)
|
||||
if err == nil {
|
||||
sendResponse(ctx, config, log, []byte(validateJson))
|
||||
} else {
|
||||
requestContext.SaveStrToHistMsg(log, err.Error())
|
||||
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
|
||||
}
|
||||
}, uint32(config.serviceTimeout))
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||
sendResponse(ctx, config, log, nil)
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
// verify if the request is from this plugin
|
||||
extendHeaderValue, err := proxywasm.GetHttpRequestHeader(EXTEND_HEADER_KEY)
|
||||
if err == nil {
|
||||
fromThisPlugin, convErr := strconv.ParseBool(extendHeaderValue)
|
||||
if convErr != nil {
|
||||
log.Debugf("failed to parse header value as bool: %v", convErr)
|
||||
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
|
||||
}
|
||||
if fromThisPlugin {
|
||||
ctx.SetContext(FROM_THIS_PLUGIN_KEY, true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
} else {
|
||||
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
|
||||
}
|
||||
|
||||
path, err := proxywasm.GetHttpRequestHeader(":path")
|
||||
if err != nil {
|
||||
log.Infof("get request path failed: %v", err)
|
||||
path = ""
|
||||
} else {
|
||||
ctx.SetContext("path", path)
|
||||
}
|
||||
|
||||
headers, err := proxywasm.GetHttpRequestHeaders()
|
||||
if err != nil {
|
||||
log.Infof("get request header failed: %v", err)
|
||||
}
|
||||
|
||||
apiKey, err := proxywasm.GetHttpRequestHeader("Authorization")
|
||||
if err != nil {
|
||||
log.Infof("get request header failed: %v", err)
|
||||
apiKey = ""
|
||||
}
|
||||
if apiKey != "" {
|
||||
// remove the Authorization header
|
||||
proxywasm.RemoveHttpRequestHeader("Authorization")
|
||||
// remove the Authorization header from the headers
|
||||
for i, header := range headers {
|
||||
if header[0] == "Authorization" {
|
||||
headers = append(headers[:i], headers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if config.apiKey != "" {
|
||||
log.Debugf("add Authorization header %s", "Bearer "+config.apiKey)
|
||||
headers = append(headers, [2]string{"Authorization", "Bearer " + config.apiKey})
|
||||
}
|
||||
ctx.SetContext("headers", headers)
|
||||
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||||
// if the request is from this plugin, continue the request
|
||||
fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool)
|
||||
if ok && fromThisPlugin {
|
||||
log.Debugf("detected buffer_request, sending request to AI service")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
var headers [][2]string
|
||||
if h, ok := ctx.GetContext("headers").([][2]string); ok {
|
||||
headers = append(h, [2]string{EXTEND_HEADER_KEY, "true"})
|
||||
} else {
|
||||
log.Debugf("cannot get headers from context, use default headers")
|
||||
headers = [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
{EXTEND_HEADER_KEY, "true"},
|
||||
}
|
||||
}
|
||||
|
||||
// if there is any error in the config, return the response directly
|
||||
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
|
||||
sendResponse(ctx, config, log, nil)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
var path string
|
||||
if path, ok := ctx.GetContext("path").(string); ok {
|
||||
log.Debugf("use path: %s", path)
|
||||
} else {
|
||||
log.Debugf("cannot get path from context, use default path")
|
||||
path = "/v1/chat/completions"
|
||||
}
|
||||
|
||||
if config.servicePath != "" {
|
||||
log.Debugf("use base path: %s", config.servicePath)
|
||||
path = config.servicePath
|
||||
}
|
||||
|
||||
requestContext := &RequestContext{
|
||||
Path: path,
|
||||
ReqHeaders: headers,
|
||||
ReqBody: body,
|
||||
}
|
||||
|
||||
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.ReqBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
err := config.ValidateBody(responseBody)
|
||||
if err != nil {
|
||||
sendResponse(ctx, config, log, nil)
|
||||
return
|
||||
}
|
||||
requestContext.SaveBodyToHistMsg(log, body, responseBody)
|
||||
log.Debugf("[first request] resp code: %d", statusCode)
|
||||
validateJson, err := config.ValidateJson(responseBody, log)
|
||||
if err == nil {
|
||||
sendResponse(ctx, config, log, []byte(validateJson))
|
||||
return
|
||||
} else {
|
||||
retryCount := 0
|
||||
requestContext.SaveStrToHistMsg(log, err.Error())
|
||||
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
|
||||
}
|
||||
}, uint32(config.serviceTimeout))
|
||||
|
||||
return types.ActionPause
|
||||
}
|
||||
180
plugins/wasm-go/extensions/ai-json-resp/model.go
Normal file
180
plugins/wasm-go/extensions/ai-json-resp/model.go
Normal file
@@ -0,0 +1,180 @@
|
||||
// adopt from https://github.com/alibaba/higress/blob/main/plugins/wasm-go/extensions/ai-proxy/provider/model.go
|
||||
package main
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
streamEventIdItemKey = "id:"
|
||||
streamEventNameItemKey = "event:"
|
||||
streamBuiltInItemKey = ":"
|
||||
streamHttpStatusValuePrefix = "HTTP_STATUS/"
|
||||
streamDataItemKey = "data:"
|
||||
streamEndDataValue = "[DONE]"
|
||||
)
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
StreamOptions *streamOptions `json:"stream_options,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type streamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
type tool struct {
|
||||
Type string `json:"type"`
|
||||
Function function `json:"function"`
|
||||
}
|
||||
|
||||
type function struct {
|
||||
Description string `json:"description,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Parameters map[string]interface{} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type toolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Function function `json:"function"`
|
||||
}
|
||||
|
||||
type chatCompletionResponse struct {
|
||||
Id string `json:"id,omitempty"`
|
||||
Choices []chatCompletionChoice `json:"choices"`
|
||||
Created int64 `json:"created,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Usage usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message *chatMessage `json:"message,omitempty"`
|
||||
Delta *chatMessage `json:"delta,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type usage struct {
|
||||
PromptTokens int `json:"prompt_tokens,omitempty"`
|
||||
CompletionTokens int `json:"completion_tokens,omitempty"`
|
||||
TotalTokens int `json:"total_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
func (m *chatMessage) IsEmpty() bool {
|
||||
if m.Content != "" {
|
||||
return false
|
||||
}
|
||||
if len(m.ToolCalls) != 0 {
|
||||
nonEmpty := false
|
||||
for _, toolCall := range m.ToolCalls {
|
||||
if !toolCall.Function.IsEmpty() {
|
||||
nonEmpty = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if nonEmpty {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type toolCall struct {
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function functionCall `json:"function"`
|
||||
}
|
||||
|
||||
type functionCall struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
func (m *functionCall) IsEmpty() bool {
|
||||
return m.Name == "" && m.Arguments == ""
|
||||
}
|
||||
|
||||
type streamEvent struct {
|
||||
Id string `json:"id"`
|
||||
Event string `json:"event"`
|
||||
Data string `json:"data"`
|
||||
HttpStatus string `json:"http_status"`
|
||||
}
|
||||
|
||||
func (e *streamEvent) setValue(key, value string) {
|
||||
switch key {
|
||||
case streamEventIdItemKey:
|
||||
e.Id = value
|
||||
case streamEventNameItemKey:
|
||||
e.Event = value
|
||||
case streamDataItemKey:
|
||||
e.Data = value
|
||||
case streamBuiltInItemKey:
|
||||
if strings.HasPrefix(value, streamHttpStatusValuePrefix) {
|
||||
e.HttpStatus = value[len(streamHttpStatusValuePrefix):]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type embeddingsRequest struct {
|
||||
Input interface{} `json:"input"`
|
||||
Model string `json:"model"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type embeddingsResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []embedding `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage usage `json:"usage"`
|
||||
}
|
||||
|
||||
type embedding struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (r embeddingsRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
33
plugins/wasm-go/extensions/ai-json-resp/util.go
Normal file
33
plugins/wasm-go/extensions/ai-json-resp/util.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package main
|
||||
|
||||
func GetMaxDepth(data interface{}) int {
|
||||
type item struct {
|
||||
value interface{}
|
||||
depth int
|
||||
}
|
||||
|
||||
maxDepth := 0
|
||||
stack := []item{{value: data, depth: 1}}
|
||||
|
||||
for len(stack) > 0 {
|
||||
currentItem := stack[len(stack)-1]
|
||||
stack = stack[:len(stack)-1]
|
||||
|
||||
if currentItem.depth > maxDepth {
|
||||
maxDepth = currentItem.depth
|
||||
}
|
||||
|
||||
switch v := currentItem.value.(type) {
|
||||
case map[string]interface{}:
|
||||
for _, value := range v {
|
||||
stack = append(stack, item{value: value, depth: currentItem.depth + 1})
|
||||
}
|
||||
case []interface{}:
|
||||
for _, value := range v {
|
||||
stack = append(stack, item{value: value, depth: currentItem.depth + 1})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return maxDepth
|
||||
}
|
||||
Reference in New Issue
Block a user