Add tool-search server (#3136)

Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
This commit is contained in:
Wangzy
2025-12-22 09:46:31 +08:00
committed by jingze
parent fef8ecc822
commit 6f4ef33590
11 changed files with 1191 additions and 5 deletions

View File

@@ -53,7 +53,6 @@ require (
github.com/cockroachdb/errors v1.9.1 // indirect
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
github.com/cockroachdb/redact v1.1.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/deckarep/golang-set v1.7.1 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/getsentry/sentry-go v0.12.0 // indirect

View File

@@ -185,10 +185,7 @@ github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TB
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-faker/faker/v4 v4.1.0 h1:ffuWmpDrducIUOO0QSKSF5Q2dxAht+dhsT9FvVHhPEI=
github.com/go-faker/faker/v4 v4.1.0/go.mod h1:uuNc0PSRxF8nMgjGrrrU4Nw5cF30Jc6Kd0/FUTTYbhg=
github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw=
github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw=
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
@@ -429,7 +426,6 @@ github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKf
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-ops"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/tool-search"
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
xds "github.com/cncf/xds/go/xds/type/v3"

View File

@@ -0,0 +1,144 @@
# Tool Search MCP Server
这是一个基于 Higress Golang Filter 实现的 MCP Server用于提供工具语义搜索功能。当前实现**仅支持向量语义搜索**(基于 Milvus 向量数据库),**不包含全文检索或混合搜索**。
## 功能特性
- **向量语义搜索**:使用 OpenAI 兼容的 Embedding API 将用户查询转换为向量,并在 Milvus 中进行相似度检索
- **工具元数据支持**从数据库中读取完整的工具定义JSON 格式),并动态拼接工具名称
- **全量工具列表**:支持获取数据库中所有可用工具
- **可配置 Embedding 模型**:支持自定义模型、维度及 API 端点(如 DashScope
- **Milvus 集成**:通过标准 gRPC 接口连接 Milvus 向量数据库
## 数据库要求Milvus
本服务依赖 **Milvus 向量数据库**需预先创建集合Collection其 Schema 应包含以下字段:
| 字段名 | 类型 | 说明 |
|--------------|-------------------|-------------------------|
| `id` | VarChar(64) | 文档唯一 ID |
| `content` | VarChar(64) | 工具描述文本 |
| `metadata` | JSON | 完整的工具定义(必须包含 `name` 字段) |
| `vector` | FloatVector(1024) | embedding 向量 |
| `metadata` | Int64 | 创建时间 |
## 配置参数
### 根级配置
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|--------------|--------|------|-----------------------------------------------------|------|
| `vector` | object | 是 | - | 向量数据库配置(见下文) |
| `embedding` | object | 是 | - | Embedding API 配置(见下文) |
| `description`| string | 否 | `"Tool search server for semantic similarity search"` | MCP Server 描述信息 |
### Vector 配置(`vector` 对象)
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|-------------|--------|------|--------------------|------|
| `type` | string | 是 | - | **必须为 `"milvus"`** |
| `host` | string | 是 | - | Milvus 服务地址(如 `localhost` |
| `port` | int | 是 | - | Milvus gRPC 端口(如 `19530` |
| `database` | string | 否 | `"default"` | Milvus 数据库名 |
| `tableName` | string | 否 | `"apig_mcp_tools"` | Milvus 集合名 |
| `username` | string | 否 | - | 认证用户名(可选) |
| `password` | string | 否 | - | 认证密码(可选) |
### Embedding 配置(`embedding` 对象)
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|--------------|--------|------|-----------------------------------------------------------|------|
| `apiKey` | string | 是 | - | Embedding 服务的 API Key |
| `baseURL` | string | 否 | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI 兼容 API 的 Base URL |
| `model` | string | 否 | `text-embedding-v4` | 使用的 Embedding 模型 |
| `dimensions` | int | 否 | `1024` | 向量维度 |
## 配置示例
Tool Search MCP Server 也可以作为 Higress 的一个模块进行配置。以下是一个在 Higress ConfigMap 中配置 Tool Search 的示例:
```yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: higress-config
namespace: higress-system
data:
higress: |
mcpServer:
enable: true
sse_path_suffix: "/sse"
redis:
address: "<Redis IP>:6379"
username: ""
password: ""
db: 0
match_list:
- path_rewrite_prefix: ""
upstream_type: ""
enable_path_rewrite: false
match_rule_domain: "*"
match_rule_path: "/mcp-servers/tool-search"
match_rule_type: "prefix"
servers:
- path: "/mcp-servers/tool-search"
name: "tool-search"
type: "tool-search"
config:
vector:
type: "milvus"
host: "localhost"
port: 19530
database: "default"
tableName: "apig_mcp_tools"
username: "root"
password: "Milvus"
maxTools: 1000
embedding:
apiKey: "your-dashscope-api-key"
baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1"
model: "text-embedding-v4"
dimensions: 1024
description: "Higress 工具语义搜索服务"
```
## 工具搜索接口
Tool Search MCP Server 提供以下 MCP 工具:
### x_higress_tool_search
基于语义相似度搜索最相关的工具。
**输入参数**:
| 参数名 | 类型 | 必填 | 说明 |
|---------|--------|------|------|
| `query` | string | 是 | 查询语句,用于与工具描述进行语义相似度比较 |
| `topK` | int | 否 | 指定需要选择的工具数量默认选择前10个工具 |
**输出格式**:
```
{
"tools": [
{
"name": "server_name___tool_name",
"title": "Tool Title",
"description": "Tool description",
"inputSchema": {...},
"outputSchema": {...}
}
]
}
```
## 搜索实现
通过向量相似度进行搜索,索引配置如下
- 使用 HNSW 索引算法进行向量索引
- 默认参数M=8, efConstruction=64
- 相似度度量方式内积IP

View File

@@ -0,0 +1,18 @@
{
"vector": {
"type": "milvus",
"vectorWeight": 0.5,
"tableName": "apig_mcp_tools",
"host": "localhost",
"port": 19530,
"database": "default",
"username": "root",
"password": "Milvus"
},
"embedding": {
"apiKey": "your-dashscope-api-key",
"baseURL": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "text-embedding-v4",
"dimensions": 1024
}
}

View File

@@ -0,0 +1,79 @@
package tool_search
import (
"context"
"fmt"
"time"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
)
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
type EmbeddingClient struct {
client *openai.Client
model string
dimensions int
}
// NewEmbeddingClient creates a new EmbeddingClient instance for OpenAI-compatible APIs
func NewEmbeddingClient(apiKey, baseURL, model string, dimensions int) *EmbeddingClient {
api.LogInfof("Creating EmbeddingClient with baseURL: %s, model: %s, dimensions: %d", baseURL, model, dimensions)
// Create client with timeout
client := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(baseURL),
option.WithRequestTimeout(30*time.Second),
)
return &EmbeddingClient{
client: &client,
model: model,
dimensions: dimensions,
}
}
// GetEmbedding generates vector embedding for the given text
func (e *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
api.LogInfof("Generating embedding for text (length: %d)", len(text))
api.LogDebugf("Using model: %s, dimensions: %d", e.model, e.dimensions)
// Add timeout to context if not already present
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
params := openai.EmbeddingNewParams{
Model: e.model,
Input: openai.EmbeddingNewParamsInputUnion{
OfString: openai.String(text),
},
Dimensions: openai.Int(int64(e.dimensions)),
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
}
api.LogDebugf("Calling OpenAI-compatible API for embedding generation")
embeddingResp, err := e.client.Embeddings.New(ctx, params)
if err != nil {
api.LogErrorf("OpenAI-compatible API call failed: %v", err)
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
if len(embeddingResp.Data) == 0 {
api.LogErrorf("Empty embedding response from API")
return nil, fmt.Errorf("empty embedding response")
}
api.LogDebugf("Successfully received embedding from API")
api.LogDebugf("Response data length: %d, embedding dimension: %d", len(embeddingResp.Data), len(embeddingResp.Data[0].Embedding))
// Convert []float64 to []float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
}
api.LogInfof("Embedding conversion completed, final dimension: %d", len(embedding))
return embedding, nil
}

View File

@@ -0,0 +1,204 @@
package tool_search
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)
type MilvusVectorStoreProvider struct {
client client.Client
collection string
dimensions int
}
func NewMilvusVectorStoreProvider(cfg *config.VectorDBConfig, dimensions int) (*MilvusVectorStoreProvider, error) {
connectParam := client.Config{
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
}
connectParam.DBName = cfg.Database
if cfg.Username != "" && cfg.Password != "" {
connectParam.Username = cfg.Username
connectParam.Password = cfg.Password
}
milvusClient, err := client.NewClient(context.Background(), connectParam)
if err != nil {
return nil, fmt.Errorf("failed to create milvus client: %w", err)
}
return &MilvusVectorStoreProvider{
client: milvusClient,
collection: cfg.Collection,
dimensions: dimensions,
}, nil
}
func (c *MilvusVectorStoreProvider) ListAllDocs(ctx context.Context, limit int) ([]schema.Document, error) {
expr := ""
outputFields := []string{"id", "content", "metadata", "created_at"}
var queryResult []entity.Column
var err error
if limit > 0 {
queryResult, err = c.client.Query(
ctx,
c.collection,
[]string{}, // partitions
expr, // filter condition
outputFields,
client.WithLimit(int64(limit)),
)
} else {
queryResult, err = c.client.Query(
ctx,
c.collection,
[]string{}, // partitions
expr, // filter condition
outputFields,
)
}
if err != nil {
return nil, fmt.Errorf("failed to query all documents: %w", err)
}
if len(queryResult) == 0 {
return []schema.Document{}, nil
}
rowCount := queryResult[0].Len()
documents := make([]schema.Document, 0, rowCount)
for i := 0; i < rowCount; i++ {
var (
id string
content string
metadata map[string]interface{}
createdAt int64
)
for _, col := range queryResult {
switch col.Name() {
case "id":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
id = v.(string)
}
case "content":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
content = v.(string)
}
case "metadata":
if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil {
if bytes, ok := v.([]byte); ok {
_ = json.Unmarshal(bytes, &metadata)
}
}
case "created_at":
if v, err := col.(*entity.ColumnInt64).Get(i); err == nil {
createdAt = v.(int64)
}
}
}
doc := schema.Document{
ID: id,
Content: content,
Metadata: metadata,
CreatedAt: time.UnixMilli(createdAt),
}
documents = append(documents, doc)
}
return documents, nil
}
func (c *MilvusVectorStoreProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
if options == nil {
options = &schema.SearchOptions{TopK: 10}
}
sp, err := entity.NewIndexHNSWSearchParam(16) // 默认 HNSW 搜索参数
if err != nil {
return nil, fmt.Errorf("failed to build search param: %w", err)
}
outputFields := []string{"id", "content", "metadata"}
searchResults, err := c.client.Search(
ctx,
c.collection,
[]string{}, // partition names
"", // filter expression
outputFields, // output fields
[]entity.Vector{entity.FloatVector(vector)},
"vector", // anns_field
entity.IP, // metric_type
options.TopK,
sp,
)
if err != nil {
return nil, fmt.Errorf("failed to search documents: %w", err)
}
var results []schema.SearchResult
for _, result := range searchResults {
for i := 0; i < result.ResultCount; i++ {
id, _ := result.IDs.Get(i)
score := result.Scores[i]
var content string
var metadata map[string]interface{}
for _, field := range result.Fields {
switch field.Name() {
case "content":
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
if contentVal, err := contentCol.Get(i); err == nil {
if contentStr, ok := contentVal.(string); ok {
content = contentStr
}
}
}
case "metadata":
if metaCol, ok := field.(*entity.ColumnJSONBytes); ok {
if metaVal, err := metaCol.Get(i); err == nil {
if metaBytes, ok := metaVal.([]byte); ok {
if err := json.Unmarshal(metaBytes, &metadata); err != nil {
metadata = make(map[string]interface{})
}
}
}
}
}
}
searchResult := schema.SearchResult{
Document: schema.Document{
ID: fmt.Sprintf("%s", id),
Content: content,
Metadata: metadata,
},
Score: float64(score),
}
results = append(results, searchResult)
}
}
return results, nil
}
func (c *MilvusVectorStoreProvider) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}

