mirror of
https://github.com/alibaba/higress.git
synced 2026-03-06 17:40:51 +08:00
322 lines
10 KiB
Go
322 lines
10 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"regexp"
|
||
"strings"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/dashscope"
|
||
prompttpl "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-agent/promptTpl"
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
func main() {
|
||
wrapper.SetCtx(
|
||
"ai-agent",
|
||
wrapper.ParseConfigBy(parseConfig),
|
||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
||
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
||
)
|
||
}
|
||
|
||
func parseConfig(gjson gjson.Result, c *PluginConfig, log wrapper.Log) error {
|
||
initResponsePromptTpl(gjson, c)
|
||
|
||
err := initAPIs(gjson, c)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
initReActPromptTpl(gjson, c)
|
||
|
||
initLLMClient(gjson, c)
|
||
|
||
return nil
|
||
}
|
||
|
||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
func firstReq(config PluginConfig, prompt string, rawRequest Request, log wrapper.Log) types.Action {
|
||
log.Debugf("[onHttpRequestBody] firstreq:%s", prompt)
|
||
|
||
var userMessage Message
|
||
userMessage.Role = "user"
|
||
userMessage.Content = prompt
|
||
|
||
newMessages := []Message{userMessage}
|
||
rawRequest.Messages = newMessages
|
||
|
||
//replace old message and resume request qwen
|
||
newbody, err := json.Marshal(rawRequest)
|
||
if err != nil {
|
||
return types.ActionContinue
|
||
} else {
|
||
log.Debugf("[onHttpRequestBody] newRequestBody: ", string(newbody))
|
||
err := proxywasm.ReplaceHttpRequestBody(newbody)
|
||
if err != nil {
|
||
log.Debug("替换失败")
|
||
proxywasm.SendHttpResponse(200, [][2]string{{"content-type", "application/json; charset=utf-8"}}, []byte(fmt.Sprintf(config.ReturnResponseTemplate, "替换失败"+err.Error())), -1)
|
||
}
|
||
log.Debug("[onHttpRequestBody] request替换成功")
|
||
return types.ActionContinue
|
||
}
|
||
}
|
||
|
||
func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Debug("onHttpRequestBody start")
|
||
defer log.Debug("onHttpRequestBody end")
|
||
|
||
//拿到请求
|
||
var rawRequest Request
|
||
err := json.Unmarshal(body, &rawRequest)
|
||
if err != nil {
|
||
log.Debugf("[onHttpRequestBody] body json umarshal err: ", err.Error())
|
||
return types.ActionContinue
|
||
}
|
||
log.Debugf("onHttpRequestBody rawRequest: %v", rawRequest)
|
||
|
||
//获取用户query
|
||
var query string
|
||
messageLength := len(rawRequest.Messages)
|
||
log.Debugf("[onHttpRequestBody] messageLength: %s\n", messageLength)
|
||
if messageLength > 0 {
|
||
query = rawRequest.Messages[messageLength-1].Content
|
||
log.Debugf("[onHttpRequestBody] query: %s\n", query)
|
||
} else {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
if query == "" {
|
||
log.Debug("parse query from request body failed")
|
||
return types.ActionContinue
|
||
}
|
||
|
||
//拼装agent prompt模板
|
||
tool_desc := make([]string, 0)
|
||
tool_names := make([]string, 0)
|
||
for _, api_param := range config.API_Param {
|
||
for _, tool_param := range api_param.Tool_Param {
|
||
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Desciption, tool_param.Desciption, tool_param.Desciption, tool_param.Parameter), "\n")
|
||
tool_names = append(tool_names, tool_param.ToolName)
|
||
}
|
||
}
|
||
|
||
var prompt string
|
||
if config.PromptTemplate.Language == "CH" {
|
||
prompt = fmt.Sprintf(prompttpl.CH_Template,
|
||
tool_desc,
|
||
config.PromptTemplate.CHTemplate.Question,
|
||
config.PromptTemplate.CHTemplate.Thought1,
|
||
tool_names,
|
||
config.PromptTemplate.CHTemplate.ActionInput,
|
||
config.PromptTemplate.CHTemplate.Observation,
|
||
config.PromptTemplate.CHTemplate.FinalAnswer,
|
||
config.PromptTemplate.CHTemplate.Begin,
|
||
query)
|
||
} else {
|
||
prompt = fmt.Sprintf(prompttpl.EN_Template,
|
||
tool_desc,
|
||
config.PromptTemplate.ENTemplate.Question,
|
||
config.PromptTemplate.ENTemplate.Thought1,
|
||
tool_names,
|
||
config.PromptTemplate.ENTemplate.ActionInput,
|
||
config.PromptTemplate.ENTemplate.Observation,
|
||
config.PromptTemplate.ENTemplate.FinalAnswer,
|
||
config.PromptTemplate.ENTemplate.Begin,
|
||
query)
|
||
}
|
||
|
||
//将请求加入到历史对话存储器中
|
||
dashscope.MessageStore.AddForUser(prompt)
|
||
|
||
//开始第一次请求
|
||
ret := firstReq(config, prompt, rawRequest, log)
|
||
|
||
return ret
|
||
}
|
||
|
||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
func toolsCall(config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
|
||
dashscope.MessageStore.AddForAssistant(content)
|
||
|
||
//得到最终答案
|
||
regexPattern := regexp.MustCompile(`Final Answer:(.*)`)
|
||
finalAnswer := regexPattern.FindStringSubmatch(content)
|
||
if len(finalAnswer) > 1 {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
//没得到最终答案
|
||
regexAction := regexp.MustCompile(`Action:\s*(.*?)[.\n]`)
|
||
regexActionInput := regexp.MustCompile(`Action Input:\s*(.*)`)
|
||
|
||
action := regexAction.FindStringSubmatch(content)
|
||
actionInput := regexActionInput.FindStringSubmatch(content)
|
||
|
||
if len(action) > 1 && len(actionInput) > 1 {
|
||
var url string
|
||
var headers [][2]string
|
||
var apiClient wrapper.HttpClient
|
||
var method string
|
||
|
||
for i, api_param := range config.API_Param {
|
||
for _, tool_param := range api_param.Tool_Param {
|
||
if action[1] == tool_param.ToolName {
|
||
log.Infof("calls %s\n", tool_param.ToolName)
|
||
log.Infof("actionInput[1]: %s", actionInput[1])
|
||
|
||
//将大模型需要的参数反序列化
|
||
var data map[string]interface{}
|
||
if err := json.Unmarshal([]byte(actionInput[1]), &data); err != nil {
|
||
log.Debugf("Error: %s\n", err.Error())
|
||
return types.ActionContinue
|
||
}
|
||
|
||
var args string
|
||
for i, param := range tool_param.ParamName { //从参数列表中取出参数
|
||
if i == 0 {
|
||
args = "?" + param + "=%s"
|
||
args = fmt.Sprintf(args, data[param])
|
||
} else {
|
||
args = args + "&" + param + "=%s"
|
||
args = fmt.Sprintf(args, data[param])
|
||
}
|
||
}
|
||
|
||
url = api_param.URL + tool_param.Path + args
|
||
|
||
if api_param.APIKey.Name != "" {
|
||
if api_param.APIKey.In == "query" {
|
||
headers = nil
|
||
key := "&" + api_param.APIKey.Name + "=" + api_param.APIKey.Value
|
||
url += key
|
||
} else if api_param.APIKey.In == "header" {
|
||
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", api_param.APIKey.Name + " " + api_param.APIKey.Value}}
|
||
}
|
||
}
|
||
|
||
log.Infof("url: %s\n", url)
|
||
|
||
method = tool_param.Method
|
||
|
||
apiClient = config.APIClient[i]
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if method == "get" {
|
||
//调用工具
|
||
err := apiClient.Get(
|
||
url,
|
||
headers,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
if statusCode != http.StatusOK {
|
||
log.Debugf("statusCode: %d\n", statusCode)
|
||
}
|
||
log.Info("========函数返回结果========")
|
||
log.Infof(string(responseBody))
|
||
|
||
Observation := "Observation: " + string(responseBody)
|
||
|
||
dashscope.MessageStore.AddForUser(Observation)
|
||
|
||
completion := dashscope.Completion{
|
||
Model: config.LLMInfo.Model,
|
||
Messages: dashscope.MessageStore,
|
||
}
|
||
|
||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
|
||
completionSerialized, _ := json.Marshal(completion)
|
||
err := config.LLMClient.Post(
|
||
config.LLMInfo.Path,
|
||
headers,
|
||
completionSerialized,
|
||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
//得到gpt的返回结果
|
||
var responseCompletion dashscope.CompletionResponse
|
||
_ = json.Unmarshal(responseBody, &responseCompletion)
|
||
log.Infof("[toolsCall] content: ", responseCompletion.Choices[0].Message.Content)
|
||
|
||
if responseCompletion.Choices[0].Message.Content != "" {
|
||
retType := toolsCall(config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||
if retType == types.ActionContinue {
|
||
//得到了Final Answer
|
||
var assistantMessage Message
|
||
assistantMessage.Role = "assistant"
|
||
startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:")
|
||
if startIndex != -1 {
|
||
startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置
|
||
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
|
||
assistantMessage.Content = extractedText
|
||
}
|
||
//assistantMessage.Content = responseCompletion.Choices[0].Message.Content
|
||
rawResponse.Choices[0].Message = assistantMessage
|
||
|
||
newbody, err := json.Marshal(rawResponse)
|
||
if err != nil {
|
||
proxywasm.ResumeHttpResponse()
|
||
return
|
||
} else {
|
||
log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody))
|
||
proxywasm.ReplaceHttpResponseBody(newbody)
|
||
|
||
log.Debug("[onHttpResponseBody] response替换成功")
|
||
proxywasm.ResumeHttpResponse()
|
||
}
|
||
}
|
||
} else {
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
}, 50000)
|
||
if err != nil {
|
||
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
}, 50000)
|
||
if err != nil {
|
||
log.Debugf("tool calls error: %s\n", err.Error())
|
||
proxywasm.ResumeHttpRequest()
|
||
}
|
||
} else {
|
||
return types.ActionContinue
|
||
}
|
||
|
||
}
|
||
return types.ActionPause
|
||
}
|
||
|
||
// 从response接收到firstreq的大模型返回
|
||
func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byte, log wrapper.Log) types.Action {
|
||
log.Debugf("onHttpResponseBody start")
|
||
defer log.Debugf("onHttpResponseBody end")
|
||
|
||
//初始化接收gpt返回内容的结构体
|
||
var rawResponse Response
|
||
err := json.Unmarshal(body, &rawResponse)
|
||
if err != nil {
|
||
log.Debugf("[onHttpResponseBody] body to json err: %s", err.Error())
|
||
return types.ActionContinue
|
||
}
|
||
log.Infof("first content: %s\n", rawResponse.Choices[0].Message.Content)
|
||
//如果gpt返回的内容不是空的
|
||
if rawResponse.Choices[0].Message.Content != "" {
|
||
//进入agent的循环思考,工具调用的过程中
|
||
return toolsCall(config, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||
} else {
|
||
return types.ActionContinue
|
||
}
|
||
}
|