Files
higress/plugins/wasm-go/extensions/ai-json-resp/main.go
2025-03-26 20:27:53 +08:00

574 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/santhosh-tekuri/jsonschema"
"github.com/tidwall/gjson"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)
const (
DEFAULT_SCHEMA = "defaultSchema"
HTTP_STATUS_OK = uint32(200)
HTTP_STATUS_INTERNAL_SERVER_ERROR = uint32(500)
FROM_THIS_PLUGIN_KEY = "fromThisPlugin"
EXTEND_HEADER_KEY = "X-HIGRESS-AI-JSON-RESP"
JSON_SCHEMA_INVALID_CODE = 1001
JSON_SCHEMA_COMPILE_FAILED_CODE = 1002
CANNOT_FIND_JSON_IN_RESPONSE_CODE = 1003
CONTENT_IS_EMPTY_CODE = 1004
JSON_MISMATCH_SCHEMA_CODE = 1005
REACH_MAX_RETRY_COUNT_CODE = 1006
SERVICE_UNAVAILABLE_CODE = 1007
SERVICE_CONFIG_INVALID_CODE = 1008
)
type RejectStruct struct {
RejectCode uint32 `json:"Code"`
RejectMsg string `json:"Msg"`
}
func (r RejectStruct) GetBytes() []byte {
jsonData, _ := json.Marshal(r)
return jsonData
}
func (r RejectStruct) GetShortMsg() string {
return "ai-json-resp." + strings.Split(r.RejectMsg, ":")[0]
}
type PluginConfig struct {
// @Title zh-CN 服务名称
// @Description zh-CN 用以请求服务的名称(网关或其他AI服务)
serviceName string `required:"true" json:"serviceName" yaml:"serviceName"`
// @Title zh-CN 服务域名
// @Description zh-CN 用以请求服务的域名
serviceDomain string `required:"false" json:"serviceDomain" yaml:"serviceDomain"`
// @Title zh-CN 服务端口
// @Description zh-CN 用以请求服务的端口
servicePort int `required:"false" json:"servicePort" yaml:"servicePort"`
// @Title zh-CN 服务URL
// @Description zh-CN 用以请求服务的URL若提供则会覆盖serviceDomain和servicePort
serviceUrl string `required:"false" json:"serviceUrl" yaml:"serviceUrl"`
// @Title zh-CN API Key
// @Description zh-CN 若使用AI服务需要填写请求服务的API Key
apiKey string `required:"false" json: "apiKey" yaml:"apiKey"`
// @Title zh-CN 请求端点
// @Description zh-CN 用以请求服务的端点, 默认为"/v1/chat/completions"
servicePath string `required:"false" json: "servicePath" yaml:"servicePath"`
// @Title zh-CN 服务超时时间
// @Description zh-CN 用以请求服务的超时时间
serviceTimeout int `required:"false" json:"serviceTimeout" yaml:"serviceTimeout"`
// @Title zh-CN 最大重试次数
// @Description zh-CN 用以请求服务的最大重试次数
maxRetry int `required:"false" json:"maxRetry" yaml:"maxRetry"`
// @Title zh-CN 内容路径
// @Description zh-CN 从AI服务返回的响应中提取json的gpath路径
contentPath string `required:"false" json:"contentPath" yaml:"contentPath"`
// @Title zh-CN Json Schema
// @Description zh-CN 用以验证响应json的Json Schema, 为空则只验证返回的响应是否为合法json
jsonSchema map[string]interface{} `required:"false" json:"jsonSchema" yaml:"jsonSchema"`
// @Title zh-CN 是否启用swagger
// @Description zh-CN 是否启用swagger进行Json Schema验证
enableSwagger bool `required:"false" json:"enableSwagger" yaml:"enableSwagger"`
// @Title zh-CN 是否启用oas3
// @Description zh-CN 是否启用oas3进行Json Schema验证
enableOas3 bool `required:"false" json:"enableOas3" yaml:"enableOas3"`
// @Title zh-CN 是否启用Content-Disposition
// @Description zh-CN 是否启用Content-Disposition, 若启用则会在响应头中添加Content-Disposition: attachment; filename="response.json"
enableContentDisposition bool `required:"false" json:"enableContentDisposition" yaml:"enableContentDisposition"`
serviceClient wrapper.HttpClient
draft *jsonschema.Draft
compiler *jsonschema.Compiler
compile *jsonschema.Schema
rejectStruct RejectStruct
jsonSchemaMaxDepth int
enableJsonSchemaValidation bool
}
func main() {
wrapper.SetCtx(
"ai-json-resp",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
)
}
type RequestContext struct {
Path string
ReqHeaders [][2]string
ReqBody []byte
RespHeader [][2]string
RespBody []byte
HistoryMessages []chatMessage
}
func parseUrl(url string) (string, string) {
if url == "" {
return "", ""
}
url = strings.TrimPrefix(url, "http://")
url = strings.TrimPrefix(url, "https://")
index := strings.Index(url, "/")
if index == -1 {
return url, ""
}
return url[:index], url[index:]
}
func parseConfig(result gjson.Result, config *PluginConfig, log wrapper.Log) error {
config.serviceName = result.Get("serviceName").String()
config.serviceUrl = result.Get("serviceUrl").String()
config.serviceDomain = result.Get("serviceDomain").String()
config.servicePath = result.Get("servicePath").String()
config.servicePort = int(result.Get("servicePort").Int())
if config.serviceUrl != "" {
domain, url := parseUrl(config.serviceUrl)
log.Debugf("serviceUrl: %s, the parsed domain: %s, the parsed url: %s", config.serviceUrl, domain, url)
if config.serviceDomain == "" {
config.serviceDomain = domain
}
if config.servicePath == "" {
config.servicePath = url
}
}
if config.servicePort == 0 {
config.servicePort = 443
}
config.serviceTimeout = int(result.Get("serviceTimeout").Int())
config.apiKey = result.Get("apiKey").String()
config.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
if config.serviceTimeout == 0 {
config.serviceTimeout = 50000
}
config.maxRetry = int(result.Get("maxRetry").Int())
if config.maxRetry == 0 {
config.maxRetry = 3
}
config.contentPath = result.Get("contentPath").String()
if config.contentPath == "" {
config.contentPath = "choices.0.message.content"
}
if jsonSchemaValue := result.Get("jsonSchema"); jsonSchemaValue.Exists() {
if schemaValue, ok := jsonSchemaValue.Value().(map[string]interface{}); ok {
config.jsonSchema = schemaValue
} else {
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema is not valid"}
}
} else {
config.jsonSchema = nil
}
if config.serviceDomain == "" {
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "service domain is empty"}
}
config.serviceClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: config.serviceName,
Port: int64(config.servicePort),
Domain: config.serviceDomain,
})
enableSwagger := result.Get("enableSwagger").Bool()
enableOas3 := result.Get("enableOas3").Bool()
// set draft version
if enableSwagger {
config.draft = jsonschema.Draft4
}
if enableOas3 {
config.draft = jsonschema.Draft7
}
if !enableSwagger && !enableOas3 {
config.draft = jsonschema.Draft7
}
// create compiler
compiler := jsonschema.NewCompiler()
compiler.Draft = config.draft
config.compiler = compiler
// set max depth of json schema
config.jsonSchemaMaxDepth = 6
enableContentDispositionValue := result.Get("enableContentDisposition")
if !enableContentDispositionValue.Exists() {
config.enableContentDisposition = true
} else {
config.enableContentDisposition = enableContentDispositionValue.Bool()
}
config.enableJsonSchemaValidation = true
jsonSchemaBytes, err := json.Marshal(config.jsonSchema)
if err != nil {
config.rejectStruct = RejectStruct{JSON_SCHEMA_INVALID_CODE, "Json Schema marshal failed"}
return err
}
maxDepth := GetMaxDepth(config.jsonSchema)
log.Debugf("max depth of json schema: %d", maxDepth)
if maxDepth > config.jsonSchemaMaxDepth {
config.enableJsonSchemaValidation = false
log.Infof("Json Schema depth exceeded: %d from %d , Json Schema validation will not be used.", maxDepth, config.jsonSchemaMaxDepth)
}
if config.enableJsonSchemaValidation {
jsonSchemaStr := string(jsonSchemaBytes)
config.compiler.AddResource(DEFAULT_SCHEMA, strings.NewReader(jsonSchemaStr))
// Test if the Json Schema is valid
compile, err := config.compiler.Compile(DEFAULT_SCHEMA)
if err != nil {
log.Infof("Json Schema compile failed: %v", err)
config.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
config.compile = nil
} else {
config.compile = compile
}
}
return nil
}
func (r *RequestContext) assembleReqBody(config PluginConfig) []byte {
var reqBodystrut chatCompletionRequest
json.Unmarshal(r.ReqBody, &reqBodystrut)
content := gjson.ParseBytes(r.RespBody).Get(config.contentPath).String()
jsonSchemaBytes, _ := json.Marshal(config.jsonSchema)
jsonSchemaStr := string(jsonSchemaBytes)
askQuestion := "Given the Json Schema: " + jsonSchemaStr + ", please help me convert the following content to a pure json: " + content
askQuestion += "\n Do not respond other content except the pure json!!!!"
reqBodystrut.Messages = append(r.HistoryMessages, []chatMessage{
{
Role: "user",
Content: askQuestion,
},
}...)
reqBody, _ := json.Marshal(reqBodystrut)
return reqBody
}
func (r *RequestContext) SaveBodyToHistMsg(log wrapper.Log, reqBody []byte, respBody []byte) {
r.RespBody = respBody
lastUserMessage := ""
lastSystemMessage := ""
var reqBodystrut chatCompletionRequest
err := json.Unmarshal(reqBody, &reqBodystrut)
if err != nil {
log.Debugf("unmarshal reqBody failed: %v", err)
} else {
if len(reqBodystrut.Messages) != 0 {
lastUserMessage = reqBodystrut.Messages[len(reqBodystrut.Messages)-1].Content
}
}
var respBodystrut chatCompletionResponse
err = json.Unmarshal(respBody, &respBodystrut)
if err != nil {
log.Debugf("unmarshal respBody failed: %v", err)
} else {
if len(respBodystrut.Choices) != 0 {
lastSystemMessage = respBodystrut.Choices[len(respBodystrut.Choices)-1].Message.Content
}
}
if lastUserMessage != "" {
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
Role: "user",
Content: lastUserMessage,
})
}
if lastSystemMessage != "" {
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
Role: "system",
Content: lastSystemMessage,
})
}
}
func (r *RequestContext) SaveStrToHistMsg(log wrapper.Log, errMsg string) {
r.HistoryMessages = append(r.HistoryMessages, chatMessage{
Role: "system",
Content: errMsg,
})
}
func (c *PluginConfig) ValidateBody(body []byte) error {
var respJsonStrct chatCompletionResponse
err := json.Unmarshal(body, &respJsonStrct)
if err != nil {
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "service unavailable: " + string(body)}
return errors.New(c.rejectStruct.RejectMsg)
}
content := gjson.ParseBytes(body).Get(c.contentPath)
if !content.Exists() {
c.rejectStruct = RejectStruct{SERVICE_UNAVAILABLE_CODE, "response body does not contain the content: " + string(body)}
return errors.New(c.rejectStruct.RejectMsg)
}
return nil
}
func (c *PluginConfig) ValidateJson(body []byte, log wrapper.Log) (string, error) {
content := gjson.ParseBytes(body).Get(c.contentPath).String()
// first extract json from response body
if content == "" {
log.Infof("response body does not contain the content")
c.rejectStruct = RejectStruct{CONTENT_IS_EMPTY_CODE, "response body does not contain the content"}
return "", errors.New(c.rejectStruct.RejectMsg)
}
jsonStr, err := c.ExtractJson(content)
if err != nil {
log.Infof("response body does not contain the valid json: %v", err.Error())
c.rejectStruct = RejectStruct{CANNOT_FIND_JSON_IN_RESPONSE_CODE, "response body does not contain the valid json: " + err.Error()}
return "", errors.New(c.rejectStruct.RejectMsg)
}
if c.jsonSchema != nil && c.enableJsonSchemaValidation {
compile, err := c.compiler.Compile(DEFAULT_SCHEMA)
if err != nil {
log.Infof("Json Schema compile failed: %v", err)
c.rejectStruct = RejectStruct{JSON_SCHEMA_COMPILE_FAILED_CODE, "Json Schema compile failed: " + err.Error()}
c.compile = nil
} else {
c.compile = compile
}
// validate the json
err = c.compile.Validate(strings.NewReader(jsonStr))
if err != nil {
log.Infof("response body does not match the Json Schema: %v", err)
c.rejectStruct = RejectStruct{JSON_MISMATCH_SCHEMA_CODE, "response body does not match the Json Schema: " + err.Error()}
return "", errors.New(c.rejectStruct.RejectMsg)
}
}
c.rejectStruct = RejectStruct{HTTP_STATUS_OK, ""}
return jsonStr, nil
}
func (c *PluginConfig) ExtractJson(bodyStr string) (string, error) {
// simply extract json from response body string
startIndex := strings.Index(bodyStr, "{")
endIndex := strings.LastIndex(bodyStr, "}") + 1
// if not found
if startIndex == -1 || endIndex == -1 || startIndex >= endIndex {
return "", errors.New("cannot find json in the response body")
}
jsonStr := bodyStr[startIndex:endIndex]
// attempt to parse the JSON
var result map[string]interface{}
err := json.Unmarshal([]byte(jsonStr), &result)
if err != nil {
return "", err
}
return jsonStr, nil
}
func sendResponse(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, body []byte) {
log.Infof("Final send: Code %d, Message %s, Body: %s", config.rejectStruct.RejectCode, config.rejectStruct.RejectMsg, string(body))
header := [][2]string{
{"Content-Type", "application/json"},
}
if body != nil && config.enableContentDisposition {
header = append(header, [2]string{"Content-Disposition", "attachment; filename=\"response.json\""})
}
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
proxywasm.SendHttpResponseWithDetail(HTTP_STATUS_INTERNAL_SERVER_ERROR, config.rejectStruct.GetShortMsg(), nil, config.rejectStruct.GetBytes(), -1)
} else {
proxywasm.SendHttpResponse(HTTP_STATUS_OK, header, body, -1)
}
}
func recursiveRefineJson(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log, retryCount int, requestContext *RequestContext) {
// if retry count exceeds max retry count, return the response
if retryCount >= config.maxRetry {
log.Debugf("retry count exceeds max retry count")
// report more useful error by appending the last of previous error message
config.rejectStruct = RejectStruct{REACH_MAX_RETRY_COUNT_CODE, "retry count exceeds max retry count: " + config.rejectStruct.RejectMsg}
sendResponse(ctx, config, log, nil)
return
}
// recursively refine json
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.assembleReqBody(config),
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
err := config.ValidateBody(responseBody)
if err != nil {
sendResponse(ctx, config, log, nil)
return
}
retryCount++
requestContext.SaveBodyToHistMsg(log, requestContext.assembleReqBody(config), responseBody)
log.Debugf("[retry request %d/%d] resp code: %d", retryCount, config.maxRetry, statusCode)
validateJson, err := config.ValidateJson(responseBody, log)
if err == nil {
sendResponse(ctx, config, log, []byte(validateJson))
} else {
requestContext.SaveStrToHistMsg(log, err.Error())
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
}
}, uint32(config.serviceTimeout))
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
sendResponse(ctx, config, log, nil)
return types.ActionPause
}
// verify if the request is from this plugin
extendHeaderValue, err := proxywasm.GetHttpRequestHeader(EXTEND_HEADER_KEY)
if err == nil {
fromThisPlugin, convErr := strconv.ParseBool(extendHeaderValue)
if convErr != nil {
log.Debugf("failed to parse header value as bool: %v", convErr)
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
}
if fromThisPlugin {
ctx.SetContext(FROM_THIS_PLUGIN_KEY, true)
return types.ActionContinue
}
} else {
ctx.SetContext(FROM_THIS_PLUGIN_KEY, false)
}
path, err := proxywasm.GetHttpRequestHeader(":path")
if err != nil {
log.Infof("get request path failed: %v", err)
path = ""
} else {
ctx.SetContext("path", path)
}
headers, err := proxywasm.GetHttpRequestHeaders()
if err != nil {
log.Infof("get request header failed: %v", err)
}
apiKey, err := proxywasm.GetHttpRequestHeader("Authorization")
if err != nil {
log.Infof("get request header failed: %v", err)
apiKey = ""
}
if apiKey != "" {
// remove the Authorization header
proxywasm.RemoveHttpRequestHeader("Authorization")
// remove the Authorization header from the headers
for i, header := range headers {
if header[0] == "Authorization" {
headers = append(headers[:i], headers[i+1:]...)
break
}
}
}
if config.apiKey != "" {
log.Debugf("add Authorization header %s", "Bearer "+config.apiKey)
headers = append(headers, [2]string{"Authorization", "Bearer " + config.apiKey})
}
ctx.SetContext("headers", headers)
return types.ActionContinue
}
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
// if the request is from this plugin, continue the request
fromThisPlugin, ok := ctx.GetContext(FROM_THIS_PLUGIN_KEY).(bool)
if ok && fromThisPlugin {
log.Debugf("detected buffer_request, sending request to AI service")
return types.ActionContinue
}
var headers [][2]string
if h, ok := ctx.GetContext("headers").([][2]string); ok {
headers = append(h, [2]string{EXTEND_HEADER_KEY, "true"})
} else {
log.Debugf("cannot get headers from context, use default headers")
headers = [][2]string{
{"Content-Type", "application/json"},
{EXTEND_HEADER_KEY, "true"},
}
}
// if there is any error in the config, return the response directly
if config.rejectStruct.RejectCode != HTTP_STATUS_OK {
sendResponse(ctx, config, log, nil)
return types.ActionContinue
}
var path string
if path, ok := ctx.GetContext("path").(string); ok {
log.Debugf("use path: %s", path)
} else {
log.Debugf("cannot get path from context, use default path")
path = "/v1/chat/completions"
}
if config.servicePath != "" {
log.Debugf("use base path: %s", config.servicePath)
path = config.servicePath
}
requestContext := &RequestContext{
Path: path,
ReqHeaders: headers,
ReqBody: body,
}
config.serviceClient.Post(requestContext.Path, requestContext.ReqHeaders, requestContext.ReqBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
err := config.ValidateBody(responseBody)
if err != nil {
sendResponse(ctx, config, log, nil)
return
}
requestContext.SaveBodyToHistMsg(log, body, responseBody)
log.Debugf("[first request] resp code: %d", statusCode)
validateJson, err := config.ValidateJson(responseBody, log)
if err == nil {
sendResponse(ctx, config, log, []byte(validateJson))
return
} else {
retryCount := 0
requestContext.SaveStrToHistMsg(log, err.Error())
recursiveRefineJson(ctx, config, log, retryCount, requestContext)
}
}, uint32(config.serviceTimeout))
return types.ActionPause
}