add vectordb mapping (#2968)

This commit is contained in:
Jun
2025-10-06 15:08:13 +08:00
committed by GitHub
parent 45a11734bd
commit aebe354055
14 changed files with 1188 additions and 564 deletions

View File

@@ -55,16 +55,21 @@ require (
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/openai/openai-go/v2 v2.7.0 // indirect
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect
github.com/pkoukk/tiktoken-go v0.1.8 // indirect github.com/pkoukk/tiktoken-go v0.1.8 // indirect
github.com/prometheus/client_golang v1.14.0 // indirect github.com/prometheus/client_golang v1.14.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect go.uber.org/zap v1.27.0 // indirect
golang.org/x/net v0.33.0 // indirect golang.org/x/net v0.34.0 // indirect
golang.org/x/time v0.3.0 // indirect golang.org/x/time v0.3.0 // indirect
google.golang.org/grpc v1.59.0 // indirect google.golang.org/grpc v1.59.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect
@@ -99,9 +104,9 @@ require (
github.com/shopspring/decimal v1.4.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect
go.opentelemetry.io/otel v1.26.0 // indirect go.opentelemetry.io/otel v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/crypto v0.31.0 // indirect golang.org/x/crypto v0.32.0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect

View File

@@ -311,6 +311,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.24.2 h1:J/tulyYK6JwBldPViHJReihxxZ+22FHs0piGjQAvoUE= github.com/onsi/gomega v1.24.2 h1:J/tulyYK6JwBldPViHJReihxxZ+22FHs0piGjQAvoUE=
github.com/onsi/gomega v1.24.2/go.mod h1:gs3J10IS7Z7r7eXRoNJIrNqU4ToQukCJhFtKrWgHWnk= github.com/onsi/gomega v1.24.2/go.mod h1:gs3J10IS7Z7r7eXRoNJIrNqU4ToQukCJhFtKrWgHWnk=
github.com/openai/openai-go/v2 v2.7.0 h1:/8MSFCXcasin7AyuWQ2au6FraXL71gzAs+VfbMv+J3k=
github.com/openai/openai-go/v2 v2.7.0/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE=
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc h1:Ak86L+yDSOzKFa7WM5bf5itSOo1e3Xh8bm5YCMUXIjQ= github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc h1:Ak86L+yDSOzKFa7WM5bf5itSOo1e3Xh8bm5YCMUXIjQ=
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI= github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI=
github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU= github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU=
@@ -377,7 +379,18 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w= github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE=
@@ -426,6 +439,8 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -506,6 +521,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -581,6 +598,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=

View File

@@ -84,10 +84,11 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
| llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 | | llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 |
| llm.temperature | float | 可选 | 0.5 | 温度参数 | | llm.temperature | float | 可选 | 0.5 | 温度参数 |
| **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) | | **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) |
| embedding.provider | string | 必填 | dashscope | 嵌入提供商openai或dashscope | | embedding.provider | string | 必填 | openai | 嵌入提供商:支持openai协议的任意供应商 |
| embedding.api_key | string | 必填 | - | 嵌入API密钥 | | embedding.api_key | string | 必填 | - | 嵌入API密钥 |
| embedding.base_url | string | 可选 | | 嵌入API基础URL | | embedding.base_url | string | 可选 | | 嵌入API基础URL |
| embedding.model | string | 必填 | text-embedding-v4 | 嵌入模型名称 | | embedding.model | string | 必填 | text-embedding-ada-002 | 嵌入模型名称 |
| embedding.dimensions | integer | 可选 | 1536 | 嵌入维度 |
| **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) | | **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) |
| vectordb.provider | string | 必填 | milvus | 向量数据库提供商 | | vectordb.provider | string | 必填 | milvus | 向量数据库提供商 |
| vectordb.host | string | 必填 | localhost | 数据库主机地址 | | vectordb.host | string | 必填 | localhost | 数据库主机地址 |
@@ -96,6 +97,17 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
| vectordb.collection | string | 必填 | test_collection | 集合名称 | | vectordb.collection | string | 必填 | test_collection | 集合名称 |
| vectordb.username | string | 可选 | - | 数据库用户名 | | vectordb.username | string | 可选 | - | 数据库用户名 |
| vectordb.password | string | 可选 | - | 数据库密码 | | vectordb.password | string | 可选 | - | 数据库密码 |
| **vectordb.mapping** | object | 可选 | - | 字段映射配置 |
| vectordb.mapping.fields | array | 可选 | - | 字段映射列表 |
| vectordb.mapping.fields[].standard_name | string | 必填 | - | 标准字段名称(如 id, content, vector 等) |
| vectordb.mapping.fields[].raw_name | string | 必填 | - | 原始字段名称(数据库中的实际字段名) |
| vectordb.mapping.fields[].properties | object | 可选 | - | 字段属性(如 auto_id, max_length 等) |
| vectordb.mapping.index | object | 可选 | - | 索引配置 |
| vectordb.mapping.index.index_type | string | 必填 | - | 索引类型(如 FLAT, IVF_FLAT, HNSW 等) |
| vectordb.mapping.index.params | object | 可选 | - | 索引参数(根据索引类型不同而异) |
| vectordb.mapping.search | object | 可选 | - | 搜索配置 |
| vectordb.mapping.search.metric_type | string | 可选 | L2 | 度量类型(如 L2, IP, COSINE 等) |
| vectordb.mapping.search.params | object | 可选 | - | 搜索参数(如 nprobe, ef_search 等)
### higress-config 配置样例 ### higress-config 配置样例
@@ -143,27 +155,54 @@ data:
temperature: 0.5 temperature: 0.5
max_tokens: 2048 max_tokens: 2048
embedding: embedding:
provider: dashscope provider: openai
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
api_key: sk-xxx api_key: sk-xxx
model: text-embedding-v4 model: text-embedding-v4
dimensions: 1536
vectordb: vectordb:
provider: milvus provider: milvus
host: <milvus IP> host: localhost
port: 19530 port: 19530
database: default database: default
collection: test_collection collection: test_rag
``` mapping:
fields:
- standard_name: id
raw_name: id
properties:
auto_id: false
max_length: 256
- standard_name: content
raw_name: content
properties:
max_length: 8192
- standard_name: vector
raw_name: vector
- standard_name: metadata
raw_name: metadata
- standard_name: created_at
raw_name: created_at
index:
index_type: HNSW
params:
M: 4
efConstruction: 32
search:
metric_type: IP
params:
ef: 32
```
### 支持的提供商 ### 支持的提供商
#### Embedding #### Embedding
- **OpenAI** - **OpenAI 兼容**
- **DashScope**
#### Vector Database #### Vector Database
- **Milvus** - **Milvus**
#### LLM #### LLM
- **OpenAI** - **OpenAI 兼容**
## 如何测试数据集的效果 ## 如何测试数据集的效果

View File

