mirror of
https://github.com/alibaba/higress.git
synced 2026-03-17 08:50:46 +08:00
feat: add rag mcp server (#2930)
This commit is contained in:
@@ -46,14 +46,17 @@ require (
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/clbanning/mxj/v2 v2.5.5 // indirect
|
||||
github.com/deckarep/golang-set v1.7.1 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/golang/mock v1.6.0 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||
github.com/prometheus/client_golang v1.14.0 // indirect
|
||||
github.com/prometheus/client_model v0.4.0 // indirect
|
||||
github.com/prometheus/common v0.37.0 // indirect
|
||||
|
||||
@@ -136,6 +136,8 @@ github.com/deckarep/golang-set v1.7.1 h1:SCQV0S6gTtp6itiFrTqI+pfmJ4LN85S1YzhDf9r
|
||||
github.com/deckarep/golang-set v1.7.1/go.mod h1:93vsz/8Wt4joVM7c2AVqh+YRMiUSc14yDtF28KmMOgQ=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
@@ -290,6 +292,8 @@ github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 h1:Xqf+S7iicElwYoS2Zly8Nf/zKHuZsNy1xQajfdtygVY=
|
||||
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2/go.mod h1:ulO1YUXKH0PGg50q27grw048GDY9ayB4FPmh7D+FFTA=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -318,6 +322,8 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
||||
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/registry/nacos"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/gorm"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/higress/higress-api"
|
||||
_ "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag"
|
||||
mcp_session "github.com/alibaba/higress/plugins/golang-filter/mcp-session"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
xds "github.com/cncf/xds/go/xds/type/v3"
|
||||
|
||||
173
plugins/golang-filter/mcp-server/servers/rag/READMD.md
Normal file
173
plugins/golang-filter/mcp-server/servers/rag/READMD.md
Normal file
@@ -0,0 +1,173 @@
|
||||
# Higress RAG MCP Server
|
||||
|
||||
这是一个 Model Context Protocol (MCP) 服务器,提供知识管理和检索功能。
|
||||
|
||||
该 MCP 服务器提供以下工具:
|
||||
|
||||
## MCP Tools
|
||||
|
||||
### 知识管理
|
||||
- `create-chunks-from-text` - 从 Text 创建知识 (p1)
|
||||
|
||||
### 块管理
|
||||
- `list-chunks` - 列出知识块
|
||||
- `delete-chunk` - 删除知识块
|
||||
|
||||
### 搜索
|
||||
- `search` - 搜索
|
||||
|
||||
### 聊天功能
|
||||
- `chat` - 发送聊天消息
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 配置结构
|
||||
|
||||
```yaml
|
||||
rag:
|
||||
# RAG系统基础配置
|
||||
splitter:
|
||||
type: "recursive" # 递归分块器 recursive 和 nosplitter
|
||||
chunk_size: 500
|
||||
chunk_overlap: 50
|
||||
top_k: 5 # 搜索返回的知识块数量
|
||||
threshold: 0.5 # 搜索阈值
|
||||
|
||||
llm:
|
||||
provider: "openai" # openai
|
||||
api_key: "your-llm-api-key"
|
||||
base_url: "https://api.openai.com/v1" # 可选
|
||||
model: "gpt-3.5-turbo" # LLM模型
|
||||
max_tokens: 2048 # 最大令牌数
|
||||
temperature: 0.5 # 温度参数
|
||||
|
||||
embedding:
|
||||
provider: "openai" # openai, dashscope
|
||||
api_key: "your-embedding-api-key"
|
||||
base_url: "https://api.openai.com/v1" # 可选
|
||||
model: "text-embedding-ada-002" # 嵌入模型
|
||||
|
||||
vectordb:
|
||||
provider: "milvus" # milvus
|
||||
host: "localhost"
|
||||
port: 19530
|
||||
database: "default"
|
||||
collection: "test_collection"
|
||||
username: "" # 可选
|
||||
password: "" # 可选
|
||||
|
||||
```
|
||||
### higress-config 配置样例
|
||||
|
||||
```yaml
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: higress-config
|
||||
namespace: higress-system
|
||||
data:
|
||||
higress: |
|
||||
mcpServer:
|
||||
enable: true
|
||||
sse_path_suffix: "/sse"
|
||||
redis:
|
||||
address: "<Redis IP>:6379"
|
||||
username: ""
|
||||
password: ""
|
||||
db: 0
|
||||
match_list:
|
||||
- path_rewrite_prefix: ""
|
||||
upstream_type: ""
|
||||
enable_path_rewrite: false
|
||||
match_rule_domain: ""
|
||||
match_rule_path: "/mcp-servers/rag"
|
||||
match_rule_type: "prefix"
|
||||
servers:
|
||||
- path: "/mcp-servers/rag"
|
||||
name: "rag"
|
||||
type: "rag"
|
||||
config:
|
||||
rag:
|
||||
splitter:
|
||||
provider: recursive
|
||||
chunk_size: 500
|
||||
chunk_overlap: 50
|
||||
top_k: 10
|
||||
threshold: 0.5
|
||||
llm:
|
||||
provider: openai
|
||||
api_key: sk-XXX
|
||||
base_url: https://openrouter.ai/api/v1
|
||||
model: openai/gpt-4o
|
||||
temperature: 0.5
|
||||
max_tokens: 2048
|
||||
embedding:
|
||||
provider: dashscope
|
||||
api_key: sk-xxx
|
||||
model: text-embedding-v4
|
||||
vectordb:
|
||||
provider: milvus
|
||||
host: <milvus IP>
|
||||
port: 19530
|
||||
database: default
|
||||
collection: test_collection
|
||||
```
|
||||
|
||||
### 支持的提供商
|
||||
#### Embedding
|
||||
- **OpenAI**
|
||||
- **DashScope**
|
||||
|
||||
#### Vector Database
|
||||
- **Milvus**
|
||||
|
||||
#### LLM
|
||||
- **OpenAI**
|
||||
|
||||
|
||||
## Milvus 安装
|
||||
|
||||
### Docker 配置
|
||||
配置 Docker Desktop 镜像加速器
|
||||
编辑 daemon.json 配置,加上镜像加速器,例如:
|
||||
```
|
||||
{
|
||||
"registry-mirrors": [
|
||||
"https://docker.m.daocloud.io",
|
||||
"https://mirror.ccs.tencentyun.com",
|
||||
"https://hub-mirror.c.163.com"
|
||||
],
|
||||
"dns": ["8.8.8.8", "1.1.1.1"]
|
||||
}
|
||||
```
|
||||
|
||||
### 安装 milvus
|
||||
|
||||
```
|
||||
v2.6.0
|
||||
Download the configuration file
|
||||
wget https://github.com/milvus-io/milvus/releases/download/v2.6.0/milvus-standalone-docker-compose.yml -O docker-compose.yml
|
||||
|
||||
v2.4
|
||||
$ wget https://github.com/milvus-io/milvus/releases/download/v2.4.23/milvus-standalone-docker-compose.yml -O docker-compose.yml
|
||||
|
||||
# Start Milvus
|
||||
$ sudo docker compose up -d
|
||||
|
||||
Creating milvus-etcd ... done
|
||||
Creating milvus-minio ... done
|
||||
Creating milvus-standalone ... done
|
||||
```
|
||||
|
||||
### 安装 attu
|
||||
|
||||
Attu 是 Milvus 的可视化管理工具,用于查看和管理 Milvus 中的数据。
|
||||
|
||||
```
|
||||
docker run -p 8000:3000 -e MILVUS_URL=http://<本机 IP>:19530 zilliz/attu:v2.6
|
||||
Open your browser and navigate to http://localhost:8000
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTPClient handles HTTP API connections and operations
|
||||
type HTTPClient struct {
|
||||
baseURL string
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewHTTPClient creates a new HTTP client with base URL and optional headers
|
||||
func NewHTTPClient(baseURL string, headers map[string]string) *HTTPClient {
|
||||
client := &HTTPClient{
|
||||
baseURL: baseURL,
|
||||
headers: make(map[string]string),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// Copy headers to avoid external modification
|
||||
if headers != nil {
|
||||
for k, v := range headers {
|
||||
client.headers[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// SetHeader sets a header for all requests
|
||||
func (c *HTTPClient) SetHeader(key, value string) {
|
||||
c.headers[key] = value
|
||||
}
|
||||
|
||||
// SetHeaders sets multiple headers for all requests
|
||||
func (c *HTTPClient) SetHeaders(headers map[string]string) {
|
||||
for k, v := range headers {
|
||||
c.headers[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveHeader removes a header
|
||||
func (c *HTTPClient) RemoveHeader(key string) {
|
||||
delete(c.headers, key)
|
||||
}
|
||||
|
||||
// Get performs a GET request
|
||||
func (c *HTTPClient) Get(path string) ([]byte, error) {
|
||||
return c.request("GET", path, nil)
|
||||
}
|
||||
|
||||
// Post performs a POST request
|
||||
func (c *HTTPClient) Post(path string, data interface{}) ([]byte, error) {
|
||||
return c.request("POST", path, data)
|
||||
}
|
||||
|
||||
// Put performs a PUT request
|
||||
func (c *HTTPClient) Put(path string, data interface{}) ([]byte, error) {
|
||||
return c.request("PUT", path, data)
|
||||
}
|
||||
|
||||
// Delete performs a DELETE request
|
||||
func (c *HTTPClient) Delete(path string) ([]byte, error) {
|
||||
return c.request("DELETE", path, nil)
|
||||
}
|
||||
|
||||
// Patch performs a PATCH request
|
||||
func (c *HTTPClient) Patch(path string, data interface{}) ([]byte, error) {
|
||||
return c.request("PATCH", path, data)
|
||||
}
|
||||
|
||||
// RequestWithHeaders performs a request with additional headers for this request only
|
||||
func (c *HTTPClient) RequestWithHeaders(method, path string, data interface{}, additionalHeaders map[string]string) ([]byte, error) {
|
||||
return c.requestWithHeaders(method, path, data, additionalHeaders)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) request(method, path string, data interface{}) ([]byte, error) {
|
||||
return c.requestWithHeaders(method, path, data, nil)
|
||||
}
|
||||
|
||||
func (c *HTTPClient) requestWithHeaders(method, path string, data interface{}, additionalHeaders map[string]string) ([]byte, error) {
|
||||
url := c.baseURL + path
|
||||
|
||||
var body io.Reader
|
||||
if data != nil {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request data: %w", err)
|
||||
}
|
||||
body = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set default headers
|
||||
for k, v := range c.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// Set additional headers for this request
|
||||
if additionalHeaders != nil {
|
||||
for k, v := range additionalHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Set Content-Type for requests with body
|
||||
if data != nil && req.Header.Get("Content-Type") == "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("HTTP error %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
// SetTimeout sets the HTTP client timeout
|
||||
func (c *HTTPClient) SetTimeout(timeout time.Duration) {
|
||||
c.httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// GetBaseURL returns the base URL
|
||||
func (c *HTTPClient) GetBaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
// SetBaseURL sets the base URL
|
||||
func (c *HTTPClient) SetBaseURL(baseURL string) {
|
||||
c.baseURL = baseURL
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package config
|
||||
|
||||
// Config represents the main configuration structure for the MCP server
|
||||
type Config struct {
|
||||
RAG RAGConfig `json:"rag" yaml:"rag"`
|
||||
LLM LLMConfig `json:"llm" yaml:"llm"`
|
||||
Embedding EmbeddingConfig `json:"embedding" yaml:"embedding"`
|
||||
VectorDB VectorDBConfig `json:"vectordb" yaml:"vectordb"`
|
||||
}
|
||||
|
||||
// RAGConfig contains basic configuration for the RAG system
|
||||
type RAGConfig struct {
|
||||
Splitter SplitterConfig `json:"splitter" yaml:"splitter"`
|
||||
Threshold float64 `json:"threshold,omitempty" yaml:"threshold,omitempty"`
|
||||
TopK int `json:"top_k,omitempty" yaml:"top_k,omitempty"`
|
||||
}
|
||||
|
||||
// SplitterConfig defines document splitter configuration
|
||||
type SplitterConfig struct {
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: recursive, character, token
|
||||
ChunkSize int `json:"chunk_size,omitempty" yaml:"chunk_size,omitempty"`
|
||||
ChunkOverlap int `json:"chunk_overlap,omitempty" yaml:"chunk_overlap,omitempty"`
|
||||
}
|
||||
|
||||
// LLMConfig defines configuration for Large Language Models
|
||||
type LLMConfig struct {
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope, qwen
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key"`
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||
Model string `json:"model" yaml:"model"`
|
||||
Temperature float64 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingConfig defines configuration for embedding models
|
||||
type EmbeddingConfig struct {
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"`
|
||||
}
|
||||
|
||||
// VectorDBConfig defines configuration for vector databases
|
||||
type VectorDBConfig struct {
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
|
||||
Host string `json:"host,omitempty" yaml:"host,omitempty"`
|
||||
Port int `json:"port,omitempty" yaml:"port,omitempty"`
|
||||
Database string `json:"database,omitempty" yaml:"database,omitempty"`
|
||||
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
const (
|
||||
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
|
||||
DASHSCOPE_PORT = 443
|
||||
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4"
|
||||
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
)
|
||||
|
||||
var dashScopeConfig dashScopeProviderConfig
|
||||
|
||||
type dashScopeProviderInitializer struct {
|
||||
}
|
||||
type dashScopeProviderConfig struct {
|
||||
apiKey string
|
||||
model string
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) {
|
||||
dashScopeConfig.apiKey = config.APIKey
|
||||
dashScopeConfig.model = config.Model
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) ValidateConfig() error {
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
return errors.New("[DashScope] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
||||
c.InitConfig(config)
|
||||
err := c.ValidateConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + config.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers)
|
||||
|
||||
return &DashScopeProvider{
|
||||
config: dashScopeConfig,
|
||||
client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_DASHSCOPE
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
|
||||
type Params struct {
|
||||
TextType string `json:"text_type"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameters Params `json:"parameters"`
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
Vector []float64 `json:"vector"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
type DashScopeProvider struct {
|
||||
config dashScopeProviderConfig
|
||||
client *common.HTTPClient
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) {
|
||||
model := d.config.model
|
||||
if model == "" {
|
||||
model = DASHSCOPE_DEFAULT_MODEL_NAME
|
||||
}
|
||||
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
return EmbeddingRequest{}, errors.New("dashScopeKey is empty")
|
||||
}
|
||||
|
||||
data := EmbeddingRequest{
|
||||
Model: model,
|
||||
Input: Input{
|
||||
Texts: texts,
|
||||
},
|
||||
Parameters: Params{
|
||||
TextType: "query",
|
||||
},
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector,omitempty"`
|
||||
Fields map[string]interface{} `json:"fields"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
|
||||
var resp Response
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) GetEmbedding(
|
||||
ctx context.Context,
|
||||
queryString string) ([]float32, error) {
|
||||
requestData, err := d.constructRequestData([]string{queryString})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct request data: %v", err)
|
||||
}
|
||||
|
||||
responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %v", err)
|
||||
}
|
||||
|
||||
embeddingResp, err := d.parseTextEmbedding(responseBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if len(embeddingResp.Output.Embeddings) == 0 {
|
||||
return nil, errors.New("no embedding found in response")
|
||||
}
|
||||
|
||||
return embeddingResp.Output.Embeddings[0].Embedding, nil
|
||||
}
|
||||
161
plugins/golang-filter/mcp-server/servers/rag/embedding/openai.go
Normal file
161
plugins/golang-filter/mcp-server/servers/rag/embedding/openai.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
const (
|
||||
OPENAI_DOMAIN = "api.openai.com"
|
||||
OPENAI_PORT = 443
|
||||
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
|
||||
OPENAI_ENDPOINT = "/v1/embeddings"
|
||||
)
|
||||
|
||||
type openAIProviderInitializer struct {
|
||||
}
|
||||
|
||||
var openAIConfig openAIProviderConfig
|
||||
|
||||
type openAIProviderConfig struct {
|
||||
baseUrl string
|
||||
apiKey string
|
||||
model string
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) {
|
||||
openAIConfig.apiKey = config.APIKey
|
||||
openAIConfig.model = config.Model
|
||||
openAIConfig.baseUrl = config.BaseURL
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) ValidateConfig() error {
|
||||
if openAIConfig.apiKey == "" {
|
||||
return errors.New("[openAI] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
||||
c.InitConfig(config)
|
||||
err := c.ValidateConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if openAIConfig.model == "" {
|
||||
openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME
|
||||
}
|
||||
|
||||
if openAIConfig.baseUrl == "" {
|
||||
openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN)
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + config.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers)
|
||||
|
||||
return &OpenAIProvider{
|
||||
config: openAIConfig,
|
||||
client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_OPENAI
|
||||
}
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIResult `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAIResult struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"prompt_tokens"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Param string `json:"param"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
config openAIProviderConfig
|
||||
client *common.HTTPClient
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) {
|
||||
if text == "" {
|
||||
return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty")
|
||||
}
|
||||
|
||||
if openAIConfig.apiKey == "" {
|
||||
return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty")
|
||||
}
|
||||
|
||||
model := o.config.model
|
||||
if model == "" {
|
||||
model = OPENAI_DEFAULT_MODEL_NAME
|
||||
}
|
||||
|
||||
data := OpenAIEmbeddingRequest{
|
||||
Input: text,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
|
||||
var resp OpenAIResponse
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) {
|
||||
requestData, err := o.constructRequestData(queryString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct request data: %v", err)
|
||||
}
|
||||
|
||||
responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := o.parseTextEmbedding(responseBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, errors.New("no embedding found in response")
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
// Provider type constants for different embedding services
|
||||
const (
|
||||
// DashScope embedding service
|
||||
PROVIDER_TYPE_DASHSCOPE = "dashscope"
|
||||
// TextIn embedding service
|
||||
PROVIDER_TYPE_TEXTIN = "textin"
|
||||
// Cohere embedding service
|
||||
PROVIDER_TYPE_COHERE = "cohere"
|
||||
// OpenAI embedding service
|
||||
PROVIDER_TYPE_OPENAI = "openai"
|
||||
// Ollama embedding service
|
||||
PROVIDER_TYPE_OLLAMA = "ollama"
|
||||
// HuggingFace embedding service
|
||||
PROVIDER_TYPE_HUGGINGFACE = "huggingface"
|
||||
// XFYun embedding service
|
||||
PROVIDER_TYPE_XFYUN = "xfyun"
|
||||
// Azure embedding service
|
||||
PROVIDER_TYPE_AZURE = "azure"
|
||||
)
|
||||
|
||||
// Factory interface for creating Provider instances
|
||||
type providerInitializer interface {
|
||||
// Creates a new Provider with the given configuration
|
||||
CreateProvider(config.EmbeddingConfig) (Provider, error)
|
||||
}
|
||||
|
||||
// Maps provider types to their initializers
|
||||
var (
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
|
||||
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
// Provider defines the interface for embedding services
|
||||
type Provider interface {
|
||||
// Returns the provider type identifier
|
||||
GetProviderType() string
|
||||
// Generates embedding vector for the input text
|
||||
// Returns a float32 array representing the embedding vector
|
||||
GetEmbedding(ctx context.Context, queryString string) ([]float32, error)
|
||||
}
|
||||
|
||||
// Creates a new embedding Provider based on the configuration
|
||||
// Returns error if provider type is not supported
|
||||
func NewEmbeddingProvider(config config.EmbeddingConfig) (Provider, error) {
|
||||
initializer, ok := providerInitializers[config.Provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no initializer found for provider type: %s", config.Provider)
|
||||
}
|
||||
return initializer.CreateProvider(config)
|
||||
}
|
||||
136
plugins/golang-filter/mcp-server/servers/rag/llm/openai.go
Normal file
136
plugins/golang-filter/mcp-server/servers/rag/llm/openai.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
const (
|
||||
OPENAI_CHAT_ENDPOINT = "/chat/completions"
|
||||
OPENAI_DEFAULT_MODEL = "gpt-3.5-turbo"
|
||||
)
|
||||
|
||||
// openAI specific configuration captured after initialization.
|
||||
type openAIProviderConfig struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
model string
|
||||
maxTokens int
|
||||
temperature float64
|
||||
}
|
||||
|
||||
type openAIProviderInitializer struct{}
|
||||
|
||||
var openAIConfig openAIProviderConfig
|
||||
|
||||
func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) {
|
||||
openAIConfig.apiKey = c.APIKey
|
||||
openAIConfig.baseURL = c.BaseURL
|
||||
openAIConfig.model = c.Model
|
||||
if openAIConfig.model == "" {
|
||||
openAIConfig.model = OPENAI_DEFAULT_MODEL
|
||||
}
|
||||
if openAIConfig.baseURL == "" {
|
||||
openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint
|
||||
}
|
||||
openAIConfig.maxTokens = c.MaxTokens
|
||||
openAIConfig.temperature = c.Temperature
|
||||
}
|
||||
|
||||
func (i *openAIProviderInitializer) validateConfig() error {
|
||||
if openAIConfig.apiKey == "" {
|
||||
return errors.New("[openai llm] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
|
||||
i.initConfig(cfg)
|
||||
if err := i.validateConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + openAIConfig.apiKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
client := common.NewHTTPClient(openAIConfig.baseURL, headers)
|
||||
return &OpenAIProvider{client: client, cfg: openAIConfig}, nil
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
client *common.HTTPClient
|
||||
cfg openAIProviderConfig
|
||||
}
|
||||
|
||||
type openAIChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Choices []openAIChatCompletionResponseChoice `json:"choices"`
|
||||
Error *openAIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChatCompletionResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Param string `json:"param"`
|
||||
}
|
||||
|
||||
// GenerateCompletion implements Provider interface.
|
||||
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
req := openAIChatCompletionRequest{
|
||||
Model: o.cfg.model,
|
||||
Messages: []openAIChatMessage{
|
||||
{Role: "user", Content: prompt},
|
||||
},
|
||||
Temperature: o.cfg.temperature,
|
||||
MaxTokens: o.cfg.maxTokens,
|
||||
}
|
||||
|
||||
body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("openai llm post error: %w", err)
|
||||
}
|
||||
|
||||
var resp openAIChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("openai llm unmarshal error: %w", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", errors.New("openai llm: empty choices")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_OPENAI
|
||||
}
|
||||
24
plugins/golang-filter/mcp-server/servers/rag/llm/prompt.go
Normal file
24
plugins/golang-filter/mcp-server/servers/rag/llm/prompt.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const RAGPromptTemplate = `You are a professional knowledge Q&A assistant. Your task is to provide accurate, complete, and strictly relevant answers based on the user's question and retrieved context.
|
||||
|
||||
Retrieved relevant context (may be empty, multiple segments separated by line breaks):
|
||||
{contexts}
|
||||
|
||||
User question:
|
||||
{query}
|
||||
|
||||
Requirements:
|
||||
1. If the context provides sufficient information, answer directly based on the context. You may use domain knowledge to supplement, but do not fabricate facts beyond the context.
|
||||
2. If the context is insufficient or unrelated to the question, respond with: "I am unable to answer this question."
|
||||
3. Your response must correctly answer the user's question and must not contain any irrelevant or unrelated content.`
|
||||
|
||||
func BuildPrompt(query string, contexts []string, join string) string {
|
||||
rendered := strings.ReplaceAll(RAGPromptTemplate, "{query}", query)
|
||||
rendered = strings.ReplaceAll(rendered, "{contexts}", strings.Join(contexts, join))
|
||||
return rendered
|
||||
}
|
||||
53
plugins/golang-filter/mcp-server/servers/rag/llm/provider.go
Normal file
53
plugins/golang-filter/mcp-server/servers/rag/llm/provider.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// OpenAI LLM provider
|
||||
PROVIDER_TYPE_OPENAI = "openai"
|
||||
// More providers can be added (e.g., Qwen)
|
||||
)
|
||||
|
||||
// Provider defines interface for LLM providers with prompt-response pattern.
|
||||
// Extensible for future chat-style and streaming features.
|
||||
type Provider interface {
|
||||
// Returns provider type for registration and lookup
|
||||
GetProviderType() string
|
||||
|
||||
// Generates text response for given prompt
|
||||
//
|
||||
// ctx: For cancellation and timeout
|
||||
// prompt: Input text
|
||||
// Returns: Generated response and error if any
|
||||
GenerateCompletion(ctx context.Context, prompt string) (string, error)
|
||||
}
|
||||
|
||||
// Factory interface for creating Provider instances
|
||||
type providerInitializer interface {
|
||||
// Creates Provider with given config
|
||||
CreateProvider(config.LLMConfig) (Provider, error)
|
||||
}
|
||||
|
||||
// Maps provider types to initializers
|
||||
var (
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
// Creates Provider instance based on config
|
||||
//
|
||||
// cfg: Provider config
|
||||
// Returns: Provider instance and error if any
|
||||
func NewLLMProvider(cfg config.LLMConfig) (Provider, error) {
|
||||
initializer, ok := providerInitializers[cfg.Provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no initializer found for llm provider type: %s", cfg.Provider)
|
||||
}
|
||||
return initializer.CreateProvider(cfg)
|
||||
}
|
||||
158
plugins/golang-filter/mcp-server/servers/rag/rag_client.go
Normal file
158
plugins/golang-filter/mcp-server/servers/rag/rag_client.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/embedding"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/llm"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/textsplitter"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/vectordb"
|
||||
"github.com/distribution/distribution/v3/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
MAX_LIST_KNOWLEDGE_ROW_COUNT = 1000
|
||||
MAX_LIST_DOCUMENT_ROW_COUNT = 1000
|
||||
)
|
||||
|
||||
// RAGClient represents the RAG (Retrieval-Augmented Generation) client
|
||||
type RAGClient struct {
|
||||
config *config.Config
|
||||
vectordbProvider vectordb.VectorStoreProvider
|
||||
embeddingProvider embedding.Provider
|
||||
textSplitter textsplitter.TextSplitter
|
||||
llmProvider llm.Provider
|
||||
}
|
||||
|
||||
// NewRAGClient creates a new RAG client instance
|
||||
func NewRAGClient(config *config.Config) (*RAGClient, error) {
|
||||
ragclient := &RAGClient{
|
||||
config: config,
|
||||
}
|
||||
textSplitter, err := textsplitter.NewTextSplitter(&config.RAG.Splitter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create text splitter failed, err: %w", err)
|
||||
}
|
||||
ragclient.textSplitter = textSplitter
|
||||
|
||||
embeddingProvider, err := embedding.NewEmbeddingProvider(ragclient.config.Embedding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create embedding provider failed, err: %w", err)
|
||||
}
|
||||
ragclient.embeddingProvider = embeddingProvider
|
||||
|
||||
llmProvider, err := llm.NewLLMProvider(ragclient.config.LLM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create llm provider failed, err: %w", err)
|
||||
}
|
||||
ragclient.llmProvider = llmProvider
|
||||
|
||||
demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create init embedding failed, err: %w", err)
|
||||
}
|
||||
dim := len(demoVector)
|
||||
|
||||
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
|
||||
}
|
||||
ragclient.vectordbProvider = provider
|
||||
|
||||
return ragclient, nil
|
||||
}
|
||||
|
||||
// ListChunks lists document chunks by knowledge ID, returns in ascending order of DocumentIndex
|
||||
func (r *RAGClient) ListChunks() ([]schema.Document, error) {
|
||||
docs, err := r.vectordbProvider.ListDocs(context.Background(), MAX_LIST_DOCUMENT_ROW_COUNT)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list chunks failed, err: %w", err)
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// DeleteChunk deletes a specific document chunk
|
||||
func (r *RAGClient) DeleteChunk(id string) error {
|
||||
if err := r.vectordbProvider.DeleteDocs(context.Background(), []string{id}); err != nil {
|
||||
return fmt.Errorf("delete chunk failed, err: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RAGClient) CreateChunkFromText(text string, title string) ([]schema.Document, error) {
|
||||
|
||||
docs, err := textsplitter.CreateDocuments(r.textSplitter, []string{text}, make([]map[string]any, 0))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create documents failed, err: %w", err)
|
||||
}
|
||||
|
||||
results := make([]schema.Document, 0, len(docs))
|
||||
|
||||
for chunkIndex, doc := range docs {
|
||||
doc.ID = uuid.Generate().String()
|
||||
doc.Metadata["chunk_index"] = chunkIndex
|
||||
doc.Metadata["chunk_title"] = title
|
||||
doc.Metadata["chunk_size"] = len(doc.Content)
|
||||
// Generate embedding for the document
|
||||
embedding, err := r.embeddingProvider.GetEmbedding(context.Background(), doc.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create embedding failed, err: %w", err)
|
||||
}
|
||||
doc.Vector = embedding
|
||||
doc.CreatedAt = time.Now()
|
||||
results = append(results, doc)
|
||||
}
|
||||
|
||||
if err := r.vectordbProvider.AddDoc(context.Background(), results); err != nil {
|
||||
return nil, fmt.Errorf("add documents failed, err: %w", err)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SearchChunks searches for document chunks
|
||||
func (r *RAGClient) SearchChunks(query string, topK int, threshold float64) ([]schema.SearchResult, error) {
|
||||
|
||||
vector, err := r.embeddingProvider.GetEmbedding(context.Background(), query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create embedding failed, err: %w", err)
|
||||
}
|
||||
options := &schema.SearchOptions{
|
||||
TopK: topK,
|
||||
Threshold: threshold,
|
||||
}
|
||||
docs, err := r.vectordbProvider.SearchDocs(context.Background(), vector, options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search chunks failed, err: %w", err)
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// Chat generates a response using LLM
|
||||
func (r *RAGClient) Chat(query string) (string, error) {
|
||||
if r.llmProvider == nil {
|
||||
return "", fmt.Errorf("llm provider not initialized")
|
||||
}
|
||||
|
||||
docs, err := r.SearchChunks(query, r.config.RAG.TopK, r.config.RAG.Threshold)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("search chunks failed, err: %w", err)
|
||||
}
|
||||
|
||||
contexts := make([]string, 0, len(docs))
|
||||
for _, doc := range docs {
|
||||
contexts = append(contexts, strings.ReplaceAll(doc.Document.Content, "\n", " "))
|
||||
}
|
||||
|
||||
prompt := llm.BuildPrompt(query, contexts, "\n\n")
|
||||
resp, err := r.llmProvider.GenerateCompletion(context.Background(), prompt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate completion failed, err: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
150
plugins/golang-filter/mcp-server/servers/rag/rag_client_test.go
Normal file
150
plugins/golang-filter/mcp-server/servers/rag/rag_client_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
func getRAGClient() (*RAGClient, error) {
|
||||
config := &config.Config{
|
||||
RAG: config.RAGConfig{
|
||||
Splitter: config.SplitterConfig{
|
||||
Provider: "recursive",
|
||||
ChunkSize: 200,
|
||||
ChunkOverlap: 20,
|
||||
},
|
||||
Threshold: 0.5,
|
||||
TopK: 10,
|
||||
},
|
||||
|
||||
LLM: config.LLMConfig{
|
||||
Provider: "openai",
|
||||
APIKey: "sk-xxxx",
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
Model: "openai/gpt-4o",
|
||||
},
|
||||
|
||||
Embedding: config.EmbeddingConfig{
|
||||
Provider: "dashscope",
|
||||
APIKey: "sk-xxxx",
|
||||
Model: "text-embedding-v4",
|
||||
},
|
||||
|
||||
VectorDB: config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
Host: "localhost",
|
||||
Port: 19530,
|
||||
Database: "default",
|
||||
Collection: "test_collection",
|
||||
},
|
||||
}
|
||||
|
||||
ragClient, err := NewRAGClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ragClient, nil
|
||||
|
||||
}
|
||||
|
||||
func TestNewRAGClient(t *testing.T) {
|
||||
_, err := getRAGClient()
|
||||
if err != nil {
|
||||
t.Errorf("getRAGClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestRAGClient_CreateChunkFromText(t *testing.T) {
|
||||
ragClient, err := getRAGClient()
|
||||
if err != nil {
|
||||
t.Errorf("getRAGClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
text := "The multi-agent interaction technology competition based on the openKylin desktop environment aims to promote the development of agent applications on the openKylin open-source OS, using the Kirin AI inference framework and the UKUI desktop environment. These applications should have autonomous planning and decision-making capabilities, access to system resources, and the ability to call system and desktop environment interfaces and tools, with memory functions. They should also be able to collaborate with other agent applications. The competition aims to deeply explore the integration of operating systems and AI and help enhance the international competitiveness of domestic open-source operating systems."
|
||||
chunkName := "test_chunk3"
|
||||
docs, err := ragClient.CreateChunkFromText(text, chunkName)
|
||||
if err != nil {
|
||||
t.Errorf("CreateChunkFromText() error = %v", err)
|
||||
return
|
||||
}
|
||||
if len(docs) != 1 {
|
||||
t.Errorf("CreateChunkFromText() docs len = %d, want 1", len(docs))
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRAGClient_ListChunks(t *testing.T) {
|
||||
ragClient, err := getRAGClient()
|
||||
if err != nil {
|
||||
t.Errorf("getRAGClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
docs, err := ragClient.ListChunks()
|
||||
if err != nil {
|
||||
t.Errorf("ListChunks() error = %v", err)
|
||||
return
|
||||
}
|
||||
if len(docs) == 0 {
|
||||
t.Errorf("ListChunks() docs len = %d, want > 0", len(docs))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestRAGClient_DeleteChunk(t *testing.T) {
|
||||
ragClient, err := getRAGClient()
|
||||
if err != nil {
|
||||
t.Errorf("getRAGClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2"
|
||||
err = ragClient.DeleteChunk(chunk_id)
|
||||
if err != nil {
|
||||
t.Errorf("DeleteChunk() error = %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestRAGClient_SearchChunks(t *testing.T) {
|
||||
ragClient, err := getRAGClient()
|
||||
if err != nil {
|
||||
t.Errorf("getRAGClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
topk := 2
|
||||
threshold := 0.5
|
||||
query := "multi-agent"
|
||||
docs, err := ragClient.SearchChunks(query, topk, threshold)
|
||||
if err != nil {
|
||||
t.Errorf("SearchChunks() error = %v", err)
|
||||
return
|
||||
}
|
||||
if len(docs) != topk {
|
||||
t.Errorf("SearchChunks() docs len = %d, want %d", len(docs), topk)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRAGClient_Chat(t *testing.T) {
|
||||
ragClient, err := getRAGClient()
|
||||
if err != nil {
|
||||
t.Errorf("getRAGClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
query := "what is the competition about?"
|
||||
resp, err := ragClient.Chat(query)
|
||||
if err != nil {
|
||||
t.Errorf("Chat() error = %v", err)
|
||||
return
|
||||
}
|
||||
if resp == "" {
|
||||
t.Errorf("Chat() resp = %s, want not empty", resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package schema
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
DEFAULT_KNOWLEDGE_COLLECTION = "knowledge"
|
||||
DEFAULT_DOCUMENT_COLLECTION = "document"
|
||||
)
|
||||
|
||||
// Document represents a document with its vector embedding and metadata
|
||||
type Document struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Vector []float32 `json:"-"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// SearchResult represents a result from a vector search
|
||||
type SearchResult struct {
|
||||
Document Document `json:"document"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
// Knowledge represents a knowledge entity with associated documents
|
||||
type Knowledge struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
SourceURL string `json:"source_url"`
|
||||
Status string `json:"status"`
|
||||
FileSize int64 `json:"file_size"`
|
||||
ChunkCount int `json:"chunk_count"`
|
||||
EnableMultimodel bool `json:"enable_multimodel"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Documents []Document `json:"-"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
CompletedAt time.Time `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
// SearchOptions contains options for vector search
|
||||
type SearchOptions struct {
|
||||
TopK int `json:"top_k"`
|
||||
Threshold float64 `json:"threshold"`
|
||||
Filters map[string]interface{} `json:"filters,omitempty"`
|
||||
}
|
||||
198
plugins/golang-filter/mcp-server/servers/rag/server.go
Normal file
198
plugins/golang-filter/mcp-server/servers/rag/server.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const Version = "1.0.0"
|
||||
|
||||
type RAGConfig struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func init() {
|
||||
common.GlobalRegistry.RegisterServer("rag", &RAGConfig{
|
||||
config: &config.Config{
|
||||
RAG: config.RAGConfig{
|
||||
Splitter: config.SplitterConfig{
|
||||
Provider: "recursive",
|
||||
ChunkSize: 500,
|
||||
ChunkOverlap: 50,
|
||||
},
|
||||
Threshold: 0.5,
|
||||
TopK: 10,
|
||||
},
|
||||
LLM: config.LLMConfig{
|
||||
Provider: "openai",
|
||||
APIKey: "",
|
||||
BaseURL: "",
|
||||
Model: "gpt-4o",
|
||||
Temperature: 0.5,
|
||||
MaxTokens: 2048,
|
||||
},
|
||||
Embedding: config.EmbeddingConfig{
|
||||
Provider: "dashscope",
|
||||
APIKey: "",
|
||||
BaseURL: "",
|
||||
Model: "text-embedding-v4",
|
||||
Dimension: 1024,
|
||||
},
|
||||
VectorDB: config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Database: "default",
|
||||
Collection: "rag",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
||||
// Parse RAG configuration
|
||||
if ragConfig, ok := config["rag"].(map[string]any); ok {
|
||||
if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
|
||||
if splitterType, exists := splitter["provider"].(string); exists {
|
||||
c.config.RAG.Splitter.Provider = splitterType
|
||||
}
|
||||
if chunkSize, exists := splitter["chunk_size"].(float64); exists {
|
||||
c.config.RAG.Splitter.ChunkSize = int(chunkSize)
|
||||
}
|
||||
if chunkOverlap, exists := splitter["chunk_overlap"].(float64); exists {
|
||||
c.config.RAG.Splitter.ChunkOverlap = int(chunkOverlap)
|
||||
}
|
||||
}
|
||||
if threshold, exists := ragConfig["threshold"].(float64); exists {
|
||||
c.config.RAG.Threshold = threshold
|
||||
}
|
||||
if topK, exists := ragConfig["top_k"].(float64); exists {
|
||||
c.config.RAG.TopK = int(topK)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Embedding configuration
|
||||
if embeddingConfig, ok := config["embedding"].(map[string]any); ok {
|
||||
if provider, exists := embeddingConfig["provider"].(string); exists {
|
||||
c.config.Embedding.Provider = provider
|
||||
} else {
|
||||
return errors.New("missing embedding provider")
|
||||
}
|
||||
|
||||
if apiKey, exists := embeddingConfig["api_key"].(string); exists {
|
||||
c.config.Embedding.APIKey = apiKey
|
||||
}
|
||||
if baseURL, exists := embeddingConfig["base_url"].(string); exists {
|
||||
c.config.Embedding.BaseURL = baseURL
|
||||
}
|
||||
if model, exists := embeddingConfig["model"].(string); exists {
|
||||
c.config.Embedding.Model = model
|
||||
}
|
||||
if dimension, exists := embeddingConfig["dimension"].(float64); exists {
|
||||
c.config.Embedding.Dimension = int(dimension)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse llm configuration
|
||||
if llmConfig, ok := config["llm"].(map[string]any); ok {
|
||||
if provider, exists := llmConfig["provider"].(string); exists {
|
||||
c.config.LLM.Provider = provider
|
||||
} else {
|
||||
return errors.New("missing llm provider")
|
||||
}
|
||||
if apiKey, exists := llmConfig["api_key"].(string); exists {
|
||||
c.config.LLM.APIKey = apiKey
|
||||
}
|
||||
if baseURL, exists := llmConfig["base_url"].(string); exists {
|
||||
c.config.LLM.BaseURL = baseURL
|
||||
}
|
||||
if model, exists := llmConfig["model"].(string); exists {
|
||||
c.config.LLM.Model = model
|
||||
}
|
||||
if temperature, exists := llmConfig["temperature"].(float64); exists {
|
||||
c.config.LLM.Temperature = temperature
|
||||
}
|
||||
if maxTokens, exists := llmConfig["max_tokens"].(float64); exists {
|
||||
c.config.LLM.MaxTokens = int(maxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse VectorDB configuration
|
||||
if vectordbConfig, ok := config["vectordb"].(map[string]any); ok {
|
||||
if provider, exists := vectordbConfig["provider"].(string); exists {
|
||||
c.config.VectorDB.Provider = provider
|
||||
} else {
|
||||
return errors.New("missing vectordb provider")
|
||||
}
|
||||
if host, exists := vectordbConfig["host"].(string); exists {
|
||||
c.config.VectorDB.Host = host
|
||||
}
|
||||
if port, exists := vectordbConfig["port"].(float64); exists {
|
||||
c.config.VectorDB.Port = int(port)
|
||||
}
|
||||
if dbName, exists := vectordbConfig["database"].(string); exists {
|
||||
c.config.VectorDB.Database = dbName
|
||||
}
|
||||
if collection, exists := vectordbConfig["collection"].(string); exists {
|
||||
c.config.VectorDB.Collection = collection
|
||||
}
|
||||
if username, exists := vectordbConfig["username"].(string); exists {
|
||||
c.config.VectorDB.Username = username
|
||||
}
|
||||
if password, exists := vectordbConfig["password"].(string); exists {
|
||||
c.config.VectorDB.Password = password
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RAGConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
mcpServer := common.NewMCPServer(
|
||||
serverName,
|
||||
Version,
|
||||
common.WithInstructions("This is a RAG (Retrieval-Augmented Generation) server for knowledge management and intelligent Q&A"),
|
||||
)
|
||||
|
||||
// Initialize RAG client with configuration
|
||||
ragClient, err := NewRAGClient(c.config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create rag client failed, err: %w", err)
|
||||
}
|
||||
|
||||
// Knowledge Base Management Tools
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("create-chunks-from-text", "Process and segment input text into semantic chunks for knowledge base ingestion", GetCreateChunkFromTextSchema()),
|
||||
HandleCreateChunkFromText(ragClient),
|
||||
)
|
||||
|
||||
// Chunk Management Tools
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("list-chunks", "Retrieve and display all knowledge chunks in the database", GetListChunksSchema()),
|
||||
HandleListChunks(ragClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("delete-chunk", "Remove a specific knowledge chunk from the database using its unique identifier", GetDeleteChunkSchema()),
|
||||
HandleDeleteChunk(ragClient),
|
||||
)
|
||||
|
||||
// Semantic Search Tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("search-chunks", "Perform semantic search across knowledge chunks using natural language query", GetSearchSchema()),
|
||||
HandleSearch(ragClient),
|
||||
)
|
||||
|
||||
// Intelligent Q&A Tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("chat", "Generate contextually relevant responses using RAG system with LLM integration", GetChatSchema()),
|
||||
HandleChat(ragClient),
|
||||
)
|
||||
|
||||
return mcpServer, nil
|
||||
}
|
||||
54
plugins/golang-filter/mcp-server/servers/rag/server_test.go
Normal file
54
plugins/golang-filter/mcp-server/servers/rag/server_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestRAGConfig_ParseConfig(t *testing.T) {
|
||||
config := &config.Config{
|
||||
RAG: config.RAGConfig{
|
||||
Splitter: config.SplitterConfig{
|
||||
Provider: "nosplitter",
|
||||
ChunkSize: 500,
|
||||
ChunkOverlap: 50,
|
||||
},
|
||||
Threshold: 0.5,
|
||||
TopK: 5,
|
||||
},
|
||||
LLM: config.LLMConfig{
|
||||
Provider: "openai",
|
||||
APIKey: "sk-XXX",
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
Model: "openai/gpt-4o",
|
||||
Temperature: 0.5,
|
||||
MaxTokens: 2048,
|
||||
},
|
||||
Embedding: config.EmbeddingConfig{
|
||||
Provider: "dashscope",
|
||||
APIKey: "sk-XXX",
|
||||
BaseURL: "",
|
||||
Model: "text-embedding-v4",
|
||||
Dimension: 1024,
|
||||
},
|
||||
VectorDB: config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
Host: "localhost",
|
||||
Port: 19530,
|
||||
Database: "default",
|
||||
Collection: "test_rag",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
}
|
||||
// 把 config 输出 yaml 格式
|
||||
yaml, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal config failed, err: %v", err)
|
||||
}
|
||||
t.Logf("config yaml: %s", string(yaml))
|
||||
fmt.Printf("\n%s", string(yaml))
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package textsplitter
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
const (
|
||||
// nolint:gosec
|
||||
_defaultTokenModelName = "gpt-3.5-turbo"
|
||||
_defaultTokenEncoding = "cl100k_base"
|
||||
_defaultTokenChunkSize = 512
|
||||
_defaultTokenChunkOverlap = 100
|
||||
)
|
||||
|
||||
// Options is a struct that contains options for a text splitter.
|
||||
type Options struct {
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
Separators []string
|
||||
KeepSeparator bool
|
||||
LenFunc func(string) int
|
||||
ModelName string
|
||||
EncodingName string
|
||||
AllowedSpecial []string
|
||||
DisallowedSpecial []string
|
||||
SecondSplitter TextSplitter
|
||||
CodeBlocks bool
|
||||
ReferenceLinks bool
|
||||
KeepHeadingHierarchy bool // Persist hierarchy of markdown headers in each chunk
|
||||
JoinTableRows bool
|
||||
}
|
||||
|
||||
// DefaultOptions returns the default options for all text splitter.
|
||||
func DefaultOptions() Options {
|
||||
return Options{
|
||||
ChunkSize: _defaultTokenChunkSize,
|
||||
ChunkOverlap: _defaultTokenChunkOverlap,
|
||||
Separators: []string{"\n\n", "\n", " ", ""},
|
||||
KeepSeparator: false,
|
||||
LenFunc: utf8.RuneCountInString,
|
||||
ModelName: _defaultTokenModelName,
|
||||
EncodingName: _defaultTokenEncoding,
|
||||
AllowedSpecial: []string{},
|
||||
DisallowedSpecial: []string{"all"},
|
||||
KeepHeadingHierarchy: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Option is a function that can be used to set options for a text splitter.
|
||||
type Option func(*Options)
|
||||
|
||||
// WithChunkSize sets the chunk size for a text splitter.
|
||||
func WithChunkSize(chunkSize int) Option {
|
||||
return func(o *Options) {
|
||||
o.ChunkSize = chunkSize
|
||||
}
|
||||
}
|
||||
|
||||
// WithChunkOverlap sets the chunk overlap for a text splitter.
|
||||
func WithChunkOverlap(chunkOverlap int) Option {
|
||||
return func(o *Options) {
|
||||
o.ChunkOverlap = chunkOverlap
|
||||
}
|
||||
}
|
||||
|
||||
// WithSeparators sets the separators for a text splitter.
|
||||
func WithSeparators(separators []string) Option {
|
||||
return func(o *Options) {
|
||||
o.Separators = separators
|
||||
}
|
||||
}
|
||||
|
||||
// WithLenFunc sets the lenfunc for a text splitter.
|
||||
func WithLenFunc(lenFunc func(string) int) Option {
|
||||
return func(o *Options) {
|
||||
o.LenFunc = lenFunc
|
||||
}
|
||||
}
|
||||
|
||||
// WithModelName sets the model name for a text splitter.
|
||||
func WithModelName(modelName string) Option {
|
||||
return func(o *Options) {
|
||||
o.ModelName = modelName
|
||||
}
|
||||
}
|
||||
|
||||
// WithEncodingName sets the encoding name for a text splitter.
|
||||
func WithEncodingName(encodingName string) Option {
|
||||
return func(o *Options) {
|
||||
o.EncodingName = encodingName
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowedSpecial sets the allowed special tokens for a text splitter.
|
||||
func WithAllowedSpecial(allowedSpecial []string) Option {
|
||||
return func(o *Options) {
|
||||
o.AllowedSpecial = allowedSpecial
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisallowedSpecial sets the disallowed special tokens for a text splitter.
|
||||
func WithDisallowedSpecial(disallowedSpecial []string) Option {
|
||||
return func(o *Options) {
|
||||
o.DisallowedSpecial = disallowedSpecial
|
||||
}
|
||||
}
|
||||
|
||||
// WithSecondSplitter sets the second splitter for a text splitter.
|
||||
func WithSecondSplitter(secondSplitter TextSplitter) Option {
|
||||
return func(o *Options) {
|
||||
o.SecondSplitter = secondSplitter
|
||||
}
|
||||
}
|
||||
|
||||
// WithCodeBlocks sets whether indented and fenced codeblocks should be included
|
||||
// in the output.
|
||||
func WithCodeBlocks(renderCode bool) Option {
|
||||
return func(o *Options) {
|
||||
o.CodeBlocks = renderCode
|
||||
}
|
||||
}
|
||||
|
||||
// WithReferenceLinks sets whether reference links (i.e. `[text][label]`)
|
||||
// should be patched with the url and title from their definition. Note that
|
||||
// by default reference definitions are dropped from the output.
|
||||
//
|
||||
// Caution: this also affects how other inline elements are rendered, e.g. all
|
||||
// emphasis will use `*` even when another character (e.g. `_`) was used in the
|
||||
// input.
|
||||
func WithReferenceLinks(referenceLinks bool) Option {
|
||||
return func(o *Options) {
|
||||
o.ReferenceLinks = referenceLinks
|
||||
}
|
||||
}
|
||||
|
||||
// WithKeepSeparator sets whether the separators should be kept in the resulting
|
||||
// split text or not. When it is set to True, the separators are included in the
|
||||
// resulting split text. When it is set to False, the separators are not included
|
||||
// in the resulting split text. The purpose of having this parameter is to provide
|
||||
// flexibility in how text splitting is handled. Default to False if not specified.
|
||||
func WithKeepSeparator(keepSeparator bool) Option {
|
||||
return func(o *Options) {
|
||||
o.KeepSeparator = keepSeparator
|
||||
}
|
||||
}
|
||||
|
||||
// WithHeadingHierarchy sets whether the hierarchy of headings in a document should
|
||||
// be persisted in the resulting chunks. When it is set to true, each chunk gets prepended
|
||||
// with a list of all parent headings in the hierarchy up to this point.
|
||||
// The purpose of having this parameter is to allow for returning more relevant chunks during
|
||||
// similarity search. Default to False if not specified.
|
||||
func WithHeadingHierarchy(trackHeadingHierarchy bool) Option {
|
||||
return func(o *Options) {
|
||||
o.KeepHeadingHierarchy = trackHeadingHierarchy
|
||||
}
|
||||
}
|
||||
|
||||
// WithJoinTableRows sets whether tables should be split by row or not. When it is set to True,
|
||||
// table rows are joined until the chunksize. When it is set to False (the default), tables are
|
||||
// split by row.
|
||||
//
|
||||
// The default behavior is to split tables by row, so that each row is in a separate chunk.
|
||||
func WithJoinTableRows(join bool) Option {
|
||||
return func(o *Options) {
|
||||
o.JoinTableRows = join
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package textsplitter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RecursiveCharacter is a text splitter that will split texts recursively by different
|
||||
// characters.
|
||||
type RecursiveCharacter struct {
|
||||
Separators []string
|
||||
ChunkSize int
|
||||
ChunkOverlap int
|
||||
LenFunc func(string) int
|
||||
KeepSeparator bool
|
||||
}
|
||||
|
||||
// NewRecursiveCharacter creates a new recursive character splitter with default values. By
|
||||
// default, the separators used are "\n\n", "\n", " " and "". The chunk size is set to 4000
|
||||
// and chunk overlap is set to 200.
|
||||
func NewRecursiveCharacter(opts ...Option) RecursiveCharacter {
|
||||
options := DefaultOptions()
|
||||
for _, o := range opts {
|
||||
o(&options)
|
||||
}
|
||||
|
||||
s := RecursiveCharacter{
|
||||
Separators: options.Separators,
|
||||
ChunkSize: options.ChunkSize,
|
||||
ChunkOverlap: options.ChunkOverlap,
|
||||
LenFunc: options.LenFunc,
|
||||
KeepSeparator: options.KeepSeparator,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SplitText splits a text into multiple text.
|
||||
func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
|
||||
return s.splitText(text, s.Separators)
|
||||
}
|
||||
|
||||
// addSeparatorInSplits adds the separator in each of splits.
|
||||
func (s RecursiveCharacter) addSeparatorInSplits(splits []string, separator string) []string {
|
||||
splitsWithSeparator := make([]string, 0, len(splits))
|
||||
for i, s := range splits {
|
||||
if i > 0 {
|
||||
s = separator + s
|
||||
}
|
||||
splitsWithSeparator = append(splitsWithSeparator, s)
|
||||
}
|
||||
return splitsWithSeparator
|
||||
}
|
||||
|
||||
func (s RecursiveCharacter) splitText(text string, separators []string) ([]string, error) {
|
||||
finalChunks := make([]string, 0)
|
||||
|
||||
// Find the appropriate separator.
|
||||
separator := separators[len(separators)-1]
|
||||
newSeparators := []string{}
|
||||
for i, c := range separators {
|
||||
if c == "" || strings.Contains(text, c) {
|
||||
separator = c
|
||||
newSeparators = separators[i+1:]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
splits := strings.Split(text, separator)
|
||||
if s.KeepSeparator {
|
||||
splits = s.addSeparatorInSplits(splits, separator)
|
||||
separator = ""
|
||||
}
|
||||
goodSplits := make([]string, 0)
|
||||
|
||||
// Merge the splits, recursively splitting larger texts.
|
||||
for _, split := range splits {
|
||||
if s.LenFunc(split) < s.ChunkSize {
|
||||
goodSplits = append(goodSplits, split)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(goodSplits) > 0 {
|
||||
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap, s.LenFunc)
|
||||
|
||||
finalChunks = append(finalChunks, mergedText...)
|
||||
goodSplits = make([]string, 0)
|
||||
}
|
||||
|
||||
if len(newSeparators) == 0 {
|
||||
finalChunks = append(finalChunks, split)
|
||||
} else {
|
||||
otherInfo, err := s.splitText(split, newSeparators)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
finalChunks = append(finalChunks, otherInfo...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(goodSplits) > 0 {
|
||||
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap, s.LenFunc)
|
||||
finalChunks = append(finalChunks, mergedText...)
|
||||
}
|
||||
|
||||
return finalChunks, nil
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
package textsplitter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
//nolint:dupword,funlen
|
||||
func TestRecursiveCharacterSplitter(t *testing.T) {
|
||||
tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base")
|
||||
|
||||
// t.Parallel()
|
||||
type testCase struct {
|
||||
text string
|
||||
chunkOverlap int
|
||||
chunkSize int
|
||||
separators []string
|
||||
expectedDocs []schema.Document
|
||||
keepSeparator bool
|
||||
LenFunc func(string) int
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
text: "哈里森\n很高兴遇见你\n欢迎来中国",
|
||||
chunkOverlap: 0,
|
||||
chunkSize: 10,
|
||||
separators: []string{"\n\n", "\n", " "},
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "哈里森\n很高兴遇见你", Metadata: map[string]any{}},
|
||||
{Content: "欢迎来中国", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: "Hi, Harrison. \nI am glad to meet you",
|
||||
chunkOverlap: 1,
|
||||
chunkSize: 20,
|
||||
separators: []string{"\n", "$"},
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "Hi, Harrison.", Metadata: map[string]any{}},
|
||||
{Content: "I am glad to meet you", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: "Hi.\nI'm Harrison.\n\nHow?\na\nbHi.\nI'm Harrison.\n\nHow?\na\nb",
|
||||
chunkOverlap: 1,
|
||||
chunkSize: 40,
|
||||
separators: []string{"\n\n", "\n", " ", ""},
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "Hi.\nI'm Harrison.", Metadata: map[string]any{}},
|
||||
{Content: "How?\na\nbHi.\nI'm Harrison.\n\nHow?\na\nb", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: "name: Harrison\nage: 30",
|
||||
chunkOverlap: 1,
|
||||
chunkSize: 40,
|
||||
separators: []string{"\n\n", "\n", " ", ""},
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "name: Harrison\nage: 30", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: `name: Harrison
|
||||
age: 30
|
||||
|
||||
name: Joe
|
||||
age: 32`,
|
||||
chunkOverlap: 1,
|
||||
chunkSize: 40,
|
||||
separators: []string{"\n\n", "\n", " ", ""},
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "name: Harrison\nage: 30", Metadata: map[string]any{}},
|
||||
{Content: "name: Joe\nage: 32", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: `Hi.
|
||||
I'm Harrison.
|
||||
|
||||
How? Are? You?
|
||||
Okay then f f f f.
|
||||
This is a weird text to write, but gotta test the splittingggg some how.
|
||||
|
||||
Bye!
|
||||
|
||||
-H.`,
|
||||
chunkOverlap: 1,
|
||||
chunkSize: 10,
|
||||
separators: []string{"\n\n", "\n", " ", ""},
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "Hi.", Metadata: map[string]any{}},
|
||||
{Content: "I'm", Metadata: map[string]any{}},
|
||||
{Content: "Harrison.", Metadata: map[string]any{}},
|
||||
{Content: "How? Are?", Metadata: map[string]any{}},
|
||||
{Content: "You?", Metadata: map[string]any{}},
|
||||
{Content: "Okay then", Metadata: map[string]any{}},
|
||||
{Content: "f f f f.", Metadata: map[string]any{}},
|
||||
{Content: "This is a", Metadata: map[string]any{}},
|
||||
{Content: "a weird", Metadata: map[string]any{}},
|
||||
{Content: "text to", Metadata: map[string]any{}},
|
||||
{Content: "write, but", Metadata: map[string]any{}},
|
||||
{Content: "gotta test", Metadata: map[string]any{}},
|
||||
{Content: "the", Metadata: map[string]any{}},
|
||||
{Content: "splittingg", Metadata: map[string]any{}},
|
||||
{Content: "ggg", Metadata: map[string]any{}},
|
||||
{Content: "some how.", Metadata: map[string]any{}},
|
||||
{Content: "Bye!\n\n-H.", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: "Hi, Harrison. \nI am glad to meet you",
|
||||
chunkOverlap: 0,
|
||||
chunkSize: 10,
|
||||
separators: []string{"\n", "$"},
|
||||
keepSeparator: true,
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "Hi, Harrison. ", Metadata: map[string]any{}},
|
||||
{Content: "\nI am glad to meet you", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
text: strings.Repeat("The quick brown fox jumped over the lazy dog. ", 2),
|
||||
chunkOverlap: 0,
|
||||
chunkSize: 10,
|
||||
separators: []string{" "},
|
||||
keepSeparator: true,
|
||||
LenFunc: func(s string) int { return len(tokenEncoder.Encode(s, nil, nil)) },
|
||||
expectedDocs: []schema.Document{
|
||||
{Content: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
|
||||
{Content: "The quick brown fox jumped over the lazy dog.", Metadata: map[string]any{}},
|
||||
},
|
||||
},
|
||||
}
|
||||
splitter := NewRecursiveCharacter()
|
||||
for _, tc := range testCases {
|
||||
splitter.ChunkOverlap = tc.chunkOverlap
|
||||
splitter.ChunkSize = tc.chunkSize
|
||||
splitter.Separators = tc.separators
|
||||
splitter.KeepSeparator = tc.keepSeparator
|
||||
if tc.LenFunc != nil {
|
||||
splitter.LenFunc = tc.LenFunc
|
||||
}
|
||||
|
||||
docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.expectedDocs, docs)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package textsplitter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
)
|
||||
|
||||
// ErrMismatchMetadatasAndText is returned when the number of texts and metadatas
|
||||
// given to CreateDocuments does not match. The function will not error if the
|
||||
// length of the metadatas slice is zero.
|
||||
var ErrMismatchMetadatasAndText = errors.New("number of texts and metadatas does not match")
|
||||
|
||||
// SplitDocuments splits documents using a textsplitter.
|
||||
func SplitDocuments(textSplitter TextSplitter, documents []schema.Document) ([]schema.Document, error) {
|
||||
texts := make([]string, 0)
|
||||
metadatas := make([]map[string]any, 0)
|
||||
for _, document := range documents {
|
||||
texts = append(texts, document.Content)
|
||||
metadatas = append(metadatas, document.Metadata)
|
||||
}
|
||||
|
||||
return CreateDocuments(textSplitter, texts, metadatas)
|
||||
}
|
||||
|
||||
// CreateDocuments creates documents from texts and metadatas with a text splitter. If
|
||||
// the length of the metadatas is zero, the result documents will contain no metadata.
|
||||
// Otherwise, the numbers of texts and metadatas must match.
|
||||
func CreateDocuments(textSplitter TextSplitter, texts []string, metadatas []map[string]any) ([]schema.Document, error) {
|
||||
if len(metadatas) == 0 {
|
||||
metadatas = make([]map[string]any, len(texts))
|
||||
}
|
||||
|
||||
if len(texts) != len(metadatas) {
|
||||
return nil, ErrMismatchMetadatasAndText
|
||||
}
|
||||
|
||||
documents := make([]schema.Document, 0)
|
||||
|
||||
for i := 0; i < len(texts); i++ {
|
||||
chunks, err := textSplitter.SplitText(texts[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, chunk := range chunks {
|
||||
// Copy the document metadata
|
||||
curMetadata := make(map[string]any, len(metadatas[i]))
|
||||
for key, value := range metadatas[i] {
|
||||
curMetadata[key] = value
|
||||
}
|
||||
|
||||
documents = append(documents, schema.Document{
|
||||
Content: chunk,
|
||||
Metadata: curMetadata,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
// joinDocs comines two documents with the separator used to split them.
|
||||
func joinDocs(docs []string, separator string) string {
|
||||
return strings.TrimSpace(strings.Join(docs, separator))
|
||||
}
|
||||
|
||||
// mergeSplits merges smaller splits into splits that are closer to the chunkSize.
|
||||
func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int, lenFunc func(string) int) []string { //nolint:cyclop
|
||||
docs := make([]string, 0)
|
||||
currentDoc := make([]string, 0)
|
||||
total := 0
|
||||
|
||||
for _, split := range splits {
|
||||
totalWithSplit := total + lenFunc(split)
|
||||
if len(currentDoc) != 0 {
|
||||
totalWithSplit += lenFunc(separator)
|
||||
}
|
||||
|
||||
maybePrintWarning(total, chunkSize)
|
||||
if totalWithSplit > chunkSize && len(currentDoc) > 0 {
|
||||
doc := joinDocs(currentDoc, separator)
|
||||
if doc != "" {
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
for len(currentDoc) > 0 && shouldPop(chunkOverlap, chunkSize, total, lenFunc(split), lenFunc(separator), len(currentDoc)) {
|
||||
total -= lenFunc(currentDoc[0]) //nolint:gosec
|
||||
if len(currentDoc) > 1 {
|
||||
total -= lenFunc(separator)
|
||||
}
|
||||
currentDoc = currentDoc[1:] //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
currentDoc = append(currentDoc, split)
|
||||
total += lenFunc(split)
|
||||
if len(currentDoc) > 1 {
|
||||
total += lenFunc(separator)
|
||||
}
|
||||
}
|
||||
|
||||
doc := joinDocs(currentDoc, separator)
|
||||
if doc != "" {
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
|
||||
return docs
|
||||
}
|
||||
|
||||
func maybePrintWarning(total, chunkSize int) {
|
||||
if total > chunkSize {
|
||||
log.Printf(
|
||||
"[WARN] created a chunk with size of %v, which is longer then the specified %v\n",
|
||||
total,
|
||||
chunkSize,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep popping if:
|
||||
// - the chunk is larger than the chunk overlap
|
||||
// - or if there are any chunks and the length is long
|
||||
func shouldPop(chunkOverlap, chunkSize, total, splitLen, separatorLen, currentDocLen int) bool {
|
||||
docsNeededToAddSep := 2
|
||||
if currentDocLen < docsNeededToAddSep {
|
||||
separatorLen = 0
|
||||
}
|
||||
|
||||
return currentDocLen > 0 && (total > chunkOverlap || (total+splitLen+separatorLen > chunkSize && total > 0))
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package textsplitter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
// TextSplitter is the standard interface for splitting texts.
|
||||
type TextSplitter interface {
|
||||
SplitText(text string) ([]string, error)
|
||||
}
|
||||
|
||||
type NoSplitterCharacter struct {
|
||||
}
|
||||
|
||||
func (s NoSplitterCharacter) SplitText(text string) ([]string, error) {
|
||||
return []string{text}, nil
|
||||
}
|
||||
|
||||
func NewTextSplitter(cfg *config.SplitterConfig) (TextSplitter, error) {
|
||||
switch cfg.Provider {
|
||||
case "recursive":
|
||||
return NewRecursiveCharacter(WithChunkSize(cfg.ChunkSize), WithChunkOverlap(cfg.ChunkOverlap), WithSeparators([]string{"\n\n", "\n", ".", "。", "?", "!", ";"})), nil
|
||||
case "nosplitter":
|
||||
return NoSplitterCharacter{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown text splitter type: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
355
plugins/golang-filter/mcp-server/servers/rag/tools.go
Normal file
355
plugins/golang-filter/mcp-server/servers/rag/tools.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-session/common"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// HandleCreateChunkFromText handles the creation of knowledge chunks from text input
|
||||
func HandleCreateChunkFromText(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
text, ok1 := arguments["text"].(string)
|
||||
title, ok2 := arguments["title"].(string)
|
||||
if !ok1 {
|
||||
return nil, fmt.Errorf("invalid text argument")
|
||||
}
|
||||
if !ok2 {
|
||||
return nil, fmt.Errorf("invalid title argument")
|
||||
}
|
||||
// Create knowledge chunks
|
||||
docs, err := ragClient.CreateChunkFromText(text, title)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create chunk failed, err: %w", err)
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("chunks created from text, title: %s", title),
|
||||
"data": docs,
|
||||
}
|
||||
|
||||
return buildCallToolResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleListChunks handles the listing of knowledge chunks
|
||||
func HandleListChunks(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
chunks, err := ragClient.ListChunks()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list chunks failed, err: %w", err)
|
||||
}
|
||||
return buildCallToolResult(chunks)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDeleteChunk handles the deletion of a knowledge chunk
|
||||
func HandleDeleteChunk(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
id, ok := arguments["id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid id argument")
|
||||
}
|
||||
|
||||
if err := ragClient.DeleteChunk(id); err != nil {
|
||||
return nil, fmt.Errorf("delete chunk failed, err: %w", err)
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("chunk deleted, id: %s", id),
|
||||
}
|
||||
|
||||
return buildCallToolResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleCreateSession handles the creation of a chat session
|
||||
func HandleCreateSession(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// TODO: Implement chat session creation logic
|
||||
result := map[string]interface{}{
|
||||
"session_id": "session-1",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
|
||||
return buildCallToolResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleGetSession handles retrieving session details
|
||||
func HandleGetSession(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
sessionId, ok := arguments["session_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid session_id argument")
|
||||
}
|
||||
|
||||
// TODO: Implement session details retrieval logic
|
||||
result := map[string]interface{}{
|
||||
"session_id": sessionId,
|
||||
"messages": []interface{}{},
|
||||
}
|
||||
|
||||
return buildCallToolResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleListSessions handles listing all sessions
|
||||
func HandleListSessions(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
// TODO: Implement session listing logic
|
||||
result := map[string]interface{}{
|
||||
"sessions": []interface{}{},
|
||||
"total": 0,
|
||||
}
|
||||
|
||||
return buildCallToolResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDeleteSession handles the deletion of a session
|
||||
func HandleDeleteSession(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
sessionId, ok := arguments["session_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid session_id argument")
|
||||
}
|
||||
|
||||
// TODO: Implement session deletion logic
|
||||
result := map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "Session deleted",
|
||||
"session_id": sessionId,
|
||||
}
|
||||
|
||||
return buildCallToolResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleSearch handles semantic search functionality
|
||||
func HandleSearch(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
query, ok := arguments["query"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid query argument")
|
||||
}
|
||||
topK, ok := arguments["topk"].(int)
|
||||
if !ok {
|
||||
topK = ragClient.config.RAG.TopK
|
||||
}
|
||||
|
||||
threshold, ok := arguments["threshold"].(float64)
|
||||
if !ok {
|
||||
threshold = ragClient.config.RAG.Threshold
|
||||
}
|
||||
|
||||
searchResult, err := ragClient.SearchChunks(query, int(topK), threshold)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search chunks failed, err: %w", err)
|
||||
}
|
||||
return buildCallToolResult(searchResult)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleChat handles chat interactions using LLM
|
||||
func HandleChat(ragClient *RAGClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
query, ok := arguments["query"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid query argument")
|
||||
}
|
||||
// Generate response using RAGClient's LLM
|
||||
reply, err := ragClient.Chat(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chat failed, err: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(reply)
|
||||
}
|
||||
}
|
||||
|
||||
// buildCallToolResult builds the call tool result
|
||||
func buildCallToolResult(results any) (*mcp.CallToolResult, error) {
|
||||
jsonData, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal results: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Schema functions
|
||||
|
||||
// GetCreateChunkFromTextSchema returns the schema for create chunk from text tool
|
||||
func GetCreateChunkFromTextSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text content to create chunks from"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "The title of text content"
|
||||
}
|
||||
},
|
||||
"required": ["text", "title"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetListKnowledgeSchema returns the schema for list knowledge tool
|
||||
func GetListKnowledgeSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetGetKnowledgeSchema returns the schema for get knowledge tool
|
||||
func GetGetKnowledgeSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The knowledge ID"
|
||||
}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetDeleteKnowledgeSchema returns the schema for delete knowledge tool
|
||||
func GetDeleteKnowledgeSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The knowledge ID to delete"
|
||||
}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetListChunksSchema returns the schema for list chunks tool
|
||||
func GetListChunksSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetDeleteChunkSchema returns the schema for delete chunk tool
|
||||
func GetDeleteChunkSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "The chunk ID to delete"
|
||||
}
|
||||
},
|
||||
"required": ["id"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetCreateSessionSchema returns the schema for create session tool
|
||||
func GetCreateSessionSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetGetSessionSchema returns the schema for get session tool
|
||||
func GetGetSessionSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The session ID"
|
||||
}
|
||||
},
|
||||
"required": ["session_id"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetListSessionsSchema returns the schema for list sessions tool
|
||||
func GetListSessionsSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetDeleteSessionSchema returns the schema for delete session tool
|
||||
func GetDeleteSessionSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "The session ID to delete"
|
||||
}
|
||||
},
|
||||
"required": ["session_id"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetSearchSchema returns the schema for search tool
|
||||
func GetSearchSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query"
|
||||
},
|
||||
"topk": {
|
||||
"type": "integer",
|
||||
"description": "The number of top results to return (optional, default 10)"
|
||||
},
|
||||
"threshold": {
|
||||
"type": "number",
|
||||
"description": "The relevance score threshold for filtering results (optional, default 0.5)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}`)
|
||||
}
|
||||
|
||||
// GetChatSchema returns the schema for chat tool
|
||||
func GetChatSchema() json.RawMessage {
|
||||
return json.RawMessage(`{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "User query"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}`)
|
||||
}
|
||||
495
plugins/golang-filter/mcp-server/servers/rag/vectordb/milvus.go
Normal file
495
plugins/golang-filter/mcp-server/servers/rag/vectordb/milvus.go
Normal file
@@ -0,0 +1,495 @@
|
||||
package vectordb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/client"
|
||||
"github.com/milvus-io/milvus-sdk-go/v2/entity"
|
||||
)
|
||||
|
||||
const (
|
||||
MILVUS_DUMMY_DIM = 8
|
||||
MILVUS_PROVIDER_TYPE = "milvus"
|
||||
)
|
||||
|
||||
// MilvusProviderInitializer initializes the Milvus vector store provider
|
||||
type milvusProviderInitializer struct{}
|
||||
|
||||
// InitConfig initializes the configuration with default values if not set
|
||||
func (m *milvusProviderInitializer) InitConfig(cfg *config.VectorDBConfig) error {
|
||||
if cfg.Provider != MILVUS_PROVIDER_TYPE {
|
||||
return fmt.Errorf("provider type mismatch: expected %s, got %s", MILVUS_PROVIDER_TYPE, cfg.Provider)
|
||||
}
|
||||
|
||||
// Set default values
|
||||
if cfg.Host == "" {
|
||||
cfg.Host = "localhost"
|
||||
}
|
||||
if cfg.Port == 0 {
|
||||
cfg.Port = 19530
|
||||
}
|
||||
if cfg.Database == "" {
|
||||
cfg.Database = "default"
|
||||
}
|
||||
|
||||
if cfg.Collection == "" {
|
||||
cfg.Collection = schema.DEFAULT_DOCUMENT_COLLECTION
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConfig validates the configuration parameters
|
||||
func (m *milvusProviderInitializer) ValidateConfig(cfg *config.VectorDBConfig) error {
|
||||
if cfg.Host == "" {
|
||||
return fmt.Errorf("milvus host is required")
|
||||
}
|
||||
if cfg.Port <= 0 {
|
||||
return fmt.Errorf("milvus port must be positive")
|
||||
}
|
||||
|
||||
if cfg.Database == "" {
|
||||
return fmt.Errorf("milvus database is required")
|
||||
}
|
||||
|
||||
if cfg.Collection == "" {
|
||||
return fmt.Errorf("milvus document collection is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateProvider creates a new Milvus vector store provider instance
|
||||
func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
|
||||
if err := m.InitConfig(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.ValidateConfig(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider, err := NewMilvusProvider(cfg, dim)
|
||||
return provider, err
|
||||
}
|
||||
|
||||
// MilvusProvider implements the vector store provider interface for Milvus
|
||||
type MilvusProvider struct {
|
||||
client client.Client
|
||||
config *config.VectorDBConfig
|
||||
Collection string
|
||||
}
|
||||
|
||||
// NewMilvusProvider creates a new instance of MilvusProvider
|
||||
func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
|
||||
// Create Milvus client
|
||||
connectParam := client.Config{
|
||||
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
}
|
||||
|
||||
connectParam.DBName = cfg.Database
|
||||
// Add authentication if credentials are provided
|
||||
if cfg.Username != "" && cfg.Password != "" {
|
||||
connectParam.Username = cfg.Username
|
||||
connectParam.Password = cfg.Password
|
||||
}
|
||||
|
||||
milvusClient, err := client.NewClient(context.Background(), connectParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create milvus client: %w", err)
|
||||
}
|
||||
|
||||
provider := &MilvusProvider{
|
||||
client: milvusClient,
|
||||
config: cfg,
|
||||
Collection: cfg.Collection,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := provider.CreateCollection(ctx, dim); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// CreateCollection creates a new collection with the specified dimension
|
||||
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
|
||||
// Check if collection exists
|
||||
document_exists, err := m.client.HasCollection(ctx, m.Collection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
|
||||
}
|
||||
|
||||
if !document_exists {
|
||||
fmt.Printf("create collection %s\n", m.Collection)
|
||||
// Create schema
|
||||
schema := entity.NewSchema().
|
||||
WithName(m.Collection).
|
||||
WithDescription("Knowledge document collection").
|
||||
WithAutoID(false).
|
||||
WithDynamicFieldEnabled(false)
|
||||
|
||||
// Add fields based on schema.Document structure
|
||||
// Primary key field - ID
|
||||
pkField := entity.NewField().
|
||||
WithName("id").
|
||||
WithDataType(entity.FieldTypeVarChar).
|
||||
WithMaxLength(256).
|
||||
WithIsPrimaryKey(true).
|
||||
WithIsAutoID(false)
|
||||
schema.WithField(pkField)
|
||||
|
||||
// Content field
|
||||
contentField := entity.NewField().
|
||||
WithName("content").
|
||||
WithDataType(entity.FieldTypeVarChar).
|
||||
WithMaxLength(8192)
|
||||
schema.WithField(contentField)
|
||||
|
||||
// Vector field
|
||||
vectorField := entity.NewField().
|
||||
WithName("vector").
|
||||
WithDataType(entity.FieldTypeFloatVector).
|
||||
WithDim(int64(dim))
|
||||
schema.WithField(vectorField)
|
||||
|
||||
// Metadata field
|
||||
metadataField := entity.NewField().
|
||||
WithName("metadata").
|
||||
WithDataType(entity.FieldTypeJSON)
|
||||
schema.WithField(metadataField)
|
||||
|
||||
// CreatedAt field (stored as Unix timestamp)
|
||||
createdAtField := entity.NewField().
|
||||
WithName("created_at").
|
||||
WithDataType(entity.FieldTypeInt64)
|
||||
schema.WithField(createdAtField)
|
||||
|
||||
// Create collection
|
||||
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create collection: %w", err)
|
||||
}
|
||||
|
||||
// Create vector index
|
||||
vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create vector index: %w", err)
|
||||
}
|
||||
|
||||
err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create vector index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load collection
|
||||
err = m.client.LoadCollection(ctx, m.Collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load document collection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropCollection removes the collection from the database
|
||||
func (m *MilvusProvider) DropCollection(ctx context.Context) error {
|
||||
// Check if collection exists
|
||||
exists, err := m.client.HasCollection(ctx, m.Collection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("collection %s does not exist", m.Collection)
|
||||
}
|
||||
// Drop collection
|
||||
err = m.client.DropCollection(ctx, m.Collection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop collection: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDoc adds documents to the vector database
|
||||
func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) error {
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Prepare data
|
||||
ids := make([]string, len(docs))
|
||||
contents := make([]string, len(docs))
|
||||
vectors := make([][]float32, len(docs))
|
||||
metadatas := make([][]byte, len(docs))
|
||||
createdAts := make([]int64, len(docs))
|
||||
|
||||
for i, doc := range docs {
|
||||
ids[i] = doc.ID
|
||||
contents[i] = doc.Content
|
||||
|
||||
// Convert vector type
|
||||
vectorFloat32 := make([]float32, len(doc.Vector))
|
||||
for j, v := range doc.Vector {
|
||||
vectorFloat32[j] = float32(v)
|
||||
}
|
||||
vectors[i] = vectorFloat32
|
||||
|
||||
// Serialize metadata
|
||||
metadataBytes, err := json.Marshal(doc.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
|
||||
}
|
||||
metadatas[i] = metadataBytes
|
||||
|
||||
createdAts[i] = doc.CreatedAt.UnixMilli()
|
||||
}
|
||||
|
||||
// Build insert data
|
||||
columns := []entity.Column{
|
||||
entity.NewColumnVarChar("id", ids),
|
||||
entity.NewColumnVarChar("content", contents),
|
||||
entity.NewColumnFloatVector("vector", len(vectors[0]), vectors),
|
||||
entity.NewColumnJSONBytes("metadata", metadatas),
|
||||
entity.NewColumnInt64("created_at", createdAts),
|
||||
}
|
||||
|
||||
// Insert data
|
||||
_, err := m.client.Insert(ctx, m.Collection, "", columns...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert documents: %w", err)
|
||||
}
|
||||
|
||||
// Flush data
|
||||
err = m.client.Flush(ctx, m.Collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush collection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteDoc deletes a document by its ID
|
||||
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
|
||||
// Build delete expression
|
||||
expr := fmt.Sprintf(`id == "%s"`, id)
|
||||
// Delete data
|
||||
err := m.client.Delete(ctx, m.Collection, "", expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
|
||||
}
|
||||
|
||||
// Flush data
|
||||
err = m.client.Flush(ctx, m.Collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateDoc updates documents by first deleting existing ones and then adding new ones
|
||||
func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document) error {
|
||||
// Delete existing documents
|
||||
ids := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
ids[i] = doc.ID
|
||||
}
|
||||
if err := m.DeleteDocs(ctx, ids); err != nil {
|
||||
return fmt.Errorf("failed to delete existing documents: %w", err)
|
||||
}
|
||||
// Add new documents
|
||||
if err := m.AddDoc(ctx, docs); err != nil {
|
||||
return fmt.Errorf("failed to add new documents: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchDocs performs similarity search for documents
|
||||
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
|
||||
if options == nil {
|
||||
options = &schema.SearchOptions{TopK: 10}
|
||||
}
|
||||
// Build search parameters
|
||||
sp, _ := entity.NewIndexHNSWSearchParam(16)
|
||||
// Build filter expression
|
||||
expr := ""
|
||||
searchResults, err := m.client.Search(
|
||||
ctx,
|
||||
m.Collection,
|
||||
[]string{}, // partition names
|
||||
expr, // filter expression
|
||||
[]string{"id", "content", "metadata", "created_at"}, // output fields
|
||||
[]entity.Vector{entity.FloatVector(vector)},
|
||||
"vector", // anns_field
|
||||
entity.IP, // metric_type
|
||||
options.TopK,
|
||||
sp,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search documents: %w", err)
|
||||
}
|
||||
|
||||
// Parse results
|
||||
var results []schema.SearchResult
|
||||
for _, result := range searchResults {
|
||||
for i := 0; i < result.ResultCount; i++ {
|
||||
id, _ := result.IDs.Get(i)
|
||||
score := result.Scores[i]
|
||||
// Get field data
|
||||
var content string
|
||||
var metadata map[string]interface{}
|
||||
|
||||
for _, field := range result.Fields {
|
||||
switch field.Name() {
|
||||
case "content":
|
||||
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
|
||||
if contentVal, err := contentCol.Get(i); err == nil {
|
||||
if contentStr, ok := contentVal.(string); ok {
|
||||
content = contentStr
|
||||
}
|
||||
}
|
||||
}
|
||||
case "metadata":
|
||||
if metaCol, ok := field.(*entity.ColumnJSONBytes); ok {
|
||||
if metaVal, err := metaCol.Get(i); err == nil {
|
||||
if metaBytes, ok := metaVal.([]byte); ok {
|
||||
if err := json.Unmarshal(metaBytes, &metadata); err != nil {
|
||||
metadata = make(map[string]interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
searchResult := schema.SearchResult{
|
||||
Document: schema.Document{
|
||||
ID: fmt.Sprintf("%s", id),
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
},
|
||||
Score: float64(score),
|
||||
}
|
||||
results = append(results, searchResult)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// DeleteDocs deletes multiple documents by their IDs
|
||||
func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build delete expression
|
||||
// Milvus expects string values to be quoted within the expression, otherwise the parser will
|
||||
// treat the hyphen inside UUID as a minus operator and raise a parse error.
|
||||
quotedIDs := make([]string, len(ids))
|
||||
for i, id := range ids {
|
||||
quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
|
||||
}
|
||||
expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ","))
|
||||
|
||||
// Delete data
|
||||
err := m.client.Delete(ctx, m.Collection, "", expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete documents: %w", err)
|
||||
}
|
||||
// Flush data
|
||||
err = m.client.Flush(ctx, m.Collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDocs retrieves all documents with optional limit
|
||||
func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Document, error) {
|
||||
// Build query expression
|
||||
expr := ""
|
||||
// Query all relevant documents
|
||||
queryResult, err := m.client.Query(
|
||||
ctx,
|
||||
m.Collection,
|
||||
[]string{}, // partitions
|
||||
expr, // filter condition
|
||||
[]string{"id", "content", "metadata", "created_at"},
|
||||
client.WithOffset(0), client.WithLimit(int64(limit)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query documents: %w", err)
|
||||
}
|
||||
|
||||
if len(queryResult) == 0 {
|
||||
return []schema.Document{}, nil
|
||||
}
|
||||
|
||||
rowCount := queryResult[0].Len()
|
||||
documents := make([]schema.Document, 0, rowCount)
|
||||
|
||||
// Parse query results
|
||||
for i := 0; i < rowCount; i++ {
|
||||
var (
|
||||
id string
|
||||
content string
|
||||
metadata map[string]interface{}
|
||||
createdAt int64
|
||||
)
|
||||
|
||||
for _, col := range queryResult {
|
||||
switch col.Name() {
|
||||
case "id":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
id = v.(string)
|
||||
}
|
||||
case "content":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
content = v.(string)
|
||||
}
|
||||
case "metadata":
|
||||
if v, err := col.(*entity.ColumnJSONBytes).Get(i); err == nil {
|
||||
if bytes, ok := v.([]byte); ok {
|
||||
_ = json.Unmarshal(bytes, &metadata)
|
||||
}
|
||||
}
|
||||
case "created_at":
|
||||
if v, err := col.(*entity.ColumnInt64).Get(i); err == nil {
|
||||
createdAt = v.(int64)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
doc := schema.Document{
|
||||
ID: id,
|
||||
Content: content,
|
||||
Metadata: metadata,
|
||||
CreatedAt: time.UnixMilli(createdAt),
|
||||
}
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
// GetProviderType returns the provider type identifier
|
||||
func (m *MilvusProvider) GetProviderType() string {
|
||||
return MILVUS_PROVIDER_TYPE
|
||||
}
|
||||
|
||||
// Close closes the connection to the Milvus server
|
||||
func (m *MilvusProvider) Close() error {
|
||||
if m.client != nil {
|
||||
return m.client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// joinStrings joins a slice of strings with the given separator
|
||||
func joinStrings(elems []string, sep string) string {
|
||||
return strings.Join(elems, sep)
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package vectordb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
func TestNewMilvusProvider(t *testing.T) {
|
||||
_, err := getMilvusProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("expected error when connecting to unavailable Milvus server, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func getMilvusProvider() (VectorStoreProvider, error) {
|
||||
cfg := &config.VectorDBConfig{
|
||||
Provider: PROVIDER_TYPE_MILVUS,
|
||||
Host: "127.0.0.1",
|
||||
Port: 19530, // unlikely to be used
|
||||
Database: "default",
|
||||
Collection: "knowledge_test",
|
||||
}
|
||||
|
||||
provider, err := NewMilvusProvider(cfg, 128)
|
||||
return provider, err
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package vectordb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/schema"
|
||||
)
|
||||
|
||||
// Provider types constants
|
||||
const (
|
||||
PROVIDER_TYPE_CHROMA = "chroma"
|
||||
PROVIDER_TYPE_PINECONE = "pinecone"
|
||||
PROVIDER_TYPE_WEAVIATE = "weaviate"
|
||||
PROVIDER_TYPE_QDRANT = "qdrant"
|
||||
PROVIDER_TYPE_MILVUS = "milvus"
|
||||
PROVIDER_TYPE_FAISS = "faiss"
|
||||
PROVIDER_TYPE_ELASTICSEARCH = "elasticsearch"
|
||||
)
|
||||
|
||||
// VectorStoreBase defines the base interface for vector store implementations
|
||||
type VectorStoreProvider interface {
|
||||
// CreateVectorStore creates a new vector store
|
||||
CreateCollection(ctx context.Context, dim int) error
|
||||
|
||||
// DropVectorStore drops the vector store
|
||||
DropCollection(ctx context.Context) error
|
||||
|
||||
// AddDoc adds documents to the vector store
|
||||
AddDoc(ctx context.Context, docs []schema.Document) error
|
||||
|
||||
// DeleteDoc deletes documents by filename from the vector store
|
||||
DeleteDoc(ctx context.Context, id string) error
|
||||
|
||||
// UpdateDoc updates documents in the vector store
|
||||
UpdateDoc(ctx context.Context, docs []schema.Document) error
|
||||
|
||||
// SearchDocs searches for similar documents in the vector store
|
||||
SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error)
|
||||
|
||||
// DeleteDocs deletes documents by IDs from the vector store
|
||||
DeleteDocs(ctx context.Context, ids []string) error
|
||||
|
||||
// ListDocs lists documents in the vector store
|
||||
ListDocs(ctx context.Context, limit int) ([]schema.Document, error)
|
||||
|
||||
// GetProviderType returns the type of the vector store provider
|
||||
GetProviderType() string
|
||||
}
|
||||
|
||||
// VectorDBProviderInitializer defines the interface for vector database provider initializers
|
||||
type VectorDBProviderInitializer interface {
|
||||
// CreateProvider creates a new vector database provider instance
|
||||
CreateProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
vectorDBProviderInitializers = map[string]VectorDBProviderInitializer{
|
||||
PROVIDER_TYPE_MILVUS: &milvusProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
// CreateVectorDBProvider creates a vector database provider instance
|
||||
func NewVectorDBProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
|
||||
initializer, exists := vectorDBProviderInitializers[cfg.Provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unknown vector database provider: %s", cfg.Provider)
|
||||
}
|
||||
// Create provider
|
||||
return initializer.CreateProvider(cfg, dim)
|
||||
}
|
||||
Reference in New Issue
Block a user