mirror of
https://github.com/alibaba/higress.git
synced 2026-06-10 05:07:30 +08:00
update: Add support for post tools, add round limits, per-round token… (#1230)
Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
@@ -5,10 +5,10 @@ description: AI Agent插件配置参考
|
|||||||
---
|
---
|
||||||
|
|
||||||
## 功能说明
|
## 功能说明
|
||||||
一个可定制化的 API AI Agent,目前第一版本只支持配置 http method 类型为 GET 的 API,且只支持非流式模式。agent流程图如下:
|
一个可定制化的 API AI Agent,支持配置 http method 类型为 GET 与 POST 的 API,目前只支持非流式模式。
|
||||||
|
agent流程图如下:
|
||||||

|

|
||||||
|
|
||||||
由于 Agent 是多轮对话场景,需要维护历史对话记录,本版本目前是在内存中维护历史对话记录,因此只支持单机。后续会支持通过 redis 存储回话记录
|
|
||||||
|
|
||||||
## 配置字段
|
## 配置字段
|
||||||
|
|
||||||
@@ -22,13 +22,16 @@ description: AI Agent插件配置参考
|
|||||||
`llm`的配置字段说明如下:
|
`llm`的配置字段说明如下:
|
||||||
|
|
||||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||||
|-----------------|-----------|---------|--------|-----------------------------------|
|
|--------------------|-----------|---------|--------|-----------------------------------|
|
||||||
| `apiKey` | string | 必填 | - | 用于在访问大模型服务时进行认证的令牌。|
|
| `apiKey` | string | 必填 | - | 用于在访问大模型服务时进行认证的令牌。|
|
||||||
| `serviceName` | string | 必填 | - | 大模型服务名 |
|
| `serviceName` | string | 必填 | - | 大模型服务名 |
|
||||||
| `servicePort` | int | 必填 | - | 大模型服务端口 |
|
| `servicePort` | int | 必填 | - | 大模型服务端口 |
|
||||||
| `domain` | string | 必填 | - | 访问大模型服务时域名 |
|
| `domain` | string | 必填 | - | 访问大模型服务时域名 |
|
||||||
| `path` | string | 必填 | - | 访问大模型服务时路径 |
|
| `path` | string | 必填 | - | 访问大模型服务时路径 |
|
||||||
| `model` | string | 必填 | - | 访问大模型服务时模型名 |
|
| `model` | string | 必填 | - | 访问大模型服务时模型名 |
|
||||||
|
| `maxIterations` | int | 必填 | 15 | 结束执行循环前的最大步数 |
|
||||||
|
| `maxExecutionTime` | int | 必填 | 50000 | 每一次请求大模型的超时时间,单位毫秒 |
|
||||||
|
| `maxTokens` | int | 必填 | 1000 | 每一次请求大模型的输出token限制 |
|
||||||
|
|
||||||
`apis`的配置字段说明如下:
|
`apis`的配置字段说明如下:
|
||||||
|
|
||||||
@@ -86,6 +89,7 @@ llm:
|
|||||||
servicePort: 443
|
servicePort: 443
|
||||||
path: /compatible-mode/v1/chat/completions
|
path: /compatible-mode/v1/chat/completions
|
||||||
model: qwen-max-0403
|
model: qwen-max-0403
|
||||||
|
maxIterations: 2
|
||||||
promptTemplate:
|
promptTemplate:
|
||||||
language: CH
|
language: CH
|
||||||
apis:
|
apis:
|
||||||
@@ -196,11 +200,91 @@ apis:
|
|||||||
deprecated: false
|
deprecated: false
|
||||||
components:
|
components:
|
||||||
schemas: {}
|
schemas: {}
|
||||||
|
- apiProvider:
|
||||||
|
apiKey:
|
||||||
|
in: "header"
|
||||||
|
name: "DeepL-Auth-Key"
|
||||||
|
value: "73xxxxxxxxxxxxxxx:fx"
|
||||||
|
domain: "api-free.deepl.com"
|
||||||
|
serviceName: "deepl.dns"
|
||||||
|
servicePort: 443
|
||||||
|
api: |
|
||||||
|
openapi: 3.1.0
|
||||||
|
info:
|
||||||
|
title: DeepL API Documentation
|
||||||
|
description: The DeepL API provides programmatic access to DeepL’s machine translation technology.
|
||||||
|
version: v1.0.0
|
||||||
|
servers:
|
||||||
|
- url: https://api-free.deepl.com/v2
|
||||||
|
paths:
|
||||||
|
/translate:
|
||||||
|
post:
|
||||||
|
summary: Request Translation
|
||||||
|
operationId: translateText
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
required:
|
||||||
|
- text
|
||||||
|
- target_lang
|
||||||
|
properties:
|
||||||
|
text:
|
||||||
|
description: |
|
||||||
|
Text to be translated. Only UTF-8-encoded plain text is supported. The parameter may be specified
|
||||||
|
up to 50 times in a single request. Translations are returned in the same order as they are requested.
|
||||||
|
type: array
|
||||||
|
maxItems: 50
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: Hello, World!
|
||||||
|
target_lang:
|
||||||
|
description: The language into which the text should be translated.
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- BG
|
||||||
|
- CS
|
||||||
|
- DA
|
||||||
|
- DE
|
||||||
|
- EL
|
||||||
|
- EN-GB
|
||||||
|
- EN-US
|
||||||
|
- ES
|
||||||
|
- ET
|
||||||
|
- FI
|
||||||
|
- FR
|
||||||
|
- HU
|
||||||
|
- ID
|
||||||
|
- IT
|
||||||
|
- JA
|
||||||
|
- KO
|
||||||
|
- LT
|
||||||
|
- LV
|
||||||
|
- NB
|
||||||
|
- NL
|
||||||
|
- PL
|
||||||
|
- PT-BR
|
||||||
|
- PT-PT
|
||||||
|
- RO
|
||||||
|
- RU
|
||||||
|
- SK
|
||||||
|
- SL
|
||||||
|
- SV
|
||||||
|
- TR
|
||||||
|
- UK
|
||||||
|
- ZH
|
||||||
|
- ZH-HANS
|
||||||
|
example: DE
|
||||||
|
components:
|
||||||
|
schemas: {}
|
||||||
```
|
```
|
||||||
|
|
||||||
本示例配置了两个服务,一个是高德地图,另一个是心知天气,两个服务都需要现在Higress的服务中以DNS域名的方式配置好,并确保健康。
|
本示例配置了三个服务,演示了get与post两种类型的工具。其中get类型的工具包括高德地图与心知天气,post类型的工具是deepl翻译。三个服务都需要现在Higress的服务中以DNS域名的方式配置好,并确保健康。
|
||||||
高德地图提供了两个工具,分别是获取指定地点的坐标,以及搜索坐标附近的感兴趣的地点。文档:https://lbs.amap.com/api/webservice/guide/api-advanced/newpoisearch
|
高德地图提供了两个工具,分别是获取指定地点的坐标,以及搜索坐标附近的感兴趣的地点。文档:https://lbs.amap.com/api/webservice/guide/api-advanced/newpoisearch
|
||||||
心知天气提供了一个工具,用于获取指定城市的实时天气情况,支持中文,英文,日语返回,以及摄氏度和华氏度的表示。文档:https://seniverse.yuque.com/hyper_data/api_v3/nyiu3t
|
心知天气提供了一个工具,用于获取指定城市的实时天气情况,支持中文,英文,日语返回,以及摄氏度和华氏度的表示。文档:https://seniverse.yuque.com/hyper_data/api_v3/nyiu3t
|
||||||
|
deepl提供了一个工具,用于翻译给定的句子,支持多语言。。文档:https://developers.deepl.com/docs/v/zh/api-reference/translate?fallback=true
|
||||||
|
|
||||||
|
|
||||||
以下为测试用例,为了效果的稳定性,建议保持大模型版本的稳定,本例子中使用的qwen-max-0403:
|
以下为测试用例,为了效果的稳定性,建议保持大模型版本的稳定,本例子中使用的qwen-max-0403:
|
||||||
@@ -249,3 +333,18 @@ curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
|
|||||||
```json
|
```json
|
||||||
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content":" 济南市の現在の天気は雨曇りで、気温は88°Fです。この情報は2024年8月9日15時12分(東京時間)に更新されました。"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
|
{"id":"ebd6ea91-8e38-9e14-9a5b-90178d2edea4","choices":[{"index":0,"message":{"role":"assistant","content":" 济南市の現在の天気は雨曇りで、気温は88°Fです。この情報は2024年8月9日15時12分(東京時間)に更新されました。"},"finish_reason":"stop"}],"created":1723187991,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":890,"completion_tokens":56,"total_tokens":946}}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**请求示例**
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl 'http://<这里换成网关公网IP>/api/openai/v1/chat/completions' \
|
||||||
|
-H 'Accept: application/json, text/event-stream' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
--data-raw '{"model":"qwen","frequency_penalty":0,"max_tokens":800,"stream":false,"messages":[{"role":"user","content":"帮我用德语翻译以下句子:九头蛇万岁!"}],"presence_penalty":0,"temperature":0,"top_p":0}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"id":"65dcf12c-61ff-9e68-bffa-44fc9e6070d5","choices":[{"index":0,"message":{"role":"assistant","content":" “九头蛇万岁!”的德语翻译为“Hoch lebe Hydra!”。"},"finish_reason":"stop"}],"created":1724043865,"model":"qwen-max-0403","object":"chat.completion","usage":{"prompt_tokens":908,"completion_tokens":52,"total_tokens":960}}
|
||||||
|
```
|
||||||
@@ -52,11 +52,11 @@ type Tool_Param struct {
|
|||||||
Method string `yaml:"method"`
|
Method string `yaml:"method"`
|
||||||
ParamName []string `yaml:"paramName"`
|
ParamName []string `yaml:"paramName"`
|
||||||
Parameter string `yaml:"parameter"`
|
Parameter string `yaml:"parameter"`
|
||||||
Desciption string `yaml:"description"`
|
Description string `yaml:"description"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 用于存放拆解出来的api相关信息
|
// 用于存放拆解出来的api相关信息
|
||||||
type API_Param struct {
|
type APIParam struct {
|
||||||
APIKey APIKey `yaml:"apiKey"`
|
APIKey APIKey `yaml:"apiKey"`
|
||||||
URL string `yaml:"url"`
|
URL string `yaml:"url"`
|
||||||
Tool_Param []Tool_Param `yaml:"tool_Param"`
|
Tool_Param []Tool_Param `yaml:"tool_Param"`
|
||||||
@@ -72,6 +72,7 @@ type Server struct {
|
|||||||
URL string `yaml:"url"`
|
URL string `yaml:"url"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 给OpenAPI的get方法用的
|
||||||
type Parameter struct {
|
type Parameter struct {
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
In string `yaml:"in"`
|
In string `yaml:"in"`
|
||||||
@@ -84,9 +85,41 @@ type Parameter struct {
|
|||||||
} `yaml:"schema"`
|
} `yaml:"schema"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Items struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
Example string `yaml:"example"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Property struct {
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
Enum []string `yaml:"enum,omitempty"`
|
||||||
|
Items *Items `yaml:"items,omitempty"`
|
||||||
|
MaxItems int `yaml:"maxItems,omitempty"`
|
||||||
|
Example string `yaml:"example,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Schema struct {
|
||||||
|
Type string `yaml:"type"`
|
||||||
|
Required []string `yaml:"required"`
|
||||||
|
Properties map[string]Property `yaml:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MediaType struct {
|
||||||
|
Schema Schema `yaml:"schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 给OpenAPI的post方法用的
|
||||||
|
type RequestBody struct {
|
||||||
|
Required bool `yaml:"required"`
|
||||||
|
Content map[string]MediaType `yaml:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
type PathItem struct {
|
type PathItem struct {
|
||||||
Description string `yaml:"description"`
|
Description string `yaml:"description"`
|
||||||
|
Summary string `yaml:"summary"`
|
||||||
OperationID string `yaml:"operationId"`
|
OperationID string `yaml:"operationId"`
|
||||||
|
RequestBody RequestBody `yaml:"requestBody"`
|
||||||
Parameters []Parameter `yaml:"parameters"`
|
Parameters []Parameter `yaml:"parameters"`
|
||||||
Deprecated bool `yaml:"deprecated"`
|
Deprecated bool `yaml:"deprecated"`
|
||||||
}
|
}
|
||||||
@@ -166,6 +199,15 @@ type LLMInfo struct {
|
|||||||
// @Title zh-CN 大模型服务的模型名称
|
// @Title zh-CN 大模型服务的模型名称
|
||||||
// @Description zh-CN 大模型服务的模型名称,如"qwen-max-0403"
|
// @Description zh-CN 大模型服务的模型名称,如"qwen-max-0403"
|
||||||
Model string `required:"true" yaml:"model" json:"model"`
|
Model string `required:"true" yaml:"model" json:"model"`
|
||||||
|
// @Title zh-CN 结束执行循环前的最大步数
|
||||||
|
// @Description zh-CN 结束执行循环前的最大步数,比如2,设置为0,可能会无限循环,直到超时退出,默认15
|
||||||
|
MaxIterations int64 `yaml:"maxIterations" json:"maxIterations"`
|
||||||
|
// @Title zh-CN 每一次请求大模型的超时时间
|
||||||
|
// @Description zh-CN 每一次请求大模型的超时时间,单位毫秒,默认50000
|
||||||
|
MaxExecutionTime int64 `yaml:"maxExecutionTime" json:"maxExecutionTime"`
|
||||||
|
// @Title zh-CN
|
||||||
|
// @Description zh-CN 每一次请求大模型的输出token限制,默认1000
|
||||||
|
MaxTokens int64 `yaml:"maxToken" json:"maxTokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PluginConfig struct {
|
type PluginConfig struct {
|
||||||
@@ -180,7 +222,7 @@ type PluginConfig struct {
|
|||||||
// @Description zh-CN 用于存储llm使用信息
|
// @Description zh-CN 用于存储llm使用信息
|
||||||
LLMInfo LLMInfo `required:"true" yaml:"llm" json:"llm"`
|
LLMInfo LLMInfo `required:"true" yaml:"llm" json:"llm"`
|
||||||
LLMClient wrapper.HttpClient `yaml:"-" json:"-"`
|
LLMClient wrapper.HttpClient `yaml:"-" json:"-"`
|
||||||
API_Param []API_Param `yaml:"-" json:"-"`
|
APIParam []APIParam `yaml:"-" json:"-"`
|
||||||
PromptTemplate PromptTemplate `yaml:"promptTemplate" json:"promptTemplate"`
|
PromptTemplate PromptTemplate `yaml:"promptTemplate" json:"promptTemplate"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +230,7 @@ func initResponsePromptTpl(gjson gjson.Result, c *PluginConfig) {
|
|||||||
//设置回复模板
|
//设置回复模板
|
||||||
c.ReturnResponseTemplate = gjson.Get("returnResponseTemplate").String()
|
c.ReturnResponseTemplate = gjson.Get("returnResponseTemplate").String()
|
||||||
if c.ReturnResponseTemplate == "" {
|
if c.ReturnResponseTemplate == "" {
|
||||||
c.ReturnResponseTemplate = `{"id":"from-cache","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
c.ReturnResponseTemplate = `{"id":"error","choices":[{"index":0,"message":{"role":"assistant","content":"%s"},"finish_reason":"stop"}],"model":"gpt-4o","object":"chat.completion","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -241,22 +283,23 @@ func initAPIs(gjson gjson.Result, c *PluginConfig) error {
|
|||||||
return errors.New("api is required")
|
return errors.New("api is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiStrcut API
|
var apiStruct API
|
||||||
err := yaml.Unmarshal([]byte(api.String()), &apiStrcut)
|
err := yaml.Unmarshal([]byte(api.String()), &apiStruct)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var allTool_param []Tool_Param
|
var allTool_param []Tool_Param
|
||||||
//拆除服务下面的每个api的path
|
//拆除服务下面的每个api的path
|
||||||
for path, pathmap := range apiStrcut.Paths {
|
for path, pathmap := range apiStruct.Paths {
|
||||||
//拆解出每个api对应的参数
|
//拆解出每个api对应的参数
|
||||||
for method, submap := range pathmap {
|
for method, submap := range pathmap {
|
||||||
//把参数列表存起来
|
//把参数列表存起来
|
||||||
var param Tool_Param
|
var param Tool_Param
|
||||||
param.Path = path
|
param.Path = path
|
||||||
param.Method = method
|
|
||||||
param.ToolName = submap.OperationID
|
param.ToolName = submap.OperationID
|
||||||
|
if method == "get" {
|
||||||
|
param.Method = "GET"
|
||||||
paramName := make([]string, 0)
|
paramName := make([]string, 0)
|
||||||
for _, parammeter := range submap.Parameters {
|
for _, parammeter := range submap.Parameters {
|
||||||
paramName = append(paramName, parammeter.Name)
|
paramName = append(paramName, parammeter.Name)
|
||||||
@@ -264,17 +307,25 @@ func initAPIs(gjson gjson.Result, c *PluginConfig) error {
|
|||||||
param.ParamName = paramName
|
param.ParamName = paramName
|
||||||
out, _ := json.Marshal(submap.Parameters)
|
out, _ := json.Marshal(submap.Parameters)
|
||||||
param.Parameter = string(out)
|
param.Parameter = string(out)
|
||||||
param.Desciption = submap.Description
|
param.Description = submap.Description
|
||||||
|
} else if method == "post" {
|
||||||
|
param.Method = "POST"
|
||||||
|
schema := submap.RequestBody.Content["application/json"].Schema
|
||||||
|
param.ParamName = schema.Required
|
||||||
|
param.Description = submap.Summary
|
||||||
|
out, _ := json.Marshal(schema.Properties)
|
||||||
|
param.Parameter = string(out)
|
||||||
|
}
|
||||||
allTool_param = append(allTool_param, param)
|
allTool_param = append(allTool_param, param)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
api_param := API_Param{
|
apiParam := APIParam{
|
||||||
APIKey: APIKey{In: apiKeyIn, Name: apiKeyName.String(), Value: apiKeyValue.String()},
|
APIKey: APIKey{In: apiKeyIn, Name: apiKeyName.String(), Value: apiKeyValue.String()},
|
||||||
URL: apiStrcut.Servers[0].URL,
|
URL: apiStruct.Servers[0].URL,
|
||||||
Tool_Param: allTool_param,
|
Tool_Param: allTool_param,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.API_Param = append(c.API_Param, api_param)
|
c.APIParam = append(c.APIParam, apiParam)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -352,6 +403,18 @@ func initLLMClient(gjson gjson.Result, c *PluginConfig) {
|
|||||||
c.LLMInfo.Domin = gjson.Get("llm.domain").String()
|
c.LLMInfo.Domin = gjson.Get("llm.domain").String()
|
||||||
c.LLMInfo.Path = gjson.Get("llm.path").String()
|
c.LLMInfo.Path = gjson.Get("llm.path").String()
|
||||||
c.LLMInfo.Model = gjson.Get("llm.model").String()
|
c.LLMInfo.Model = gjson.Get("llm.model").String()
|
||||||
|
c.LLMInfo.MaxIterations = gjson.Get("llm.maxIterations").Int()
|
||||||
|
if c.LLMInfo.MaxIterations == 0 {
|
||||||
|
c.LLMInfo.MaxIterations = 15
|
||||||
|
}
|
||||||
|
c.LLMInfo.MaxExecutionTime = gjson.Get("llm.maxExecutionTime").Int()
|
||||||
|
if c.LLMInfo.MaxExecutionTime == 0 {
|
||||||
|
c.LLMInfo.MaxExecutionTime = 50000
|
||||||
|
}
|
||||||
|
c.LLMInfo.MaxTokens = gjson.Get("llm.maxTokens").Int()
|
||||||
|
if c.LLMInfo.MaxTokens == 0 {
|
||||||
|
c.LLMInfo.MaxTokens = 1000
|
||||||
|
}
|
||||||
|
|
||||||
c.LLMClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
c.LLMClient = wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||||
FQDN: c.LLMInfo.ServiceName,
|
FQDN: c.LLMInfo.ServiceName,
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ type Usage struct {
|
|||||||
type Completion struct {
|
type Completion struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []Message `json:"messages"`
|
Messages []Message `json:"messages"`
|
||||||
|
MaxTokens int64 `json:"max_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
@@ -67,9 +68,3 @@ type CompletionUsage struct {
|
|||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Content struct {
|
|
||||||
CH_Question string `json:"ch_question"`
|
|
||||||
Core string `json:"core"`
|
|
||||||
// EN_Question string `json:"en_question"`
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -15,6 +15,14 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 用于统计函数的递归调用次数
|
||||||
|
const ToolCallsCount = "ToolCallsCount"
|
||||||
|
|
||||||
|
// react的正则规则
|
||||||
|
const ActionPattern = `Action:\s*(.*?)[.\n]`
|
||||||
|
const ActionInputPattern = `Action Input:\s*(.*)`
|
||||||
|
const FinalAnswerPattern = `Final Answer:(.*)`
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
wrapper.SetCtx(
|
wrapper.SetCtx(
|
||||||
"ai-agent",
|
"ai-agent",
|
||||||
@@ -103,9 +111,9 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
|||||||
//拼装agent prompt模板
|
//拼装agent prompt模板
|
||||||
tool_desc := make([]string, 0)
|
tool_desc := make([]string, 0)
|
||||||
tool_names := make([]string, 0)
|
tool_names := make([]string, 0)
|
||||||
for _, api_param := range config.API_Param {
|
for _, apiParam := range config.APIParam {
|
||||||
for _, tool_param := range api_param.Tool_Param {
|
for _, tool_param := range apiParam.Tool_Param {
|
||||||
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Desciption, tool_param.Desciption, tool_param.Desciption, tool_param.Parameter), "\n")
|
tool_desc = append(tool_desc, fmt.Sprintf(prompttpl.TOOL_DESC, tool_param.ToolName, tool_param.Description, tool_param.Description, tool_param.Description, tool_param.Parameter), "\n")
|
||||||
tool_names = append(tool_names, tool_param.ToolName)
|
tool_names = append(tool_names, tool_param.ToolName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -119,6 +127,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
|||||||
tool_names,
|
tool_names,
|
||||||
config.PromptTemplate.CHTemplate.ActionInput,
|
config.PromptTemplate.CHTemplate.ActionInput,
|
||||||
config.PromptTemplate.CHTemplate.Observation,
|
config.PromptTemplate.CHTemplate.Observation,
|
||||||
|
config.PromptTemplate.CHTemplate.Thought2,
|
||||||
config.PromptTemplate.CHTemplate.FinalAnswer,
|
config.PromptTemplate.CHTemplate.FinalAnswer,
|
||||||
config.PromptTemplate.CHTemplate.Begin,
|
config.PromptTemplate.CHTemplate.Begin,
|
||||||
query)
|
query)
|
||||||
@@ -130,11 +139,17 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
|||||||
tool_names,
|
tool_names,
|
||||||
config.PromptTemplate.ENTemplate.ActionInput,
|
config.PromptTemplate.ENTemplate.ActionInput,
|
||||||
config.PromptTemplate.ENTemplate.Observation,
|
config.PromptTemplate.ENTemplate.Observation,
|
||||||
|
config.PromptTemplate.ENTemplate.Thought2,
|
||||||
config.PromptTemplate.ENTemplate.FinalAnswer,
|
config.PromptTemplate.ENTemplate.FinalAnswer,
|
||||||
config.PromptTemplate.ENTemplate.Begin,
|
config.PromptTemplate.ENTemplate.Begin,
|
||||||
query)
|
query)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx.SetContext(ToolCallsCount, 0)
|
||||||
|
|
||||||
|
//清理历史对话记录
|
||||||
|
dashscope.MessageStore.Clear()
|
||||||
|
|
||||||
//将请求加入到历史对话存储器中
|
//将请求加入到历史对话存储器中
|
||||||
dashscope.MessageStore.AddForUser(prompt)
|
dashscope.MessageStore.AddForUser(prompt)
|
||||||
|
|
||||||
@@ -145,97 +160,27 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config PluginConfig, body []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, config PluginConfig, log wrapper.Log) types.Action {
|
||||||
|
log.Debug("onHttpResponseHeaders start")
|
||||||
|
defer log.Debug("onHttpResponseHeaders end")
|
||||||
|
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
func toolsCall(config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
|
func toolsCallResult(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log, statusCode int, responseBody []byte) {
|
||||||
dashscope.MessageStore.AddForAssistant(content)
|
|
||||||
|
|
||||||
//得到最终答案
|
|
||||||
regexPattern := regexp.MustCompile(`Final Answer:(.*)`)
|
|
||||||
finalAnswer := regexPattern.FindStringSubmatch(content)
|
|
||||||
if len(finalAnswer) > 1 {
|
|
||||||
return types.ActionContinue
|
|
||||||
}
|
|
||||||
|
|
||||||
//没得到最终答案
|
|
||||||
regexAction := regexp.MustCompile(`Action:\s*(.*?)[.\n]`)
|
|
||||||
regexActionInput := regexp.MustCompile(`Action Input:\s*(.*)`)
|
|
||||||
|
|
||||||
action := regexAction.FindStringSubmatch(content)
|
|
||||||
actionInput := regexActionInput.FindStringSubmatch(content)
|
|
||||||
|
|
||||||
if len(action) > 1 && len(actionInput) > 1 {
|
|
||||||
var url string
|
|
||||||
var headers [][2]string
|
|
||||||
var apiClient wrapper.HttpClient
|
|
||||||
var method string
|
|
||||||
|
|
||||||
for i, api_param := range config.API_Param {
|
|
||||||
for _, tool_param := range api_param.Tool_Param {
|
|
||||||
if action[1] == tool_param.ToolName {
|
|
||||||
log.Infof("calls %s\n", tool_param.ToolName)
|
|
||||||
log.Infof("actionInput[1]: %s", actionInput[1])
|
|
||||||
|
|
||||||
//将大模型需要的参数反序列化
|
|
||||||
var data map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(actionInput[1]), &data); err != nil {
|
|
||||||
log.Debugf("Error: %s\n", err.Error())
|
|
||||||
return types.ActionContinue
|
|
||||||
}
|
|
||||||
|
|
||||||
var args string
|
|
||||||
for i, param := range tool_param.ParamName { //从参数列表中取出参数
|
|
||||||
if i == 0 {
|
|
||||||
args = "?" + param + "=%s"
|
|
||||||
args = fmt.Sprintf(args, data[param])
|
|
||||||
} else {
|
|
||||||
args = args + "&" + param + "=%s"
|
|
||||||
args = fmt.Sprintf(args, data[param])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
url = api_param.URL + tool_param.Path + args
|
|
||||||
|
|
||||||
if api_param.APIKey.Name != "" {
|
|
||||||
if api_param.APIKey.In == "query" {
|
|
||||||
headers = nil
|
|
||||||
key := "&" + api_param.APIKey.Name + "=" + api_param.APIKey.Value
|
|
||||||
url += key
|
|
||||||
} else if api_param.APIKey.In == "header" {
|
|
||||||
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", api_param.APIKey.Name + " " + api_param.APIKey.Value}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("url: %s\n", url)
|
|
||||||
|
|
||||||
method = tool_param.Method
|
|
||||||
|
|
||||||
apiClient = config.APIClient[i]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if method == "get" {
|
|
||||||
//调用工具
|
|
||||||
err := apiClient.Get(
|
|
||||||
url,
|
|
||||||
headers,
|
|
||||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
|
||||||
if statusCode != http.StatusOK {
|
if statusCode != http.StatusOK {
|
||||||
log.Debugf("statusCode: %d\n", statusCode)
|
log.Debugf("statusCode: %d\n", statusCode)
|
||||||
}
|
}
|
||||||
log.Info("========函数返回结果========")
|
log.Info("========函数返回结果========")
|
||||||
log.Infof(string(responseBody))
|
log.Infof(string(responseBody))
|
||||||
|
|
||||||
Observation := "Observation: " + string(responseBody)
|
observation := "Observation: " + string(responseBody)
|
||||||
|
|
||||||
dashscope.MessageStore.AddForUser(Observation)
|
dashscope.MessageStore.AddForUser(observation)
|
||||||
|
|
||||||
completion := dashscope.Completion{
|
completion := dashscope.Completion{
|
||||||
Model: config.LLMInfo.Model,
|
Model: config.LLMInfo.Model,
|
||||||
Messages: dashscope.MessageStore,
|
Messages: dashscope.MessageStore,
|
||||||
|
MaxTokens: config.LLMInfo.MaxTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
|
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.LLMInfo.APIKey}}
|
||||||
@@ -248,10 +193,10 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
|
|||||||
//得到gpt的返回结果
|
//得到gpt的返回结果
|
||||||
var responseCompletion dashscope.CompletionResponse
|
var responseCompletion dashscope.CompletionResponse
|
||||||
_ = json.Unmarshal(responseBody, &responseCompletion)
|
_ = json.Unmarshal(responseBody, &responseCompletion)
|
||||||
log.Infof("[toolsCall] content: ", responseCompletion.Choices[0].Message.Content)
|
log.Infof("[toolsCall] content: %s\n", responseCompletion.Choices[0].Message.Content)
|
||||||
|
|
||||||
if responseCompletion.Choices[0].Message.Content != "" {
|
if responseCompletion.Choices[0].Message.Content != "" {
|
||||||
retType := toolsCall(config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
retType := toolsCall(ctx, config, responseCompletion.Choices[0].Message.Content, rawResponse, log)
|
||||||
if retType == types.ActionContinue {
|
if retType == types.ActionContinue {
|
||||||
//得到了Final Answer
|
//得到了Final Answer
|
||||||
var assistantMessage Message
|
var assistantMessage Message
|
||||||
@@ -262,7 +207,7 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
|
|||||||
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
|
extractedText := responseCompletion.Choices[0].Message.Content[startIndex:]
|
||||||
assistantMessage.Content = extractedText
|
assistantMessage.Content = extractedText
|
||||||
}
|
}
|
||||||
//assistantMessage.Content = responseCompletion.Choices[0].Message.Content
|
|
||||||
rawResponse.Choices[0].Message = assistantMessage
|
rawResponse.Choices[0].Message = assistantMessage
|
||||||
|
|
||||||
newbody, err := json.Marshal(rawResponse)
|
newbody, err := json.Marshal(rawResponse)
|
||||||
@@ -280,11 +225,118 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
|
|||||||
} else {
|
} else {
|
||||||
proxywasm.ResumeHttpRequest()
|
proxywasm.ResumeHttpRequest()
|
||||||
}
|
}
|
||||||
}, 50000)
|
}, uint32(config.LLMInfo.MaxExecutionTime))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
log.Debugf("[onHttpRequestBody] completion err: %s", err.Error())
|
||||||
proxywasm.ResumeHttpRequest()
|
proxywasm.ResumeHttpRequest()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolsCall(ctx wrapper.HttpContext, config PluginConfig, content string, rawResponse Response, log wrapper.Log) types.Action {
|
||||||
|
dashscope.MessageStore.AddForAssistant(content)
|
||||||
|
|
||||||
|
//得到最终答案
|
||||||
|
regexPattern := regexp.MustCompile(FinalAnswerPattern)
|
||||||
|
finalAnswer := regexPattern.FindStringSubmatch(content)
|
||||||
|
if len(finalAnswer) > 1 {
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
count := ctx.GetContext(ToolCallsCount).(int)
|
||||||
|
count++
|
||||||
|
log.Debugf("toolCallsCount:%d, config.LLMInfo.MaxIterations=%d\n", count, config.LLMInfo.MaxIterations)
|
||||||
|
//函数递归调用次数,达到了预设的循环次数,强制结束
|
||||||
|
if int64(count) > config.LLMInfo.MaxIterations {
|
||||||
|
ctx.SetContext(ToolCallsCount, 0)
|
||||||
|
return types.ActionContinue
|
||||||
|
} else {
|
||||||
|
ctx.SetContext(ToolCallsCount, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
//没得到最终答案
|
||||||
|
regexAction := regexp.MustCompile(ActionPattern)
|
||||||
|
regexActionInput := regexp.MustCompile(ActionInputPattern)
|
||||||
|
|
||||||
|
action := regexAction.FindStringSubmatch(content)
|
||||||
|
actionInput := regexActionInput.FindStringSubmatch(content)
|
||||||
|
|
||||||
|
if len(action) > 1 && len(actionInput) > 1 {
|
||||||
|
var url string
|
||||||
|
var headers [][2]string
|
||||||
|
var apiClient wrapper.HttpClient
|
||||||
|
var method string
|
||||||
|
var reqBody []byte
|
||||||
|
var key string
|
||||||
|
|
||||||
|
for i, apiParam := range config.APIParam {
|
||||||
|
for _, tool_param := range apiParam.Tool_Param {
|
||||||
|
if action[1] == tool_param.ToolName {
|
||||||
|
log.Infof("calls %s\n", tool_param.ToolName)
|
||||||
|
log.Infof("actionInput[1]: %s", actionInput[1])
|
||||||
|
|
||||||
|
//将大模型需要的参数反序列化
|
||||||
|
var data map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(actionInput[1]), &data); err != nil {
|
||||||
|
log.Debugf("Error: %s\n", err.Error())
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
|
||||||
|
method = tool_param.Method
|
||||||
|
|
||||||
|
//key or header组装
|
||||||
|
if apiParam.APIKey.Name != "" {
|
||||||
|
if apiParam.APIKey.In == "query" { //query类型的key要放到url中
|
||||||
|
headers = nil
|
||||||
|
key = "?" + apiParam.APIKey.Name + "=" + apiParam.APIKey.Value
|
||||||
|
} else if apiParam.APIKey.In == "header" { //header类型的key放在header中
|
||||||
|
headers = [][2]string{{"Content-Type", "application/json"}, {"Authorization", apiParam.APIKey.Name + " " + apiParam.APIKey.Value}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if method == "GET" {
|
||||||
|
//query组装
|
||||||
|
var args string
|
||||||
|
for i, param := range tool_param.ParamName { //从参数列表中取出参数
|
||||||
|
if i == 0 && apiParam.APIKey.In != "query" {
|
||||||
|
args = "?" + param + "=%s"
|
||||||
|
args = fmt.Sprintf(args, data[param])
|
||||||
|
} else {
|
||||||
|
args = args + "&" + param + "=%s"
|
||||||
|
args = fmt.Sprintf(args, data[param])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//url组装
|
||||||
|
url = apiParam.URL + tool_param.Path + key + args
|
||||||
|
} else if method == "POST" {
|
||||||
|
reqBody = nil
|
||||||
|
//json参数组装
|
||||||
|
jsonData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Error: %s\n", err.Error())
|
||||||
|
return types.ActionContinue
|
||||||
|
}
|
||||||
|
reqBody = jsonData
|
||||||
|
|
||||||
|
//url组装
|
||||||
|
url = apiParam.URL + tool_param.Path + key
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("url: %s\n", url)
|
||||||
|
|
||||||
|
apiClient = config.APIClient[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiClient != nil {
|
||||||
|
err := apiClient.Call(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
headers,
|
||||||
|
reqBody,
|
||||||
|
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||||
|
toolsCallResult(ctx, config, content, rawResponse, log, statusCode, responseBody)
|
||||||
}, 50000)
|
}, 50000)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("tool calls error: %s\n", err.Error())
|
log.Debugf("tool calls error: %s\n", err.Error())
|
||||||
@@ -293,7 +345,6 @@ func toolsCall(config PluginConfig, content string, rawResponse Response, log wr
|
|||||||
} else {
|
} else {
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
return types.ActionPause
|
return types.ActionPause
|
||||||
}
|
}
|
||||||
@@ -314,7 +365,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config PluginConfig, body []byt
|
|||||||
//如果gpt返回的内容不是空的
|
//如果gpt返回的内容不是空的
|
||||||
if rawResponse.Choices[0].Message.Content != "" {
|
if rawResponse.Choices[0].Message.Content != "" {
|
||||||
//进入agent的循环思考,工具调用的过程中
|
//进入agent的循环思考,工具调用的过程中
|
||||||
return toolsCall(config, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
return toolsCall(ctx, config, rawResponse.Choices[0].Message.Content, rawResponse, log)
|
||||||
} else {
|
} else {
|
||||||
return types.ActionContinue
|
return types.ActionContinue
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user