add vectordb mapping (#2968)

This commit is contained in:
Jun
2025-10-06 15:08:13 +08:00
committed by GitHub
parent 45a11734bd
commit aebe354055
14 changed files with 1188 additions and 564 deletions

View File

@@ -84,10 +84,11 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
| llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 |
| llm.temperature | float | 可选 | 0.5 | 温度参数 |
| **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) |
| embedding.provider | string | 必填 | dashscope | 嵌入提供商openai或dashscope |
| embedding.provider | string | 必填 | openai | 嵌入提供商:支持openai协议的任意供应商 |
| embedding.api_key | string | 必填 | - | 嵌入API密钥 |
| embedding.base_url | string | 可选 | | 嵌入API基础URL |
| embedding.model | string | 必填 | text-embedding-v4 | 嵌入模型名称 |
| embedding.model | string | 必填 | text-embedding-ada-002 | 嵌入模型名称 |
| embedding.dimensions | integer | 可选 | 1536 | 嵌入维度 |
| **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) |
| vectordb.provider | string | 必填 | milvus | 向量数据库提供商 |
| vectordb.host | string | 必填 | localhost | 数据库主机地址 |
@@ -96,6 +97,17 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
| vectordb.collection | string | 必填 | test_collection | 集合名称 |
| vectordb.username | string | 可选 | - | 数据库用户名 |
| vectordb.password | string | 可选 | - | 数据库密码 |
| **vectordb.mapping** | object | 可选 | - | 字段映射配置 |
| vectordb.mapping.fields | array | 可选 | - | 字段映射列表 |
| vectordb.mapping.fields[].standard_name | string | 必填 | - | 标准字段名称(如 id, content, vector 等) |
| vectordb.mapping.fields[].raw_name | string | 必填 | - | 原始字段名称(数据库中的实际字段名) |
| vectordb.mapping.fields[].properties | object | 可选 | - | 字段属性(如 auto_id, max_length 等) |
| vectordb.mapping.index | object | 可选 | - | 索引配置 |
| vectordb.mapping.index.index_type | string | 必填 | - | 索引类型(如 FLAT, IVF_FLAT, HNSW 等) |
| vectordb.mapping.index.params | object | 可选 | - | 索引参数(根据索引类型不同而异) |
| vectordb.mapping.search | object | 可选 | - | 搜索配置 |
| vectordb.mapping.search.metric_type | string | 可选 | L2 | 度量类型(如 L2, IP, COSINE 等) |
| vectordb.mapping.search.params | object | 可选 | - | 搜索参数(如 nprobe, ef_search 等)
### higress-config 配置样例
@@ -143,27 +155,54 @@ data:
temperature: 0.5
max_tokens: 2048
embedding:
provider: dashscope
provider: openai
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
api_key: sk-xxx
model: text-embedding-v4
dimensions: 1536
vectordb:
provider: milvus
host: <milvus IP>
host: localhost
port: 19530
database: default
collection: test_collection
```
collection: test_rag
mapping:
fields:
- standard_name: id
raw_name: id
properties:
auto_id: false
max_length: 256
- standard_name: content
raw_name: content
properties:
max_length: 8192
- standard_name: vector
raw_name: vector
- standard_name: metadata
raw_name: metadata
- standard_name: created_at
raw_name: created_at
index:
index_type: HNSW
params:
M: 4
efConstruction: 32
search:
metric_type: IP
params:
ef: 32
```
### 支持的提供商
#### Embedding
- **OpenAI**
- **DashScope**
- **OpenAI 兼容**
#### Vector Database
- **Milvus**
#### LLM
- **OpenAI**
- **OpenAI 兼容**
## 如何测试数据集的效果

View File

