feat: add rag mcp server (#2930)

This commit is contained in:
Jun
2025-09-21 14:48:22 +08:00
committed by GitHub
parent fc65104437
commit 8b8c8b242b
26 changed files with 3141 additions and 0 deletions

View File

@@ -46,14 +46,17 @@ require (
github.com/buger/jsonparser v1.1.1 // indirect
github.com/clbanning/mxj/v2 v2.5.5 // indirect
github.com/deckarep/golang-set v1.7.1 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
github.com/prometheus/client_golang v1.14.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.37.0 // indirect

View File

@@ -136,6 +136,8 @@ github.com/deckarep/golang-set v1.7.1 h1:SCQV0S6gTtp6itiFrTqI+pfmJ4LN85S1YzhDf9r
github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@@ -290,6 +292,8 @@ github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 h1:Xqf+S7iicElwYoS2Zly8Nf/zKHuZsNy1xQajfdtygVY=
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2/go.mod h1:ulO1YUXKH0PGg50q27grw048GDY9ayB4FPmh7D+FFTA=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -318,6 +322,8 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=

View File

@@ -6,6 +6,7 @@ import (
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api"
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag"
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
xds "github.com/cncf/xds/go/xds/type/v3"

View File

@@ -0,0 +1,173 @@
# Higress RAG MCP Server
这是一个 Model Context Protocol (MCP) 服务器,提供知识管理和检索功能。
该 MCP 服务器提供以下工具:
## MCP Tools
### 知识管理
- `create-chunks-from-text` - 从 Text 创建知识 (p1)
### 块管理
- `list-chunks` - 列出知识块
- `delete-chunk` - 删除知识块
### 搜索
- `search` - 搜索
### 聊天功能
- `chat` - 发送聊天消息
## 配置说明
### 配置结构
```yaml
rag:
# RAG系统基础配置
splitter:
type: "recursive" # 递归分块器 recursive 和 nosplitter
chunk_size: 500
chunk_overlap: 50
top_k: 5 # 搜索返回的知识块数量
threshold: 0.5 # 搜索阈值
llm:
provider: "openai" # openai
api_key: "your-llm-api-key"
base_url: "https://api.openai.com/v1" # 可选
model: "gpt-3.5-turbo" # LLM模型
max_tokens: 2048 # 最大令牌数
temperature: 0.5 # 温度参数
embedding:
provider: "openai" # openai, dashscope
api_key: "your-embedding-api-key"
base_url: "https://api.openai.com/v1" # 可选
model: "text-embedding-ada-002" # 嵌入模型
vectordb:
provider: "milvus" # milvus
host: "localhost"
port: 19530
database: "default"
collection: "test_collection"
username: "" # 可选
password: "" # 可选
```
### higress-config 配置样例
```yaml
apiVersion: v1
kind: ConfigMap
metadata:
name: higress-config
namespace: higress-system
data:
higress: |
mcpServer:
enable: true
sse_path_suffix: "/sse"
redis:
address: "<Redis IP>:6379"
username: ""
password: ""
db: 0
match_list:
- path_rewrite_prefix: ""
upstream_type: ""
enable_path_rewrite: false
match_rule_domain: ""
match_rule_path: "/mcp-servers/rag"
match_rule_type: "prefix"
servers:
- path: "/mcp-servers/rag"
name: "rag"
type: "rag"
config:
rag:
splitter:
provider: recursive
chunk_size: 500
chunk_overlap: 50
top_k: 10
threshold: 0.5
llm:
provider: openai
api_key: sk-XXX
base_url: https://openrouter.ai/api/v1
model: openai/gpt-4o
temperature: 0.5
max_tokens: 2048
embedding:
provider: dashscope
api_key: sk-xxx
model: text-embedding-v4
vectordb:
provider: milvus
host: <milvus IP>
port: 19530
database: default
collection: test_collection
```
### 支持的提供商
#### Embedding
- **OpenAI**
- **DashScope**
#### Vector Database
- **Milvus**
#### LLM
- **OpenAI**
## Milvus 安装
### Docker 配置
配置 Docker Desktop 镜像加速器
编辑 daemon.json 配置,加上镜像加速器,例如:
```
{
"registry-mirrors": [
"https://docker.m.daocloud.io",
"https://mirror.ccs.tencentyun.com",
"https://hub-mirror.c.163.com"
],
"dns": ["8.8.8.8", "1.1.1.1"]
}
```
### 安装 milvus
```
v2.6.0
Download the configuration file
wget https://github.com/milvus-io/milvus/releases/download/v2.6.0/milvus-standalone-docker-compose.yml -O docker-compose.yml
v2.4
$ wget https://github.com/milvus-io/milvus/releases/download/v2.4.23/milvus-standalone-docker-compose.yml -O docker-compose.yml
# Start Milvus
$ sudo docker compose up -d
Creating milvus-etcd ... done
Creating milvus-minio ... done
Creating milvus-standalone ... done
```
### 安装 attu
Attu 是 Milvus 的可视化管理工具,用于查看和管理 Milvus 中的数据。
```
docker run -p 8000:3000 -e MILVUS_URL=http://<本机 IP>:19530 zilliz/attu:v2.6
Open your browser and navigate to http://localhost:8000
```

View File

@@ -0,0 +1,159 @@
package common
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// HTTPClient handles HTTP API connections and operations
type HTTPClient struct {
baseURL string
headers map[string]string
httpClient *http.Client
}
// NewHTTPClient creates a new HTTP client with base URL and optional headers
func NewHTTPClient(baseURL string, headers map[string]string) *HTTPClient {
client := &HTTPClient{
baseURL: baseURL,
headers: make(map[string]string),
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
// Copy headers to avoid external modification
if headers != nil {
for k, v := range headers {
client.headers[k] = v
}
}
return client
}
// SetHeader sets a header for all requests
func (c *HTTPClient) SetHeader(key, value string) {
c.headers[key] = value
}
// SetHeaders sets multiple headers for all requests
func (c *HTTPClient) SetHeaders(headers map[string]string) {
for k, v := range headers {
c.headers[k] = v
}
}
// RemoveHeader removes a header
func (c *HTTPClient) RemoveHeader(key string) {
delete(c.headers, key)
}
// Get performs a GET request
func (c *HTTPClient) Get(path string) ([]byte, error) {
return c.request("GET", path, nil)
}
// Post performs a POST request
func (c *HTTPClient) Post(path string, data interface{}) ([]byte, error) {
return c.request("POST", path, data)
}
// Put performs a PUT request
func (c *HTTPClient) Put(path string, data interface{}) ([]byte, error) {
return c.request("PUT", path, data)
}
// Delete performs a DELETE request
func (c *HTTPClient) Delete(path string) ([]byte, error) {
return c.request("DELETE", path, nil)
}
// Patch performs a PATCH request
func (c *HTTPClient) Patch(path string, data interface{}) ([]byte, error) {
return c.request("PATCH", path, data)
}
// RequestWithHeaders performs a request with additional headers for this request only
func (c *HTTPClient) RequestWithHeaders(method, path string, data interface{}, additionalHeaders map[string]string) ([]byte, error) {
return c.requestWithHeaders(method, path, data, additionalHeaders)
}
func (c *HTTPClient) request(method, path string, data interface{}) ([]byte, error) {
return c.requestWithHeaders(method, path, data, nil)
}
func (c *HTTPClient) requestWithHeaders(method, path string, data interface{}, additionalHeaders map[string]string) ([]byte, error) {
url := c.baseURL + path
var body io.Reader
if data != nil {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("failed to marshal request data: %w", err)
}
body = bytes.NewBuffer(jsonData)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set default headers
for k, v := range c.headers {
req.Header.Set(k, v)
}
// Set additional headers for this request
if additionalHeaders != nil {
for k, v := range additionalHeaders {
req.Header.Set(k, v)
}
}
// Set Content-Type for requests with body
if data != nil && req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("HTTP error %d", resp.StatusCode)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
return respBody, nil
}
// SetTimeout sets the HTTP client timeout
func (c *HTTPClient) SetTimeout(timeout time.Duration) {
c.httpClient.Timeout = timeout
}
// GetBaseURL returns the base URL
func (c *HTTPClient) GetBaseURL() string {
return c.baseURL
}
// SetBaseURL sets the base URL
func (c *HTTPClient) SetBaseURL(baseURL string) {
c.baseURL = baseURL
}

View File

@@ -0,0 +1,53 @@
package config
// Config represents the main configuration structure for the MCP server
type Config struct {
RAG RAGConfig `json:"rag" yaml:"rag"`
LLM LLMConfig `json:"llm" yaml:"llm"`
Embedding EmbeddingConfig `json:"embedding" yaml:"embedding"`
VectorDB VectorDBConfig `json:"vectordb" yaml:"vectordb"`
}
// RAGConfig contains basic configuration for the RAG system
type RAGConfig struct {
Splitter SplitterConfig `json:"splitter" yaml:"splitter"`
Threshold float64 `json:"threshold,omitempty" yaml:"threshold,omitempty"`
TopK int `json:"top_k,omitempty" yaml:"top_k,omitempty"`
}
// SplitterConfig defines document splitter configuration
type SplitterConfig struct {
Provider string `json:"provider" yaml:"provider"` // Available options: recursive, character, token
ChunkSize int `json:"chunk_size,omitempty" yaml:"chunk_size,omitempty"`
ChunkOverlap int `json:"chunk_overlap,omitempty" yaml:"chunk_overlap,omitempty"`
}
// LLMConfig defines configuration for Large Language Models
type LLMConfig struct {
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope, qwen
APIKey string `json:"api_key,omitempty" yaml:"api_key"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
Model string `json:"model" yaml:"model"`
Temperature float64 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
}
// EmbeddingConfig defines configuration for embedding models
type EmbeddingConfig struct {
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
Model string `json:"model,omitempty" yaml:"model,omitempty"`
Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"`
}
// VectorDBConfig defines configuration for vector databases
type VectorDBConfig struct {
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
Host string `json:"host,omitempty" yaml:"host,omitempty"`
Port int `json:"port,omitempty" yaml:"port,omitempty"`
Database string `json:"database,omitempty" yaml:"database,omitempty"`
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
Password string `json:"password,omitempty" yaml:"password,omitempty"`
}

View File

@@ -0,0 +1,169 @@
package embedding
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
const (
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
DASHSCOPE_PORT = 443
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4"
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)
var dashScopeConfig dashScopeProviderConfig
type dashScopeProviderInitializer struct {
}
type dashScopeProviderConfig struct {
apiKey string
model string
}
func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) {
dashScopeConfig.apiKey = config.APIKey
dashScopeConfig.model = config.Model
}
func (c *dashScopeProviderInitializer) ValidateConfig() error {
if dashScopeConfig.apiKey == "" {
return errors.New("[DashScope] apiKey is required")
}
return nil
}
func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
c.InitConfig(config)
err := c.ValidateConfig()
if err != nil {
return nil, err
}
headers := map[string]string{
"Authorization": "Bearer " + config.APIKey,
"Content-Type": "application/json",
}
httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers)
return &DashScopeProvider{
config: dashScopeConfig,
client: httpClient,
}, nil
}
func (d *DashScopeProvider) GetProviderType() string {
return PROVIDER_TYPE_DASHSCOPE
}
type Embedding struct {
Embedding []float32 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type Input struct {
Texts []string `json:"texts"`
}
type Params struct {
TextType string `json:"text_type"`
}
type Response struct {
RequestID string `json:"request_id"`
Output Output `json:"output"`
Usage Usage `json:"usage"`
}
type Output struct {
Embeddings []Embedding `json:"embeddings"`
}
type Usage struct {
TotalTokens int `json:"total_tokens"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Params `json:"parameters"`
}
type Document struct {
Vector []float64 `json:"vector"`
Fields map[string]string `json:"fields"`
}
type DashScopeProvider struct {
config dashScopeProviderConfig
client *common.HTTPClient
}
func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) {
model := d.config.model
if model == "" {
model = DASHSCOPE_DEFAULT_MODEL_NAME
}
if dashScopeConfig.apiKey == "" {
return EmbeddingRequest{}, errors.New("dashScopeKey is empty")
}
data := EmbeddingRequest{
Model: model,
Input: Input{
Texts: texts,
},
Parameters: Params{
TextType: "query",
},
}
return data, nil
}
type Result struct {
ID string `json:"id"`
Vector []float32 `json:"vector,omitempty"`
Fields map[string]interface{} `json:"fields"`
Score float64 `json:"score"`
}
func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
var resp Response
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (d *DashScopeProvider) GetEmbedding(
ctx context.Context,
queryString string) ([]float32, error) {
requestData, err := d.constructRequestData([]string{queryString})
if err != nil {
return nil, fmt.Errorf("failed to construct request data: %v", err)
}
responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
embeddingResp, err := d.parseTextEmbedding(responseBody)
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
if len(embeddingResp.Output.Embeddings) == 0 {
return nil, errors.New("no embedding found in response")
}
return embeddingResp.Output.Embeddings[0].Embedding, nil
}

View File

@@ -0,0 +1,161 @@
package embedding
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
const (
OPENAI_DOMAIN = "api.openai.com"
OPENAI_PORT = 443
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
OPENAI_ENDPOINT = "/v1/embeddings"
)
type openAIProviderInitializer struct {
}
var openAIConfig openAIProviderConfig
type openAIProviderConfig struct {
baseUrl string
apiKey string
model string
}
func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) {
openAIConfig.apiKey = config.APIKey
openAIConfig.model = config.Model
openAIConfig.baseUrl = config.BaseURL
}
func (c *openAIProviderInitializer) ValidateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openAI] apiKey is required")
}
return nil
}
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
c.InitConfig(config)
err := c.ValidateConfig()
if err != nil {
return nil, err
}
if openAIConfig.model == "" {
openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME
}
if openAIConfig.baseUrl == "" {
openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN)
}
headers := map[string]string{
"Authorization": "Bearer " + config.APIKey,
"Content-Type": "application/json",
}
httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers)
return &OpenAIProvider{
config: openAIConfig,
client: httpClient,
}, nil
}
func (o *OpenAIProvider) GetProviderType() string {
return PROVIDER_TYPE_OPENAI
}
type OpenAIResponse struct {
Object string `json:"object"`
Data []OpenAIResult `json:"data"`
Model string `json:"model"`
Error *OpenAIError `json:"error"`
}
type OpenAIResult struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type OpenAIError struct {
Message string `json:"prompt_tokens"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
}
type OpenAIEmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type OpenAIProvider struct {
config openAIProviderConfig
client *common.HTTPClient
}
func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) {
if text == "" {
return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty")
}
if openAIConfig.apiKey == "" {
return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty")
}
model := o.config.model
if model == "" {
model = OPENAI_DEFAULT_MODEL_NAME
}
data := OpenAIEmbeddingRequest{
Input: text,
Model: model,
}
return data, nil
}
func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
var resp OpenAIResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) {
requestData, err := o.constructRequestData(queryString)
if err != nil {
return nil, fmt.Errorf("failed to construct request data: %v", err)
}
responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
resp, err := o.parseTextEmbedding(responseBody)
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
if resp.Error != nil {
return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message)
}
if len(resp.Data) == 0 {
return nil, errors.New("no embedding found in response")
}
return resp.Data[0].Embedding, nil
}

