From 6f4ef335904c2f044ccd5b16bd725e769500077f Mon Sep 17 00:00:00 2001 From: Wangzy Date: Mon, 22 Dec 2025 09:46:31 +0800 Subject: [PATCH] Add tool-search server (#3136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 澄潭 --- plugins/golang-filter/go.mod | 1 - plugins/golang-filter/go.sum | 4 - plugins/golang-filter/mcp-server/config.go | 1 + .../mcp-server/servers/tool-search/README.md | 144 +++++++++++ .../servers/tool-search/config-example.json | 18 ++ .../servers/tool-search/embedding.go | 79 ++++++ .../mcp-server/servers/tool-search/milvus.go | 204 +++++++++++++++ .../mcp-server/servers/tool-search/search.go | 237 ++++++++++++++++++ .../mcp-server/servers/tool-search/server.go | 196 +++++++++++++++ .../servers/tool-search/server_test.go | 198 +++++++++++++++ .../mcp-server/servers/tool-search/tools.go | 114 +++++++++ 11 files changed, 1191 insertions(+), 5 deletions(-) create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/README.md create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/config-example.json create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/embedding.go create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/milvus.go create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/search.go create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/server.go create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/server_test.go create mode 100644 plugins/golang-filter/mcp-server/servers/tool-search/tools.go diff --git a/plugins/golang-filter/go.mod b/plugins/golang-filter/go.mod index 3ba96a1da..039bc7cb5 100644 --- a/plugins/golang-filter/go.mod +++ b/plugins/golang-filter/go.mod @@ -53,7 +53,6 @@ require ( github.com/cockroachdb/errors v1.9.1 // indirect github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect github.com/cockroachdb/redact v1.1.3 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/deckarep/golang-set v1.7.1 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect github.com/getsentry/sentry-go v0.12.0 // indirect diff --git a/plugins/golang-filter/go.sum b/plugins/golang-filter/go.sum index e43f479af..b6e865652 100644 --- a/plugins/golang-filter/go.sum +++ b/plugins/golang-filter/go.sum @@ -185,10 +185,7 @@ github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TB github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= -github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= -github.com/go-faker/faker/v4 v4.1.0 h1:ffuWmpDrducIUOO0QSKSF5Q2dxAht+dhsT9FvVHhPEI= -github.com/go-faker/faker/v4 v4.1.0/go.mod h1:uuNc0PSRxF8nMgjGrrrU4Nw5cF30Jc6Kd0/FUTTYbhg= github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw= github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw= github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= @@ -429,7 +426,6 @@ github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKf github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= diff --git a/plugins/golang-filter/mcp-server/config.go b/plugins/golang-filter/mcp-server/config.go index 301ffa515..521ba2dc9 100644 --- a/plugins/golang-filter/mcp-server/config.go +++ b/plugins/golang-filter/mcp-server/config.go @@ -8,6 +8,7 @@ import ( _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api" _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-ops" _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag" + _ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/tool-search" mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session" "github.com/alibaba/higress/plugins/golang-filter/mcp-session/common" xds "github.com/cncf/xds/go/xds/type/v3" diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/README.md b/plugins/golang-filter/mcp-server/servers/tool-search/README.md new file mode 100644 index 000000000..f264cabd9 --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/README.md @@ -0,0 +1,144 @@ +# Tool Search MCP Server + +这是一个基于 Higress Golang Filter 实现的 MCP Server,用于提供工具语义搜索功能。当前实现**仅支持向量语义搜索**(基于 Milvus 向量数据库),**不包含全文检索或混合搜索**。 + +## 功能特性 + +- **向量语义搜索**:使用 OpenAI 兼容的 Embedding API 将用户查询转换为向量,并在 Milvus 中进行相似度检索 +- **工具元数据支持**:从数据库中读取完整的工具定义(JSON 格式),并动态拼接工具名称 +- **全量工具列表**:支持获取数据库中所有可用工具 +- **可配置 Embedding 模型**:支持自定义模型、维度及 API 端点(如 DashScope) +- **Milvus 集成**:通过标准 gRPC 接口连接 Milvus 向量数据库 + +## 数据库要求(Milvus) + +本服务依赖 **Milvus 向量数据库**,需预先创建集合(Collection),其 Schema 应包含以下字段: + +| 字段名 | 类型 | 说明 | +|--------------|-------------------|-------------------------| +| `id` | VarChar(64) | 文档唯一 ID | +| `content` | VarChar(64) | 工具描述文本 | +| `metadata` | JSON | 完整的工具定义(必须包含 `name` 字段) | +| `vector` | FloatVector(1024) | embedding 向量 | +| `metadata` | Int64 | 创建时间 | + + +## 配置参数 + +### 根级配置 + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|--------------|--------|------|-----------------------------------------------------|------| +| `vector` | object | 是 | - | 向量数据库配置(见下文) | +| `embedding` | object | 是 | - | Embedding API 配置(见下文) | +| `description`| string | 否 | `"Tool search server for semantic similarity search"` | MCP Server 描述信息 | + +### Vector 配置(`vector` 对象) + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|-------------|--------|------|--------------------|------| +| `type` | string | 是 | - | **必须为 `"milvus"`** | +| `host` | string | 是 | - | Milvus 服务地址(如 `localhost`) | +| `port` | int | 是 | - | Milvus gRPC 端口(如 `19530`) | +| `database` | string | 否 | `"default"` | Milvus 数据库名 | +| `tableName` | string | 否 | `"apig_mcp_tools"` | Milvus 集合名 | +| `username` | string | 否 | - | 认证用户名(可选) | +| `password` | string | 否 | - | 认证密码(可选) | + +### Embedding 配置(`embedding` 对象) + +| 参数 | 类型 | 必填 | 默认值 | 说明 | +|--------------|--------|------|-----------------------------------------------------------|------| +| `apiKey` | string | 是 | - | Embedding 服务的 API Key | +| `baseURL` | string | 否 | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI 兼容 API 的 Base URL | +| `model` | string | 否 | `text-embedding-v4` | 使用的 Embedding 模型 | +| `dimensions` | int | 否 | `1024` | 向量维度 | + +## 配置示例 + + +Tool Search MCP Server 也可以作为 Higress 的一个模块进行配置。以下是一个在 Higress ConfigMap 中配置 Tool Search 的示例: + +```yaml +apiVersion: v1 +kind: ConfigMap +metadata: + name: higress-config + namespace: higress-system +data: + higress: | + mcpServer: + enable: true + sse_path_suffix: "/sse" + redis: + address: ":6379" + username: "" + password: "" + db: 0 + match_list: + - path_rewrite_prefix: "" + upstream_type: "" + enable_path_rewrite: false + match_rule_domain: "*" + match_rule_path: "/mcp-servers/tool-search" + match_rule_type: "prefix" + servers: + - path: "/mcp-servers/tool-search" + name: "tool-search" + type: "tool-search" + config: + vector: + type: "milvus" + host: "localhost" + port: 19530 + database: "default" + tableName: "apig_mcp_tools" + username: "root" + password: "Milvus" + maxTools: 1000 + embedding: + apiKey: "your-dashscope-api-key" + baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1" + model: "text-embedding-v4" + dimensions: 1024 + description: "Higress 工具语义搜索服务" +``` + +## 工具搜索接口 + +Tool Search MCP Server 提供以下 MCP 工具: + +### x_higress_tool_search + +基于语义相似度搜索最相关的工具。 + +**输入参数**: + +| 参数名 | 类型 | 必填 | 说明 | +|---------|--------|------|------| +| `query` | string | 是 | 查询语句,用于与工具描述进行语义相似度比较 | +| `topK` | int | 否 | 指定需要选择的工具数量,默认选择前10个工具 | + +**输出格式**: + +``` +{ + "tools": [ + { + "name": "server_name___tool_name", + "title": "Tool Title", + "description": "Tool description", + "inputSchema": {...}, + "outputSchema": {...} + } + ] +} +``` + + +## 搜索实现 + +通过向量相似度进行搜索,索引配置如下 +- 使用 HNSW 索引算法进行向量索引 +- 默认参数:M=8, efConstruction=64 +- 相似度度量方式:内积(IP) \ No newline at end of file diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/config-example.json b/plugins/golang-filter/mcp-server/servers/tool-search/config-example.json new file mode 100644 index 000000000..d29499b4b --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/config-example.json @@ -0,0 +1,18 @@ +{ + "vector": { + "type": "milvus", + "vectorWeight": 0.5, + "tableName": "apig_mcp_tools", + "host": "localhost", + "port": 19530, + "database": "default", + "username": "root", + "password": "Milvus" + }, + "embedding": { + "apiKey": "your-dashscope-api-key", + "baseURL": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "text-embedding-v4", + "dimensions": 1024 + } +} \ No newline at end of file diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/embedding.go b/plugins/golang-filter/mcp-server/servers/tool-search/embedding.go new file mode 100644 index 000000000..324afbf1a --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/embedding.go @@ -0,0 +1,79 @@ +package tool_search + +import ( + "context" + "fmt" + "time" + + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" +) + +// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs +type EmbeddingClient struct { + client *openai.Client + model string + dimensions int +} + +// NewEmbeddingClient creates a new EmbeddingClient instance for OpenAI-compatible APIs +func NewEmbeddingClient(apiKey, baseURL, model string, dimensions int) *EmbeddingClient { + api.LogInfof("Creating EmbeddingClient with baseURL: %s, model: %s, dimensions: %d", baseURL, model, dimensions) + + // Create client with timeout + client := openai.NewClient( + option.WithAPIKey(apiKey), + option.WithBaseURL(baseURL), + option.WithRequestTimeout(30*time.Second), + ) + + return &EmbeddingClient{ + client: &client, + model: model, + dimensions: dimensions, + } +} + +// GetEmbedding generates vector embedding for the given text +func (e *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) { + api.LogInfof("Generating embedding for text (length: %d)", len(text)) + api.LogDebugf("Using model: %s, dimensions: %d", e.model, e.dimensions) + + // Add timeout to context if not already present + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + params := openai.EmbeddingNewParams{ + Model: e.model, + Input: openai.EmbeddingNewParamsInputUnion{ + OfString: openai.String(text), + }, + Dimensions: openai.Int(int64(e.dimensions)), + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, + } + + api.LogDebugf("Calling OpenAI-compatible API for embedding generation") + embeddingResp, err := e.client.Embeddings.New(ctx, params) + if err != nil { + api.LogErrorf("OpenAI-compatible API call failed: %v", err) + return nil, fmt.Errorf("failed to generate embedding: %w", err) + } + + if len(embeddingResp.Data) == 0 { + api.LogErrorf("Empty embedding response from API") + return nil, fmt.Errorf("empty embedding response") + } + + api.LogDebugf("Successfully received embedding from API") + api.LogDebugf("Response data length: %d, embedding dimension: %d", len(embeddingResp.Data), len(embeddingResp.Data[0].Embedding)) + + // Convert []float64 to []float32 + embedding := make([]float32, len(embeddingResp.Data[0].Embedding)) + for i, v := range embeddingResp.Data[0].Embedding { + embedding[i] = float32(v) + } + + api.LogInfof("Embedding conversion completed, final dimension: %d", len(embedding)) + return embedding, nil +} diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/milvus.go b/plugins/golang-filter/mcp-server/servers/tool-search/milvus.go new file mode 100644 index 000000000..1cbf2e9f9 --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/milvus.go @@ -0,0 +1,204 @@ +package tool_search + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema" + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +type MilvusVectorStoreProvider struct { + client client.Client + collection string + dimensions int +} + +func NewMilvusVectorStoreProvider(cfg *config.VectorDBConfig, dimensions int) (*MilvusVectorStoreProvider, error) { + connectParam := client.Config{ + Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + } + connectParam.DBName = cfg.Database + + if cfg.Username != "" && cfg.Password != "" { + connectParam.Username = cfg.Username + connectParam.Password = cfg.Password + } + + milvusClient, err := client.NewClient(context.Background(), connectParam) + if err != nil { + return nil, fmt.Errorf("failed to create milvus client: %w", err) + } + + return &MilvusVectorStoreProvider{ + client: milvusClient, + collection: cfg.Collection, + dimensions: dimensions, + }, nil +} + +func (c *MilvusVectorStoreProvider) ListAllDocs(ctx context.Context, limit int) ([]schema.Document, error) { + expr := "" + + outputFields := []string{"id", "content", "metadata", "created_at"} + + var queryResult []entity.Column + var err error + + if limit > 0 { + queryResult, err = c.client.Query( + ctx, + c.collection, + []string{}, // partitions + expr, // filter condition + outputFields, + client.WithLimit(int64(limit)), + ) + } else { + queryResult, err = c.client.Query( + ctx, + c.collection, + []string{}, // partitions + expr, // filter condition + outputFields, + ) + } + + if err != nil { + return nil, fmt.Errorf("failed to query all documents: %w", err) + } + + if len(queryResult) == 0 { + return []schema.Document{}, nil + } + + rowCount := queryResult[0].Len() + documents := make([]schema.Document, 0, rowCount) + + for i := 0; i < rowCount; i++ { + var ( + id string + content string + metadata map[string]interface{} + createdAt int64 + ) + + for _, col := range queryResult { + switch col.Name() { + case "id": + if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil { + id = v.(string) + } + case "content": + if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil { + content = v.(string) + } + case "metadata": + if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil { + if bytes, ok := v.([]byte); ok { + _ = json.Unmarshal(bytes, &metadata) + } + } + case "created_at": + if v, err := col.(*entity.ColumnInt64).Get(i); err == nil { + createdAt = v.(int64) + } + } + } + + doc := schema.Document{ + ID: id, + Content: content, + Metadata: metadata, + CreatedAt: time.UnixMilli(createdAt), + } + documents = append(documents, doc) + } + + return documents, nil +} + +func (c *MilvusVectorStoreProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) { + if options == nil { + options = &schema.SearchOptions{TopK: 10} + } + + sp, err := entity.NewIndexHNSWSearchParam(16) // 默认 HNSW 搜索参数 + if err != nil { + return nil, fmt.Errorf("failed to build search param: %w", err) + } + + outputFields := []string{"id", "content", "metadata"} + searchResults, err := c.client.Search( + ctx, + c.collection, + []string{}, // partition names + "", // filter expression + outputFields, // output fields + []entity.Vector{entity.FloatVector(vector)}, + "vector", // anns_field + entity.IP, // metric_type + options.TopK, + sp, + ) + + if err != nil { + return nil, fmt.Errorf("failed to search documents: %w", err) + } + + var results []schema.SearchResult + for _, result := range searchResults { + for i := 0; i < result.ResultCount; i++ { + id, _ := result.IDs.Get(i) + score := result.Scores[i] + + var content string + var metadata map[string]interface{} + for _, field := range result.Fields { + switch field.Name() { + case "content": + if contentCol, ok := field.(*entity.ColumnVarChar); ok { + if contentVal, err := contentCol.Get(i); err == nil { + if contentStr, ok := contentVal.(string); ok { + content = contentStr + } + } + } + case "metadata": + if metaCol, ok := field.(*entity.ColumnJSONBytes); ok { + if metaVal, err := metaCol.Get(i); err == nil { + if metaBytes, ok := metaVal.([]byte); ok { + if err := json.Unmarshal(metaBytes, &metadata); err != nil { + metadata = make(map[string]interface{}) + } + } + } + } + } + } + + searchResult := schema.SearchResult{ + Document: schema.Document{ + ID: fmt.Sprintf("%s", id), + Content: content, + Metadata: metadata, + }, + Score: float64(score), + } + results = append(results, searchResult) + } + } + + return results, nil +} + +func (c *MilvusVectorStoreProvider) Close() error { + if c.client != nil { + return c.client.Close() + } + return nil +} diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/search.go b/plugins/golang-filter/mcp-server/servers/tool-search/search.go new file mode 100644 index 000000000..bdcff9728 --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/search.go @@ -0,0 +1,237 @@ +package tool_search + +import ( + "context" + "fmt" + "time" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema" + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" +) + +// SearchService handles tool search operations +type SearchService struct { + milvusProvider *MilvusVectorStoreProvider + config *config.VectorDBConfig + tableName string + dimensions int + maxTools int // 写死的最大工具数量,仅用于单测 + embeddingClient *EmbeddingClient +} + +// NewSearchService creates a new SearchService instance +func NewSearchService(host string, port int, database, username, password, tableName string, embeddingClient *EmbeddingClient, dimensions int, maxTools int) *SearchService { + // Create Milvus configuration + cfg := &config.VectorDBConfig{ + Provider: "milvus", + Host: host, + Port: port, + Database: database, + Collection: tableName, + Username: username, + Password: password, + } + + // Create Milvus provider + provider, err := NewMilvusVectorStoreProvider(cfg, dimensions) + if err != nil { + api.LogErrorf("Failed to create Milvus provider: %v", err) + return nil + } + + return &SearchService{ + milvusProvider: provider, + config: cfg, + tableName: tableName, + dimensions: dimensions, + maxTools: maxTools, // 使用写死的值 + embeddingClient: embeddingClient, + } +} + +// ToolSearchResult represents the result of a tool search +type ToolSearchResult struct { + Tools []ToolDefinition `json:"tools"` +} + +// ToolDefinition represents a tool definition in the search result +type ToolDefinition map[string]interface{} + +// SearchTools performs semantic search for tools +func (s *SearchService) SearchTools(ctx context.Context, query string, topK int) (*ToolSearchResult, error) { + api.LogInfof("Starting tool search for query: '%s', topK: %d", query, topK) + + // Generate vector embedding for the query + vector, err := s.embeddingClient.GetEmbedding(ctx, query) + if err != nil { + api.LogErrorf("Failed to generate embedding for query '%s': %v", query, err) + return nil, fmt.Errorf("failed to generate embedding: %w", err) + } + + api.LogInfof("Embedding generated successfully, vector dimension: %d", len(vector)) + + // Perform vector search + records, err := s.searchToolsInDB(query, vector, topK) + if err != nil { + api.LogErrorf("Failed to search tools: %v", err) + return nil, fmt.Errorf("failed to search tools: %w", err) + } + + api.LogInfof("Vector search completed, found %d records", len(records)) + + return s.convertRecordsToResult(records), nil +} + +// convertRecordsToResult converts database records to tool search result +func (s *SearchService) convertRecordsToResult(records []ToolRecord) *ToolSearchResult { + api.LogInfof("Converting %d records to tool definitions", len(records)) + + tools := make([]ToolDefinition, 0, len(records)) + for i, record := range records { + var tool ToolDefinition + + // Use metadata if available + if len(record.Metadata) > 0 { + tool = record.Metadata + api.LogDebugf("Successfully parsed metadata for tool %s", record.Name) + } else { + api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name) + // If no metadata, create a basic tool definition + tool = ToolDefinition{ + "name": record.Name, + "description": record.Content, + } + } + + // Update the name to include server name + tool["name"] = fmt.Sprintf("%s", record.Name) + + tools = append(tools, tool) + + api.LogDebugf("Tool %d: %s - %s", i+1, tool["name"], record.Content) + } + + api.LogInfof("Successfully converted %d tools", len(tools)) + return &ToolSearchResult{Tools: tools} +} + +// GetAllTools retrieves all available tools +func (s *SearchService) GetAllTools() (*ToolSearchResult, error) { + api.LogInfo("Retrieving all tools") + records, err := s.getAllToolsFromDB() + if err != nil { + api.LogErrorf("Failed to get all tools: %v", err) + return nil, fmt.Errorf("failed to get all tools: %w", err) + } + + api.LogInfof("Found %d tools in database", len(records)) + + // Convert records to tool definitions + tools := make([]ToolDefinition, 0, len(records)) + for _, record := range records { + var tool ToolDefinition + + // Use metadata if available + if len(record.Metadata) > 0 { + tool = record.Metadata + api.LogDebugf("Successfully parsed metadata for tool %s", record.Name) + } else { + api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name) + // If no metadata, create a basic tool definition + tool = ToolDefinition{ + "name": record.Name, + "description": record.Content, + } + } + + // Update the name to include server name + tool["name"] = fmt.Sprintf("%s", record.Name) + + tools = append(tools, tool) + } + + api.LogInfof("Successfully converted %d tools", len(tools)) + return &ToolSearchResult{Tools: tools}, nil +} + +// ToolRecord represents a tool record in the database +type ToolRecord struct { + ID string `json:"id"` + Name string `json:"name"` + Content string `json:"content"` + Metadata map[string]interface{} `json:"metadata"` +} + +func (s *SearchService) searchToolsInDB(query string, vector []float32, topK int) ([]ToolRecord, error) { + api.LogInfof("Performing vector search for query: '%s', topK: %d", query, topK) + + // For Milvus, we'll perform vector search directly + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Perform vector search + searchOptions := &schema.SearchOptions{ + TopK: topK, + } + + results, err := s.milvusProvider.SearchDocs(ctx, vector, searchOptions) + if err != nil { + api.LogErrorf("Vector search failed: %v", err) + return nil, fmt.Errorf("failed to perform vector search: %w", err) + } + + // Convert results to ToolRecords + var records []ToolRecord + for _, result := range results { + doc := result.Document + tool := ToolRecord{ + ID: doc.ID, + Content: doc.Content, + Metadata: doc.Metadata, + } + + if name, ok := doc.Metadata["name"].(string); ok { + tool.Name = name + } + + records = append(records, tool) + } + + api.LogInfof("Vector search completed, found %d results", len(records)) + return records, nil +} + +// getAllToolsFromDB retrieves all tools from the database +func (s *SearchService) getAllToolsFromDB() ([]ToolRecord, error) { + api.LogInfof("Executing GetAllTools query from collection: %s", s.tableName) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Retrieve all documents with limit + docs, err := s.milvusProvider.ListAllDocs(ctx, s.maxTools) + if err != nil { + api.LogErrorf("Failed to list documents: %v", err) + return nil, fmt.Errorf("failed to list documents: %w", err) + } + + // Convert documents to ToolRecords + var tools []ToolRecord + for _, doc := range docs { + tool := ToolRecord{ + ID: doc.ID, + Content: doc.Content, + Metadata: doc.Metadata, + } + + if name, ok := doc.Metadata["name"].(string); ok { + tool.Name = name + } + + tools = append(tools, tool) + } + + api.LogInfof("GetAllTools query completed, found %d tools", len(tools)) + return tools, nil +} diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/server.go b/plugins/golang-filter/mcp-server/servers/tool-search/server.go new file mode 100644 index 000000000..5b40f6af4 --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/server.go @@ -0,0 +1,196 @@ +package tool_search + +import ( + "errors" + "fmt" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-session/common" + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "github.com/mark3labs/mcp-go/mcp" +) + +const ( + Version = "1.0.0" + + // 默认配置值 + defaultTableName = "apig_mcp_tools" + defaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1" + defaultModel = "text-embedding-v4" + defaultDimensions = 1024 + // 写死最大工具数量为1000,仅用于单测 + fixedMaxTools = 1000 +) + +func init() { + common.GlobalRegistry.RegisterServer("tool-search", &ToolSearchConfig{}) +} + +type VectorConfig struct { + Type string `json:"type"` + VectorWeight float64 `json:"vectorWeight"` + TableName string `json:"tableName"` + Host string `json:"host"` + Port int `json:"port"` + Database string `json:"database"` + Username string `json:"username"` + Password string `json:"password"` +} + +type EmbeddingConfig struct { + APIKey string `json:"apiKey"` + BaseURL string `json:"baseURL"` + Model string `json:"model"` + Dimensions int `json:"dimensions"` +} + +type ToolSearchConfig struct { + Vector VectorConfig `json:"vector"` + Embedding EmbeddingConfig `json:"embedding"` + description string +} + +func (c *ToolSearchConfig) ParseConfig(config map[string]any) error { + // Parse vector configuration + vectorConfig, ok := config["vector"].(map[string]any) + if !ok { + return errors.New("missing vector configuration") + } + + if err := c.parseVectorConfig(vectorConfig); err != nil { + return fmt.Errorf("failed to parse vector config: %w", err) + } + + // Parse embedding configuration + embeddingConfig, ok := config["embedding"].(map[string]any) + if !ok { + return errors.New("missing embedding configuration") + } + + if err := c.parseEmbeddingConfig(embeddingConfig); err != nil { + return fmt.Errorf("failed to parse embedding config: %w", err) + } + + // Optional description + if description, ok := config["description"].(string); ok { + c.description = description + } else { + c.description = "Tool search server for semantic similarity search" + } + + api.LogDebugf("ToolSearchConfig ParseConfig: %+v", config) + return nil +} + +func (c *ToolSearchConfig) parseVectorConfig(config map[string]any) error { + if vectorType, ok := config["type"].(string); ok { + c.Vector.Type = vectorType + } else { + return errors.New("missing vector.type") + } + + if c.Vector.Type != "milvus" { + return fmt.Errorf("unsupported vector.type: %s, only 'milvus' is supported", c.Vector.Type) + } + + if host, ok := config["host"].(string); ok { + c.Vector.Host = host + } else { + return errors.New("missing vector.host") + } + + if port, ok := config["port"].(float64); ok { + c.Vector.Port = int(port) + } else if port, ok := config["port"].(int); ok { + c.Vector.Port = port + } else { + return errors.New("missing vector.port") + } + + if database, ok := config["database"].(string); ok { + c.Vector.Database = database + } else { + c.Vector.Database = "default" // 默认数据库 + } + + if tableName, ok := config["tableName"].(string); ok { + c.Vector.TableName = tableName + } else { + c.Vector.TableName = defaultTableName + } + + if username, ok := config["username"].(string); ok { + c.Vector.Username = username + } + + if password, ok := config["password"].(string); ok { + c.Vector.Password = password + } + + // 移除maxTools的解析逻辑 + + return nil +} + +func (c *ToolSearchConfig) parseEmbeddingConfig(config map[string]any) error { + // Parse API key (required) + if apiKey, ok := config["apiKey"].(string); ok { + c.Embedding.APIKey = apiKey + } else { + return errors.New("missing embedding.apiKey") + } + + // Parse optional fields with defaults + if baseURL, ok := config["baseURL"].(string); ok { + c.Embedding.BaseURL = baseURL + } else { + c.Embedding.BaseURL = defaultBaseURL + } + + if model, ok := config["model"].(string); ok { + c.Embedding.Model = model + } else { + c.Embedding.Model = defaultModel + } + + if dimensions, ok := config["dimensions"].(float64); ok { + c.Embedding.Dimensions = int(dimensions) + } else if dimensions, ok := config["dimensions"].(int); ok { + c.Embedding.Dimensions = dimensions + } else { + c.Embedding.Dimensions = defaultDimensions + } + + return nil +} + +func (c *ToolSearchConfig) NewServer(serverName string) (*common.MCPServer, error) { + mcpServer := common.NewMCPServer( + serverName, + Version, + common.WithInstructions(c.description), + ) + + // Create embedding client + embeddingClient := NewEmbeddingClient(c.Embedding.APIKey, c.Embedding.BaseURL, c.Embedding.Model, c.Embedding.Dimensions) + + // Create search service,使用写死的fixedMaxTools值 + searchService := NewSearchService( + c.Vector.Host, + c.Vector.Port, + c.Vector.Database, + c.Vector.Username, + c.Vector.Password, + c.Vector.TableName, + embeddingClient, + c.Embedding.Dimensions, + fixedMaxTools, // 使用写死的值 + ) + + // Add tool search tool + mcpServer.AddTool( + mcp.NewToolWithRawSchema("x_higress_tool_search", "Higress MCP Tools Searcher", GetToolSearchSchema()), + HandleToolSearch(searchService), + ) + + return mcpServer, nil +} diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/server_test.go b/plugins/golang-filter/mcp-server/servers/tool-search/server_test.go new file mode 100644 index 000000000..df3786045 --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/server_test.go @@ -0,0 +1,198 @@ +package tool_search + +import ( + "context" + "encoding/json" + "fmt" + "os" + "testing" + "time" + + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "github.com/mark3labs/mcp-go/mcp" +) + +// Mock implementation of CommonCAPI for testing +type mockCommonCAPI struct { + logs []string +} + +func (m *mockCommonCAPI) Log(level api.LogType, message string) { + fmt.Printf("[%s] %s\n", level, message) + m.logs = append(m.logs, message) +} + +func (m *mockCommonCAPI) LogLevel() api.LogType { + return api.Debug +} + +// TestServer is used for local functional testing +func TestServer(t *testing.T) { + // Setup mock API for logging + mockAPI := &mockCommonCAPI{} + api.SetCommonCAPI(mockAPI) + + // Load configuration from environment variables or use defaults + config := map[string]any{ + "vector": map[string]any{ + "type": "milvus", + "vectorWeight": 0.6, + "tableName": getEnvOrDefault("TEST_TABLE_NAME", "apig_mcp_tools"), + "host": getEnvOrDefault("TEST_MILVUS_HOST", "localhost"), + "port": getEnvOrDefaultInt("TEST_MILVUS_PORT", 19530), + "database": getEnvOrDefault("TEST_MILVUS_DATABASE", "default"), + "username": getEnvOrDefault("TEST_MILVUS_USERNAME", "root"), + "password": getEnvOrDefault("TEST_MILVUS_PASSWORD", "Milvus"), + "maxTools": getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000), + }, + "embedding": map[string]any{ + "apiKey": getEnvOrDefault("TEST_API_KEY", "your-dashscope-api-key"), + "baseURL": getEnvOrDefault("TEST_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"), + "model": getEnvOrDefault("TEST_MODEL", "text-embedding-v4"), + "dimensions": 1024, + }, + "description": "Test MCP Tools Search Server", + } + + // Create configuration instance + toolSearchConfig := &ToolSearchConfig{} + if err := toolSearchConfig.ParseConfig(config); err != nil { + t.Fatalf("Failed to parse config: %v", err) + } + + // Create MCP Server + _, err := toolSearchConfig.NewServer("test-tool-search") + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + // Test database connection + vectorConfig := config["vector"].(map[string]any) + embeddingConfig := config["embedding"].(map[string]any) + + // Test GetAllTools + t.Logf("\n=== Testing GetAllTools ===") + embeddingClient := NewEmbeddingClient( + embeddingConfig["apiKey"].(string), + embeddingConfig["baseURL"].(string), + embeddingConfig["model"].(string), + embeddingConfig["dimensions"].(int), + ) + + searchService := NewSearchService( + vectorConfig["host"].(string), + vectorConfig["port"].(int), + vectorConfig["database"].(string), + vectorConfig["username"].(string), + vectorConfig["password"].(string), + vectorConfig["tableName"].(string), + embeddingClient, + embeddingConfig["dimensions"].(int), + getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000), + ) + + allTools, err := searchService.GetAllTools() + if err != nil { + t.Logf("GetAllTools failed: %v", err) + } else { + t.Logf("Found %d tools:", len(allTools.Tools)) + for i, tool := range allTools.Tools { + if i < 3 { // Show only first 3 tools + toolJSON, _ := json.MarshalIndent(tool, "", " ") + t.Logf("Tool %d: %s", i+1, string(toolJSON)) + } + } + if len(allTools.Tools) > 3 { + t.Logf("... and %d more tools", len(allTools.Tools)-3) + } + } + + // Test tool search with timing + t.Logf("\n=== Testing Tool Search ===") + testQueries := []string{ + "weather data", + "database query", + "file operations", + "HTTP requests", + "library documents", + } + + for _, query := range testQueries { + t.Logf("\n--- Testing query: '%s' ---", query) + + // Create MCP tool call request + request := mcp.CallToolRequest{ + Params: struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + }{ + Name: "x_higress_tool_search", + Arguments: map[string]interface{}{ + "query": query, + "topK": 3, + }, + }, + } + + // Get tool handler + handler := HandleToolSearch(searchService) + + // Execute search with timing + start := time.Now() + result, err := handler(context.Background(), request) + duration := time.Since(start) + + if err != nil { + t.Logf("Search failed: %v", err) + continue + } + + // Print results with timing information + t.Logf("Search completed in %v", duration) + if len(result.Content) > 0 { + if textContent, ok := result.Content[0].(mcp.TextContent); ok { + var toolsResult map[string]interface{} + if err := json.Unmarshal([]byte(textContent.Text), &toolsResult); err == nil { + toolsJSON, _ := json.MarshalIndent(toolsResult, "", " ") + t.Logf("Tools Result: %s", string(toolsJSON)) + } else { + t.Logf("Text Content: %s", textContent.Text) + } + } + } + } + + // Test configuration validation + t.Logf("\n=== Configuration Validation ===") + t.Logf("Host: %s", vectorConfig["host"]) + t.Logf("Port: %d", vectorConfig["port"]) + t.Logf("Database: %s", vectorConfig["database"]) + t.Logf("Table Name: %s", vectorConfig["tableName"]) + t.Logf("Vector Weight: %f", vectorConfig["vectorWeight"]) + t.Logf("Text Weight: %f", 1.0-vectorConfig["vectorWeight"].(float64)) + t.Logf("Model: %s", embeddingConfig["model"]) + t.Logf("Dimensions: %d", embeddingConfig["dimensions"]) + t.Logf("API Base URL: %s", embeddingConfig["baseURL"]) + + t.Logf("\n=== Test completed ===") +} + +// Helper function to get environment variable or default value +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getEnvOrDefaultInt(key string, defaultValue int) int { + if valueStr := os.Getenv(key); valueStr != "" { + if value, err := fmt.Sscanf(valueStr, "%d", &defaultValue); err == nil && value == 1 { + return defaultValue + } + } + return defaultValue +} diff --git a/plugins/golang-filter/mcp-server/servers/tool-search/tools.go b/plugins/golang-filter/mcp-server/servers/tool-search/tools.go new file mode 100644 index 000000000..bbab51aac --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/tool-search/tools.go @@ -0,0 +1,114 @@ +package tool_search + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-session/common" + "github.com/envoyproxy/envoy/contrib/golang/common/go/api" + "github.com/mark3labs/mcp-go/mcp" +) + +// HandleToolSearch handles the x_higress_tool_search tool +func HandleToolSearch(searchService *SearchService) common.ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + api.LogInfo("HandleToolSearch called") + + arguments := request.Params.Arguments + api.LogDebugf("Request arguments: %+v", arguments) + + // Get query parameter + query, ok := arguments["query"].(string) + if !ok { + api.LogErrorf("Invalid query argument type: %T", arguments["query"]) + return nil, fmt.Errorf("invalid query argument") + } + + // Validate query + if query == "" { + api.LogError("Empty query provided") + return nil, fmt.Errorf("query cannot be empty") + } + + // Get topK parameter (optional, default to 10) + topK := 10 + if topKVal, ok := arguments["topK"]; ok { + switch v := topKVal.(type) { + case float64: + topK = int(v) + case int: + topK = v + case int64: + topK = int(v) + default: + api.LogWarnf("Invalid topK argument type: %T, using default: %d", topKVal, topK) + } + + // Validate topK range + if topK <= 0 || topK > 100 { + api.LogWarnf("Invalid topK value: %d, using default: 10", topK) + topK = 10 + } + } + + api.LogInfof("Parsed parameters - query: '%s', topK: %d", query, topK) + + // Add timeout to context + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + // Perform search + result, err := searchService.SearchTools(ctx, query, topK) + if err != nil { + api.LogErrorf("Search failed: %v", err) + return nil, fmt.Errorf("failed to search tools: %w", err) + } + + api.LogInfof("Search completed successfully, found %d tools", len(result.Tools)) + + // Build response + response := map[string]interface{}{ + "tools": result.Tools, + } + + jsonData, err := json.Marshal(response) + if err != nil { + api.LogErrorf("Failed to marshal response: %v", err) + return nil, fmt.Errorf("failed to marshal search results: %w", err) + } + + api.LogDebugf("Response marshaled successfully, JSON length: %d", len(jsonData)) + api.LogDebugf("Returning MCP CallToolResult") + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: string(jsonData), + }, + }, + }, nil + } +} + +// GetToolSearchSchema returns the schema for the tool search tool +func GetToolSearchSchema() json.RawMessage { + return json.RawMessage(`{ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Query statement for semantic similarity comparison with tool descriptions" + }, + "topK": { + "type": "integer", + "description": "Specify how many tools need to be selected, default is to select the top 10 tools.", + "minimum": 1, + "maximum": 100 + } + }, + "required": ["query"] + }`) +}