mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 15:10:54 +08:00
Add tool-search server (#3136)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
This commit is contained in:
@@ -53,7 +53,6 @@ require (
|
||||
github.com/cockroachdb/errors v1.9.1 // indirect
|
||||
github.com/cockroachdb/logtags v0.0.0-20211118104740-dabe8e521a4f // indirect
|
||||
github.com/cockroachdb/redact v1.1.3 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/deckarep/golang-set v1.7.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/getsentry/sentry-go v0.12.0 // indirect
|
||||
|
||||
@@ -185,10 +185,7 @@ github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TB
|
||||
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
|
||||
github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w=
|
||||
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
|
||||
github.com/go-faker/faker/v4 v4.1.0 h1:ffuWmpDrducIUOO0QSKSF5Q2dxAht+dhsT9FvVHhPEI=
|
||||
github.com/go-faker/faker/v4 v4.1.0/go.mod h1:uuNc0PSRxF8nMgjGrrrU4Nw5cF30Jc6Kd0/FUTTYbhg=
|
||||
github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw=
|
||||
github.com/go-faster/city v1.0.1/go.mod h1:jKcUJId49qdW3L1qKHH/3wPeUstCVpVSXTM6vO3VcTw=
|
||||
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
|
||||
@@ -429,7 +426,6 @@ github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKf
|
||||
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-ops"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/tool-search"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
|
||||
144
plugins/golang-filter/mcp-server/servers/tool-search/README.md
Normal file
144
plugins/golang-filter/mcp-server/servers/tool-search/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# Tool Search MCP Server
|
||||
|
||||
这是一个基于 Higress Golang Filter 实现的 MCP Server,用于提供工具语义搜索功能。当前实现**仅支持向量语义搜索**(基于 Milvus 向量数据库),**不包含全文检索或混合搜索**。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **向量语义搜索**:使用 OpenAI 兼容的 Embedding API 将用户查询转换为向量,并在 Milvus 中进行相似度检索
|
||||
- **工具元数据支持**:从数据库中读取完整的工具定义(JSON 格式),并动态拼接工具名称
|
||||
- **全量工具列表**:支持获取数据库中所有可用工具
|
||||
- **可配置 Embedding 模型**:支持自定义模型、维度及 API 端点(如 DashScope)
|
||||
- **Milvus 集成**:通过标准 gRPC 接口连接 Milvus 向量数据库
|
||||
|
||||
## 数据库要求(Milvus)
|
||||
|
||||
本服务依赖 **Milvus 向量数据库**,需预先创建集合(Collection),其 Schema 应包含以下字段:
|
||||
|
||||
| 字段名 | 类型 | 说明 |
|
||||
|--------------|-------------------|-------------------------|
|
||||
| `id` | VarChar(64) | 文档唯一 ID |
|
||||
| `content` | VarChar(64) | 工具描述文本 |
|
||||
| `metadata` | JSON | 完整的工具定义(必须包含 `name` 字段) |
|
||||
| `vector` | FloatVector(1024) | embedding 向量 |
|
||||
| `metadata` | Int64 | 创建时间 |
|
||||
|
||||
|
||||
## 配置参数
|
||||
|
||||
### 根级配置
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|--------------|--------|------|-----------------------------------------------------|------|
|
||||
| `vector` | object | 是 | - | 向量数据库配置(见下文) |
|
||||
| `embedding` | object | 是 | - | Embedding API 配置(见下文) |
|
||||
| `description`| string | 否 | `"Tool search server for semantic similarity search"` | MCP Server 描述信息 |
|
||||
|
||||
### Vector 配置(`vector` 对象)
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|-------------|--------|------|--------------------|------|
|
||||
| `type` | string | 是 | - | **必须为 `"milvus"`** |
|
||||
| `host` | string | 是 | - | Milvus 服务地址(如 `localhost`) |
|
||||
| `port` | int | 是 | - | Milvus gRPC 端口(如 `19530`) |
|
||||
| `database` | string | 否 | `"default"` | Milvus 数据库名 |
|
||||
| `tableName` | string | 否 | `"apig_mcp_tools"` | Milvus 集合名 |
|
||||
| `username` | string | 否 | - | 认证用户名(可选) |
|
||||
| `password` | string | 否 | - | 认证密码(可选) |
|
||||
|
||||
### Embedding 配置(`embedding` 对象)
|
||||
|
||||
| 参数 | 类型 | 必填 | 默认值 | 说明 |
|
||||
|--------------|--------|------|-----------------------------------------------------------|------|
|
||||
| `apiKey` | string | 是 | - | Embedding 服务的 API Key |
|
||||
| `baseURL` | string | 否 | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI 兼容 API 的 Base URL |
|
||||
| `model` | string | 否 | `text-embedding-v4` | 使用的 Embedding 模型 |
|
||||
| `dimensions` | int | 否 | `1024` | 向量维度 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
|
||||
Tool Search MCP Server 也可以作为 Higress 的一个模块进行配置。以下是一个在 Higress ConfigMap 中配置 Tool Search 的示例:
|
||||
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: higress-config
|
||||
namespace: higress-system
|
||||
data:
|
||||
higress: |
|
||||
mcpServer:
|
||||
enable: true
|
||||
sse_path_suffix: "/sse"
|
||||
redis:
|
||||
address: "<Redis IP>:6379"
|
||||
username: ""
|
||||
password: ""
|
||||
db: 0
|
||||
match_list:
|
||||
- path_rewrite_prefix: ""
|
||||
upstream_type: ""
|
||||
enable_path_rewrite: false
|
||||
match_rule_domain: "*"
|
||||
match_rule_path: "/mcp-servers/tool-search"
|
||||
match_rule_type: "prefix"
|
||||
servers:
|
||||
- path: "/mcp-servers/tool-search"
|
||||
name: "tool-search"
|
||||
type: "tool-search"
|
||||
config:
|
||||
vector:
|
||||
type: "milvus"
|
||||
host: "localhost"
|
||||
port: 19530
|
||||
database: "default"
|
||||
tableName: "apig_mcp_tools"
|
||||
username: "root"
|
||||
password: "Milvus"
|
||||
maxTools: 1000
|
||||
embedding:
|
||||
apiKey: "your-dashscope-api-key"
|
||||
baseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
model: "text-embedding-v4"
|
||||
dimensions: 1024
|
||||
description: "Higress 工具语义搜索服务"
|
||||
```
|
||||
|
||||
## 工具搜索接口
|
||||
|
||||
Tool Search MCP Server 提供以下 MCP 工具:
|
||||
|
||||
### x_higress_tool_search
|
||||
|
||||
基于语义相似度搜索最相关的工具。
|
||||
|
||||
**输入参数**:
|
||||
|
||||
| 参数名 | 类型 | 必填 | 说明 |
|
||||
|---------|--------|------|------|
|
||||
| `query` | string | 是 | 查询语句,用于与工具描述进行语义相似度比较 |
|
||||
| `topK` | int | 否 | 指定需要选择的工具数量,默认选择前10个工具 |
|
||||
|
||||
**输出格式**:
|
||||
|
||||
```
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"name": "server_name___tool_name",
|
||||
"title": "Tool Title",
|
||||
"description": "Tool description",
|
||||
"inputSchema": {...},
|
||||
"outputSchema": {...}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 搜索实现
|
||||
|
||||
通过向量相似度进行搜索,索引配置如下
|
||||
- 使用 HNSW 索引算法进行向量索引
|
||||
- 默认参数:M=8, efConstruction=64
|
||||
- 相似度度量方式:内积(IP)
|
||||
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"vector": {
|
||||
"type": "milvus",
|
||||
"vectorWeight": 0.5,
|
||||
"tableName": "apig_mcp_tools",
|
||||
"host": "localhost",
|
||||
"port": 19530,
|
||||
"database": "default",
|
||||
"username": "root",
|
||||
"password": "Milvus"
|
||||
},
|
||||
"embedding": {
|
||||
"apiKey": "your-dashscope-api-key",
|
||||
"baseURL": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "text-embedding-v4",
|
||||
"dimensions": 1024
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/openai/openai-go/v2"
|
||||
"github.com/openai/openai-go/v2/option"
|
||||
)
|
||||
|
||||
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
|
||||
type EmbeddingClient struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
dimensions int
|
||||
}
|
||||
|
||||
// NewEmbeddingClient creates a new EmbeddingClient instance for OpenAI-compatible APIs
|
||||
func NewEmbeddingClient(apiKey, baseURL, model string, dimensions int) *EmbeddingClient {
|
||||
api.LogInfof("Creating EmbeddingClient with baseURL: %s, model: %s, dimensions: %d", baseURL, model, dimensions)
|
||||
|
||||
// Create client with timeout
|
||||
client := openai.NewClient(
|
||||
option.WithAPIKey(apiKey),
|
||||
option.WithBaseURL(baseURL),
|
||||
option.WithRequestTimeout(30*time.Second),
|
||||
)
|
||||
|
||||
return &EmbeddingClient{
|
||||
client: &client,
|
||||
model: model,
|
||||
dimensions: dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetEmbedding generates vector embedding for the given text
|
||||
func (e *EmbeddingClient) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
api.LogInfof("Generating embedding for text (length: %d)", len(text))
|
||||
api.LogDebugf("Using model: %s, dimensions: %d", e.model, e.dimensions)
|
||||
|
||||
// Add timeout to context if not already present
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
params := openai.EmbeddingNewParams{
|
||||
Model: e.model,
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfString: openai.String(text),
|
||||
},
|
||||
Dimensions: openai.Int(int64(e.dimensions)),
|
||||
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
|
||||
}
|
||||
|
||||
api.LogDebugf("Calling OpenAI-compatible API for embedding generation")
|
||||
embeddingResp, err := e.client.Embeddings.New(ctx, params)
|
||||
if err != nil {
|
||||
api.LogErrorf("OpenAI-compatible API call failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
|
||||
if len(embeddingResp.Data) == 0 {
|
||||
api.LogErrorf("Empty embedding response from API")
|
||||
return nil, fmt.Errorf("empty embedding response")
|
||||
}
|
||||
|
||||
api.LogDebugf("Successfully received embedding from API")
|
||||
api.LogDebugf("Response data length: %d, embedding dimension: %d", len(embeddingResp.Data), len(embeddingResp.Data[0].Embedding))
|
||||
|
||||
// Convert []float64 to []float32
|
||||
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
|
||||
for i, v := range embeddingResp.Data[0].Embedding {
|
||||
embedding[i] = float32(v)
|
||||
}
|
||||
|
||||
api.LogInfof("Embedding conversion completed, final dimension: %d", len(embedding))
|
||||
return embedding, nil
|
||||
}
|
||||
204
plugins/golang-filter/mcp-server/servers/tool-search/milvus.go
Normal file
204
plugins/golang-filter/mcp-server/servers/tool-search/milvus.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
)
|
||||
|
||||
type MilvusVectorStoreProvider struct {
|
||||
client client.Client
|
||||
collection string
|
||||
dimensions int
|
||||
}
|
||||
|
||||
func NewMilvusVectorStoreProvider(cfg *config.VectorDBConfig, dimensions int) (*MilvusVectorStoreProvider, error) {
|
||||
connectParam := client.Config{
|
||||
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
}
|
||||
connectParam.DBName = cfg.Database
|
||||
|
||||
if cfg.Username != "" && cfg.Password != "" {
|
||||
connectParam.Username = cfg.Username
|
||||
connectParam.Password = cfg.Password
|
||||
}
|
||||
|
||||
milvusClient, err := client.NewClient(context.Background(), connectParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create milvus client: %w", err)
|
||||
}
|
||||
|
||||
return &MilvusVectorStoreProvider{
|
||||
client: milvusClient,
|
||||
collection: cfg.Collection,
|
||||
dimensions: dimensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MilvusVectorStoreProvider) ListAllDocs(ctx context.Context, limit int) ([]schema.Document, error) {
|
||||
expr := ""
|
||||
|
||||
outputFields := []string{"id", "content", "metadata", "created_at"}
|
||||
|
||||
var queryResult []entity.Column
|
||||
var err error
|
||||
|
||||
if limit > 0 {
|
||||
queryResult, err = c.client.Query(
|
||||
ctx,
|
||||
c.collection,
|
||||
[]string{}, // partitions
|
||||
expr, // filter condition
|
||||
outputFields,
|
||||
client.WithLimit(int64(limit)),
|
||||
)
|
||||
} else {
|
||||
queryResult, err = c.client.Query(
|
||||
ctx,
|
||||
c.collection,
|
||||
[]string{}, // partitions
|
||||
expr, // filter condition
|
||||
outputFields,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query all documents: %w", err)
|
||||
}
|
||||
|
||||
if len(queryResult) == 0 {
|
||||
return []schema.Document{}, nil
|
||||
}
|
||||
|
||||
rowCount := queryResult[0].Len()
|
||||
documents := make([]schema.Document, 0, rowCount)
|
||||
|
||||
for i := 0; i < rowCount; i++ {
|
||||
var (
|
||||
id string
|
||||
content string
|
||||
metadata map[string]interface{}
|
||||
createdAt int64
|
||||
)
|
||||
|
||||
for _, col := range queryResult {
|
||||
switch col.Name() {
|
||||
case "id":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
id = v.(string)
|
||||
}
|
||||
case "content":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
content = v.(string)
|
||||
}
|
||||
case "metadata":
|
||||
if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil {
|
||||
if bytes, ok := v.([]byte); ok {
|
||||
_ = json.Unmarshal(bytes, &metadata)
|
||||
}
|
||||
}
|
||||
case "created_at":
|
||||
if v, err := col.(*entity.ColumnInt64).Get(i); err == nil {
|
||||
createdAt = v.(int64)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
doc := schema.Document{
|
||||
ID: id,
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
CreatedAt: time.UnixMilli(createdAt),
|
||||
}
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
func (c *MilvusVectorStoreProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
|
||||
if options == nil {
|
||||
options = &schema.SearchOptions{TopK: 10}
|
||||
}
|
||||
|
||||
sp, err := entity.NewIndexHNSWSearchParam(16) // 默认 HNSW 搜索参数
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build search param: %w", err)
|
||||
}
|
||||
|
||||
outputFields := []string{"id", "content", "metadata"}
|
||||
searchResults, err := c.client.Search(
|
||||
ctx,
|
||||
c.collection,
|
||||
[]string{}, // partition names
|
||||
"", // filter expression
|
||||
outputFields, // output fields
|
||||
[]entity.Vector{entity.FloatVector(vector)},
|
||||
"vector", // anns_field
|
||||
entity.IP, // metric_type
|
||||
options.TopK,
|
||||
sp,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search documents: %w", err)
|
||||
}
|
||||
|
||||
var results []schema.SearchResult
|
||||
for _, result := range searchResults {
|
||||
for i := 0; i < result.ResultCount; i++ {
|
||||
id, _ := result.IDs.Get(i)
|
||||
score := result.Scores[i]
|
||||
|
||||
var content string
|
||||
var metadata map[string]interface{}
|
||||
for _, field := range result.Fields {
|
||||
switch field.Name() {
|
||||
case "content":
|
||||
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
|
||||
if contentVal, err := contentCol.Get(i); err == nil {
|
||||
if contentStr, ok := contentVal.(string); ok {
|
||||
content = contentStr
|
||||
}
|
||||
}
|
||||
}
|
||||
case "metadata":
|
||||
if metaCol, ok := field.(*entity.ColumnJSONBytes); ok {
|
||||
if metaVal, err := metaCol.Get(i); err == nil {
|
||||
if metaBytes, ok := metaVal.([]byte); ok {
|
||||
if err := json.Unmarshal(metaBytes, &metadata); err != nil {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
searchResult := schema.SearchResult{
|
||||
Document: schema.Document{
|
||||
ID: fmt.Sprintf("%s", id),
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
},
|
||||
Score: float64(score),
|
||||
}
|
||||
results = append(results, searchResult)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *MilvusVectorStoreProvider) Close() error {
|
||||
if c.client != nil {
|
||||
return c.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
237
plugins/golang-filter/mcp-server/servers/tool-search/search.go
Normal file
237
plugins/golang-filter/mcp-server/servers/tool-search/search.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
)
|
||||
|
||||
// SearchService handles tool search operations
|
||||
type SearchService struct {
|
||||
milvusProvider *MilvusVectorStoreProvider
|
||||
config *config.VectorDBConfig
|
||||
tableName string
|
||||
dimensions int
|
||||
maxTools int // 写死的最大工具数量,仅用于单测
|
||||
embeddingClient *EmbeddingClient
|
||||
}
|
||||
|
||||
// NewSearchService creates a new SearchService instance
|
||||
func NewSearchService(host string, port int, database, username, password, tableName string, embeddingClient *EmbeddingClient, dimensions int, maxTools int) *SearchService {
|
||||
// Create Milvus configuration
|
||||
cfg := &config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
Host: host,
|
||||
Port: port,
|
||||
Database: database,
|
||||
Collection: tableName,
|
||||
Username: username,
|
||||
Password: password,
|
||||
}
|
||||
|
||||
// Create Milvus provider
|
||||
provider, err := NewMilvusVectorStoreProvider(cfg, dimensions)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to create Milvus provider: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &SearchService{
|
||||
milvusProvider: provider,
|
||||
config: cfg,
|
||||
tableName: tableName,
|
||||
dimensions: dimensions,
|
||||
maxTools: maxTools, // 使用写死的值
|
||||
embeddingClient: embeddingClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ToolSearchResult represents the result of a tool search
|
||||
type ToolSearchResult struct {
|
||||
Tools []ToolDefinition `json:"tools"`
|
||||
}
|
||||
|
||||
// ToolDefinition represents a tool definition in the search result
|
||||
type ToolDefinition map[string]interface{}
|
||||
|
||||
// SearchTools performs semantic search for tools
|
||||
func (s *SearchService) SearchTools(ctx context.Context, query string, topK int) (*ToolSearchResult, error) {
|
||||
api.LogInfof("Starting tool search for query: '%s', topK: %d", query, topK)
|
||||
|
||||
// Generate vector embedding for the query
|
||||
vector, err := s.embeddingClient.GetEmbedding(ctx, query)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to generate embedding for query '%s': %v", query, err)
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Embedding generated successfully, vector dimension: %d", len(vector))
|
||||
|
||||
// Perform vector search
|
||||
records, err := s.searchToolsInDB(query, vector, topK)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to search tools: %v", err)
|
||||
return nil, fmt.Errorf("failed to search tools: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Vector search completed, found %d records", len(records))
|
||||
|
||||
return s.convertRecordsToResult(records), nil
|
||||
}
|
||||
|
||||
// convertRecordsToResult converts database records to tool search result
|
||||
func (s *SearchService) convertRecordsToResult(records []ToolRecord) *ToolSearchResult {
|
||||
api.LogInfof("Converting %d records to tool definitions", len(records))
|
||||
|
||||
tools := make([]ToolDefinition, 0, len(records))
|
||||
for i, record := range records {
|
||||
var tool ToolDefinition
|
||||
|
||||
// Use metadata if available
|
||||
if len(record.Metadata) > 0 {
|
||||
tool = record.Metadata
|
||||
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
|
||||
} else {
|
||||
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
|
||||
// If no metadata, create a basic tool definition
|
||||
tool = ToolDefinition{
|
||||
"name": record.Name,
|
||||
"description": record.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Update the name to include server name
|
||||
tool["name"] = fmt.Sprintf("%s", record.Name)
|
||||
|
||||
tools = append(tools, tool)
|
||||
|
||||
api.LogDebugf("Tool %d: %s - %s", i+1, tool["name"], record.Content)
|
||||
}
|
||||
|
||||
api.LogInfof("Successfully converted %d tools", len(tools))
|
||||
return &ToolSearchResult{Tools: tools}
|
||||
}
|
||||
|
||||
// GetAllTools retrieves all available tools
|
||||
func (s *SearchService) GetAllTools() (*ToolSearchResult, error) {
|
||||
api.LogInfo("Retrieving all tools")
|
||||
records, err := s.getAllToolsFromDB()
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to get all tools: %v", err)
|
||||
return nil, fmt.Errorf("failed to get all tools: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Found %d tools in database", len(records))
|
||||
|
||||
// Convert records to tool definitions
|
||||
tools := make([]ToolDefinition, 0, len(records))
|
||||
for _, record := range records {
|
||||
var tool ToolDefinition
|
||||
|
||||
// Use metadata if available
|
||||
if len(record.Metadata) > 0 {
|
||||
tool = record.Metadata
|
||||
api.LogDebugf("Successfully parsed metadata for tool %s", record.Name)
|
||||
} else {
|
||||
api.LogDebugf("No metadata found for tool %s, using basic definition", record.Name)
|
||||
// If no metadata, create a basic tool definition
|
||||
tool = ToolDefinition{
|
||||
"name": record.Name,
|
||||
"description": record.Content,
|
||||
}
|
||||
}
|
||||
|
||||
// Update the name to include server name
|
||||
tool["name"] = fmt.Sprintf("%s", record.Name)
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
api.LogInfof("Successfully converted %d tools", len(tools))
|
||||
return &ToolSearchResult{Tools: tools}, nil
|
||||
}
|
||||
|
||||
// ToolRecord represents a tool record in the database
|
||||
type ToolRecord struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
func (s *SearchService) searchToolsInDB(query string, vector []float32, topK int) ([]ToolRecord, error) {
|
||||
api.LogInfof("Performing vector search for query: '%s', topK: %d", query, topK)
|
||||
|
||||
// For Milvus, we'll perform vector search directly
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Perform vector search
|
||||
searchOptions := &schema.SearchOptions{
|
||||
TopK: topK,
|
||||
}
|
||||
|
||||
results, err := s.milvusProvider.SearchDocs(ctx, vector, searchOptions)
|
||||
if err != nil {
|
||||
api.LogErrorf("Vector search failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to perform vector search: %w", err)
|
||||
}
|
||||
|
||||
// Convert results to ToolRecords
|
||||
var records []ToolRecord
|
||||
for _, result := range results {
|
||||
doc := result.Document
|
||||
tool := ToolRecord{
|
||||
ID: doc.ID,
|
||||
Content: doc.Content,
|
||||
Metadata: doc.Metadata,
|
||||
}
|
||||
|
||||
if name, ok := doc.Metadata["name"].(string); ok {
|
||||
tool.Name = name
|
||||
}
|
||||
|
||||
records = append(records, tool)
|
||||
}
|
||||
|
||||
api.LogInfof("Vector search completed, found %d results", len(records))
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// getAllToolsFromDB retrieves all tools from the database
|
||||
func (s *SearchService) getAllToolsFromDB() ([]ToolRecord, error) {
|
||||
api.LogInfof("Executing GetAllTools query from collection: %s", s.tableName)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Retrieve all documents with limit
|
||||
docs, err := s.milvusProvider.ListAllDocs(ctx, s.maxTools)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to list documents: %v", err)
|
||||
return nil, fmt.Errorf("failed to list documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert documents to ToolRecords
|
||||
var tools []ToolRecord
|
||||
for _, doc := range docs {
|
||||
tool := ToolRecord{
|
||||
ID: doc.ID,
|
||||
Content: doc.Content,
|
||||
Metadata: doc.Metadata,
|
||||
}
|
||||
|
||||
if name, ok := doc.Metadata["name"].(string); ok {
|
||||
tool.Name = name
|
||||
}
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
api.LogInfof("GetAllTools query completed, found %d tools", len(tools))
|
||||
return tools, nil
|
||||
}
|
||||
196
plugins/golang-filter/mcp-server/servers/tool-search/server.go
Normal file
196
plugins/golang-filter/mcp-server/servers/tool-search/server.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "1.0.0"
|
||||
|
||||
// 默认配置值
|
||||
defaultTableName = "apig_mcp_tools"
|
||||
defaultBaseURL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
defaultModel = "text-embedding-v4"
|
||||
defaultDimensions = 1024
|
||||
// 写死最大工具数量为1000,仅用于单测
|
||||
fixedMaxTools = 1000
|
||||
)
|
||||
|
||||
func init() {
|
||||
common.GlobalRegistry.RegisterServer("tool-search", &ToolSearchConfig{})
|
||||
}
|
||||
|
||||
type VectorConfig struct {
|
||||
Type string `json:"type"`
|
||||
VectorWeight float64 `json:"vectorWeight"`
|
||||
TableName string `json:"tableName"`
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
Database string `json:"database"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type EmbeddingConfig struct {
|
||||
APIKey string `json:"apiKey"`
|
||||
BaseURL string `json:"baseURL"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions"`
|
||||
}
|
||||
|
||||
type ToolSearchConfig struct {
|
||||
Vector VectorConfig `json:"vector"`
|
||||
Embedding EmbeddingConfig `json:"embedding"`
|
||||
description string
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) ParseConfig(config map[string]any) error {
|
||||
// Parse vector configuration
|
||||
vectorConfig, ok := config["vector"].(map[string]any)
|
||||
if !ok {
|
||||
return errors.New("missing vector configuration")
|
||||
}
|
||||
|
||||
if err := c.parseVectorConfig(vectorConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse vector config: %w", err)
|
||||
}
|
||||
|
||||
// Parse embedding configuration
|
||||
embeddingConfig, ok := config["embedding"].(map[string]any)
|
||||
if !ok {
|
||||
return errors.New("missing embedding configuration")
|
||||
}
|
||||
|
||||
if err := c.parseEmbeddingConfig(embeddingConfig); err != nil {
|
||||
return fmt.Errorf("failed to parse embedding config: %w", err)
|
||||
}
|
||||
|
||||
// Optional description
|
||||
if description, ok := config["description"].(string); ok {
|
||||
c.description = description
|
||||
} else {
|
||||
c.description = "Tool search server for semantic similarity search"
|
||||
}
|
||||
|
||||
api.LogDebugf("ToolSearchConfig ParseConfig: %+v", config)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) parseVectorConfig(config map[string]any) error {
|
||||
if vectorType, ok := config["type"].(string); ok {
|
||||
c.Vector.Type = vectorType
|
||||
} else {
|
||||
return errors.New("missing vector.type")
|
||||
}
|
||||
|
||||
if c.Vector.Type != "milvus" {
|
||||
return fmt.Errorf("unsupported vector.type: %s, only 'milvus' is supported", c.Vector.Type)
|
||||
}
|
||||
|
||||
if host, ok := config["host"].(string); ok {
|
||||
c.Vector.Host = host
|
||||
} else {
|
||||
return errors.New("missing vector.host")
|
||||
}
|
||||
|
||||
if port, ok := config["port"].(float64); ok {
|
||||
c.Vector.Port = int(port)
|
||||
} else if port, ok := config["port"].(int); ok {
|
||||
c.Vector.Port = port
|
||||
} else {
|
||||
return errors.New("missing vector.port")
|
||||
}
|
||||
|
||||
if database, ok := config["database"].(string); ok {
|
||||
c.Vector.Database = database
|
||||
} else {
|
||||
c.Vector.Database = "default" // 默认数据库
|
||||
}
|
||||
|
||||
if tableName, ok := config["tableName"].(string); ok {
|
||||
c.Vector.TableName = tableName
|
||||
} else {
|
||||
c.Vector.TableName = defaultTableName
|
||||
}
|
||||
|
||||
if username, ok := config["username"].(string); ok {
|
||||
c.Vector.Username = username
|
||||
}
|
||||
|
||||
if password, ok := config["password"].(string); ok {
|
||||
c.Vector.Password = password
|
||||
}
|
||||
|
||||
// 移除maxTools的解析逻辑
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) parseEmbeddingConfig(config map[string]any) error {
|
||||
// Parse API key (required)
|
||||
if apiKey, ok := config["apiKey"].(string); ok {
|
||||
c.Embedding.APIKey = apiKey
|
||||
} else {
|
||||
return errors.New("missing embedding.apiKey")
|
||||
}
|
||||
|
||||
// Parse optional fields with defaults
|
||||
if baseURL, ok := config["baseURL"].(string); ok {
|
||||
c.Embedding.BaseURL = baseURL
|
||||
} else {
|
||||
c.Embedding.BaseURL = defaultBaseURL
|
||||
}
|
||||
|
||||
if model, ok := config["model"].(string); ok {
|
||||
c.Embedding.Model = model
|
||||
} else {
|
||||
c.Embedding.Model = defaultModel
|
||||
}
|
||||
|
||||
if dimensions, ok := config["dimensions"].(float64); ok {
|
||||
c.Embedding.Dimensions = int(dimensions)
|
||||
} else if dimensions, ok := config["dimensions"].(int); ok {
|
||||
c.Embedding.Dimensions = dimensions
|
||||
} else {
|
||||
c.Embedding.Dimensions = defaultDimensions
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ToolSearchConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
Version,
|
||||
common.WithInstructions(c.description),
|
||||
)
|
||||
|
||||
// Create embedding client
|
||||
embeddingClient := NewEmbeddingClient(c.Embedding.APIKey, c.Embedding.BaseURL, c.Embedding.Model, c.Embedding.Dimensions)
|
||||
|
||||
// Create search service,使用写死的fixedMaxTools值
|
||||
searchService := NewSearchService(
|
||||
c.Vector.Host,
|
||||
c.Vector.Port,
|
||||
c.Vector.Database,
|
||||
c.Vector.Username,
|
||||
c.Vector.Password,
|
||||
c.Vector.TableName,
|
||||
embeddingClient,
|
||||
c.Embedding.Dimensions,
|
||||
fixedMaxTools, // 使用写死的值
|
||||
)
|
||||
|
||||
// Add tool search tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("x_higress_tool_search", "Higress MCP Tools Searcher", GetToolSearchSchema()),
|
||||
HandleToolSearch(searchService),
|
||||
)
|
||||
|
||||
return mcpServer, nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// Mock implementation of CommonCAPI for testing
|
||||
type mockCommonCAPI struct {
|
||||
logs []string
|
||||
}
|
||||
|
||||
func (m *mockCommonCAPI) Log(level api.LogType, message string) {
|
||||
fmt.Printf("[%s] %s\n", level, message)
|
||||
m.logs = append(m.logs, message)
|
||||
}
|
||||
|
||||
func (m *mockCommonCAPI) LogLevel() api.LogType {
|
||||
return api.Debug
|
||||
}
|
||||
|
||||
// TestServer is used for local functional testing
|
||||
func TestServer(t *testing.T) {
|
||||
// Setup mock API for logging
|
||||
mockAPI := &mockCommonCAPI{}
|
||||
api.SetCommonCAPI(mockAPI)
|
||||
|
||||
// Load configuration from environment variables or use defaults
|
||||
config := map[string]any{
|
||||
"vector": map[string]any{
|
||||
"type": "milvus",
|
||||
"vectorWeight": 0.6,
|
||||
"tableName": getEnvOrDefault("TEST_TABLE_NAME", "apig_mcp_tools"),
|
||||
"host": getEnvOrDefault("TEST_MILVUS_HOST", "localhost"),
|
||||
"port": getEnvOrDefaultInt("TEST_MILVUS_PORT", 19530),
|
||||
"database": getEnvOrDefault("TEST_MILVUS_DATABASE", "default"),
|
||||
"username": getEnvOrDefault("TEST_MILVUS_USERNAME", "root"),
|
||||
"password": getEnvOrDefault("TEST_MILVUS_PASSWORD", "Milvus"),
|
||||
"maxTools": getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
|
||||
},
|
||||
"embedding": map[string]any{
|
||||
"apiKey": getEnvOrDefault("TEST_API_KEY", "your-dashscope-api-key"),
|
||||
"baseURL": getEnvOrDefault("TEST_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
"model": getEnvOrDefault("TEST_MODEL", "text-embedding-v4"),
|
||||
"dimensions": 1024,
|
||||
},
|
||||
"description": "Test MCP Tools Search Server",
|
||||
}
|
||||
|
||||
// Create configuration instance
|
||||
toolSearchConfig := &ToolSearchConfig{}
|
||||
if err := toolSearchConfig.ParseConfig(config); err != nil {
|
||||
t.Fatalf("Failed to parse config: %v", err)
|
||||
}
|
||||
|
||||
// Create MCP Server
|
||||
_, err := toolSearchConfig.NewServer("test-tool-search")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create server: %v", err)
|
||||
}
|
||||
|
||||
// Test database connection
|
||||
vectorConfig := config["vector"].(map[string]any)
|
||||
embeddingConfig := config["embedding"].(map[string]any)
|
||||
|
||||
// Test GetAllTools
|
||||
t.Logf("\n=== Testing GetAllTools ===")
|
||||
embeddingClient := NewEmbeddingClient(
|
||||
embeddingConfig["apiKey"].(string),
|
||||
embeddingConfig["baseURL"].(string),
|
||||
embeddingConfig["model"].(string),
|
||||
embeddingConfig["dimensions"].(int),
|
||||
)
|
||||
|
||||
searchService := NewSearchService(
|
||||
vectorConfig["host"].(string),
|
||||
vectorConfig["port"].(int),
|
||||
vectorConfig["database"].(string),
|
||||
vectorConfig["username"].(string),
|
||||
vectorConfig["password"].(string),
|
||||
vectorConfig["tableName"].(string),
|
||||
embeddingClient,
|
||||
embeddingConfig["dimensions"].(int),
|
||||
getEnvOrDefaultInt("TEST_MAX_TOOLS", 1000),
|
||||
)
|
||||
|
||||
allTools, err := searchService.GetAllTools()
|
||||
if err != nil {
|
||||
t.Logf("GetAllTools failed: %v", err)
|
||||
} else {
|
||||
t.Logf("Found %d tools:", len(allTools.Tools))
|
||||
for i, tool := range allTools.Tools {
|
||||
if i < 3 { // Show only first 3 tools
|
||||
toolJSON, _ := json.MarshalIndent(tool, "", " ")
|
||||
t.Logf("Tool %d: %s", i+1, string(toolJSON))
|
||||
}
|
||||
}
|
||||
if len(allTools.Tools) > 3 {
|
||||
t.Logf("... and %d more tools", len(allTools.Tools)-3)
|
||||
}
|
||||
}
|
||||
|
||||
// Test tool search with timing
|
||||
t.Logf("\n=== Testing Tool Search ===")
|
||||
testQueries := []string{
|
||||
"weather data",
|
||||
"database query",
|
||||
"file operations",
|
||||
"HTTP requests",
|
||||
"library documents",
|
||||
}
|
||||
|
||||
for _, query := range testQueries {
|
||||
t.Logf("\n--- Testing query: '%s' ---", query)
|
||||
|
||||
// Create MCP tool call request
|
||||
request := mcp.CallToolRequest{
|
||||
Params: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments,omitempty"`
|
||||
Meta *struct {
|
||||
ProgressToken mcp.ProgressToken `json:"progressToken,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
}{
|
||||
Name: "x_higress_tool_search",
|
||||
Arguments: map[string]interface{}{
|
||||
"query": query,
|
||||
"topK": 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Get tool handler
|
||||
handler := HandleToolSearch(searchService)
|
||||
|
||||
// Execute search with timing
|
||||
start := time.Now()
|
||||
result, err := handler(context.Background(), request)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Search failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Print results with timing information
|
||||
t.Logf("Search completed in %v", duration)
|
||||
if len(result.Content) > 0 {
|
||||
if textContent, ok := result.Content[0].(mcp.TextContent); ok {
|
||||
var toolsResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(textContent.Text), &toolsResult); err == nil {
|
||||
toolsJSON, _ := json.MarshalIndent(toolsResult, "", " ")
|
||||
t.Logf("Tools Result: %s", string(toolsJSON))
|
||||
} else {
|
||||
t.Logf("Text Content: %s", textContent.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test configuration validation
|
||||
t.Logf("\n=== Configuration Validation ===")
|
||||
t.Logf("Host: %s", vectorConfig["host"])
|
||||
t.Logf("Port: %d", vectorConfig["port"])
|
||||
t.Logf("Database: %s", vectorConfig["database"])
|
||||
t.Logf("Table Name: %s", vectorConfig["tableName"])
|
||||
t.Logf("Vector Weight: %f", vectorConfig["vectorWeight"])
|
||||
t.Logf("Text Weight: %f", 1.0-vectorConfig["vectorWeight"].(float64))
|
||||
t.Logf("Model: %s", embeddingConfig["model"])
|
||||
t.Logf("Dimensions: %d", embeddingConfig["dimensions"])
|
||||
t.Logf("API Base URL: %s", embeddingConfig["baseURL"])
|
||||
|
||||
t.Logf("\n=== Test completed ===")
|
||||
}
|
||||
|
||||
// Helper function to get environment variable or default value
|
||||
func getEnvOrDefault(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func getEnvOrDefaultInt(key string, defaultValue int) int {
|
||||
if valueStr := os.Getenv(key); valueStr != "" {
|
||||
if value, err := fmt.Sscanf(valueStr, "%d", &defaultValue); err == nil && value == 1 {
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
114
plugins/golang-filter/mcp-server/servers/tool-search/tools.go
Normal file
114
plugins/golang-filter/mcp-server/servers/tool-search/tools.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package tool_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/envoyproxy/envoy/contrib/golang/common/go/api"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// HandleToolSearch handles the x_higress_tool_search tool
|
||||
func HandleToolSearch(searchService *SearchService) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
api.LogInfo("HandleToolSearch called")
|
||||
|
||||
arguments := request.Params.Arguments
|
||||
api.LogDebugf("Request arguments: %+v", arguments)
|
||||
|
||||
// Get query parameter
|
||||
query, ok := arguments["query"].(string)
|
||||
if !ok {
|
||||
api.LogErrorf("Invalid query argument type: %T", arguments["query"])
|
||||
return nil, fmt.Errorf("invalid query argument")
|
||||
}
|
||||
|
||||
// Validate query
|
||||
if query == "" {
|
||||
api.LogError("Empty query provided")
|
||||
return nil, fmt.Errorf("query cannot be empty")
|
||||
}
|
||||
|
||||
// Get topK parameter (optional, default to 10)
|
||||
topK := 10
|
||||
if topKVal, ok := arguments["topK"]; ok {
|
||||
switch v := topKVal.(type) {
|
||||
case float64:
|
||||
topK = int(v)
|
||||
case int:
|
||||
topK = v
|
||||
case int64:
|
||||
topK = int(v)
|
||||
default:
|
||||
api.LogWarnf("Invalid topK argument type: %T, using default: %d", topKVal, topK)
|
||||
}
|
||||
|
||||
// Validate topK range
|
||||
if topK <= 0 || topK > 100 {
|
||||
api.LogWarnf("Invalid topK value: %d, using default: 10", topK)
|
||||
topK = 10
|
||||
}
|
||||
}
|
||||
|
||||
api.LogInfof("Parsed parameters - query: '%s', topK: %d", query, topK)
|
||||
|
||||
// Add timeout to context
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Perform search
|
||||
result, err := searchService.SearchTools(ctx, query, topK)
|
||||
if err != nil {
|
||||
api.LogErrorf("Search failed: %v", err)
|
||||
return nil, fmt.Errorf("failed to search tools: %w", err)
|
||||
}
|
||||
|
||||
api.LogInfof("Search completed successfully, found %d tools", len(result.Tools))
|
||||
|
||||
// Build response
|
||||
response := map[string]interface{}{
|
||||
"tools": result.Tools,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to marshal response: %v", err)
|
||||
return nil, fmt.Errorf("failed to marshal search results: %w", err)
|
||||
}
|
||||
|
||||
api.LogDebugf("Response marshaled successfully, JSON length: %d", len(jsonData))
|
||||
api.LogDebugf("Returning MCP CallToolResult")
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetToolSearchSchema returns the schema for the tool search tool
|
||||
func GetToolSearchSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Query statement for semantic similarity comparison with tool descriptions"
|
||||
},
|
||||
"topK": {
|
||||
"type": "integer",
|
||||
"description": "Specify how many tools need to be selected, default is to select the top 10 tools.",
|
||||
"minimum": 1,
|
||||
"maximum": 100
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}`)
|
||||
}
|
||||
Reference in New Issue
Block a user