@@ -1,5 +1,7 @@
package config
import "fmt"
// Config represents the main configuration structure for the MCP server
type Config struct {
RAG RAGConfig `json:"rag" yaml:"rag"`
@@ -34,20 +36,148 @@ type LLMConfig struct {
// EmbeddingConfig defines configuration for embedding models
type EmbeddingConfig struct {
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
Model string `json:"model,omitempty" yaml:"model,omitempty"`
Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"`
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
Model string `json:"model,omitempty" yaml:"model,omitempty"`
Dimensions int `json:"dimensions,omitempty" yaml:"dimension,omitempty"`
}
// VectorDBConfig defines configuration for vector databases
type VectorDBConfig struct {
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
Host string `json:"host,omitempty" yaml:"host,omitempty"`
Port int `json:"port,omitempty" yaml:"port,omitempty"`
Database string `json:"database,omitempty" yaml:"database,omitempty"`
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
Password string `json:"password,omitempty" yaml:"password,omitempty"`
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
Host string `json:"host,omitempty" yaml:"host,omitempty"`
Port int `json:"port,omitempty" yaml:"port,omitempty"`
Database string `json:"database,omitempty" yaml:"database,omitempty"`
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
Password string `json:"password,omitempty" yaml:"password,omitempty"`
Mapping MappingConfig `json:"mapping,omitempty" yaml:"mapping,omitempty"`
}
// MappingConfig defines field mapping configuration for vector databases
type MappingConfig struct {
Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
Index IndexConfig `json:"index,omitempty" yaml:"index,omitempty"`
Search SearchConfig `json:"search,omitempty" yaml:"search,omitempty"`
}
// // CollectionMapping defines field mapping for collection
// type CollectionMapping struct {
// Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
// }
// FieldMapping defines mapping for a single field
type FieldMapping struct {
StandardName string `json:"standard_name" yaml:"standard_name"`
RawName string `json:"raw_name" yaml:"raw_name"`
Properties map[string]interface{} `json:"properties,omitempty" yaml:"properties,omitempty"`
}
func (f FieldMapping) IsPrimaryKey() bool {
return f.StandardName == "id"
}
func (f FieldMapping) IsAutoID() bool {
if f.Properties == nil {
return false
}
autoID, ok := f.Properties["auto_id"].(bool)
if !ok {
return false
}
return autoID
}
func (f FieldMapping) IsVectorField() bool {
return f.StandardName == "vector"
}
func (f FieldMapping) MaxLength() int {
if f.Properties == nil {
return 0
}
maxLength, ok := f.Properties["max_length"].(int)
if !ok {
return 256
}
return maxLength
}
// IndexConfig defines configuration for index parameters
type IndexConfig struct {
// Index type, e.g., IVF_FLAT, IVF_SQ8, HNSW, etc.
IndexType string `json:"index_type" yaml:"index_type"`
// Index parameter configuration
Params map[string]interface{} `json:"params" yaml:"params"`
}
func (i IndexConfig) ParamsString(key string) (string, error) {
if mVal, ok := i.Params[key].(string); ok {
return mVal, nil
}
return "", fmt.Errorf("params %s not found", key)
}
func (i IndexConfig) ParamsInt64(key string) (int64, error) {
if mVal, ok := i.Params[key].(int64); ok {
return mVal, nil
}
if mVal, ok := i.Params[key].(int); ok {
return int64(mVal), nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i IndexConfig) ParamsFloat64(key string) (float64, error) {
if mVal, ok := i.Params[key].(float64); ok {
return mVal, nil
}
if mVal, ok := i.Params[key].(float32); ok {
return float64(mVal), nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i IndexConfig) ParamsBool(key string) (bool, error) {
if mVal, ok := i.Params[key].(bool); ok {
return mVal, nil
}
return false, fmt.Errorf("params %s not found", key)
}
// SearchConfig defines configuration for search parameters
type SearchConfig struct {
// Metric type, e.g., L2, IP, etc.
MetricType string `json:"metric_type,omitempty" yaml:"metric_type,omitempty"`
// Search parameter configuration
Params map[string]interface{} `json:"params" yaml:"params"`
}
func (i SearchConfig) ParamsString(key string) (string, error) {
if mVal, ok := i.Params[key].(string); ok {
return mVal, nil
}
return "", fmt.Errorf("params %s not found", key)
}
func (i SearchConfig) ParamsInt64(key string) (int64, error) {
if mVal, ok := i.Params[key].(int64); ok {
return mVal, nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i SearchConfig) ParamsFloat64(key string) (float64, error) {
if mVal, ok := i.Params[key].(float64); ok {
return mVal, nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i SearchConfig) ParamsBool(key string) (bool, error) {
if mVal, ok := i.Params[key].(bool); ok {
return mVal, nil
}
return false, fmt.Errorf("params %s not found", key)
}

View File

@@ -1,169 +0,0 @@
package embedding
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
const (
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
DASHSCOPE_PORT = 443
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4"
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)
var dashScopeConfig dashScopeProviderConfig
type dashScopeProviderInitializer struct {
}
type dashScopeProviderConfig struct {
apiKey string
model string
}
func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) {
dashScopeConfig.apiKey = config.APIKey
dashScopeConfig.model = config.Model
}
func (c *dashScopeProviderInitializer) ValidateConfig() error {
if dashScopeConfig.apiKey == "" {
return errors.New("[DashScope] apiKey is required")
}
return nil
}
func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
c.InitConfig(config)
err := c.ValidateConfig()
if err != nil {
return nil, err
}
headers := map[string]string{
"Authorization": "Bearer " + config.APIKey,
"Content-Type": "application/json",
}
httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers)
return &DashScopeProvider{
config: dashScopeConfig,
client: httpClient,
}, nil
}
func (d *DashScopeProvider) GetProviderType() string {
return PROVIDER_TYPE_DASHSCOPE
}
type Embedding struct {
Embedding []float32 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type Input struct {
Texts []string `json:"texts"`
}
type Params struct {
TextType string `json:"text_type"`
}
type Response struct {
RequestID string `json:"request_id"`
Output Output `json:"output"`
Usage Usage `json:"usage"`
}
type Output struct {
Embeddings []Embedding `json:"embeddings"`
}
type Usage struct {
TotalTokens int `json:"total_tokens"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Params `json:"parameters"`
}
type Document struct {
Vector []float64 `json:"vector"`
Fields map[string]string `json:"fields"`
}
type DashScopeProvider struct {
config dashScopeProviderConfig
client *common.HTTPClient
}
func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) {
model := d.config.model
if model == "" {
model = DASHSCOPE_DEFAULT_MODEL_NAME
}
if dashScopeConfig.apiKey == "" {
return EmbeddingRequest{}, errors.New("dashScopeKey is empty")
}
data := EmbeddingRequest{
Model: model,
Input: Input{
Texts: texts,
},
Parameters: Params{
TextType: "query",
},
}
return data, nil
}
type Result struct {
ID string `json:"id"`
Vector []float32 `json:"vector,omitempty"`
Fields map[string]interface{} `json:"fields"`
Score float64 `json:"score"`
}
func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
var resp Response
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (d *DashScopeProvider) GetEmbedding(
ctx context.Context,
queryString string) ([]float32, error) {
requestData, err := d.constructRequestData([]string{queryString})
if err != nil {
return nil, fmt.Errorf("failed to construct request data: %v", err)
}
responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
embeddingResp, err := d.parseTextEmbedding(responseBody)
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
if len(embeddingResp.Output.Embeddings) == 0 {
return nil, errors.New("no embedding found in response")
}
return embeddingResp.Output.Embeddings[0].Embedding, nil
}

View File

@@ -2,160 +2,93 @@ package embedding
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
)
const (
OPENAI_DOMAIN = "api.openai.com"
OPENAI_PORT = 443
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
OPENAI_ENDPOINT = "/v1/embeddings"
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-ada-002"
)
type openAIProviderInitializer struct {
}
var openAIConfig openAIProviderConfig
type openAIProviderConfig struct {
baseUrl string
apiKey string
model string
}
func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) {
openAIConfig.apiKey = config.APIKey
openAIConfig.model = config.Model
openAIConfig.baseUrl = config.BaseURL
}
func (c *openAIProviderInitializer) ValidateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openAI] apiKey is required")
func (c *openAIProviderInitializer) validateConfig(config *config.EmbeddingConfig) error {
if config.APIKey == "" {
return errors.New("[openai embbeding] apiKey is required")
}
if config.Model == "" {
config.Model = OPENAI_DEFAULT_MODEL_NAME
}
if config.Dimensions <= 0 {
config.Dimensions = 1536
}
return nil
}
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
c.InitConfig(config)
err := c.ValidateConfig()
if err != nil {
if err := c.validateConfig(&config); err != nil {
return nil, err
}
// 创建 OpenAI 客户端
var clientOptions []option.RequestOption
clientOptions = append(clientOptions, option.WithAPIKey(config.APIKey))
if openAIConfig.model == "" {
openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME
// 如果设置了自定义 baseURL则使用它
if config.BaseURL != "" {
clientOptions = append(clientOptions, option.WithBaseURL(config.BaseURL))
}
if openAIConfig.baseUrl == "" {
openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN)
}
headers := map[string]string{
"Authorization": "Bearer " + config.APIKey,
"Content-Type": "application/json",
}
httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers)
// 创建 OpenAI 客户端
client := openai.NewClient(clientOptions...)
return &OpenAIProvider{
config: openAIConfig,
client: httpClient,
client: &client,
model: config.Model,
dimensions: config.Dimensions,
}, nil
}
func (o *OpenAIProvider) GetProviderType() string {
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
type OpenAIProvider struct {
client *openai.Client
model string
dimensions int
}
func (e *OpenAIProvider) GetProviderType() string {
return PROVIDER_TYPE_OPENAI
}
type OpenAIResponse struct {
Object string `json:"object"`
Data []OpenAIResult `json:"data"`
Model string `json:"model"`
Error *OpenAIError `json:"error"`
}
type OpenAIResult struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type OpenAIError struct {
Message string `json:"prompt_tokens"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
}
type OpenAIEmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type OpenAIProvider struct {
config openAIProviderConfig
client *common.HTTPClient
}
func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) {
if text == "" {
return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty")
// GetEmbedding generates vector embedding for the given text
func (e *OpenAIProvider) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
params := openai.EmbeddingNewParams{
Model: e.model,
Input: openai.EmbeddingNewParamsInputUnion{
OfString: openai.String(text),
},
Dimensions: openai.Int(int64(e.dimensions)),
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
}
if openAIConfig.apiKey == "" {
return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty")
}
model := o.config.model
if model == "" {
model = OPENAI_DEFAULT_MODEL_NAME
}
data := OpenAIEmbeddingRequest{
Input: text,
Model: model,
}
return data, nil
}
func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
var resp OpenAIResponse
err := json.Unmarshal(responseBody, &resp)
embeddingResp, err := e.client.Embeddings.New(ctx, params)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
return &resp, nil
}
func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) {
requestData, err := o.constructRequestData(queryString)
if err != nil {
return nil, fmt.Errorf("failed to construct request data: %v", err)
}
responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
resp, err := o.parseTextEmbedding(responseBody)
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
if resp.Error != nil {
return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message)
}
if len(resp.Data) == 0 {
return nil, errors.New("no embedding found in response")
}
return resp.Data[0].Embedding, nil
if len(embeddingResp.Data) == 0 {
return nil, fmt.Errorf("empty embedding response")
}
// Convert []float64 to []float32
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
}
return embedding, nil
}