@@ -1,5 +1,7 @@
package config package config
import "fmt"
// Config represents the main configuration structure for the MCP server // Config represents the main configuration structure for the MCP server
type Config struct { type Config struct {
RAG RAGConfig `json:"rag" yaml:"rag"` RAG RAGConfig `json:"rag" yaml:"rag"`
@@ -38,7 +40,7 @@ type EmbeddingConfig struct {
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"` APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"` BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
Model string `json:"model,omitempty" yaml:"model,omitempty"` Model string `json:"model,omitempty" yaml:"model,omitempty"`
Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"` Dimensions int `json:"dimensions,omitempty" yaml:"dimension,omitempty"`
} }
// VectorDBConfig defines configuration for vector databases // VectorDBConfig defines configuration for vector databases
@@ -50,4 +52,132 @@ type VectorDBConfig struct {
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"` Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
Username string `json:"username,omitempty" yaml:"username,omitempty"` Username string `json:"username,omitempty" yaml:"username,omitempty"`
Password string `json:"password,omitempty" yaml:"password,omitempty"` Password string `json:"password,omitempty" yaml:"password,omitempty"`
Mapping MappingConfig `json:"mapping,omitempty" yaml:"mapping,omitempty"`
}
// MappingConfig defines field mapping configuration for vector databases
type MappingConfig struct {
Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
Index IndexConfig `json:"index,omitempty" yaml:"index,omitempty"`
Search SearchConfig `json:"search,omitempty" yaml:"search,omitempty"`
}
// // CollectionMapping defines field mapping for collection
// type CollectionMapping struct {
// Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
// }
// FieldMapping defines mapping for a single field
type FieldMapping struct {
StandardName string `json:"standard_name" yaml:"standard_name"`
RawName string `json:"raw_name" yaml:"raw_name"`
Properties map[string]interface{} `json:"properties,omitempty" yaml:"properties,omitempty"`
}
func (f FieldMapping) IsPrimaryKey() bool {
return f.StandardName == "id"
}
func (f FieldMapping) IsAutoID() bool {
if f.Properties == nil {
return false
}
autoID, ok := f.Properties["auto_id"].(bool)
if !ok {
return false
}
return autoID
}
func (f FieldMapping) IsVectorField() bool {
return f.StandardName == "vector"
}
func (f FieldMapping) MaxLength() int {
if f.Properties == nil {
return 0
}
maxLength, ok := f.Properties["max_length"].(int)
if !ok {
return 256
}
return maxLength
}
// IndexConfig defines configuration for index parameters
type IndexConfig struct {
// Index type, e.g., IVF_FLAT, IVF_SQ8, HNSW, etc.
IndexType string `json:"index_type" yaml:"index_type"`
// Index parameter configuration
Params map[string]interface{} `json:"params" yaml:"params"`
}
func (i IndexConfig) ParamsString(key string) (string, error) {
if mVal, ok := i.Params[key].(string); ok {
return mVal, nil
}
return "", fmt.Errorf("params %s not found", key)
}
func (i IndexConfig) ParamsInt64(key string) (int64, error) {
if mVal, ok := i.Params[key].(int64); ok {
return mVal, nil
}
if mVal, ok := i.Params[key].(int); ok {
return int64(mVal), nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i IndexConfig) ParamsFloat64(key string) (float64, error) {
if mVal, ok := i.Params[key].(float64); ok {
return mVal, nil
}
if mVal, ok := i.Params[key].(float32); ok {
return float64(mVal), nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i IndexConfig) ParamsBool(key string) (bool, error) {
if mVal, ok := i.Params[key].(bool); ok {
return mVal, nil
}
return false, fmt.Errorf("params %s not found", key)
}
// SearchConfig defines configuration for search parameters
type SearchConfig struct {
// Metric type, e.g., L2, IP, etc.
MetricType string `json:"metric_type,omitempty" yaml:"metric_type,omitempty"`
// Search parameter configuration
Params map[string]interface{} `json:"params" yaml:"params"`
}
func (i SearchConfig) ParamsString(key string) (string, error) {
if mVal, ok := i.Params[key].(string); ok {
return mVal, nil
}
return "", fmt.Errorf("params %s not found", key)
}
func (i SearchConfig) ParamsInt64(key string) (int64, error) {
if mVal, ok := i.Params[key].(int64); ok {
return mVal, nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i SearchConfig) ParamsFloat64(key string) (float64, error) {
if mVal, ok := i.Params[key].(float64); ok {
return mVal, nil
}
return 0, fmt.Errorf("params %s not found", key)
}
func (i SearchConfig) ParamsBool(key string) (bool, error) {
if mVal, ok := i.Params[key].(bool); ok {
return mVal, nil
}
return false, fmt.Errorf("params %s not found", key)
} }

View File

@@ -1,169 +0,0 @@
package embedding
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
const (
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
DASHSCOPE_PORT = 443
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4"
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
)
var dashScopeConfig dashScopeProviderConfig
type dashScopeProviderInitializer struct {
}
type dashScopeProviderConfig struct {
apiKey string
model string
}
func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) {
dashScopeConfig.apiKey = config.APIKey
dashScopeConfig.model = config.Model
}
func (c *dashScopeProviderInitializer) ValidateConfig() error {
if dashScopeConfig.apiKey == "" {
return errors.New("[DashScope] apiKey is required")
}
return nil
}
func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
c.InitConfig(config)
err := c.ValidateConfig()
if err != nil {
return nil, err
}
headers := map[string]string{
"Authorization": "Bearer " + config.APIKey,
"Content-Type": "application/json",
}
httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers)
return &DashScopeProvider{
config: dashScopeConfig,
client: httpClient,
}, nil
}
func (d *DashScopeProvider) GetProviderType() string {
return PROVIDER_TYPE_DASHSCOPE
}
type Embedding struct {
Embedding []float32 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type Input struct {
Texts []string `json:"texts"`
}
type Params struct {
TextType string `json:"text_type"`
}
type Response struct {
RequestID string `json:"request_id"`
Output Output `json:"output"`
Usage Usage `json:"usage"`
}
type Output struct {
Embeddings []Embedding `json:"embeddings"`
}
type Usage struct {
TotalTokens int `json:"total_tokens"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Params `json:"parameters"`
}
type Document struct {
Vector []float64 `json:"vector"`
Fields map[string]string `json:"fields"`
}
type DashScopeProvider struct {
config dashScopeProviderConfig
client *common.HTTPClient
}
func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) {
model := d.config.model
if model == "" {
model = DASHSCOPE_DEFAULT_MODEL_NAME
}
if dashScopeConfig.apiKey == "" {
return EmbeddingRequest{}, errors.New("dashScopeKey is empty")
}
data := EmbeddingRequest{
Model: model,
Input: Input{
Texts: texts,
},
Parameters: Params{
TextType: "query",
},
}
return data, nil
}
type Result struct {
ID string `json:"id"`
Vector []float32 `json:"vector,omitempty"`
Fields map[string]interface{} `json:"fields"`
Score float64 `json:"score"`
}
func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
var resp Response
err := json.Unmarshal(responseBody, &resp)
if err != nil {
return nil, err
}
return &resp, nil
}
func (d *DashScopeProvider) GetEmbedding(
ctx context.Context,
queryString string) ([]float32, error) {
requestData, err := d.constructRequestData([]string{queryString})
if err != nil {
return nil, fmt.Errorf("failed to construct request data: %v", err)
}
responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
embeddingResp, err := d.parseTextEmbedding(responseBody)
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
if len(embeddingResp.Output.Embeddings) == 0 {
return nil, errors.New("no embedding found in response")
}
return embeddingResp.Output.Embeddings[0].Embedding, nil
}

View File

@@ -2,160 +2,93 @@ package embedding
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
) )
const ( const (
OPENAI_DOMAIN = "api.openai.com" OPENAI_DEFAULT_MODEL_NAME = "text-embedding-ada-002"
OPENAI_PORT = 443
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
OPENAI_ENDPOINT = "/v1/embeddings"
) )
type openAIProviderInitializer struct { type openAIProviderInitializer struct {
} }
var openAIConfig openAIProviderConfig func (c *openAIProviderInitializer) validateConfig(config *config.EmbeddingConfig) error {
if config.APIKey == "" {
type openAIProviderConfig struct { return errors.New("[openai embbeding] apiKey is required")
baseUrl string }
apiKey string if config.Model == "" {
model string config.Model = OPENAI_DEFAULT_MODEL_NAME
}
if config.Dimensions <= 0 {
config.Dimensions = 1536
} }
func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) {
openAIConfig.apiKey = config.APIKey
openAIConfig.model = config.Model
openAIConfig.baseUrl = config.BaseURL
}
func (c *openAIProviderInitializer) ValidateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openAI] apiKey is required")
}
return nil return nil
} }
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) { func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
c.InitConfig(config) if err := c.validateConfig(&config); err != nil {
err := c.ValidateConfig()
if err != nil {
return nil, err return nil, err
} }
// 创建 OpenAI 客户端
var clientOptions []option.RequestOption
clientOptions = append(clientOptions, option.WithAPIKey(config.APIKey))
if openAIConfig.model == "" { // 如果设置了自定义 baseURL则使用它
openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME if config.BaseURL != "" {
clientOptions = append(clientOptions, option.WithBaseURL(config.BaseURL))
} }
// 创建 OpenAI 客户端
if openAIConfig.baseUrl == "" { client := openai.NewClient(clientOptions...)
openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN)
}
headers := map[string]string{
"Authorization": "Bearer " + config.APIKey,
"Content-Type": "application/json",
}
httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers)
return &OpenAIProvider{ return &OpenAIProvider{
config: openAIConfig, client: &client,
client: httpClient, model: config.Model,
dimensions: config.Dimensions,
}, nil }, nil
} }
func (o *OpenAIProvider) GetProviderType() string { // EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
type OpenAIProvider struct {
client *openai.Client
model string
dimensions int
}
func (e *OpenAIProvider) GetProviderType() string {
return PROVIDER_TYPE_OPENAI return PROVIDER_TYPE_OPENAI
} }
type OpenAIResponse struct { // GetEmbedding generates vector embedding for the given text
Object string `json:"object"` func (e *OpenAIProvider) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
Data []OpenAIResult `json:"data"` params := openai.EmbeddingNewParams{
Model string `json:"model"` Model: e.model,
Error *OpenAIError `json:"error"` Input: openai.EmbeddingNewParamsInputUnion{
OfString: openai.String(text),
},
Dimensions: openai.Int(int64(e.dimensions)),
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
} }
type OpenAIResult struct { embeddingResp, err := e.client.Embeddings.New(ctx, params)
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type OpenAIError struct {
Message string `json:"prompt_tokens"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
}
type OpenAIEmbeddingRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type OpenAIProvider struct {
config openAIProviderConfig
client *common.HTTPClient
}
func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) {
if text == "" {
return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty")
}
if openAIConfig.apiKey == "" {
return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty")
}
model := o.config.model
if model == "" {
model = OPENAI_DEFAULT_MODEL_NAME
}
data := OpenAIEmbeddingRequest{
Input: text,
Model: model,
}
return data, nil
}
func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
var resp OpenAIResponse
err := json.Unmarshal(responseBody, &resp)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to generate embedding: %w", err)
}
return &resp, nil
} }
func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) { if len(embeddingResp.Data) == 0 {
requestData, err := o.constructRequestData(queryString) return nil, fmt.Errorf("empty embedding response")
if err != nil {
return nil, fmt.Errorf("failed to construct request data: %v", err)
} }
responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData) // Convert []float64 to []float32
if err != nil { embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
return nil, fmt.Errorf("failed to send request: %v", err) for i, v := range embeddingResp.Data[0].Embedding {
embedding[i] = float32(v)
} }
resp, err := o.parseTextEmbedding(responseBody) return embedding, nil
if err != nil {
return nil, fmt.Errorf("failed to parse response: %v", err)
}
if resp.Error != nil {
return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message)
}
if len(resp.Data) == 0 {
return nil, errors.New("no embedding found in response")
}
return resp.Data[0].Embedding, nil
} }