View File

@@ -0,0 +1,237 @@
package tool_search
import (
"context"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
)
// SearchService handles tool search operations
type SearchService struct {
milvusProvider *MilvusVectorStoreProvider
config *config.VectorDBConfig
tableName string
dimensions int
maxTools int // 写死的最大工具数量,仅用于单测
embeddingClient *EmbeddingClient
}
// NewSearchService creates a new SearchService instance
func NewSearchService(host string, port int, database, username, password, tableName string, embeddingClient *EmbeddingClient, dimensions int, maxTools int) *SearchService {
// Create Milvus configuration
cfg := &config.VectorDBConfig{
Provider: "milvus",
Host: host,
Port: port,
Database: database,
Collection: tableName,
Username: username,
Password: password,
}
// Create Milvus provider
provider, err := NewMilvusVectorStoreProvider(cfg, dimensions)
if err != nil {
api.LogErrorf("Failed to create Milvus provider: %v", err)
return nil
}
return &SearchService{
milvusProvider: provider,
config: cfg,
tableName: tableName,
dimensions: dimensions,
maxTools: maxTools, // 使用写死的值
embeddingClient: embeddingClient,
}
}
// ToolSearchResult represents the result of a tool search
type ToolSearchResult struct {
Tools []ToolDefinition `json:"tools"`
}
// ToolDefinition represents a tool definition in the search result
type ToolDefinition map[string]interface{}
// SearchTools performs semantic search for tools
func (s *SearchService) SearchTools(ctx context.Context, query string, topK int) (*ToolSearchResult, error) {
api.LogInfof("Starting tool search for query: '%s', topK: %d", query, topK)
// Generate vector embedding for the query
vector, err := s.embeddingClient.GetEmbedding(ctx, query)
if err != nil {
api.LogErrorf("Failed to generate embedding for query '%s': %v", query, err)
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
api.LogInfof("Embedding generated successfully, vector dimension: %d", len(vector))
// Perform vector search
records, err := s.searchToolsInDB(query, vector, topK)
if err != nil {
api.LogErrorf("Failed to search tools: %v", err)
return nil, fmt.Errorf("failed to search tools: %w", err)
}
api.LogInfof("Vector search completed, found %d records", len(records))
return s.convertRecordsToResult(records), nil
}
// convertRecordsToResult converts database records to tool search result
func (s *SearchService) convertRecordsToResult(records []ToolRecord) *ToolSearchResult {
api.LogInfof("Converting %d records to tool definitions", len(records))
tools := make([]ToolDefinition, 0, len(records))
for i, record := range records {
var tool ToolDefinition
// Use metadata if available
if len(record.Metadata) > 0 {
tool = record.Metadata
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
} else {
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
// If no metadata, create a basic tool definition
tool = ToolDefinition{
"name": record.Name,
"description": record.Content,
}
}
// Update the name to include server name
tool["name"] = fmt.Sprintf("%s", record.Name)
tools = append(tools, tool)
api.LogDebugf("Tool %d: %s - %s", i+1, tool["name"], record.Content)
}
api.LogInfof("Successfully converted %d tools", len(tools))
return &ToolSearchResult{Tools: tools}
}
// GetAllTools retrieves all available tools
func (s *SearchService) GetAllTools() (*ToolSearchResult, error) {
api.LogInfo("Retrieving all tools")
records, err := s.getAllToolsFromDB()
if err != nil {
api.LogErrorf("Failed to get all tools: %v", err)
return nil, fmt.Errorf("failed to get all tools: %w", err)
}
api.LogInfof("Found %d tools in database", len(records))
// Convert records to tool definitions
tools := make([]ToolDefinition, 0, len(records))
for _, record := range records {
var tool ToolDefinition
// Use metadata if available
if len(record.Metadata) > 0 {
tool = record.Metadata
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
} else {
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
// If no metadata, create a basic tool definition
tool = ToolDefinition{
"name": record.Name,
"description": record.Content,
}
}
// Update the name to include server name
tool["name"] = fmt.Sprintf("%s", record.Name)
tools = append(tools, tool)
}
api.LogInfof("Successfully converted %d tools", len(tools))
return &ToolSearchResult{Tools: tools}, nil
}
// ToolRecord represents a tool record in the database
type ToolRecord struct {
ID string `json:"id"`
Name string `json:"name"`
Content string `json:"content"`
Metadata map[string]interface{} `json:"metadata"`
}
func (s *SearchService) searchToolsInDB(query string, vector []float32, topK int) ([]ToolRecord, error) {
api.LogInfof("Performing vector search for query: '%s', topK: %d", query, topK)
// For Milvus, we'll perform vector search directly
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Perform vector search
searchOptions := &schema.SearchOptions{
TopK: topK,
}
results, err := s.milvusProvider.SearchDocs(ctx, vector, searchOptions)
if err != nil {
api.LogErrorf("Vector search failed: %v", err)
return nil, fmt.Errorf("failed to perform vector search: %w", err)
}
// Convert results to ToolRecords
var records []ToolRecord
for _, result := range results {
doc := result.Document
tool := ToolRecord{
ID: doc.ID,
Content: doc.Content,
Metadata: doc.Metadata,
}
if name, ok := doc.Metadata["name"].(string); ok {
tool.Name = name
}
records = append(records, tool)
}
api.LogInfof("Vector search completed, found %d results", len(records))
return records, nil
}
// getAllToolsFromDB retrieves all tools from the database
func (s *SearchService) getAllToolsFromDB() ([]ToolRecord, error) {
api.LogInfof("Executing GetAllTools query from collection: %s", s.tableName)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Retrieve all documents with limit
docs, err := s.milvusProvider.ListAllDocs(ctx, s.maxTools)
if err != nil {
api.LogErrorf("Failed to list documents: %v", err)
return nil, fmt.Errorf("failed to list documents: %w", err)
}
// Convert documents to ToolRecords
var tools []ToolRecord
for _, doc := range docs {
tool := ToolRecord{
ID: doc.ID,
Content: doc.Content,
Metadata: doc.Metadata,
}
if name, ok := doc.Metadata["name"].(string); ok {
tool.Name = name
}
tools = append(tools, tool)
}
api.LogInfof("GetAllTools query completed, found %d tools", len(tools))
return tools, nil
}

View File

@@ -0,0 +1,196 @@
package tool_search
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
const (
Version = "1.0.0"
// 默认配置值
defaultTableName = "apig_mcp_tools"
defaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
defaultModel = "text-embedding-v4"
defaultDimensions = 1024
// 写死最大工具数量为1000仅用于单测
fixedMaxTools = 1000
)
func init() {
common.GlobalRegistry.RegisterServer("tool-search", &ToolSearchConfig{})
}
type VectorConfig struct {
Type string `json:"type"`
VectorWeight float64 `json:"vectorWeight"`
TableName string `json:"tableName"`
Host string `json:"host"`
Port int `json:"port"`
Database string `json:"database"`
Username string `json:"username"`
Password string `json:"password"`
}
type EmbeddingConfig struct {
APIKey string `json:"apiKey"`
BaseURL string `json:"baseURL"`
Model string `json:"model"`
Dimensions int `json:"dimensions"`
}
type ToolSearchConfig struct {
Vector VectorConfig `json:"vector"`
Embedding EmbeddingConfig `json:"embedding"`
description string
}
func (c *ToolSearchConfig) ParseConfig(config map[string]any) error {
// Parse vector configuration
vectorConfig, ok := config["vector"].(map[string]any)
if !ok {
return errors.New("missing vector configuration")
}
if err := c.parseVectorConfig(vectorConfig); err != nil {
return fmt.Errorf("failed to parse vector config: %w", err)
}
// Parse embedding configuration
embeddingConfig, ok := config["embedding"].(map[string]any)
if !ok {
return errors.New("missing embedding configuration")
}
if err := c.parseEmbeddingConfig(embeddingConfig); err != nil {
return fmt.Errorf("failed to parse embedding config: %w", err)
}
// Optional description
if description, ok := config["description"].(string); ok {
c.description = description
} else {
c.description = "Tool search server for semantic similarity search"
}
api.LogDebugf("ToolSearchConfig ParseConfig: %+v", config)
return nil
}
func (c *ToolSearchConfig) parseVectorConfig(config map[string]any) error {
if vectorType, ok := config["type"].(string); ok {
c.Vector.Type = vectorType
} else {
return errors.New("missing vector.type")
}
if c.Vector.Type != "milvus" {
return fmt.Errorf("unsupported vector.type: %s, only 'milvus' is supported", c.Vector.Type)
}
if host, ok := config["host"].(string); ok {
c.Vector.Host = host
} else {
return errors.New("missing vector.host")
}
if port, ok := config["port"].(float64); ok {
c.Vector.Port = int(port)
} else if port, ok := config["port"].(int); ok {
c.Vector.Port = port
} else {
return errors.New("missing vector.port")
}
if database, ok := config["database"].(string); ok {
c.Vector.Database = database
} else {
c.Vector.Database = "default" // 默认数据库
}
if tableName, ok := config["tableName"].(string); ok {
c.Vector.TableName = tableName
} else {
c.Vector.TableName = defaultTableName
}
if username, ok := config["username"].(string); ok {
c.Vector.Username = username
}
if password, ok := config["password"].(string); ok {
c.Vector.Password = password
}
// 移除maxTools的解析逻辑
return nil
}
func (c *ToolSearchConfig) parseEmbeddingConfig(config map[string]any) error {
// Parse API key (required)
if apiKey, ok := config["apiKey"].(string); ok {
c.Embedding.APIKey = apiKey
} else {
return errors.New("missing embedding.apiKey")
}
// Parse optional fields with defaults
if baseURL, ok := config["baseURL"].(string); ok {
c.Embedding.BaseURL = baseURL
} else {
c.Embedding.BaseURL = defaultBaseURL
}
if model, ok := config["model"].(string); ok {
c.Embedding.Model = model
} else {
c.Embedding.Model = defaultModel
}
if dimensions, ok := config["dimensions"].(float64); ok {
c.Embedding.Dimensions = int(dimensions)
} else if dimensions, ok := config["dimensions"].(int); ok {
c.Embedding.Dimensions = dimensions
} else {
c.Embedding.Dimensions = defaultDimensions
}
return nil
}
func (c *ToolSearchConfig) NewServer(serverName string) (*common.MCPServer, error) {
mcpServer := common.NewMCPServer(
serverName,
Version,
common.WithInstructions(c.description),
)
// Create embedding client
embeddingClient := NewEmbeddingClient(c.Embedding.APIKey, c.Embedding.BaseURL, c.Embedding.Model, c.Embedding.Dimensions)
// Create search service使用写死的fixedMaxTools值
searchService := NewSearchService(
c.Vector.Host,
c.Vector.Port,
c.Vector.Database,
c.Vector.Username,
c.Vector.Password,
c.Vector.TableName,
embeddingClient,
c.Embedding.Dimensions,
fixedMaxTools, // 使用写死的值
)
// Add tool search tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("x_higress_tool_search", "Higress MCP Tools Searcher", GetToolSearchSchema()),
HandleToolSearch(searchService),
)
return mcpServer, nil
}

