mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
AI Agent plugin adds JSON formatting output feature (#1374)
This commit is contained in:
@@ -2,8 +2,10 @@ package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -47,6 +49,8 @@ func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||||
|
||||
initLLMClient(gjson, c)
|
||||
|
||||
initJsonResp(gjson, c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -76,10 +80,10 @@ func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRe
|
||||
log.Debugf("[onHttpRequestBody] newRequestBody: %s", string(newbody))
|
||||
err := proxywasm.ReplaceHttpRequestBody(newbody)
|
||||
if err != nil {
|
||||
log.Debug("替换失败")
|
||||
log.Debugf("failed replace err: %s", err.Error())
|
||||
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "替换失败"+err.Error())), -1)
|
||||
}
|
||||
log.Debug("[onHttpRequestBody] request替换成功")
|
||||
log.Debug("[onHttpRequestBody] replace request success")
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
@@ -175,11 +179,103 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wra
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
|
||||
func extractJson(bodyStr string) (string, error) {
|
||||
// simply extract json from response body string
|
||||
startIndex := strings.Index(bodyStr, "{")
|
||||
endIndex := strings.LastIndex(bodyStr, "}") + 1
|
||||
|
||||
// if not found
|
||||
if startIndex == -1 || startIndex >= endIndex {
|
||||
return "", errors.New("cannot find json in the response body")
|
||||
}
|
||||
|
||||
jsonStr := bodyStr[startIndex:endIndex]
|
||||
|
||||
// attempt to parse the JSON
|
||||
var result map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return jsonStr, nil
|
||||
}
|
||||
|
||||
func jsonFormat(llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonSchema map[string]interface{}, assistantMessage Message, actionInput string, headers [][2]string, streamMode bool, rawResponse Response, log wrapper.Log) string {
|
||||
prompt := fmt.Sprintf(prompttpl.Json_Resp_Template, jsonSchema, actionInput)
|
||||
|
||||
messages := []dashscope.Message{{Role: "user", Content: prompt}}
|
||||
|
||||
completion := dashscope.Completion{
|
||||
Model: llmInfo.Model,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
completionSerialized, _ := json.Marshal(completion)
|
||||
var content string
|
||||
err := llmClient.Post(
|
||||
llmInfo.Path,
|
||||
headers,
|
||||
completionSerialized,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
//得到gpt的返回结果
|
||||
var responseCompletion dashscope.CompletionResponse
|
||||
_ = json.Unmarshal(responseBody, &responseCompletion)
|
||||
log.Infof("[jsonFormat] content: %s", responseCompletion.Choices[0].Message.Content)
|
||||
content = responseCompletion.Choices[0].Message.Content
|
||||
jsonStr, err := extractJson(content)
|
||||
if err != nil {
|
||||
log.Debugf("[onHttpRequestBody] extractJson err: %s", err.Error())
|
||||
jsonStr = content
|
||||
}
|
||||
|
||||
if streamMode {
|
||||
stream(jsonStr, rawResponse, log)
|
||||
} else {
|
||||
noneStream(assistantMessage, jsonStr, rawResponse, log)
|
||||
}
|
||||
}, uint32(llmInfo.MaxExecutionTime))
|
||||
if err != nil {
|
||||
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func noneStream(assistantMessage Message, actionInput string, rawResponse Response, log wrapper.Log) {
|
||||
assistantMessage.Role = "assistant"
|
||||
assistantMessage.Content = actionInput
|
||||
rawResponse.Choices[0].Message = assistantMessage
|
||||
newbody, err := json.Marshal(rawResponse)
|
||||
if err != nil {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
} else {
|
||||
proxywasm.ReplaceHttpResponseBody(newbody)
|
||||
|
||||
log.Debug("[onHttpResponseBody] replace response success")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
}
|
||||
|
||||
func stream(actionInput string, rawResponse Response, log wrapper.Log) {
|
||||
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
|
||||
proxywasm.ReplaceHttpResponseHeaders(headers)
|
||||
// Remove quotes from actionInput
|
||||
actionInput = strings.Trim(actionInput, "\"")
|
||||
returnStreamResponseTemplate := `data:{"id":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"%s","object":"chat.completion","usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}` + "\n\ndata:[DONE]\n\n"
|
||||
newbody := fmt.Sprintf(returnStreamResponseTemplate, rawResponse.ID, actionInput, rawResponse.Model, rawResponse.Usage.PromptTokens, rawResponse.Usage.CompletionTokens, rawResponse.Usage.TotalTokens)
|
||||
log.Infof("[onHttpResponseBody] newResponseBody: ", newbody)
|
||||
proxywasm.ReplaceHttpResponseBody([]byte(newbody))
|
||||
|
||||
log.Debug("[onHttpResponseBody] replace response success")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
|
||||
func toolsCallResult(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
log.Debugf("statusCode: %d", statusCode)
|
||||
}
|
||||
log.Info("========函数返回结果========")
|
||||
log.Info("========function result========")
|
||||
log.Infof(string(responseBody))
|
||||
|
||||
observation := "Observation: " + string(responseBody)
|
||||
@@ -187,15 +283,15 @@ func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content strin
|
||||
dashscope.MessageStore.AddForUser(observation)
|
||||
|
||||
completion := dashscope.Completion{
|
||||
Model: config.LLMInfo.Model,
|
||||
Model: llmInfo.Model,
|
||||
Messages: dashscope.MessageStore,
|
||||
MaxTokens: config.LLMInfo.MaxTokens,
|
||||
MaxTokens: llmInfo.MaxTokens,
|
||||
}
|
||||
|
||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
|
||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + llmInfo.APIKey}}
|
||||
completionSerialized, _ := json.Marshal(completion)
|
||||
err := config.LLMClient.Post(
|
||||
config.LLMInfo.Path,
|
||||
err := llmClient.Post(
|
||||
llmInfo.Path,
|
||||
headers,
|
||||
completionSerialized,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
@@ -205,42 +301,31 @@ func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content strin
|
||||
log.Infof("[toolsCall] content: %s", responseCompletion.Choices[0].Message.Content)
|
||||
|
||||
if responseCompletion.Choices[0].Message.Content != "" {
|
||||
retType, actionInput := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||||
retType, actionInput := toolsCall(ctx, llmClient, llmInfo, jsonResp, aPIsParam, aPIClient, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||||
if retType == types.ActionContinue {
|
||||
//得到了Final Answer
|
||||
var assistantMessage Message
|
||||
var streamMode bool
|
||||
if ctx.GetContext(StreamContextKey) == nil {
|
||||
assistantMessage.Role = "assistant"
|
||||
assistantMessage.Content = actionInput
|
||||
rawResponse.Choices[0].Message = assistantMessage
|
||||
newbody, err := json.Marshal(rawResponse)
|
||||
if err != nil {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
streamMode = false
|
||||
if jsonResp.Enable {
|
||||
jsonFormat(llmClient, llmInfo, jsonResp.JsonSchema, assistantMessage, actionInput, headers, streamMode, rawResponse, log)
|
||||
} else {
|
||||
proxywasm.ReplaceHttpResponseBody(newbody)
|
||||
|
||||
log.Debug("[onHttpResponseBody] response替换成功")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
noneStream(assistantMessage, actionInput, rawResponse, log)
|
||||
}
|
||||
} else {
|
||||
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
|
||||
proxywasm.ReplaceHttpResponseHeaders(headers)
|
||||
// Remove quotes from actionInput
|
||||
actionInput = strings.Trim(actionInput, "\"")
|
||||
returnStreamResponseTemplate := `data:{"id":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"%s","object":"chat.completion","usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}` + "\n\ndata:[DONE]\n\n"
|
||||
newbody := fmt.Sprintf(returnStreamResponseTemplate, rawResponse.ID, actionInput, rawResponse.Model, rawResponse.Usage.PromptTokens, rawResponse.Usage.CompletionTokens, rawResponse.Usage.TotalTokens)
|
||||
log.Infof("[onHttpResponseBody] newResponseBody: ", newbody)
|
||||
proxywasm.ReplaceHttpResponseBody([]byte(newbody))
|
||||
|
||||
log.Debug("[onHttpResponseBody] response替换成功")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
streamMode = true
|
||||
if jsonResp.Enable {
|
||||
jsonFormat(llmClient, llmInfo, jsonResp.JsonSchema, assistantMessage, actionInput, headers, streamMode, rawResponse, log)
|
||||
} else {
|
||||
stream(actionInput, rawResponse, log)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}, uint32(config.LLMInfo.MaxExecutionTime))
|
||||
}, uint32(llmInfo.MaxExecutionTime))
|
||||
if err != nil {
|
||||
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
||||
proxywasm.ResumeHttpRequest()
|
||||
@@ -294,7 +379,7 @@ func outputParser(response string, log wrapper.Log) (string, string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
|
||||
func toolsCall(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
|
||||
dashscope.MessageStore.AddForAssistant(content)
|
||||
|
||||
action, actionInput := outputParser(content, log)
|
||||
@@ -305,9 +390,9 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
|
||||
}
|
||||
count := ctx.GetContext(ToolCallsCount).(int)
|
||||
count++
|
||||
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d", count, config.LLMInfo.MaxIterations)
|
||||
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d", count, llmInfo.MaxIterations)
|
||||
//函数递归调用次数,达到了预设的循环次数,强制结束
|
||||
if int64(count) > config.LLMInfo.MaxIterations {
|
||||
if int64(count) > llmInfo.MaxIterations {
|
||||
ctx.SetContext(ToolCallsCount, 0)
|
||||
return types.ActionContinue, ""
|
||||
} else {
|
||||
@@ -316,15 +401,14 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
|
||||
|
||||
//没得到最终答案
|
||||
|
||||
var url string
|
||||
var urlStr string
|
||||
var headers [][2]string
|
||||
var apiClient wrapper.HttpClient
|
||||
var method string
|
||||
var reqBody []byte
|
||||
var key string
|
||||
var maxExecutionTime int64
|
||||
|
||||
for i, apisParam := range config.APIsParam {
|
||||
for i, apisParam := range aPIsParam {
|
||||
maxExecutionTime = apisParam.MaxExecutionTime
|
||||
for _, tools_param := range apisParam.ToolsParam {
|
||||
if action == tools_param.ToolName {
|
||||
@@ -340,28 +424,37 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
|
||||
|
||||
method = tools_param.Method
|
||||
|
||||
// 组装 headers 和 key
|
||||
headers = [][2]string{{"Content-Type", "application/json"}}
|
||||
if apisParam.APIKey.Name != "" {
|
||||
if apisParam.APIKey.In == "query" {
|
||||
key = "?" + apisParam.APIKey.Name + "=" + apisParam.APIKey.Value
|
||||
} else if apisParam.APIKey.In == "header" {
|
||||
headers = append(headers, [2]string{"Authorization", apisParam.APIKey.Name + " " + apisParam.APIKey.Value})
|
||||
// 组装 URL 和请求体
|
||||
urlStr = apisParam.URL + tools_param.Path
|
||||
|
||||
// 解析URL模板以查找路径参数
|
||||
urlParts := strings.Split(urlStr, "/")
|
||||
for i, part := range urlParts {
|
||||
if strings.Contains(part, "{") && strings.Contains(part, "}") {
|
||||
for _, param := range tools_param.ParamName {
|
||||
paramNameInPath := part[1 : len(part)-1]
|
||||
if paramNameInPath == param {
|
||||
if value, ok := data[param]; ok {
|
||||
// 删除已经使用过的
|
||||
delete(data, param)
|
||||
// 替换模板中的占位符
|
||||
urlParts[i] = url.QueryEscape(value.(string))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 组装 URL 和请求体
|
||||
url = apisParam.URL + tools_param.Path + key
|
||||
// 重新组合URL
|
||||
urlStr = strings.Join(urlParts, "/")
|
||||
|
||||
queryParams := make([][2]string, 0)
|
||||
if method == "GET" {
|
||||
queryParams := make([]string, 0, len(tools_param.ParamName))
|
||||
for _, param := range tools_param.ParamName {
|
||||
if value, ok := data[param]; ok {
|
||||
queryParams = append(queryParams, fmt.Sprintf("%s=%v", param, value))
|
||||
queryParams = append(queryParams, [2]string{param, fmt.Sprintf("%v", value)})
|
||||
}
|
||||
}
|
||||
if len(queryParams) > 0 {
|
||||
url += "&" + strings.Join(queryParams, "&")
|
||||
}
|
||||
} else if method == "POST" {
|
||||
var err error
|
||||
reqBody, err = json.Marshal(data)
|
||||
@@ -371,9 +464,30 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("url: %s", url)
|
||||
// 组装 headers 和 key
|
||||
headers = [][2]string{{"Content-Type", "application/json"}}
|
||||
if apisParam.APIKey.Name != "" {
|
||||
if apisParam.APIKey.In == "query" {
|
||||
queryParams = append(queryParams, [2]string{apisParam.APIKey.Name, apisParam.APIKey.Value})
|
||||
} else if apisParam.APIKey.In == "header" {
|
||||
headers = append(headers, [2]string{"Authorization", apisParam.APIKey.Name + " " + apisParam.APIKey.Value})
|
||||
}
|
||||
}
|
||||
|
||||
apiClient = config.APIClient[i]
|
||||
if len(queryParams) > 0 {
|
||||
// 将 key 拼接到 url 后面
|
||||
urlStr += "?"
|
||||
for i, param := range queryParams {
|
||||
if i != 0 {
|
||||
urlStr += "&"
|
||||
}
|
||||
urlStr += url.QueryEscape(param[0]) + "=" + url.QueryEscape(param[1])
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("url: %s", urlStr)
|
||||
|
||||
apiClient = aPIClient[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -382,11 +496,11 @@ func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, raw
|
||||
if apiClient != nil {
|
||||
err := apiClient.Call(
|
||||
method,
|
||||
url,
|
||||
urlStr,
|
||||
headers,
|
||||
reqBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
|
||||
toolsCallResult(ctx, llmClient, llmInfo, jsonResp, aPIsParam, aPIClient, content, rawResponse, log, statusCode, responseBody)
|
||||
}, uint32(maxExecutionTime))
|
||||
if err != nil {
|
||||
log.Debugf("tool calls error: %s", err.Error())
|
||||
@@ -415,7 +529,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byt
|
||||
//如果gpt返回的内容不是空的
|
||||
if rawResponse.Choices[0].Message.Content != "" {
|
||||
//进入agent的循环思考,工具调用的过程中
|
||||
retType, _ := toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||||
retType, _ := toolsCall(ctx, config.LLMClient, config.LLMInfo, config.JsonResp, config.APIsParam, config.APIClient, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||||
return retType
|
||||
} else {
|
||||
return types.ActionContinue
|
||||
|
||||
Reference in New Issue
Block a user