View File

@@ -10,21 +10,21 @@ import (
// Provider type constants for different embedding services
const (
// DashScope embedding service
PROVIDER_TYPE_DASHSCOPE = "dashscope"
PROVIDER_TYPE_DASHSCOPE = "dashscope"
// TextIn embedding service
PROVIDER_TYPE_TEXTIN = "textin"
PROVIDER_TYPE_TEXTIN = "textin"
// Cohere embedding service
PROVIDER_TYPE_COHERE = "cohere"
PROVIDER_TYPE_COHERE = "cohere"
// OpenAI embedding service
PROVIDER_TYPE_OPENAI = "openai"
PROVIDER_TYPE_OPENAI = "openai"
// Ollama embedding service
PROVIDER_TYPE_OLLAMA = "ollama"
PROVIDER_TYPE_OLLAMA = "ollama"
// HuggingFace embedding service
PROVIDER_TYPE_HUGGINGFACE = "huggingface"
// XFYun embedding service
PROVIDER_TYPE_XFYUN = "xfyun"
PROVIDER_TYPE_XFYUN = "xfyun"
// Azure embedding service
PROVIDER_TYPE_AZURE = "azure"
PROVIDER_TYPE_AZURE = "azure"
)
// Factory interface for creating Provider instances
@@ -36,8 +36,7 @@ type providerInitializer interface {
// Maps provider types to their initializers
var (
providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
}
)

View File

@@ -2,133 +2,105 @@ package llm
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
"github.com/openai/openai-go/v2/packages/param"
)
const (
OPENAI_CHAT_ENDPOINT = "/chat/completions"
OPENAI_DEFAULT_MODEL = "gpt-4o"
)
// openAI specific configuration captured after initialization.
type openAIProviderConfig struct {
apiKey string
baseURL string
type OpenAIProvider struct {
client *openai.Client
model string
maxTokens int
temperature float64
maxTokens int
}
type openAIProviderInitializer struct{}
var openAIConfig openAIProviderConfig
func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) {
openAIConfig.apiKey = c.APIKey
openAIConfig.baseURL = c.BaseURL
openAIConfig.model = c.Model
if openAIConfig.model == "" {
openAIConfig.model = OPENAI_DEFAULT_MODEL
}
if openAIConfig.baseURL == "" {
openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint
}
openAIConfig.maxTokens = c.MaxTokens
openAIConfig.temperature = c.Temperature
}
func (i *openAIProviderInitializer) validateConfig() error {
if openAIConfig.apiKey == "" {
func (i *openAIProviderInitializer) validateConfig(cfg *config.LLMConfig) error {
if cfg.APIKey == "" {
return errors.New("[openai llm] apiKey is required")
}
if cfg.Model == "" {
cfg.Model = OPENAI_DEFAULT_MODEL
}
if cfg.Temperature <= 0 || cfg.Temperature > 2 {
cfg.Temperature = 0.5
}
if cfg.MaxTokens <= 0 {
cfg.MaxTokens = 2048
}
return nil
}
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
i.initConfig(cfg)
if err := i.validateConfig(); err != nil {
if err := i.validateConfig(&cfg); err != nil {
return nil, err
}
headers := map[string]string{
"Authorization": "Bearer " + openAIConfig.apiKey,
"Content-Type": "application/json",
// Create OpenAI client
var clientOptions []option.RequestOption
clientOptions = append(clientOptions, option.WithAPIKey(cfg.APIKey))
// If a custom baseURL is set, use it
if cfg.BaseURL != "" {
clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL))
}
client := common.NewHTTPClient(openAIConfig.baseURL, headers)
return &OpenAIProvider{client: client, cfg: openAIConfig}, nil
}
type OpenAIProvider struct {
client *common.HTTPClient
cfg openAIProviderConfig
}
// Create OpenAI client
client := openai.NewClient(clientOptions...)
type openAIChatCompletionRequest struct {
Model string `json:"model"`
Messages []openAIChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type openAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIChatCompletionResponseChoice `json:"choices"`
Error *openAIError `json:"error,omitempty"`
}
type openAIChatCompletionResponseChoice struct {
Index int `json:"index"`
Message openAIChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
return &OpenAIProvider{
client: &client,
model: cfg.Model,
temperature: cfg.Temperature,
maxTokens: cfg.MaxTokens,
}, nil
}
// GenerateCompletion implements Provider interface.
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
req := openAIChatCompletionRequest{
Model: o.cfg.model,
Messages: []openAIChatMessage{
{Role: "user", Content: prompt},
// Create chat request
params := openai.ChatCompletionNewParams{
Model: o.model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
},
Temperature: o.cfg.temperature,
MaxTokens: o.cfg.maxTokens,
}
body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req)
// Set optional parameters
if o.temperature > 0 {
temperature := float64(o.temperature)
params.Temperature = param.Opt[float64]{Value: temperature}
}
if o.maxTokens > 0 {
maxTokens := int64(o.maxTokens)
params.MaxTokens = param.Opt[int64]{Value: maxTokens}
}
// Send request
response, err := o.client.Chat.Completions.New(ctx, params)
if err != nil {
return "", fmt.Errorf("openai llm post error: %w", err)
// Handle error
return "", fmt.Errorf("openai llm error: %w", err)
}
var resp openAIChatCompletionResponse
if err := json.Unmarshal(body, &resp); err != nil {
return "", fmt.Errorf("openai llm unmarshal error: %w", err)
}
if resp.Error != nil {
return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message)
}
if len(resp.Choices) == 0 {
// Check response
if len(response.Choices) == 0 {
return "", errors.New("openai llm: empty choices")
}
return resp.Choices[0].Message.Content, nil
// Return generated content
return response.Choices[0].Message.Content, nil
}
func (o *OpenAIProvider) GetProviderType() string {

View File

@@ -56,18 +56,12 @@ func NewRAGClient(config *config.Config) (*RAGClient, error) {
ragclient.llmProvider = llmProvider
}
demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization")
if err != nil {
return nil, fmt.Errorf("create init embedding failed, err: %w", err)
}
dim := len(demoVector)
dim := ragclient.config.Embedding.Dimensions
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
if err != nil {
return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
}
ragclient.vectordbProvider = provider
return ragclient, nil
}