View File

@@ -36,7 +36,6 @@ type providerInitializer interface {
// Maps provider types to their initializers // Maps provider types to their initializers
var ( var (
providerInitializers = map[string]providerInitializer{ providerInitializers = map[string]providerInitializer{
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{}, PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
} }
) )

View File

@@ -2,133 +2,105 @@ package llm
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config" "github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/option"
"github.com/openai/openai-go/v2/packages/param"
) )
const ( const (
OPENAI_CHAT_ENDPOINT = "/chat/completions"
OPENAI_DEFAULT_MODEL = "gpt-4o" OPENAI_DEFAULT_MODEL = "gpt-4o"
) )
// openAI specific configuration captured after initialization. type OpenAIProvider struct {
type openAIProviderConfig struct { client *openai.Client
apiKey string
baseURL string
model string model string
maxTokens int
temperature float64 temperature float64
maxTokens int
} }
type openAIProviderInitializer struct{} type openAIProviderInitializer struct{}
var openAIConfig openAIProviderConfig func (i *openAIProviderInitializer) validateConfig(cfg *config.LLMConfig) error {
if cfg.APIKey == "" {
func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) {
openAIConfig.apiKey = c.APIKey
openAIConfig.baseURL = c.BaseURL
openAIConfig.model = c.Model
if openAIConfig.model == "" {
openAIConfig.model = OPENAI_DEFAULT_MODEL
}
if openAIConfig.baseURL == "" {
openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint
}
openAIConfig.maxTokens = c.MaxTokens
openAIConfig.temperature = c.Temperature
}
func (i *openAIProviderInitializer) validateConfig() error {
if openAIConfig.apiKey == "" {
return errors.New("[openai llm] apiKey is required") return errors.New("[openai llm] apiKey is required")
} }
if cfg.Model == "" {
cfg.Model = OPENAI_DEFAULT_MODEL
}
if cfg.Temperature <= 0 || cfg.Temperature > 2 {
cfg.Temperature = 0.5
}
if cfg.MaxTokens <= 0 {
cfg.MaxTokens = 2048
}
return nil return nil
} }
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) { func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
i.initConfig(cfg) if err := i.validateConfig(&cfg); err != nil {
if err := i.validateConfig(); err != nil {
return nil, err return nil, err
} }
headers := map[string]string{ // Create OpenAI client
"Authorization": "Bearer " + openAIConfig.apiKey, var clientOptions []option.RequestOption
"Content-Type": "application/json", clientOptions = append(clientOptions, option.WithAPIKey(cfg.APIKey))
}
client := common.NewHTTPClient(openAIConfig.baseURL, headers) // If a custom baseURL is set, use it
return &OpenAIProvider{client: client, cfg: openAIConfig}, nil if cfg.BaseURL != "" {
clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL))
} }
type OpenAIProvider struct { // Create OpenAI client
client *common.HTTPClient client := openai.NewClient(clientOptions...)
cfg openAIProviderConfig
}
type openAIChatCompletionRequest struct { return &OpenAIProvider{
Model string `json:"model"` client: &client,
Messages []openAIChatMessage `json:"messages"` model: cfg.Model,
Temperature float64 `json:"temperature,omitempty"` temperature: cfg.Temperature,
MaxTokens int `json:"max_tokens,omitempty"` maxTokens: cfg.MaxTokens,
} }, nil
type openAIChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type openAIChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Choices []openAIChatCompletionResponseChoice `json:"choices"`
Error *openAIError `json:"error,omitempty"`
}
type openAIChatCompletionResponseChoice struct {
Index int `json:"index"`
Message openAIChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type openAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
Param string `json:"param"`
} }
// GenerateCompletion implements Provider interface. // GenerateCompletion implements Provider interface.
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) { func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
req := openAIChatCompletionRequest{ // Create chat request
Model: o.cfg.model, params := openai.ChatCompletionNewParams{
Messages: []openAIChatMessage{ Model: o.model,
{Role: "user", Content: prompt}, Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
}, },
Temperature: o.cfg.temperature,
MaxTokens: o.cfg.maxTokens,
} }
body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req) // Set optional parameters
if o.temperature > 0 {
temperature := float64(o.temperature)
params.Temperature = param.Opt[float64]{Value: temperature}
}
if o.maxTokens > 0 {
maxTokens := int64(o.maxTokens)
params.MaxTokens = param.Opt[int64]{Value: maxTokens}
}
// Send request
response, err := o.client.Chat.Completions.New(ctx, params)
if err != nil { if err != nil {
return "", fmt.Errorf("openai llm post error: %w", err) // Handle error
return "", fmt.Errorf("openai llm error: %w", err)
} }
var resp openAIChatCompletionResponse // Check response
if err := json.Unmarshal(body, &resp); err != nil { if len(response.Choices) == 0 {
return "", fmt.Errorf("openai llm unmarshal error: %w", err)
}
if resp.Error != nil {
return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message)
}
if len(resp.Choices) == 0 {
return "", errors.New("openai llm: empty choices") return "", errors.New("openai llm: empty choices")
} }
return resp.Choices[0].Message.Content, nil // Return generated content
return response.Choices[0].Message.Content, nil
} }
func (o *OpenAIProvider) GetProviderType() string { func (o *OpenAIProvider) GetProviderType() string {

View File

@@ -56,18 +56,12 @@ func NewRAGClient(config *config.Config) (*RAGClient, error) {
ragclient.llmProvider = llmProvider ragclient.llmProvider = llmProvider
} }
demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization") dim := ragclient.config.Embedding.Dimensions
if err != nil {
return nil, fmt.Errorf("create init embedding failed, err: %w", err)
}
dim := len(demoVector)
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim) provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
if err != nil { if err != nil {
return nil, fmt.Errorf("create vector store provider failed, err: %w", err) return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
} }
ragclient.vectordbProvider = provider ragclient.vectordbProvider = provider
return ragclient, nil return ragclient, nil
} }

