diff --git a/plugins/golang-filter/go.mod b/plugins/golang-filter/go.mod index 3c05f7b18..b828a7112 100644 --- a/plugins/golang-filter/go.mod +++ b/plugins/golang-filter/go.mod @@ -55,16 +55,21 @@ require ( github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/openai/openai-go/v2 v2.7.0 // indirect github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect github.com/pkoukk/tiktoken-go v0.1.8 // indirect github.com/prometheus/client_golang v1.14.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect - golang.org/x/net v0.33.0 // indirect + golang.org/x/net v0.34.0 // indirect golang.org/x/time v0.3.0 // indirect google.golang.org/grpc v1.59.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect @@ -99,9 +104,9 @@ require ( github.com/shopspring/decimal v1.4.0 // indirect go.opentelemetry.io/otel v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect - golang.org/x/crypto v0.31.0 // indirect + golang.org/x/crypto v0.32.0 // indirect golang.org/x/sync v0.10.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect diff --git a/plugins/golang-filter/go.sum b/plugins/golang-filter/go.sum index 97af86899..e578092c1 100644 --- a/plugins/golang-filter/go.sum +++ b/plugins/golang-filter/go.sum @@ -311,6 +311,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.24.2 h1:J/tulyYK6JwBldPViHJReihxxZ+22FHs0piGjQAvoUE= github.com/onsi/gomega v1.24.2/go.mod h1:gs3J10IS7Z7r7eXRoNJIrNqU4ToQukCJhFtKrWgHWnk= +github.com/openai/openai-go/v2 v2.7.0 h1:/8MSFCXcasin7AyuWQ2au6FraXL71gzAs+VfbMv+J3k= +github.com/openai/openai-go/v2 v2.7.0/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE= github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc h1:Ak86L+yDSOzKFa7WM5bf5itSOo1e3Xh8bm5YCMUXIjQ= github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI= github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU= @@ -377,7 +379,18 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w= github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= @@ -426,6 +439,8 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -506,6 +521,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -581,6 +598,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= diff --git a/plugins/golang-filter/mcp-server/servers/rag/README.md b/plugins/golang-filter/mcp-server/servers/rag/README.md index bca469d6c..7cc1374b3 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/README.md +++ b/plugins/golang-filter/mcp-server/servers/rag/README.md @@ -84,10 +84,11 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也 | llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 | | llm.temperature | float | 可选 | 0.5 | 温度参数 | | **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) | -| embedding.provider | string | 必填 | dashscope | 嵌入提供商:openai或dashscope | +| embedding.provider | string | 必填 | openai | 嵌入提供商:支持openai协议的任意供应商 | | embedding.api_key | string | 必填 | - | 嵌入API密钥 | | embedding.base_url | string | 可选 | | 嵌入API基础URL | -| embedding.model | string | 必填 | text-embedding-v4 | 嵌入模型名称 | +| embedding.model | string | 必填 | text-embedding-ada-002 | 嵌入模型名称 | +| embedding.dimensions | integer | 可选 | 1536 | 嵌入维度 | | **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) | | vectordb.provider | string | 必填 | milvus | 向量数据库提供商 | | vectordb.host | string | 必填 | localhost | 数据库主机地址 | @@ -96,6 +97,17 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也 | vectordb.collection | string | 必填 | test_collection | 集合名称 | | vectordb.username | string | 可选 | - | 数据库用户名 | | vectordb.password | string | 可选 | - | 数据库密码 | +| **vectordb.mapping** | object | 可选 | - | 字段映射配置 | +| vectordb.mapping.fields | array | 可选 | - | 字段映射列表 | +| vectordb.mapping.fields[].standard_name | string | 必填 | - | 标准字段名称(如 id, content, vector 等) | +| vectordb.mapping.fields[].raw_name | string | 必填 | - | 原始字段名称(数据库中的实际字段名) | +| vectordb.mapping.fields[].properties | object | 可选 | - | 字段属性(如 auto_id, max_length 等) | +| vectordb.mapping.index | object | 可选 | - | 索引配置 | +| vectordb.mapping.index.index_type | string | 必填 | - | 索引类型(如 FLAT, IVF_FLAT, HNSW 等) | +| vectordb.mapping.index.params | object | 可选 | - | 索引参数(根据索引类型不同而异) | +| vectordb.mapping.search | object | 可选 | - | 搜索配置 | +| vectordb.mapping.search.metric_type | string | 可选 | L2 | 度量类型(如 L2, IP, COSINE 等) | +| vectordb.mapping.search.params | object | 可选 | - | 搜索参数(如 nprobe, ef_search 等) ### higress-config 配置样例 @@ -143,27 +155,54 @@ data: temperature: 0.5 max_tokens: 2048 embedding: - provider: dashscope + provider: openai + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 api_key: sk-xxx model: text-embedding-v4 + dimensions: 1536 vectordb: provider: milvus - host: + host: localhost port: 19530 database: default - collection: test_collection -``` + collection: test_rag + mapping: + fields: + - standard_name: id + raw_name: id + properties: + auto_id: false + max_length: 256 + - standard_name: content + raw_name: content + properties: + max_length: 8192 + - standard_name: vector + raw_name: vector + - standard_name: metadata + raw_name: metadata + - standard_name: created_at + raw_name: created_at + index: + index_type: HNSW + params: + M: 4 + efConstruction: 32 + search: + metric_type: IP + params: + ef: 32 +``` ### 支持的提供商 #### Embedding -- **OpenAI** -- **DashScope** +- **OpenAI 兼容** #### Vector Database - **Milvus** #### LLM -- **OpenAI** +- **OpenAI 兼容** ## 如何测试数据集的效果 diff --git a/plugins/golang-filter/mcp-server/servers/rag/config/config.go b/plugins/golang-filter/mcp-server/servers/rag/config/config.go index 56a6f917e..4a883c966 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/config/config.go +++ b/plugins/golang-filter/mcp-server/servers/rag/config/config.go @@ -1,5 +1,7 @@ package config +import "fmt" + // Config represents the main configuration structure for the MCP server type Config struct { RAG RAGConfig `json:"rag" yaml:"rag"` @@ -34,20 +36,148 @@ type LLMConfig struct { // EmbeddingConfig defines configuration for embedding models type EmbeddingConfig struct { - Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope - APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"` - BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"` - Model string `json:"model,omitempty" yaml:"model,omitempty"` - Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"` + Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope + APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"` + BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"` + Model string `json:"model,omitempty" yaml:"model,omitempty"` + Dimensions int `json:"dimensions,omitempty" yaml:"dimension,omitempty"` } // VectorDBConfig defines configuration for vector databases type VectorDBConfig struct { - Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma - Host string `json:"host,omitempty" yaml:"host,omitempty"` - Port int `json:"port,omitempty" yaml:"port,omitempty"` - Database string `json:"database,omitempty" yaml:"database,omitempty"` - Collection string `json:"collection,omitempty" yaml:"collection,omitempty"` - Username string `json:"username,omitempty" yaml:"username,omitempty"` - Password string `json:"password,omitempty" yaml:"password,omitempty"` + Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma + Host string `json:"host,omitempty" yaml:"host,omitempty"` + Port int `json:"port,omitempty" yaml:"port,omitempty"` + Database string `json:"database,omitempty" yaml:"database,omitempty"` + Collection string `json:"collection,omitempty" yaml:"collection,omitempty"` + Username string `json:"username,omitempty" yaml:"username,omitempty"` + Password string `json:"password,omitempty" yaml:"password,omitempty"` + Mapping MappingConfig `json:"mapping,omitempty" yaml:"mapping,omitempty"` +} + +// MappingConfig defines field mapping configuration for vector databases +type MappingConfig struct { + Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"` + Index IndexConfig `json:"index,omitempty" yaml:"index,omitempty"` + Search SearchConfig `json:"search,omitempty" yaml:"search,omitempty"` +} + +// // CollectionMapping defines field mapping for collection +// type CollectionMapping struct { +// Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"` +// } + +// FieldMapping defines mapping for a single field +type FieldMapping struct { + StandardName string `json:"standard_name" yaml:"standard_name"` + RawName string `json:"raw_name" yaml:"raw_name"` + Properties map[string]interface{} `json:"properties,omitempty" yaml:"properties,omitempty"` +} + +func (f FieldMapping) IsPrimaryKey() bool { + return f.StandardName == "id" +} + +func (f FieldMapping) IsAutoID() bool { + if f.Properties == nil { + return false + } + autoID, ok := f.Properties["auto_id"].(bool) + if !ok { + return false + } + return autoID +} + +func (f FieldMapping) IsVectorField() bool { + return f.StandardName == "vector" +} + +func (f FieldMapping) MaxLength() int { + if f.Properties == nil { + return 0 + } + maxLength, ok := f.Properties["max_length"].(int) + if !ok { + return 256 + } + return maxLength +} + +// IndexConfig defines configuration for index parameters +type IndexConfig struct { + // Index type, e.g., IVF_FLAT, IVF_SQ8, HNSW, etc. + IndexType string `json:"index_type" yaml:"index_type"` + // Index parameter configuration + Params map[string]interface{} `json:"params" yaml:"params"` +} + +func (i IndexConfig) ParamsString(key string) (string, error) { + if mVal, ok := i.Params[key].(string); ok { + return mVal, nil + } + return "", fmt.Errorf("params %s not found", key) +} + +func (i IndexConfig) ParamsInt64(key string) (int64, error) { + if mVal, ok := i.Params[key].(int64); ok { + return mVal, nil + } + if mVal, ok := i.Params[key].(int); ok { + return int64(mVal), nil + } + return 0, fmt.Errorf("params %s not found", key) +} + +func (i IndexConfig) ParamsFloat64(key string) (float64, error) { + if mVal, ok := i.Params[key].(float64); ok { + return mVal, nil + } + if mVal, ok := i.Params[key].(float32); ok { + return float64(mVal), nil + } + return 0, fmt.Errorf("params %s not found", key) +} + +func (i IndexConfig) ParamsBool(key string) (bool, error) { + if mVal, ok := i.Params[key].(bool); ok { + return mVal, nil + } + return false, fmt.Errorf("params %s not found", key) +} + +// SearchConfig defines configuration for search parameters +type SearchConfig struct { + // Metric type, e.g., L2, IP, etc. + MetricType string `json:"metric_type,omitempty" yaml:"metric_type,omitempty"` + // Search parameter configuration + Params map[string]interface{} `json:"params" yaml:"params"` +} + +func (i SearchConfig) ParamsString(key string) (string, error) { + if mVal, ok := i.Params[key].(string); ok { + return mVal, nil + } + return "", fmt.Errorf("params %s not found", key) +} + +func (i SearchConfig) ParamsInt64(key string) (int64, error) { + if mVal, ok := i.Params[key].(int64); ok { + return mVal, nil + } + return 0, fmt.Errorf("params %s not found", key) +} + +func (i SearchConfig) ParamsFloat64(key string) (float64, error) { + if mVal, ok := i.Params[key].(float64); ok { + return mVal, nil + } + return 0, fmt.Errorf("params %s not found", key) +} + +func (i SearchConfig) ParamsBool(key string) (bool, error) { + if mVal, ok := i.Params[key].(bool); ok { + return mVal, nil + } + return false, fmt.Errorf("params %s not found", key) } diff --git a/plugins/golang-filter/mcp-server/servers/rag/embedding/dashscope.go b/plugins/golang-filter/mcp-server/servers/rag/embedding/dashscope.go deleted file mode 100644 index 1216699cc..000000000 --- a/plugins/golang-filter/mcp-server/servers/rag/embedding/dashscope.go +++ /dev/null @@ -1,169 +0,0 @@ -package embedding - -import ( - "context" - "encoding/json" - "errors" - "fmt" - - "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common" - "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" -) - -const ( - DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com" - DASHSCOPE_PORT = 443 - DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4" - DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" -) - -var dashScopeConfig dashScopeProviderConfig - -type dashScopeProviderInitializer struct { -} -type dashScopeProviderConfig struct { - apiKey string - model string -} - -func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) { - dashScopeConfig.apiKey = config.APIKey - dashScopeConfig.model = config.Model -} - -func (c *dashScopeProviderInitializer) ValidateConfig() error { - if dashScopeConfig.apiKey == "" { - return errors.New("[DashScope] apiKey is required") - } - return nil -} - -func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) { - c.InitConfig(config) - err := c.ValidateConfig() - if err != nil { - return nil, err - } - - headers := map[string]string{ - "Authorization": "Bearer " + config.APIKey, - "Content-Type": "application/json", - } - httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers) - - return &DashScopeProvider{ - config: dashScopeConfig, - client: httpClient, - }, nil -} - -func (d *DashScopeProvider) GetProviderType() string { - return PROVIDER_TYPE_DASHSCOPE -} - -type Embedding struct { - Embedding []float32 `json:"embedding"` - TextIndex int `json:"text_index"` -} - -type Input struct { - Texts []string `json:"texts"` -} - -type Params struct { - TextType string `json:"text_type"` -} - -type Response struct { - RequestID string `json:"request_id"` - Output Output `json:"output"` - Usage Usage `json:"usage"` -} - -type Output struct { - Embeddings []Embedding `json:"embeddings"` -} - -type Usage struct { - TotalTokens int `json:"total_tokens"` -} - -type EmbeddingRequest struct { - Model string `json:"model"` - Input Input `json:"input"` - Parameters Params `json:"parameters"` -} - -type Document struct { - Vector []float64 `json:"vector"` - Fields map[string]string `json:"fields"` -} - -type DashScopeProvider struct { - config dashScopeProviderConfig - client *common.HTTPClient -} - -func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) { - model := d.config.model - if model == "" { - model = DASHSCOPE_DEFAULT_MODEL_NAME - } - - if dashScopeConfig.apiKey == "" { - return EmbeddingRequest{}, errors.New("dashScopeKey is empty") - } - - data := EmbeddingRequest{ - Model: model, - Input: Input{ - Texts: texts, - }, - Parameters: Params{ - TextType: "query", - }, - } - - return data, nil -} - -type Result struct { - ID string `json:"id"` - Vector []float32 `json:"vector,omitempty"` - Fields map[string]interface{} `json:"fields"` - Score float64 `json:"score"` -} - -func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) { - var resp Response - err := json.Unmarshal(responseBody, &resp) - if err != nil { - return nil, err - } - return &resp, nil -} - -func (d *DashScopeProvider) GetEmbedding( - ctx context.Context, - queryString string) ([]float32, error) { - requestData, err := d.constructRequestData([]string{queryString}) - if err != nil { - return nil, fmt.Errorf("failed to construct request data: %v", err) - } - - responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData) - if err != nil { - return nil, fmt.Errorf("failed to send request: %v", err) - } - - embeddingResp, err := d.parseTextEmbedding(responseBody) - if err != nil { - return nil, fmt.Errorf("failed to parse response: %v", err) - } - - if len(embeddingResp.Output.Embeddings) == 0 { - return nil, errors.New("no embedding found in response") - } - - return embeddingResp.Output.Embeddings[0].Embedding, nil -} diff --git a/plugins/golang-filter/mcp-server/servers/rag/embedding/openai.go b/plugins/golang-filter/mcp-server/servers/rag/embedding/openai.go index 2f4ec2c20..c59b2deb6 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/embedding/openai.go +++ b/plugins/golang-filter/mcp-server/servers/rag/embedding/openai.go @@ -2,160 +2,93 @@ package embedding import ( "context" - "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" ) const ( - OPENAI_DOMAIN = "api.openai.com" - OPENAI_PORT = 443 - OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small" - OPENAI_ENDPOINT = "/v1/embeddings" + OPENAI_DEFAULT_MODEL_NAME = "text-embedding-ada-002" ) type openAIProviderInitializer struct { } -var openAIConfig openAIProviderConfig - -type openAIProviderConfig struct { - baseUrl string - apiKey string - model string -} - -func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) { - openAIConfig.apiKey = config.APIKey - openAIConfig.model = config.Model - openAIConfig.baseUrl = config.BaseURL -} - -func (c *openAIProviderInitializer) ValidateConfig() error { - if openAIConfig.apiKey == "" { - return errors.New("[openAI] apiKey is required") +func (c *openAIProviderInitializer) validateConfig(config *config.EmbeddingConfig) error { + if config.APIKey == "" { + return errors.New("[openai embbeding] apiKey is required") } + if config.Model == "" { + config.Model = OPENAI_DEFAULT_MODEL_NAME + } + if config.Dimensions <= 0 { + config.Dimensions = 1536 + } + return nil } func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) { - c.InitConfig(config) - err := c.ValidateConfig() - if err != nil { + if err := c.validateConfig(&config); err != nil { return nil, err } + // 创建 OpenAI 客户端 + var clientOptions []option.RequestOption + clientOptions = append(clientOptions, option.WithAPIKey(config.APIKey)) - if openAIConfig.model == "" { - openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME + // 如果设置了自定义 baseURL,则使用它 + if config.BaseURL != "" { + clientOptions = append(clientOptions, option.WithBaseURL(config.BaseURL)) } - - if openAIConfig.baseUrl == "" { - openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN) - } - - headers := map[string]string{ - "Authorization": "Bearer " + config.APIKey, - "Content-Type": "application/json", - } - httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers) + // 创建 OpenAI 客户端 + client := openai.NewClient(clientOptions...) return &OpenAIProvider{ - config: openAIConfig, - client: httpClient, + client: &client, + model: config.Model, + dimensions: config.Dimensions, }, nil } -func (o *OpenAIProvider) GetProviderType() string { +// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs +type OpenAIProvider struct { + client *openai.Client + model string + dimensions int +} + +func (e *OpenAIProvider) GetProviderType() string { return PROVIDER_TYPE_OPENAI } -type OpenAIResponse struct { - Object string `json:"object"` - Data []OpenAIResult `json:"data"` - Model string `json:"model"` - Error *OpenAIError `json:"error"` -} - -type OpenAIResult struct { - Object string `json:"object"` - Embedding []float32 `json:"embedding"` - Index int `json:"index"` -} - -type OpenAIError struct { - Message string `json:"prompt_tokens"` - Type string `json:"type"` - Code string `json:"code"` - Param string `json:"param"` -} - -type OpenAIEmbeddingRequest struct { - Input string `json:"input"` - Model string `json:"model"` -} - -type OpenAIProvider struct { - config openAIProviderConfig - client *common.HTTPClient -} - -func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) { - if text == "" { - return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty") +// GetEmbedding generates vector embedding for the given text +func (e *OpenAIProvider) GetEmbedding(ctx context.Context, text string) ([]float32, error) { + params := openai.EmbeddingNewParams{ + Model: e.model, + Input: openai.EmbeddingNewParamsInputUnion{ + OfString: openai.String(text), + }, + Dimensions: openai.Int(int64(e.dimensions)), + EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat, } - if openAIConfig.apiKey == "" { - return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty") - } - - model := o.config.model - if model == "" { - model = OPENAI_DEFAULT_MODEL_NAME - } - - data := OpenAIEmbeddingRequest{ - Input: text, - Model: model, - } - - return data, nil -} - -func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) { - var resp OpenAIResponse - err := json.Unmarshal(responseBody, &resp) + embeddingResp, err := e.client.Embeddings.New(ctx, params) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to generate embedding: %w", err) } - return &resp, nil -} - -func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) { - requestData, err := o.constructRequestData(queryString) - if err != nil { - return nil, fmt.Errorf("failed to construct request data: %v", err) - } - - responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData) - if err != nil { - return nil, fmt.Errorf("failed to send request: %v", err) - } - - resp, err := o.parseTextEmbedding(responseBody) - if err != nil { - return nil, fmt.Errorf("failed to parse response: %v", err) - } - - if resp.Error != nil { - return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message) - } - - if len(resp.Data) == 0 { - return nil, errors.New("no embedding found in response") - } - - return resp.Data[0].Embedding, nil + + if len(embeddingResp.Data) == 0 { + return nil, fmt.Errorf("empty embedding response") + } + + // Convert []float64 to []float32 + embedding := make([]float32, len(embeddingResp.Data[0].Embedding)) + for i, v := range embeddingResp.Data[0].Embedding { + embedding[i] = float32(v) + } + + return embedding, nil } diff --git a/plugins/golang-filter/mcp-server/servers/rag/embedding/provider.go b/plugins/golang-filter/mcp-server/servers/rag/embedding/provider.go index 5377f0436..e8c5698bf 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/embedding/provider.go +++ b/plugins/golang-filter/mcp-server/servers/rag/embedding/provider.go @@ -10,21 +10,21 @@ import ( // Provider type constants for different embedding services const ( // DashScope embedding service - PROVIDER_TYPE_DASHSCOPE = "dashscope" + PROVIDER_TYPE_DASHSCOPE = "dashscope" // TextIn embedding service - PROVIDER_TYPE_TEXTIN = "textin" + PROVIDER_TYPE_TEXTIN = "textin" // Cohere embedding service - PROVIDER_TYPE_COHERE = "cohere" + PROVIDER_TYPE_COHERE = "cohere" // OpenAI embedding service - PROVIDER_TYPE_OPENAI = "openai" + PROVIDER_TYPE_OPENAI = "openai" // Ollama embedding service - PROVIDER_TYPE_OLLAMA = "ollama" + PROVIDER_TYPE_OLLAMA = "ollama" // HuggingFace embedding service PROVIDER_TYPE_HUGGINGFACE = "huggingface" // XFYun embedding service - PROVIDER_TYPE_XFYUN = "xfyun" + PROVIDER_TYPE_XFYUN = "xfyun" // Azure embedding service - PROVIDER_TYPE_AZURE = "azure" + PROVIDER_TYPE_AZURE = "azure" ) // Factory interface for creating Provider instances @@ -36,8 +36,7 @@ type providerInitializer interface { // Maps provider types to their initializers var ( providerInitializers = map[string]providerInitializer{ - PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, - PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{}, + PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{}, } ) diff --git a/plugins/golang-filter/mcp-server/servers/rag/llm/openai.go b/plugins/golang-filter/mcp-server/servers/rag/llm/openai.go index 901015ffb..f30822186 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/llm/openai.go +++ b/plugins/golang-filter/mcp-server/servers/rag/llm/openai.go @@ -2,133 +2,105 @@ package llm import ( "context" - "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" + "github.com/openai/openai-go/v2" + "github.com/openai/openai-go/v2/option" + "github.com/openai/openai-go/v2/packages/param" ) const ( - OPENAI_CHAT_ENDPOINT = "/chat/completions" OPENAI_DEFAULT_MODEL = "gpt-4o" ) -// openAI specific configuration captured after initialization. -type openAIProviderConfig struct { - apiKey string - baseURL string +type OpenAIProvider struct { + client *openai.Client model string - maxTokens int temperature float64 + maxTokens int } type openAIProviderInitializer struct{} -var openAIConfig openAIProviderConfig - -func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) { - openAIConfig.apiKey = c.APIKey - openAIConfig.baseURL = c.BaseURL - openAIConfig.model = c.Model - if openAIConfig.model == "" { - openAIConfig.model = OPENAI_DEFAULT_MODEL - } - if openAIConfig.baseURL == "" { - openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint - } - openAIConfig.maxTokens = c.MaxTokens - openAIConfig.temperature = c.Temperature -} - -func (i *openAIProviderInitializer) validateConfig() error { - if openAIConfig.apiKey == "" { +func (i *openAIProviderInitializer) validateConfig(cfg *config.LLMConfig) error { + if cfg.APIKey == "" { return errors.New("[openai llm] apiKey is required") } + if cfg.Model == "" { + cfg.Model = OPENAI_DEFAULT_MODEL + } + + if cfg.Temperature <= 0 || cfg.Temperature > 2 { + cfg.Temperature = 0.5 + } + + if cfg.MaxTokens <= 0 { + cfg.MaxTokens = 2048 + } return nil } func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) { - i.initConfig(cfg) - if err := i.validateConfig(); err != nil { + if err := i.validateConfig(&cfg); err != nil { return nil, err } - headers := map[string]string{ - "Authorization": "Bearer " + openAIConfig.apiKey, - "Content-Type": "application/json", + // Create OpenAI client + var clientOptions []option.RequestOption + clientOptions = append(clientOptions, option.WithAPIKey(cfg.APIKey)) + + // If a custom baseURL is set, use it + if cfg.BaseURL != "" { + clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL)) } - client := common.NewHTTPClient(openAIConfig.baseURL, headers) - return &OpenAIProvider{client: client, cfg: openAIConfig}, nil -} -type OpenAIProvider struct { - client *common.HTTPClient - cfg openAIProviderConfig -} + // Create OpenAI client + client := openai.NewClient(clientOptions...) -type openAIChatCompletionRequest struct { - Model string `json:"model"` - Messages []openAIChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` -} - -type openAIChatMessage struct { - Role string `json:"role"` - Content string `json:"content"` -} - -type openAIChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Choices []openAIChatCompletionResponseChoice `json:"choices"` - Error *openAIError `json:"error,omitempty"` -} - -type openAIChatCompletionResponseChoice struct { - Index int `json:"index"` - Message openAIChatMessage `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type openAIError struct { - Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code"` - Param string `json:"param"` + return &OpenAIProvider{ + client: &client, + model: cfg.Model, + temperature: cfg.Temperature, + maxTokens: cfg.MaxTokens, + }, nil } // GenerateCompletion implements Provider interface. func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) { - req := openAIChatCompletionRequest{ - Model: o.cfg.model, - Messages: []openAIChatMessage{ - {Role: "user", Content: prompt}, + // Create chat request + params := openai.ChatCompletionNewParams{ + Model: o.model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(prompt), }, - Temperature: o.cfg.temperature, - MaxTokens: o.cfg.maxTokens, } - body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req) + // Set optional parameters + if o.temperature > 0 { + temperature := float64(o.temperature) + params.Temperature = param.Opt[float64]{Value: temperature} + } + + if o.maxTokens > 0 { + maxTokens := int64(o.maxTokens) + params.MaxTokens = param.Opt[int64]{Value: maxTokens} + } + + // Send request + response, err := o.client.Chat.Completions.New(ctx, params) if err != nil { - return "", fmt.Errorf("openai llm post error: %w", err) + // Handle error + return "", fmt.Errorf("openai llm error: %w", err) } - var resp openAIChatCompletionResponse - if err := json.Unmarshal(body, &resp); err != nil { - return "", fmt.Errorf("openai llm unmarshal error: %w", err) - } - - if resp.Error != nil { - return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message) - } - - if len(resp.Choices) == 0 { + // Check response + if len(response.Choices) == 0 { return "", errors.New("openai llm: empty choices") } - return resp.Choices[0].Message.Content, nil + // Return generated content + return response.Choices[0].Message.Content, nil } func (o *OpenAIProvider) GetProviderType() string { diff --git a/plugins/golang-filter/mcp-server/servers/rag/rag_client.go b/plugins/golang-filter/mcp-server/servers/rag/rag_client.go index 4ed826c8f..a1494c2af 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/rag_client.go +++ b/plugins/golang-filter/mcp-server/servers/rag/rag_client.go @@ -56,18 +56,12 @@ func NewRAGClient(config *config.Config) (*RAGClient, error) { ragclient.llmProvider = llmProvider } - demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization") - if err != nil { - return nil, fmt.Errorf("create init embedding failed, err: %w", err) - } - dim := len(demoVector) - + dim := ragclient.config.Embedding.Dimensions provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim) if err != nil { return nil, fmt.Errorf("create vector store provider failed, err: %w", err) } ragclient.vectordbProvider = provider - return ragclient, nil } diff --git a/plugins/golang-filter/mcp-server/servers/rag/rag_client_test.go b/plugins/golang-filter/mcp-server/servers/rag/rag_client_test.go index 07e3b5094..26bad2638 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/rag_client_test.go +++ b/plugins/golang-filter/mcp-server/servers/rag/rag_client_test.go @@ -22,15 +22,17 @@ func getRAGClient() (*RAGClient, error) { LLM: config.LLMConfig{ Provider: "openai", - APIKey: "sk-xxxx", + APIKey: "sk-xxx", BaseURL: "https://openrouter.ai/api/v1", Model: "openai/gpt-4o", }, Embedding: config.EmbeddingConfig{ - Provider: "dashscope", - APIKey: "sk-xxxx", - Model: "text-embedding-v4", + Provider: "openai", + BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + APIKey: "sk-xxxx", + Model: "text-embedding-v4", + Dimensions: 1536, }, VectorDB: config.VectorDBConfig{ @@ -38,7 +40,49 @@ func getRAGClient() (*RAGClient, error) { Host: "localhost", Port: 19530, Database: "default", - Collection: "test_collection", + Collection: "test_collection3", + Mapping: config.MappingConfig{ + Fields: []config.FieldMapping{ + { + StandardName: "id", + RawName: "pk", + Properties: map[string]interface{}{ + "max_length": 256, + "auto_id": false, + }, + }, + { + StandardName: "content", + RawName: "page_content", + Properties: map[string]interface{}{ + "max_length": 8192, + }, + }, + { + StandardName: "vector", + RawName: "page_vector", + Properties: make(map[string]interface{}), + }, + { + StandardName: "metadata", + RawName: "metadata", + Properties: make(map[string]interface{}), + }, + { + StandardName: "created_at", + RawName: "created_at", + Properties: make(map[string]interface{}), + }, + }, + Index: config.IndexConfig{ + IndexType: "IVF_FLAT", + Params: map[string]interface{}{"nlist": 64}, + }, + Search: config.SearchConfig{ + MetricType: "COSINE", + Params: map[string]interface{}{"nprobe": 32}, + }, + }, }, } @@ -48,7 +92,6 @@ func getRAGClient() (*RAGClient, error) { } return ragClient, nil - } func TestNewRAGClient(t *testing.T) { @@ -104,7 +147,7 @@ func TestRAGClient_DeleteChunk(t *testing.T) { return } - chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2" + chunk_id := "2a06679c-a8ea-46dc-bf1c-7e7b164a73c8" err = ragClient.DeleteChunk(chunk_id) if err != nil { t.Errorf("DeleteChunk() error = %v", err) diff --git a/plugins/golang-filter/mcp-server/servers/rag/server.go b/plugins/golang-filter/mcp-server/servers/rag/server.go index 4e054d692..8088ab453 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/server.go +++ b/plugins/golang-filter/mcp-server/servers/rag/server.go @@ -36,11 +36,11 @@ func init() { MaxTokens: 2048, }, Embedding: config.EmbeddingConfig{ - Provider: "dashscope", - APIKey: "", - BaseURL: "", - Model: "text-embedding-v4", - Dimension: 1024, + Provider: "openai", + APIKey: "", + BaseURL: "", + Model: "text-embedding-ada-002", + Dimensions: 1536, }, VectorDB: config.VectorDBConfig{ Provider: "milvus", @@ -50,14 +50,56 @@ func init() { Collection: "rag", Username: "", Password: "", + Mapping: config.MappingConfig{ + Fields: []config.FieldMapping{ + { + StandardName: "id", + RawName: "id", + Properties: map[string]interface{}{ + "max_length": 256, + "auto_id": false, + }, + }, + { + StandardName: "content", + RawName: "content", + Properties: map[string]interface{}{ + "max_length": 8192, + }, + }, + { + StandardName: "vector", + RawName: "vector", + Properties: make(map[string]interface{}), + }, + { + StandardName: "metadata", + RawName: "metadata", + Properties: make(map[string]interface{}), + }, + { + StandardName: "created_at", + RawName: "created_at", + Properties: make(map[string]interface{}), + }, + }, + Index: config.IndexConfig{ + IndexType: "HNSW", + Params: map[string]interface{}{"M": 8, "efConstruction": 64}, + }, + Search: config.SearchConfig{ + MetricType: "IP", + Params: make(map[string]interface{}), + }, + }, }, }, }) } -func (c *RAGConfig) ParseConfig(config map[string]any) error { +func (c *RAGConfig) ParseConfig(cfg map[string]any) error { // Parse RAG configuration - if ragConfig, ok := config["rag"].(map[string]any); ok { + if ragConfig, ok := cfg["rag"].(map[string]any); ok { if splitter, exists := ragConfig["splitter"].(map[string]any); exists { if splitterType, exists := splitter["provider"].(string); exists { c.config.RAG.Splitter.Provider = splitterType @@ -78,7 +120,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error { } // Parse Embedding configuration - if embeddingConfig, ok := config["embedding"].(map[string]any); ok { + if embeddingConfig, ok := cfg["embedding"].(map[string]any); ok { if provider, exists := embeddingConfig["provider"].(string); exists { c.config.Embedding.Provider = provider } else { @@ -94,13 +136,13 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error { if model, exists := embeddingConfig["model"].(string); exists { c.config.Embedding.Model = model } - if dimension, exists := embeddingConfig["dimension"].(float64); exists { - c.config.Embedding.Dimension = int(dimension) + if dimensions, exists := embeddingConfig["dimensions"].(float64); exists { + c.config.Embedding.Dimensions = int(dimensions) } } // Parse llm configuration - if llmConfig, ok := config["llm"].(map[string]any); ok { + if llmConfig, ok := cfg["llm"].(map[string]any); ok { if provider, exists := llmConfig["provider"].(string); exists { c.config.LLM.Provider = provider } @@ -122,7 +164,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error { } // Parse VectorDB configuration - if vectordbConfig, ok := config["vectordb"].(map[string]any); ok { + if vectordbConfig, ok := cfg["vectordb"].(map[string]any); ok { if provider, exists := vectordbConfig["provider"].(string); exists { c.config.VectorDB.Provider = provider } else { @@ -146,8 +188,59 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error { if password, exists := vectordbConfig["password"].(string); exists { c.config.VectorDB.Password = password } - } + // Parse mapping here + if mapping, exists := vectordbConfig["mapping"].(map[string]any); exists { + // Parse field mappings + if fields, ok := mapping["fields"].([]any); ok { + c.config.VectorDB.Mapping.Fields = []config.FieldMapping{} + for _, field := range fields { + if fieldMap, ok := field.(map[string]any); ok { + fieldMapping := config.FieldMapping{ + Properties: make(map[string]interface{}), + } + if standardName, ok := fieldMap["standard_name"].(string); ok { + fieldMapping.StandardName = standardName + } + + if rawName, ok := fieldMap["raw_name"].(string); ok { + fieldMapping.RawName = rawName + } + // Parse properties + if properties, ok := fieldMap["properties"].(map[string]any); ok { + for key, value := range properties { + fieldMapping.Properties[key] = value + } + } + c.config.VectorDB.Mapping.Fields = append(c.config.VectorDB.Mapping.Fields, fieldMapping) + } + } + } + + // Parse index configuration + if index, ok := mapping["index"].(map[string]any); ok { + if indexType, ok := index["index_type"].(string); ok { + c.config.VectorDB.Mapping.Index.IndexType = indexType + } + + // Parse index parameters + if params, ok := index["params"].(map[string]any); ok { + c.config.VectorDB.Mapping.Index.Params = params + } + } + + // Parse search configuration + if search, ok := mapping["search"].(map[string]any); ok { + if metricType, ok := search["metric_type"].(string); ok { + c.config.VectorDB.Mapping.Search.MetricType = metricType + } + // Parse search parameters + if params, ok := search["params"].(map[string]any); ok { + c.config.VectorDB.Mapping.Search.Params = params + } + } + } + } return nil } diff --git a/plugins/golang-filter/mcp-server/servers/rag/server_test.go b/plugins/golang-filter/mcp-server/servers/rag/server_test.go index 9d1e4fd50..c6184e0b3 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/server_test.go +++ b/plugins/golang-filter/mcp-server/servers/rag/server_test.go @@ -28,11 +28,11 @@ func TestRAGConfig_ParseConfig(t *testing.T) { MaxTokens: 2048, }, Embedding: config.EmbeddingConfig{ - Provider: "dashscope", - APIKey: "sk-XXX", - BaseURL: "", - Model: "text-embedding-v4", - Dimension: 1024, + Provider: "dashscope", + APIKey: "sk-XXX", + BaseURL: "", + Model: "text-embedding-v4", + Dimensions: 1024, }, VectorDB: config.VectorDBConfig{ Provider: "milvus", @@ -42,6 +42,48 @@ func TestRAGConfig_ParseConfig(t *testing.T) { Collection: "test_rag", Username: "", Password: "", + Mapping: config.MappingConfig{ + Fields: []config.FieldMapping{ + { + StandardName: "id", + RawName: "id", + Properties: map[string]interface{}{ + "max_length": 256, + "auto_id": false, + }, + }, + { + StandardName: "content", + RawName: "content", + Properties: map[string]interface{}{ + "max_length": 8192, + }, + }, + { + StandardName: "vector", + RawName: "vector", + Properties: make(map[string]interface{}), + }, + { + StandardName: "metadata", + RawName: "metadata", + Properties: make(map[string]interface{}), + }, + { + StandardName: "created_at", + RawName: "created_at", + Properties: make(map[string]interface{}), + }, + }, + Index: config.IndexConfig{ + IndexType: "HNSW", + Params: map[string]interface{}{"M": 4, "efConstruction": 32}, + }, + Search: config.SearchConfig{ + MetricType: "IP", + Params: map[string]interface{}{"ef": 32}, + }, + }, }, } // 把 config 输出 yaml 格式 diff --git a/plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go b/plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go new file mode 100644 index 000000000..55b456924 --- /dev/null +++ b/plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go @@ -0,0 +1,182 @@ +package vectordb + +import ( + "errors" + "fmt" + + "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" +) + +// Error definitions +var ( + ErrFieldNotFound = errors.New("field not found") + ErrInvalidFieldType = errors.New("invalid field type") + ErrInvalidIndexType = errors.New("invalid index type") + ErrInvalidMetricType = errors.New("invalid metric type") + ErrInvalidSearchParams = errors.New("invalid search parameters") + ErrCollectionNotFound = errors.New("collection not found") + ErrUnsupportedOperation = errors.New("unsupported operation") +) + +// VectorDBMapper interface for vector database mapping +type VectorDBMapper interface { + // ParseMapping parses the mapping configuration + ParseMapping(provider string, cfg config.MappingConfig) error + + // GetIndexConfig returns the index configuration + GetIndexConfig() (config.IndexConfig, error) + + // GetSearchConfig returns the search configuration + GetSearchConfig() (config.SearchConfig, error) + + // Get all raw field names + GetRawAllFieldNames() ([]string, error) + + // GetIDField returns the ID field mapping + GetIDField() (*config.FieldMapping, error) + + // GetVectorField returns the vector field mapping + GetVectorField() (*config.FieldMapping, error) + + // Get raw field name by standard field name + GetRawField(standardFieldName string) (*config.FieldMapping, error) + + // Get field mapping by raw field name + GetField(rawFieldName string) (*config.FieldMapping, error) + + // Get all field mappings + GetFieldMappings() ([]config.FieldMapping, error) +} + +// DefaultVectorDBMapper is the default implementation of VectorDBMapper interface +type DefaultVectorDBMapper struct { + // Mapping configuration + mappingConfig config.MappingConfig + // Map from standard field name to field mapping + standardFieldMap map[string]*config.FieldMapping + // Map from raw field name to field mapping + rawFieldMap map[string]*config.FieldMapping +} + +// NewDefaultVectorDBMapper creates a new default vector database mapper +func NewDefaultVectorDBMapper(provider string, mappingConfig config.MappingConfig) (*DefaultVectorDBMapper, error) { + mapper := &DefaultVectorDBMapper{ + standardFieldMap: make(map[string]*config.FieldMapping), + rawFieldMap: make(map[string]*config.FieldMapping), + } + if err := mapper.ParseMapping(provider, mappingConfig); err != nil { + return nil, err + } + return mapper, nil +} + +// ParseMapping parses the mapping configuration +func (m *DefaultVectorDBMapper) ParseMapping(provider string, cfg config.MappingConfig) error { + m.mappingConfig = cfg + // Clear existing mappings + m.standardFieldMap = make(map[string]*config.FieldMapping) + m.rawFieldMap = make(map[string]*config.FieldMapping) + // fill default field mappings + if len(cfg.Fields) == 0 { + defaultFields := []config.FieldMapping{ + { + StandardName: "id", + RawName: "id", + Properties: map[string]interface{}{ + "max_length": 256, + "auto_id": false, + }, + }, + { + StandardName: "content", + RawName: "content", + Properties: map[string]interface{}{ + "max_length": 8192, + }, + }, + { + StandardName: "vector", + RawName: "vector", + }, + { + StandardName: "metadata", + RawName: "metadata", + }, + { + StandardName: "created_at", + RawName: "created_at", + }, + } + cfg.Fields = defaultFields + } + + // Parse field mappings + for i, field := range cfg.Fields { + // Save pointer for future reference + fieldPtr := &cfg.Fields[i] + m.standardFieldMap[field.StandardName] = fieldPtr + m.rawFieldMap[field.RawName] = fieldPtr + } + + // Check fields, must include id, content, vector fields + requiredFields := []string{"id", "content", "vector"} + for _, fieldName := range requiredFields { + if _, err := m.GetRawField(fieldName); err != nil { + return fmt.Errorf("[vector db mapper] required field %s not found or not varchar type", fieldName) + } + } + + return nil +} + +// GetIndexConfig gets the index configuration +func (m *DefaultVectorDBMapper) GetIndexConfig() (config.IndexConfig, error) { + return m.mappingConfig.Index, nil +} + +// GetSearchConfig gets the search configuration +func (m *DefaultVectorDBMapper) GetSearchConfig() (config.SearchConfig, error) { + return m.mappingConfig.Search, nil +} + +// GetRawAllFieldNames gets all raw field names +func (m *DefaultVectorDBMapper) GetRawAllFieldNames() ([]string, error) { + fieldNames := make([]string, 0, len(m.rawFieldMap)) + for name := range m.rawFieldMap { + fieldNames = append(fieldNames, name) + } + return fieldNames, nil +} + +// GetIDField gets the ID field +func (m *DefaultVectorDBMapper) GetIDField() (*config.FieldMapping, error) { + return m.GetRawField("id") +} + +// GetVectorField gets the vector field +func (m *DefaultVectorDBMapper) GetVectorField() (*config.FieldMapping, error) { + return m.GetRawField("vector") +} + +// GetRawField gets the raw field mapping by standard field name +func (m *DefaultVectorDBMapper) GetRawField(standardFieldName string) (*config.FieldMapping, error) { + field, exists := m.standardFieldMap[standardFieldName] + if !exists { + return nil, fmt.Errorf("%w: standard field %s not found", ErrFieldNotFound, standardFieldName) + } + return field, nil +} + +// GetField gets the field mapping by raw field name +func (m *DefaultVectorDBMapper) GetField(rawFieldName string) (*config.FieldMapping, error) { + field, exists := m.rawFieldMap[rawFieldName] + if !exists { + return nil, fmt.Errorf("%w: raw field %s not found", ErrFieldNotFound, rawFieldName) + } + return field, nil +} + +// GetFieldMappings gets all field mappings +func (m *DefaultVectorDBMapper) GetFieldMappings() ([]config.FieldMapping, error) { + return m.mappingConfig.Fields, nil +} diff --git a/plugins/golang-filter/mcp-server/servers/rag/vectordb/milvus.go b/plugins/golang-filter/mcp-server/servers/rag/vectordb/milvus.go index 7e6e8b198..66c8706f1 100644 --- a/plugins/golang-filter/mcp-server/servers/rag/vectordb/milvus.go +++ b/plugins/golang-filter/mcp-server/servers/rag/vectordb/milvus.go @@ -80,16 +80,17 @@ func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, d type MilvusProvider struct { client client.Client config *config.VectorDBConfig - Collection string + collection string + mapper VectorDBMapper + dimensions int } // NewMilvusProvider creates a new instance of MilvusProvider -func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) { +func NewMilvusProvider(cfg *config.VectorDBConfig, dimensions int) (VectorStoreProvider, error) { // Create Milvus client connectParam := client.Config{ Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), } - connectParam.DBName = cfg.Database // Add authentication if credentials are provided if cfg.Username != "" && cfg.Password != "" { @@ -102,92 +103,301 @@ func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider return nil, fmt.Errorf("failed to create milvus client: %w", err) } + mapper, err := NewDefaultVectorDBMapper(MILVUS_PROVIDER_TYPE, cfg.Mapping) + if err != nil { + return nil, fmt.Errorf("failed to create default vector db mapper: %w", err) + } + provider := &MilvusProvider{ client: milvusClient, config: cfg, - Collection: cfg.Collection, + collection: cfg.Collection, + mapper: mapper, + dimensions: dimensions, } - ctx := context.Background() - if err := provider.CreateCollection(ctx, dim); err != nil { + if err := provider.CreateCollection(ctx, dimensions); err != nil { return nil, err } return provider, nil } +func (m *MilvusProvider) buildSchema() (*entity.Schema, error) { + // Create Milvus collection Schema + idField, _ := m.mapper.GetIDField() + isIDAuto := idField.IsAutoID() + schema := entity.NewSchema(). + WithName(m.collection). + WithDescription("Knowledge document collection"). + WithAutoID(isIDAuto). + WithDynamicFieldEnabled(false) + // Add fields + var fieldEntity *entity.Field + fieldMappings, _ := m.mapper.GetFieldMappings() + for _, field := range fieldMappings { + fieldEntity = nil + maxLength := field.MaxLength() + switch field.StandardName { + case "id": + isIDAuto := field.IsAutoID() + fieldEntity = entity.NewField(). + WithName(field.RawName). + WithDataType(entity.FieldTypeVarChar). + WithMaxLength(int64(maxLength)). + WithIsPrimaryKey(true) + if isIDAuto { + fieldEntity.WithIsAutoID(true) + } + schema.WithField(fieldEntity) + case "content": + fieldEntity = entity.NewField(). + WithName(field.RawName). + WithDataType(entity.FieldTypeVarChar). + WithMaxLength(int64(maxLength)) + schema.WithField(fieldEntity) + case "vector": + fieldEntity = entity.NewField(). + WithName(field.RawName). + WithDataType(entity.FieldTypeFloatVector). + WithDim(int64(m.dimensions)) + schema.WithField(fieldEntity) + case "metadata": + fieldEntity = entity.NewField(). + WithName(field.RawName). + WithDataType(entity.FieldTypeJSON) + schema.WithField(fieldEntity) + case "created_at": + fieldEntity = entity.NewField(). + WithName(field.RawName). + WithDataType(entity.FieldTypeInt64) + schema.WithField(fieldEntity) + } + } + return schema, nil +} + +func (m *MilvusProvider) GetMetricType(metricType string) entity.MetricType { + switch strings.ToUpper(metricType) { + case "L2": + return entity.L2 + case "IP": + return entity.IP + case "COSINE": + return entity.COSINE + case "HAMMING": + return entity.HAMMING + case "JACCARD": + return entity.JACCARD + case "TANIMOTO": + return entity.TANIMOTO + case "SUBSTRUCTURE": + return entity.SUBSTRUCTURE + case "SUPERSTRUCTURE": + return entity.SUPERSTRUCTURE + default: + return entity.IP + } +} + +func (m *MilvusProvider) buildVectorIndex() (entity.Index, error) { + // Map index type + indexConfig, _ := m.mapper.GetIndexConfig() + searchConfig, _ := m.mapper.GetSearchConfig() + // Map index parameters + milvusIndexType := strings.ToUpper(indexConfig.IndexType) + if milvusIndexType == "" { + milvusIndexType = "HNSW" + } + metricType := m.GetMetricType(searchConfig.MetricType) + switch milvusIndexType { + case "FLAT": + // FLAT index doesn't need additional parameters + index, err := entity.NewIndexFlat(metricType) + if err != nil { + return nil, fmt.Errorf("failed to create FLAT index: %w", err) + } + return index, nil + + case "BIN_FLAT": + // BIN_FLAT index doesn't need additional parameters + nlist := 128 + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + index, err := entity.NewIndexBinFlat(metricType, nlist) + if err != nil { + return nil, fmt.Errorf("failed to create BIN_FLAT index: %w", err) + } + return index, nil + + case "IVF_FLAT": + // Default parameters + nlist := 128 + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + index, err := entity.NewIndexIvfFlat(metricType, nlist) + if err != nil { + return nil, fmt.Errorf("failed to create IVF_FLAT index: %w", err) + } + return index, nil + + case "BIN_IVF_FLAT": + // Default parameters + nlist := 128 + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + index, err := entity.NewIndexBinIvfFlat(metricType, nlist) + if err != nil { + return nil, fmt.Errorf("failed to create BIN_IVF_FLAT index: %w", err) + } + return index, nil + + case "IVF_SQ8": + // Default parameters + nlist := 128 + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + index, err := entity.NewIndexIvfSQ8(metricType, nlist) + if err != nil { + return nil, fmt.Errorf("failed to create IVF_SQ8 index: %w", err) + } + return index, nil + + case "IVF_PQ": + // Default parameters + nlist := 128 + m := 4 + nbits := 8 + + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + if mVal, err := indexConfig.ParamsFloat64("m"); err == nil { + m = int(mVal) + } + if nbitsVal, err := indexConfig.ParamsInt64("nbits"); err == nil { + nbits = int(nbitsVal) + } + + index, err := entity.NewIndexIvfPQ(metricType, nlist, m, nbits) + if err != nil { + return nil, fmt.Errorf("failed to create IVF_PQ index: %w", err) + } + return index, nil + + case "HNSW": + // Default parameters + m := 8 + efConstruction := 64 + if mVal, err := indexConfig.ParamsInt64("M"); err == nil { + m = int(mVal) + } + if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil { + efConstruction = int(efConstructionVal) + } + index, err := entity.NewIndexHNSW(metricType, m, efConstruction) + if err != nil { + return nil, fmt.Errorf("failed to create HNSW index: %w", err) + } + return index, nil + + case "IVF_HNSW": + // Default parameters + nlist := 128 + m := 8 + efConstruction := 64 + + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + if mVal, err := indexConfig.ParamsInt64("M"); err == nil { + m = int(mVal) + } + + if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil { + efConstruction = int(efConstructionVal) + } + + index, err := entity.NewIndexIvfHNSW(metricType, nlist, m, efConstruction) + if err != nil { + return nil, fmt.Errorf("failed to create IVF_HNSW index: %w", err) + } + return index, nil + + case "DISKANN": + // DISKANN index parameters + index, err := entity.NewIndexDISKANN(metricType) + if err != nil { + return nil, fmt.Errorf("failed to create DISKANN index: %w", err) + } + return index, nil + + case "SCANN": + // SCANN index parameters + nlist := 128 + with_raw_data := false + if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil { + nlist = int(nlistVal) + } + if with_raw_dataVal, err := indexConfig.ParamsBool("with_raw_data"); err == nil { + with_raw_data = with_raw_dataVal + } + index, err := entity.NewIndexSCANN(metricType, nlist, with_raw_data) + if err != nil { + return nil, fmt.Errorf("failed to create SCANN index: %w", err) + } + return index, nil + + case "AUTOINDEX": + // Auto index + index, err := entity.NewIndexAUTOINDEX(metricType) + if err != nil { + return nil, fmt.Errorf("failed to create AUTOINDEX index: %w", err) + } + return index, nil + + default: + return nil, fmt.Errorf("unsupported index type: %s", milvusIndexType) + } +} + // CreateCollection creates a new collection with the specified dimension func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error { // Check if collection exists - document_exists, err := m.client.HasCollection(ctx, m.Collection) + document_exists, err := m.client.HasCollection(ctx, m.collection) if err != nil { - return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err) + return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err) } if !document_exists { - fmt.Printf("create collection %s\n", m.Collection) + fmt.Printf("create collection %s\n", m.collection) // Create schema - schema := entity.NewSchema(). - WithName(m.Collection). - WithDescription("Knowledge document collection"). - WithAutoID(false). - WithDynamicFieldEnabled(false) - - // Add fields based on schema.Document structure - // Primary key field - ID - pkField := entity.NewField(). - WithName("id"). - WithDataType(entity.FieldTypeVarChar). - WithMaxLength(256). - WithIsPrimaryKey(true). - WithIsAutoID(false) - schema.WithField(pkField) - - // Content field - contentField := entity.NewField(). - WithName("content"). - WithDataType(entity.FieldTypeVarChar). - WithMaxLength(8192) - schema.WithField(contentField) - - // Vector field - vectorField := entity.NewField(). - WithName("vector"). - WithDataType(entity.FieldTypeFloatVector). - WithDim(int64(dim)) - schema.WithField(vectorField) - - // Metadata field - metadataField := entity.NewField(). - WithName("metadata"). - WithDataType(entity.FieldTypeJSON) - schema.WithField(metadataField) - - // CreatedAt field (stored as Unix timestamp) - createdAtField := entity.NewField(). - WithName("created_at"). - WithDataType(entity.FieldTypeInt64) - schema.WithField(createdAtField) - + schema, err := m.buildSchema() + if err != nil { + return fmt.Errorf("failed to build schema: %w", err) + } // Create collection err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber) if err != nil { return fmt.Errorf("failed to create collection: %w", err) } - // Create vector index - vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64) + vectorIndex, err := m.buildVectorIndex() + vectorField, _ := m.mapper.GetVectorField() if err != nil { return fmt.Errorf("failed to create vector index: %w", err) } - err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index")) + err = m.client.CreateIndex(ctx, m.collection, vectorField.RawName, vectorIndex, false, client.WithIndexName("vector_index")) if err != nil { return fmt.Errorf("failed to create vector index: %w", err) } } - // Load collection - err = m.client.LoadCollection(ctx, m.Collection, false) + err = m.client.LoadCollection(ctx, m.collection, false) if err != nil { return fmt.Errorf("failed to load document collection: %w", err) } @@ -197,15 +407,15 @@ func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error { // DropCollection removes the collection from the database func (m *MilvusProvider) DropCollection(ctx context.Context) error { // Check if collection exists - exists, err := m.client.HasCollection(ctx, m.Collection) + exists, err := m.client.HasCollection(ctx, m.collection) if err != nil { - return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err) + return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err) } if !exists { - return fmt.Errorf("collection %s does not exist", m.Collection) + return fmt.Errorf("collection %s does not exist", m.collection) } // Drop collection - err = m.client.DropCollection(ctx, m.Collection) + err = m.client.DropCollection(ctx, m.collection) if err != nil { return fmt.Errorf("failed to drop collection: %w", err) } @@ -217,51 +427,71 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err if len(docs) == 0 { return nil } - // Prepare data - ids := make([]string, len(docs)) - contents := make([]string, len(docs)) - vectors := make([][]float32, len(docs)) - metadatas := make([][]byte, len(docs)) - createdAts := make([]int64, len(docs)) - for i, doc := range docs { - ids[i] = doc.ID - contents[i] = doc.Content - - // Convert vector type - vectorFloat32 := make([]float32, len(doc.Vector)) - for j, v := range doc.Vector { - vectorFloat32[j] = float32(v) - } - vectors[i] = vectorFloat32 - - // Serialize metadata - metadataBytes, err := json.Marshal(doc.Metadata) - if err != nil { - return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err) - } - metadatas[i] = metadataBytes - - createdAts[i] = doc.CreatedAt.UnixMilli() + // Get field mappings + fieldMappings, err := m.mapper.GetFieldMappings() + if err != nil { + return fmt.Errorf("failed to get field mappings: %w", err) } + // Prepare data and columns + columns := make([]entity.Column, 0, len(fieldMappings)) + // Create corresponding column data for each field + for _, field := range fieldMappings { + // Skip ID field if configured as auto ID + if field.IsPrimaryKey() && field.IsAutoID() { + continue + } + switch field.StandardName { + case "id": + // Handle string type fields + values := make([]string, len(docs)) + for i, doc := range docs { + values[i] = doc.ID + } + columns = append(columns, entity.NewColumnVarChar(field.RawName, values)) + case "content": + values := make([]string, len(docs)) + for i, doc := range docs { + values[i] = doc.Content + } + columns = append(columns, entity.NewColumnVarChar(field.RawName, values)) - // Build insert data - columns := []entity.Column{ - entity.NewColumnVarChar("id", ids), - entity.NewColumnVarChar("content", contents), - entity.NewColumnFloatVector("vector", len(vectors[0]), vectors), - entity.NewColumnJSONBytes("metadata", metadatas), - entity.NewColumnInt64("created_at", createdAts), + case "vector": + // Handle vector fields + vectors := make([][]float32, len(docs)) + for i, doc := range docs { + vectors[i] = doc.Vector + } + columns = append(columns, entity.NewColumnFloatVector(field.RawName, len(vectors[0]), vectors)) + case "metadata": + // Handle JSON type fields (like metadata) + values := make([][]byte, len(docs)) + for i, doc := range docs { + // Serialize metadata + metadataBytes, err := json.Marshal(doc.Metadata) + if err != nil { + return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err) + } + values[i] = metadataBytes + } + columns = append(columns, entity.NewColumnJSONBytes(field.RawName, values)) + case "created_at": + // Handle integer type fields + values := make([]int64, len(docs)) + for i, doc := range docs { + values[i] = doc.CreatedAt.UnixMilli() + } + columns = append(columns, entity.NewColumnInt64(field.RawName, values)) + } } - // Insert data - _, err := m.client.Insert(ctx, m.Collection, "", columns...) + _, err = m.client.Insert(ctx, m.collection, "", columns...) if err != nil { return fmt.Errorf("failed to insert documents: %w", err) } // Flush data - err = m.client.Flush(ctx, m.Collection, false) + err = m.client.Flush(ctx, m.collection, false) if err != nil { return fmt.Errorf("failed to flush collection: %w", err) } @@ -271,16 +501,19 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err // DeleteDoc deletes a document by its ID func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error { - // Build delete expression - expr := fmt.Sprintf(`id == "%s"`, id) + // Get ID field + idField, _ := m.mapper.GetIDField() + // Build delete expression using the RawName of ID field + expr := fmt.Sprintf(`%s == "%s"`, idField.RawName, id) + // Delete data - err := m.client.Delete(ctx, m.Collection, "", expr) + err := m.client.Delete(ctx, m.collection, "", expr) if err != nil { return fmt.Errorf("failed to delete documents for id %s: %w", id, err) } // Flush data - err = m.client.Flush(ctx, m.Collection, false) + err = m.client.Flush(ctx, m.collection, false) if err != nil { return fmt.Errorf("failed to flush collection after delete: %w", err) } @@ -306,24 +539,127 @@ func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document) return nil } +func (m *MilvusProvider) buildSearchParam() (entity.SearchParam, error) { + // Get index configuration + indexConfig, err := m.mapper.GetIndexConfig() + if err != nil { + return nil, fmt.Errorf("failed to get index config: %w", err) + } + + // Get search configuration + searchConfig, err := m.mapper.GetSearchConfig() + if err != nil { + return nil, fmt.Errorf("failed to get search config: %w", err) + } + + // Choose appropriate search parameters based on index type + milvusIndexType := strings.ToUpper(indexConfig.IndexType) + if milvusIndexType == "" { + milvusIndexType = "HNSW" // Default to HNSW index + } + + switch milvusIndexType { + case "FLAT": + // FLAT and BIN_FLAT indices don't need additional search parameters + return entity.NewIndexFlatSearchParam() + + case "BIN_FLAT", "IVF_FLAT", "BIN_IVF_FLAT", "IVF_SQ8": + // Search parameters for IVF series indices + nprobe := 16 // Default value + if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil { + nprobe = int(nprobeVal) + } + return entity.NewIndexIvfFlatSearchParam(nprobe) + + case "IVF_PQ": + // Search parameters for IVF_PQ index + nprobe := 16 // Default value + if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil { + nprobe = int(nprobeVal) + } + return entity.NewIndexIvfPQSearchParam(nprobe) + + case "HNSW": + // Search parameters for HNSW index + efSearch := 16 // Default value + if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil { + efSearch = int(efSearchVal) + } + return entity.NewIndexHNSWSearchParam(efSearch) + + case "IVF_HNSW": + // Search parameters for IVF_HNSW index + nprobe := 16 // Default value + efSearch := 64 // Default value + if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil { + nprobe = int(nprobeVal) + } + if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil { + efSearch = int(efSearchVal) + } + return entity.NewIndexIvfHNSWSearchParam(nprobe, efSearch) + + case "SCANN": + // Search parameters for SCANN index + nprobe := 16 // Default value + reorder_k := 64 + if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil { + nprobe = int(nprobeVal) + } + if reorderKVal, err := searchConfig.ParamsInt64("reorder_k"); err == nil { + reorder_k = int(reorderKVal) + } + return entity.NewIndexSCANNSearchParam(nprobe, reorder_k) + + case "DISKANN": + // Search parameters for DISKANN index + search_list := 100 // Default value + if searchListVal, err := searchConfig.ParamsInt64("search_list"); err == nil { + search_list = int(searchListVal) + } + return entity.NewIndexDISKANNSearchParam(search_list) + + case "AUTOINDEX": + level := 8 + if levelVal, err := searchConfig.ParamsInt64("level"); err == nil { + level = int(levelVal) + } + // Search parameters for AUTOINDEX index + return entity.NewIndexAUTOINDEXSearchParam(level) + default: + // Default to using HNSW search parameters + return entity.NewIndexHNSWSearchParam(16) + } +} + // SearchDocs performs similarity search for documents func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) { if options == nil { options = &schema.SearchOptions{TopK: 10} } + // Build search parameters - sp, _ := entity.NewIndexHNSWSearchParam(16) + sp, err := m.buildSearchParam() + if err != nil { + return nil, fmt.Errorf("failed to build search param: %w", err) + } + + outputFields, _ := m.mapper.GetRawAllFieldNames() + vectorField, _ := m.mapper.GetVectorField() + searchConfig, _ := m.mapper.GetSearchConfig() + metricType := m.GetMetricType(searchConfig.MetricType) + // Build filter expression expr := "" searchResults, err := m.client.Search( ctx, - m.Collection, - []string{}, // partition names - expr, // filter expression - []string{"id", "content", "metadata", "created_at"}, // output fields + m.collection, + []string{}, // partition names + expr, // filter expression + outputFields, // output fields []entity.Vector{entity.FloatVector(vector)}, - "vector", // anns_field - entity.IP, // metric_type + vectorField.RawName, // anns_field + metricType, // metric_type options.TopK, sp, ) @@ -341,9 +677,13 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio // Get field data var content string var metadata map[string]interface{} - for _, field := range result.Fields { - switch field.Name() { + fieldMapping, err := m.mapper.GetField(field.Name()) + if err != nil { + continue + } + fieldName := strings.ToLower(fieldMapping.StandardName) + switch fieldName { case "content": if contentCol, ok := field.(*entity.ColumnVarChar); ok { if contentVal, err := contentCol.Get(i); err == nil { @@ -364,7 +704,6 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio } } } - searchResult := schema.SearchResult{ Document: schema.Document{ ID: fmt.Sprintf("%s", id), @@ -392,15 +731,17 @@ func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error { for i, id := range ids { quotedIDs[i] = fmt.Sprintf("\"%s\"", id) } - expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ",")) + + idField, _ := m.mapper.GetIDField() + expr := fmt.Sprintf("%s in [%s]", idField.RawName, strings.Join(quotedIDs, ",")) // Delete data - err := m.client.Delete(ctx, m.Collection, "", expr) + err := m.client.Delete(ctx, m.collection, "", expr) if err != nil { return fmt.Errorf("failed to delete documents: %w", err) } // Flush data - err = m.client.Flush(ctx, m.Collection, false) + err = m.client.Flush(ctx, m.collection, false) if err != nil { return fmt.Errorf("failed to flush collection after delete: %w", err) } @@ -413,12 +754,13 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu // Build query expression expr := "" // Query all relevant documents + outputFields, _ := m.mapper.GetRawAllFieldNames() queryResult, err := m.client.Query( ctx, - m.Collection, + m.collection, []string{}, // partitions expr, // filter condition - []string{"id", "content", "metadata", "created_at"}, + outputFields, client.WithOffset(0), client.WithLimit(int64(limit)), ) @@ -443,7 +785,12 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu ) for _, col := range queryResult { - switch col.Name() { + fieldMapping, err := m.mapper.GetField(col.Name()) + if err != nil { + continue + } + fieldName := strings.ToLower(fieldMapping.StandardName) + switch fieldName { case "id": if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil { id = v.(string) @@ -488,8 +835,3 @@ func (m *MilvusProvider) Close() error { } return nil } - -// joinStrings joins a slice of strings with the given separator -func joinStrings(elems []string, sep string) string { - return strings.Join(elems, sep) -}