View File

@@ -22,15 +22,17 @@ func getRAGClient() (*RAGClient, error) {
LLM: config.LLMConfig{
Provider: "openai",
APIKey: "sk-xxxx",
APIKey: "sk-xxx",
BaseURL: "https://openrouter.ai/api/v1",
Model: "openai/gpt-4o",
},
Embedding: config.EmbeddingConfig{
Provider: "dashscope",
APIKey: "sk-xxxx",
Model: "text-embedding-v4",
Provider: "openai",
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
APIKey: "sk-xxxx",
Model: "text-embedding-v4",
Dimensions: 1536,
},
VectorDB: config.VectorDBConfig{
@@ -38,7 +40,49 @@ func getRAGClient() (*RAGClient, error) {
Host: "localhost",
Port: 19530,
Database: "default",
Collection: "test_collection",
Collection: "test_collection3",
Mapping: config.MappingConfig{
Fields: []config.FieldMapping{
{
StandardName: "id",
RawName: "pk",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "page_content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "page_vector",
Properties: make(map[string]interface{}),
},
{
StandardName: "metadata",
RawName: "metadata",
Properties: make(map[string]interface{}),
},
{
StandardName: "created_at",
RawName: "created_at",
Properties: make(map[string]interface{}),
},
},
Index: config.IndexConfig{
IndexType: "IVF_FLAT",
Params: map[string]interface{}{"nlist": 64},
},
Search: config.SearchConfig{
MetricType: "COSINE",
Params: map[string]interface{}{"nprobe": 32},
},
},
},
}
@@ -48,7 +92,6 @@ func getRAGClient() (*RAGClient, error) {
}
return ragClient, nil
}
func TestNewRAGClient(t *testing.T) {
@@ -104,7 +147,7 @@ func TestRAGClient_DeleteChunk(t *testing.T) {
return
}
chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2"
chunk_id := "2a06679c-a8ea-46dc-bf1c-7e7b164a73c8"
err = ragClient.DeleteChunk(chunk_id)
if err != nil {
t.Errorf("DeleteChunk() error = %v", err)

View File

@@ -36,11 +36,11 @@ func init() {
MaxTokens: 2048,
},
Embedding: config.EmbeddingConfig{
Provider: "dashscope",
APIKey: "",
BaseURL: "",
Model: "text-embedding-v4",
Dimension: 1024,
Provider: "openai",
APIKey: "",
BaseURL: "",
Model: "text-embedding-ada-002",
Dimensions: 1536,
},
VectorDB: config.VectorDBConfig{
Provider: "milvus",
@@ -50,14 +50,56 @@ func init() {
Collection: "rag",
Username: "",
Password: "",
Mapping: config.MappingConfig{
Fields: []config.FieldMapping{
{
StandardName: "id",
RawName: "id",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "vector",
Properties: make(map[string]interface{}),
},
{
StandardName: "metadata",
RawName: "metadata",
Properties: make(map[string]interface{}),
},
{
StandardName: "created_at",
RawName: "created_at",
Properties: make(map[string]interface{}),
},
},
Index: config.IndexConfig{
IndexType: "HNSW",
Params: map[string]interface{}{"M": 8, "efConstruction": 64},
},
Search: config.SearchConfig{
MetricType: "IP",
Params: make(map[string]interface{}),
},
},
},
},
})
}
func (c *RAGConfig) ParseConfig(config map[string]any) error {
func (c *RAGConfig) ParseConfig(cfg map[string]any) error {
// Parse RAG configuration
if ragConfig, ok := config["rag"].(map[string]any); ok {
if ragConfig, ok := cfg["rag"].(map[string]any); ok {
if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
if splitterType, exists := splitter["provider"].(string); exists {
c.config.RAG.Splitter.Provider = splitterType
@@ -78,7 +120,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
}
// Parse Embedding configuration
if embeddingConfig, ok := config["embedding"].(map[string]any); ok {
if embeddingConfig, ok := cfg["embedding"].(map[string]any); ok {
if provider, exists := embeddingConfig["provider"].(string); exists {
c.config.Embedding.Provider = provider
} else {
@@ -94,13 +136,13 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
if model, exists := embeddingConfig["model"].(string); exists {
c.config.Embedding.Model = model
}
if dimension, exists := embeddingConfig["dimension"].(float64); exists {
c.config.Embedding.Dimension = int(dimension)
if dimensions, exists := embeddingConfig["dimensions"].(float64); exists {
c.config.Embedding.Dimensions = int(dimensions)
}
}
// Parse llm configuration
if llmConfig, ok := config["llm"].(map[string]any); ok {
if llmConfig, ok := cfg["llm"].(map[string]any); ok {
if provider, exists := llmConfig["provider"].(string); exists {
c.config.LLM.Provider = provider
}
@@ -122,7 +164,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
}
// Parse VectorDB configuration
if vectordbConfig, ok := config["vectordb"].(map[string]any); ok {
if vectordbConfig, ok := cfg["vectordb"].(map[string]any); ok {
if provider, exists := vectordbConfig["provider"].(string); exists {
c.config.VectorDB.Provider = provider
} else {
@@ -146,8 +188,59 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
if password, exists := vectordbConfig["password"].(string); exists {
c.config.VectorDB.Password = password
}
}
// Parse mapping here
if mapping, exists := vectordbConfig["mapping"].(map[string]any); exists {
// Parse field mappings
if fields, ok := mapping["fields"].([]any); ok {
c.config.VectorDB.Mapping.Fields = []config.FieldMapping{}
for _, field := range fields {
if fieldMap, ok := field.(map[string]any); ok {
fieldMapping := config.FieldMapping{
Properties: make(map[string]interface{}),
}
if standardName, ok := fieldMap["standard_name"].(string); ok {
fieldMapping.StandardName = standardName
}
if rawName, ok := fieldMap["raw_name"].(string); ok {
fieldMapping.RawName = rawName
}
// Parse properties
if properties, ok := fieldMap["properties"].(map[string]any); ok {
for key, value := range properties {
fieldMapping.Properties[key] = value
}
}
c.config.VectorDB.Mapping.Fields = append(c.config.VectorDB.Mapping.Fields, fieldMapping)
}
}
}
// Parse index configuration
if index, ok := mapping["index"].(map[string]any); ok {
if indexType, ok := index["index_type"].(string); ok {
c.config.VectorDB.Mapping.Index.IndexType = indexType
}
// Parse index parameters
if params, ok := index["params"].(map[string]any); ok {
c.config.VectorDB.Mapping.Index.Params = params
}
}
// Parse search configuration
if search, ok := mapping["search"].(map[string]any); ok {
if metricType, ok := search["metric_type"].(string); ok {
c.config.VectorDB.Mapping.Search.MetricType = metricType
}
// Parse search parameters
if params, ok := search["params"].(map[string]any); ok {
c.config.VectorDB.Mapping.Search.Params = params
}
}
}
}
return nil
}

View File

@@ -28,11 +28,11 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
MaxTokens: 2048,
},
Embedding: config.EmbeddingConfig{
Provider: "dashscope",
APIKey: "sk-XXX",
BaseURL: "",
Model: "text-embedding-v4",
Dimension: 1024,
Provider: "dashscope",
APIKey: "sk-XXX",
BaseURL: "",
Model: "text-embedding-v4",
Dimensions: 1024,
},
VectorDB: config.VectorDBConfig{
Provider: "milvus",
@@ -42,6 +42,48 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
Collection: "test_rag",
Username: "",
Password: "",
Mapping: config.MappingConfig{
Fields: []config.FieldMapping{
{
StandardName: "id",
RawName: "id",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "vector",
Properties: make(map[string]interface{}),
},
{
StandardName: "metadata",
RawName: "metadata",
Properties: make(map[string]interface{}),
},
{
StandardName: "created_at",
RawName: "created_at",
Properties: make(map[string]interface{}),
},
},
Index: config.IndexConfig{
IndexType: "HNSW",
Params: map[string]interface{}{"M": 4, "efConstruction": 32},
},
Search: config.SearchConfig{
MetricType: "IP",
Params: map[string]interface{}{"ef": 32},
},
},
},
}
// 把 config 输出 yaml 格式

View File

@@ -0,0 +1,182 @@
package vectordb
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
// Error definitions
var (
ErrFieldNotFound = errors.New("field not found")
ErrInvalidFieldType = errors.New("invalid field type")
ErrInvalidIndexType = errors.New("invalid index type")
ErrInvalidMetricType = errors.New("invalid metric type")
ErrInvalidSearchParams = errors.New("invalid search parameters")
ErrCollectionNotFound = errors.New("collection not found")
ErrUnsupportedOperation = errors.New("unsupported operation")
)
// VectorDBMapper interface for vector database mapping
type VectorDBMapper interface {
// ParseMapping parses the mapping configuration
ParseMapping(provider string, cfg config.MappingConfig) error
// GetIndexConfig returns the index configuration
GetIndexConfig() (config.IndexConfig, error)
// GetSearchConfig returns the search configuration
GetSearchConfig() (config.SearchConfig, error)
// Get all raw field names
GetRawAllFieldNames() ([]string, error)
// GetIDField returns the ID field mapping
GetIDField() (*config.FieldMapping, error)
// GetVectorField returns the vector field mapping
GetVectorField() (*config.FieldMapping, error)
// Get raw field name by standard field name
GetRawField(standardFieldName string) (*config.FieldMapping, error)
// Get field mapping by raw field name
GetField(rawFieldName string) (*config.FieldMapping, error)
// Get all field mappings
GetFieldMappings() ([]config.FieldMapping, error)
}
// DefaultVectorDBMapper is the default implementation of VectorDBMapper interface
type DefaultVectorDBMapper struct {
// Mapping configuration
mappingConfig config.MappingConfig
// Map from standard field name to field mapping
standardFieldMap map[string]*config.FieldMapping
// Map from raw field name to field mapping
rawFieldMap map[string]*config.FieldMapping
}
// NewDefaultVectorDBMapper creates a new default vector database mapper
func NewDefaultVectorDBMapper(provider string, mappingConfig config.MappingConfig) (*DefaultVectorDBMapper, error) {
mapper := &DefaultVectorDBMapper{
standardFieldMap: make(map[string]*config.FieldMapping),
rawFieldMap: make(map[string]*config.FieldMapping),
}
if err := mapper.ParseMapping(provider, mappingConfig); err != nil {
return nil, err
}
return mapper, nil
}
// ParseMapping parses the mapping configuration
func (m *DefaultVectorDBMapper) ParseMapping(provider string, cfg config.MappingConfig) error {
m.mappingConfig = cfg
// Clear existing mappings
m.standardFieldMap = make(map[string]*config.FieldMapping)
m.rawFieldMap = make(map[string]*config.FieldMapping)
// fill default field mappings
if len(cfg.Fields) == 0 {
defaultFields := []config.FieldMapping{
{
StandardName: "id",
RawName: "id",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "vector",
},
{
StandardName: "metadata",
RawName: "metadata",
},
{
StandardName: "created_at",
RawName: "created_at",
},
}
cfg.Fields = defaultFields
}
// Parse field mappings
for i, field := range cfg.Fields {
// Save pointer for future reference
fieldPtr := &cfg.Fields[i]
m.standardFieldMap[field.StandardName] = fieldPtr
m.rawFieldMap[field.RawName] = fieldPtr
}
// Check fields, must include id, content, vector fields
requiredFields := []string{"id", "content", "vector"}
for _, fieldName := range requiredFields {
if _, err := m.GetRawField(fieldName); err != nil {
return fmt.Errorf("[vector db mapper] required field %s not found or not varchar type", fieldName)
}
}
return nil
}
// GetIndexConfig gets the index configuration
func (m *DefaultVectorDBMapper) GetIndexConfig() (config.IndexConfig, error) {
return m.mappingConfig.Index, nil
}
// GetSearchConfig gets the search configuration
func (m *DefaultVectorDBMapper) GetSearchConfig() (config.SearchConfig, error) {
return m.mappingConfig.Search, nil
}
// GetRawAllFieldNames gets all raw field names
func (m *DefaultVectorDBMapper) GetRawAllFieldNames() ([]string, error) {
fieldNames := make([]string, 0, len(m.rawFieldMap))
for name := range m.rawFieldMap {
fieldNames = append(fieldNames, name)
}
return fieldNames, nil
}
// GetIDField gets the ID field
func (m *DefaultVectorDBMapper) GetIDField() (*config.FieldMapping, error) {
return m.GetRawField("id")
}
// GetVectorField gets the vector field
func (m *DefaultVectorDBMapper) GetVectorField() (*config.FieldMapping, error) {
return m.GetRawField("vector")
}
// GetRawField gets the raw field mapping by standard field name
func (m *DefaultVectorDBMapper) GetRawField(standardFieldName string) (*config.FieldMapping, error) {
field, exists := m.standardFieldMap[standardFieldName]
if !exists {
return nil, fmt.Errorf("%w: standard field %s not found", ErrFieldNotFound, standardFieldName)
}
return field, nil
}
// GetField gets the field mapping by raw field name
func (m *DefaultVectorDBMapper) GetField(rawFieldName string) (*config.FieldMapping, error) {
field, exists := m.rawFieldMap[rawFieldName]
if !exists {
return nil, fmt.Errorf("%w: raw field %s not found", ErrFieldNotFound, rawFieldName)
}
return field, nil
}
// GetFieldMappings gets all field mappings
func (m *DefaultVectorDBMapper) GetFieldMappings() ([]config.FieldMapping, error) {
return m.mappingConfig.Fields, nil
}

View File

@@ -80,16 +80,17 @@ func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, d
type MilvusProvider struct {
client client.Client
config *config.VectorDBConfig
Collection string
collection string
mapper VectorDBMapper
dimensions int
}
// NewMilvusProvider creates a new instance of MilvusProvider
func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
func NewMilvusProvider(cfg *config.VectorDBConfig, dimensions int) (VectorStoreProvider, error) {
// Create Milvus client
connectParam := client.Config{
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
}
connectParam.DBName = cfg.Database
// Add authentication if credentials are provided
if cfg.Username != "" && cfg.Password != "" {
@@ -102,92 +103,301 @@ func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider
return nil, fmt.Errorf("failed to create milvus client: %w", err)
}
mapper, err := NewDefaultVectorDBMapper(MILVUS_PROVIDER_TYPE, cfg.Mapping)
if err != nil {
return nil, fmt.Errorf("failed to create default vector db mapper: %w", err)
}
provider := &MilvusProvider{
client: milvusClient,
config: cfg,
Collection: cfg.Collection,
collection: cfg.Collection,
mapper: mapper,
dimensions: dimensions,
}
ctx := context.Background()
if err := provider.CreateCollection(ctx, dim); err != nil {
if err := provider.CreateCollection(ctx, dimensions); err != nil {
return nil, err
}
return provider, nil
}
func (m *MilvusProvider) buildSchema() (*entity.Schema, error) {
// Create Milvus collection Schema
idField, _ := m.mapper.GetIDField()
isIDAuto := idField.IsAutoID()
schema := entity.NewSchema().
WithName(m.collection).
WithDescription("Knowledge document collection").
WithAutoID(isIDAuto).
WithDynamicFieldEnabled(false)
// Add fields
var fieldEntity *entity.Field
fieldMappings, _ := m.mapper.GetFieldMappings()
for _, field := range fieldMappings {
fieldEntity = nil
maxLength := field.MaxLength()
switch field.StandardName {
case "id":
isIDAuto := field.IsAutoID()
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(int64(maxLength)).
WithIsPrimaryKey(true)
if isIDAuto {
fieldEntity.WithIsAutoID(true)
}
schema.WithField(fieldEntity)
case "content":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(int64(maxLength))
schema.WithField(fieldEntity)
case "vector":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeFloatVector).
WithDim(int64(m.dimensions))
schema.WithField(fieldEntity)
case "metadata":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeJSON)
schema.WithField(fieldEntity)
case "created_at":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeInt64)
schema.WithField(fieldEntity)
}
}
return schema, nil
}
func (m *MilvusProvider) GetMetricType(metricType string) entity.MetricType {
switch strings.ToUpper(metricType) {
case "L2":
return entity.L2
case "IP":
return entity.IP
case "COSINE":
return entity.COSINE
case "HAMMING":
return entity.HAMMING
case "JACCARD":
return entity.JACCARD
case "TANIMOTO":
return entity.TANIMOTO
case "SUBSTRUCTURE":
return entity.SUBSTRUCTURE
case "SUPERSTRUCTURE":
return entity.SUPERSTRUCTURE
default:
return entity.IP
}
}
func (m *MilvusProvider) buildVectorIndex() (entity.Index, error) {
// Map index type
indexConfig, _ := m.mapper.GetIndexConfig()
searchConfig, _ := m.mapper.GetSearchConfig()
// Map index parameters
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
if milvusIndexType == "" {
milvusIndexType = "HNSW"
}
metricType := m.GetMetricType(searchConfig.MetricType)
switch milvusIndexType {
case "FLAT":
// FLAT index doesn't need additional parameters
index, err := entity.NewIndexFlat(metricType)
if err != nil {
return nil, fmt.Errorf("failed to create FLAT index: %w", err)
}
return index, nil
case "BIN_FLAT":
// BIN_FLAT index doesn't need additional parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexBinFlat(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create BIN_FLAT index: %w", err)
}
return index, nil
case "IVF_FLAT":
// Default parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexIvfFlat(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_FLAT index: %w", err)
}
return index, nil
case "BIN_IVF_FLAT":
// Default parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexBinIvfFlat(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create BIN_IVF_FLAT index: %w", err)
}
return index, nil
case "IVF_SQ8":
// Default parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexIvfSQ8(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_SQ8 index: %w", err)
}
return index, nil
case "IVF_PQ":
// Default parameters
nlist := 128
m := 4
nbits := 8
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
if mVal, err := indexConfig.ParamsFloat64("m"); err == nil {
m = int(mVal)
}
if nbitsVal, err := indexConfig.ParamsInt64("nbits"); err == nil {
nbits = int(nbitsVal)
}
index, err := entity.NewIndexIvfPQ(metricType, nlist, m, nbits)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_PQ index: %w", err)
}
return index, nil
case "HNSW":
// Default parameters
m := 8
efConstruction := 64
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
m = int(mVal)
}
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
efConstruction = int(efConstructionVal)
}
index, err := entity.NewIndexHNSW(metricType, m, efConstruction)
if err != nil {
return nil, fmt.Errorf("failed to create HNSW index: %w", err)
}
return index, nil
case "IVF_HNSW":
// Default parameters
nlist := 128
m := 8
efConstruction := 64
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
m = int(mVal)
}
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
efConstruction = int(efConstructionVal)
}
index, err := entity.NewIndexIvfHNSW(metricType, nlist, m, efConstruction)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_HNSW index: %w", err)
}
return index, nil
case "DISKANN":
// DISKANN index parameters
index, err := entity.NewIndexDISKANN(metricType)
if err != nil {
return nil, fmt.Errorf("failed to create DISKANN index: %w", err)
}
return index, nil
case "SCANN":
// SCANN index parameters
nlist := 128
with_raw_data := false
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
if with_raw_dataVal, err := indexConfig.ParamsBool("with_raw_data"); err == nil {
with_raw_data = with_raw_dataVal
}
index, err := entity.NewIndexSCANN(metricType, nlist, with_raw_data)
if err != nil {
return nil, fmt.Errorf("failed to create SCANN index: %w", err)
}
return index, nil
case "AUTOINDEX":
// Auto index
index, err := entity.NewIndexAUTOINDEX(metricType)
if err != nil {
return nil, fmt.Errorf("failed to create AUTOINDEX index: %w", err)
}
return index, nil
default:
return nil, fmt.Errorf("unsupported index type: %s", milvusIndexType)
}
}
// CreateCollection creates a new collection with the specified dimension
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
// Check if collection exists
document_exists, err := m.client.HasCollection(ctx, m.Collection)
document_exists, err := m.client.HasCollection(ctx, m.collection)
if err != nil {
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
}
if !document_exists {
fmt.Printf("create collection %s\n", m.Collection)
fmt.Printf("create collection %s\n", m.collection)
// Create schema
schema := entity.NewSchema().
WithName(m.Collection).
WithDescription("Knowledge document collection").
WithAutoID(false).
WithDynamicFieldEnabled(false)
// Add fields based on schema.Document structure
// Primary key field - ID
pkField := entity.NewField().
WithName("id").
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(256).
WithIsPrimaryKey(true).
WithIsAutoID(false)
schema.WithField(pkField)
// Content field
contentField := entity.NewField().
WithName("content").
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(8192)
schema.WithField(contentField)
// Vector field
vectorField := entity.NewField().
WithName("vector").
WithDataType(entity.FieldTypeFloatVector).
WithDim(int64(dim))
schema.WithField(vectorField)
// Metadata field
metadataField := entity.NewField().
WithName("metadata").
WithDataType(entity.FieldTypeJSON)
schema.WithField(metadataField)
// CreatedAt field (stored as Unix timestamp)
createdAtField := entity.NewField().
WithName("created_at").
WithDataType(entity.FieldTypeInt64)
schema.WithField(createdAtField)
schema, err := m.buildSchema()
if err != nil {
return fmt.Errorf("failed to build schema: %w", err)
}
// Create collection
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil {
return fmt.Errorf("failed to create collection: %w", err)
}
// Create vector index
vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64)
vectorIndex, err := m.buildVectorIndex()
vectorField, _ := m.mapper.GetVectorField()
if err != nil {
return fmt.Errorf("failed to create vector index: %w", err)
}
err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index"))
err = m.client.CreateIndex(ctx, m.collection, vectorField.RawName, vectorIndex, false, client.WithIndexName("vector_index"))
if err != nil {
return fmt.Errorf("failed to create vector index: %w", err)
}
}
// Load collection
err = m.client.LoadCollection(ctx, m.Collection, false)
err = m.client.LoadCollection(ctx, m.collection, false)
if err != nil {
return fmt.Errorf("failed to load document collection: %w", err)
}
@@ -197,15 +407,15 @@ func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
// DropCollection removes the collection from the database
func (m *MilvusProvider) DropCollection(ctx context.Context) error {
// Check if collection exists
exists, err := m.client.HasCollection(ctx, m.Collection)
exists, err := m.client.HasCollection(ctx, m.collection)
if err != nil {
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
}
if !exists {
return fmt.Errorf("collection %s does not exist", m.Collection)
return fmt.Errorf("collection %s does not exist", m.collection)
}
// Drop collection
err = m.client.DropCollection(ctx, m.Collection)
err = m.client.DropCollection(ctx, m.collection)
if err != nil {
return fmt.Errorf("failed to drop collection: %w", err)
}
@@ -217,51 +427,71 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
if len(docs) == 0 {
return nil
}
// Prepare data
ids := make([]string, len(docs))
contents := make([]string, len(docs))
vectors := make([][]float32, len(docs))
metadatas := make([][]byte, len(docs))
createdAts := make([]int64, len(docs))
for i, doc := range docs {
ids[i] = doc.ID
contents[i] = doc.Content
// Convert vector type
vectorFloat32 := make([]float32, len(doc.Vector))
for j, v := range doc.Vector {
vectorFloat32[j] = float32(v)
}
vectors[i] = vectorFloat32
// Serialize metadata
metadataBytes, err := json.Marshal(doc.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
}
metadatas[i] = metadataBytes
createdAts[i] = doc.CreatedAt.UnixMilli()
// Get field mappings
fieldMappings, err := m.mapper.GetFieldMappings()
if err != nil {
return fmt.Errorf("failed to get field mappings: %w", err)
}
// Prepare data and columns
columns := make([]entity.Column, 0, len(fieldMappings))
// Create corresponding column data for each field
for _, field := range fieldMappings {
// Skip ID field if configured as auto ID
if field.IsPrimaryKey() && field.IsAutoID() {
continue
}
switch field.StandardName {
case "id":
// Handle string type fields
values := make([]string, len(docs))
for i, doc := range docs {
values[i] = doc.ID
}
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
case "content":
values := make([]string, len(docs))
for i, doc := range docs {
values[i] = doc.Content
}
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
// Build insert data
columns := []entity.Column{
entity.NewColumnVarChar("id", ids),
entity.NewColumnVarChar("content", contents),
entity.NewColumnFloatVector("vector", len(vectors[0]), vectors),
entity.NewColumnJSONBytes("metadata", metadatas),
entity.NewColumnInt64("created_at", createdAts),
case "vector":
// Handle vector fields
vectors := make([][]float32, len(docs))
for i, doc := range docs {
vectors[i] = doc.Vector
}
columns = append(columns, entity.NewColumnFloatVector(field.RawName, len(vectors[0]), vectors))
case "metadata":
// Handle JSON type fields (like metadata)
values := make([][]byte, len(docs))
for i, doc := range docs {
// Serialize metadata
metadataBytes, err := json.Marshal(doc.Metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
}
values[i] = metadataBytes
}
columns = append(columns, entity.NewColumnJSONBytes(field.RawName, values))
case "created_at":
// Handle integer type fields
values := make([]int64, len(docs))
for i, doc := range docs {
values[i] = doc.CreatedAt.UnixMilli()
}
columns = append(columns, entity.NewColumnInt64(field.RawName, values))
}
}
// Insert data
_, err := m.client.Insert(ctx, m.Collection, "", columns...)
_, err = m.client.Insert(ctx, m.collection, "", columns...)
if err != nil {
return fmt.Errorf("failed to insert documents: %w", err)
}
// Flush data
err = m.client.Flush(ctx, m.Collection, false)
err = m.client.Flush(ctx, m.collection, false)
if err != nil {
return fmt.Errorf("failed to flush collection: %w", err)
}
@@ -271,16 +501,19 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
// DeleteDoc deletes a document by its ID
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
// Build delete expression
expr := fmt.Sprintf(`id == "%s"`, id)
// Get ID field
idField, _ := m.mapper.GetIDField()
// Build delete expression using the RawName of ID field
expr := fmt.Sprintf(`%s == "%s"`, idField.RawName, id)
// Delete data
err := m.client.Delete(ctx, m.Collection, "", expr)
err := m.client.Delete(ctx, m.collection, "", expr)
if err != nil {
return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
}
// Flush data
err = m.client.Flush(ctx, m.Collection, false)
err = m.client.Flush(ctx, m.collection, false)
if err != nil {
return fmt.Errorf("failed to flush collection after delete: %w", err)
}
@@ -306,24 +539,127 @@ func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document)
return nil
}
func (m *MilvusProvider) buildSearchParam() (entity.SearchParam, error) {
// Get index configuration
indexConfig, err := m.mapper.GetIndexConfig()
if err != nil {
return nil, fmt.Errorf("failed to get index config: %w", err)
}
// Get search configuration
searchConfig, err := m.mapper.GetSearchConfig()
if err != nil {
return nil, fmt.Errorf("failed to get search config: %w", err)
}
// Choose appropriate search parameters based on index type
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
if milvusIndexType == "" {
milvusIndexType = "HNSW" // Default to HNSW index
}
switch milvusIndexType {
case "FLAT":
// FLAT and BIN_FLAT indices don't need additional search parameters
return entity.NewIndexFlatSearchParam()
case "BIN_FLAT", "IVF_FLAT", "BIN_IVF_FLAT", "IVF_SQ8":
// Search parameters for IVF series indices
nprobe := 16 // Default value
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
return entity.NewIndexIvfFlatSearchParam(nprobe)
case "IVF_PQ":
// Search parameters for IVF_PQ index
nprobe := 16 // Default value
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
return entity.NewIndexIvfPQSearchParam(nprobe)
case "HNSW":
// Search parameters for HNSW index
efSearch := 16 // Default value
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
efSearch = int(efSearchVal)
}
return entity.NewIndexHNSWSearchParam(efSearch)
case "IVF_HNSW":
// Search parameters for IVF_HNSW index
nprobe := 16 // Default value
efSearch := 64 // Default value
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
efSearch = int(efSearchVal)
}
return entity.NewIndexIvfHNSWSearchParam(nprobe, efSearch)
case "SCANN":
// Search parameters for SCANN index
nprobe := 16 // Default value
reorder_k := 64
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
if reorderKVal, err := searchConfig.ParamsInt64("reorder_k"); err == nil {
reorder_k = int(reorderKVal)
}
return entity.NewIndexSCANNSearchParam(nprobe, reorder_k)
case "DISKANN":
// Search parameters for DISKANN index
search_list := 100 // Default value
if searchListVal, err := searchConfig.ParamsInt64("search_list"); err == nil {
search_list = int(searchListVal)
}
return entity.NewIndexDISKANNSearchParam(search_list)
case "AUTOINDEX":
level := 8
if levelVal, err := searchConfig.ParamsInt64("level"); err == nil {
level = int(levelVal)
}
// Search parameters for AUTOINDEX index
return entity.NewIndexAUTOINDEXSearchParam(level)
default:
// Default to using HNSW search parameters
return entity.NewIndexHNSWSearchParam(16)
}
}
// SearchDocs performs similarity search for documents
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
if options == nil {
options = &schema.SearchOptions{TopK: 10}
}
// Build search parameters
sp, _ := entity.NewIndexHNSWSearchParam(16)
sp, err := m.buildSearchParam()
if err != nil {
return nil, fmt.Errorf("failed to build search param: %w", err)
}
outputFields, _ := m.mapper.GetRawAllFieldNames()
vectorField, _ := m.mapper.GetVectorField()
searchConfig, _ := m.mapper.GetSearchConfig()
metricType := m.GetMetricType(searchConfig.MetricType)
// Build filter expression
expr := ""
searchResults, err := m.client.Search(
ctx,
m.Collection,
[]string{}, // partition names
expr, // filter expression
[]string{"id", "content", "metadata", "created_at"}, // output fields
m.collection,
[]string{}, // partition names
expr, // filter expression
outputFields, // output fields
[]entity.Vector{entity.FloatVector(vector)},
"vector", // anns_field
entity.IP, // metric_type
vectorField.RawName, // anns_field
metricType, // metric_type
options.TopK,
sp,
)
@@ -341,9 +677,13 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
// Get field data
var content string
var metadata map[string]interface{}
for _, field := range result.Fields {
switch field.Name() {
fieldMapping, err := m.mapper.GetField(field.Name())
if err != nil {
continue
}
fieldName := strings.ToLower(fieldMapping.StandardName)
switch fieldName {
case "content":
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
if contentVal, err := contentCol.Get(i); err == nil {
@@ -364,7 +704,6 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
}
}
}
searchResult := schema.SearchResult{
Document: schema.Document{
ID: fmt.Sprintf("%s", id),
@@ -392,15 +731,17 @@ func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error {
for i, id := range ids {
quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
}
expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ","))
idField, _ := m.mapper.GetIDField()
expr := fmt.Sprintf("%s in [%s]", idField.RawName, strings.Join(quotedIDs, ","))
// Delete data
err := m.client.Delete(ctx, m.Collection, "", expr)
err := m.client.Delete(ctx, m.collection, "", expr)
if err != nil {
return fmt.Errorf("failed to delete documents: %w", err)
}
// Flush data
err = m.client.Flush(ctx, m.Collection, false)
err = m.client.Flush(ctx, m.collection, false)
if err != nil {
return fmt.Errorf("failed to flush collection after delete: %w", err)
}
@@ -413,12 +754,13 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
// Build query expression
expr := ""
// Query all relevant documents
outputFields, _ := m.mapper.GetRawAllFieldNames()
queryResult, err := m.client.Query(
ctx,
m.Collection,
m.collection,
[]string{}, // partitions
expr, // filter condition
[]string{"id", "content", "metadata", "created_at"},
outputFields,
client.WithOffset(0), client.WithLimit(int64(limit)),
)
@@ -443,7 +785,12 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
)
for _, col := range queryResult {
switch col.Name() {
fieldMapping, err := m.mapper.GetField(col.Name())
if err != nil {
continue
}
fieldName := strings.ToLower(fieldMapping.StandardName)
switch fieldName {
case "id":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
id = v.(string)
@@ -488,8 +835,3 @@ func (m *MilvusProvider) Close() error {
}
return nil
}
// joinStrings joins a slice of strings with the given separator
func joinStrings(elems []string, sep string) string {
return strings.Join(elems, sep)
}