View File

@@ -0,0 +1,61 @@
package embedding
import (
"context"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
// Provider type constants for different embedding services
const (
// DashScope embedding service
PROVIDER_TYPE_DASHSCOPE = "dashscope"
// TextIn embedding service
PROVIDER_TYPE_TEXTIN = "textin"
// Cohere embedding service
PROVIDER_TYPE_COHERE = "cohere"
// OpenAI embedding service
PROVIDER_TYPE_OPENAI = "openai"
// Ollama embedding service
PROVIDER_TYPE_OLLAMA = "ollama"
// HuggingFace embedding service
PROVIDER_TYPE_HUGGINGFACE = "huggingface"
// XFYun embedding service
PROVIDER_TYPE_XFYUN = "xfyun"
// Azure embedding service
PROVIDER_TYPE_AZURE = "azure"
)
// Factory interface for creating Provider instances
type providerInitializer interface {
// Creates a new Provider with the given configuration
CreateProvider(config.EmbeddingConfig) (Provider, error)
}
// Maps provider types to their initializers
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
}
)
// Provider defines the interface for embedding services
type Provider interface {
// Returns the provider type identifier
GetProviderType() string
// Generates embedding vector for the input text
// Returns a float32 array representing the embedding vector
GetEmbedding(ctx context.Context, queryString string) ([]float32, error)
}
// Creates a new embedding Provider based on the configuration
// Returns error if provider type is not supported
func NewEmbeddingProvider(config config.EmbeddingConfig) (Provider, error) {
initializer, ok := providerInitializers[config.Provider]
if !ok {
return nil, fmt.Errorf("no initializer found for provider type: %s", config.Provider)
}
return initializer.CreateProvider(config)
}

View File

@@ -0,0 +1,136 @@
package llm
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
const (
OPENAI_CHAT_ENDPOINT = "/chat/completions"
OPENAI_DEFAULT_MODEL = "gpt-3.5-turbo"
)
// openAI specific configuration captured after initialization.
type openAIProviderConfig struct {
apiKey string
baseURL string
model string
maxTokens int
temperature float64
}
type openAIProviderInitializer struct{}
var openAIConfig openAIProviderConfig
func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) {
openAIConfig.apiKey = c.APIKey
openAIConfig.baseURL = c.BaseURL
openAIConfig.model = c.Model
if openAIConfig.model == "" {
openAIConfig.model = OPENAI_DEFAULT_MODEL
}
if openAIConfig.baseURL == "" {
openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint
}
openAIConfig.maxTokens = c.MaxTokens
openAIConfig.temperature = c.Temperature
}
func (i *openAIProviderInitializer) validateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openai llm] apiKey is required")
}
return nil
}
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
i.initConfig(cfg)
if err := i.validateConfig(); err != nil {
return nil, err
}
headers := map[string]string{
"Authorization": "Bearer " + openAIConfig.apiKey,
"Content-Type": "application/json",
}
client := common.NewHTTPClient(openAIConfig.baseURL, headers)
return &OpenAIProvider{client: client, cfg: openAIConfig}, nil
}
type OpenAIProvider struct {
client *common.HTTPClient
cfg openAIProviderConfig
}
type openAIChatCompletionRequest struct {
Model string `json:"model"`
Messages []openAIChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type openAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIChatCompletionResponseChoice `json:"choices"`
Error *openAIError `json:"error,omitempty"`
}
type openAIChatCompletionResponseChoice struct {
Index int `json:"index"`
Message openAIChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
}
// GenerateCompletion implements Provider interface.
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
req := openAIChatCompletionRequest{
Model: o.cfg.model,
Messages: []openAIChatMessage{
{Role: "user", Content: prompt},
},
Temperature: o.cfg.temperature,
MaxTokens: o.cfg.maxTokens,
}
body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req)
if err != nil {
return "", fmt.Errorf("openai llm post error: %w", err)
}
var resp openAIChatCompletionResponse
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("openai llm unmarshal error: %w", err)
}
if resp.Error != nil {
return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message)
}
if len(resp.Choices) == 0 {
return "", errors.New("openai llm: empty choices")
}
return resp.Choices[0].Message.Content, nil
}
func (o *OpenAIProvider) GetProviderType() string {
return PROVIDER_TYPE_OPENAI
}