View File

@@ -0,0 +1,198 @@
package tool_search
import (
"context"
"encoding/json"
"fmt"
"os"
"testing"
"time"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
// Mock implementation of CommonCAPI for testing
type mockCommonCAPI struct {
logs []string
}
func (m *mockCommonCAPI) Log(level api.LogType, message string) {
fmt.Printf("[%s] %s\n", level, message)
m.logs = append(m.logs, message)
}
func (m *mockCommonCAPI) LogLevel() api.LogType {
return api.Debug
}
// TestServer is used for local functional testing
func TestServer(t *testing.T) {
// Setup mock API for logging
mockAPI := &mockCommonCAPI{}
api.SetCommonCAPI(mockAPI)
// Load configuration from environment variables or use defaults
config := map[string]any{
"vector": map[string]any{
"type": "milvus",
"vectorWeight": 0.6,
"tableName": getEnvOrDefault("TEST_TABLE_NAME", "apig_mcp_tools"),
"host": getEnvOrDefault("TEST_MILVUS_HOST", "localhost"),
"port": getEnvOrDefaultInt("TEST_MILVUS_PORT", 19530),
"database": getEnvOrDefault("TEST_MILVUS_DATABASE", "default"),
"username": getEnvOrDefault("TEST_MILVUS_USERNAME", "root"),
"password": getEnvOrDefault("TEST_MILVUS_PASSWORD", "Milvus"),
"maxTools": getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
},
"embedding": map[string]any{
"apiKey": getEnvOrDefault("TEST_API_KEY", "your-dashscope-api-key"),
"baseURL": getEnvOrDefault("TEST_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
"model": getEnvOrDefault("TEST_MODEL", "text-embedding-v4"),
"dimensions": 1024,
},
"description": "Test MCP Tools Search Server",
}
// Create configuration instance
toolSearchConfig := &ToolSearchConfig{}
if err := toolSearchConfig.ParseConfig(config); err != nil {
t.Fatalf("Failed to parse config: %v", err)
}
// Create MCP Server
_, err := toolSearchConfig.NewServer("test-tool-search")
if err != nil {
t.Fatalf("Failed to create server: %v", err)
}
// Test database connection
vectorConfig := config["vector"].(map[string]any)
embeddingConfig := config["embedding"].(map[string]any)
// Test GetAllTools
t.Logf("\n=== Testing GetAllTools ===")
embeddingClient := NewEmbeddingClient(
embeddingConfig["apiKey"].(string),
embeddingConfig["baseURL"].(string),
embeddingConfig["model"].(string),
embeddingConfig["dimensions"].(int),
)
searchService := NewSearchService(
vectorConfig["host"].(string),
vectorConfig["port"].(int),
vectorConfig["database"].(string),
vectorConfig["username"].(string),
vectorConfig["password"].(string),
vectorConfig["tableName"].(string),
embeddingClient,
embeddingConfig["dimensions"].(int),
getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
)
allTools, err := searchService.GetAllTools()
if err != nil {
t.Logf("GetAllTools failed: %v", err)
} else {
t.Logf("Found %d tools:", len(allTools.Tools))
for i, tool := range allTools.Tools {
if i < 3 { // Show only first 3 tools
toolJSON, _ := json.MarshalIndent(tool, "", " ")
t.Logf("Tool %d: %s", i+1, string(toolJSON))
}
}
if len(allTools.Tools) > 3 {
t.Logf("... and %d more tools", len(allTools.Tools)-3)
}
}
// Test tool search with timing
t.Logf("\n=== Testing Tool Search ===")
testQueries := []string{
"weather data",
"database query",
"file operations",
"HTTP requests",
"library documents",
}
for _, query := range testQueries {
t.Logf("\n--- Testing query: '%s' ---", query)
// Create MCP tool call request
request := mcp.CallToolRequest{
Params: struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments,omitempty"`
Meta *struct {
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
} `json:"_meta,omitempty"`
}{
Name: "x_higress_tool_search",
Arguments: map[string]interface{}{
"query": query,
"topK": 3,
},
},
}
// Get tool handler
handler := HandleToolSearch(searchService)
// Execute search with timing
start := time.Now()
result, err := handler(context.Background(), request)
duration := time.Since(start)
if err != nil {
t.Logf("Search failed: %v", err)
continue
}
// Print results with timing information
t.Logf("Search completed in %v", duration)
if len(result.Content) > 0 {
if textContent, ok := result.Content[0].(mcp.TextContent); ok {
var toolsResult map[string]interface{}
if err := json.Unmarshal([]byte(textContent.Text), &toolsResult); err == nil {
toolsJSON, _ := json.MarshalIndent(toolsResult, "", " ")
t.Logf("Tools Result: %s", string(toolsJSON))
} else {
t.Logf("Text Content: %s", textContent.Text)
}
}
}
}
// Test configuration validation
t.Logf("\n=== Configuration Validation ===")
t.Logf("Host: %s", vectorConfig["host"])
t.Logf("Port: %d", vectorConfig["port"])
t.Logf("Database: %s", vectorConfig["database"])
t.Logf("Table Name: %s", vectorConfig["tableName"])
t.Logf("Vector Weight: %f", vectorConfig["vectorWeight"])
t.Logf("Text Weight: %f", 1.0-vectorConfig["vectorWeight"].(float64))
t.Logf("Model: %s", embeddingConfig["model"])
t.Logf("Dimensions: %d", embeddingConfig["dimensions"])
t.Logf("API Base URL: %s", embeddingConfig["baseURL"])
t.Logf("\n=== Test completed ===")
}
// Helper function to get environment variable or default value
func getEnvOrDefault(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
func getEnvOrDefaultInt(key string, defaultValue int) int {
if valueStr := os.Getenv(key); valueStr != "" {
if value, err := fmt.Sscanf(valueStr, "%d", &defaultValue); err == nil && value == 1 {
return defaultValue
}
}
return defaultValue
}

View File

@@ -0,0 +1,114 @@
package tool_search
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
"github.com/mark3labs/mcp-go/mcp"
)
// HandleToolSearch handles the x_higress_tool_search tool
func HandleToolSearch(searchService *SearchService) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
api.LogInfo("HandleToolSearch called")
arguments := request.Params.Arguments
api.LogDebugf("Request arguments: %+v", arguments)
// Get query parameter
query, ok := arguments["query"].(string)
if !ok {
api.LogErrorf("Invalid query argument type: %T", arguments["query"])
return nil, fmt.Errorf("invalid query argument")
}
// Validate query
if query == "" {
api.LogError("Empty query provided")
return nil, fmt.Errorf("query cannot be empty")
}
// Get topK parameter (optional, default to 10)
topK := 10
if topKVal, ok := arguments["topK"]; ok {
switch v := topKVal.(type) {
case float64:
topK = int(v)
case int:
topK = v
case int64:
topK = int(v)
default:
api.LogWarnf("Invalid topK argument type: %T, using default: %d", topKVal, topK)
}
// Validate topK range
if topK <= 0 || topK > 100 {
api.LogWarnf("Invalid topK value: %d, using default: 10", topK)
topK = 10
}
}
api.LogInfof("Parsed parameters - query: '%s', topK: %d", query, topK)
// Add timeout to context
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Perform search
result, err := searchService.SearchTools(ctx, query, topK)
if err != nil {
api.LogErrorf("Search failed: %v", err)
return nil, fmt.Errorf("failed to search tools: %w", err)
}
api.LogInfof("Search completed successfully, found %d tools", len(result.Tools))
// Build response
response := map[string]interface{}{
"tools": result.Tools,
}
jsonData, err := json.Marshal(response)
if err != nil {
api.LogErrorf("Failed to marshal response: %v", err)
return nil, fmt.Errorf("failed to marshal search results: %w", err)
}
api.LogDebugf("Response marshaled successfully, JSON length: %d", len(jsonData))
api.LogDebugf("Returning MCP CallToolResult")
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: string(jsonData),
},
},
}, nil
}
}
// GetToolSearchSchema returns the schema for the tool search tool
func GetToolSearchSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Query statement for semantic similarity comparison with tool descriptions"
},
"topK": {
"type": "integer",
"description": "Specify how many tools need to be selected, default is to select the top 10 tools.",
"minimum": 1,
"maximum": 100
}
},
"required": ["query"]
}`)
}