update: Add support for post tools, add round limits, per-round token… (#1230)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
xingyunyang01
2024-08-22 16:33:42 +08:00
committed by GitHub
parent 29fcd330d5
commit 895f17f8d8
4 changed files with 345 additions and 137 deletions

View File

@@ -15,6 +15,14 @@ import (
"github.com/tidwall/gjson"
)
// 用于统计函数的递归调用次数
const ToolCallsCount = "ToolCallsCount"
// react的正则规则
const ActionPattern = `Action:\s*(.*?)[.\n]`
const ActionInputPattern = `Action Input:\s*(.*)`
const FinalAnswerPattern = `Final Answer:(.*)`
func main() {
wrapper.SetCtx(
"ai-agent",
@@ -103,9 +111,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
//拼装agent prompt模板
tool_desc := make([]string, 0)
tool_names := make([]string, 0)
for _, api_param := range config.API_Param {
for _, tool_param := range api_param.Tool_Param {
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Desciption, tool_param.Desciption, tool_param.Desciption, tool_param.Parameter), "\n")
for _, apiParam := range config.APIParam {
for _, tool_param := range apiParam.Tool_Param {
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Description, tool_param.Description, tool_param.Description, tool_param.Parameter), "\n")
tool_names = append(tool_names, tool_param.ToolName)
}
}
@@ -119,6 +127,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
tool_names,
config.PromptTemplate.CHTemplate.ActionInput,
config.PromptTemplate.CHTemplate.Observation,
config.PromptTemplate.CHTemplate.Thought2,
config.PromptTemplate.CHTemplate.FinalAnswer,
config.PromptTemplate.CHTemplate.Begin,
query)
@@ -130,11 +139,17 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
tool_names,
config.PromptTemplate.ENTemplate.ActionInput,
config.PromptTemplate.ENTemplate.Observation,
config.PromptTemplate.ENTemplate.Thought2,
config.PromptTemplate.ENTemplate.FinalAnswer,
config.PromptTemplate.ENTemplate.Begin,
query)
}
ctx.SetContext(ToolCallsCount, 0)
//清理历史对话记录
dashscope.MessageStore.Clear()
//将请求加入到历史对话存储器中
dashscope.MessageStore.AddForUser(prompt)
@@ -145,22 +160,101 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
}
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
log.Debug("onHttpResponseHeaders start")
defer log.Debug("onHttpResponseHeaders end")
return types.ActionContinue
}
func toolsCall(config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
if statusCode != http.StatusOK {
log.Debugf("statusCode: %d\n", statusCode)
}
log.Info("========函数返回结果========")
log.Infof(string(responseBody))
observation := "Observation: " + string(responseBody)
dashscope.MessageStore.AddForUser(observation)
completion := dashscope.Completion{
Model: config.LLMInfo.Model,
Messages: dashscope.MessageStore,
MaxTokens: config.LLMInfo.MaxTokens,
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
completionSerialized, _ := json.Marshal(completion)
err := config.LLMClient.Post(
config.LLMInfo.Path,
headers,
completionSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
//得到gpt的返回结果
var responseCompletion dashscope.CompletionResponse
_ = json.Unmarshal(responseBody, &responseCompletion)
log.Infof("[toolsCall] content: %s\n", responseCompletion.Choices[0].Message.Content)
if responseCompletion.Choices[0].Message.Content != "" {
retType := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
if retType == types.ActionContinue {
//得到了Final Answer
var assistantMessage Message
assistantMessage.Role = "assistant"
startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:")
if startIndex != -1 {
startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
assistantMessage.Content = extractedText
}
rawResponse.Choices[0].Message = assistantMessage
newbody, err := json.Marshal(rawResponse)
if err != nil {
proxywasm.ResumeHttpResponse()
return
} else {
log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody))
proxywasm.ReplaceHttpResponseBody(newbody)
log.Debug("[onHttpResponseBody] response替换成功")
proxywasm.ResumeHttpResponse()
}
}
} else {
proxywasm.ResumeHttpRequest()
}
}, uint32(config.LLMInfo.MaxExecutionTime))
if err != nil {
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
proxywasm.ResumeHttpRequest()
}
}
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
dashscope.MessageStore.AddForAssistant(content)
//得到最终答案
regexPattern := regexp.MustCompile(`Final Answer:(.*)`)
regexPattern := regexp.MustCompile(FinalAnswerPattern)
finalAnswer := regexPattern.FindStringSubmatch(content)
if len(finalAnswer) > 1 {
return types.ActionContinue
}
count := ctx.GetContext(ToolCallsCount).(int)
count++
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d\n", count, config.LLMInfo.MaxIterations)
//函数递归调用次数,达到了预设的循环次数,强制结束
if int64(count) > config.LLMInfo.MaxIterations {
ctx.SetContext(ToolCallsCount, 0)
return types.ActionContinue
} else {
ctx.SetContext(ToolCallsCount, count)
}
//没得到最终答案
regexAction := regexp.MustCompile(`Action:\s*(.*?)[.\n]`)
regexActionInput := regexp.MustCompile(`Action Input:\s*(.*)`)
regexAction := regexp.MustCompile(ActionPattern)
regexActionInput := regexp.MustCompile(ActionInputPattern)
action := regexAction.FindStringSubmatch(content)
actionInput := regexActionInput.FindStringSubmatch(content)
@@ -170,9 +264,11 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
var headers [][2]string
var apiClient wrapper.HttpClient
var method string
var reqBody []byte
var key string
for i, api_param := range config.API_Param {
for _, tool_param := range api_param.Tool_Param {
for i, apiParam := range config.APIParam {
for _, tool_param := range apiParam.Tool_Param {
if action[1] == tool_param.ToolName {
log.Infof("calls %s\n", tool_param.ToolName)
log.Infof("actionInput[1]: %s", actionInput[1])
@@ -184,107 +280,63 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
return types.ActionContinue
}
var args string
for i, param := range tool_param.ParamName { //从参数列表中取出参数
if i == 0 {
args = "?" + param + "=%s"
args = fmt.Sprintf(args, data[param])
} else {
args = args + "&" + param + "=%s"
args = fmt.Sprintf(args, data[param])
method = tool_param.Method
//key or header组装
if apiParam.APIKey.Name != "" {
if apiParam.APIKey.In == "query" { //query类型的key要放到url中
headers = nil
key = "?" + apiParam.APIKey.Name + "=" + apiParam.APIKey.Value
} else if apiParam.APIKey.In == "header" { //header类型的key放在header中
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", apiParam.APIKey.Name + " " + apiParam.APIKey.Value}}
}
}
url = api_param.URL + tool_param.Path + args
if api_param.APIKey.Name != "" {
if api_param.APIKey.In == "query" {
headers = nil
key := "&" + api_param.APIKey.Name + "=" + api_param.APIKey.Value
url += key
} else if api_param.APIKey.In == "header" {
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", api_param.APIKey.Name + " " + api_param.APIKey.Value}}
if method == "GET" {
//query组装
var args string
for i, param := range tool_param.ParamName { //从参数列表中取出参数
if i == 0 && apiParam.APIKey.In != "query" {
args = "?" + param + "=%s"
args = fmt.Sprintf(args, data[param])
} else {
args = args + "&" + param + "=%s"
args = fmt.Sprintf(args, data[param])
}
}
//url组装
url = apiParam.URL + tool_param.Path + key + args
} else if method == "POST" {
reqBody = nil
//json参数组装
jsonData, err := json.Marshal(data)
if err != nil {
log.Debugf("Error: %s\n", err.Error())
return types.ActionContinue
}
reqBody = jsonData
//url组装
url = apiParam.URL + tool_param.Path + key
}
log.Infof("url: %s\n", url)
method = tool_param.Method
apiClient = config.APIClient[i]
break
}
}
}
if method == "get" {
//调用工具
err := apiClient.Get(
if apiClient != nil {
err := apiClient.Call(
method,
url,
headers,
reqBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
log.Debugf("statusCode: %d\n", statusCode)
}
log.Info("========函数返回结果========")
log.Infof(string(responseBody))
Observation := "Observation: " + string(responseBody)
dashscope.MessageStore.AddForUser(Observation)
completion := dashscope.Completion{
Model: config.LLMInfo.Model,
Messages: dashscope.MessageStore,
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
completionSerialized, _ := json.Marshal(completion)
err := config.LLMClient.Post(
config.LLMInfo.Path,
headers,
completionSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
//得到gpt的返回结果
var responseCompletion dashscope.CompletionResponse
_ = json.Unmarshal(responseBody, &responseCompletion)
log.Infof("[toolsCall] content: ", responseCompletion.Choices[0].Message.Content)
if responseCompletion.Choices[0].Message.Content != "" {
retType := toolsCall(config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
if retType == types.ActionContinue {
//得到了Final Answer
var assistantMessage Message
assistantMessage.Role = "assistant"
startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:")
if startIndex != -1 {
startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
assistantMessage.Content = extractedText
}
//assistantMessage.Content = responseCompletion.Choices[0].Message.Content
rawResponse.Choices[0].Message = assistantMessage
newbody, err := json.Marshal(rawResponse)
if err != nil {
proxywasm.ResumeHttpResponse()
return
} else {
log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody))
proxywasm.ReplaceHttpResponseBody(newbody)
log.Debug("[onHttpResponseBody] response替换成功")
proxywasm.ResumeHttpResponse()
}
}
} else {
proxywasm.ResumeHttpRequest()
}
}, 50000)
if err != nil {
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
proxywasm.ResumeHttpRequest()
}
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
}, 50000)
if err != nil {
log.Debugf("tool calls error: %s\n", err.Error())
@@ -293,7 +345,6 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
} else {
return types.ActionContinue
}
}
return types.ActionPause
}
@@ -314,7 +365,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byt
//如果gpt返回的内容不是空的
if rawResponse.Choices[0].Message.Content != "" {
//进入agent的循环思考工具调用的过程中
return toolsCall(config, rawResponse.Choices[0].Message.Content, rawResponse, log)
return toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
} else {
return types.ActionContinue
}