View File

@@ -0,0 +1,24 @@
package llm
import (
"strings"
)
const RAGPromptTemplate = `You are a professional knowledge Q&A assistant. Your task is to provide accurate, complete, and strictly relevant answers based on the user's question and retrieved context.
Retrieved relevant context (may be empty, multiple segments separated by line breaks):
{contexts}
User question:
{query}
Requirements:
1. If the context provides sufficient information, answer directly based on the context. You may use domain knowledge to supplement, but do not fabricate facts beyond the context.
2. If the context is insufficient or unrelated to the question, respond with: "I am unable to answer this question."
3. Your response must correctly answer the user's question and must not contain any irrelevant or unrelated content.`
func BuildPrompt(query string, contexts []string, join string) string {
rendered := strings.ReplaceAll(RAGPromptTemplate, "{query}", query)
rendered = strings.ReplaceAll(rendered, "{contexts}", strings.Join(contexts, join))
return rendered
}

View File

@@ -0,0 +1,53 @@
package llm
import (
"context"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
const (
// OpenAI LLM provider
PROVIDER_TYPE_OPENAI = "openai"
// More providers can be added (e.g., Qwen)
)
// Provider defines interface for LLM providers with prompt-response pattern.
// Extensible for future chat-style and streaming features.
type Provider interface {
// Returns provider type for registration and lookup
GetProviderType() string
// Generates text response for given prompt
//
// ctx: For cancellation and timeout
// prompt: Input text
// Returns: Generated response and error if any
GenerateCompletion(ctx context.Context, prompt string) (string, error)
}
// Factory interface for creating Provider instances
type providerInitializer interface {
// Creates Provider with given config
CreateProvider(config.LLMConfig) (Provider, error)
}
// Maps provider types to initializers
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
}
)
// Creates Provider instance based on config
//
// cfg: Provider config
// Returns: Provider instance and error if any
func NewLLMProvider(cfg config.LLMConfig) (Provider, error) {
initializer, ok := providerInitializers[cfg.Provider]
if !ok {
return nil, fmt.Errorf("no initializer found for llm provider type: %s", cfg.Provider)
}
return initializer.CreateProvider(cfg)
}

View File

@@ -0,0 +1,158 @@
package rag
import (
"context"
"fmt"
"strings"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/embedding"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/llm"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/textsplitter"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/vectordb"
"github.com/distribution/distribution/v3/uuid"
)
const (
MAX_LIST_KNOWLEDGE_ROW_COUNT = 1000
MAX_LIST_DOCUMENT_ROW_COUNT = 1000
)
// RAGClient represents the RAG (Retrieval-Augmented Generation) client
type RAGClient struct {
config *config.Config
vectordbProvider vectordb.VectorStoreProvider
embeddingProvider embedding.Provider
textSplitter textsplitter.TextSplitter
llmProvider llm.Provider
}
// NewRAGClient creates a new RAG client instance
func NewRAGClient(config *config.Config) (*RAGClient, error) {
ragclient := &RAGClient{
config: config,
}
textSplitter, err := textsplitter.NewTextSplitter(&config.RAG.Splitter)
if err != nil {
return nil, fmt.Errorf("create text splitter failed, err: %w", err)
}
ragclient.textSplitter = textSplitter
embeddingProvider, err := embedding.NewEmbeddingProvider(ragclient.config.Embedding)
if err != nil {
return nil, fmt.Errorf("create embedding provider failed, err: %w", err)
}
ragclient.embeddingProvider = embeddingProvider
llmProvider, err := llm.NewLLMProvider(ragclient.config.LLM)
if err != nil {
return nil, fmt.Errorf("create llm provider failed, err: %w", err)
}
ragclient.llmProvider = llmProvider
demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization")
if err != nil {
return nil, fmt.Errorf("create init embedding failed, err: %w", err)
}
dim := len(demoVector)
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
if err != nil {
return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
}
ragclient.vectordbProvider = provider
return ragclient, nil
}
// ListChunks lists document chunks by knowledge ID, returns in ascending order of DocumentIndex
func (r *RAGClient) ListChunks() ([]schema.Document, error) {
docs, err := r.vectordbProvider.ListDocs(context.Background(), MAX_LIST_DOCUMENT_ROW_COUNT)
if err != nil {
return nil, fmt.Errorf("list chunks failed, err: %w", err)
}
return docs, nil
}
// DeleteChunk deletes a specific document chunk
func (r *RAGClient) DeleteChunk(id string) error {
if err := r.vectordbProvider.DeleteDocs(context.Background(), []string{id}); err != nil {
return fmt.Errorf("delete chunk failed, err: %w", err)
}
return nil
}
func (r *RAGClient) CreateChunkFromText(text string, title string) ([]schema.Document, error) {
docs, err := textsplitter.CreateDocuments(r.textSplitter, []string{text}, make([]map[string]any, 0))
if err != nil {
return nil, fmt.Errorf("create documents failed, err: %w", err)
}
results := make([]schema.Document, 0, len(docs))
for chunkIndex, doc := range docs {
doc.ID = uuid.Generate().String()
doc.Metadata["chunk_index"] = chunkIndex
doc.Metadata["chunk_title"] = title
doc.Metadata["chunk_size"] = len(doc.Content)
// Generate embedding for the document
embedding, err := r.embeddingProvider.GetEmbedding(context.Background(), doc.Content)
if err != nil {
return nil, fmt.Errorf("create embedding failed, err: %w", err)
}
doc.Vector = embedding
doc.CreatedAt = time.Now()
results = append(results, doc)
}
if err := r.vectordbProvider.AddDoc(context.Background(), results); err != nil {
return nil, fmt.Errorf("add documents failed, err: %w", err)
}
return results, nil
}
// SearchChunks searches for document chunks
func (r *RAGClient) SearchChunks(query string, topK int, threshold float64) ([]schema.SearchResult, error) {
vector, err := r.embeddingProvider.GetEmbedding(context.Background(), query)
if err != nil {
return nil, fmt.Errorf("create embedding failed, err: %w", err)
}
options := &schema.SearchOptions{
TopK: topK,
Threshold: threshold,
}
docs, err := r.vectordbProvider.SearchDocs(context.Background(), vector, options)
if err != nil {
return nil, fmt.Errorf("search chunks failed, err: %w", err)
}
return docs, nil
}
// Chat generates a response using LLM
func (r *RAGClient) Chat(query string) (string, error) {
if r.llmProvider == nil {
return "", fmt.Errorf("llm provider not initialized")
}
docs, err := r.SearchChunks(query, r.config.RAG.TopK, r.config.RAG.Threshold)
if err != nil {
return "", fmt.Errorf("search chunks failed, err: %w", err)
}
contexts := make([]string, 0, len(docs))
for _, doc := range docs {
contexts = append(contexts, strings.ReplaceAll(doc.Document.Content, "\n", " "))
}
prompt := llm.BuildPrompt(query, contexts, "\n\n")
resp, err := r.llmProvider.GenerateCompletion(context.Background(), prompt)
if err != nil {
return "", fmt.Errorf("generate completion failed, err: %w", err)
}
return resp, nil
}

View File