View File

@@ -22,15 +22,17 @@ func getRAGClient() (*RAGClient, error) {
LLM: config.LLMConfig{ LLM: config.LLMConfig{
Provider: "openai", Provider: "openai",
APIKey: "sk-xxxx", APIKey: "sk-xxx",
BaseURL: "https://openrouter.ai/api/v1", BaseURL: "https://openrouter.ai/api/v1",
Model: "openai/gpt-4o", Model: "openai/gpt-4o",
}, },
Embedding: config.EmbeddingConfig{ Embedding: config.EmbeddingConfig{
Provider: "dashscope", Provider: "openai",
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
APIKey: "sk-xxxx", APIKey: "sk-xxxx",
Model: "text-embedding-v4", Model: "text-embedding-v4",
Dimensions: 1536,
}, },
VectorDB: config.VectorDBConfig{ VectorDB: config.VectorDBConfig{
@@ -38,7 +40,49 @@ func getRAGClient() (*RAGClient, error) {
Host: "localhost", Host: "localhost",
Port: 19530, Port: 19530,
Database: "default", Database: "default",
Collection: "test_collection", Collection: "test_collection3",
Mapping: config.MappingConfig{
Fields: []config.FieldMapping{
{
StandardName: "id",
RawName: "pk",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "page_content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "page_vector",
Properties: make(map[string]interface{}),
},
{
StandardName: "metadata",
RawName: "metadata",
Properties: make(map[string]interface{}),
},
{
StandardName: "created_at",
RawName: "created_at",
Properties: make(map[string]interface{}),
},
},
Index: config.IndexConfig{
IndexType: "IVF_FLAT",
Params: map[string]interface{}{"nlist": 64},
},
Search: config.SearchConfig{
MetricType: "COSINE",
Params: map[string]interface{}{"nprobe": 32},
},
},
}, },
} }
@@ -48,7 +92,6 @@ func getRAGClient() (*RAGClient, error) {
} }
return ragClient, nil return ragClient, nil
} }
func TestNewRAGClient(t *testing.T) { func TestNewRAGClient(t *testing.T) {
@@ -104,7 +147,7 @@ func TestRAGClient_DeleteChunk(t *testing.T) {
return return
} }
chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2" chunk_id := "2a06679c-a8ea-46dc-bf1c-7e7b164a73c8"
err = ragClient.DeleteChunk(chunk_id) err = ragClient.DeleteChunk(chunk_id)
if err != nil { if err != nil {
t.Errorf("DeleteChunk() error = %v", err) t.Errorf("DeleteChunk() error = %v", err)

View File

@@ -36,11 +36,11 @@ func init() {
MaxTokens: 2048, MaxTokens: 2048,
}, },
Embedding: config.EmbeddingConfig{ Embedding: config.EmbeddingConfig{
Provider: "dashscope", Provider: "openai",
APIKey: "", APIKey: "",
BaseURL: "", BaseURL: "",
Model: "text-embedding-v4", Model: "text-embedding-ada-002",
Dimension: 1024, Dimensions: 1536,
}, },
VectorDB: config.VectorDBConfig{ VectorDB: config.VectorDBConfig{
Provider: "milvus", Provider: "milvus",
@@ -50,14 +50,56 @@ func init() {
Collection: "rag", Collection: "rag",
Username: "", Username: "",
Password: "", Password: "",
Mapping: config.MappingConfig{
Fields: []config.FieldMapping{
{
StandardName: "id",
RawName: "id",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "vector",
Properties: make(map[string]interface{}),
},
{
StandardName: "metadata",
RawName: "metadata",
Properties: make(map[string]interface{}),
},
{
StandardName: "created_at",
RawName: "created_at",
Properties: make(map[string]interface{}),
},
},
Index: config.IndexConfig{
IndexType: "HNSW",
Params: map[string]interface{}{"M": 8, "efConstruction": 64},
},
Search: config.SearchConfig{
MetricType: "IP",
Params: make(map[string]interface{}),
},
},
}, },
}, },
}) })
} }
func (c *RAGConfig) ParseConfig(config map[string]any) error { func (c *RAGConfig) ParseConfig(cfg map[string]any) error {
// Parse RAG configuration // Parse RAG configuration
if ragConfig, ok := config["rag"].(map[string]any); ok { if ragConfig, ok := cfg["rag"].(map[string]any); ok {
if splitter, exists := ragConfig["splitter"].(map[string]any); exists { if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
if splitterType, exists := splitter["provider"].(string); exists { if splitterType, exists := splitter["provider"].(string); exists {
c.config.RAG.Splitter.Provider = splitterType c.config.RAG.Splitter.Provider = splitterType
@@ -78,7 +120,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
} }
// Parse Embedding configuration // Parse Embedding configuration
if embeddingConfig, ok := config["embedding"].(map[string]any); ok { if embeddingConfig, ok := cfg["embedding"].(map[string]any); ok {
if provider, exists := embeddingConfig["provider"].(string); exists { if provider, exists := embeddingConfig["provider"].(string); exists {
c.config.Embedding.Provider = provider c.config.Embedding.Provider = provider
} else { } else {
@@ -94,13 +136,13 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
if model, exists := embeddingConfig["model"].(string); exists { if model, exists := embeddingConfig["model"].(string); exists {
c.config.Embedding.Model = model c.config.Embedding.Model = model
} }
if dimension, exists := embeddingConfig["dimension"].(float64); exists { if dimensions, exists := embeddingConfig["dimensions"].(float64); exists {
c.config.Embedding.Dimension = int(dimension) c.config.Embedding.Dimensions = int(dimensions)
} }
} }
// Parse llm configuration // Parse llm configuration
if llmConfig, ok := config["llm"].(map[string]any); ok { if llmConfig, ok := cfg["llm"].(map[string]any); ok {
if provider, exists := llmConfig["provider"].(string); exists { if provider, exists := llmConfig["provider"].(string); exists {
c.config.LLM.Provider = provider c.config.LLM.Provider = provider
} }
@@ -122,7 +164,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
} }
// Parse VectorDB configuration // Parse VectorDB configuration
if vectordbConfig, ok := config["vectordb"].(map[string]any); ok { if vectordbConfig, ok := cfg["vectordb"].(map[string]any); ok {
if provider, exists := vectordbConfig["provider"].(string); exists { if provider, exists := vectordbConfig["provider"].(string); exists {
c.config.VectorDB.Provider = provider c.config.VectorDB.Provider = provider
} else { } else {
@@ -146,8 +188,59 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
if password, exists := vectordbConfig["password"].(string); exists { if password, exists := vectordbConfig["password"].(string); exists {
c.config.VectorDB.Password = password c.config.VectorDB.Password = password
} }
// Parse mapping here
if mapping, exists := vectordbConfig["mapping"].(map[string]any); exists {
// Parse field mappings
if fields, ok := mapping["fields"].([]any); ok {
c.config.VectorDB.Mapping.Fields = []config.FieldMapping{}
for _, field := range fields {
if fieldMap, ok := field.(map[string]any); ok {
fieldMapping := config.FieldMapping{
Properties: make(map[string]interface{}),
}
if standardName, ok := fieldMap["standard_name"].(string); ok {
fieldMapping.StandardName = standardName
} }
if rawName, ok := fieldMap["raw_name"].(string); ok {
fieldMapping.RawName = rawName
}
// Parse properties
if properties, ok := fieldMap["properties"].(map[string]any); ok {
for key, value := range properties {
fieldMapping.Properties[key] = value
}
}
c.config.VectorDB.Mapping.Fields = append(c.config.VectorDB.Mapping.Fields, fieldMapping)
}
}
}
// Parse index configuration
if index, ok := mapping["index"].(map[string]any); ok {
if indexType, ok := index["index_type"].(string); ok {
c.config.VectorDB.Mapping.Index.IndexType = indexType
}
// Parse index parameters
if params, ok := index["params"].(map[string]any); ok {
c.config.VectorDB.Mapping.Index.Params = params
}
}
// Parse search configuration
if search, ok := mapping["search"].(map[string]any); ok {
if metricType, ok := search["metric_type"].(string); ok {
c.config.VectorDB.Mapping.Search.MetricType = metricType
}
// Parse search parameters
if params, ok := search["params"].(map[string]any); ok {
c.config.VectorDB.Mapping.Search.Params = params
}
}
}
}
return nil return nil
} }

View File

@@ -32,7 +32,7 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
APIKey: "sk-XXX", APIKey: "sk-XXX",
BaseURL: "", BaseURL: "",
Model: "text-embedding-v4", Model: "text-embedding-v4",
Dimension: 1024, Dimensions: 1024,
}, },
VectorDB: config.VectorDBConfig{ VectorDB: config.VectorDBConfig{
Provider: "milvus", Provider: "milvus",
@@ -42,6 +42,48 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
Collection: "test_rag", Collection: "test_rag",
Username: "", Username: "",
Password: "", Password: "",
Mapping: config.MappingConfig{
Fields: []config.FieldMapping{
{
StandardName: "id",
RawName: "id",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "vector",
Properties: make(map[string]interface{}),
},
{
StandardName: "metadata",
RawName: "metadata",
Properties: make(map[string]interface{}),
},
{
StandardName: "created_at",
RawName: "created_at",
Properties: make(map[string]interface{}),
},
},
Index: config.IndexConfig{
IndexType: "HNSW",
Params: map[string]interface{}{"M": 4, "efConstruction": 32},
},
Search: config.SearchConfig{
MetricType: "IP",
Params: map[string]interface{}{"ef": 32},
},
},
}, },
} }
// 把 config 输出 yaml 格式 // 把 config 输出 yaml 格式

View File

@@ -0,0 +1,182 @@
package vectordb
import (
"errors"
"fmt"
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
)
// Error definitions
var (
ErrFieldNotFound = errors.New("field not found")
ErrInvalidFieldType = errors.New("invalid field type")
ErrInvalidIndexType = errors.New("invalid index type")
ErrInvalidMetricType = errors.New("invalid metric type")
ErrInvalidSearchParams = errors.New("invalid search parameters")
ErrCollectionNotFound = errors.New("collection not found")
ErrUnsupportedOperation = errors.New("unsupported operation")
)
// VectorDBMapper interface for vector database mapping
type VectorDBMapper interface {
// ParseMapping parses the mapping configuration
ParseMapping(provider string, cfg config.MappingConfig) error
// GetIndexConfig returns the index configuration
GetIndexConfig() (config.IndexConfig, error)
// GetSearchConfig returns the search configuration
GetSearchConfig() (config.SearchConfig, error)
// Get all raw field names
GetRawAllFieldNames() ([]string, error)
// GetIDField returns the ID field mapping
GetIDField() (*config.FieldMapping, error)
// GetVectorField returns the vector field mapping
GetVectorField() (*config.FieldMapping, error)
// Get raw field name by standard field name
GetRawField(standardFieldName string) (*config.FieldMapping, error)
// Get field mapping by raw field name
GetField(rawFieldName string) (*config.FieldMapping, error)
// Get all field mappings
GetFieldMappings() ([]config.FieldMapping, error)
}
// DefaultVectorDBMapper is the default implementation of VectorDBMapper interface
type DefaultVectorDBMapper struct {
// Mapping configuration
mappingConfig config.MappingConfig
// Map from standard field name to field mapping
standardFieldMap map[string]*config.FieldMapping
// Map from raw field name to field mapping
rawFieldMap map[string]*config.FieldMapping
}
// NewDefaultVectorDBMapper creates a new default vector database mapper
func NewDefaultVectorDBMapper(provider string, mappingConfig config.MappingConfig) (*DefaultVectorDBMapper, error) {
mapper := &DefaultVectorDBMapper{
standardFieldMap: make(map[string]*config.FieldMapping),
rawFieldMap: make(map[string]*config.FieldMapping),
}
if err := mapper.ParseMapping(provider, mappingConfig); err != nil {
return nil, err
}
return mapper, nil
}
// ParseMapping parses the mapping configuration
func (m *DefaultVectorDBMapper) ParseMapping(provider string, cfg config.MappingConfig) error {
m.mappingConfig = cfg
// Clear existing mappings
m.standardFieldMap = make(map[string]*config.FieldMapping)
m.rawFieldMap = make(map[string]*config.FieldMapping)
// fill default field mappings
if len(cfg.Fields) == 0 {
defaultFields := []config.FieldMapping{
{
StandardName: "id",
RawName: "id",
Properties: map[string]interface{}{
"max_length": 256,
"auto_id": false,
},
},
{
StandardName: "content",
RawName: "content",
Properties: map[string]interface{}{
"max_length": 8192,
},
},
{
StandardName: "vector",
RawName: "vector",
},
{
StandardName: "metadata",
RawName: "metadata",
},
{
StandardName: "created_at",
RawName: "created_at",
},
}
cfg.Fields = defaultFields
}
// Parse field mappings
for i, field := range cfg.Fields {
// Save pointer for future reference
fieldPtr := &cfg.Fields[i]
m.standardFieldMap[field.StandardName] = fieldPtr
m.rawFieldMap[field.RawName] = fieldPtr
}
// Check fields, must include id, content, vector fields
requiredFields := []string{"id", "content", "vector"}
for _, fieldName := range requiredFields {
if _, err := m.GetRawField(fieldName); err != nil {
return fmt.Errorf("[vector db mapper] required field %s not found or not varchar type", fieldName)
}
}
return nil
}
// GetIndexConfig gets the index configuration
func (m *DefaultVectorDBMapper) GetIndexConfig() (config.IndexConfig, error) {
return m.mappingConfig.Index, nil
}
// GetSearchConfig gets the search configuration
func (m *DefaultVectorDBMapper) GetSearchConfig() (config.SearchConfig, error) {
return m.mappingConfig.Search, nil
}
// GetRawAllFieldNames gets all raw field names
func (m *DefaultVectorDBMapper) GetRawAllFieldNames() ([]string, error) {
fieldNames := make([]string, 0, len(m.rawFieldMap))
for name := range m.rawFieldMap {
fieldNames = append(fieldNames, name)
}
return fieldNames, nil
}
// GetIDField gets the ID field
func (m *DefaultVectorDBMapper) GetIDField() (*config.FieldMapping, error) {
return m.GetRawField("id")
}
// GetVectorField gets the vector field
func (m *DefaultVectorDBMapper) GetVectorField() (*config.FieldMapping, error) {
return m.GetRawField("vector")
}
// GetRawField gets the raw field mapping by standard field name
func (m *DefaultVectorDBMapper) GetRawField(standardFieldName string) (*config.FieldMapping, error) {
field, exists := m.standardFieldMap[standardFieldName]
if !exists {
return nil, fmt.Errorf("%w: standard field %s not found", ErrFieldNotFound, standardFieldName)
}
return field, nil
}
// GetField gets the field mapping by raw field name
func (m *DefaultVectorDBMapper) GetField(rawFieldName string) (*config.FieldMapping, error) {
field, exists := m.rawFieldMap[rawFieldName]
if !exists {
return nil, fmt.Errorf("%w: raw field %s not found", ErrFieldNotFound, rawFieldName)
}
return field, nil
}
// GetFieldMappings gets all field mappings
func (m *DefaultVectorDBMapper) GetFieldMappings() ([]config.FieldMapping, error) {
return m.mappingConfig.Fields, nil
}

View File

@@ -80,16 +80,17 @@ func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, d
type MilvusProvider struct { type MilvusProvider struct {
client client.Client client client.Client
config *config.VectorDBConfig config *config.VectorDBConfig
Collection string collection string
mapper VectorDBMapper
dimensions int
} }
// NewMilvusProvider creates a new instance of MilvusProvider // NewMilvusProvider creates a new instance of MilvusProvider
func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) { func NewMilvusProvider(cfg *config.VectorDBConfig, dimensions int) (VectorStoreProvider, error) {
// Create Milvus client // Create Milvus client
connectParam := client.Config{ connectParam := client.Config{
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
} }
connectParam.DBName = cfg.Database connectParam.DBName = cfg.Database
// Add authentication if credentials are provided // Add authentication if credentials are provided
if cfg.Username != "" && cfg.Password != "" { if cfg.Username != "" && cfg.Password != "" {
@@ -102,92 +103,301 @@ func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider
return nil, fmt.Errorf("failed to create milvus client: %w", err) return nil, fmt.Errorf("failed to create milvus client: %w", err)
} }
mapper, err := NewDefaultVectorDBMapper(MILVUS_PROVIDER_TYPE, cfg.Mapping)
if err != nil {
return nil, fmt.Errorf("failed to create default vector db mapper: %w", err)
}
provider := &MilvusProvider{ provider := &MilvusProvider{
client: milvusClient, client: milvusClient,
config: cfg, config: cfg,
Collection: cfg.Collection, collection: cfg.Collection,
mapper: mapper,
dimensions: dimensions,
} }
ctx := context.Background() ctx := context.Background()
if err := provider.CreateCollection(ctx, dim); err != nil { if err := provider.CreateCollection(ctx, dimensions); err != nil {
return nil, err return nil, err
} }
return provider, nil return provider, nil
} }
func (m *MilvusProvider) buildSchema() (*entity.Schema, error) {
// Create Milvus collection Schema
idField, _ := m.mapper.GetIDField()
isIDAuto := idField.IsAutoID()
schema := entity.NewSchema().
WithName(m.collection).
WithDescription("Knowledge document collection").
WithAutoID(isIDAuto).
WithDynamicFieldEnabled(false)
// Add fields
var fieldEntity *entity.Field
fieldMappings, _ := m.mapper.GetFieldMappings()
for _, field := range fieldMappings {
fieldEntity = nil
maxLength := field.MaxLength()
switch field.StandardName {
case "id":
isIDAuto := field.IsAutoID()
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(int64(maxLength)).
WithIsPrimaryKey(true)
if isIDAuto {
fieldEntity.WithIsAutoID(true)
}
schema.WithField(fieldEntity)
case "content":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(int64(maxLength))
schema.WithField(fieldEntity)
case "vector":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeFloatVector).
WithDim(int64(m.dimensions))
schema.WithField(fieldEntity)
case "metadata":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeJSON)
schema.WithField(fieldEntity)
case "created_at":
fieldEntity = entity.NewField().
WithName(field.RawName).
WithDataType(entity.FieldTypeInt64)
schema.WithField(fieldEntity)
}
}
return schema, nil
}
func (m *MilvusProvider) GetMetricType(metricType string) entity.MetricType {
switch strings.ToUpper(metricType) {
case "L2":
return entity.L2
case "IP":
return entity.IP
case "COSINE":
return entity.COSINE
case "HAMMING":
return entity.HAMMING
case "JACCARD":
return entity.JACCARD
case "TANIMOTO":
return entity.TANIMOTO
case "SUBSTRUCTURE":
return entity.SUBSTRUCTURE
case "SUPERSTRUCTURE":
return entity.SUPERSTRUCTURE
default:
return entity.IP
}
}
func (m *MilvusProvider) buildVectorIndex() (entity.Index, error) {
// Map index type
indexConfig, _ := m.mapper.GetIndexConfig()
searchConfig, _ := m.mapper.GetSearchConfig()
// Map index parameters
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
if milvusIndexType == "" {
milvusIndexType = "HNSW"
}
metricType := m.GetMetricType(searchConfig.MetricType)
switch milvusIndexType {
case "FLAT":
// FLAT index doesn't need additional parameters
index, err := entity.NewIndexFlat(metricType)
if err != nil {
return nil, fmt.Errorf("failed to create FLAT index: %w", err)
}
return index, nil
case "BIN_FLAT":
// BIN_FLAT index doesn't need additional parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexBinFlat(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create BIN_FLAT index: %w", err)
}
return index, nil
case "IVF_FLAT":
// Default parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexIvfFlat(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_FLAT index: %w", err)
}
return index, nil
case "BIN_IVF_FLAT":
// Default parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexBinIvfFlat(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create BIN_IVF_FLAT index: %w", err)
}
return index, nil
case "IVF_SQ8":
// Default parameters
nlist := 128
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
index, err := entity.NewIndexIvfSQ8(metricType, nlist)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_SQ8 index: %w", err)
}
return index, nil
case "IVF_PQ":
// Default parameters
nlist := 128
m := 4
nbits := 8
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
if mVal, err := indexConfig.ParamsFloat64("m"); err == nil {
m = int(mVal)
}
if nbitsVal, err := indexConfig.ParamsInt64("nbits"); err == nil {
nbits = int(nbitsVal)
}
index, err := entity.NewIndexIvfPQ(metricType, nlist, m, nbits)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_PQ index: %w", err)
}
return index, nil
case "HNSW":
// Default parameters
m := 8
efConstruction := 64
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
m = int(mVal)
}
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
efConstruction = int(efConstructionVal)
}
index, err := entity.NewIndexHNSW(metricType, m, efConstruction)
if err != nil {
return nil, fmt.Errorf("failed to create HNSW index: %w", err)
}
return index, nil
case "IVF_HNSW":
// Default parameters
nlist := 128
m := 8
efConstruction := 64
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
m = int(mVal)
}
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
efConstruction = int(efConstructionVal)
}
index, err := entity.NewIndexIvfHNSW(metricType, nlist, m, efConstruction)
if err != nil {
return nil, fmt.Errorf("failed to create IVF_HNSW index: %w", err)
}
return index, nil
case "DISKANN":
// DISKANN index parameters
index, err := entity.NewIndexDISKANN(metricType)
if err != nil {
return nil, fmt.Errorf("failed to create DISKANN index: %w", err)
}
return index, nil
case "SCANN":
// SCANN index parameters
nlist := 128
with_raw_data := false
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
nlist = int(nlistVal)
}
if with_raw_dataVal, err := indexConfig.ParamsBool("with_raw_data"); err == nil {
with_raw_data = with_raw_dataVal
}
index, err := entity.NewIndexSCANN(metricType, nlist, with_raw_data)
if err != nil {
return nil, fmt.Errorf("failed to create SCANN index: %w", err)
}
return index, nil
case "AUTOINDEX":
// Auto index
index, err := entity.NewIndexAUTOINDEX(metricType)
if err != nil {
return nil, fmt.Errorf("failed to create AUTOINDEX index: %w", err)
}
return index, nil
default:
return nil, fmt.Errorf("unsupported index type: %s", milvusIndexType)
}
}
// CreateCollection creates a new collection with the specified dimension // CreateCollection creates a new collection with the specified dimension
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error { func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
// Check if collection exists // Check if collection exists
document_exists, err := m.client.HasCollection(ctx, m.Collection) document_exists, err := m.client.HasCollection(ctx, m.collection)
if err != nil { if err != nil {
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err) return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
} }
if !document_exists { if !document_exists {
fmt.Printf("create collection %s\n", m.Collection) fmt.Printf("create collection %s\n", m.collection)
// Create schema // Create schema
schema := entity.NewSchema(). schema, err := m.buildSchema()
WithName(m.Collection). if err != nil {
WithDescription("Knowledge document collection"). return fmt.Errorf("failed to build schema: %w", err)
WithAutoID(false). }
WithDynamicFieldEnabled(false)
// Add fields based on schema.Document structure
// Primary key field - ID
pkField := entity.NewField().
WithName("id").
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(256).
WithIsPrimaryKey(true).
WithIsAutoID(false)
schema.WithField(pkField)
// Content field
contentField := entity.NewField().
WithName("content").
WithDataType(entity.FieldTypeVarChar).
WithMaxLength(8192)
schema.WithField(contentField)
// Vector field
vectorField := entity.NewField().
WithName("vector").
WithDataType(entity.FieldTypeFloatVector).
WithDim(int64(dim))
schema.WithField(vectorField)
// Metadata field
metadataField := entity.NewField().
WithName("metadata").
WithDataType(entity.FieldTypeJSON)
schema.WithField(metadataField)
// CreatedAt field (stored as Unix timestamp)
createdAtField := entity.NewField().
WithName("created_at").
WithDataType(entity.FieldTypeInt64)
schema.WithField(createdAtField)
// Create collection // Create collection
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber) err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
if err != nil { if err != nil {
return fmt.Errorf("failed to create collection: %w", err) return fmt.Errorf("failed to create collection: %w", err)
} }
// Create vector index // Create vector index
vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64) vectorIndex, err := m.buildVectorIndex()
vectorField, _ := m.mapper.GetVectorField()
if err != nil { if err != nil {
return fmt.Errorf("failed to create vector index: %w", err) return fmt.Errorf("failed to create vector index: %w", err)
} }
err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index")) err = m.client.CreateIndex(ctx, m.collection, vectorField.RawName, vectorIndex, false, client.WithIndexName("vector_index"))
if err != nil { if err != nil {
return fmt.Errorf("failed to create vector index: %w", err) return fmt.Errorf("failed to create vector index: %w", err)
} }
} }
// Load collection // Load collection
err = m.client.LoadCollection(ctx, m.Collection, false) err = m.client.LoadCollection(ctx, m.collection, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to load document collection: %w", err) return fmt.Errorf("failed to load document collection: %w", err)
} }
@@ -197,15 +407,15 @@ func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
// DropCollection removes the collection from the database // DropCollection removes the collection from the database
func (m *MilvusProvider) DropCollection(ctx context.Context) error { func (m *MilvusProvider) DropCollection(ctx context.Context) error {
// Check if collection exists // Check if collection exists
exists, err := m.client.HasCollection(ctx, m.Collection) exists, err := m.client.HasCollection(ctx, m.collection)
if err != nil { if err != nil {
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err) return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
} }
if !exists { if !exists {
return fmt.Errorf("collection %s does not exist", m.Collection) return fmt.Errorf("collection %s does not exist", m.collection)
} }
// Drop collection // Drop collection
err = m.client.DropCollection(ctx, m.Collection) err = m.client.DropCollection(ctx, m.collection)
if err != nil { if err != nil {
return fmt.Errorf("failed to drop collection: %w", err) return fmt.Errorf("failed to drop collection: %w", err)
} }
@@ -217,51 +427,71 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
if len(docs) == 0 { if len(docs) == 0 {
return nil return nil
} }
// Prepare data
ids := make([]string, len(docs))
contents := make([]string, len(docs))
vectors := make([][]float32, len(docs))
metadatas := make([][]byte, len(docs))
createdAts := make([]int64, len(docs))
for i, doc := range docs { // Get field mappings
ids[i] = doc.ID fieldMappings, err := m.mapper.GetFieldMappings()
contents[i] = doc.Content if err != nil {
return fmt.Errorf("failed to get field mappings: %w", err)
// Convert vector type
vectorFloat32 := make([]float32, len(doc.Vector))
for j, v := range doc.Vector {
vectorFloat32[j] = float32(v)
} }
vectors[i] = vectorFloat32 // Prepare data and columns
columns := make([]entity.Column, 0, len(fieldMappings))
// Create corresponding column data for each field
for _, field := range fieldMappings {
// Skip ID field if configured as auto ID
if field.IsPrimaryKey() && field.IsAutoID() {
continue
}
switch field.StandardName {
case "id":
// Handle string type fields
values := make([]string, len(docs))
for i, doc := range docs {
values[i] = doc.ID
}
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
case "content":
values := make([]string, len(docs))
for i, doc := range docs {
values[i] = doc.Content
}
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
case "vector":
// Handle vector fields
vectors := make([][]float32, len(docs))
for i, doc := range docs {
vectors[i] = doc.Vector
}
columns = append(columns, entity.NewColumnFloatVector(field.RawName, len(vectors[0]), vectors))
case "metadata":
// Handle JSON type fields (like metadata)
values := make([][]byte, len(docs))
for i, doc := range docs {
// Serialize metadata // Serialize metadata
metadataBytes, err := json.Marshal(doc.Metadata) metadataBytes, err := json.Marshal(doc.Metadata)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err) return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
} }
metadatas[i] = metadataBytes values[i] = metadataBytes
}
createdAts[i] = doc.CreatedAt.UnixMilli() columns = append(columns, entity.NewColumnJSONBytes(field.RawName, values))
case "created_at":
// Handle integer type fields
values := make([]int64, len(docs))
for i, doc := range docs {
values[i] = doc.CreatedAt.UnixMilli()
}
columns = append(columns, entity.NewColumnInt64(field.RawName, values))
} }
// Build insert data
columns := []entity.Column{
entity.NewColumnVarChar("id", ids),
entity.NewColumnVarChar("content", contents),
entity.NewColumnFloatVector("vector", len(vectors[0]), vectors),
entity.NewColumnJSONBytes("metadata", metadatas),
entity.NewColumnInt64("created_at", createdAts),
} }
// Insert data // Insert data
_, err := m.client.Insert(ctx, m.Collection, "", columns...) _, err = m.client.Insert(ctx, m.collection, "", columns...)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert documents: %w", err) return fmt.Errorf("failed to insert documents: %w", err)
} }
// Flush data // Flush data
err = m.client.Flush(ctx, m.Collection, false) err = m.client.Flush(ctx, m.collection, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to flush collection: %w", err) return fmt.Errorf("failed to flush collection: %w", err)
} }
@@ -271,16 +501,19 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
// DeleteDoc deletes a document by its ID // DeleteDoc deletes a document by its ID
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error { func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
// Build delete expression // Get ID field
expr := fmt.Sprintf(`id == "%s"`, id) idField, _ := m.mapper.GetIDField()
// Build delete expression using the RawName of ID field
expr := fmt.Sprintf(`%s == "%s"`, idField.RawName, id)
// Delete data // Delete data
err := m.client.Delete(ctx, m.Collection, "", expr) err := m.client.Delete(ctx, m.collection, "", expr)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete documents for id %s: %w", id, err) return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
} }
// Flush data // Flush data
err = m.client.Flush(ctx, m.Collection, false) err = m.client.Flush(ctx, m.collection, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to flush collection after delete: %w", err) return fmt.Errorf("failed to flush collection after delete: %w", err)
} }
@@ -306,24 +539,127 @@ func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document)
return nil return nil
} }
func (m *MilvusProvider) buildSearchParam() (entity.SearchParam, error) {
// Get index configuration
indexConfig, err := m.mapper.GetIndexConfig()
if err != nil {
return nil, fmt.Errorf("failed to get index config: %w", err)
}
// Get search configuration
searchConfig, err := m.mapper.GetSearchConfig()
if err != nil {
return nil, fmt.Errorf("failed to get search config: %w", err)
}
// Choose appropriate search parameters based on index type
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
if milvusIndexType == "" {
milvusIndexType = "HNSW" // Default to HNSW index
}
switch milvusIndexType {
case "FLAT":
// FLAT and BIN_FLAT indices don't need additional search parameters
return entity.NewIndexFlatSearchParam()
case "BIN_FLAT", "IVF_FLAT", "BIN_IVF_FLAT", "IVF_SQ8":
// Search parameters for IVF series indices
nprobe := 16 // Default value
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
return entity.NewIndexIvfFlatSearchParam(nprobe)
case "IVF_PQ":
// Search parameters for IVF_PQ index
nprobe := 16 // Default value
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
return entity.NewIndexIvfPQSearchParam(nprobe)
case "HNSW":
// Search parameters for HNSW index
efSearch := 16 // Default value
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
efSearch = int(efSearchVal)
}
return entity.NewIndexHNSWSearchParam(efSearch)
case "IVF_HNSW":
// Search parameters for IVF_HNSW index
nprobe := 16 // Default value
efSearch := 64 // Default value
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
efSearch = int(efSearchVal)
}
return entity.NewIndexIvfHNSWSearchParam(nprobe, efSearch)
case "SCANN":
// Search parameters for SCANN index
nprobe := 16 // Default value
reorder_k := 64
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
nprobe = int(nprobeVal)
}
if reorderKVal, err := searchConfig.ParamsInt64("reorder_k"); err == nil {
reorder_k = int(reorderKVal)
}
return entity.NewIndexSCANNSearchParam(nprobe, reorder_k)
case "DISKANN":
// Search parameters for DISKANN index
search_list := 100 // Default value
if searchListVal, err := searchConfig.ParamsInt64("search_list"); err == nil {
search_list = int(searchListVal)
}
return entity.NewIndexDISKANNSearchParam(search_list)
case "AUTOINDEX":
level := 8
if levelVal, err := searchConfig.ParamsInt64("level"); err == nil {
level = int(levelVal)
}
// Search parameters for AUTOINDEX index
return entity.NewIndexAUTOINDEXSearchParam(level)
default:
// Default to using HNSW search parameters
return entity.NewIndexHNSWSearchParam(16)
}
}
// SearchDocs performs similarity search for documents // SearchDocs performs similarity search for documents
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) { func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
if options == nil { if options == nil {
options = &schema.SearchOptions{TopK: 10} options = &schema.SearchOptions{TopK: 10}
} }
// Build search parameters // Build search parameters
sp, _ := entity.NewIndexHNSWSearchParam(16) sp, err := m.buildSearchParam()
if err != nil {
return nil, fmt.Errorf("failed to build search param: %w", err)
}
outputFields, _ := m.mapper.GetRawAllFieldNames()
vectorField, _ := m.mapper.GetVectorField()
searchConfig, _ := m.mapper.GetSearchConfig()
metricType := m.GetMetricType(searchConfig.MetricType)
// Build filter expression // Build filter expression
expr := "" expr := ""
searchResults, err := m.client.Search( searchResults, err := m.client.Search(
ctx, ctx,
m.Collection, m.collection,
[]string{}, // partition names []string{}, // partition names
expr, // filter expression expr, // filter expression
[]string{"id", "content", "metadata", "created_at"}, // output fields outputFields, // output fields
[]entity.Vector{entity.FloatVector(vector)}, []entity.Vector{entity.FloatVector(vector)},
"vector", // anns_field vectorField.RawName, // anns_field
entity.IP, // metric_type metricType, // metric_type
options.TopK, options.TopK,
sp, sp,
) )
@@ -341,9 +677,13 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
// Get field data // Get field data
var content string var content string
var metadata map[string]interface{} var metadata map[string]interface{}
for _, field := range result.Fields { for _, field := range result.Fields {
switch field.Name() { fieldMapping, err := m.mapper.GetField(field.Name())
if err != nil {
continue
}
fieldName := strings.ToLower(fieldMapping.StandardName)
switch fieldName {
case "content": case "content":
if contentCol, ok := field.(*entity.ColumnVarChar); ok { if contentCol, ok := field.(*entity.ColumnVarChar); ok {
if contentVal, err := contentCol.Get(i); err == nil { if contentVal, err := contentCol.Get(i); err == nil {
@@ -364,7 +704,6 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
} }
} }
} }
searchResult := schema.SearchResult{ searchResult := schema.SearchResult{
Document: schema.Document{ Document: schema.Document{
ID: fmt.Sprintf("%s", id), ID: fmt.Sprintf("%s", id),
@@ -392,15 +731,17 @@ func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error {
for i, id := range ids { for i, id := range ids {
quotedIDs[i] = fmt.Sprintf("\"%s\"", id) quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
} }
expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ","))
idField, _ := m.mapper.GetIDField()
expr := fmt.Sprintf("%s in [%s]", idField.RawName, strings.Join(quotedIDs, ","))
// Delete data // Delete data
err := m.client.Delete(ctx, m.Collection, "", expr) err := m.client.Delete(ctx, m.collection, "", expr)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete documents: %w", err) return fmt.Errorf("failed to delete documents: %w", err)
} }
// Flush data // Flush data
err = m.client.Flush(ctx, m.Collection, false) err = m.client.Flush(ctx, m.collection, false)
if err != nil { if err != nil {
return fmt.Errorf("failed to flush collection after delete: %w", err) return fmt.Errorf("failed to flush collection after delete: %w", err)
} }
@@ -413,12 +754,13 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
// Build query expression // Build query expression
expr := "" expr := ""
// Query all relevant documents // Query all relevant documents
outputFields, _ := m.mapper.GetRawAllFieldNames()
queryResult, err := m.client.Query( queryResult, err := m.client.Query(
ctx, ctx,
m.Collection, m.collection,
[]string{}, // partitions []string{}, // partitions
expr, // filter condition expr, // filter condition
[]string{"id", "content", "metadata", "created_at"}, outputFields,
client.WithOffset(0), client.WithLimit(int64(limit)), client.WithOffset(0), client.WithLimit(int64(limit)),
) )
@@ -443,7 +785,12 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
) )
for _, col := range queryResult { for _, col := range queryResult {
switch col.Name() { fieldMapping, err := m.mapper.GetField(col.Name())
if err != nil {
continue
}
fieldName := strings.ToLower(fieldMapping.StandardName)
switch fieldName {
case "id": case "id":
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil { if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
id = v.(string) id = v.(string)
@@ -488,8 +835,3 @@ func (m *MilvusProvider) Close() error {
} }
return nil return nil
} }
// joinStrings joins a slice of strings with the given separator
func joinStrings(elems []string, sep string) string {
return strings.Join(elems, sep)
}