mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 12:47:28 +08:00
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
// 用于统计函数的递归调用次数
|
||||
const ToolCallsCount = "ToolCallsCount"
|
||||
const StreamContextKey = "Stream"
|
||||
|
||||
// react的正则规则
|
||||
const ActionPattern = `Action:\s*(.*?)[.\n]`
|
||||
@@ -53,7 +54,7 @@ func onHttpRequestHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrap
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func firstReq(config PluginConfig, prompt string, rawRequest Request, log wrapper.Log) types.Action {
|
||||
func firstReq(ctx wrapper.HttpContext, config PluginConfig, prompt string, rawRequest Request, log wrapper.Log) types.Action {
|
||||
log.Debugf("[onHttpRequestBody] firstreq:%s", prompt)
|
||||
|
||||
var userMessage Message
|
||||
@@ -62,13 +63,17 @@ func firstReq(config PluginConfig, prompt string, rawRequest Request, log wrappe
|
||||
|
||||
newMessages := []Message{userMessage}
|
||||
rawRequest.Messages = newMessages
|
||||
if rawRequest.Stream {
|
||||
ctx.SetContext(StreamContextKey, struct{}{})
|
||||
rawRequest.Stream = false
|
||||
}
|
||||
|
||||
//replace old message and resume request qwen
|
||||
newbody, err := json.Marshal(rawRequest)
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
log.Debugf("[onHttpRequestBody] newRequestBody: ", string(newbody))
|
||||
log.Debugf("[onHttpRequestBody] newRequestBody: %s", string(newbody))
|
||||
err := proxywasm.ReplaceHttpRequestBody(newbody)
|
||||
if err != nil {
|
||||
log.Debug("替换失败")
|
||||
@@ -87,18 +92,26 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
var rawRequest Request
|
||||
err := json.Unmarshal(body, &rawRequest)
|
||||
if err != nil {
|
||||
log.Debugf("[onHttpRequestBody] body json umarshal err: ", err.Error())
|
||||
log.Debugf("[onHttpRequestBody] body json umarshal err: %s", err.Error())
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Debugf("onHttpRequestBody rawRequest: %v", rawRequest)
|
||||
|
||||
//获取用户query
|
||||
var query string
|
||||
var history string
|
||||
messageLength := len(rawRequest.Messages)
|
||||
log.Debugf("[onHttpRequestBody] messageLength: %s\n", messageLength)
|
||||
log.Debugf("[onHttpRequestBody] messageLength: %s", messageLength)
|
||||
if messageLength > 0 {
|
||||
query = rawRequest.Messages[messageLength-1].Content
|
||||
log.Debugf("[onHttpRequestBody] query: %s\n", query)
|
||||
log.Debugf("[onHttpRequestBody] query: %s", query)
|
||||
if messageLength >= 3 {
|
||||
for i := 0; i < messageLength-1; i += 2 {
|
||||
history += "human: " + rawRequest.Messages[i].Content + "\nAI: " + rawRequest.Messages[i+1].Content
|
||||
}
|
||||
} else {
|
||||
history = ""
|
||||
}
|
||||
} else {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -111,8 +124,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
//拼装agent prompt模板
|
||||
tool_desc := make([]string, 0)
|
||||
tool_names := make([]string, 0)
|
||||
for _, apiParam := range config.APIParam {
|
||||
for _, tool_param := range apiParam.Tool_Param {
|
||||
for _, apisParam := range config.APIsParam {
|
||||
for _, tool_param := range apisParam.ToolsParam {
|
||||
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Description, tool_param.Description, tool_param.Description, tool_param.Parameter), "\n")
|
||||
tool_names = append(tool_names, tool_param.ToolName)
|
||||
}
|
||||
@@ -122,26 +135,22 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
if config.PromptTemplate.Language == "CH" {
|
||||
prompt = fmt.Sprintf(prompttpl.CH_Template,
|
||||
tool_desc,
|
||||
tool_names,
|
||||
config.PromptTemplate.CHTemplate.Question,
|
||||
config.PromptTemplate.CHTemplate.Thought1,
|
||||
tool_names,
|
||||
config.PromptTemplate.CHTemplate.ActionInput,
|
||||
config.PromptTemplate.CHTemplate.Observation,
|
||||
config.PromptTemplate.CHTemplate.Thought2,
|
||||
config.PromptTemplate.CHTemplate.FinalAnswer,
|
||||
config.PromptTemplate.CHTemplate.Begin,
|
||||
history,
|
||||
query)
|
||||
} else {
|
||||
prompt = fmt.Sprintf(prompttpl.EN_Template,
|
||||
tool_desc,
|
||||
tool_names,
|
||||
config.PromptTemplate.ENTemplate.Question,
|
||||
config.PromptTemplate.ENTemplate.Thought1,
|
||||
tool_names,
|
||||
config.PromptTemplate.ENTemplate.ActionInput,
|
||||
config.PromptTemplate.ENTemplate.Observation,
|
||||
config.PromptTemplate.ENTemplate.Thought2,
|
||||
config.PromptTemplate.ENTemplate.FinalAnswer,
|
||||
config.PromptTemplate.ENTemplate.Begin,
|
||||
history,
|
||||
query)
|
||||
}
|
||||
|
||||
@@ -154,7 +163,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
||||
dashscope.MessageStore.AddForUser(prompt)
|
||||
|
||||
//开始第一次请求
|
||||
ret := firstReq(config, prompt, rawRequest, log)
|
||||
ret := firstReq(ctx, config, prompt, rawRequest, log)
|
||||
|
||||
return ret
|
||||
}
|
||||
@@ -168,7 +177,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wra
|
||||
|
||||
func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
log.Debugf("statusCode: %d\n", statusCode)
|
||||
log.Debugf("statusCode: %d", statusCode)
|
||||
}
|
||||
log.Info("========函数返回结果========")
|
||||
log.Infof(string(responseBody))
|
||||
@@ -193,30 +202,36 @@ func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content strin
|
||||
//得到gpt的返回结果
|
||||
var responseCompletion dashscope.CompletionResponse
|
||||
_ = json.Unmarshal(responseBody, &responseCompletion)
|
||||
log.Infof("[toolsCall] content: %s\n", responseCompletion.Choices[0].Message.Content)
|
||||
log.Infof("[toolsCall] content: %s", responseCompletion.Choices[0].Message.Content)
|
||||
|
||||
if responseCompletion.Choices[0].Message.Content != "" {
|
||||
retType := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||||
retType, actionInput := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||||
if retType == types.ActionContinue {
|
||||
//得到了Final Answer
|
||||
var assistantMessage Message
|
||||
assistantMessage.Role = "assistant"
|
||||
startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:")
|
||||
if startIndex != -1 {
|
||||
startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置
|
||||
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
|
||||
assistantMessage.Content = extractedText
|
||||
}
|
||||
if ctx.GetContext(StreamContextKey) == nil {
|
||||
assistantMessage.Role = "assistant"
|
||||
assistantMessage.Content = actionInput
|
||||
rawResponse.Choices[0].Message = assistantMessage
|
||||
newbody, err := json.Marshal(rawResponse)
|
||||
if err != nil {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
} else {
|
||||
proxywasm.ReplaceHttpResponseBody(newbody)
|
||||
|
||||
rawResponse.Choices[0].Message = assistantMessage
|
||||
|
||||
newbody, err := json.Marshal(rawResponse)
|
||||
if err != nil {
|
||||
proxywasm.ResumeHttpResponse()
|
||||
return
|
||||
log.Debug("[onHttpResponseBody] response替换成功")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
}
|
||||
} else {
|
||||
log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody))
|
||||
proxywasm.ReplaceHttpResponseBody(newbody)
|
||||
headers := [][2]string{{"content-type", "text/event-stream; charset=utf-8"}}
|
||||
proxywasm.ReplaceHttpResponseHeaders(headers)
|
||||
// Remove quotes from actionInput
|
||||
actionInput = strings.Trim(actionInput, "\"")
|
||||
returnStreamResponseTemplate := `data:{"id":"%s","choices":[{"index":0,"delta":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"%s","object":"chat.completion","usage":{"prompt_tokens":%d,"completion_tokens":%d,"total_tokens":%d}}` + "\n\ndata:[DONE]\n\n"
|
||||
newbody := fmt.Sprintf(returnStreamResponseTemplate, rawResponse.ID, actionInput, rawResponse.Model, rawResponse.Usage.PromptTokens, rawResponse.Usage.CompletionTokens, rawResponse.Usage.TotalTokens)
|
||||
log.Infof("[onHttpResponseBody] newResponseBody: ", newbody)
|
||||
proxywasm.ReplaceHttpResponseBody([]byte(newbody))
|
||||
|
||||
log.Debug("[onHttpResponseBody] response替换成功")
|
||||
proxywasm.ResumeHttpResponse()
|
||||
@@ -232,121 +247,156 @@ func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content strin
|
||||
}
|
||||
}
|
||||
|
||||
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
|
||||
func outputParser(response string, log wrapper.Log) (string, string) {
|
||||
log.Debugf("Raw response:%s", response)
|
||||
|
||||
start := strings.Index(response, "```")
|
||||
end := strings.LastIndex(response, "```")
|
||||
|
||||
var jsonStr string
|
||||
if start != -1 && end != -1 {
|
||||
jsonStr = strings.TrimSpace(response[start+3 : end])
|
||||
} else {
|
||||
jsonStr = response
|
||||
}
|
||||
|
||||
log.Debugf("Extracted JSON string:%s", jsonStr)
|
||||
|
||||
var action map[string]interface{}
|
||||
err := json.Unmarshal([]byte(jsonStr), &action)
|
||||
if err == nil {
|
||||
var actionName, actionInput string
|
||||
for key, value := range action {
|
||||
if strings.Contains(strings.ToLower(key), "input") {
|
||||
actionInput = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
actionName = fmt.Sprintf("%v", value)
|
||||
}
|
||||
}
|
||||
if actionName != "" && actionInput != "" {
|
||||
return actionName, actionInput
|
||||
}
|
||||
}
|
||||
log.Debugf("json parse err: %s", err.Error())
|
||||
// Fallback to regex parsing if JSON unmarshaling fails
|
||||
pattern := `\{\s*"action":\s*"([^"]+)",\s*"action_input":\s*((?:\{[^}]+\})|"[^"]+")\s*\}`
|
||||
re := regexp.MustCompile(pattern)
|
||||
match := re.FindStringSubmatch(jsonStr)
|
||||
|
||||
if len(match) == 3 {
|
||||
action := match[1]
|
||||
actionInput := match[2]
|
||||
log.Debugf("Parsed action:%s, action_input:%s", action, actionInput)
|
||||
return action, actionInput
|
||||
}
|
||||
|
||||
log.Debug("No valid action and action_input found in the response")
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) (types.Action, string) {
|
||||
dashscope.MessageStore.AddForAssistant(content)
|
||||
|
||||
action, actionInput := outputParser(content, log)
|
||||
|
||||
//得到最终答案
|
||||
regexPattern := regexp.MustCompile(FinalAnswerPattern)
|
||||
finalAnswer := regexPattern.FindStringSubmatch(content)
|
||||
if len(finalAnswer) > 1 {
|
||||
return types.ActionContinue
|
||||
if action == "Final Answer" {
|
||||
return types.ActionContinue, actionInput
|
||||
}
|
||||
count := ctx.GetContext(ToolCallsCount).(int)
|
||||
count++
|
||||
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d\n", count, config.LLMInfo.MaxIterations)
|
||||
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d", count, config.LLMInfo.MaxIterations)
|
||||
//函数递归调用次数,达到了预设的循环次数,强制结束
|
||||
if int64(count) > config.LLMInfo.MaxIterations {
|
||||
ctx.SetContext(ToolCallsCount, 0)
|
||||
return types.ActionContinue
|
||||
return types.ActionContinue, ""
|
||||
} else {
|
||||
ctx.SetContext(ToolCallsCount, count)
|
||||
}
|
||||
|
||||
//没得到最终答案
|
||||
regexAction := regexp.MustCompile(ActionPattern)
|
||||
regexActionInput := regexp.MustCompile(ActionInputPattern)
|
||||
|
||||
action := regexAction.FindStringSubmatch(content)
|
||||
actionInput := regexActionInput.FindStringSubmatch(content)
|
||||
var url string
|
||||
var headers [][2]string
|
||||
var apiClient wrapper.HttpClient
|
||||
var method string
|
||||
var reqBody []byte
|
||||
var key string
|
||||
var maxExecutionTime int64
|
||||
|
||||
if len(action) > 1 && len(actionInput) > 1 {
|
||||
var url string
|
||||
var headers [][2]string
|
||||
var apiClient wrapper.HttpClient
|
||||
var method string
|
||||
var reqBody []byte
|
||||
var key string
|
||||
for i, apisParam := range config.APIsParam {
|
||||
maxExecutionTime = apisParam.MaxExecutionTime
|
||||
for _, tools_param := range apisParam.ToolsParam {
|
||||
if action == tools_param.ToolName {
|
||||
log.Infof("calls %s", tools_param.ToolName)
|
||||
log.Infof("actionInput: %s", actionInput)
|
||||
|
||||
for i, apiParam := range config.APIParam {
|
||||
for _, tool_param := range apiParam.Tool_Param {
|
||||
if action[1] == tool_param.ToolName {
|
||||
log.Infof("calls %s\n", tool_param.ToolName)
|
||||
log.Infof("actionInput[1]: %s", actionInput[1])
|
||||
|
||||
//将大模型需要的参数反序列化
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(actionInput[1]), &data); err != nil {
|
||||
log.Debugf("Error: %s\n", err.Error())
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
method = tool_param.Method
|
||||
|
||||
//key or header组装
|
||||
if apiParam.APIKey.Name != "" {
|
||||
if apiParam.APIKey.In == "query" { //query类型的key要放到url中
|
||||
headers = nil
|
||||
key = "?" + apiParam.APIKey.Name + "=" + apiParam.APIKey.Value
|
||||
} else if apiParam.APIKey.In == "header" { //header类型的key放在header中
|
||||
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", apiParam.APIKey.Name + " " + apiParam.APIKey.Value}}
|
||||
}
|
||||
}
|
||||
|
||||
if method == "GET" {
|
||||
//query组装
|
||||
var args string
|
||||
for i, param := range tool_param.ParamName { //从参数列表中取出参数
|
||||
if i == 0 && apiParam.APIKey.In != "query" {
|
||||
args = "?" + param + "=%s"
|
||||
args = fmt.Sprintf(args, data[param])
|
||||
} else {
|
||||
args = args + "&" + param + "=%s"
|
||||
args = fmt.Sprintf(args, data[param])
|
||||
}
|
||||
}
|
||||
|
||||
//url组装
|
||||
url = apiParam.URL + tool_param.Path + key + args
|
||||
} else if method == "POST" {
|
||||
reqBody = nil
|
||||
//json参数组装
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Debugf("Error: %s\n", err.Error())
|
||||
return types.ActionContinue
|
||||
}
|
||||
reqBody = jsonData
|
||||
|
||||
//url组装
|
||||
url = apiParam.URL + tool_param.Path + key
|
||||
}
|
||||
|
||||
log.Infof("url: %s\n", url)
|
||||
|
||||
apiClient = config.APIClient[i]
|
||||
break
|
||||
//将大模型需要的参数反序列化
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(actionInput), &data); err != nil {
|
||||
log.Debugf("Error: %s", err.Error())
|
||||
return types.ActionContinue, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if apiClient != nil {
|
||||
err := apiClient.Call(
|
||||
method,
|
||||
url,
|
||||
headers,
|
||||
reqBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
|
||||
}, 50000)
|
||||
if err != nil {
|
||||
log.Debugf("tool calls error: %s\n", err.Error())
|
||||
proxywasm.ResumeHttpRequest()
|
||||
method = tools_param.Method
|
||||
|
||||
// 组装 headers 和 key
|
||||
headers = [][2]string{{"Content-Type", "application/json"}}
|
||||
if apisParam.APIKey.Name != "" {
|
||||
if apisParam.APIKey.In == "query" {
|
||||
key = "?" + apisParam.APIKey.Name + "=" + apisParam.APIKey.Value
|
||||
} else if apisParam.APIKey.In == "header" {
|
||||
headers = append(headers, [2]string{"Authorization", apisParam.APIKey.Name + " " + apisParam.APIKey.Value})
|
||||
}
|
||||
}
|
||||
|
||||
// 组装 URL 和请求体
|
||||
url = apisParam.URL + tools_param.Path + key
|
||||
if method == "GET" {
|
||||
queryParams := make([]string, 0, len(tools_param.ParamName))
|
||||
for _, param := range tools_param.ParamName {
|
||||
if value, ok := data[param]; ok {
|
||||
queryParams = append(queryParams, fmt.Sprintf("%s=%v", param, value))
|
||||
}
|
||||
}
|
||||
if len(queryParams) > 0 {
|
||||
url += "&" + strings.Join(queryParams, "&")
|
||||
}
|
||||
} else if method == "POST" {
|
||||
var err error
|
||||
reqBody, err = json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Debugf("Error marshaling JSON: %s", err.Error())
|
||||
return types.ActionContinue, ""
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("url: %s", url)
|
||||
|
||||
apiClient = config.APIClient[i]
|
||||
break
|
||||
}
|
||||
} else {
|
||||
return types.ActionContinue
|
||||
}
|
||||
}
|
||||
return types.ActionPause
|
||||
|
||||
if apiClient != nil {
|
||||
err := apiClient.Call(
|
||||
method,
|
||||
url,
|
||||
headers,
|
||||
reqBody,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
|
||||
}, uint32(maxExecutionTime))
|
||||
if err != nil {
|
||||
log.Debugf("tool calls error: %s", err.Error())
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
} else {
|
||||
return types.ActionContinue, ""
|
||||
}
|
||||
|
||||
return types.ActionPause, ""
|
||||
}
|
||||
|
||||
// 从response接收到firstreq的大模型返回
|
||||
@@ -361,11 +411,12 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byt
|
||||
log.Debugf("[onHttpResponseBody] body to json err: %s", err.Error())
|
||||
return types.ActionContinue
|
||||
}
|
||||
log.Infof("first content: %s\n", rawResponse.Choices[0].Message.Content)
|
||||
log.Infof("first content: %s", rawResponse.Choices[0].Message.Content)
|
||||
//如果gpt返回的内容不是空的
|
||||
if rawResponse.Choices[0].Message.Content != "" {
|
||||
//进入agent的循环思考,工具调用的过程中
|
||||
return toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||||
retType, _ := toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||||
return retType
|
||||
} else {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user