mirror of
https://github.com/alibaba/higress.git
synced 2026-03-01 23:20:52 +08:00
538 lines
17 KiB
Go
538 lines
17 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"regexp"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/dashscope"
|
||
prompttpl "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/promptTpl"
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
// 用于统计函数的递归调用次数
|
||
const ToolCallsCount = "ToolCallsCount"
|
||
const StreamContextKey = "Stream"
|
||
|
||
// react的正则规则
|
||
const ActionPattern = `Action:\s*(.*?)[.\n]`
|
||
const ActionInputPattern = `Action Input:\s*(.*)`
|
||
const FinalAnswerPattern = `Final Answer:(.*)`
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"ai-agent",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||
)
|
||
}
|
||
|
||
func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||
initResponsePromptTpl(gjson, c)
|
||
|
||
err := initAPIs(gjson, c)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
initReActPromptTpl(gjson, c)
|
||
|
||
initLLMClient(gjson, c)
|
||
|
||
initJsonResp(gjson, c)
|
||
|
||
return nil
|
||
}
|
||
|
||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRequest Request, log wrapper.Log) types.Action {
|
||
log.Debugf("[onHttpRequestBody] firstreq:%s", prompt)
|
||
|
||
var userMessage Message
|
||
userMessage.Role = "user"
|
||
userMessage.Content = prompt
|
||
|
||
newMessages := []Message{userMessage}
|
||
rawRequest.Messages = newMessages
|
||
if rawRequest.Stream {
|
||
ctx.SetContext(StreamContextKey, struct{}{})
|
||
rawRequest.Stream = false
|
||
}
|
||
|
||
//replace old message and resume request qwen
|
||
newbody, err := json.Marshal(rawRequest)
|
||
if err != nil {
|
||
return types.ActionContinue
|
||
} else {
|
||
log.Debugf("[onHttpRequestBody] newRequestBody: %s", string(newbody))
|
||
err := proxywasm.ReplaceHttpRequestBody(newbody)
|
||
if err != nil {
|
||
log.Debugf("failed replace err: %s", err.Error())
|
||
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "替换失败"+err.Error())), -1)
|
||
}
|
||
log.Debug("[onHttpRequestBody] replace request success")
|
||
return types.ActionContinue
|
||
}
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Debug("onHttpRequestBody start")
|
||
defer log.Debug("onHttpRequestBody end")
|
||
|
||
//拿到请求
|
||
var rawRequest Request
|
||
err := json.Unmarshal(body, &rawRequest)
|
||
if err != nil {
|
||
log.Debugf("[onHttpRequestBody] body json umarshal err: %s", err.Error())
|
||
return types.ActionContinue
|
||
}
|
||
log.Debugf("onHttpRequestBody rawRequest: %v", rawRequest)
|
||
|
||
//获取用户query
|
||
var query string
|
||
var history string
|
||
messageLength := len(rawRequest.Messages)
|
||
log.Debugf("[onHttpRequestBody] messageLength: %s", messageLength)
|
||
if messageLength > 0 {
|
||
query = rawRequest.Messages[messageLength-1].Content
|
||
log.Debugf("[onHttpRequestBody] query: %s", query)
|
||
if messageLength >= 3 {
|
||
for i := 0; i < messageLength-1; i += 2 {
|
||
history += "human: " + rawRequest.Messages[i].Content + "\nAI: " + rawRequest.Messages[i+1].Content
|
||
}
|
||
} else {
|
||
history = ""
|
||
}
|
||
} else {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
if query == "" {
|
||
log.Debug("parse query from request body failed")
|
||
return types.ActionContinue
|
||
}
|
||
|
||
//拼装agent prompt模板
|
||
tool_desc := make([]string, 0)
|
||
tool_names := make([]string, 0)
|
||
for _, apisParam := range config.APIsParam {
|
||
for _, tool_param := range apisParam.ToolsParam {
|
||
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Description, tool_param.Description, tool_param.Description, tool_param.Parameter), "\n")
|
||
tool_names = append(tool_names, tool_param.ToolName)
|
||
}
|
||
}
|
||
|
||
var prompt string
|
||
if config.PromptTemplate.Language == "CH" {
|
||
prompt = fmt.Sprintf(prompttpl.CH_Template,
|
||
tool_desc,
|
||
tool_names,
|
||
config.PromptTemplate.CHTemplate.Question,
|
||
config.PromptTemplate.CHTemplate.Thought1,
|
||
config.PromptTemplate.CHTemplate.Observation,
|
||
config.PromptTemplate.CHTemplate.Thought2,
|
||
history,
|
||
query)
|
||
} else {
|
||
prompt = fmt.Sprintf(prompttpl.EN_Template,
|
||
tool_desc,
|
||
tool_names,
|
||
config.PromptTemplate.ENTemplate.Question,
|
||
config.PromptTemplate.ENTemplate.Thought1,
|
||
config.PromptTemplate.ENTemplate.Observation,
|
||
config.PromptTemplate.ENTemplate.Thought2,
|
||
history,
|
||
query)
|
||
}
|
||
|
||
ctx.SetContext(ToolCallsCount, 0)
|
||
|
||
//清理历史对话记录
|
||
dashscope.MessageStore.Clear()
|
||
|
||
//将请求加入到历史对话存储器中
|
||
dashscope.MessageStore.AddForUser(prompt)
|
||
|
||
//开始第一次请求
|
||
ret := firstReq(ctx, config, prompt, rawRequest, log)
|
||
|
||
return ret
|
||
}
|
||
|
||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
log.Debug("onHttpResponseHeaders start")
|
||
defer log.Debug("onHttpResponseHeaders end")
|
||
|
||
return types.ActionContinue
|
||
}
|
||
|
||
func extractJson(bodyStr string) (string, error) {
|
||
// simply extract json from response body string
|
||
startIndex := strings.Index(bodyStr, "{")
|
||
endIndex := strings.LastIndex(bodyStr, "}") + 1
|
||
|
||
// if not found
|
||
if startIndex == -1 || startIndex >= endIndex {
|
||
return "", errors.New("cannot find json in the response body")
|
||
}
|
||
|
||
jsonStr := bodyStr[startIndex:endIndex]
|
||
|
||
// attempt to parse the JSON
|
||
var result map[string]interface{}
|
||
err := json.Unmarshal([]byte(jsonStr), &result)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return jsonStr, nil
|
||
}
|
||
|
||
func jsonFormat(llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonSchema map[string]interface{}, assistantMessage Message, actionInput string, headers [][2]string, streamMode bool, rawResponse Response, log wrapper.Log) string {
|
||
prompt := fmt.Sprintf(prompttpl.Json_Resp_Template, jsonSchema, actionInput)
|
||
|
||
messages := []dashscope.Message{{Role: "user", Content: prompt}}
|
||
|
||
completion := dashscope.Completion{
|
||
Model: llmInfo.Model,
|
||
Messages: messages,
|
||
}
|
||
|
||
completionSerialized, _ := json.Marshal(completion)
|
||
var content string
|
||
err := llmClient.Post(
|
||
llmInfo.Path,
|
||
headers,
|
||
completionSerialized,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
//得到gpt的返回结果
|
||
var responseCompletion dashscope.CompletionResponse
|
||
_ = json.Unmarshal(responseBody, &responseCompletion)
|
||
log.Infof("[jsonFormat] content: %s", responseCompletion.Choices[0].Message.Content)
|
||
content = responseCompletion.Choices[0].Message.Content
|
||
jsonStr, err := extractJson(content)
|
||
if err != nil {
|
||
log.Debugf("[onHttpRequestBody] extractJson err: %s", err.Error())
|
||
jsonStr = content
|
||
}
|
||
|
||
if streamMode {
|
||
stream(jsonStr, rawResponse, log)
|
||
} else {
|
||
noneStream(assistantMessage, jsonStr, rawResponse, log)
|
||
}
|
||
}, uint32(llmInfo.MaxExecutionTime))
|
||
if err != nil {
|
||
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
return content
|
||
}
|
||
|
||
func noneStream(assistantMessage Message, actionInput string, rawResponse Response, log wrapper.Log) {
|
||
assistantMessage.Role = "assistant"
|
||
assistantMessage.Content = actionInput
|
||
rawResponse.Choices[0].Message = assistantMessage
|
||
newbody, err := json.Marshal(rawResponse)
|
||
if err != nil {
|
||
proxywasm.ResumeHttpResponse()
|
||
return
|
||
} else {
|
||
proxywasm.ReplaceHttpResponseBody(newbody)
|
||
|
||
log.Debug("[onHttpResponseBody] replace response success")
|
||
proxywasm.ResumeHttpResponse()
|
||
}
|
||
}
|
||
|
||
func stream(actionInput string, rawResponse Response, log wrapper.Log) {
|
||
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
|
||
proxywasm.ReplaceHttpResponseHeaders(headers)
|
||
// Remove quotes from actionInput
|
||
actionInput = strings.Trim(actionInput, "\"")
|
||
returnStreamResponseTemplate := `data:{"id":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"%s","object":"chat.completion","usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}` + "\n\ndata:[DONE]\n\n"
|
||
newbody := fmt.Sprintf(returnStreamResponseTemplate, rawResponse.ID, actionInput, rawResponse.Model, rawResponse.Usage.PromptTokens, rawResponse.Usage.CompletionTokens, rawResponse.Usage.TotalTokens)
|
||
log.Infof("[onHttpResponseBody] newResponseBody: ", newbody)
|
||
proxywasm.ReplaceHttpResponseBody([]byte(newbody))
|
||
|
||
log.Debug("[onHttpResponseBody] replace response success")
|
||
proxywasm.ResumeHttpResponse()
|
||
}
|
||
|
||
func toolsCallResult(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
|
||
if statusCode != http.StatusOK {
|
||
log.Debugf("statusCode: %d", statusCode)
|
||
}
|
||
log.Info("========function result========")
|
||
log.Infof(string(responseBody))
|
||
|
||
observation := "Observation: " + string(responseBody)
|
||
|
||
dashscope.MessageStore.AddForUser(observation)
|
||
|
||
completion := dashscope.Completion{
|
||
Model: llmInfo.Model,
|
||
Messages: dashscope.MessageStore,
|
||
MaxTokens: llmInfo.MaxTokens,
|
||
}
|
||
|
||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + llmInfo.APIKey}}
|
||
completionSerialized, _ := json.Marshal(completion)
|
||
err := llmClient.Post(
|
||
llmInfo.Path,
|
||
headers,
|
||
completionSerialized,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
//得到gpt的返回结果
|
||
var responseCompletion dashscope.CompletionResponse
|
||
_ = json.Unmarshal(responseBody, &responseCompletion)
|
||
log.Infof("[toolsCall] content: %s", responseCompletion.Choices[0].Message.Content)
|
||
|
||
if responseCompletion.Choices[0].Message.Content != "" {
|
||
retType, actionInput := toolsCall(ctx, llmClient, llmInfo, jsonResp, aPIsParam, aPIClient, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||
if retType == types.ActionContinue {
|
||
//得到了Final Answer
|
||
var assistantMessage Message
|
||
var streamMode bool
|
||
if ctx.GetContext(StreamContextKey) == nil {
|
||
streamMode = false
|
||
if jsonResp.Enable {
|
||
jsonFormat(llmClient, llmInfo, jsonResp.JsonSchema, assistantMessage, actionInput, headers, streamMode, rawResponse, log)
|
||
} else {
|
||
noneStream(assistantMessage, actionInput, rawResponse, log)
|
||
}
|
||
} else {
|
||
streamMode = true
|
||
if jsonResp.Enable {
|
||
jsonFormat(llmClient, llmInfo, jsonResp.JsonSchema, assistantMessage, actionInput, headers, streamMode, rawResponse, log)
|
||
} else {
|
||
stream(actionInput, rawResponse, log)
|
||
}
|
||
}
|
||
}
|
||
} else {
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
}, uint32(llmInfo.MaxExecutionTime))
|
||
if err != nil {
|
||
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
}
|
||
|
||
func outputParser(response string, log wrapper.Log) (string, string) {
|
||
log.Debugf("Raw response:%s", response)
|
||
|
||
start := strings.Index(response, "```")
|
||
end := strings.LastIndex(response, "```")
|
||
|
||
var jsonStr string
|
||
if start != -1 && end != -1 {
|
||
jsonStr = strings.TrimSpace(response[start+3 : end])
|
||
} else {
|
||
jsonStr = response
|
||
}
|
||
|
||
log.Debugf("Extracted JSON string:%s", jsonStr)
|
||
|
||
var action map[string]interface{}
|
||
err := json.Unmarshal([]byte(jsonStr), &action)
|
||
if err == nil {
|
||
var actionName, actionInput string
|
||
for key, value := range action {
|
||
if strings.Contains(strings.ToLower(key), "input") {
|
||
actionInput = fmt.Sprintf("%v", value)
|
||
} else {
|
||
actionName = fmt.Sprintf("%v", value)
|
||
}
|
||
}
|
||
if actionName != "" && actionInput != "" {
|
||
return actionName, actionInput
|
||
}
|
||
}
|
||
log.Debugf("json parse err: %s", err.Error())
|
||
// Fallback to regex parsing if JSON unmarshaling fails
|
||
pattern := `\{\s*"action":\s*"([^"]+)",\s*"action_input":\s*((?:\{[^}]+\})|"[^"]+")\s*\}`
|
||
re := regexp.MustCompile(pattern)
|
||
match := re.FindStringSubmatch(jsonStr)
|
||
|
||
if len(match) == 3 {
|
||
action := match[1]
|
||
actionInput := match[2]
|
||
log.Debugf("Parsed action:%s, action_input:%s", action, actionInput)
|
||
return action, actionInput
|
||
}
|
||
|
||
log.Debug("No valid action and action_input found in the response")
|
||
return "", ""
|
||
}
|
||
|
||
func toolsCall(ctx wrapper.HttpContext, llmClient wrapper.HttpClient, llmInfo LLMInfo, jsonResp JsonResp, aPIsParam []APIsParam, aPIClient []wrapper.HttpClient, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
|
||
dashscope.MessageStore.AddForAssistant(content)
|
||
|
||
action, actionInput := outputParser(content, log)
|
||
|
||
//得到最终答案
|
||
if action == "Final Answer" {
|
||
return types.ActionContinue, actionInput
|
||
}
|
||
count := ctx.GetContext(ToolCallsCount).(int)
|
||
count++
|
||
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d", count, llmInfo.MaxIterations)
|
||
//函数递归调用次数,达到了预设的循环次数,强制结束
|
||
if int64(count) > llmInfo.MaxIterations {
|
||
ctx.SetContext(ToolCallsCount, 0)
|
||
return types.ActionContinue, ""
|
||
} else {
|
||
ctx.SetContext(ToolCallsCount, count)
|
||
}
|
||
|
||
//没得到最终答案
|
||
|
||
var urlStr string
|
||
var headers [][2]string
|
||
var apiClient wrapper.HttpClient
|
||
var method string
|
||
var reqBody []byte
|
||
var maxExecutionTime int64
|
||
|
||
for i, apisParam := range aPIsParam {
|
||
maxExecutionTime = apisParam.MaxExecutionTime
|
||
for _, tools_param := range apisParam.ToolsParam {
|
||
if action == tools_param.ToolName {
|
||
log.Infof("calls %s", tools_param.ToolName)
|
||
log.Infof("actionInput: %s", actionInput)
|
||
|
||
//将大模型需要的参数反序列化
|
||
var data map[string]interface{}
|
||
if err := json.Unmarshal([]byte(actionInput), &data); err != nil {
|
||
log.Debugf("Error: %s", err.Error())
|
||
return types.ActionContinue, ""
|
||
}
|
||
|
||
method = tools_param.Method
|
||
|
||
// 组装 URL 和请求体
|
||
urlStr = apisParam.URL + tools_param.Path
|
||
|
||
// 解析URL模板以查找路径参数
|
||
urlParts := strings.Split(urlStr, "/")
|
||
for i, part := range urlParts {
|
||
if strings.Contains(part, "{") && strings.Contains(part, "}") {
|
||
for _, param := range tools_param.ParamName {
|
||
paramNameInPath := part[1 : len(part)-1]
|
||
if paramNameInPath == param {
|
||
if value, ok := data[param]; ok {
|
||
// 删除已经使用过的
|
||
delete(data, param)
|
||
// 替换模板中的占位符
|
||
urlParts[i] = url.QueryEscape(value.(string))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 重新组合URL
|
||
urlStr = strings.Join(urlParts, "/")
|
||
|
||
queryParams := make([][2]string, 0)
|
||
if method == "GET" {
|
||
for _, param := range tools_param.ParamName {
|
||
if value, ok := data[param]; ok {
|
||
queryParams = append(queryParams, [2]string{param, fmt.Sprintf("%v", value)})
|
||
}
|
||
}
|
||
} else if method == "POST" {
|
||
var err error
|
||
reqBody, err = json.Marshal(data)
|
||
if err != nil {
|
||
log.Debugf("Error marshaling JSON: %s", err.Error())
|
||
return types.ActionContinue, ""
|
||
}
|
||
}
|
||
|
||
// 组装 headers 和 key
|
||
headers = [][2]string{{"Content-Type", "application/json"}}
|
||
if apisParam.APIKey.Name != "" {
|
||
if apisParam.APIKey.In == "query" {
|
||
queryParams = append(queryParams, [2]string{apisParam.APIKey.Name, apisParam.APIKey.Value})
|
||
} else if apisParam.APIKey.In == "header" {
|
||
headers = append(headers, [2]string{"Authorization", apisParam.APIKey.Name + " " + apisParam.APIKey.Value})
|
||
}
|
||
}
|
||
|
||
if len(queryParams) > 0 {
|
||
// 将 key 拼接到 url 后面
|
||
urlStr += "?"
|
||
for i, param := range queryParams {
|
||
if i != 0 {
|
||
urlStr += "&"
|
||
}
|
||
urlStr += url.QueryEscape(param[0]) + "=" + url.QueryEscape(param[1])
|
||
}
|
||
}
|
||
|
||
log.Debugf("url: %s", urlStr)
|
||
|
||
apiClient = aPIClient[i]
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if apiClient != nil {
|
||
err := apiClient.Call(
|
||
method,
|
||
urlStr,
|
||
headers,
|
||
reqBody,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
toolsCallResult(ctx, llmClient, llmInfo, jsonResp, aPIsParam, aPIClient, content, rawResponse, log, statusCode, responseBody)
|
||
}, uint32(maxExecutionTime))
|
||
if err != nil {
|
||
log.Debugf("tool calls error: %s", err.Error())
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
} else {
|
||
return types.ActionContinue, ""
|
||
}
|
||
|
||
return types.ActionPause, ""
|
||
}
|
||
|
||
// 从response接收到firstreq的大模型返回
|
||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Debugf("onHttpResponseBody start")
|
||
defer log.Debugf("onHttpResponseBody end")
|
||
|
||
//初始化接收gpt返回内容的结构体
|
||
var rawResponse Response
|
||
err := json.Unmarshal(body, &rawResponse)
|
||
if err != nil {
|
||
log.Debugf("[onHttpResponseBody] body to json err: %s", err.Error())
|
||
return types.ActionContinue
|
||
}
|
||
log.Infof("first content: %s", rawResponse.Choices[0].Message.Content)
|
||
//如果gpt返回的内容不是空的
|
||
if rawResponse.Choices[0].Message.Content != "" {
|
||
//进入agent的循环思考,工具调用的过程中
|
||
retType, _ := toolsCall(ctx, config.LLMClient, config.LLMInfo, config.JsonResp, config.APIsParam, config.APIClient, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||
return retType
|
||
} else {
|
||
return types.ActionContinue
|
||
}
|
||
}
|