@@ -0,0 +1,150 @@
package rag
import (
"testing"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
func getRAGClient() (*RAGClient, error) {
config := &config.Config{
RAG: config.RAGConfig{
Splitter: config.SplitterConfig{
Provider: "recursive",
ChunkSize: 200,
ChunkOverlap: 20,
},
Threshold: 0.5,
TopK: 10,
},
LLM: config.LLMConfig{
Provider: "openai",
APIKey: "sk-xxxx",
BaseURL: "https://openrouter.ai/api/v1",
Model: "openai/gpt-4o",
},
Embedding: config.EmbeddingConfig{
Provider: "dashscope",
APIKey: "sk-xxxx",
Model: "text-embedding-v4",
},
VectorDB: config.VectorDBConfig{
Provider: "milvus",
Host: "localhost",
Port: 19530,
Database: "default",
Collection: "test_collection",
},
}
ragClient, err := NewRAGClient(config)
if err != nil {
return nil, err
}
return ragClient, nil
}
func TestNewRAGClient(t *testing.T) {
_, err := getRAGClient()
if err != nil {
t.Errorf("getRAGClient() error = %v", err)
return
}
}
func TestRAGClient_CreateChunkFromText(t *testing.T) {
ragClient, err := getRAGClient()
if err != nil {
t.Errorf("getRAGClient() error = %v", err)
return
}
text := "The multi-agent interaction technology competition based on the openKylin desktop environment aims to promote the development of agent applications on the openKylin open-source OS, using the Kirin AI inference framework and the UKUI desktop environment. These applications should have autonomous planning and decision-making capabilities, access to system resources, and the ability to call system and desktop environment interfaces and tools, with memory functions. They should also be able to collaborate with other agent applications. The competition aims to deeply explore the integration of operating systems and AI and help enhance the international competitiveness of domestic open-source operating systems."
chunkName := "test_chunk3"
docs, err := ragClient.CreateChunkFromText(text, chunkName)
if err != nil {
t.Errorf("CreateChunkFromText() error = %v", err)
return
}
if len(docs) != 1 {
t.Errorf("CreateChunkFromText() docs len = %d, want 1", len(docs))
return
}
}
func TestRAGClient_ListChunks(t *testing.T) {
ragClient, err := getRAGClient()
if err != nil {
t.Errorf("getRAGClient() error = %v", err)
return
}
docs, err := ragClient.ListChunks()
if err != nil {
t.Errorf("ListChunks() error = %v", err)
return
}
if len(docs) == 0 {
t.Errorf("ListChunks() docs len = %d, want > 0", len(docs))
return
}
}
func TestRAGClient_DeleteChunk(t *testing.T) {
ragClient, err := getRAGClient()
if err != nil {
t.Errorf("getRAGClient() error = %v", err)
return
}
chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2"
err = ragClient.DeleteChunk(chunk_id)
if err != nil {
t.Errorf("DeleteChunk() error = %v", err)
return
}
}
func TestRAGClient_SearchChunks(t *testing.T) {
ragClient, err := getRAGClient()
if err != nil {
t.Errorf("getRAGClient() error = %v", err)
return
}
topk := 2
threshold := 0.5
query := "multi-agent"
docs, err := ragClient.SearchChunks(query, topk, threshold)
if err != nil {
t.Errorf("SearchChunks() error = %v", err)
return
}
if len(docs) != topk {
t.Errorf("SearchChunks() docs len = %d, want %d", len(docs), topk)
return
}
}
func TestRAGClient_Chat(t *testing.T) {
ragClient, err := getRAGClient()
if err != nil {
t.Errorf("getRAGClient() error = %v", err)
return
}
query := "what is the competition about?"
resp, err := ragClient.Chat(query)
if err != nil {
t.Errorf("Chat() error = %v", err)
return
}
if resp == "" {
t.Errorf("Chat() resp = %s, want not empty", resp)
return
}
}

View File

@@ -0,0 +1,46 @@
package schema
import "time"
const (
DEFAULT_KNOWLEDGE_COLLECTION = "knowledge"
DEFAULT_DOCUMENT_COLLECTION = "document"
)
// Document represents a document with its vector embedding and metadata
type Document struct {
ID string `json:"id"`
Content string `json:"content"`
Vector []float32 `json:"-"`
Metadata map[string]interface{} `json:"metadata"`
CreatedAt time.Time `json:"created_at"`
}
// SearchResult represents a result from a vector search
type SearchResult struct {
Document Document `json:"document"`
Score float64 `json:"score"`
}
// Knowledge represents a knowledge entity with associated documents
type Knowledge struct {
ID string `json:"id"`
Name string `json:"name"`
SourceURL string `json:"source_url"`
Status string `json:"status"`
FileSize int64 `json:"file_size"`
ChunkCount int `json:"chunk_count"`
EnableMultimodel bool `json:"enable_multimodel"`
Metadata map[string]interface{} `json:"metadata"`
Documents []Document `json:"-"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
CompletedAt time.Time `json:"completed_at,omitempty"`
}
// SearchOptions contains options for vector search
type SearchOptions struct {
TopK int `json:"top_k"`
Threshold float64 `json:"threshold"`
Filters map[string]interface{} `json:"filters,omitempty"`
}

View File

@@ -0,0 +1,198 @@
package rag
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/mark3labs/mcp-go/mcp"
)
const Version = "1.0.0"
type RAGConfig struct {
config *config.Config
}
func init() {
common.GlobalRegistry.RegisterServer("rag", &RAGConfig{
config: &config.Config{
RAG: config.RAGConfig{
Splitter: config.SplitterConfig{
Provider: "recursive",
ChunkSize: 500,
ChunkOverlap: 50,
},
Threshold: 0.5,
TopK: 10,
},
LLM: config.LLMConfig{
Provider: "openai",
APIKey: "",
BaseURL: "",
Model: "gpt-4o",
Temperature: 0.5,
MaxTokens: 2048,
},
Embedding: config.EmbeddingConfig{
Provider: "dashscope",
APIKey: "",
BaseURL: "",
Model: "text-embedding-v4",
Dimension: 1024,
},
VectorDB: config.VectorDBConfig{
Provider: "milvus",
Host: "localhost",
Port: 6379,
Database: "default",
Collection: "rag",
Username: "",
Password: "",
},
},
})
}
func (c *RAGConfig) ParseConfig(config map[string]any) error {
// Parse RAG configuration
if ragConfig, ok := config["rag"].(map[string]any); ok {
if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
if splitterType, exists := splitter["provider"].(string); exists {
c.config.RAG.Splitter.Provider = splitterType
}
if chunkSize, exists := splitter["chunk_size"].(float64); exists {
c.config.RAG.Splitter.ChunkSize = int(chunkSize)
}
if chunkOverlap, exists := splitter["chunk_overlap"].(float64); exists {
c.config.RAG.Splitter.ChunkOverlap = int(chunkOverlap)
}
}
if threshold, exists := ragConfig["threshold"].(float64); exists {
c.config.RAG.Threshold = threshold
}
if topK, exists := ragConfig["top_k"].(float64); exists {
c.config.RAG.TopK = int(topK)
}
}
// Parse Embedding configuration
if embeddingConfig, ok := config["embedding"].(map[string]any); ok {
if provider, exists := embeddingConfig["provider"].(string); exists {
c.config.Embedding.Provider = provider
} else {
return errors.New("missing embedding provider")
}
if apiKey, exists := embeddingConfig["api_key"].(string); exists {
c.config.Embedding.APIKey = apiKey
}
if baseURL, exists := embeddingConfig["base_url"].(string); exists {
c.config.Embedding.BaseURL = baseURL
}
if model, exists := embeddingConfig["model"].(string); exists {
c.config.Embedding.Model = model
}
if dimension, exists := embeddingConfig["dimension"].(float64); exists {
c.config.Embedding.Dimension = int(dimension)
}
}
// Parse llm configuration
if llmConfig, ok := config["llm"].(map[string]any); ok {
if provider, exists := llmConfig["provider"].(string); exists {
c.config.LLM.Provider = provider
} else {
return errors.New("missing llm provider")
}
if apiKey, exists := llmConfig["api_key"].(string); exists {
c.config.LLM.APIKey = apiKey
}
if baseURL, exists := llmConfig["base_url"].(string); exists {
c.config.LLM.BaseURL = baseURL
}
if model, exists := llmConfig["model"].(string); exists {
c.config.LLM.Model = model
}
if temperature, exists := llmConfig["temperature"].(float64); exists {
c.config.LLM.Temperature = temperature
}
if maxTokens, exists := llmConfig["max_tokens"].(float64); exists {
c.config.LLM.MaxTokens = int(maxTokens)
}
}
// Parse VectorDB configuration
if vectordbConfig, ok := config["vectordb"].(map[string]any); ok {
if provider, exists := vectordbConfig["provider"].(string); exists {
c.config.VectorDB.Provider = provider
} else {
return errors.New("missing vectordb provider")
}
if host, exists := vectordbConfig["host"].(string); exists {
c.config.VectorDB.Host = host
}
if port, exists := vectordbConfig["port"].(float64); exists {
c.config.VectorDB.Port = int(port)
}
if dbName, exists := vectordbConfig["database"].(string); exists {
c.config.VectorDB.Database = dbName
}
if collection, exists := vectordbConfig["collection"].(string); exists {
c.config.VectorDB.Collection = collection
}
if username, exists := vectordbConfig["username"].(string); exists {
c.config.VectorDB.Username = username
}
if password, exists := vectordbConfig["password"].(string); exists {
c.config.VectorDB.Password = password
}
}
return nil
}
func (c *RAGConfig) NewServer(serverName string) (*common.MCPServer, error) {
mcpServer := common.NewMCPServer(
serverName,
Version,
common.WithInstructions("This is a RAG (Retrieval-Augmented Generation) server for knowledge management and intelligent Q&A"),
)
// Initialize RAG client with configuration
ragClient, err := NewRAGClient(c.config)
if err != nil {
return nil, fmt.Errorf("create rag client failed, err: %w", err)
}
// Knowledge Base Management Tools
mcpServer.AddTool(
mcp.NewToolWithRawSchema("create-chunks-from-text", "Process and segment input text into semantic chunks for knowledge base ingestion", GetCreateChunkFromTextSchema()),
HandleCreateChunkFromText(ragClient),
)
// Chunk Management Tools
mcpServer.AddTool(
mcp.NewToolWithRawSchema("list-chunks", "Retrieve and display all knowledge chunks in the database", GetListChunksSchema()),
HandleListChunks(ragClient),
)
mcpServer.AddTool(
mcp.NewToolWithRawSchema("delete-chunk", "Remove a specific knowledge chunk from the database using its unique identifier", GetDeleteChunkSchema()),
HandleDeleteChunk(ragClient),
)
// Semantic Search Tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("search-chunks", "Perform semantic search across knowledge chunks using natural language query", GetSearchSchema()),
HandleSearch(ragClient),
)
// Intelligent Q&A Tool
mcpServer.AddTool(
mcp.NewToolWithRawSchema("chat", "Generate contextually relevant responses using RAG system with LLM integration", GetChatSchema()),
HandleChat(ragClient),
)
return mcpServer, nil
}

View File

@@ -0,0 +1,54 @@
package rag
import (
"fmt"
"testing"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"gopkg.in/yaml.v3"
)
func TestRAGConfig_ParseConfig(t *testing.T) {
config := &config.Config{
RAG: config.RAGConfig{
Splitter: config.SplitterConfig{
Provider: "nosplitter",
ChunkSize: 500,
ChunkOverlap: 50,
},
Threshold: 0.5,
TopK: 5,
},
LLM: config.LLMConfig{
Provider: "openai",
APIKey: "sk-XXX",
BaseURL: "https://openrouter.ai/api/v1",
Model: "openai/gpt-4o",
Temperature: 0.5,
MaxTokens: 2048,
},
Embedding: config.EmbeddingConfig{
Provider: "dashscope",
APIKey: "sk-XXX",
BaseURL: "",
Model: "text-embedding-v4",
Dimension: 1024,
},
VectorDB: config.VectorDBConfig{
Provider: "milvus",
Host: "localhost",
Port: 19530,
Database: "default",
Collection: "test_rag",
Username: "",
Password: "",
},
}
// 把 config 输出 yaml 格式
yaml, err := yaml.Marshal(config)
if err != nil {
t.Fatalf("marshal config failed, err: %v", err)
}
t.Logf("config yaml: %s", string(yaml))
fmt.Printf("\n%s", string(yaml))
}

View File

@@ -0,0 +1,165 @@
package textsplitter
import "unicode/utf8"
const (
// nolint:gosec
_defaultTokenModelName = "gpt-3.5-turbo"
_defaultTokenEncoding = "cl100k_base"
_defaultTokenChunkSize = 512
_defaultTokenChunkOverlap = 100
)
// Options is a struct that contains options for a text splitter.
type Options struct {
ChunkSize int
ChunkOverlap int
Separators []string
KeepSeparator bool
LenFunc func(string) int
ModelName string
EncodingName string
AllowedSpecial []string
DisallowedSpecial []string
SecondSplitter TextSplitter
CodeBlocks bool
ReferenceLinks bool
KeepHeadingHierarchy bool // Persist hierarchy of markdown headers in each chunk
JoinTableRows bool
}
// DefaultOptions returns the default options for all text splitter.
func DefaultOptions() Options {
return Options{
ChunkSize: _defaultTokenChunkSize,
ChunkOverlap: _defaultTokenChunkOverlap,
Separators: []string{"\n\n", "\n", " ", ""},
KeepSeparator: false,
LenFunc: utf8.RuneCountInString,
ModelName: _defaultTokenModelName,
EncodingName: _defaultTokenEncoding,
AllowedSpecial: []string{},
DisallowedSpecial: []string{"all"},
KeepHeadingHierarchy: false,
}
}
// Option is a function that can be used to set options for a text splitter.
type Option func(*Options)
// WithChunkSize sets the chunk size for a text splitter.
func WithChunkSize(chunkSize int) Option {
return func(o *Options) {
o.ChunkSize = chunkSize
}
}
// WithChunkOverlap sets the chunk overlap for a text splitter.
func WithChunkOverlap(chunkOverlap int) Option {
return func(o *Options) {
o.ChunkOverlap = chunkOverlap
}
}
// WithSeparators sets the separators for a text splitter.
func WithSeparators(separators []string) Option {
return func(o *Options) {
o.Separators = separators
}
}
// WithLenFunc sets the lenfunc for a text splitter.
func WithLenFunc(lenFunc func(string) int) Option {
return func(o *Options) {
o.LenFunc = lenFunc
}
}
// WithModelName sets the model name for a text splitter.
func WithModelName(modelName string) Option {
return func(o *Options) {
o.ModelName = modelName
}
}
// WithEncodingName sets the encoding name for a text splitter.
func WithEncodingName(encodingName string) Option {
return func(o *Options) {
o.EncodingName = encodingName
}
}
// WithAllowedSpecial sets the allowed special tokens for a text splitter.
func WithAllowedSpecial(allowedSpecial []string) Option {
return func(o *Options) {
o.AllowedSpecial = allowedSpecial
}
}
// WithDisallowedSpecial sets the disallowed special tokens for a text splitter.
func WithDisallowedSpecial(disallowedSpecial []string) Option {
return func(o *Options) {
o.DisallowedSpecial = disallowedSpecial
}
}
// WithSecondSplitter sets the second splitter for a text splitter.
func WithSecondSplitter(secondSplitter TextSplitter) Option {
return func(o *Options) {
o.SecondSplitter = secondSplitter
}
}
// WithCodeBlocks sets whether indented and fenced codeblocks should be included
// in the output.
func WithCodeBlocks(renderCode bool) Option {
return func(o *Options) {
o.CodeBlocks = renderCode
}
}
// WithReferenceLinks sets whether reference links (i.e. `[text][label]`)
// should be patched with the url and title from their definition. Note that
// by default reference definitions are dropped from the output.
//
// Caution: this also affects how other inline elements are rendered, e.g. all
// emphasis will use `*` even when another character (e.g. `_`) was used in the
// input.
func WithReferenceLinks(referenceLinks bool) Option {
return func(o *Options) {
o.ReferenceLinks = referenceLinks
}
}
// WithKeepSeparator sets whether the separators should be kept in the resulting
// split text or not. When it is set to True, the separators are included in the
// resulting split text. When it is set to False, the separators are not included
// in the resulting split text. The purpose of having this parameter is to provide
// flexibility in how text splitting is handled. Default to False if not specified.
func WithKeepSeparator(keepSeparator bool) Option {
return func(o *Options) {
o.KeepSeparator = keepSeparator
}
}
// WithHeadingHierarchy sets whether the hierarchy of headings in a document should
// be persisted in the resulting chunks. When it is set to true, each chunk gets prepended
// with a list of all parent headings in the hierarchy up to this point.
// The purpose of having this parameter is to allow for returning more relevant chunks during
// similarity search. Default to False if not specified.
func WithHeadingHierarchy(trackHeadingHierarchy bool) Option {
return func(o *Options) {
o.KeepHeadingHierarchy = trackHeadingHierarchy
}
}
// WithJoinTableRows sets whether tables should be split by row or not. When it is set to True,
// table rows are joined until the chunksize. When it is set to False (the default), tables are
// split by row.
//
// The default behavior is to split tables by row, so that each row is in a separate chunk.
func WithJoinTableRows(join bool) Option {
return func(o *Options) {
o.JoinTableRows = join
}
}

View File

@@ -0,0 +1,106 @@
package textsplitter
import (
"strings"
)
// RecursiveCharacter is a text splitter that will split texts recursively by different
// characters.
type RecursiveCharacter struct {
Separators []string
ChunkSize int
ChunkOverlap int
LenFunc func(string) int
KeepSeparator bool
}
// NewRecursiveCharacter creates a new recursive character splitter with default values. By
// default, the separators used are "\n\n", "\n", " " and "". The chunk size is set to 4000
// and chunk overlap is set to 200.
func NewRecursiveCharacter(opts ...Option) RecursiveCharacter {
options := DefaultOptions()
for _, o := range opts {
o(&options)
}
s := RecursiveCharacter{
Separators: options.Separators,
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
LenFunc: options.LenFunc,
KeepSeparator: options.KeepSeparator,
}
return s
}
// SplitText splits a text into multiple text.
func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
return s.splitText(text, s.Separators)
}
// addSeparatorInSplits adds the separator in each of splits.
func (s RecursiveCharacter) addSeparatorInSplits(splits []string, separator string) []string {
splitsWithSeparator := make([]string, 0, len(splits))
for i, s := range splits {
if i > 0 {
s = separator + s
}
splitsWithSeparator = append(splitsWithSeparator, s)
}
return splitsWithSeparator
}
func (s RecursiveCharacter) splitText(text string, separators []string) ([]string, error) {
finalChunks := make([]string, 0)
// Find the appropriate separator.
separator := separators[len(separators)-1]
newSeparators := []string{}
for i, c := range separators {
if c == "" || strings.Contains(text, c) {
separator = c
newSeparators = separators[i+1:]
break
}
}
splits := strings.Split(text, separator)
if s.KeepSeparator {
splits = s.addSeparatorInSplits(splits, separator)
separator = ""
}
goodSplits := make([]string, 0)
// Merge the splits, recursively splitting larger texts.
for _, split := range splits {
if s.LenFunc(split) < s.ChunkSize {
goodSplits = append(goodSplits, split)
continue
}
if len(goodSplits) > 0 {
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap, s.LenFunc)
finalChunks = append(finalChunks, mergedText...)
goodSplits = make([]string, 0)
}
if len(newSeparators) == 0 {
finalChunks = append(finalChunks, split)
} else {
otherInfo, err := s.splitText(split, newSeparators)
if err != nil {
return nil, err
}
finalChunks = append(finalChunks, otherInfo...)
}
}
if len(goodSplits) > 0 {
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap, s.LenFunc)
finalChunks = append(finalChunks, mergedText...)
}
return finalChunks, nil
}

View File

@@ -0,0 +1,153 @@
package textsplitter
import (
"strings"
"testing"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/pkoukk/tiktoken-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
//nolint:dupword,funlen
func TestRecursiveCharacterSplitter(t *testing.T) {
tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base")
// t.Parallel()
type testCase struct {
text string
chunkOverlap int
chunkSize int
separators []string
expectedDocs []schema.Document
keepSeparator bool
LenFunc func(string) int
}
testCases := []testCase{
{
text: "哈里森\n很高兴遇见你\n欢迎来中国",
chunkOverlap: 0,
chunkSize: 10,
separators: []string{"\n\n", "\n", " "},
expectedDocs: []schema.Document{
{Content: "哈里森\n很高兴遇见你", Metadata: map[string]any{}},
{Content: "欢迎来中国", Metadata: map[string]any{}},
},
},
{
text: "Hi, Harrison. \nI am glad to meet you",
chunkOverlap: 1,
chunkSize: 20,
separators: []string{"\n", "$"},
expectedDocs: []schema.Document{
{Content: "Hi, Harrison.", Metadata: map[string]any{}},
{Content: "I am glad to meet you", Metadata: map[string]any{}},
},
},
{
text: "Hi.\nI'm Harrison.\n\nHow?\na\nbHi.\nI'm Harrison.\n\nHow?\na\nb",
chunkOverlap: 1,
chunkSize: 40,
separators: []string{"\n\n", "\n", " ", ""},
expectedDocs: []schema.Document{
{Content: "Hi.\nI'm Harrison.", Metadata: map[string]any{}},
{Content: "How?\na\nbHi.\nI'm Harrison.\n\nHow?\na\nb", Metadata: map[string]any{}},
},
},
{
text: "name: Harrison\nage: 30",
chunkOverlap: 1,
chunkSize: 40,
separators: []string{"\n\n", "\n", " ", ""},
expectedDocs: []schema.Document{
{Content: "name: Harrison\nage: 30", Metadata: map[string]any{}},
},
},
{
text: `name: Harrison
age: 30
name: Joe
age: 32`,
chunkOverlap: 1,
chunkSize: 40,
separators: []string{"\n\n", "\n", " ", ""},
expectedDocs: []schema.Document{
{Content: "name: Harrison\nage: 30", Metadata: map[string]any{}},
{Content: "name: Joe\nage: 32", Metadata: map[string]any{}},
},
},
{
text: `Hi.
I'm Harrison.
How? Are? You?
Okay then f f f f.
This is a weird text to write, but gotta test the splittingggg some how.
Bye!
-H.`,
chunkOverlap: 1,
chunkSize: 10,
separators: []string{"\n\n", "\n", " ", ""},
expectedDocs: []schema.Document{
{Content: "Hi.", Metadata: map[string]any{}},
{Content: "I'm", Metadata: map[string]any{}},
{Content: "Harrison.", Metadata: map[string]any{}},
{Content: "How? Are?", Metadata: map[string]any{}},
{Content: "You?", Metadata: map[string]any{}},
{Content: "Okay then", Metadata: map[string]any{}},
{Content: "f f f f.", Metadata: map[string]any{}},
{Content: "This is a", Metadata: map[string]any{}},
{Content: "a weird", Metadata: map[string]any{}},
{Content: "text to", Metadata: map[string]any{}},
{Content: "write, but", Metadata: map[string]any{}},
{Content: "gotta test", Metadata: map[string]any{}},
{Content: "the", Metadata: map[string]any{}},
{Content: "splittingg", Metadata: map[string]any{}},
{Content: "ggg", Metadata: map[string]any{}},
{Content: "some how.", Metadata: map[string]any{}},
{Content: "Bye!\n\n-H.", Metadata: map[string]any{}},
},
},
{
text: "Hi, Harrison. \nI am glad to meet you",
chunkOverlap: 0,
chunkSize: 10,
separators: []string{"\n", "$"},
keepSeparator: true,
expectedDocs: []schema.Document{
{Content: "Hi, Harrison. ", Metadata: map[string]any{}},
{Content: "\nI am glad to meet you", Metadata: map[string]any{}},
},
},
{
text: strings.Repeat("The quick brown fox jumped over the lazy dog. ", 2),
chunkOverlap: 0,
chunkSize: 10,
separators: []string{" "},
keepSeparator: true,
LenFunc: func(s string) int { return len(tokenEncoder.Encode(s, nil, nil)) },
expectedDocs: []schema.Document{
{Content: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
{Content: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
},
},
}
splitter := NewRecursiveCharacter()
for _, tc := range testCases {
splitter.ChunkOverlap = tc.chunkOverlap
splitter.ChunkSize = tc.chunkSize
splitter.Separators = tc.separators
splitter.KeepSeparator = tc.keepSeparator
if tc.LenFunc != nil {
splitter.LenFunc = tc.LenFunc
}
docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
require.NoError(t, err)
assert.Equal(t, tc.expectedDocs, docs)
}
}

View File

@@ -0,0 +1,133 @@
package textsplitter
import (
"errors"
"log"
"strings"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
)
// ErrMismatchMetadatasAndText is returned when the number of texts and metadatas
// given to CreateDocuments does not match. The function will not error if the
// length of the metadatas slice is zero.
var ErrMismatchMetadatasAndText = errors.New("number of texts and metadatas does not match")
// SplitDocuments splits documents using a textsplitter.
func SplitDocuments(textSplitter TextSplitter, documents []schema.Document) ([]schema.Document, error) {
texts := make([]string, 0)
metadatas := make([]map[string]any, 0)
for _, document := range documents {
texts = append(texts, document.Content)
metadatas = append(metadatas, document.Metadata)
}
return CreateDocuments(textSplitter, texts, metadatas)
}
// CreateDocuments creates documents from texts and metadatas with a text splitter. If
// the length of the metadatas is zero, the result documents will contain no metadata.
// Otherwise, the numbers of texts and metadatas must match.
func CreateDocuments(textSplitter TextSplitter, texts []string, metadatas []map[string]any) ([]schema.Document, error) {
if len(metadatas) == 0 {
metadatas = make([]map[string]any, len(texts))
}
if len(texts) != len(metadatas) {
return nil, ErrMismatchMetadatasAndText
}
documents := make([]schema.Document, 0)
for i := 0; i < len(texts); i++ {
chunks, err := textSplitter.SplitText(texts[i])
if err != nil {
return nil, err
}
for _, chunk := range chunks {
// Copy the document metadata
curMetadata := make(map[string]any, len(metadatas[i]))
for key, value := range metadatas[i] {
curMetadata[key] = value
}
documents = append(documents, schema.Document{
Content: chunk,
Metadata: curMetadata,
})
}
}
return documents, nil
}
// joinDocs comines two documents with the separator used to split them.
func joinDocs(docs []string, separator string) string {
return strings.TrimSpace(strings.Join(docs, separator))
}
// mergeSplits merges smaller splits into splits that are closer to the chunkSize.
func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int, lenFunc func(string) int) []string { //nolint:cyclop
docs := make([]string, 0)
currentDoc := make([]string, 0)
total := 0
for _, split := range splits {
totalWithSplit := total + lenFunc(split)
if len(currentDoc) != 0 {
totalWithSplit += lenFunc(separator)
}
maybePrintWarning(total, chunkSize)
if totalWithSplit > chunkSize && len(currentDoc) > 0 {
doc := joinDocs(currentDoc, separator)
if doc != "" {
docs = append(docs, doc)
}
for len(currentDoc) > 0 && shouldPop(chunkOverlap, chunkSize, total, lenFunc(split), lenFunc(separator), len(currentDoc)) {
total -= lenFunc(currentDoc[0]) //nolint:gosec
if len(currentDoc) > 1 {
total -= lenFunc(separator)
}
currentDoc = currentDoc[1:] //nolint:gosec
}
}
currentDoc = append(currentDoc, split)
total += lenFunc(split)
if len(currentDoc) > 1 {
total += lenFunc(separator)
}
}
doc := joinDocs(currentDoc, separator)
if doc != "" {
docs = append(docs, doc)
}
return docs
}
func maybePrintWarning(total, chunkSize int) {
if total > chunkSize {
log.Printf(
"[WARN] created a chunk with size of %v, which is longer then the specified %v\n",
total,
chunkSize,
)
}
}
// Keep popping if:
// - the chunk is larger than the chunk overlap
// - or if there are any chunks and the length is long
func shouldPop(chunkOverlap, chunkSize, total, splitLen, separatorLen, currentDocLen int) bool {
docsNeededToAddSep := 2
if currentDocLen < docsNeededToAddSep {
separatorLen = 0
}
return currentDocLen > 0 && (total > chunkOverlap || (total+splitLen+separatorLen > chunkSize && total > 0))
}

View File

@@ -0,0 +1,30 @@
package textsplitter
import (
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
// TextSplitter is the standard interface for splitting texts.
type TextSplitter interface {
SplitText(text string) ([]string, error)
}
type NoSplitterCharacter struct {
}
func (s NoSplitterCharacter) SplitText(text string) ([]string, error) {
return []string{text}, nil
}
func NewTextSplitter(cfg *config.SplitterConfig) (TextSplitter, error) {
switch cfg.Provider {
case "recursive":
return NewRecursiveCharacter(WithChunkSize(cfg.ChunkSize), WithChunkOverlap(cfg.ChunkOverlap), WithSeparators([]string{"\n\n", "\n", ".", "。", "?", "!", ""})), nil
case "nosplitter":
return NoSplitterCharacter{}, nil
default:
return nil, fmt.Errorf("unknown text splitter type: %s", cfg.Provider)
}
}

View File

@@ -0,0 +1,355 @@
package rag
import (
"context"
"encoding/json"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
"github.com/mark3labs/mcp-go/mcp"
)
// HandleCreateChunkFromText handles the creation of knowledge chunks from text input
func HandleCreateChunkFromText(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
text, ok1 := arguments["text"].(string)
title, ok2 := arguments["title"].(string)
if !ok1 {
return nil, fmt.Errorf("invalid text argument")
}
if !ok2 {
return nil, fmt.Errorf("invalid title argument")
}
// Create knowledge chunks
docs, err := ragClient.CreateChunkFromText(text, title)
if err != nil {
return nil, fmt.Errorf("create chunk failed, err: %w", err)
}
result := map[string]interface{}{
"success": true,
"message": fmt.Sprintf("chunks created from text, title: %s", title),
"data": docs,
}
return buildCallToolResult(result)
}
}
// HandleListChunks handles the listing of knowledge chunks
func HandleListChunks(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
chunks, err := ragClient.ListChunks()
if err != nil {
return nil, fmt.Errorf("list chunks failed, err: %w", err)
}
return buildCallToolResult(chunks)
}
}
// HandleDeleteChunk handles the deletion of a knowledge chunk
func HandleDeleteChunk(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
id, ok := arguments["id"].(string)
if !ok {
return nil, fmt.Errorf("invalid id argument")
}
if err := ragClient.DeleteChunk(id); err != nil {
return nil, fmt.Errorf("delete chunk failed, err: %w", err)
}
result := map[string]interface{}{
"success": true,
"message": fmt.Sprintf("chunk deleted, id: %s", id),
}
return buildCallToolResult(result)
}
}
// HandleCreateSession handles the creation of a chat session
func HandleCreateSession(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// TODO: Implement chat session creation logic
result := map[string]interface{}{
"session_id": "session-1",
"created_at": "2024-01-01T00:00:00Z",
}
return buildCallToolResult(result)
}
}
// HandleGetSession handles retrieving session details
func HandleGetSession(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
sessionId, ok := arguments["session_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid session_id argument")
}
// TODO: Implement session details retrieval logic
result := map[string]interface{}{
"session_id": sessionId,
"messages": []interface{}{},
}
return buildCallToolResult(result)
}
}
// HandleListSessions handles listing all sessions
func HandleListSessions(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// TODO: Implement session listing logic
result := map[string]interface{}{
"sessions": []interface{}{},
"total": 0,
}
return buildCallToolResult(result)
}
}
// HandleDeleteSession handles the deletion of a session
func HandleDeleteSession(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
sessionId, ok := arguments["session_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid session_id argument")
}
// TODO: Implement session deletion logic
result := map[string]interface{}{
"success": true,
"message": "Session deleted",
"session_id": sessionId,
}
return buildCallToolResult(result)
}
}
// HandleSearch handles semantic search functionality
func HandleSearch(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
query, ok := arguments["query"].(string)
if !ok {
return nil, fmt.Errorf("invalid query argument")
}
topK, ok := arguments["topk"].(int)
if !ok {
topK = ragClient.config.RAG.TopK
}
threshold, ok := arguments["threshold"].(float64)
if !ok {
threshold = ragClient.config.RAG.Threshold
}
searchResult, err := ragClient.SearchChunks(query, int(topK), threshold)
if err != nil {
return nil, fmt.Errorf("search chunks failed, err: %w", err)
}
return buildCallToolResult(searchResult)
}
}
// HandleChat handles chat interactions using LLM
func HandleChat(ragClient *RAGClient) common.ToolHandlerFunc {
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
arguments := request.Params.Arguments
query, ok := arguments["query"].(string)
if !ok {
return nil, fmt.Errorf("invalid query argument")
}
// Generate response using RAGClient's LLM
reply, err := ragClient.Chat(query)
if err != nil {
return nil, fmt.Errorf("chat failed, err: %w", err)
}
return buildCallToolResult(reply)
}
}
// buildCallToolResult builds the call tool result
func buildCallToolResult(results any) (*mcp.CallToolResult, error) {
jsonData, err := json.Marshal(results)
if err != nil {
return nil, fmt.Errorf("failed to marshal results: %w", err)
}
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: string(jsonData),
},
},
}, nil
}
// Schema functions
// GetCreateChunkFromTextSchema returns the schema for create chunk from text tool
func GetCreateChunkFromTextSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"text": {
"type": "string",
"description": "The text content to create chunks from"
},
"title": {
"type": "string",
"description": "The title of text content"
}
},
"required": ["text", "title"]
}`)
}
// GetListKnowledgeSchema returns the schema for list knowledge tool
func GetListKnowledgeSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {}
}`)
}
// GetGetKnowledgeSchema returns the schema for get knowledge tool
func GetGetKnowledgeSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "The knowledge ID"
}
},
"required": ["id"]
}`)
}
// GetDeleteKnowledgeSchema returns the schema for delete knowledge tool
func GetDeleteKnowledgeSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "The knowledge ID to delete"
}
},
"required": ["id"]
}`)
}
// GetListChunksSchema returns the schema for list chunks tool
func GetListChunksSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {}
}`)
}
// GetDeleteChunkSchema returns the schema for delete chunk tool
func GetDeleteChunkSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "The chunk ID to delete"
}
},
"required": ["id"]
}`)
}
// GetCreateSessionSchema returns the schema for create session tool
func GetCreateSessionSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {}
}`)
}
// GetGetSessionSchema returns the schema for get session tool
func GetGetSessionSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "The session ID"
}
},
"required": ["session_id"]
}`)
}
// GetListSessionsSchema returns the schema for list sessions tool
func GetListSessionsSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {}
}`)
}
// GetDeleteSessionSchema returns the schema for delete session tool
func GetDeleteSessionSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"session_id": {
"type": "string",
"description": "The session ID to delete"
}
},
"required": ["session_id"]
}`)
}
// GetSearchSchema returns the schema for search tool
func GetSearchSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"topk": {
"type": "integer",
"description": "The number of top results to return (optional, default 10)"
},
"threshold": {
"type": "number",
"description": "The relevance score threshold for filtering results (optional, default 0.5)"
}
},
"required": ["query"]
}`)
}
// GetChatSchema returns the schema for chat tool
func GetChatSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "User query"
}
},
"required": ["query"]
}`)
}

View File

@@ -0,0 +1,495 @@
package vectordb
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)
const (
MILVUS_DUMMY_DIM = 8
MILVUS_PROVIDER_TYPE = "milvus"
)
// MilvusProviderInitializer initializes the Milvus vector store provider
type milvusProviderInitializer struct{}
// InitConfig initializes the configuration with default values if not set
func (m *milvusProviderInitializer) InitConfig(cfg *config.VectorDBConfig) error {
if cfg.Provider != MILVUS_PROVIDER_TYPE {
return fmt.Errorf("provider type mismatch: expected %s, got %s", MILVUS_PROVIDER_TYPE, cfg.Provider)
}
// Set default values
if cfg.Host == "" {
cfg.Host = "localhost"
}
if cfg.Port == 0 {
cfg.Port = 19530
}
if cfg.Database == "" {
cfg.Database = "default"
}
if cfg.Collection == "" {
cfg.Collection = schema.DEFAULT_DOCUMENT_COLLECTION
}
return nil
}
// ValidateConfig validates the configuration parameters
func (m *milvusProviderInitializer) ValidateConfig(cfg *config.VectorDBConfig) error {
if cfg.Host == "" {
return fmt.Errorf("milvus host is required")
}
if cfg.Port <= 0 {
return fmt.Errorf("milvus port must be positive")
}
if cfg.Database == "" {
return fmt.Errorf("milvus database is required")
}
if cfg.Collection == "" {
return fmt.Errorf("milvus document collection is required")
}
return nil
}
// CreateProvider creates a new Milvus vector store provider instance
func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
if err := m.InitConfig(cfg); err != nil {
return nil, err
}
if err := m.ValidateConfig(cfg); err != nil {
return nil, err
}
provider, err := NewMilvusProvider(cfg, dim)
return provider, err
}
// MilvusProvider implements the vector store provider interface for Milvus
type MilvusProvider struct {
client client.Client
config *config.VectorDBConfig
Collection string
}
// NewMilvusProvider creates a new instance of MilvusProvider
func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
// Create Milvus client
connectParam := client.Config{
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
}
connectParam.DBName = cfg.Database
// Add authentication if credentials are provided
if cfg.Username != "" && cfg.Password != "" {
connectParam.Username = cfg.Username
connectParam.Password = cfg.Password
}
milvusClient, err := client.NewClient(context.Background(), connectParam)
if err != nil {
return nil, fmt.Errorf("failed to create milvus client: %w", err)
}
provider := &MilvusProvider{
client: milvusClient,
config: cfg,
Collection: cfg.Collection,
}
ctx := context.Background()
if err := provider.CreateCollection(ctx, dim); err != nil {
return nil, err
}
return provider, nil
}
// CreateCollection creates a new collection with the specified dimension
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
// Check if collection exists
document_exists, err := m.client.HasCollection(ctx, m.Collection)
if err != nil {
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
}
if !document_exists {
fmt.Printf("create collection %s\n", m.Collection)
// Create schema
schema := entity.NewSchema().
WithName(m.Collection).
WithDescription("Knowledge document collection").
WithAutoID(false).
WithDynamicFieldEnabled(false)
// Add fields based on schema.Document structure
// Primary key field - ID
pkField := entity.NewField().
WithName("id").
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(256).
WithIsPrimaryKey(true).
WithIsAutoID(false)
schema.WithField(pkField)
// Content field
contentField := entity.NewField().
WithName("content").
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(8192)
schema.WithField(contentField)
// Vector field
vectorField := entity.NewField().
WithName("vector").
WithDataType(entity.FieldTypeFloatVector).
WithDim(int64(dim))
schema.WithField(vectorField)
// Metadata field
metadataField := entity.NewField().
WithName("metadata").
WithDataType(entity.FieldTypeJSON)
schema.WithField(metadataField)
// CreatedAt field (stored as Unix timestamp)
createdAtField := entity.NewField().
WithName("created_at").
WithDataType(entity.FieldTypeInt64)
schema.WithField(createdAtField)
// Create collection
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return fmt.Errorf("failed to create collection: %w", err)
}
// Create vector index
vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64)
if err != nil {
return fmt.Errorf("failed to create vector index: %w", err)
}
err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index"))
if err != nil {
return fmt.Errorf("failed to create vector index: %w", err)
}
}
// Load collection
err = m.client.LoadCollection(ctx, m.Collection, false)
if err != nil {
return fmt.Errorf("failed to load document collection: %w", err)
}
return nil
}
// DropCollection removes the collection from the database
func (m *MilvusProvider) DropCollection(ctx context.Context) error {
// Check if collection exists
exists, err := m.client.HasCollection(ctx, m.Collection)
if err != nil {
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
}
if !exists {
return fmt.Errorf("collection %s does not exist", m.Collection)
}
// Drop collection
err = m.client.DropCollection(ctx, m.Collection)
if err != nil {
return fmt.Errorf("failed to drop collection: %w", err)
}
return nil
}
// AddDoc adds documents to the vector database
func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) error {
if len(docs) == 0 {
return nil
}
// Prepare data
ids := make([]string, len(docs))
contents := make([]string, len(docs))
vectors := make([][]float32, len(docs))
metadatas := make([][]byte, len(docs))
createdAts := make([]int64, len(docs))
for i, doc := range docs {
ids[i] = doc.ID
contents[i] = doc.Content
// Convert vector type
vectorFloat32 := make([]float32, len(doc.Vector))
for j, v := range doc.Vector {
vectorFloat32[j] = float32(v)
}
vectors[i] = vectorFloat32
// Serialize metadata
metadataBytes, err := json.Marshal(doc.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
}
metadatas[i] = metadataBytes
createdAts[i] = doc.CreatedAt.UnixMilli()
}
// Build insert data
columns := []entity.Column{
entity.NewColumnVarChar("id", ids),
entity.NewColumnVarChar("content", contents),
entity.NewColumnFloatVector("vector", len(vectors[0]), vectors),
entity.NewColumnJSONBytes("metadata", metadatas),
entity.NewColumnInt64("created_at", createdAts),
}
// Insert data
_, err := m.client.Insert(ctx, m.Collection, "", columns...)
if err != nil {
return fmt.Errorf("failed to insert documents: %w", err)
}
// Flush data
err = m.client.Flush(ctx, m.Collection, false)
if err != nil {
return fmt.Errorf("failed to flush collection: %w", err)
}
return nil
}
// DeleteDoc deletes a document by its ID
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
// Build delete expression
expr := fmt.Sprintf(`id == "%s"`, id)
// Delete data
err := m.client.Delete(ctx, m.Collection, "", expr)
if err != nil {
return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
}
// Flush data
err = m.client.Flush(ctx, m.Collection, false)
if err != nil {
return fmt.Errorf("failed to flush collection after delete: %w", err)
}
return nil
}
// UpdateDoc updates documents by first deleting existing ones and then adding new ones
func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document) error {
// Delete existing documents
ids := make([]string, len(docs))
for i, doc := range docs {
ids[i] = doc.ID
}
if err := m.DeleteDocs(ctx, ids); err != nil {
return fmt.Errorf("failed to delete existing documents: %w", err)
}
// Add new documents
if err := m.AddDoc(ctx, docs); err != nil {
return fmt.Errorf("failed to add new documents: %w", err)
}
return nil
}
// SearchDocs performs similarity search for documents
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
if options == nil {
options = &schema.SearchOptions{TopK: 10}
}
// Build search parameters
sp, _ := entity.NewIndexHNSWSearchParam(16)
// Build filter expression
expr := ""
searchResults, err := m.client.Search(
ctx,
m.Collection,
[]string{}, // partition names
expr, // filter expression
[]string{"id", "content", "metadata", "created_at"}, // output fields
[]entity.Vector{entity.FloatVector(vector)},
"vector", // anns_field
entity.IP, // metric_type
options.TopK,
sp,
)
if err != nil {
return nil, fmt.Errorf("failed to search documents: %w", err)
}
// Parse results
var results []schema.SearchResult
for _, result := range searchResults {
for i := 0; i < result.ResultCount; i++ {
id, _ := result.IDs.Get(i)
score := result.Scores[i]
// Get field data
var content string
var metadata map[string]interface{}
for _, field := range result.Fields {
switch field.Name() {
case "content":
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
if contentVal, err := contentCol.Get(i); err == nil {
if contentStr, ok := contentVal.(string); ok {
content = contentStr
}
}
}
case "metadata":
if metaCol, ok := field.(*entity.ColumnJSONBytes); ok {
if metaVal, err := metaCol.Get(i); err == nil {
if metaBytes, ok := metaVal.([]byte); ok {
if err := json.Unmarshal(metaBytes, &metadata); err != nil {
metadata = make(map[string]interface{})
}
}
}
}
}
}
searchResult := schema.SearchResult{
Document: schema.Document{
ID: fmt.Sprintf("%s", id),
Content: content,
Metadata: metadata,
},
Score: float64(score),
}
results = append(results, searchResult)
}
}
return results, nil
}
// DeleteDocs deletes multiple documents by their IDs
func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return nil
}
// Build delete expression
// Milvus expects string values to be quoted within the expression, otherwise the parser will
// treat the hyphen inside UUID as a minus operator and raise a parse error.
quotedIDs := make([]string, len(ids))
for i, id := range ids {
quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
}
expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ","))
// Delete data
err := m.client.Delete(ctx, m.Collection, "", expr)
if err != nil {
return fmt.Errorf("failed to delete documents: %w", err)
}
// Flush data
err = m.client.Flush(ctx, m.Collection, false)
if err != nil {
return fmt.Errorf("failed to flush collection after delete: %w", err)
}
return nil
}
// ListDocs retrieves all documents with optional limit
func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Document, error) {
// Build query expression
expr := ""
// Query all relevant documents
queryResult, err := m.client.Query(
ctx,
m.Collection,
[]string{}, // partitions
expr, // filter condition
[]string{"id", "content", "metadata", "created_at"},
client.WithOffset(0), client.WithLimit(int64(limit)),
)
if err != nil {
return nil, fmt.Errorf("failed to query documents: %w", err)
}
if len(queryResult) == 0 {
return []schema.Document{}, nil
}
rowCount := queryResult[0].Len()
documents := make([]schema.Document, 0, rowCount)
// Parse query results
for i := 0; i < rowCount; i++ {
var (
id string
content string
metadata map[string]interface{}
createdAt int64
)
for _, col := range queryResult {
switch col.Name() {
case "id":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
id = v.(string)
}
case "content":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
content = v.(string)
}
case "metadata":
if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil {
if bytes, ok := v.([]byte); ok {
_ = json.Unmarshal(bytes, &metadata)
}
}
case "created_at":
if v, err := col.(*entity.ColumnInt64).Get(i); err == nil {
createdAt = v.(int64)
}
}
}
doc := schema.Document{
ID: id,
Content: content,
Metadata: metadata,
CreatedAt: time.UnixMilli(createdAt),
}
documents = append(documents, doc)
}
return documents, nil
}
// GetProviderType returns the provider type identifier
func (m *MilvusProvider) GetProviderType() string {
return MILVUS_PROVIDER_TYPE
}
// Close closes the connection to the Milvus server
func (m *MilvusProvider) Close() error {
if m.client != nil {
return m.client.Close()
}
return nil
}
// joinStrings joins a slice of strings with the given separator
func joinStrings(elems []string, sep string) string {
return strings.Join(elems, sep)
}

View File

@@ -0,0 +1,27 @@
package vectordb
import (
"testing"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
func TestNewMilvusProvider(t *testing.T) {
_, err := getMilvusProvider()
if err != nil {
t.Fatalf("expected error when connecting to unavailable Milvus server, got nil")
}
}
func getMilvusProvider() (VectorStoreProvider, error) {
cfg := &config.VectorDBConfig{
Provider: PROVIDER_TYPE_MILVUS,
Host: "127.0.0.1",
Port: 19530, // unlikely to be used
Database: "default",
Collection: "knowledge_test",
}
provider, err := NewMilvusProvider(cfg, 128)
return provider, err
}

View File

@@ -0,0 +1,72 @@
package vectordb
import (
"context"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
)
// Provider types constants
const (
PROVIDER_TYPE_CHROMA = "chroma"
PROVIDER_TYPE_PINECONE = "pinecone"
PROVIDER_TYPE_WEAVIATE = "weaviate"
PROVIDER_TYPE_QDRANT = "qdrant"
PROVIDER_TYPE_MILVUS = "milvus"
PROVIDER_TYPE_FAISS = "faiss"
PROVIDER_TYPE_ELASTICSEARCH = "elasticsearch"
)
// VectorStoreBase defines the base interface for vector store implementations
type VectorStoreProvider interface {
// CreateVectorStore creates a new vector store
CreateCollection(ctx context.Context, dim int) error
// DropVectorStore drops the vector store
DropCollection(ctx context.Context) error
// AddDoc adds documents to the vector store
AddDoc(ctx context.Context, docs []schema.Document) error
// DeleteDoc deletes documents by filename from the vector store
DeleteDoc(ctx context.Context, id string) error
// UpdateDoc updates documents in the vector store
UpdateDoc(ctx context.Context, docs []schema.Document) error
// SearchDocs searches for similar documents in the vector store
SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error)
// DeleteDocs deletes documents by IDs from the vector store
DeleteDocs(ctx context.Context, ids []string) error
// ListDocs lists documents in the vector store
ListDocs(ctx context.Context, limit int) ([]schema.Document, error)
// GetProviderType returns the type of the vector store provider
GetProviderType() string
}
// VectorDBProviderInitializer defines the interface for vector database provider initializers
type VectorDBProviderInitializer interface {
// CreateProvider creates a new vector database provider instance
CreateProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error)
}
var (
vectorDBProviderInitializers = map[string]VectorDBProviderInitializer{
PROVIDER_TYPE_MILVUS: &milvusProviderInitializer{},
}
)
// CreateVectorDBProvider creates a vector database provider instance
func NewVectorDBProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
initializer, exists := vectorDBProviderInitializers[cfg.Provider]
if !exists {
return nil, fmt.Errorf("unknown vector database provider: %s", cfg.Provider)
}
// Create provider
return initializer.CreateProvider(cfg, dim)
}