mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
add vectordb mapping (#2968)
This commit is contained in:
@@ -84,10 +84,11 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
|
||||
| llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 |
|
||||
| llm.temperature | float | 可选 | 0.5 | 温度参数 |
|
||||
| **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) |
|
||||
| embedding.provider | string | 必填 | dashscope | 嵌入提供商:openai或dashscope |
|
||||
| embedding.provider | string | 必填 | openai | 嵌入提供商:支持openai协议的任意供应商 |
|
||||
| embedding.api_key | string | 必填 | - | 嵌入API密钥 |
|
||||
| embedding.base_url | string | 可选 | | 嵌入API基础URL |
|
||||
| embedding.model | string | 必填 | text-embedding-v4 | 嵌入模型名称 |
|
||||
| embedding.model | string | 必填 | text-embedding-ada-002 | 嵌入模型名称 |
|
||||
| embedding.dimensions | integer | 可选 | 1536 | 嵌入维度 |
|
||||
| **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) |
|
||||
| vectordb.provider | string | 必填 | milvus | 向量数据库提供商 |
|
||||
| vectordb.host | string | 必填 | localhost | 数据库主机地址 |
|
||||
@@ -96,6 +97,17 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
|
||||
| vectordb.collection | string | 必填 | test_collection | 集合名称 |
|
||||
| vectordb.username | string | 可选 | - | 数据库用户名 |
|
||||
| vectordb.password | string | 可选 | - | 数据库密码 |
|
||||
| **vectordb.mapping** | object | 可选 | - | 字段映射配置 |
|
||||
| vectordb.mapping.fields | array | 可选 | - | 字段映射列表 |
|
||||
| vectordb.mapping.fields[].standard_name | string | 必填 | - | 标准字段名称(如 id, content, vector 等) |
|
||||
| vectordb.mapping.fields[].raw_name | string | 必填 | - | 原始字段名称(数据库中的实际字段名) |
|
||||
| vectordb.mapping.fields[].properties | object | 可选 | - | 字段属性(如 auto_id, max_length 等) |
|
||||
| vectordb.mapping.index | object | 可选 | - | 索引配置 |
|
||||
| vectordb.mapping.index.index_type | string | 必填 | - | 索引类型(如 FLAT, IVF_FLAT, HNSW 等) |
|
||||
| vectordb.mapping.index.params | object | 可选 | - | 索引参数(根据索引类型不同而异) |
|
||||
| vectordb.mapping.search | object | 可选 | - | 搜索配置 |
|
||||
| vectordb.mapping.search.metric_type | string | 可选 | L2 | 度量类型(如 L2, IP, COSINE 等) |
|
||||
| vectordb.mapping.search.params | object | 可选 | - | 搜索参数(如 nprobe, ef_search 等)
|
||||
|
||||
|
||||
### higress-config 配置样例
|
||||
@@ -143,27 +155,54 @@ data:
|
||||
temperature: 0.5
|
||||
max_tokens: 2048
|
||||
embedding:
|
||||
provider: dashscope
|
||||
provider: openai
|
||||
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
api_key: sk-xxx
|
||||
model: text-embedding-v4
|
||||
dimensions: 1536
|
||||
vectordb:
|
||||
provider: milvus
|
||||
host: <milvus IP>
|
||||
host: localhost
|
||||
port: 19530
|
||||
database: default
|
||||
collection: test_collection
|
||||
```
|
||||
collection: test_rag
|
||||
mapping:
|
||||
fields:
|
||||
- standard_name: id
|
||||
raw_name: id
|
||||
properties:
|
||||
auto_id: false
|
||||
max_length: 256
|
||||
- standard_name: content
|
||||
raw_name: content
|
||||
properties:
|
||||
max_length: 8192
|
||||
- standard_name: vector
|
||||
raw_name: vector
|
||||
- standard_name: metadata
|
||||
raw_name: metadata
|
||||
- standard_name: created_at
|
||||
raw_name: created_at
|
||||
index:
|
||||
index_type: HNSW
|
||||
params:
|
||||
M: 4
|
||||
efConstruction: 32
|
||||
search:
|
||||
metric_type: IP
|
||||
params:
|
||||
ef: 32
|
||||
|
||||
```
|
||||
### 支持的提供商
|
||||
#### Embedding
|
||||
- **OpenAI**
|
||||
- **DashScope**
|
||||
- **OpenAI 兼容**
|
||||
|
||||
#### Vector Database
|
||||
- **Milvus**
|
||||
|
||||
#### LLM
|
||||
- **OpenAI**
|
||||
- **OpenAI 兼容**
|
||||
|
||||
## 如何测试数据集的效果
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Config represents the main configuration structure for the MCP server
|
||||
type Config struct {
|
||||
RAG RAGConfig `json:"rag" yaml:"rag"`
|
||||
@@ -34,20 +36,148 @@ type LLMConfig struct {
|
||||
|
||||
// EmbeddingConfig defines configuration for embedding models
|
||||
type EmbeddingConfig struct {
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"`
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
|
||||
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty" yaml:"dimension,omitempty"`
|
||||
}
|
||||
|
||||
// VectorDBConfig defines configuration for vector databases
|
||||
type VectorDBConfig struct {
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
|
||||
Host string `json:"host,omitempty" yaml:"host,omitempty"`
|
||||
Port int `json:"port,omitempty" yaml:"port,omitempty"`
|
||||
Database string `json:"database,omitempty" yaml:"database,omitempty"`
|
||||
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
|
||||
Host string `json:"host,omitempty" yaml:"host,omitempty"`
|
||||
Port int `json:"port,omitempty" yaml:"port,omitempty"`
|
||||
Database string `json:"database,omitempty" yaml:"database,omitempty"`
|
||||
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
Mapping MappingConfig `json:"mapping,omitempty" yaml:"mapping,omitempty"`
|
||||
}
|
||||
|
||||
// MappingConfig defines field mapping configuration for vector databases
|
||||
type MappingConfig struct {
|
||||
Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
|
||||
Index IndexConfig `json:"index,omitempty" yaml:"index,omitempty"`
|
||||
Search SearchConfig `json:"search,omitempty" yaml:"search,omitempty"`
|
||||
}
|
||||
|
||||
// // CollectionMapping defines field mapping for collection
|
||||
// type CollectionMapping struct {
|
||||
// Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
|
||||
// }
|
||||
|
||||
// FieldMapping defines mapping for a single field
|
||||
type FieldMapping struct {
|
||||
StandardName string `json:"standard_name" yaml:"standard_name"`
|
||||
RawName string `json:"raw_name" yaml:"raw_name"`
|
||||
Properties map[string]interface{} `json:"properties,omitempty" yaml:"properties,omitempty"`
|
||||
}
|
||||
|
||||
func (f FieldMapping) IsPrimaryKey() bool {
|
||||
return f.StandardName == "id"
|
||||
}
|
||||
|
||||
func (f FieldMapping) IsAutoID() bool {
|
||||
if f.Properties == nil {
|
||||
return false
|
||||
}
|
||||
autoID, ok := f.Properties["auto_id"].(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return autoID
|
||||
}
|
||||
|
||||
func (f FieldMapping) IsVectorField() bool {
|
||||
return f.StandardName == "vector"
|
||||
}
|
||||
|
||||
func (f FieldMapping) MaxLength() int {
|
||||
if f.Properties == nil {
|
||||
return 0
|
||||
}
|
||||
maxLength, ok := f.Properties["max_length"].(int)
|
||||
if !ok {
|
||||
return 256
|
||||
}
|
||||
return maxLength
|
||||
}
|
||||
|
||||
// IndexConfig defines configuration for index parameters
|
||||
type IndexConfig struct {
|
||||
// Index type, e.g., IVF_FLAT, IVF_SQ8, HNSW, etc.
|
||||
IndexType string `json:"index_type" yaml:"index_type"`
|
||||
// Index parameter configuration
|
||||
Params map[string]interface{} `json:"params" yaml:"params"`
|
||||
}
|
||||
|
||||
func (i IndexConfig) ParamsString(key string) (string, error) {
|
||||
if mVal, ok := i.Params[key].(string); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
return "", fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
func (i IndexConfig) ParamsInt64(key string) (int64, error) {
|
||||
if mVal, ok := i.Params[key].(int64); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
if mVal, ok := i.Params[key].(int); ok {
|
||||
return int64(mVal), nil
|
||||
}
|
||||
return 0, fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
func (i IndexConfig) ParamsFloat64(key string) (float64, error) {
|
||||
if mVal, ok := i.Params[key].(float64); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
if mVal, ok := i.Params[key].(float32); ok {
|
||||
return float64(mVal), nil
|
||||
}
|
||||
return 0, fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
func (i IndexConfig) ParamsBool(key string) (bool, error) {
|
||||
if mVal, ok := i.Params[key].(bool); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
return false, fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
// SearchConfig defines configuration for search parameters
|
||||
type SearchConfig struct {
|
||||
// Metric type, e.g., L2, IP, etc.
|
||||
MetricType string `json:"metric_type,omitempty" yaml:"metric_type,omitempty"`
|
||||
// Search parameter configuration
|
||||
Params map[string]interface{} `json:"params" yaml:"params"`
|
||||
}
|
||||
|
||||
func (i SearchConfig) ParamsString(key string) (string, error) {
|
||||
if mVal, ok := i.Params[key].(string); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
return "", fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
func (i SearchConfig) ParamsInt64(key string) (int64, error) {
|
||||
if mVal, ok := i.Params[key].(int64); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
return 0, fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
func (i SearchConfig) ParamsFloat64(key string) (float64, error) {
|
||||
if mVal, ok := i.Params[key].(float64); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
return 0, fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
func (i SearchConfig) ParamsBool(key string) (bool, error) {
|
||||
if mVal, ok := i.Params[key].(bool); ok {
|
||||
return mVal, nil
|
||||
}
|
||||
return false, fmt.Errorf("params %s not found", key)
|
||||
}
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
const (
|
||||
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
|
||||
DASHSCOPE_PORT = 443
|
||||
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4"
|
||||
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
)
|
||||
|
||||
var dashScopeConfig dashScopeProviderConfig
|
||||
|
||||
type dashScopeProviderInitializer struct {
|
||||
}
|
||||
type dashScopeProviderConfig struct {
|
||||
apiKey string
|
||||
model string
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) {
|
||||
dashScopeConfig.apiKey = config.APIKey
|
||||
dashScopeConfig.model = config.Model
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) ValidateConfig() error {
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
return errors.New("[DashScope] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
||||
c.InitConfig(config)
|
||||
err := c.ValidateConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + config.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers)
|
||||
|
||||
return &DashScopeProvider{
|
||||
config: dashScopeConfig,
|
||||
client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_DASHSCOPE
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
}
|
||||
|
||||
type Params struct {
|
||||
TextType string `json:"text_type"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameters Params `json:"parameters"`
|
||||
}
|
||||
|
||||
type Document struct {
|
||||
Vector []float64 `json:"vector"`
|
||||
Fields map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
type DashScopeProvider struct {
|
||||
config dashScopeProviderConfig
|
||||
client *common.HTTPClient
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) {
|
||||
model := d.config.model
|
||||
if model == "" {
|
||||
model = DASHSCOPE_DEFAULT_MODEL_NAME
|
||||
}
|
||||
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
return EmbeddingRequest{}, errors.New("dashScopeKey is empty")
|
||||
}
|
||||
|
||||
data := EmbeddingRequest{
|
||||
Model: model,
|
||||
Input: Input{
|
||||
Texts: texts,
|
||||
},
|
||||
Parameters: Params{
|
||||
TextType: "query",
|
||||
},
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
ID string `json:"id"`
|
||||
Vector []float32 `json:"vector,omitempty"`
|
||||
Fields map[string]interface{} `json:"fields"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
|
||||
var resp Response
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (d *DashScopeProvider) GetEmbedding(
|
||||
ctx context.Context,
|
||||
queryString string) ([]float32, error) {
|
||||
requestData, err := d.constructRequestData([]string{queryString})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct request data: %v", err)
|
||||
}
|
||||
|
||||
responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %v", err)
|
||||
}
|
||||
|
||||
embeddingResp, err := d.parseTextEmbedding(responseBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if len(embeddingResp.Output.Embeddings) == 0 {
|
||||
return nil, errors.New("no embedding found in response")
|
||||
}
|
||||
|
||||
return embeddingResp.Output.Embeddings[0].Embedding, nil
|
||||
}
|
||||
@@ -2,160 +2,93 @@ package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/openai/openai-go/v2"
|
||||
"github.com/openai/openai-go/v2/option"
|
||||
)
|
||||
|
||||
const (
|
||||
OPENAI_DOMAIN = "api.openai.com"
|
||||
OPENAI_PORT = 443
|
||||
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
|
||||
OPENAI_ENDPOINT = "/v1/embeddings"
|
||||
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-ada-002"
|
||||
)
|
||||
|
||||
type openAIProviderInitializer struct {
|
||||
}
|
||||
|
||||
var openAIConfig openAIProviderConfig
|
||||
|
||||
type openAIProviderConfig struct {
|
||||
baseUrl string
|
||||
apiKey string
|
||||
model string
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) {
|
||||
openAIConfig.apiKey = config.APIKey
|
||||
openAIConfig.model = config.Model
|
||||
openAIConfig.baseUrl = config.BaseURL
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) ValidateConfig() error {
|
||||
if openAIConfig.apiKey == "" {
|
||||
return errors.New("[openAI] apiKey is required")
|
||||
func (c *openAIProviderInitializer) validateConfig(config *config.EmbeddingConfig) error {
|
||||
if config.APIKey == "" {
|
||||
return errors.New("[openai embbeding] apiKey is required")
|
||||
}
|
||||
if config.Model == "" {
|
||||
config.Model = OPENAI_DEFAULT_MODEL_NAME
|
||||
}
|
||||
if config.Dimensions <= 0 {
|
||||
config.Dimensions = 1536
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
||||
c.InitConfig(config)
|
||||
err := c.ValidateConfig()
|
||||
if err != nil {
|
||||
if err := c.validateConfig(&config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 创建 OpenAI 客户端
|
||||
var clientOptions []option.RequestOption
|
||||
clientOptions = append(clientOptions, option.WithAPIKey(config.APIKey))
|
||||
|
||||
if openAIConfig.model == "" {
|
||||
openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME
|
||||
// 如果设置了自定义 baseURL,则使用它
|
||||
if config.BaseURL != "" {
|
||||
clientOptions = append(clientOptions, option.WithBaseURL(config.BaseURL))
|
||||
}
|
||||
|
||||
if openAIConfig.baseUrl == "" {
|
||||
openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN)
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + config.APIKey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers)
|
||||
// 创建 OpenAI 客户端
|
||||
client := openai.NewClient(clientOptions...)
|
||||
|
||||
return &OpenAIProvider{
|
||||
config: openAIConfig,
|
||||
client: httpClient,
|
||||
client: &client,
|
||||
model: config.Model,
|
||||
dimensions: config.Dimensions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) GetProviderType() string {
|
||||
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
|
||||
type OpenAIProvider struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
dimensions int
|
||||
}
|
||||
|
||||
func (e *OpenAIProvider) GetProviderType() string {
|
||||
return PROVIDER_TYPE_OPENAI
|
||||
}
|
||||
|
||||
type OpenAIResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIResult `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Error *OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAIResult struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"prompt_tokens"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Param string `json:"param"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
config openAIProviderConfig
|
||||
client *common.HTTPClient
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) {
|
||||
if text == "" {
|
||||
return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty")
|
||||
// GetEmbedding generates vector embedding for the given text
|
||||
func (e *OpenAIProvider) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||
params := openai.EmbeddingNewParams{
|
||||
Model: e.model,
|
||||
Input: openai.EmbeddingNewParamsInputUnion{
|
||||
OfString: openai.String(text),
|
||||
},
|
||||
Dimensions: openai.Int(int64(e.dimensions)),
|
||||
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
|
||||
}
|
||||
|
||||
if openAIConfig.apiKey == "" {
|
||||
return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty")
|
||||
}
|
||||
|
||||
model := o.config.model
|
||||
if model == "" {
|
||||
model = OPENAI_DEFAULT_MODEL_NAME
|
||||
}
|
||||
|
||||
data := OpenAIEmbeddingRequest{
|
||||
Input: text,
|
||||
Model: model,
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
|
||||
var resp OpenAIResponse
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
embeddingResp, err := e.client.Embeddings.New(ctx, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) {
|
||||
requestData, err := o.constructRequestData(queryString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to construct request data: %v", err)
|
||||
}
|
||||
|
||||
responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := o.parseTextEmbedding(responseBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message)
|
||||
}
|
||||
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, errors.New("no embedding found in response")
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
|
||||
if len(embeddingResp.Data) == 0 {
|
||||
return nil, fmt.Errorf("empty embedding response")
|
||||
}
|
||||
|
||||
// Convert []float64 to []float32
|
||||
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
|
||||
for i, v := range embeddingResp.Data[0].Embedding {
|
||||
embedding[i] = float32(v)
|
||||
}
|
||||
|
||||
return embedding, nil
|
||||
}
|
||||
|
||||
@@ -10,21 +10,21 @@ import (
|
||||
// Provider type constants for different embedding services
|
||||
const (
|
||||
// DashScope embedding service
|
||||
PROVIDER_TYPE_DASHSCOPE = "dashscope"
|
||||
PROVIDER_TYPE_DASHSCOPE = "dashscope"
|
||||
// TextIn embedding service
|
||||
PROVIDER_TYPE_TEXTIN = "textin"
|
||||
PROVIDER_TYPE_TEXTIN = "textin"
|
||||
// Cohere embedding service
|
||||
PROVIDER_TYPE_COHERE = "cohere"
|
||||
PROVIDER_TYPE_COHERE = "cohere"
|
||||
// OpenAI embedding service
|
||||
PROVIDER_TYPE_OPENAI = "openai"
|
||||
PROVIDER_TYPE_OPENAI = "openai"
|
||||
// Ollama embedding service
|
||||
PROVIDER_TYPE_OLLAMA = "ollama"
|
||||
PROVIDER_TYPE_OLLAMA = "ollama"
|
||||
// HuggingFace embedding service
|
||||
PROVIDER_TYPE_HUGGINGFACE = "huggingface"
|
||||
// XFYun embedding service
|
||||
PROVIDER_TYPE_XFYUN = "xfyun"
|
||||
PROVIDER_TYPE_XFYUN = "xfyun"
|
||||
// Azure embedding service
|
||||
PROVIDER_TYPE_AZURE = "azure"
|
||||
PROVIDER_TYPE_AZURE = "azure"
|
||||
)
|
||||
|
||||
// Factory interface for creating Provider instances
|
||||
@@ -36,8 +36,7 @@ type providerInitializer interface {
|
||||
// Maps provider types to their initializers
|
||||
var (
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
|
||||
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
||||
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,133 +2,105 @@ package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
"github.com/openai/openai-go/v2"
|
||||
"github.com/openai/openai-go/v2/option"
|
||||
"github.com/openai/openai-go/v2/packages/param"
|
||||
)
|
||||
|
||||
const (
|
||||
OPENAI_CHAT_ENDPOINT = "/chat/completions"
|
||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||
)
|
||||
|
||||
// openAI specific configuration captured after initialization.
|
||||
type openAIProviderConfig struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
type OpenAIProvider struct {
|
||||
client *openai.Client
|
||||
model string
|
||||
maxTokens int
|
||||
temperature float64
|
||||
maxTokens int
|
||||
}
|
||||
|
||||
type openAIProviderInitializer struct{}
|
||||
|
||||
var openAIConfig openAIProviderConfig
|
||||
|
||||
func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) {
|
||||
openAIConfig.apiKey = c.APIKey
|
||||
openAIConfig.baseURL = c.BaseURL
|
||||
openAIConfig.model = c.Model
|
||||
if openAIConfig.model == "" {
|
||||
openAIConfig.model = OPENAI_DEFAULT_MODEL
|
||||
}
|
||||
if openAIConfig.baseURL == "" {
|
||||
openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint
|
||||
}
|
||||
openAIConfig.maxTokens = c.MaxTokens
|
||||
openAIConfig.temperature = c.Temperature
|
||||
}
|
||||
|
||||
func (i *openAIProviderInitializer) validateConfig() error {
|
||||
if openAIConfig.apiKey == "" {
|
||||
func (i *openAIProviderInitializer) validateConfig(cfg *config.LLMConfig) error {
|
||||
if cfg.APIKey == "" {
|
||||
return errors.New("[openai llm] apiKey is required")
|
||||
}
|
||||
if cfg.Model == "" {
|
||||
cfg.Model = OPENAI_DEFAULT_MODEL
|
||||
}
|
||||
|
||||
if cfg.Temperature <= 0 || cfg.Temperature > 2 {
|
||||
cfg.Temperature = 0.5
|
||||
}
|
||||
|
||||
if cfg.MaxTokens <= 0 {
|
||||
cfg.MaxTokens = 2048
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
|
||||
i.initConfig(cfg)
|
||||
if err := i.validateConfig(); err != nil {
|
||||
if err := i.validateConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers := map[string]string{
|
||||
"Authorization": "Bearer " + openAIConfig.apiKey,
|
||||
"Content-Type": "application/json",
|
||||
// Create OpenAI client
|
||||
var clientOptions []option.RequestOption
|
||||
clientOptions = append(clientOptions, option.WithAPIKey(cfg.APIKey))
|
||||
|
||||
// If a custom baseURL is set, use it
|
||||
if cfg.BaseURL != "" {
|
||||
clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
client := common.NewHTTPClient(openAIConfig.baseURL, headers)
|
||||
return &OpenAIProvider{client: client, cfg: openAIConfig}, nil
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
client *common.HTTPClient
|
||||
cfg openAIProviderConfig
|
||||
}
|
||||
// Create OpenAI client
|
||||
client := openai.NewClient(clientOptions...)
|
||||
|
||||
type openAIChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []openAIChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type openAIChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Choices []openAIChatCompletionResponseChoice `json:"choices"`
|
||||
Error *openAIError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type openAIChatCompletionResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message openAIChatMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type openAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Param string `json:"param"`
|
||||
return &OpenAIProvider{
|
||||
client: &client,
|
||||
model: cfg.Model,
|
||||
temperature: cfg.Temperature,
|
||||
maxTokens: cfg.MaxTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateCompletion implements Provider interface.
|
||||
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
req := openAIChatCompletionRequest{
|
||||
Model: o.cfg.model,
|
||||
Messages: []openAIChatMessage{
|
||||
{Role: "user", Content: prompt},
|
||||
// Create chat request
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: o.model,
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage(prompt),
|
||||
},
|
||||
Temperature: o.cfg.temperature,
|
||||
MaxTokens: o.cfg.maxTokens,
|
||||
}
|
||||
|
||||
body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req)
|
||||
// Set optional parameters
|
||||
if o.temperature > 0 {
|
||||
temperature := float64(o.temperature)
|
||||
params.Temperature = param.Opt[float64]{Value: temperature}
|
||||
}
|
||||
|
||||
if o.maxTokens > 0 {
|
||||
maxTokens := int64(o.maxTokens)
|
||||
params.MaxTokens = param.Opt[int64]{Value: maxTokens}
|
||||
}
|
||||
|
||||
// Send request
|
||||
response, err := o.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("openai llm post error: %w", err)
|
||||
// Handle error
|
||||
return "", fmt.Errorf("openai llm error: %w", err)
|
||||
}
|
||||
|
||||
var resp openAIChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
return "", fmt.Errorf("openai llm unmarshal error: %w", err)
|
||||
}
|
||||
|
||||
if resp.Error != nil {
|
||||
return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
// Check response
|
||||
if len(response.Choices) == 0 {
|
||||
return "", errors.New("openai llm: empty choices")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
// Return generated content
|
||||
return response.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (o *OpenAIProvider) GetProviderType() string {
|
||||
|
||||
@@ -56,18 +56,12 @@ func NewRAGClient(config *config.Config) (*RAGClient, error) {
|
||||
ragclient.llmProvider = llmProvider
|
||||
}
|
||||
|
||||
demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create init embedding failed, err: %w", err)
|
||||
}
|
||||
dim := len(demoVector)
|
||||
|
||||
dim := ragclient.config.Embedding.Dimensions
|
||||
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
|
||||
}
|
||||
ragclient.vectordbProvider = provider
|
||||
|
||||
return ragclient, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -22,15 +22,17 @@ func getRAGClient() (*RAGClient, error) {
|
||||
|
||||
LLM: config.LLMConfig{
|
||||
Provider: "openai",
|
||||
APIKey: "sk-xxxx",
|
||||
APIKey: "sk-xxx",
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
Model: "openai/gpt-4o",
|
||||
},
|
||||
|
||||
Embedding: config.EmbeddingConfig{
|
||||
Provider: "dashscope",
|
||||
APIKey: "sk-xxxx",
|
||||
Model: "text-embedding-v4",
|
||||
Provider: "openai",
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
APIKey: "sk-xxxx",
|
||||
Model: "text-embedding-v4",
|
||||
Dimensions: 1536,
|
||||
},
|
||||
|
||||
VectorDB: config.VectorDBConfig{
|
||||
@@ -38,7 +40,49 @@ func getRAGClient() (*RAGClient, error) {
|
||||
Host: "localhost",
|
||||
Port: 19530,
|
||||
Database: "default",
|
||||
Collection: "test_collection",
|
||||
Collection: "test_collection3",
|
||||
Mapping: config.MappingConfig{
|
||||
Fields: []config.FieldMapping{
|
||||
{
|
||||
StandardName: "id",
|
||||
RawName: "pk",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 256,
|
||||
"auto_id": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "content",
|
||||
RawName: "page_content",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 8192,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "vector",
|
||||
RawName: "page_vector",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
{
|
||||
StandardName: "metadata",
|
||||
RawName: "metadata",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
{
|
||||
StandardName: "created_at",
|
||||
RawName: "created_at",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
},
|
||||
Index: config.IndexConfig{
|
||||
IndexType: "IVF_FLAT",
|
||||
Params: map[string]interface{}{"nlist": 64},
|
||||
},
|
||||
Search: config.SearchConfig{
|
||||
MetricType: "COSINE",
|
||||
Params: map[string]interface{}{"nprobe": 32},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -48,7 +92,6 @@ func getRAGClient() (*RAGClient, error) {
|
||||
}
|
||||
|
||||
return ragClient, nil
|
||||
|
||||
}
|
||||
|
||||
func TestNewRAGClient(t *testing.T) {
|
||||
@@ -104,7 +147,7 @@ func TestRAGClient_DeleteChunk(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2"
|
||||
chunk_id := "2a06679c-a8ea-46dc-bf1c-7e7b164a73c8"
|
||||
err = ragClient.DeleteChunk(chunk_id)
|
||||
if err != nil {
|
||||
t.Errorf("DeleteChunk() error = %v", err)
|
||||
|
||||
@@ -36,11 +36,11 @@ func init() {
|
||||
MaxTokens: 2048,
|
||||
},
|
||||
Embedding: config.EmbeddingConfig{
|
||||
Provider: "dashscope",
|
||||
APIKey: "",
|
||||
BaseURL: "",
|
||||
Model: "text-embedding-v4",
|
||||
Dimension: 1024,
|
||||
Provider: "openai",
|
||||
APIKey: "",
|
||||
BaseURL: "",
|
||||
Model: "text-embedding-ada-002",
|
||||
Dimensions: 1536,
|
||||
},
|
||||
VectorDB: config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
@@ -50,14 +50,56 @@ func init() {
|
||||
Collection: "rag",
|
||||
Username: "",
|
||||
Password: "",
|
||||
Mapping: config.MappingConfig{
|
||||
Fields: []config.FieldMapping{
|
||||
{
|
||||
StandardName: "id",
|
||||
RawName: "id",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 256,
|
||||
"auto_id": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "content",
|
||||
RawName: "content",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 8192,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "vector",
|
||||
RawName: "vector",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
{
|
||||
StandardName: "metadata",
|
||||
RawName: "metadata",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
{
|
||||
StandardName: "created_at",
|
||||
RawName: "created_at",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
},
|
||||
Index: config.IndexConfig{
|
||||
IndexType: "HNSW",
|
||||
Params: map[string]interface{}{"M": 8, "efConstruction": 64},
|
||||
},
|
||||
Search: config.SearchConfig{
|
||||
MetricType: "IP",
|
||||
Params: make(map[string]interface{}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
||||
func (c *RAGConfig) ParseConfig(cfg map[string]any) error {
|
||||
// Parse RAG configuration
|
||||
if ragConfig, ok := config["rag"].(map[string]any); ok {
|
||||
if ragConfig, ok := cfg["rag"].(map[string]any); ok {
|
||||
if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
|
||||
if splitterType, exists := splitter["provider"].(string); exists {
|
||||
c.config.RAG.Splitter.Provider = splitterType
|
||||
@@ -78,7 +120,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
||||
}
|
||||
|
||||
// Parse Embedding configuration
|
||||
if embeddingConfig, ok := config["embedding"].(map[string]any); ok {
|
||||
if embeddingConfig, ok := cfg["embedding"].(map[string]any); ok {
|
||||
if provider, exists := embeddingConfig["provider"].(string); exists {
|
||||
c.config.Embedding.Provider = provider
|
||||
} else {
|
||||
@@ -94,13 +136,13 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
||||
if model, exists := embeddingConfig["model"].(string); exists {
|
||||
c.config.Embedding.Model = model
|
||||
}
|
||||
if dimension, exists := embeddingConfig["dimension"].(float64); exists {
|
||||
c.config.Embedding.Dimension = int(dimension)
|
||||
if dimensions, exists := embeddingConfig["dimensions"].(float64); exists {
|
||||
c.config.Embedding.Dimensions = int(dimensions)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse llm configuration
|
||||
if llmConfig, ok := config["llm"].(map[string]any); ok {
|
||||
if llmConfig, ok := cfg["llm"].(map[string]any); ok {
|
||||
if provider, exists := llmConfig["provider"].(string); exists {
|
||||
c.config.LLM.Provider = provider
|
||||
}
|
||||
@@ -122,7 +164,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
||||
}
|
||||
|
||||
// Parse VectorDB configuration
|
||||
if vectordbConfig, ok := config["vectordb"].(map[string]any); ok {
|
||||
if vectordbConfig, ok := cfg["vectordb"].(map[string]any); ok {
|
||||
if provider, exists := vectordbConfig["provider"].(string); exists {
|
||||
c.config.VectorDB.Provider = provider
|
||||
} else {
|
||||
@@ -146,8 +188,59 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
||||
if password, exists := vectordbConfig["password"].(string); exists {
|
||||
c.config.VectorDB.Password = password
|
||||
}
|
||||
}
|
||||
|
||||
// Parse mapping here
|
||||
if mapping, exists := vectordbConfig["mapping"].(map[string]any); exists {
|
||||
// Parse field mappings
|
||||
if fields, ok := mapping["fields"].([]any); ok {
|
||||
c.config.VectorDB.Mapping.Fields = []config.FieldMapping{}
|
||||
for _, field := range fields {
|
||||
if fieldMap, ok := field.(map[string]any); ok {
|
||||
fieldMapping := config.FieldMapping{
|
||||
Properties: make(map[string]interface{}),
|
||||
}
|
||||
if standardName, ok := fieldMap["standard_name"].(string); ok {
|
||||
fieldMapping.StandardName = standardName
|
||||
}
|
||||
|
||||
if rawName, ok := fieldMap["raw_name"].(string); ok {
|
||||
fieldMapping.RawName = rawName
|
||||
}
|
||||
// Parse properties
|
||||
if properties, ok := fieldMap["properties"].(map[string]any); ok {
|
||||
for key, value := range properties {
|
||||
fieldMapping.Properties[key] = value
|
||||
}
|
||||
}
|
||||
c.config.VectorDB.Mapping.Fields = append(c.config.VectorDB.Mapping.Fields, fieldMapping)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse index configuration
|
||||
if index, ok := mapping["index"].(map[string]any); ok {
|
||||
if indexType, ok := index["index_type"].(string); ok {
|
||||
c.config.VectorDB.Mapping.Index.IndexType = indexType
|
||||
}
|
||||
|
||||
// Parse index parameters
|
||||
if params, ok := index["params"].(map[string]any); ok {
|
||||
c.config.VectorDB.Mapping.Index.Params = params
|
||||
}
|
||||
}
|
||||
|
||||
// Parse search configuration
|
||||
if search, ok := mapping["search"].(map[string]any); ok {
|
||||
if metricType, ok := search["metric_type"].(string); ok {
|
||||
c.config.VectorDB.Mapping.Search.MetricType = metricType
|
||||
}
|
||||
// Parse search parameters
|
||||
if params, ok := search["params"].(map[string]any); ok {
|
||||
c.config.VectorDB.Mapping.Search.Params = params
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -28,11 +28,11 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
|
||||
MaxTokens: 2048,
|
||||
},
|
||||
Embedding: config.EmbeddingConfig{
|
||||
Provider: "dashscope",
|
||||
APIKey: "sk-XXX",
|
||||
BaseURL: "",
|
||||
Model: "text-embedding-v4",
|
||||
Dimension: 1024,
|
||||
Provider: "dashscope",
|
||||
APIKey: "sk-XXX",
|
||||
BaseURL: "",
|
||||
Model: "text-embedding-v4",
|
||||
Dimensions: 1024,
|
||||
},
|
||||
VectorDB: config.VectorDBConfig{
|
||||
Provider: "milvus",
|
||||
@@ -42,6 +42,48 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
|
||||
Collection: "test_rag",
|
||||
Username: "",
|
||||
Password: "",
|
||||
Mapping: config.MappingConfig{
|
||||
Fields: []config.FieldMapping{
|
||||
{
|
||||
StandardName: "id",
|
||||
RawName: "id",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 256,
|
||||
"auto_id": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "content",
|
||||
RawName: "content",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 8192,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "vector",
|
||||
RawName: "vector",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
{
|
||||
StandardName: "metadata",
|
||||
RawName: "metadata",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
{
|
||||
StandardName: "created_at",
|
||||
RawName: "created_at",
|
||||
Properties: make(map[string]interface{}),
|
||||
},
|
||||
},
|
||||
Index: config.IndexConfig{
|
||||
IndexType: "HNSW",
|
||||
Params: map[string]interface{}{"M": 4, "efConstruction": 32},
|
||||
},
|
||||
Search: config.SearchConfig{
|
||||
MetricType: "IP",
|
||||
Params: map[string]interface{}{"ef": 32},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// 把 config 输出 yaml 格式
|
||||
|
||||
182
plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go
Normal file
182
plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package vectordb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||
)
|
||||
|
||||
// Error definitions
|
||||
var (
|
||||
ErrFieldNotFound = errors.New("field not found")
|
||||
ErrInvalidFieldType = errors.New("invalid field type")
|
||||
ErrInvalidIndexType = errors.New("invalid index type")
|
||||
ErrInvalidMetricType = errors.New("invalid metric type")
|
||||
ErrInvalidSearchParams = errors.New("invalid search parameters")
|
||||
ErrCollectionNotFound = errors.New("collection not found")
|
||||
ErrUnsupportedOperation = errors.New("unsupported operation")
|
||||
)
|
||||
|
||||
// VectorDBMapper interface for vector database mapping
|
||||
type VectorDBMapper interface {
|
||||
// ParseMapping parses the mapping configuration
|
||||
ParseMapping(provider string, cfg config.MappingConfig) error
|
||||
|
||||
// GetIndexConfig returns the index configuration
|
||||
GetIndexConfig() (config.IndexConfig, error)
|
||||
|
||||
// GetSearchConfig returns the search configuration
|
||||
GetSearchConfig() (config.SearchConfig, error)
|
||||
|
||||
// Get all raw field names
|
||||
GetRawAllFieldNames() ([]string, error)
|
||||
|
||||
// GetIDField returns the ID field mapping
|
||||
GetIDField() (*config.FieldMapping, error)
|
||||
|
||||
// GetVectorField returns the vector field mapping
|
||||
GetVectorField() (*config.FieldMapping, error)
|
||||
|
||||
// Get raw field name by standard field name
|
||||
GetRawField(standardFieldName string) (*config.FieldMapping, error)
|
||||
|
||||
// Get field mapping by raw field name
|
||||
GetField(rawFieldName string) (*config.FieldMapping, error)
|
||||
|
||||
// Get all field mappings
|
||||
GetFieldMappings() ([]config.FieldMapping, error)
|
||||
}
|
||||
|
||||
// DefaultVectorDBMapper is the default implementation of VectorDBMapper interface
|
||||
type DefaultVectorDBMapper struct {
|
||||
// Mapping configuration
|
||||
mappingConfig config.MappingConfig
|
||||
// Map from standard field name to field mapping
|
||||
standardFieldMap map[string]*config.FieldMapping
|
||||
// Map from raw field name to field mapping
|
||||
rawFieldMap map[string]*config.FieldMapping
|
||||
}
|
||||
|
||||
// NewDefaultVectorDBMapper creates a new default vector database mapper
|
||||
func NewDefaultVectorDBMapper(provider string, mappingConfig config.MappingConfig) (*DefaultVectorDBMapper, error) {
|
||||
mapper := &DefaultVectorDBMapper{
|
||||
standardFieldMap: make(map[string]*config.FieldMapping),
|
||||
rawFieldMap: make(map[string]*config.FieldMapping),
|
||||
}
|
||||
if err := mapper.ParseMapping(provider, mappingConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mapper, nil
|
||||
}
|
||||
|
||||
// ParseMapping parses the mapping configuration
|
||||
func (m *DefaultVectorDBMapper) ParseMapping(provider string, cfg config.MappingConfig) error {
|
||||
m.mappingConfig = cfg
|
||||
// Clear existing mappings
|
||||
m.standardFieldMap = make(map[string]*config.FieldMapping)
|
||||
m.rawFieldMap = make(map[string]*config.FieldMapping)
|
||||
// fill default field mappings
|
||||
if len(cfg.Fields) == 0 {
|
||||
defaultFields := []config.FieldMapping{
|
||||
{
|
||||
StandardName: "id",
|
||||
RawName: "id",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 256,
|
||||
"auto_id": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "content",
|
||||
RawName: "content",
|
||||
Properties: map[string]interface{}{
|
||||
"max_length": 8192,
|
||||
},
|
||||
},
|
||||
{
|
||||
StandardName: "vector",
|
||||
RawName: "vector",
|
||||
},
|
||||
{
|
||||
StandardName: "metadata",
|
||||
RawName: "metadata",
|
||||
},
|
||||
{
|
||||
StandardName: "created_at",
|
||||
RawName: "created_at",
|
||||
},
|
||||
}
|
||||
cfg.Fields = defaultFields
|
||||
}
|
||||
|
||||
// Parse field mappings
|
||||
for i, field := range cfg.Fields {
|
||||
// Save pointer for future reference
|
||||
fieldPtr := &cfg.Fields[i]
|
||||
m.standardFieldMap[field.StandardName] = fieldPtr
|
||||
m.rawFieldMap[field.RawName] = fieldPtr
|
||||
}
|
||||
|
||||
// Check fields, must include id, content, vector fields
|
||||
requiredFields := []string{"id", "content", "vector"}
|
||||
for _, fieldName := range requiredFields {
|
||||
if _, err := m.GetRawField(fieldName); err != nil {
|
||||
return fmt.Errorf("[vector db mapper] required field %s not found or not varchar type", fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetIndexConfig gets the index configuration
|
||||
func (m *DefaultVectorDBMapper) GetIndexConfig() (config.IndexConfig, error) {
|
||||
return m.mappingConfig.Index, nil
|
||||
}
|
||||
|
||||
// GetSearchConfig gets the search configuration
|
||||
func (m *DefaultVectorDBMapper) GetSearchConfig() (config.SearchConfig, error) {
|
||||
return m.mappingConfig.Search, nil
|
||||
}
|
||||
|
||||
// GetRawAllFieldNames gets all raw field names
|
||||
func (m *DefaultVectorDBMapper) GetRawAllFieldNames() ([]string, error) {
|
||||
fieldNames := make([]string, 0, len(m.rawFieldMap))
|
||||
for name := range m.rawFieldMap {
|
||||
fieldNames = append(fieldNames, name)
|
||||
}
|
||||
return fieldNames, nil
|
||||
}
|
||||
|
||||
// GetIDField gets the ID field
|
||||
func (m *DefaultVectorDBMapper) GetIDField() (*config.FieldMapping, error) {
|
||||
return m.GetRawField("id")
|
||||
}
|
||||
|
||||
// GetVectorField gets the vector field
|
||||
func (m *DefaultVectorDBMapper) GetVectorField() (*config.FieldMapping, error) {
|
||||
return m.GetRawField("vector")
|
||||
}
|
||||
|
||||
// GetRawField gets the raw field mapping by standard field name
|
||||
func (m *DefaultVectorDBMapper) GetRawField(standardFieldName string) (*config.FieldMapping, error) {
|
||||
field, exists := m.standardFieldMap[standardFieldName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("%w: standard field %s not found", ErrFieldNotFound, standardFieldName)
|
||||
}
|
||||
return field, nil
|
||||
}
|
||||
|
||||
// GetField gets the field mapping by raw field name
|
||||
func (m *DefaultVectorDBMapper) GetField(rawFieldName string) (*config.FieldMapping, error) {
|
||||
field, exists := m.rawFieldMap[rawFieldName]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("%w: raw field %s not found", ErrFieldNotFound, rawFieldName)
|
||||
}
|
||||
return field, nil
|
||||
}
|
||||
|
||||
// GetFieldMappings gets all field mappings
|
||||
func (m *DefaultVectorDBMapper) GetFieldMappings() ([]config.FieldMapping, error) {
|
||||
return m.mappingConfig.Fields, nil
|
||||
}
|
||||
@@ -80,16 +80,17 @@ func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, d
|
||||
type MilvusProvider struct {
|
||||
client client.Client
|
||||
config *config.VectorDBConfig
|
||||
Collection string
|
||||
collection string
|
||||
mapper VectorDBMapper
|
||||
dimensions int
|
||||
}
|
||||
|
||||
// NewMilvusProvider creates a new instance of MilvusProvider
|
||||
func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
|
||||
func NewMilvusProvider(cfg *config.VectorDBConfig, dimensions int) (VectorStoreProvider, error) {
|
||||
// Create Milvus client
|
||||
connectParam := client.Config{
|
||||
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
}
|
||||
|
||||
connectParam.DBName = cfg.Database
|
||||
// Add authentication if credentials are provided
|
||||
if cfg.Username != "" && cfg.Password != "" {
|
||||
@@ -102,92 +103,301 @@ func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider
|
||||
return nil, fmt.Errorf("failed to create milvus client: %w", err)
|
||||
}
|
||||
|
||||
mapper, err := NewDefaultVectorDBMapper(MILVUS_PROVIDER_TYPE, cfg.Mapping)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default vector db mapper: %w", err)
|
||||
}
|
||||
|
||||
provider := &MilvusProvider{
|
||||
client: milvusClient,
|
||||
config: cfg,
|
||||
Collection: cfg.Collection,
|
||||
collection: cfg.Collection,
|
||||
mapper: mapper,
|
||||
dimensions: dimensions,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := provider.CreateCollection(ctx, dim); err != nil {
|
||||
if err := provider.CreateCollection(ctx, dimensions); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (m *MilvusProvider) buildSchema() (*entity.Schema, error) {
|
||||
// Create Milvus collection Schema
|
||||
idField, _ := m.mapper.GetIDField()
|
||||
isIDAuto := idField.IsAutoID()
|
||||
schema := entity.NewSchema().
|
||||
WithName(m.collection).
|
||||
WithDescription("Knowledge document collection").
|
||||
WithAutoID(isIDAuto).
|
||||
WithDynamicFieldEnabled(false)
|
||||
// Add fields
|
||||
var fieldEntity *entity.Field
|
||||
fieldMappings, _ := m.mapper.GetFieldMappings()
|
||||
for _, field := range fieldMappings {
|
||||
fieldEntity = nil
|
||||
maxLength := field.MaxLength()
|
||||
switch field.StandardName {
|
||||
case "id":
|
||||
isIDAuto := field.IsAutoID()
|
||||
fieldEntity = entity.NewField().
|
||||
WithName(field.RawName).
|
||||
WithDataType(entity.FieldTypeVarChar).
|
||||
WithMaxLength(int64(maxLength)).
|
||||
WithIsPrimaryKey(true)
|
||||
if isIDAuto {
|
||||
fieldEntity.WithIsAutoID(true)
|
||||
}
|
||||
schema.WithField(fieldEntity)
|
||||
case "content":
|
||||
fieldEntity = entity.NewField().
|
||||
WithName(field.RawName).
|
||||
WithDataType(entity.FieldTypeVarChar).
|
||||
WithMaxLength(int64(maxLength))
|
||||
schema.WithField(fieldEntity)
|
||||
case "vector":
|
||||
fieldEntity = entity.NewField().
|
||||
WithName(field.RawName).
|
||||
WithDataType(entity.FieldTypeFloatVector).
|
||||
WithDim(int64(m.dimensions))
|
||||
schema.WithField(fieldEntity)
|
||||
case "metadata":
|
||||
fieldEntity = entity.NewField().
|
||||
WithName(field.RawName).
|
||||
WithDataType(entity.FieldTypeJSON)
|
||||
schema.WithField(fieldEntity)
|
||||
case "created_at":
|
||||
fieldEntity = entity.NewField().
|
||||
WithName(field.RawName).
|
||||
WithDataType(entity.FieldTypeInt64)
|
||||
schema.WithField(fieldEntity)
|
||||
}
|
||||
}
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
func (m *MilvusProvider) GetMetricType(metricType string) entity.MetricType {
|
||||
switch strings.ToUpper(metricType) {
|
||||
case "L2":
|
||||
return entity.L2
|
||||
case "IP":
|
||||
return entity.IP
|
||||
case "COSINE":
|
||||
return entity.COSINE
|
||||
case "HAMMING":
|
||||
return entity.HAMMING
|
||||
case "JACCARD":
|
||||
return entity.JACCARD
|
||||
case "TANIMOTO":
|
||||
return entity.TANIMOTO
|
||||
case "SUBSTRUCTURE":
|
||||
return entity.SUBSTRUCTURE
|
||||
case "SUPERSTRUCTURE":
|
||||
return entity.SUPERSTRUCTURE
|
||||
default:
|
||||
return entity.IP
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MilvusProvider) buildVectorIndex() (entity.Index, error) {
|
||||
// Map index type
|
||||
indexConfig, _ := m.mapper.GetIndexConfig()
|
||||
searchConfig, _ := m.mapper.GetSearchConfig()
|
||||
// Map index parameters
|
||||
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
|
||||
if milvusIndexType == "" {
|
||||
milvusIndexType = "HNSW"
|
||||
}
|
||||
metricType := m.GetMetricType(searchConfig.MetricType)
|
||||
switch milvusIndexType {
|
||||
case "FLAT":
|
||||
// FLAT index doesn't need additional parameters
|
||||
index, err := entity.NewIndexFlat(metricType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create FLAT index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "BIN_FLAT":
|
||||
// BIN_FLAT index doesn't need additional parameters
|
||||
nlist := 128
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
index, err := entity.NewIndexBinFlat(metricType, nlist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create BIN_FLAT index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "IVF_FLAT":
|
||||
// Default parameters
|
||||
nlist := 128
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
index, err := entity.NewIndexIvfFlat(metricType, nlist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IVF_FLAT index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "BIN_IVF_FLAT":
|
||||
// Default parameters
|
||||
nlist := 128
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
index, err := entity.NewIndexBinIvfFlat(metricType, nlist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create BIN_IVF_FLAT index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "IVF_SQ8":
|
||||
// Default parameters
|
||||
nlist := 128
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
index, err := entity.NewIndexIvfSQ8(metricType, nlist)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IVF_SQ8 index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "IVF_PQ":
|
||||
// Default parameters
|
||||
nlist := 128
|
||||
m := 4
|
||||
nbits := 8
|
||||
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
if mVal, err := indexConfig.ParamsFloat64("m"); err == nil {
|
||||
m = int(mVal)
|
||||
}
|
||||
if nbitsVal, err := indexConfig.ParamsInt64("nbits"); err == nil {
|
||||
nbits = int(nbitsVal)
|
||||
}
|
||||
|
||||
index, err := entity.NewIndexIvfPQ(metricType, nlist, m, nbits)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IVF_PQ index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "HNSW":
|
||||
// Default parameters
|
||||
m := 8
|
||||
efConstruction := 64
|
||||
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
|
||||
m = int(mVal)
|
||||
}
|
||||
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
|
||||
efConstruction = int(efConstructionVal)
|
||||
}
|
||||
index, err := entity.NewIndexHNSW(metricType, m, efConstruction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HNSW index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "IVF_HNSW":
|
||||
// Default parameters
|
||||
nlist := 128
|
||||
m := 8
|
||||
efConstruction := 64
|
||||
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
|
||||
m = int(mVal)
|
||||
}
|
||||
|
||||
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
|
||||
efConstruction = int(efConstructionVal)
|
||||
}
|
||||
|
||||
index, err := entity.NewIndexIvfHNSW(metricType, nlist, m, efConstruction)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IVF_HNSW index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "DISKANN":
|
||||
// DISKANN index parameters
|
||||
index, err := entity.NewIndexDISKANN(metricType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DISKANN index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "SCANN":
|
||||
// SCANN index parameters
|
||||
nlist := 128
|
||||
with_raw_data := false
|
||||
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||
nlist = int(nlistVal)
|
||||
}
|
||||
if with_raw_dataVal, err := indexConfig.ParamsBool("with_raw_data"); err == nil {
|
||||
with_raw_data = with_raw_dataVal
|
||||
}
|
||||
index, err := entity.NewIndexSCANN(metricType, nlist, with_raw_data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SCANN index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
case "AUTOINDEX":
|
||||
// Auto index
|
||||
index, err := entity.NewIndexAUTOINDEX(metricType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create AUTOINDEX index: %w", err)
|
||||
}
|
||||
return index, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported index type: %s", milvusIndexType)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCollection creates a new collection with the specified dimension
|
||||
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
|
||||
// Check if collection exists
|
||||
document_exists, err := m.client.HasCollection(ctx, m.Collection)
|
||||
document_exists, err := m.client.HasCollection(ctx, m.collection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
|
||||
return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
|
||||
}
|
||||
|
||||
if !document_exists {
|
||||
fmt.Printf("create collection %s\n", m.Collection)
|
||||
fmt.Printf("create collection %s\n", m.collection)
|
||||
// Create schema
|
||||
schema := entity.NewSchema().
|
||||
WithName(m.Collection).
|
||||
WithDescription("Knowledge document collection").
|
||||
WithAutoID(false).
|
||||
WithDynamicFieldEnabled(false)
|
||||
|
||||
// Add fields based on schema.Document structure
|
||||
// Primary key field - ID
|
||||
pkField := entity.NewField().
|
||||
WithName("id").
|
||||
WithDataType(entity.FieldTypeVarChar).
|
||||
WithMaxLength(256).
|
||||
WithIsPrimaryKey(true).
|
||||
WithIsAutoID(false)
|
||||
schema.WithField(pkField)
|
||||
|
||||
// Content field
|
||||
contentField := entity.NewField().
|
||||
WithName("content").
|
||||
WithDataType(entity.FieldTypeVarChar).
|
||||
WithMaxLength(8192)
|
||||
schema.WithField(contentField)
|
||||
|
||||
// Vector field
|
||||
vectorField := entity.NewField().
|
||||
WithName("vector").
|
||||
WithDataType(entity.FieldTypeFloatVector).
|
||||
WithDim(int64(dim))
|
||||
schema.WithField(vectorField)
|
||||
|
||||
// Metadata field
|
||||
metadataField := entity.NewField().
|
||||
WithName("metadata").
|
||||
WithDataType(entity.FieldTypeJSON)
|
||||
schema.WithField(metadataField)
|
||||
|
||||
// CreatedAt field (stored as Unix timestamp)
|
||||
createdAtField := entity.NewField().
|
||||
WithName("created_at").
|
||||
WithDataType(entity.FieldTypeInt64)
|
||||
schema.WithField(createdAtField)
|
||||
|
||||
schema, err := m.buildSchema()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build schema: %w", err)
|
||||
}
|
||||
// Create collection
|
||||
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create collection: %w", err)
|
||||
}
|
||||
|
||||
// Create vector index
|
||||
vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64)
|
||||
vectorIndex, err := m.buildVectorIndex()
|
||||
vectorField, _ := m.mapper.GetVectorField()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create vector index: %w", err)
|
||||
}
|
||||
|
||||
err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index"))
|
||||
err = m.client.CreateIndex(ctx, m.collection, vectorField.RawName, vectorIndex, false, client.WithIndexName("vector_index"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create vector index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load collection
|
||||
err = m.client.LoadCollection(ctx, m.Collection, false)
|
||||
err = m.client.LoadCollection(ctx, m.collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load document collection: %w", err)
|
||||
}
|
||||
@@ -197,15 +407,15 @@ func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
|
||||
// DropCollection removes the collection from the database
|
||||
func (m *MilvusProvider) DropCollection(ctx context.Context) error {
|
||||
// Check if collection exists
|
||||
exists, err := m.client.HasCollection(ctx, m.Collection)
|
||||
exists, err := m.client.HasCollection(ctx, m.collection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
|
||||
return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("collection %s does not exist", m.Collection)
|
||||
return fmt.Errorf("collection %s does not exist", m.collection)
|
||||
}
|
||||
// Drop collection
|
||||
err = m.client.DropCollection(ctx, m.Collection)
|
||||
err = m.client.DropCollection(ctx, m.collection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to drop collection: %w", err)
|
||||
}
|
||||
@@ -217,51 +427,71 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
|
||||
if len(docs) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Prepare data
|
||||
ids := make([]string, len(docs))
|
||||
contents := make([]string, len(docs))
|
||||
vectors := make([][]float32, len(docs))
|
||||
metadatas := make([][]byte, len(docs))
|
||||
createdAts := make([]int64, len(docs))
|
||||
|
||||
for i, doc := range docs {
|
||||
ids[i] = doc.ID
|
||||
contents[i] = doc.Content
|
||||
|
||||
// Convert vector type
|
||||
vectorFloat32 := make([]float32, len(doc.Vector))
|
||||
for j, v := range doc.Vector {
|
||||
vectorFloat32[j] = float32(v)
|
||||
}
|
||||
vectors[i] = vectorFloat32
|
||||
|
||||
// Serialize metadata
|
||||
metadataBytes, err := json.Marshal(doc.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
|
||||
}
|
||||
metadatas[i] = metadataBytes
|
||||
|
||||
createdAts[i] = doc.CreatedAt.UnixMilli()
|
||||
// Get field mappings
|
||||
fieldMappings, err := m.mapper.GetFieldMappings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get field mappings: %w", err)
|
||||
}
|
||||
// Prepare data and columns
|
||||
columns := make([]entity.Column, 0, len(fieldMappings))
|
||||
// Create corresponding column data for each field
|
||||
for _, field := range fieldMappings {
|
||||
// Skip ID field if configured as auto ID
|
||||
if field.IsPrimaryKey() && field.IsAutoID() {
|
||||
continue
|
||||
}
|
||||
switch field.StandardName {
|
||||
case "id":
|
||||
// Handle string type fields
|
||||
values := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
values[i] = doc.ID
|
||||
}
|
||||
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
|
||||
case "content":
|
||||
values := make([]string, len(docs))
|
||||
for i, doc := range docs {
|
||||
values[i] = doc.Content
|
||||
}
|
||||
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
|
||||
|
||||
// Build insert data
|
||||
columns := []entity.Column{
|
||||
entity.NewColumnVarChar("id", ids),
|
||||
entity.NewColumnVarChar("content", contents),
|
||||
entity.NewColumnFloatVector("vector", len(vectors[0]), vectors),
|
||||
entity.NewColumnJSONBytes("metadata", metadatas),
|
||||
entity.NewColumnInt64("created_at", createdAts),
|
||||
case "vector":
|
||||
// Handle vector fields
|
||||
vectors := make([][]float32, len(docs))
|
||||
for i, doc := range docs {
|
||||
vectors[i] = doc.Vector
|
||||
}
|
||||
columns = append(columns, entity.NewColumnFloatVector(field.RawName, len(vectors[0]), vectors))
|
||||
case "metadata":
|
||||
// Handle JSON type fields (like metadata)
|
||||
values := make([][]byte, len(docs))
|
||||
for i, doc := range docs {
|
||||
// Serialize metadata
|
||||
metadataBytes, err := json.Marshal(doc.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
|
||||
}
|
||||
values[i] = metadataBytes
|
||||
}
|
||||
columns = append(columns, entity.NewColumnJSONBytes(field.RawName, values))
|
||||
case "created_at":
|
||||
// Handle integer type fields
|
||||
values := make([]int64, len(docs))
|
||||
for i, doc := range docs {
|
||||
values[i] = doc.CreatedAt.UnixMilli()
|
||||
}
|
||||
columns = append(columns, entity.NewColumnInt64(field.RawName, values))
|
||||
}
|
||||
}
|
||||
|
||||
// Insert data
|
||||
_, err := m.client.Insert(ctx, m.Collection, "", columns...)
|
||||
_, err = m.client.Insert(ctx, m.collection, "", columns...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert documents: %w", err)
|
||||
}
|
||||
|
||||
// Flush data
|
||||
err = m.client.Flush(ctx, m.Collection, false)
|
||||
err = m.client.Flush(ctx, m.collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush collection: %w", err)
|
||||
}
|
||||
@@ -271,16 +501,19 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
|
||||
|
||||
// DeleteDoc deletes a document by its ID
|
||||
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
|
||||
// Build delete expression
|
||||
expr := fmt.Sprintf(`id == "%s"`, id)
|
||||
// Get ID field
|
||||
idField, _ := m.mapper.GetIDField()
|
||||
// Build delete expression using the RawName of ID field
|
||||
expr := fmt.Sprintf(`%s == "%s"`, idField.RawName, id)
|
||||
|
||||
// Delete data
|
||||
err := m.client.Delete(ctx, m.Collection, "", expr)
|
||||
err := m.client.Delete(ctx, m.collection, "", expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
|
||||
}
|
||||
|
||||
// Flush data
|
||||
err = m.client.Flush(ctx, m.Collection, false)
|
||||
err = m.client.Flush(ctx, m.collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
||||
}
|
||||
@@ -306,24 +539,127 @@ func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MilvusProvider) buildSearchParam() (entity.SearchParam, error) {
|
||||
// Get index configuration
|
||||
indexConfig, err := m.mapper.GetIndexConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get index config: %w", err)
|
||||
}
|
||||
|
||||
// Get search configuration
|
||||
searchConfig, err := m.mapper.GetSearchConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get search config: %w", err)
|
||||
}
|
||||
|
||||
// Choose appropriate search parameters based on index type
|
||||
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
|
||||
if milvusIndexType == "" {
|
||||
milvusIndexType = "HNSW" // Default to HNSW index
|
||||
}
|
||||
|
||||
switch milvusIndexType {
|
||||
case "FLAT":
|
||||
// FLAT and BIN_FLAT indices don't need additional search parameters
|
||||
return entity.NewIndexFlatSearchParam()
|
||||
|
||||
case "BIN_FLAT", "IVF_FLAT", "BIN_IVF_FLAT", "IVF_SQ8":
|
||||
// Search parameters for IVF series indices
|
||||
nprobe := 16 // Default value
|
||||
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||
nprobe = int(nprobeVal)
|
||||
}
|
||||
return entity.NewIndexIvfFlatSearchParam(nprobe)
|
||||
|
||||
case "IVF_PQ":
|
||||
// Search parameters for IVF_PQ index
|
||||
nprobe := 16 // Default value
|
||||
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||
nprobe = int(nprobeVal)
|
||||
}
|
||||
return entity.NewIndexIvfPQSearchParam(nprobe)
|
||||
|
||||
case "HNSW":
|
||||
// Search parameters for HNSW index
|
||||
efSearch := 16 // Default value
|
||||
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
|
||||
efSearch = int(efSearchVal)
|
||||
}
|
||||
return entity.NewIndexHNSWSearchParam(efSearch)
|
||||
|
||||
case "IVF_HNSW":
|
||||
// Search parameters for IVF_HNSW index
|
||||
nprobe := 16 // Default value
|
||||
efSearch := 64 // Default value
|
||||
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||
nprobe = int(nprobeVal)
|
||||
}
|
||||
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
|
||||
efSearch = int(efSearchVal)
|
||||
}
|
||||
return entity.NewIndexIvfHNSWSearchParam(nprobe, efSearch)
|
||||
|
||||
case "SCANN":
|
||||
// Search parameters for SCANN index
|
||||
nprobe := 16 // Default value
|
||||
reorder_k := 64
|
||||
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||
nprobe = int(nprobeVal)
|
||||
}
|
||||
if reorderKVal, err := searchConfig.ParamsInt64("reorder_k"); err == nil {
|
||||
reorder_k = int(reorderKVal)
|
||||
}
|
||||
return entity.NewIndexSCANNSearchParam(nprobe, reorder_k)
|
||||
|
||||
case "DISKANN":
|
||||
// Search parameters for DISKANN index
|
||||
search_list := 100 // Default value
|
||||
if searchListVal, err := searchConfig.ParamsInt64("search_list"); err == nil {
|
||||
search_list = int(searchListVal)
|
||||
}
|
||||
return entity.NewIndexDISKANNSearchParam(search_list)
|
||||
|
||||
case "AUTOINDEX":
|
||||
level := 8
|
||||
if levelVal, err := searchConfig.ParamsInt64("level"); err == nil {
|
||||
level = int(levelVal)
|
||||
}
|
||||
// Search parameters for AUTOINDEX index
|
||||
return entity.NewIndexAUTOINDEXSearchParam(level)
|
||||
default:
|
||||
// Default to using HNSW search parameters
|
||||
return entity.NewIndexHNSWSearchParam(16)
|
||||
}
|
||||
}
|
||||
|
||||
// SearchDocs performs similarity search for documents
|
||||
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
|
||||
if options == nil {
|
||||
options = &schema.SearchOptions{TopK: 10}
|
||||
}
|
||||
|
||||
// Build search parameters
|
||||
sp, _ := entity.NewIndexHNSWSearchParam(16)
|
||||
sp, err := m.buildSearchParam()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build search param: %w", err)
|
||||
}
|
||||
|
||||
outputFields, _ := m.mapper.GetRawAllFieldNames()
|
||||
vectorField, _ := m.mapper.GetVectorField()
|
||||
searchConfig, _ := m.mapper.GetSearchConfig()
|
||||
metricType := m.GetMetricType(searchConfig.MetricType)
|
||||
|
||||
// Build filter expression
|
||||
expr := ""
|
||||
searchResults, err := m.client.Search(
|
||||
ctx,
|
||||
m.Collection,
|
||||
[]string{}, // partition names
|
||||
expr, // filter expression
|
||||
[]string{"id", "content", "metadata", "created_at"}, // output fields
|
||||
m.collection,
|
||||
[]string{}, // partition names
|
||||
expr, // filter expression
|
||||
outputFields, // output fields
|
||||
[]entity.Vector{entity.FloatVector(vector)},
|
||||
"vector", // anns_field
|
||||
entity.IP, // metric_type
|
||||
vectorField.RawName, // anns_field
|
||||
metricType, // metric_type
|
||||
options.TopK,
|
||||
sp,
|
||||
)
|
||||
@@ -341,9 +677,13 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
|
||||
// Get field data
|
||||
var content string
|
||||
var metadata map[string]interface{}
|
||||
|
||||
for _, field := range result.Fields {
|
||||
switch field.Name() {
|
||||
fieldMapping, err := m.mapper.GetField(field.Name())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fieldName := strings.ToLower(fieldMapping.StandardName)
|
||||
switch fieldName {
|
||||
case "content":
|
||||
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
|
||||
if contentVal, err := contentCol.Get(i); err == nil {
|
||||
@@ -364,7 +704,6 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
searchResult := schema.SearchResult{
|
||||
Document: schema.Document{
|
||||
ID: fmt.Sprintf("%s", id),
|
||||
@@ -392,15 +731,17 @@ func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error {
|
||||
for i, id := range ids {
|
||||
quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
|
||||
}
|
||||
expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ","))
|
||||
|
||||
idField, _ := m.mapper.GetIDField()
|
||||
expr := fmt.Sprintf("%s in [%s]", idField.RawName, strings.Join(quotedIDs, ","))
|
||||
|
||||
// Delete data
|
||||
err := m.client.Delete(ctx, m.Collection, "", expr)
|
||||
err := m.client.Delete(ctx, m.collection, "", expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete documents: %w", err)
|
||||
}
|
||||
// Flush data
|
||||
err = m.client.Flush(ctx, m.Collection, false)
|
||||
err = m.client.Flush(ctx, m.collection, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
||||
}
|
||||
@@ -413,12 +754,13 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
|
||||
// Build query expression
|
||||
expr := ""
|
||||
// Query all relevant documents
|
||||
outputFields, _ := m.mapper.GetRawAllFieldNames()
|
||||
queryResult, err := m.client.Query(
|
||||
ctx,
|
||||
m.Collection,
|
||||
m.collection,
|
||||
[]string{}, // partitions
|
||||
expr, // filter condition
|
||||
[]string{"id", "content", "metadata", "created_at"},
|
||||
outputFields,
|
||||
client.WithOffset(0), client.WithLimit(int64(limit)),
|
||||
)
|
||||
|
||||
@@ -443,7 +785,12 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
|
||||
)
|
||||
|
||||
for _, col := range queryResult {
|
||||
switch col.Name() {
|
||||
fieldMapping, err := m.mapper.GetField(col.Name())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
fieldName := strings.ToLower(fieldMapping.StandardName)
|
||||
switch fieldName {
|
||||
case "id":
|
||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||
id = v.(string)
|
||||
@@ -488,8 +835,3 @@ func (m *MilvusProvider) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// joinStrings joins a slice of strings with the given separator
|
||||
func joinStrings(elems []string, sep string) string {
|
||||
return strings.Join(elems, sep)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user