From 895f17f8d84ca3f70db865badf04dbe04064beec Mon Sep 17 00:00:00 2001 From: xingyunyang01 <94745901+xingyunyang01@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:33:42 +0800 Subject: [PATCH] =?UTF-8?q?update:=20Add=20support=20for=20post=20tools,?= =?UTF-8?q?=20add=20round=20limits,=20per-round=20token=E2=80=A6=20(#1230)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Kent Dong --- plugins/wasm-go/extensions/ai-agent/README.md | 121 ++++++++- plugins/wasm-go/extensions/ai-agent/config.go | 109 ++++++-- .../extensions/ai-agent/dashscope/types.go | 11 +- plugins/wasm-go/extensions/ai-agent/main.go | 241 +++++++++++------- 4 files changed, 345 insertions(+), 137 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-agent/README.md b/plugins/wasm-go/extensions/ai-agent/README.md index b1fbcac39..3419fa7ec 100644 --- a/plugins/wasm-go/extensions/ai-agent/README.md +++ b/plugins/wasm-go/extensions/ai-agent/README.md @@ -5,10 +5,10 @@ description: AI Agent插件配置参考 --- ## 功能说明 -一个可定制化的 API AI Agent,目前第一版本只支持配置 http method 类型为 GET 的 API,且只支持非流式模式。agent流程图如下: +一个可定制化的 API AI Agent,支持配置 http method 类型为 GET 与 POST 的 API,目前只支持非流式模式。 +agent流程图如下: ![ai-agent](https://github.com/user-attachments/assets/b0761a0c-1afa-496c-a98e-bb9f38b340f8) -由于 Agent 是多轮对话场景,需要维护历史对话记录,本版本目前是在内存中维护历史对话记录,因此只支持单机。后续会支持通过 redis 存储回话记录 ## 配置字段 @@ -21,14 +21,17 @@ description: AI Agent插件配置参考 `llm`的配置字段说明如下: -| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | -|-----------------|-----------|---------|--------|-----------------------------------| -| `apiKey` | string | 必填 | - | 用于在访问大模型服务时进行认证的令牌。| -| `serviceName` | string | 必填 | - | 大模型服务名 | -| `servicePort` | int | 必填 | - | 大模型服务端口 | -| `domain` | string | 必填 | - | 访问大模型服务时域名 | -| `path` | string | 必填 | - | 访问大模型服务时路径 | -| `model` | string | 必填 | - | 访问大模型服务时模型名 | +| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 | +|--------------------|-----------|---------|--------|-----------------------------------| +| `apiKey` | string | 必填 | - | 用于在访问大模型服务时进行认证的令牌。| +| `serviceName` | string | 必填 | - | 大模型服务名 | +| `servicePort` | int | 必填 | - | 大模型服务端口 | +| `domain` | string | 必填 | - | 访问大模型服务时域名 | +| `path` | string | 必填 | - | 访问大模型服务时路径 | +| `model` | string | 必填 | - | 访问大模型服务时模型名 | +| `maxIterations` | int | 必填 | 15 | 结束执行循环前的最大步数 | +| `maxExecutionTime` | int | 必填 | 50000 | 每一次请求大模型的超时时间,单位毫秒 | +| `maxTokens` | int | 必填 | 1000 | 每一次请求大模型的输出token限制 | `apis`的配置字段说明如下: @@ -86,6 +89,7 @@ llm: servicePort: 443 path: /compatible-mode/v1/chat/completions model: qwen-max-0403 + maxIterations: 2 promptTemplate: language: CH apis: @@ -196,11 +200,91 @@ apis: deprecated: false components: schemas: {} +- apiProvider: + apiKey: + in: "header" + name: "DeepL-Auth-Key" + value: "73xxxxxxxxxxxxxxx:fx" + domain: "api-free.deepl.com" + serviceName: "deepl.dns" + servicePort: 443 + api: | + openapi: 3.1.0 + info: + title: DeepL API Documentation + description: The DeepL API provides programmatic access to DeepL’s machine translation technology. + version: v1.0.0 + servers: + - url: https://api-free.deepl.com/v2 + paths: + /translate: + post: + summary: Request Translation + operationId: translateText + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - text + - target_lang + properties: + text: + description: | + Text to be translated. Only UTF-8-encoded plain text is supported. The parameter may be specified + up to 50 times in a single request. Translations are returned in the same order as they are requested. + type: array + maxItems: 50 + items: + type: string + example: Hello, World! + target_lang: + description: The language into which the text should be translated. + type: string + enum: + - BG + - CS + - DA + - DE + - EL + - EN-GB + - EN-US + - ES + - ET + - FI + - FR + - HU + - ID + - IT + - JA + - KO + - LT + - LV + - NB + - NL + - PL + - PT-BR + - PT-PT + - RO + - RU + - SK + - SL + - SV + - TR + - UK + - ZH + - ZH-HANS + example: DE + components: + schemas: {} ``` -本示例配置了两个服务,一个是高德地图,另一个是心知天气,两个服务都需要现在Higress的服务中以DNS域名的方式配置好,并确保健康。 +本示例配置了三个服务,演示了get与post两种类型的工具。其中get类型的工具包括高德地图与心知天气,post类型的工具是deepl翻译。三个服务都需要现在Higress的服务中以DNS域名的方式配置好,并确保健康。 高德地图提供了两个工具,分别是获取指定地点的坐标,以及搜索坐标附近的感兴趣的地点。文档:https://lbs.amap.com/api/webservice/guide/api-advanced/newpoisearch 心知天气提供了一个工具,用于获取指定城市的实时天气情况,支持中文,英文,日语返回,以及摄氏度和华氏度的表示。文档:https://seniverse.yuque.com/hyper_data/api_v3/nyiu3t +deepl提供了一个工具,用于翻译给定的句子,支持多语言。。文档:https://developers.deepl.com/docs/v/zh/api-reference/translate?fallback=true 以下为测试用例,为了效果的稳定性,建议保持大模型版本的稳定,本例子中使用的qwen-max-0403: @@ -249,3 +333,18 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \ ```json {"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content":" 济南市の現在の天気は雨曇りで、気温は88°Fです。この情報は2024年8月9日15時12分(東京時間)に更新されました。"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}} ``` + +**请求示例** + +```shell +curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \ +-H 'Accept: application/json, text/event-stream' \ +-H 'Content-Type: application/json' \ +--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"帮我用德语翻译以下句子:九头蛇万岁!"}],"presence_penalty":0,"temperature":0,"top_p":0}' +``` + +**响应示例** + +```json +{"id":"65dcf12c-61ff-9e68-bffa-44fc9e6070d5","choices":[{"index":0,"message":{"role":"assistant","content":" “九头蛇万岁!”的德语翻译为“Hoch lebe Hydra!”。"},"finish_reason":"stop"}],"created":1724043865,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":908,"completion_tokens":52,"total_tokens":960}} +``` \ No newline at end of file diff --git a/plugins/wasm-go/extensions/ai-agent/config.go b/plugins/wasm-go/extensions/ai-agent/config.go index 9aa47ef1f..31445c17f 100644 --- a/plugins/wasm-go/extensions/ai-agent/config.go +++ b/plugins/wasm-go/extensions/ai-agent/config.go @@ -47,16 +47,16 @@ type Response struct { // 用于存放拆解出来的工具相关信息 type Tool_Param struct { - ToolName string `yaml:"toolName"` - Path string `yaml:"path"` - Method string `yaml:"method"` - ParamName []string `yaml:"paramName"` - Parameter string `yaml:"parameter"` - Desciption string `yaml:"description"` + ToolName string `yaml:"toolName"` + Path string `yaml:"path"` + Method string `yaml:"method"` + ParamName []string `yaml:"paramName"` + Parameter string `yaml:"parameter"` + Description string `yaml:"description"` } // 用于存放拆解出来的api相关信息 -type API_Param struct { +type APIParam struct { APIKey APIKey `yaml:"apiKey"` URL string `yaml:"url"` Tool_Param []Tool_Param `yaml:"tool_Param"` @@ -72,6 +72,7 @@ type Server struct { URL string `yaml:"url"` } +// 给OpenAPI的get方法用的 type Parameter struct { Name string `yaml:"name"` In string `yaml:"in"` @@ -84,9 +85,41 @@ type Parameter struct { } `yaml:"schema"` } +type Items struct { + Type string `yaml:"type"` + Example string `yaml:"example"` +} + +type Property struct { + Description string `yaml:"description"` + Type string `yaml:"type"` + Enum []string `yaml:"enum,omitempty"` + Items *Items `yaml:"items,omitempty"` + MaxItems int `yaml:"maxItems,omitempty"` + Example string `yaml:"example,omitempty"` +} + +type Schema struct { + Type string `yaml:"type"` + Required []string `yaml:"required"` + Properties map[string]Property `yaml:"properties"` +} + +type MediaType struct { + Schema Schema `yaml:"schema"` +} + +// 给OpenAPI的post方法用的 +type RequestBody struct { + Required bool `yaml:"required"` + Content map[string]MediaType `yaml:"content"` +} + type PathItem struct { Description string `yaml:"description"` + Summary string `yaml:"summary"` OperationID string `yaml:"operationId"` + RequestBody RequestBody `yaml:"requestBody"` Parameters []Parameter `yaml:"parameters"` Deprecated bool `yaml:"deprecated"` } @@ -166,6 +199,15 @@ type LLMInfo struct { // @Title zh-CN 大模型服务的模型名称 // @Description zh-CN 大模型服务的模型名称,如"qwen-max-0403" Model string `required:"true" yaml:"model" json:"model"` + // @Title zh-CN 结束执行循环前的最大步数 + // @Description zh-CN 结束执行循环前的最大步数,比如2,设置为0,可能会无限循环,直到超时退出,默认15 + MaxIterations int64 `yaml:"maxIterations" json:"maxIterations"` + // @Title zh-CN 每一次请求大模型的超时时间 + // @Description zh-CN 每一次请求大模型的超时时间,单位毫秒,默认50000 + MaxExecutionTime int64 `yaml:"maxExecutionTime" json:"maxExecutionTime"` + // @Title zh-CN + // @Description zh-CN 每一次请求大模型的输出token限制,默认1000 + MaxTokens int64 `yaml:"maxToken" json:"maxTokens"` } type PluginConfig struct { @@ -180,7 +222,7 @@ type PluginConfig struct { // @Description zh-CN 用于存储llm使用信息 LLMInfo LLMInfo `required:"true" yaml:"llm" json:"llm"` LLMClient wrapper.HttpClient `yaml:"-" json:"-"` - API_Param []API_Param `yaml:"-" json:"-"` + APIParam []APIParam `yaml:"-" json:"-"` PromptTemplate PromptTemplate `yaml:"promptTemplate" json:"promptTemplate"` } @@ -188,7 +230,7 @@ func initResponsePromptTpl(gjson gjson.Result, c *PluginConfig) { //设置回复模板 c.ReturnResponseTemplate = gjson.Get("returnResponseTemplate").String() if c.ReturnResponseTemplate == "" { - c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` + c.ReturnResponseTemplate = `{"id":"error","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}` } } @@ -241,40 +283,49 @@ func initAPIs(gjson gjson.Result, c *PluginConfig) error { return errors.New("api is required") } - var apiStrcut API - err := yaml.Unmarshal([]byte(api.String()), &apiStrcut) + var apiStruct API + err := yaml.Unmarshal([]byte(api.String()), &apiStruct) if err != nil { return err } var allTool_param []Tool_Param //拆除服务下面的每个api的path - for path, pathmap := range apiStrcut.Paths { + for path, pathmap := range apiStruct.Paths { //拆解出每个api对应的参数 for method, submap := range pathmap { //把参数列表存起来 var param Tool_Param param.Path = path - param.Method = method param.ToolName = submap.OperationID - paramName := make([]string, 0) - for _, parammeter := range submap.Parameters { - paramName = append(paramName, parammeter.Name) + if method == "get" { + param.Method = "GET" + paramName := make([]string, 0) + for _, parammeter := range submap.Parameters { + paramName = append(paramName, parammeter.Name) + } + param.ParamName = paramName + out, _ := json.Marshal(submap.Parameters) + param.Parameter = string(out) + param.Description = submap.Description + } else if method == "post" { + param.Method = "POST" + schema := submap.RequestBody.Content["application/json"].Schema + param.ParamName = schema.Required + param.Description = submap.Summary + out, _ := json.Marshal(schema.Properties) + param.Parameter = string(out) } - param.ParamName = paramName - out, _ := json.Marshal(submap.Parameters) - param.Parameter = string(out) - param.Desciption = submap.Description allTool_param = append(allTool_param, param) } } - api_param := API_Param{ + apiParam := APIParam{ APIKey: APIKey{In: apiKeyIn, Name: apiKeyName.String(), Value: apiKeyValue.String()}, - URL: apiStrcut.Servers[0].URL, + URL: apiStruct.Servers[0].URL, Tool_Param: allTool_param, } - c.API_Param = append(c.API_Param, api_param) + c.APIParam = append(c.APIParam, apiParam) } return nil } @@ -352,6 +403,18 @@ func initLLMClient(gjson gjson.Result, c *PluginConfig) { c.LLMInfo.Domin = gjson.Get("llm.domain").String() c.LLMInfo.Path = gjson.Get("llm.path").String() c.LLMInfo.Model = gjson.Get("llm.model").String() + c.LLMInfo.MaxIterations = gjson.Get("llm.maxIterations").Int() + if c.LLMInfo.MaxIterations == 0 { + c.LLMInfo.MaxIterations = 15 + } + c.LLMInfo.MaxExecutionTime = gjson.Get("llm.maxExecutionTime").Int() + if c.LLMInfo.MaxExecutionTime == 0 { + c.LLMInfo.MaxExecutionTime = 50000 + } + c.LLMInfo.MaxTokens = gjson.Get("llm.maxTokens").Int() + if c.LLMInfo.MaxTokens == 0 { + c.LLMInfo.MaxTokens = 1000 + } c.LLMClient = wrapper.NewClusterClient(wrapper.FQDNCluster{ FQDN: c.LLMInfo.ServiceName, diff --git a/plugins/wasm-go/extensions/ai-agent/dashscope/types.go b/plugins/wasm-go/extensions/ai-agent/dashscope/types.go index 7aef7272e..1a0bf1188 100644 --- a/plugins/wasm-go/extensions/ai-agent/dashscope/types.go +++ b/plugins/wasm-go/extensions/ai-agent/dashscope/types.go @@ -37,8 +37,9 @@ type Usage struct { // completion type Completion struct { - Model string `json:"model"` - Messages []Message `json:"messages"` + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int64 `json:"max_tokens"` } type Message struct { @@ -67,9 +68,3 @@ type CompletionUsage struct { CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } - -type Content struct { - CH_Question string `json:"ch_question"` - Core string `json:"core"` - // EN_Question string `json:"en_question"` -} diff --git a/plugins/wasm-go/extensions/ai-agent/main.go b/plugins/wasm-go/extensions/ai-agent/main.go index 5f0f0fe22..4359e63fb 100644 --- a/plugins/wasm-go/extensions/ai-agent/main.go +++ b/plugins/wasm-go/extensions/ai-agent/main.go @@ -15,6 +15,14 @@ import ( "github.com/tidwall/gjson" ) +// 用于统计函数的递归调用次数 +const ToolCallsCount = "ToolCallsCount" + +// react的正则规则 +const ActionPattern = `Action:\s*(.*?)[.\n]` +const ActionInputPattern = `Action Input:\s*(.*)` +const FinalAnswerPattern = `Final Answer:(.*)` + func main() { wrapper.SetCtx( "ai-agent", @@ -103,9 +111,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte //拼装agent prompt模板 tool_desc := make([]string, 0) tool_names := make([]string, 0) - for _, api_param := range config.API_Param { - for _, tool_param := range api_param.Tool_Param { - tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Desciption, tool_param.Desciption, tool_param.Desciption, tool_param.Parameter), "\n") + for _, apiParam := range config.APIParam { + for _, tool_param := range apiParam.Tool_Param { + tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Description, tool_param.Description, tool_param.Description, tool_param.Parameter), "\n") tool_names = append(tool_names, tool_param.ToolName) } } @@ -119,6 +127,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte tool_names, config.PromptTemplate.CHTemplate.ActionInput, config.PromptTemplate.CHTemplate.Observation, + config.PromptTemplate.CHTemplate.Thought2, config.PromptTemplate.CHTemplate.FinalAnswer, config.PromptTemplate.CHTemplate.Begin, query) @@ -130,11 +139,17 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte tool_names, config.PromptTemplate.ENTemplate.ActionInput, config.PromptTemplate.ENTemplate.Observation, + config.PromptTemplate.ENTemplate.Thought2, config.PromptTemplate.ENTemplate.FinalAnswer, config.PromptTemplate.ENTemplate.Begin, query) } + ctx.SetContext(ToolCallsCount, 0) + + //清理历史对话记录 + dashscope.MessageStore.Clear() + //将请求加入到历史对话存储器中 dashscope.MessageStore.AddForUser(prompt) @@ -145,22 +160,101 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte } func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action { + log.Debug("onHttpResponseHeaders start") + defer log.Debug("onHttpResponseHeaders end") + return types.ActionContinue } -func toolsCall(config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action { +func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) { + if statusCode != http.StatusOK { + log.Debugf("statusCode: %d\n", statusCode) + } + log.Info("========函数返回结果========") + log.Infof(string(responseBody)) + + observation := "Observation: " + string(responseBody) + + dashscope.MessageStore.AddForUser(observation) + + completion := dashscope.Completion{ + Model: config.LLMInfo.Model, + Messages: dashscope.MessageStore, + MaxTokens: config.LLMInfo.MaxTokens, + } + + headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}} + completionSerialized, _ := json.Marshal(completion) + err := config.LLMClient.Post( + config.LLMInfo.Path, + headers, + completionSerialized, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + //得到gpt的返回结果 + var responseCompletion dashscope.CompletionResponse + _ = json.Unmarshal(responseBody, &responseCompletion) + log.Infof("[toolsCall] content: %s\n", responseCompletion.Choices[0].Message.Content) + + if responseCompletion.Choices[0].Message.Content != "" { + retType := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log) + if retType == types.ActionContinue { + //得到了Final Answer + var assistantMessage Message + assistantMessage.Role = "assistant" + startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:") + if startIndex != -1 { + startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置 + extractedText := responseCompletion.Choices[0].Message.Content[startIndex:] + assistantMessage.Content = extractedText + } + + rawResponse.Choices[0].Message = assistantMessage + + newbody, err := json.Marshal(rawResponse) + if err != nil { + proxywasm.ResumeHttpResponse() + return + } else { + log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody)) + proxywasm.ReplaceHttpResponseBody(newbody) + + log.Debug("[onHttpResponseBody] response替换成功") + proxywasm.ResumeHttpResponse() + } + } + } else { + proxywasm.ResumeHttpRequest() + } + }, uint32(config.LLMInfo.MaxExecutionTime)) + if err != nil { + log.Debugf("[onHttpRequestBody] completion err: %s", err.Error()) + proxywasm.ResumeHttpRequest() + } +} + +func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action { dashscope.MessageStore.AddForAssistant(content) //得到最终答案 - regexPattern := regexp.MustCompile(`Final Answer:(.*)`) + regexPattern := regexp.MustCompile(FinalAnswerPattern) finalAnswer := regexPattern.FindStringSubmatch(content) if len(finalAnswer) > 1 { return types.ActionContinue } + count := ctx.GetContext(ToolCallsCount).(int) + count++ + log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d\n", count, config.LLMInfo.MaxIterations) + //函数递归调用次数,达到了预设的循环次数,强制结束 + if int64(count) > config.LLMInfo.MaxIterations { + ctx.SetContext(ToolCallsCount, 0) + return types.ActionContinue + } else { + ctx.SetContext(ToolCallsCount, count) + } //没得到最终答案 - regexAction := regexp.MustCompile(`Action:\s*(.*?)[.\n]`) - regexActionInput := regexp.MustCompile(`Action Input:\s*(.*)`) + regexAction := regexp.MustCompile(ActionPattern) + regexActionInput := regexp.MustCompile(ActionInputPattern) action := regexAction.FindStringSubmatch(content) actionInput := regexActionInput.FindStringSubmatch(content) @@ -170,9 +264,11 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr var headers [][2]string var apiClient wrapper.HttpClient var method string + var reqBody []byte + var key string - for i, api_param := range config.API_Param { - for _, tool_param := range api_param.Tool_Param { + for i, apiParam := range config.APIParam { + for _, tool_param := range apiParam.Tool_Param { if action[1] == tool_param.ToolName { log.Infof("calls %s\n", tool_param.ToolName) log.Infof("actionInput[1]: %s", actionInput[1]) @@ -184,107 +280,63 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr return types.ActionContinue } - var args string - for i, param := range tool_param.ParamName { //从参数列表中取出参数 - if i == 0 { - args = "?" + param + "=%s" - args = fmt.Sprintf(args, data[param]) - } else { - args = args + "&" + param + "=%s" - args = fmt.Sprintf(args, data[param]) + method = tool_param.Method + + //key or header组装 + if apiParam.APIKey.Name != "" { + if apiParam.APIKey.In == "query" { //query类型的key要放到url中 + headers = nil + key = "?" + apiParam.APIKey.Name + "=" + apiParam.APIKey.Value + } else if apiParam.APIKey.In == "header" { //header类型的key放在header中 + headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", apiParam.APIKey.Name + " " + apiParam.APIKey.Value}} } } - url = api_param.URL + tool_param.Path + args - - if api_param.APIKey.Name != "" { - if api_param.APIKey.In == "query" { - headers = nil - key := "&" + api_param.APIKey.Name + "=" + api_param.APIKey.Value - url += key - } else if api_param.APIKey.In == "header" { - headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", api_param.APIKey.Name + " " + api_param.APIKey.Value}} + if method == "GET" { + //query组装 + var args string + for i, param := range tool_param.ParamName { //从参数列表中取出参数 + if i == 0 && apiParam.APIKey.In != "query" { + args = "?" + param + "=%s" + args = fmt.Sprintf(args, data[param]) + } else { + args = args + "&" + param + "=%s" + args = fmt.Sprintf(args, data[param]) + } } + + //url组装 + url = apiParam.URL + tool_param.Path + key + args + } else if method == "POST" { + reqBody = nil + //json参数组装 + jsonData, err := json.Marshal(data) + if err != nil { + log.Debugf("Error: %s\n", err.Error()) + return types.ActionContinue + } + reqBody = jsonData + + //url组装 + url = apiParam.URL + tool_param.Path + key } log.Infof("url: %s\n", url) - method = tool_param.Method - apiClient = config.APIClient[i] break } } } - if method == "get" { - //调用工具 - err := apiClient.Get( + if apiClient != nil { + err := apiClient.Call( + method, url, headers, + reqBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) { - if statusCode != http.StatusOK { - log.Debugf("statusCode: %d\n", statusCode) - } - log.Info("========函数返回结果========") - log.Infof(string(responseBody)) - - Observation := "Observation: " + string(responseBody) - - dashscope.MessageStore.AddForUser(Observation) - - completion := dashscope.Completion{ - Model: config.LLMInfo.Model, - Messages: dashscope.MessageStore, - } - - headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}} - completionSerialized, _ := json.Marshal(completion) - err := config.LLMClient.Post( - config.LLMInfo.Path, - headers, - completionSerialized, - func(statusCode int, responseHeaders http.Header, responseBody []byte) { - //得到gpt的返回结果 - var responseCompletion dashscope.CompletionResponse - _ = json.Unmarshal(responseBody, &responseCompletion) - log.Infof("[toolsCall] content: ", responseCompletion.Choices[0].Message.Content) - - if responseCompletion.Choices[0].Message.Content != "" { - retType := toolsCall(config, responseCompletion.Choices[0].Message.Content, rawResponse, log) - if retType == types.ActionContinue { - //得到了Final Answer - var assistantMessage Message - assistantMessage.Role = "assistant" - startIndex := strings.Index(responseCompletion.Choices[0].Message.Content, "Final Answer:") - if startIndex != -1 { - startIndex += len("Final Answer:") // 移动到"Final Answer:"之后的位置 - extractedText := responseCompletion.Choices[0].Message.Content[startIndex:] - assistantMessage.Content = extractedText - } - //assistantMessage.Content = responseCompletion.Choices[0].Message.Content - rawResponse.Choices[0].Message = assistantMessage - - newbody, err := json.Marshal(rawResponse) - if err != nil { - proxywasm.ResumeHttpResponse() - return - } else { - log.Infof("[onHttpResponseBody] newResponseBody: ", string(newbody)) - proxywasm.ReplaceHttpResponseBody(newbody) - - log.Debug("[onHttpResponseBody] response替换成功") - proxywasm.ResumeHttpResponse() - } - } - } else { - proxywasm.ResumeHttpRequest() - } - }, 50000) - if err != nil { - log.Debugf("[onHttpRequestBody] completion err: %s", err.Error()) - proxywasm.ResumeHttpRequest() - } + toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody) }, 50000) if err != nil { log.Debugf("tool calls error: %s\n", err.Error()) @@ -293,7 +345,6 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr } else { return types.ActionContinue } - } return types.ActionPause } @@ -314,7 +365,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byt //如果gpt返回的内容不是空的 if rawResponse.Choices[0].Message.Content != "" { //进入agent的循环思考,工具调用的过程中 - return toolsCall(config, rawResponse.Choices[0].Message.Content, rawResponse, log) + return toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log) } else { return types.ActionContinue }