mirror of
https://github.com/alibaba/higress.git
synced 2026-06-26 02:35:02 +08:00
add vectordb mapping (#2968)
This commit is contained in:
@@ -55,16 +55,21 @@ require (
|
|||||||
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect
|
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 // indirect
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/openai/openai-go/v2 v2.7.0 // indirect
|
||||||
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect
|
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc // indirect
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
||||||
github.com/prometheus/client_golang v1.14.0 // indirect
|
github.com/prometheus/client_golang v1.14.0 // indirect
|
||||||
github.com/prometheus/client_model v0.4.0 // indirect
|
github.com/prometheus/client_model v0.4.0 // indirect
|
||||||
github.com/prometheus/common v0.37.0 // indirect
|
github.com/prometheus/common v0.37.0 // indirect
|
||||||
github.com/prometheus/procfs v0.8.0 // indirect
|
github.com/prometheus/procfs v0.8.0 // indirect
|
||||||
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
github.com/tjfoc/gmsm v1.4.1 // indirect
|
github.com/tjfoc/gmsm v1.4.1 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
go.uber.org/zap v1.27.0 // indirect
|
go.uber.org/zap v1.27.0 // indirect
|
||||||
golang.org/x/net v0.33.0 // indirect
|
golang.org/x/net v0.34.0 // indirect
|
||||||
golang.org/x/time v0.3.0 // indirect
|
golang.org/x/time v0.3.0 // indirect
|
||||||
google.golang.org/grpc v1.59.0 // indirect
|
google.golang.org/grpc v1.59.0 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
@@ -99,9 +104,9 @@ require (
|
|||||||
github.com/shopspring/decimal v1.4.0 // indirect
|
github.com/shopspring/decimal v1.4.0 // indirect
|
||||||
go.opentelemetry.io/otel v1.26.0 // indirect
|
go.opentelemetry.io/otel v1.26.0 // indirect
|
||||||
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
go.opentelemetry.io/otel/trace v1.26.0 // indirect
|
||||||
golang.org/x/crypto v0.31.0 // indirect
|
golang.org/x/crypto v0.32.0 // indirect
|
||||||
golang.org/x/sync v0.10.0 // indirect
|
golang.org/x/sync v0.10.0 // indirect
|
||||||
golang.org/x/sys v0.28.0 // indirect
|
golang.org/x/sys v0.29.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20230822172742-b8732ec3820d // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
|
||||||
|
|||||||
@@ -311,6 +311,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
|||||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||||
github.com/onsi/gomega v1.24.2 h1:J/tulyYK6JwBldPViHJReihxxZ+22FHs0piGjQAvoUE=
|
github.com/onsi/gomega v1.24.2 h1:J/tulyYK6JwBldPViHJReihxxZ+22FHs0piGjQAvoUE=
|
||||||
github.com/onsi/gomega v1.24.2/go.mod h1:gs3J10IS7Z7r7eXRoNJIrNqU4ToQukCJhFtKrWgHWnk=
|
github.com/onsi/gomega v1.24.2/go.mod h1:gs3J10IS7Z7r7eXRoNJIrNqU4ToQukCJhFtKrWgHWnk=
|
||||||
|
github.com/openai/openai-go/v2 v2.7.0 h1:/8MSFCXcasin7AyuWQ2au6FraXL71gzAs+VfbMv+J3k=
|
||||||
|
github.com/openai/openai-go/v2 v2.7.0/go.mod h1:jrJs23apqJKKbT+pqtFgNKpRju/KP9zpUTZhz3GElQE=
|
||||||
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc h1:Ak86L+yDSOzKFa7WM5bf5itSOo1e3Xh8bm5YCMUXIjQ=
|
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc h1:Ak86L+yDSOzKFa7WM5bf5itSOo1e3Xh8bm5YCMUXIjQ=
|
||||||
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI=
|
github.com/orcaman/concurrent-map v0.0.0-20210501183033-44dafcb38ecc/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI=
|
||||||
github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU=
|
github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU=
|
||||||
@@ -377,7 +379,18 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||||
|
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w=
|
github.com/tjfoc/gmsm v1.3.2/go.mod h1:HaUcFuY0auTiaHB9MHFGCPx5IaLhTUd2atbCFBQXn9w=
|
||||||
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
|
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
|
||||||
github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE=
|
github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE=
|
||||||
@@ -426,6 +439,8 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf
|
|||||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
|
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
|
||||||
|
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
@@ -506,6 +521,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
|||||||
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||||
|
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||||
|
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
@@ -581,6 +598,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|||||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
|
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||||
|
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
|
|||||||
@@ -84,10 +84,11 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
|
|||||||
| llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 |
|
| llm.max_tokens | integer | 可选 | 2048 | 最大令牌数 |
|
||||||
| llm.temperature | float | 可选 | 0.5 | 温度参数 |
|
| llm.temperature | float | 可选 | 0.5 | 温度参数 |
|
||||||
| **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) |
|
| **embedding** | object | 必填 | - | 嵌入配置(所有工具必需) |
|
||||||
| embedding.provider | string | 必填 | dashscope | 嵌入提供商:openai或dashscope |
|
| embedding.provider | string | 必填 | openai | 嵌入提供商:支持openai协议的任意供应商 |
|
||||||
| embedding.api_key | string | 必填 | - | 嵌入API密钥 |
|
| embedding.api_key | string | 必填 | - | 嵌入API密钥 |
|
||||||
| embedding.base_url | string | 可选 | | 嵌入API基础URL |
|
| embedding.base_url | string | 可选 | | 嵌入API基础URL |
|
||||||
| embedding.model | string | 必填 | text-embedding-v4 | 嵌入模型名称 |
|
| embedding.model | string | 必填 | text-embedding-ada-002 | 嵌入模型名称 |
|
||||||
|
| embedding.dimensions | integer | 可选 | 1536 | 嵌入维度 |
|
||||||
| **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) |
|
| **vectordb** | object | 必填 | - | 向量数据库配置(所有工具必需) |
|
||||||
| vectordb.provider | string | 必填 | milvus | 向量数据库提供商 |
|
| vectordb.provider | string | 必填 | milvus | 向量数据库提供商 |
|
||||||
| vectordb.host | string | 必填 | localhost | 数据库主机地址 |
|
| vectordb.host | string | 必填 | localhost | 数据库主机地址 |
|
||||||
@@ -96,6 +97,17 @@ Higress RAG MCP Server 提供以下工具,根据配置不同,可用工具也
|
|||||||
| vectordb.collection | string | 必填 | test_collection | 集合名称 |
|
| vectordb.collection | string | 必填 | test_collection | 集合名称 |
|
||||||
| vectordb.username | string | 可选 | - | 数据库用户名 |
|
| vectordb.username | string | 可选 | - | 数据库用户名 |
|
||||||
| vectordb.password | string | 可选 | - | 数据库密码 |
|
| vectordb.password | string | 可选 | - | 数据库密码 |
|
||||||
|
| **vectordb.mapping** | object | 可选 | - | 字段映射配置 |
|
||||||
|
| vectordb.mapping.fields | array | 可选 | - | 字段映射列表 |
|
||||||
|
| vectordb.mapping.fields[].standard_name | string | 必填 | - | 标准字段名称(如 id, content, vector 等) |
|
||||||
|
| vectordb.mapping.fields[].raw_name | string | 必填 | - | 原始字段名称(数据库中的实际字段名) |
|
||||||
|
| vectordb.mapping.fields[].properties | object | 可选 | - | 字段属性(如 auto_id, max_length 等) |
|
||||||
|
| vectordb.mapping.index | object | 可选 | - | 索引配置 |
|
||||||
|
| vectordb.mapping.index.index_type | string | 必填 | - | 索引类型(如 FLAT, IVF_FLAT, HNSW 等) |
|
||||||
|
| vectordb.mapping.index.params | object | 可选 | - | 索引参数(根据索引类型不同而异) |
|
||||||
|
| vectordb.mapping.search | object | 可选 | - | 搜索配置 |
|
||||||
|
| vectordb.mapping.search.metric_type | string | 可选 | L2 | 度量类型(如 L2, IP, COSINE 等) |
|
||||||
|
| vectordb.mapping.search.params | object | 可选 | - | 搜索参数(如 nprobe, ef_search 等)
|
||||||
|
|
||||||
|
|
||||||
### higress-config 配置样例
|
### higress-config 配置样例
|
||||||
@@ -143,27 +155,54 @@ data:
|
|||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
max_tokens: 2048
|
max_tokens: 2048
|
||||||
embedding:
|
embedding:
|
||||||
provider: dashscope
|
provider: openai
|
||||||
|
base_url: https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||||
api_key: sk-xxx
|
api_key: sk-xxx
|
||||||
model: text-embedding-v4
|
model: text-embedding-v4
|
||||||
|
dimensions: 1536
|
||||||
vectordb:
|
vectordb:
|
||||||
provider: milvus
|
provider: milvus
|
||||||
host: <milvus IP>
|
host: localhost
|
||||||
port: 19530
|
port: 19530
|
||||||
database: default
|
database: default
|
||||||
collection: test_collection
|
collection: test_rag
|
||||||
```
|
mapping:
|
||||||
|
fields:
|
||||||
|
- standard_name: id
|
||||||
|
raw_name: id
|
||||||
|
properties:
|
||||||
|
auto_id: false
|
||||||
|
max_length: 256
|
||||||
|
- standard_name: content
|
||||||
|
raw_name: content
|
||||||
|
properties:
|
||||||
|
max_length: 8192
|
||||||
|
- standard_name: vector
|
||||||
|
raw_name: vector
|
||||||
|
- standard_name: metadata
|
||||||
|
raw_name: metadata
|
||||||
|
- standard_name: created_at
|
||||||
|
raw_name: created_at
|
||||||
|
index:
|
||||||
|
index_type: HNSW
|
||||||
|
params:
|
||||||
|
M: 4
|
||||||
|
efConstruction: 32
|
||||||
|
search:
|
||||||
|
metric_type: IP
|
||||||
|
params:
|
||||||
|
ef: 32
|
||||||
|
|
||||||
|
```
|
||||||
### 支持的提供商
|
### 支持的提供商
|
||||||
#### Embedding
|
#### Embedding
|
||||||
- **OpenAI**
|
- **OpenAI 兼容**
|
||||||
- **DashScope**
|
|
||||||
|
|
||||||
#### Vector Database
|
#### Vector Database
|
||||||
- **Milvus**
|
- **Milvus**
|
||||||
|
|
||||||
#### LLM
|
#### LLM
|
||||||
- **OpenAI**
|
- **OpenAI 兼容**
|
||||||
|
|
||||||
## 如何测试数据集的效果
|
## 如何测试数据集的效果
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
// Config represents the main configuration structure for the MCP server
|
// Config represents the main configuration structure for the MCP server
|
||||||
type Config struct {
|
type Config struct {
|
||||||
RAG RAGConfig `json:"rag" yaml:"rag"`
|
RAG RAGConfig `json:"rag" yaml:"rag"`
|
||||||
@@ -34,20 +36,148 @@ type LLMConfig struct {
|
|||||||
|
|
||||||
// EmbeddingConfig defines configuration for embedding models
|
// EmbeddingConfig defines configuration for embedding models
|
||||||
type EmbeddingConfig struct {
|
type EmbeddingConfig struct {
|
||||||
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
|
Provider string `json:"provider" yaml:"provider"` // Available options: openai, dashscope
|
||||||
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
APIKey string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
|
||||||
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
BaseURL string `json:"base_url,omitempty" yaml:"base_url,omitempty"`
|
||||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||||
Dimension int `json:"dimension,omitempty" yaml:"dimension,omitempty"`
|
Dimensions int `json:"dimensions,omitempty" yaml:"dimension,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// VectorDBConfig defines configuration for vector databases
|
// VectorDBConfig defines configuration for vector databases
|
||||||
type VectorDBConfig struct {
|
type VectorDBConfig struct {
|
||||||
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
|
Provider string `json:"provider" yaml:"provider"` // Available options: milvus, qdrant, chroma
|
||||||
Host string `json:"host,omitempty" yaml:"host,omitempty"`
|
Host string `json:"host,omitempty" yaml:"host,omitempty"`
|
||||||
Port int `json:"port,omitempty" yaml:"port,omitempty"`
|
Port int `json:"port,omitempty" yaml:"port,omitempty"`
|
||||||
Database string `json:"database,omitempty" yaml:"database,omitempty"`
|
Database string `json:"database,omitempty" yaml:"database,omitempty"`
|
||||||
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
|
Collection string `json:"collection,omitempty" yaml:"collection,omitempty"`
|
||||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||||
|
Mapping MappingConfig `json:"mapping,omitempty" yaml:"mapping,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MappingConfig defines field mapping configuration for vector databases
|
||||||
|
type MappingConfig struct {
|
||||||
|
Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
|
||||||
|
Index IndexConfig `json:"index,omitempty" yaml:"index,omitempty"`
|
||||||
|
Search SearchConfig `json:"search,omitempty" yaml:"search,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// // CollectionMapping defines field mapping for collection
|
||||||
|
// type CollectionMapping struct {
|
||||||
|
// Fields []FieldMapping `json:"fields,omitempty" yaml:"fields,omitempty"`
|
||||||
|
// }
|
||||||
|
|
||||||
|
// FieldMapping defines mapping for a single field
|
||||||
|
type FieldMapping struct {
|
||||||
|
StandardName string `json:"standard_name" yaml:"standard_name"`
|
||||||
|
RawName string `json:"raw_name" yaml:"raw_name"`
|
||||||
|
Properties map[string]interface{} `json:"properties,omitempty" yaml:"properties,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FieldMapping) IsPrimaryKey() bool {
|
||||||
|
return f.StandardName == "id"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FieldMapping) IsAutoID() bool {
|
||||||
|
if f.Properties == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
autoID, ok := f.Properties["auto_id"].(bool)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return autoID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FieldMapping) IsVectorField() bool {
|
||||||
|
return f.StandardName == "vector"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f FieldMapping) MaxLength() int {
|
||||||
|
if f.Properties == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
maxLength, ok := f.Properties["max_length"].(int)
|
||||||
|
if !ok {
|
||||||
|
return 256
|
||||||
|
}
|
||||||
|
return maxLength
|
||||||
|
}
|
||||||
|
|
||||||
|
// IndexConfig defines configuration for index parameters
|
||||||
|
type IndexConfig struct {
|
||||||
|
// Index type, e.g., IVF_FLAT, IVF_SQ8, HNSW, etc.
|
||||||
|
IndexType string `json:"index_type" yaml:"index_type"`
|
||||||
|
// Index parameter configuration
|
||||||
|
Params map[string]interface{} `json:"params" yaml:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i IndexConfig) ParamsString(key string) (string, error) {
|
||||||
|
if mVal, ok := i.Params[key].(string); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i IndexConfig) ParamsInt64(key string) (int64, error) {
|
||||||
|
if mVal, ok := i.Params[key].(int64); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
if mVal, ok := i.Params[key].(int); ok {
|
||||||
|
return int64(mVal), nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i IndexConfig) ParamsFloat64(key string) (float64, error) {
|
||||||
|
if mVal, ok := i.Params[key].(float64); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
if mVal, ok := i.Params[key].(float32); ok {
|
||||||
|
return float64(mVal), nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i IndexConfig) ParamsBool(key string) (bool, error) {
|
||||||
|
if mVal, ok := i.Params[key].(bool); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchConfig defines configuration for search parameters
|
||||||
|
type SearchConfig struct {
|
||||||
|
// Metric type, e.g., L2, IP, etc.
|
||||||
|
MetricType string `json:"metric_type,omitempty" yaml:"metric_type,omitempty"`
|
||||||
|
// Search parameter configuration
|
||||||
|
Params map[string]interface{} `json:"params" yaml:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i SearchConfig) ParamsString(key string) (string, error) {
|
||||||
|
if mVal, ok := i.Params[key].(string); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i SearchConfig) ParamsInt64(key string) (int64, error) {
|
||||||
|
if mVal, ok := i.Params[key].(int64); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i SearchConfig) ParamsFloat64(key string) (float64, error) {
|
||||||
|
if mVal, ok := i.Params[key].(float64); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("params %s not found", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i SearchConfig) ParamsBool(key string) (bool, error) {
|
||||||
|
if mVal, ok := i.Params[key].(bool); ok {
|
||||||
|
return mVal, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("params %s not found", key)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,169 +0,0 @@
|
|||||||
package embedding
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DASHSCOPE_DOMAIN = "dashscope.aliyuncs.com"
|
|
||||||
DASHSCOPE_PORT = 443
|
|
||||||
DASHSCOPE_DEFAULT_MODEL_NAME = "text-embedding-v4"
|
|
||||||
DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
|
||||||
)
|
|
||||||
|
|
||||||
var dashScopeConfig dashScopeProviderConfig
|
|
||||||
|
|
||||||
type dashScopeProviderInitializer struct {
|
|
||||||
}
|
|
||||||
type dashScopeProviderConfig struct {
|
|
||||||
apiKey string
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *dashScopeProviderInitializer) InitConfig(config config.EmbeddingConfig) {
|
|
||||||
dashScopeConfig.apiKey = config.APIKey
|
|
||||||
dashScopeConfig.model = config.Model
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *dashScopeProviderInitializer) ValidateConfig() error {
|
|
||||||
if dashScopeConfig.apiKey == "" {
|
|
||||||
return errors.New("[DashScope] apiKey is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *dashScopeProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
|
||||||
c.InitConfig(config)
|
|
||||||
err := c.ValidateConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
headers := map[string]string{
|
|
||||||
"Authorization": "Bearer " + config.APIKey,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
httpClient := common.NewHTTPClient(fmt.Sprintf("https://%s", DASHSCOPE_DOMAIN), headers)
|
|
||||||
|
|
||||||
return &DashScopeProvider{
|
|
||||||
config: dashScopeConfig,
|
|
||||||
client: httpClient,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DashScopeProvider) GetProviderType() string {
|
|
||||||
return PROVIDER_TYPE_DASHSCOPE
|
|
||||||
}
|
|
||||||
|
|
||||||
type Embedding struct {
|
|
||||||
Embedding []float32 `json:"embedding"`
|
|
||||||
TextIndex int `json:"text_index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Input struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Params struct {
|
|
||||||
TextType string `json:"text_type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
Output Output `json:"output"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Output struct {
|
|
||||||
Embeddings []Embedding `json:"embeddings"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Usage struct {
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input Input `json:"input"`
|
|
||||||
Parameters Params `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Document struct {
|
|
||||||
Vector []float64 `json:"vector"`
|
|
||||||
Fields map[string]string `json:"fields"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DashScopeProvider struct {
|
|
||||||
config dashScopeProviderConfig
|
|
||||||
client *common.HTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DashScopeProvider) constructRequestData(texts []string) (EmbeddingRequest, error) {
|
|
||||||
model := d.config.model
|
|
||||||
if model == "" {
|
|
||||||
model = DASHSCOPE_DEFAULT_MODEL_NAME
|
|
||||||
}
|
|
||||||
|
|
||||||
if dashScopeConfig.apiKey == "" {
|
|
||||||
return EmbeddingRequest{}, errors.New("dashScopeKey is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
data := EmbeddingRequest{
|
|
||||||
Model: model,
|
|
||||||
Input: Input{
|
|
||||||
Texts: texts,
|
|
||||||
},
|
|
||||||
Parameters: Params{
|
|
||||||
TextType: "query",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Result struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Vector []float32 `json:"vector,omitempty"`
|
|
||||||
Fields map[string]interface{} `json:"fields"`
|
|
||||||
Score float64 `json:"score"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DashScopeProvider) parseTextEmbedding(responseBody []byte) (*Response, error) {
|
|
||||||
var resp Response
|
|
||||||
err := json.Unmarshal(responseBody, &resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DashScopeProvider) GetEmbedding(
|
|
||||||
ctx context.Context,
|
|
||||||
queryString string) ([]float32, error) {
|
|
||||||
requestData, err := d.constructRequestData([]string{queryString})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to construct request data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseBody, err := d.client.Post(DASHSCOPE_ENDPOINT, requestData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
embeddingResp, err := d.parseTextEmbedding(responseBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(embeddingResp.Output.Embeddings) == 0 {
|
|
||||||
return nil, errors.New("no embedding found in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return embeddingResp.Output.Embeddings[0].Embedding, nil
|
|
||||||
}
|
|
||||||
@@ -2,160 +2,93 @@ package embedding
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||||
|
"github.com/openai/openai-go/v2"
|
||||||
|
"github.com/openai/openai-go/v2/option"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OPENAI_DOMAIN = "api.openai.com"
|
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-ada-002"
|
||||||
OPENAI_PORT = 443
|
|
||||||
OPENAI_DEFAULT_MODEL_NAME = "text-embedding-3-small"
|
|
||||||
OPENAI_ENDPOINT = "/v1/embeddings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type openAIProviderInitializer struct {
|
type openAIProviderInitializer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var openAIConfig openAIProviderConfig
|
func (c *openAIProviderInitializer) validateConfig(config *config.EmbeddingConfig) error {
|
||||||
|
if config.APIKey == "" {
|
||||||
type openAIProviderConfig struct {
|
return errors.New("[openai embbeding] apiKey is required")
|
||||||
baseUrl string
|
|
||||||
apiKey string
|
|
||||||
model string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *openAIProviderInitializer) InitConfig(config config.EmbeddingConfig) {
|
|
||||||
openAIConfig.apiKey = config.APIKey
|
|
||||||
openAIConfig.model = config.Model
|
|
||||||
openAIConfig.baseUrl = config.BaseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *openAIProviderInitializer) ValidateConfig() error {
|
|
||||||
if openAIConfig.apiKey == "" {
|
|
||||||
return errors.New("[openAI] apiKey is required")
|
|
||||||
}
|
}
|
||||||
|
if config.Model == "" {
|
||||||
|
config.Model = OPENAI_DEFAULT_MODEL_NAME
|
||||||
|
}
|
||||||
|
if config.Dimensions <= 0 {
|
||||||
|
config.Dimensions = 1536
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
func (c *openAIProviderInitializer) CreateProvider(config config.EmbeddingConfig) (Provider, error) {
|
||||||
c.InitConfig(config)
|
if err := c.validateConfig(&config); err != nil {
|
||||||
err := c.ValidateConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// 创建 OpenAI 客户端
|
||||||
|
var clientOptions []option.RequestOption
|
||||||
|
clientOptions = append(clientOptions, option.WithAPIKey(config.APIKey))
|
||||||
|
|
||||||
if openAIConfig.model == "" {
|
// 如果设置了自定义 baseURL,则使用它
|
||||||
openAIConfig.model = OPENAI_DEFAULT_MODEL_NAME
|
if config.BaseURL != "" {
|
||||||
|
clientOptions = append(clientOptions, option.WithBaseURL(config.BaseURL))
|
||||||
}
|
}
|
||||||
|
// 创建 OpenAI 客户端
|
||||||
if openAIConfig.baseUrl == "" {
|
client := openai.NewClient(clientOptions...)
|
||||||
openAIConfig.baseUrl = fmt.Sprintf("https://%s", OPENAI_DOMAIN)
|
|
||||||
}
|
|
||||||
|
|
||||||
headers := map[string]string{
|
|
||||||
"Authorization": "Bearer " + config.APIKey,
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
httpClient := common.NewHTTPClient(openAIConfig.baseUrl, headers)
|
|
||||||
|
|
||||||
return &OpenAIProvider{
|
return &OpenAIProvider{
|
||||||
config: openAIConfig,
|
client: &client,
|
||||||
client: httpClient,
|
model: config.Model,
|
||||||
|
dimensions: config.Dimensions,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenAIProvider) GetProviderType() string {
|
// EmbeddingClient handles vector embedding generation using OpenAI-compatible APIs
|
||||||
|
type OpenAIProvider struct {
|
||||||
|
client *openai.Client
|
||||||
|
model string
|
||||||
|
dimensions int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *OpenAIProvider) GetProviderType() string {
|
||||||
return PROVIDER_TYPE_OPENAI
|
return PROVIDER_TYPE_OPENAI
|
||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIResponse struct {
|
// GetEmbedding generates vector embedding for the given text
|
||||||
Object string `json:"object"`
|
func (e *OpenAIProvider) GetEmbedding(ctx context.Context, text string) ([]float32, error) {
|
||||||
Data []OpenAIResult `json:"data"`
|
params := openai.EmbeddingNewParams{
|
||||||
Model string `json:"model"`
|
Model: e.model,
|
||||||
Error *OpenAIError `json:"error"`
|
Input: openai.EmbeddingNewParamsInputUnion{
|
||||||
}
|
OfString: openai.String(text),
|
||||||
|
},
|
||||||
type OpenAIResult struct {
|
Dimensions: openai.Int(int64(e.dimensions)),
|
||||||
Object string `json:"object"`
|
EncodingFormat: openai.EmbeddingNewParamsEncodingFormatFloat,
|
||||||
Embedding []float32 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIError struct {
|
|
||||||
Message string `json:"prompt_tokens"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Code string `json:"code"`
|
|
||||||
Param string `json:"param"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIEmbeddingRequest struct {
|
|
||||||
Input string `json:"input"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIProvider struct {
|
|
||||||
config openAIProviderConfig
|
|
||||||
client *common.HTTPClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenAIProvider) constructRequestData(text string) (OpenAIEmbeddingRequest, error) {
|
|
||||||
if text == "" {
|
|
||||||
return OpenAIEmbeddingRequest{}, errors.New("queryString text cannot be empty")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if openAIConfig.apiKey == "" {
|
embeddingResp, err := e.client.Embeddings.New(ctx, params)
|
||||||
return OpenAIEmbeddingRequest{}, errors.New("openAI apiKey is empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
model := o.config.model
|
|
||||||
if model == "" {
|
|
||||||
model = OPENAI_DEFAULT_MODEL_NAME
|
|
||||||
}
|
|
||||||
|
|
||||||
data := OpenAIEmbeddingRequest{
|
|
||||||
Input: text,
|
|
||||||
Model: model,
|
|
||||||
}
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OpenAIProvider) parseTextEmbedding(responseBody []byte) (*OpenAIResponse, error) {
|
|
||||||
var resp OpenAIResponse
|
|
||||||
err := json.Unmarshal(responseBody, &resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||||
}
|
}
|
||||||
return &resp, nil
|
|
||||||
}
|
if len(embeddingResp.Data) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty embedding response")
|
||||||
func (o *OpenAIProvider) GetEmbedding(ctx context.Context, queryString string) ([]float32, error) {
|
}
|
||||||
requestData, err := o.constructRequestData(queryString)
|
|
||||||
if err != nil {
|
// Convert []float64 to []float32
|
||||||
return nil, fmt.Errorf("failed to construct request data: %v", err)
|
embedding := make([]float32, len(embeddingResp.Data[0].Embedding))
|
||||||
}
|
for i, v := range embeddingResp.Data[0].Embedding {
|
||||||
|
embedding[i] = float32(v)
|
||||||
responseBody, err := o.client.Post(OPENAI_ENDPOINT, requestData)
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send request: %v", err)
|
return embedding, nil
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := o.parseTextEmbedding(responseBody)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Error != nil {
|
|
||||||
return nil, fmt.Errorf("OpenAI API error: %s - %s", resp.Error.Type, resp.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Data) == 0 {
|
|
||||||
return nil, errors.New("no embedding found in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp.Data[0].Embedding, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,21 +10,21 @@ import (
|
|||||||
// Provider type constants for different embedding services
|
// Provider type constants for different embedding services
|
||||||
const (
|
const (
|
||||||
// DashScope embedding service
|
// DashScope embedding service
|
||||||
PROVIDER_TYPE_DASHSCOPE = "dashscope"
|
PROVIDER_TYPE_DASHSCOPE = "dashscope"
|
||||||
// TextIn embedding service
|
// TextIn embedding service
|
||||||
PROVIDER_TYPE_TEXTIN = "textin"
|
PROVIDER_TYPE_TEXTIN = "textin"
|
||||||
// Cohere embedding service
|
// Cohere embedding service
|
||||||
PROVIDER_TYPE_COHERE = "cohere"
|
PROVIDER_TYPE_COHERE = "cohere"
|
||||||
// OpenAI embedding service
|
// OpenAI embedding service
|
||||||
PROVIDER_TYPE_OPENAI = "openai"
|
PROVIDER_TYPE_OPENAI = "openai"
|
||||||
// Ollama embedding service
|
// Ollama embedding service
|
||||||
PROVIDER_TYPE_OLLAMA = "ollama"
|
PROVIDER_TYPE_OLLAMA = "ollama"
|
||||||
// HuggingFace embedding service
|
// HuggingFace embedding service
|
||||||
PROVIDER_TYPE_HUGGINGFACE = "huggingface"
|
PROVIDER_TYPE_HUGGINGFACE = "huggingface"
|
||||||
// XFYun embedding service
|
// XFYun embedding service
|
||||||
PROVIDER_TYPE_XFYUN = "xfyun"
|
PROVIDER_TYPE_XFYUN = "xfyun"
|
||||||
// Azure embedding service
|
// Azure embedding service
|
||||||
PROVIDER_TYPE_AZURE = "azure"
|
PROVIDER_TYPE_AZURE = "azure"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Factory interface for creating Provider instances
|
// Factory interface for creating Provider instances
|
||||||
@@ -36,8 +36,7 @@ type providerInitializer interface {
|
|||||||
// Maps provider types to their initializers
|
// Maps provider types to their initializers
|
||||||
var (
|
var (
|
||||||
providerInitializers = map[string]providerInitializer{
|
providerInitializers = map[string]providerInitializer{
|
||||||
PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{},
|
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
||||||
PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{},
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,133 +2,105 @@ package llm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/common"
|
|
||||||
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||||
|
"github.com/openai/openai-go/v2"
|
||||||
|
"github.com/openai/openai-go/v2/option"
|
||||||
|
"github.com/openai/openai-go/v2/packages/param"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OPENAI_CHAT_ENDPOINT = "/chat/completions"
|
|
||||||
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
OPENAI_DEFAULT_MODEL = "gpt-4o"
|
||||||
)
|
)
|
||||||
|
|
||||||
// openAI specific configuration captured after initialization.
|
type OpenAIProvider struct {
|
||||||
type openAIProviderConfig struct {
|
client *openai.Client
|
||||||
apiKey string
|
|
||||||
baseURL string
|
|
||||||
model string
|
model string
|
||||||
maxTokens int
|
|
||||||
temperature float64
|
temperature float64
|
||||||
|
maxTokens int
|
||||||
}
|
}
|
||||||
|
|
||||||
type openAIProviderInitializer struct{}
|
type openAIProviderInitializer struct{}
|
||||||
|
|
||||||
var openAIConfig openAIProviderConfig
|
func (i *openAIProviderInitializer) validateConfig(cfg *config.LLMConfig) error {
|
||||||
|
if cfg.APIKey == "" {
|
||||||
func (i *openAIProviderInitializer) initConfig(c config.LLMConfig) {
|
|
||||||
openAIConfig.apiKey = c.APIKey
|
|
||||||
openAIConfig.baseURL = c.BaseURL
|
|
||||||
openAIConfig.model = c.Model
|
|
||||||
if openAIConfig.model == "" {
|
|
||||||
openAIConfig.model = OPENAI_DEFAULT_MODEL
|
|
||||||
}
|
|
||||||
if openAIConfig.baseURL == "" {
|
|
||||||
openAIConfig.baseURL = "https://api.openai.com/v1" // default public endpoint
|
|
||||||
}
|
|
||||||
openAIConfig.maxTokens = c.MaxTokens
|
|
||||||
openAIConfig.temperature = c.Temperature
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *openAIProviderInitializer) validateConfig() error {
|
|
||||||
if openAIConfig.apiKey == "" {
|
|
||||||
return errors.New("[openai llm] apiKey is required")
|
return errors.New("[openai llm] apiKey is required")
|
||||||
}
|
}
|
||||||
|
if cfg.Model == "" {
|
||||||
|
cfg.Model = OPENAI_DEFAULT_MODEL
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Temperature <= 0 || cfg.Temperature > 2 {
|
||||||
|
cfg.Temperature = 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.MaxTokens <= 0 {
|
||||||
|
cfg.MaxTokens = 2048
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
|
func (i *openAIProviderInitializer) CreateProvider(cfg config.LLMConfig) (Provider, error) {
|
||||||
i.initConfig(cfg)
|
if err := i.validateConfig(&cfg); err != nil {
|
||||||
if err := i.validateConfig(); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
headers := map[string]string{
|
// Create OpenAI client
|
||||||
"Authorization": "Bearer " + openAIConfig.apiKey,
|
var clientOptions []option.RequestOption
|
||||||
"Content-Type": "application/json",
|
clientOptions = append(clientOptions, option.WithAPIKey(cfg.APIKey))
|
||||||
|
|
||||||
|
// If a custom baseURL is set, use it
|
||||||
|
if cfg.BaseURL != "" {
|
||||||
|
clientOptions = append(clientOptions, option.WithBaseURL(cfg.BaseURL))
|
||||||
}
|
}
|
||||||
client := common.NewHTTPClient(openAIConfig.baseURL, headers)
|
|
||||||
return &OpenAIProvider{client: client, cfg: openAIConfig}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIProvider struct {
|
// Create OpenAI client
|
||||||
client *common.HTTPClient
|
client := openai.NewClient(clientOptions...)
|
||||||
cfg openAIProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIChatCompletionRequest struct {
|
return &OpenAIProvider{
|
||||||
Model string `json:"model"`
|
client: &client,
|
||||||
Messages []openAIChatMessage `json:"messages"`
|
model: cfg.Model,
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
temperature: cfg.Temperature,
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
maxTokens: cfg.MaxTokens,
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
type openAIChatMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIChatCompletionResponse struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Choices []openAIChatCompletionResponseChoice `json:"choices"`
|
|
||||||
Error *openAIError `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIChatCompletionResponseChoice struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
Message openAIChatMessage `json:"message"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type openAIError struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Code string `json:"code"`
|
|
||||||
Param string `json:"param"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateCompletion implements Provider interface.
|
// GenerateCompletion implements Provider interface.
|
||||||
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
|
func (o *OpenAIProvider) GenerateCompletion(ctx context.Context, prompt string) (string, error) {
|
||||||
req := openAIChatCompletionRequest{
|
// Create chat request
|
||||||
Model: o.cfg.model,
|
params := openai.ChatCompletionNewParams{
|
||||||
Messages: []openAIChatMessage{
|
Model: o.model,
|
||||||
{Role: "user", Content: prompt},
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
|
openai.UserMessage(prompt),
|
||||||
},
|
},
|
||||||
Temperature: o.cfg.temperature,
|
|
||||||
MaxTokens: o.cfg.maxTokens,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := o.client.Post(OPENAI_CHAT_ENDPOINT, req)
|
// Set optional parameters
|
||||||
|
if o.temperature > 0 {
|
||||||
|
temperature := float64(o.temperature)
|
||||||
|
params.Temperature = param.Opt[float64]{Value: temperature}
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.maxTokens > 0 {
|
||||||
|
maxTokens := int64(o.maxTokens)
|
||||||
|
params.MaxTokens = param.Opt[int64]{Value: maxTokens}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
response, err := o.client.Chat.Completions.New(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("openai llm post error: %w", err)
|
// Handle error
|
||||||
|
return "", fmt.Errorf("openai llm error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp openAIChatCompletionResponse
|
// Check response
|
||||||
if err := json.Unmarshal(body, &resp); err != nil {
|
if len(response.Choices) == 0 {
|
||||||
return "", fmt.Errorf("openai llm unmarshal error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Error != nil {
|
|
||||||
return "", fmt.Errorf("openai llm api error: %s - %s", resp.Error.Type, resp.Error.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Choices) == 0 {
|
|
||||||
return "", errors.New("openai llm: empty choices")
|
return "", errors.New("openai llm: empty choices")
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp.Choices[0].Message.Content, nil
|
// Return generated content
|
||||||
|
return response.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OpenAIProvider) GetProviderType() string {
|
func (o *OpenAIProvider) GetProviderType() string {
|
||||||
|
|||||||
@@ -56,18 +56,12 @@ func NewRAGClient(config *config.Config) (*RAGClient, error) {
|
|||||||
ragclient.llmProvider = llmProvider
|
ragclient.llmProvider = llmProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
demoVector, err := embeddingProvider.GetEmbedding(context.Background(), "initialization")
|
dim := ragclient.config.Embedding.Dimensions
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create init embedding failed, err: %w", err)
|
|
||||||
}
|
|
||||||
dim := len(demoVector)
|
|
||||||
|
|
||||||
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
|
provider, err := vectordb.NewVectorDBProvider(&ragclient.config.VectorDB, dim)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
|
return nil, fmt.Errorf("create vector store provider failed, err: %w", err)
|
||||||
}
|
}
|
||||||
ragclient.vectordbProvider = provider
|
ragclient.vectordbProvider = provider
|
||||||
|
|
||||||
return ragclient, nil
|
return ragclient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,15 +22,17 @@ func getRAGClient() (*RAGClient, error) {
|
|||||||
|
|
||||||
LLM: config.LLMConfig{
|
LLM: config.LLMConfig{
|
||||||
Provider: "openai",
|
Provider: "openai",
|
||||||
APIKey: "sk-xxxx",
|
APIKey: "sk-xxx",
|
||||||
BaseURL: "https://openrouter.ai/api/v1",
|
BaseURL: "https://openrouter.ai/api/v1",
|
||||||
Model: "openai/gpt-4o",
|
Model: "openai/gpt-4o",
|
||||||
},
|
},
|
||||||
|
|
||||||
Embedding: config.EmbeddingConfig{
|
Embedding: config.EmbeddingConfig{
|
||||||
Provider: "dashscope",
|
Provider: "openai",
|
||||||
APIKey: "sk-xxxx",
|
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
Model: "text-embedding-v4",
|
APIKey: "sk-xxxx",
|
||||||
|
Model: "text-embedding-v4",
|
||||||
|
Dimensions: 1536,
|
||||||
},
|
},
|
||||||
|
|
||||||
VectorDB: config.VectorDBConfig{
|
VectorDB: config.VectorDBConfig{
|
||||||
@@ -38,7 +40,49 @@ func getRAGClient() (*RAGClient, error) {
|
|||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
Port: 19530,
|
Port: 19530,
|
||||||
Database: "default",
|
Database: "default",
|
||||||
Collection: "test_collection",
|
Collection: "test_collection3",
|
||||||
|
Mapping: config.MappingConfig{
|
||||||
|
Fields: []config.FieldMapping{
|
||||||
|
{
|
||||||
|
StandardName: "id",
|
||||||
|
RawName: "pk",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 256,
|
||||||
|
"auto_id": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "content",
|
||||||
|
RawName: "page_content",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 8192,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "vector",
|
||||||
|
RawName: "page_vector",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "metadata",
|
||||||
|
RawName: "metadata",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "created_at",
|
||||||
|
RawName: "created_at",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Index: config.IndexConfig{
|
||||||
|
IndexType: "IVF_FLAT",
|
||||||
|
Params: map[string]interface{}{"nlist": 64},
|
||||||
|
},
|
||||||
|
Search: config.SearchConfig{
|
||||||
|
MetricType: "COSINE",
|
||||||
|
Params: map[string]interface{}{"nprobe": 32},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,7 +92,6 @@ func getRAGClient() (*RAGClient, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ragClient, nil
|
return ragClient, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewRAGClient(t *testing.T) {
|
func TestNewRAGClient(t *testing.T) {
|
||||||
@@ -104,7 +147,7 @@ func TestRAGClient_DeleteChunk(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
chunk_id := "63ee25d7-41b9-4455-8066-075ca5c803b2"
|
chunk_id := "2a06679c-a8ea-46dc-bf1c-7e7b164a73c8"
|
||||||
err = ragClient.DeleteChunk(chunk_id)
|
err = ragClient.DeleteChunk(chunk_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("DeleteChunk() error = %v", err)
|
t.Errorf("DeleteChunk() error = %v", err)
|
||||||
|
|||||||
@@ -36,11 +36,11 @@ func init() {
|
|||||||
MaxTokens: 2048,
|
MaxTokens: 2048,
|
||||||
},
|
},
|
||||||
Embedding: config.EmbeddingConfig{
|
Embedding: config.EmbeddingConfig{
|
||||||
Provider: "dashscope",
|
Provider: "openai",
|
||||||
APIKey: "",
|
APIKey: "",
|
||||||
BaseURL: "",
|
BaseURL: "",
|
||||||
Model: "text-embedding-v4",
|
Model: "text-embedding-ada-002",
|
||||||
Dimension: 1024,
|
Dimensions: 1536,
|
||||||
},
|
},
|
||||||
VectorDB: config.VectorDBConfig{
|
VectorDB: config.VectorDBConfig{
|
||||||
Provider: "milvus",
|
Provider: "milvus",
|
||||||
@@ -50,14 +50,56 @@ func init() {
|
|||||||
Collection: "rag",
|
Collection: "rag",
|
||||||
Username: "",
|
Username: "",
|
||||||
Password: "",
|
Password: "",
|
||||||
|
Mapping: config.MappingConfig{
|
||||||
|
Fields: []config.FieldMapping{
|
||||||
|
{
|
||||||
|
StandardName: "id",
|
||||||
|
RawName: "id",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 256,
|
||||||
|
"auto_id": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "content",
|
||||||
|
RawName: "content",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 8192,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "vector",
|
||||||
|
RawName: "vector",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "metadata",
|
||||||
|
RawName: "metadata",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "created_at",
|
||||||
|
RawName: "created_at",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Index: config.IndexConfig{
|
||||||
|
IndexType: "HNSW",
|
||||||
|
Params: map[string]interface{}{"M": 8, "efConstruction": 64},
|
||||||
|
},
|
||||||
|
Search: config.SearchConfig{
|
||||||
|
MetricType: "IP",
|
||||||
|
Params: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
func (c *RAGConfig) ParseConfig(cfg map[string]any) error {
|
||||||
// Parse RAG configuration
|
// Parse RAG configuration
|
||||||
if ragConfig, ok := config["rag"].(map[string]any); ok {
|
if ragConfig, ok := cfg["rag"].(map[string]any); ok {
|
||||||
if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
|
if splitter, exists := ragConfig["splitter"].(map[string]any); exists {
|
||||||
if splitterType, exists := splitter["provider"].(string); exists {
|
if splitterType, exists := splitter["provider"].(string); exists {
|
||||||
c.config.RAG.Splitter.Provider = splitterType
|
c.config.RAG.Splitter.Provider = splitterType
|
||||||
@@ -78,7 +120,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse Embedding configuration
|
// Parse Embedding configuration
|
||||||
if embeddingConfig, ok := config["embedding"].(map[string]any); ok {
|
if embeddingConfig, ok := cfg["embedding"].(map[string]any); ok {
|
||||||
if provider, exists := embeddingConfig["provider"].(string); exists {
|
if provider, exists := embeddingConfig["provider"].(string); exists {
|
||||||
c.config.Embedding.Provider = provider
|
c.config.Embedding.Provider = provider
|
||||||
} else {
|
} else {
|
||||||
@@ -94,13 +136,13 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
|||||||
if model, exists := embeddingConfig["model"].(string); exists {
|
if model, exists := embeddingConfig["model"].(string); exists {
|
||||||
c.config.Embedding.Model = model
|
c.config.Embedding.Model = model
|
||||||
}
|
}
|
||||||
if dimension, exists := embeddingConfig["dimension"].(float64); exists {
|
if dimensions, exists := embeddingConfig["dimensions"].(float64); exists {
|
||||||
c.config.Embedding.Dimension = int(dimension)
|
c.config.Embedding.Dimensions = int(dimensions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse llm configuration
|
// Parse llm configuration
|
||||||
if llmConfig, ok := config["llm"].(map[string]any); ok {
|
if llmConfig, ok := cfg["llm"].(map[string]any); ok {
|
||||||
if provider, exists := llmConfig["provider"].(string); exists {
|
if provider, exists := llmConfig["provider"].(string); exists {
|
||||||
c.config.LLM.Provider = provider
|
c.config.LLM.Provider = provider
|
||||||
}
|
}
|
||||||
@@ -122,7 +164,7 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse VectorDB configuration
|
// Parse VectorDB configuration
|
||||||
if vectordbConfig, ok := config["vectordb"].(map[string]any); ok {
|
if vectordbConfig, ok := cfg["vectordb"].(map[string]any); ok {
|
||||||
if provider, exists := vectordbConfig["provider"].(string); exists {
|
if provider, exists := vectordbConfig["provider"].(string); exists {
|
||||||
c.config.VectorDB.Provider = provider
|
c.config.VectorDB.Provider = provider
|
||||||
} else {
|
} else {
|
||||||
@@ -146,8 +188,59 @@ func (c *RAGConfig) ParseConfig(config map[string]any) error {
|
|||||||
if password, exists := vectordbConfig["password"].(string); exists {
|
if password, exists := vectordbConfig["password"].(string); exists {
|
||||||
c.config.VectorDB.Password = password
|
c.config.VectorDB.Password = password
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// Parse mapping here
|
||||||
|
if mapping, exists := vectordbConfig["mapping"].(map[string]any); exists {
|
||||||
|
// Parse field mappings
|
||||||
|
if fields, ok := mapping["fields"].([]any); ok {
|
||||||
|
c.config.VectorDB.Mapping.Fields = []config.FieldMapping{}
|
||||||
|
for _, field := range fields {
|
||||||
|
if fieldMap, ok := field.(map[string]any); ok {
|
||||||
|
fieldMapping := config.FieldMapping{
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
if standardName, ok := fieldMap["standard_name"].(string); ok {
|
||||||
|
fieldMapping.StandardName = standardName
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawName, ok := fieldMap["raw_name"].(string); ok {
|
||||||
|
fieldMapping.RawName = rawName
|
||||||
|
}
|
||||||
|
// Parse properties
|
||||||
|
if properties, ok := fieldMap["properties"].(map[string]any); ok {
|
||||||
|
for key, value := range properties {
|
||||||
|
fieldMapping.Properties[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.config.VectorDB.Mapping.Fields = append(c.config.VectorDB.Mapping.Fields, fieldMapping)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse index configuration
|
||||||
|
if index, ok := mapping["index"].(map[string]any); ok {
|
||||||
|
if indexType, ok := index["index_type"].(string); ok {
|
||||||
|
c.config.VectorDB.Mapping.Index.IndexType = indexType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse index parameters
|
||||||
|
if params, ok := index["params"].(map[string]any); ok {
|
||||||
|
c.config.VectorDB.Mapping.Index.Params = params
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse search configuration
|
||||||
|
if search, ok := mapping["search"].(map[string]any); ok {
|
||||||
|
if metricType, ok := search["metric_type"].(string); ok {
|
||||||
|
c.config.VectorDB.Mapping.Search.MetricType = metricType
|
||||||
|
}
|
||||||
|
// Parse search parameters
|
||||||
|
if params, ok := search["params"].(map[string]any); ok {
|
||||||
|
c.config.VectorDB.Mapping.Search.Params = params
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
|
|||||||
MaxTokens: 2048,
|
MaxTokens: 2048,
|
||||||
},
|
},
|
||||||
Embedding: config.EmbeddingConfig{
|
Embedding: config.EmbeddingConfig{
|
||||||
Provider: "dashscope",
|
Provider: "dashscope",
|
||||||
APIKey: "sk-XXX",
|
APIKey: "sk-XXX",
|
||||||
BaseURL: "",
|
BaseURL: "",
|
||||||
Model: "text-embedding-v4",
|
Model: "text-embedding-v4",
|
||||||
Dimension: 1024,
|
Dimensions: 1024,
|
||||||
},
|
},
|
||||||
VectorDB: config.VectorDBConfig{
|
VectorDB: config.VectorDBConfig{
|
||||||
Provider: "milvus",
|
Provider: "milvus",
|
||||||
@@ -42,6 +42,48 @@ func TestRAGConfig_ParseConfig(t *testing.T) {
|
|||||||
Collection: "test_rag",
|
Collection: "test_rag",
|
||||||
Username: "",
|
Username: "",
|
||||||
Password: "",
|
Password: "",
|
||||||
|
Mapping: config.MappingConfig{
|
||||||
|
Fields: []config.FieldMapping{
|
||||||
|
{
|
||||||
|
StandardName: "id",
|
||||||
|
RawName: "id",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 256,
|
||||||
|
"auto_id": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "content",
|
||||||
|
RawName: "content",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 8192,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "vector",
|
||||||
|
RawName: "vector",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "metadata",
|
||||||
|
RawName: "metadata",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "created_at",
|
||||||
|
RawName: "created_at",
|
||||||
|
Properties: make(map[string]interface{}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Index: config.IndexConfig{
|
||||||
|
IndexType: "HNSW",
|
||||||
|
Params: map[string]interface{}{"M": 4, "efConstruction": 32},
|
||||||
|
},
|
||||||
|
Search: config.SearchConfig{
|
||||||
|
MetricType: "IP",
|
||||||
|
Params: map[string]interface{}{"ef": 32},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// 把 config 输出 yaml 格式
|
// 把 config 输出 yaml 格式
|
||||||
|
|||||||
182
plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go
Normal file
182
plugins/golang-filter/mcp-server/servers/rag/vectordb/mapper.go
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
package vectordb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/alibaba/higress/plugins/golang-filter/mcp-server/servers/rag/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error definitions
|
||||||
|
var (
|
||||||
|
ErrFieldNotFound = errors.New("field not found")
|
||||||
|
ErrInvalidFieldType = errors.New("invalid field type")
|
||||||
|
ErrInvalidIndexType = errors.New("invalid index type")
|
||||||
|
ErrInvalidMetricType = errors.New("invalid metric type")
|
||||||
|
ErrInvalidSearchParams = errors.New("invalid search parameters")
|
||||||
|
ErrCollectionNotFound = errors.New("collection not found")
|
||||||
|
ErrUnsupportedOperation = errors.New("unsupported operation")
|
||||||
|
)
|
||||||
|
|
||||||
|
// VectorDBMapper interface for vector database mapping
|
||||||
|
type VectorDBMapper interface {
|
||||||
|
// ParseMapping parses the mapping configuration
|
||||||
|
ParseMapping(provider string, cfg config.MappingConfig) error
|
||||||
|
|
||||||
|
// GetIndexConfig returns the index configuration
|
||||||
|
GetIndexConfig() (config.IndexConfig, error)
|
||||||
|
|
||||||
|
// GetSearchConfig returns the search configuration
|
||||||
|
GetSearchConfig() (config.SearchConfig, error)
|
||||||
|
|
||||||
|
// Get all raw field names
|
||||||
|
GetRawAllFieldNames() ([]string, error)
|
||||||
|
|
||||||
|
// GetIDField returns the ID field mapping
|
||||||
|
GetIDField() (*config.FieldMapping, error)
|
||||||
|
|
||||||
|
// GetVectorField returns the vector field mapping
|
||||||
|
GetVectorField() (*config.FieldMapping, error)
|
||||||
|
|
||||||
|
// Get raw field name by standard field name
|
||||||
|
GetRawField(standardFieldName string) (*config.FieldMapping, error)
|
||||||
|
|
||||||
|
// Get field mapping by raw field name
|
||||||
|
GetField(rawFieldName string) (*config.FieldMapping, error)
|
||||||
|
|
||||||
|
// Get all field mappings
|
||||||
|
GetFieldMappings() ([]config.FieldMapping, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultVectorDBMapper is the default implementation of VectorDBMapper interface
|
||||||
|
type DefaultVectorDBMapper struct {
|
||||||
|
// Mapping configuration
|
||||||
|
mappingConfig config.MappingConfig
|
||||||
|
// Map from standard field name to field mapping
|
||||||
|
standardFieldMap map[string]*config.FieldMapping
|
||||||
|
// Map from raw field name to field mapping
|
||||||
|
rawFieldMap map[string]*config.FieldMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultVectorDBMapper creates a new default vector database mapper
|
||||||
|
func NewDefaultVectorDBMapper(provider string, mappingConfig config.MappingConfig) (*DefaultVectorDBMapper, error) {
|
||||||
|
mapper := &DefaultVectorDBMapper{
|
||||||
|
standardFieldMap: make(map[string]*config.FieldMapping),
|
||||||
|
rawFieldMap: make(map[string]*config.FieldMapping),
|
||||||
|
}
|
||||||
|
if err := mapper.ParseMapping(provider, mappingConfig); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return mapper, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseMapping parses the mapping configuration
|
||||||
|
func (m *DefaultVectorDBMapper) ParseMapping(provider string, cfg config.MappingConfig) error {
|
||||||
|
m.mappingConfig = cfg
|
||||||
|
// Clear existing mappings
|
||||||
|
m.standardFieldMap = make(map[string]*config.FieldMapping)
|
||||||
|
m.rawFieldMap = make(map[string]*config.FieldMapping)
|
||||||
|
// fill default field mappings
|
||||||
|
if len(cfg.Fields) == 0 {
|
||||||
|
defaultFields := []config.FieldMapping{
|
||||||
|
{
|
||||||
|
StandardName: "id",
|
||||||
|
RawName: "id",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 256,
|
||||||
|
"auto_id": false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "content",
|
||||||
|
RawName: "content",
|
||||||
|
Properties: map[string]interface{}{
|
||||||
|
"max_length": 8192,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "vector",
|
||||||
|
RawName: "vector",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "metadata",
|
||||||
|
RawName: "metadata",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
StandardName: "created_at",
|
||||||
|
RawName: "created_at",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cfg.Fields = defaultFields
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse field mappings
|
||||||
|
for i, field := range cfg.Fields {
|
||||||
|
// Save pointer for future reference
|
||||||
|
fieldPtr := &cfg.Fields[i]
|
||||||
|
m.standardFieldMap[field.StandardName] = fieldPtr
|
||||||
|
m.rawFieldMap[field.RawName] = fieldPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check fields, must include id, content, vector fields
|
||||||
|
requiredFields := []string{"id", "content", "vector"}
|
||||||
|
for _, fieldName := range requiredFields {
|
||||||
|
if _, err := m.GetRawField(fieldName); err != nil {
|
||||||
|
return fmt.Errorf("[vector db mapper] required field %s not found or not varchar type", fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIndexConfig gets the index configuration
|
||||||
|
func (m *DefaultVectorDBMapper) GetIndexConfig() (config.IndexConfig, error) {
|
||||||
|
return m.mappingConfig.Index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSearchConfig gets the search configuration
|
||||||
|
func (m *DefaultVectorDBMapper) GetSearchConfig() (config.SearchConfig, error) {
|
||||||
|
return m.mappingConfig.Search, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRawAllFieldNames gets all raw field names
|
||||||
|
func (m *DefaultVectorDBMapper) GetRawAllFieldNames() ([]string, error) {
|
||||||
|
fieldNames := make([]string, 0, len(m.rawFieldMap))
|
||||||
|
for name := range m.rawFieldMap {
|
||||||
|
fieldNames = append(fieldNames, name)
|
||||||
|
}
|
||||||
|
return fieldNames, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIDField gets the ID field
|
||||||
|
func (m *DefaultVectorDBMapper) GetIDField() (*config.FieldMapping, error) {
|
||||||
|
return m.GetRawField("id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVectorField gets the vector field
|
||||||
|
func (m *DefaultVectorDBMapper) GetVectorField() (*config.FieldMapping, error) {
|
||||||
|
return m.GetRawField("vector")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRawField gets the raw field mapping by standard field name
|
||||||
|
func (m *DefaultVectorDBMapper) GetRawField(standardFieldName string) (*config.FieldMapping, error) {
|
||||||
|
field, exists := m.standardFieldMap[standardFieldName]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("%w: standard field %s not found", ErrFieldNotFound, standardFieldName)
|
||||||
|
}
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetField gets the field mapping by raw field name
|
||||||
|
func (m *DefaultVectorDBMapper) GetField(rawFieldName string) (*config.FieldMapping, error) {
|
||||||
|
field, exists := m.rawFieldMap[rawFieldName]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("%w: raw field %s not found", ErrFieldNotFound, rawFieldName)
|
||||||
|
}
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFieldMappings gets all field mappings
|
||||||
|
func (m *DefaultVectorDBMapper) GetFieldMappings() ([]config.FieldMapping, error) {
|
||||||
|
return m.mappingConfig.Fields, nil
|
||||||
|
}
|
||||||
@@ -80,16 +80,17 @@ func (m *milvusProviderInitializer) CreateProvider(cfg *config.VectorDBConfig, d
|
|||||||
type MilvusProvider struct {
|
type MilvusProvider struct {
|
||||||
client client.Client
|
client client.Client
|
||||||
config *config.VectorDBConfig
|
config *config.VectorDBConfig
|
||||||
Collection string
|
collection string
|
||||||
|
mapper VectorDBMapper
|
||||||
|
dimensions int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMilvusProvider creates a new instance of MilvusProvider
|
// NewMilvusProvider creates a new instance of MilvusProvider
|
||||||
func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider, error) {
|
func NewMilvusProvider(cfg *config.VectorDBConfig, dimensions int) (VectorStoreProvider, error) {
|
||||||
// Create Milvus client
|
// Create Milvus client
|
||||||
connectParam := client.Config{
|
connectParam := client.Config{
|
||||||
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
Address: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||||
}
|
}
|
||||||
|
|
||||||
connectParam.DBName = cfg.Database
|
connectParam.DBName = cfg.Database
|
||||||
// Add authentication if credentials are provided
|
// Add authentication if credentials are provided
|
||||||
if cfg.Username != "" && cfg.Password != "" {
|
if cfg.Username != "" && cfg.Password != "" {
|
||||||
@@ -102,92 +103,301 @@ func NewMilvusProvider(cfg *config.VectorDBConfig, dim int) (VectorStoreProvider
|
|||||||
return nil, fmt.Errorf("failed to create milvus client: %w", err)
|
return nil, fmt.Errorf("failed to create milvus client: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mapper, err := NewDefaultVectorDBMapper(MILVUS_PROVIDER_TYPE, cfg.Mapping)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create default vector db mapper: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
provider := &MilvusProvider{
|
provider := &MilvusProvider{
|
||||||
client: milvusClient,
|
client: milvusClient,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
Collection: cfg.Collection,
|
collection: cfg.Collection,
|
||||||
|
mapper: mapper,
|
||||||
|
dimensions: dimensions,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if err := provider.CreateCollection(ctx, dim); err != nil {
|
if err := provider.CreateCollection(ctx, dimensions); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return provider, nil
|
return provider, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MilvusProvider) buildSchema() (*entity.Schema, error) {
|
||||||
|
// Create Milvus collection Schema
|
||||||
|
idField, _ := m.mapper.GetIDField()
|
||||||
|
isIDAuto := idField.IsAutoID()
|
||||||
|
schema := entity.NewSchema().
|
||||||
|
WithName(m.collection).
|
||||||
|
WithDescription("Knowledge document collection").
|
||||||
|
WithAutoID(isIDAuto).
|
||||||
|
WithDynamicFieldEnabled(false)
|
||||||
|
// Add fields
|
||||||
|
var fieldEntity *entity.Field
|
||||||
|
fieldMappings, _ := m.mapper.GetFieldMappings()
|
||||||
|
for _, field := range fieldMappings {
|
||||||
|
fieldEntity = nil
|
||||||
|
maxLength := field.MaxLength()
|
||||||
|
switch field.StandardName {
|
||||||
|
case "id":
|
||||||
|
isIDAuto := field.IsAutoID()
|
||||||
|
fieldEntity = entity.NewField().
|
||||||
|
WithName(field.RawName).
|
||||||
|
WithDataType(entity.FieldTypeVarChar).
|
||||||
|
WithMaxLength(int64(maxLength)).
|
||||||
|
WithIsPrimaryKey(true)
|
||||||
|
if isIDAuto {
|
||||||
|
fieldEntity.WithIsAutoID(true)
|
||||||
|
}
|
||||||
|
schema.WithField(fieldEntity)
|
||||||
|
case "content":
|
||||||
|
fieldEntity = entity.NewField().
|
||||||
|
WithName(field.RawName).
|
||||||
|
WithDataType(entity.FieldTypeVarChar).
|
||||||
|
WithMaxLength(int64(maxLength))
|
||||||
|
schema.WithField(fieldEntity)
|
||||||
|
case "vector":
|
||||||
|
fieldEntity = entity.NewField().
|
||||||
|
WithName(field.RawName).
|
||||||
|
WithDataType(entity.FieldTypeFloatVector).
|
||||||
|
WithDim(int64(m.dimensions))
|
||||||
|
schema.WithField(fieldEntity)
|
||||||
|
case "metadata":
|
||||||
|
fieldEntity = entity.NewField().
|
||||||
|
WithName(field.RawName).
|
||||||
|
WithDataType(entity.FieldTypeJSON)
|
||||||
|
schema.WithField(fieldEntity)
|
||||||
|
case "created_at":
|
||||||
|
fieldEntity = entity.NewField().
|
||||||
|
WithName(field.RawName).
|
||||||
|
WithDataType(entity.FieldTypeInt64)
|
||||||
|
schema.WithField(fieldEntity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return schema, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MilvusProvider) GetMetricType(metricType string) entity.MetricType {
|
||||||
|
switch strings.ToUpper(metricType) {
|
||||||
|
case "L2":
|
||||||
|
return entity.L2
|
||||||
|
case "IP":
|
||||||
|
return entity.IP
|
||||||
|
case "COSINE":
|
||||||
|
return entity.COSINE
|
||||||
|
case "HAMMING":
|
||||||
|
return entity.HAMMING
|
||||||
|
case "JACCARD":
|
||||||
|
return entity.JACCARD
|
||||||
|
case "TANIMOTO":
|
||||||
|
return entity.TANIMOTO
|
||||||
|
case "SUBSTRUCTURE":
|
||||||
|
return entity.SUBSTRUCTURE
|
||||||
|
case "SUPERSTRUCTURE":
|
||||||
|
return entity.SUPERSTRUCTURE
|
||||||
|
default:
|
||||||
|
return entity.IP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MilvusProvider) buildVectorIndex() (entity.Index, error) {
|
||||||
|
// Map index type
|
||||||
|
indexConfig, _ := m.mapper.GetIndexConfig()
|
||||||
|
searchConfig, _ := m.mapper.GetSearchConfig()
|
||||||
|
// Map index parameters
|
||||||
|
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
|
||||||
|
if milvusIndexType == "" {
|
||||||
|
milvusIndexType = "HNSW"
|
||||||
|
}
|
||||||
|
metricType := m.GetMetricType(searchConfig.MetricType)
|
||||||
|
switch milvusIndexType {
|
||||||
|
case "FLAT":
|
||||||
|
// FLAT index doesn't need additional parameters
|
||||||
|
index, err := entity.NewIndexFlat(metricType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create FLAT index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "BIN_FLAT":
|
||||||
|
// BIN_FLAT index doesn't need additional parameters
|
||||||
|
nlist := 128
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
index, err := entity.NewIndexBinFlat(metricType, nlist)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create BIN_FLAT index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "IVF_FLAT":
|
||||||
|
// Default parameters
|
||||||
|
nlist := 128
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
index, err := entity.NewIndexIvfFlat(metricType, nlist)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create IVF_FLAT index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "BIN_IVF_FLAT":
|
||||||
|
// Default parameters
|
||||||
|
nlist := 128
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
index, err := entity.NewIndexBinIvfFlat(metricType, nlist)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create BIN_IVF_FLAT index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "IVF_SQ8":
|
||||||
|
// Default parameters
|
||||||
|
nlist := 128
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
index, err := entity.NewIndexIvfSQ8(metricType, nlist)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create IVF_SQ8 index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "IVF_PQ":
|
||||||
|
// Default parameters
|
||||||
|
nlist := 128
|
||||||
|
m := 4
|
||||||
|
nbits := 8
|
||||||
|
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
if mVal, err := indexConfig.ParamsFloat64("m"); err == nil {
|
||||||
|
m = int(mVal)
|
||||||
|
}
|
||||||
|
if nbitsVal, err := indexConfig.ParamsInt64("nbits"); err == nil {
|
||||||
|
nbits = int(nbitsVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err := entity.NewIndexIvfPQ(metricType, nlist, m, nbits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create IVF_PQ index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "HNSW":
|
||||||
|
// Default parameters
|
||||||
|
m := 8
|
||||||
|
efConstruction := 64
|
||||||
|
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
|
||||||
|
m = int(mVal)
|
||||||
|
}
|
||||||
|
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
|
||||||
|
efConstruction = int(efConstructionVal)
|
||||||
|
}
|
||||||
|
index, err := entity.NewIndexHNSW(metricType, m, efConstruction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HNSW index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "IVF_HNSW":
|
||||||
|
// Default parameters
|
||||||
|
nlist := 128
|
||||||
|
m := 8
|
||||||
|
efConstruction := 64
|
||||||
|
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
if mVal, err := indexConfig.ParamsInt64("M"); err == nil {
|
||||||
|
m = int(mVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
if efConstructionVal, err := indexConfig.ParamsInt64("efConstruction"); err == nil {
|
||||||
|
efConstruction = int(efConstructionVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err := entity.NewIndexIvfHNSW(metricType, nlist, m, efConstruction)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create IVF_HNSW index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "DISKANN":
|
||||||
|
// DISKANN index parameters
|
||||||
|
index, err := entity.NewIndexDISKANN(metricType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create DISKANN index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "SCANN":
|
||||||
|
// SCANN index parameters
|
||||||
|
nlist := 128
|
||||||
|
with_raw_data := false
|
||||||
|
if nlistVal, err := indexConfig.ParamsInt64("nlist"); err == nil {
|
||||||
|
nlist = int(nlistVal)
|
||||||
|
}
|
||||||
|
if with_raw_dataVal, err := indexConfig.ParamsBool("with_raw_data"); err == nil {
|
||||||
|
with_raw_data = with_raw_dataVal
|
||||||
|
}
|
||||||
|
index, err := entity.NewIndexSCANN(metricType, nlist, with_raw_data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create SCANN index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
case "AUTOINDEX":
|
||||||
|
// Auto index
|
||||||
|
index, err := entity.NewIndexAUTOINDEX(metricType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create AUTOINDEX index: %w", err)
|
||||||
|
}
|
||||||
|
return index, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported index type: %s", milvusIndexType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CreateCollection creates a new collection with the specified dimension
|
// CreateCollection creates a new collection with the specified dimension
|
||||||
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
|
func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
|
||||||
// Check if collection exists
|
// Check if collection exists
|
||||||
document_exists, err := m.client.HasCollection(ctx, m.Collection)
|
document_exists, err := m.client.HasCollection(ctx, m.collection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
|
return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !document_exists {
|
if !document_exists {
|
||||||
fmt.Printf("create collection %s\n", m.Collection)
|
fmt.Printf("create collection %s\n", m.collection)
|
||||||
// Create schema
|
// Create schema
|
||||||
schema := entity.NewSchema().
|
schema, err := m.buildSchema()
|
||||||
WithName(m.Collection).
|
if err != nil {
|
||||||
WithDescription("Knowledge document collection").
|
return fmt.Errorf("failed to build schema: %w", err)
|
||||||
WithAutoID(false).
|
}
|
||||||
WithDynamicFieldEnabled(false)
|
|
||||||
|
|
||||||
// Add fields based on schema.Document structure
|
|
||||||
// Primary key field - ID
|
|
||||||
pkField := entity.NewField().
|
|
||||||
WithName("id").
|
|
||||||
WithDataType(entity.FieldTypeVarChar).
|
|
||||||
WithMaxLength(256).
|
|
||||||
WithIsPrimaryKey(true).
|
|
||||||
WithIsAutoID(false)
|
|
||||||
schema.WithField(pkField)
|
|
||||||
|
|
||||||
// Content field
|
|
||||||
contentField := entity.NewField().
|
|
||||||
WithName("content").
|
|
||||||
WithDataType(entity.FieldTypeVarChar).
|
|
||||||
WithMaxLength(8192)
|
|
||||||
schema.WithField(contentField)
|
|
||||||
|
|
||||||
// Vector field
|
|
||||||
vectorField := entity.NewField().
|
|
||||||
WithName("vector").
|
|
||||||
WithDataType(entity.FieldTypeFloatVector).
|
|
||||||
WithDim(int64(dim))
|
|
||||||
schema.WithField(vectorField)
|
|
||||||
|
|
||||||
// Metadata field
|
|
||||||
metadataField := entity.NewField().
|
|
||||||
WithName("metadata").
|
|
||||||
WithDataType(entity.FieldTypeJSON)
|
|
||||||
schema.WithField(metadataField)
|
|
||||||
|
|
||||||
// CreatedAt field (stored as Unix timestamp)
|
|
||||||
createdAtField := entity.NewField().
|
|
||||||
WithName("created_at").
|
|
||||||
WithDataType(entity.FieldTypeInt64)
|
|
||||||
schema.WithField(createdAtField)
|
|
||||||
|
|
||||||
// Create collection
|
// Create collection
|
||||||
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
|
err = m.client.CreateCollection(ctx, schema, entity.DefaultShardNumber)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create collection: %w", err)
|
return fmt.Errorf("failed to create collection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create vector index
|
// Create vector index
|
||||||
vectorIndex, err := entity.NewIndexHNSW(entity.IP, 8, 64)
|
vectorIndex, err := m.buildVectorIndex()
|
||||||
|
vectorField, _ := m.mapper.GetVectorField()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create vector index: %w", err)
|
return fmt.Errorf("failed to create vector index: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.client.CreateIndex(ctx, m.Collection, "vector", vectorIndex, false, client.WithIndexName("vector_index"))
|
err = m.client.CreateIndex(ctx, m.collection, vectorField.RawName, vectorIndex, false, client.WithIndexName("vector_index"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create vector index: %w", err)
|
return fmt.Errorf("failed to create vector index: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load collection
|
// Load collection
|
||||||
err = m.client.LoadCollection(ctx, m.Collection, false)
|
err = m.client.LoadCollection(ctx, m.collection, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load document collection: %w", err)
|
return fmt.Errorf("failed to load document collection: %w", err)
|
||||||
}
|
}
|
||||||
@@ -197,15 +407,15 @@ func (m *MilvusProvider) CreateCollection(ctx context.Context, dim int) error {
|
|||||||
// DropCollection removes the collection from the database
|
// DropCollection removes the collection from the database
|
||||||
func (m *MilvusProvider) DropCollection(ctx context.Context) error {
|
func (m *MilvusProvider) DropCollection(ctx context.Context) error {
|
||||||
// Check if collection exists
|
// Check if collection exists
|
||||||
exists, err := m.client.HasCollection(ctx, m.Collection)
|
exists, err := m.client.HasCollection(ctx, m.collection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to check %s collection existence: %w", m.Collection, err)
|
return fmt.Errorf("failed to check %s collection existence: %w", m.collection, err)
|
||||||
}
|
}
|
||||||
if !exists {
|
if !exists {
|
||||||
return fmt.Errorf("collection %s does not exist", m.Collection)
|
return fmt.Errorf("collection %s does not exist", m.collection)
|
||||||
}
|
}
|
||||||
// Drop collection
|
// Drop collection
|
||||||
err = m.client.DropCollection(ctx, m.Collection)
|
err = m.client.DropCollection(ctx, m.collection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to drop collection: %w", err)
|
return fmt.Errorf("failed to drop collection: %w", err)
|
||||||
}
|
}
|
||||||
@@ -217,51 +427,71 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
|
|||||||
if len(docs) == 0 {
|
if len(docs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Prepare data
|
|
||||||
ids := make([]string, len(docs))
|
|
||||||
contents := make([]string, len(docs))
|
|
||||||
vectors := make([][]float32, len(docs))
|
|
||||||
metadatas := make([][]byte, len(docs))
|
|
||||||
createdAts := make([]int64, len(docs))
|
|
||||||
|
|
||||||
for i, doc := range docs {
|
// Get field mappings
|
||||||
ids[i] = doc.ID
|
fieldMappings, err := m.mapper.GetFieldMappings()
|
||||||
contents[i] = doc.Content
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get field mappings: %w", err)
|
||||||
// Convert vector type
|
|
||||||
vectorFloat32 := make([]float32, len(doc.Vector))
|
|
||||||
for j, v := range doc.Vector {
|
|
||||||
vectorFloat32[j] = float32(v)
|
|
||||||
}
|
|
||||||
vectors[i] = vectorFloat32
|
|
||||||
|
|
||||||
// Serialize metadata
|
|
||||||
metadataBytes, err := json.Marshal(doc.Metadata)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
|
|
||||||
}
|
|
||||||
metadatas[i] = metadataBytes
|
|
||||||
|
|
||||||
createdAts[i] = doc.CreatedAt.UnixMilli()
|
|
||||||
}
|
}
|
||||||
|
// Prepare data and columns
|
||||||
|
columns := make([]entity.Column, 0, len(fieldMappings))
|
||||||
|
// Create corresponding column data for each field
|
||||||
|
for _, field := range fieldMappings {
|
||||||
|
// Skip ID field if configured as auto ID
|
||||||
|
if field.IsPrimaryKey() && field.IsAutoID() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch field.StandardName {
|
||||||
|
case "id":
|
||||||
|
// Handle string type fields
|
||||||
|
values := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
values[i] = doc.ID
|
||||||
|
}
|
||||||
|
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
|
||||||
|
case "content":
|
||||||
|
values := make([]string, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
values[i] = doc.Content
|
||||||
|
}
|
||||||
|
columns = append(columns, entity.NewColumnVarChar(field.RawName, values))
|
||||||
|
|
||||||
// Build insert data
|
case "vector":
|
||||||
columns := []entity.Column{
|
// Handle vector fields
|
||||||
entity.NewColumnVarChar("id", ids),
|
vectors := make([][]float32, len(docs))
|
||||||
entity.NewColumnVarChar("content", contents),
|
for i, doc := range docs {
|
||||||
entity.NewColumnFloatVector("vector", len(vectors[0]), vectors),
|
vectors[i] = doc.Vector
|
||||||
entity.NewColumnJSONBytes("metadata", metadatas),
|
}
|
||||||
entity.NewColumnInt64("created_at", createdAts),
|
columns = append(columns, entity.NewColumnFloatVector(field.RawName, len(vectors[0]), vectors))
|
||||||
|
case "metadata":
|
||||||
|
// Handle JSON type fields (like metadata)
|
||||||
|
values := make([][]byte, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
// Serialize metadata
|
||||||
|
metadataBytes, err := json.Marshal(doc.Metadata)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal metadata for doc %s: %w", doc.ID, err)
|
||||||
|
}
|
||||||
|
values[i] = metadataBytes
|
||||||
|
}
|
||||||
|
columns = append(columns, entity.NewColumnJSONBytes(field.RawName, values))
|
||||||
|
case "created_at":
|
||||||
|
// Handle integer type fields
|
||||||
|
values := make([]int64, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
values[i] = doc.CreatedAt.UnixMilli()
|
||||||
|
}
|
||||||
|
columns = append(columns, entity.NewColumnInt64(field.RawName, values))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert data
|
// Insert data
|
||||||
_, err := m.client.Insert(ctx, m.Collection, "", columns...)
|
_, err = m.client.Insert(ctx, m.collection, "", columns...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to insert documents: %w", err)
|
return fmt.Errorf("failed to insert documents: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush data
|
// Flush data
|
||||||
err = m.client.Flush(ctx, m.Collection, false)
|
err = m.client.Flush(ctx, m.collection, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to flush collection: %w", err)
|
return fmt.Errorf("failed to flush collection: %w", err)
|
||||||
}
|
}
|
||||||
@@ -271,16 +501,19 @@ func (m *MilvusProvider) AddDoc(ctx context.Context, docs []schema.Document) err
|
|||||||
|
|
||||||
// DeleteDoc deletes a document by its ID
|
// DeleteDoc deletes a document by its ID
|
||||||
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
|
func (m *MilvusProvider) DeleteDoc(ctx context.Context, id string) error {
|
||||||
// Build delete expression
|
// Get ID field
|
||||||
expr := fmt.Sprintf(`id == "%s"`, id)
|
idField, _ := m.mapper.GetIDField()
|
||||||
|
// Build delete expression using the RawName of ID field
|
||||||
|
expr := fmt.Sprintf(`%s == "%s"`, idField.RawName, id)
|
||||||
|
|
||||||
// Delete data
|
// Delete data
|
||||||
err := m.client.Delete(ctx, m.Collection, "", expr)
|
err := m.client.Delete(ctx, m.collection, "", expr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
|
return fmt.Errorf("failed to delete documents for id %s: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush data
|
// Flush data
|
||||||
err = m.client.Flush(ctx, m.Collection, false)
|
err = m.client.Flush(ctx, m.collection, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
||||||
}
|
}
|
||||||
@@ -306,24 +539,127 @@ func (m *MilvusProvider) UpdateDoc(ctx context.Context, docs []schema.Document)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MilvusProvider) buildSearchParam() (entity.SearchParam, error) {
|
||||||
|
// Get index configuration
|
||||||
|
indexConfig, err := m.mapper.GetIndexConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get index config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get search configuration
|
||||||
|
searchConfig, err := m.mapper.GetSearchConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get search config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose appropriate search parameters based on index type
|
||||||
|
milvusIndexType := strings.ToUpper(indexConfig.IndexType)
|
||||||
|
if milvusIndexType == "" {
|
||||||
|
milvusIndexType = "HNSW" // Default to HNSW index
|
||||||
|
}
|
||||||
|
|
||||||
|
switch milvusIndexType {
|
||||||
|
case "FLAT":
|
||||||
|
// FLAT and BIN_FLAT indices don't need additional search parameters
|
||||||
|
return entity.NewIndexFlatSearchParam()
|
||||||
|
|
||||||
|
case "BIN_FLAT", "IVF_FLAT", "BIN_IVF_FLAT", "IVF_SQ8":
|
||||||
|
// Search parameters for IVF series indices
|
||||||
|
nprobe := 16 // Default value
|
||||||
|
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||||
|
nprobe = int(nprobeVal)
|
||||||
|
}
|
||||||
|
return entity.NewIndexIvfFlatSearchParam(nprobe)
|
||||||
|
|
||||||
|
case "IVF_PQ":
|
||||||
|
// Search parameters for IVF_PQ index
|
||||||
|
nprobe := 16 // Default value
|
||||||
|
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||||
|
nprobe = int(nprobeVal)
|
||||||
|
}
|
||||||
|
return entity.NewIndexIvfPQSearchParam(nprobe)
|
||||||
|
|
||||||
|
case "HNSW":
|
||||||
|
// Search parameters for HNSW index
|
||||||
|
efSearch := 16 // Default value
|
||||||
|
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
|
||||||
|
efSearch = int(efSearchVal)
|
||||||
|
}
|
||||||
|
return entity.NewIndexHNSWSearchParam(efSearch)
|
||||||
|
|
||||||
|
case "IVF_HNSW":
|
||||||
|
// Search parameters for IVF_HNSW index
|
||||||
|
nprobe := 16 // Default value
|
||||||
|
efSearch := 64 // Default value
|
||||||
|
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||||
|
nprobe = int(nprobeVal)
|
||||||
|
}
|
||||||
|
if efSearchVal, err := searchConfig.ParamsFloat64("ef"); err == nil {
|
||||||
|
efSearch = int(efSearchVal)
|
||||||
|
}
|
||||||
|
return entity.NewIndexIvfHNSWSearchParam(nprobe, efSearch)
|
||||||
|
|
||||||
|
case "SCANN":
|
||||||
|
// Search parameters for SCANN index
|
||||||
|
nprobe := 16 // Default value
|
||||||
|
reorder_k := 64
|
||||||
|
if nprobeVal, err := searchConfig.ParamsFloat64("nprobe"); err == nil {
|
||||||
|
nprobe = int(nprobeVal)
|
||||||
|
}
|
||||||
|
if reorderKVal, err := searchConfig.ParamsInt64("reorder_k"); err == nil {
|
||||||
|
reorder_k = int(reorderKVal)
|
||||||
|
}
|
||||||
|
return entity.NewIndexSCANNSearchParam(nprobe, reorder_k)
|
||||||
|
|
||||||
|
case "DISKANN":
|
||||||
|
// Search parameters for DISKANN index
|
||||||
|
search_list := 100 // Default value
|
||||||
|
if searchListVal, err := searchConfig.ParamsInt64("search_list"); err == nil {
|
||||||
|
search_list = int(searchListVal)
|
||||||
|
}
|
||||||
|
return entity.NewIndexDISKANNSearchParam(search_list)
|
||||||
|
|
||||||
|
case "AUTOINDEX":
|
||||||
|
level := 8
|
||||||
|
if levelVal, err := searchConfig.ParamsInt64("level"); err == nil {
|
||||||
|
level = int(levelVal)
|
||||||
|
}
|
||||||
|
// Search parameters for AUTOINDEX index
|
||||||
|
return entity.NewIndexAUTOINDEXSearchParam(level)
|
||||||
|
default:
|
||||||
|
// Default to using HNSW search parameters
|
||||||
|
return entity.NewIndexHNSWSearchParam(16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SearchDocs performs similarity search for documents
|
// SearchDocs performs similarity search for documents
|
||||||
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
|
func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, options *schema.SearchOptions) ([]schema.SearchResult, error) {
|
||||||
if options == nil {
|
if options == nil {
|
||||||
options = &schema.SearchOptions{TopK: 10}
|
options = &schema.SearchOptions{TopK: 10}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build search parameters
|
// Build search parameters
|
||||||
sp, _ := entity.NewIndexHNSWSearchParam(16)
|
sp, err := m.buildSearchParam()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to build search param: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputFields, _ := m.mapper.GetRawAllFieldNames()
|
||||||
|
vectorField, _ := m.mapper.GetVectorField()
|
||||||
|
searchConfig, _ := m.mapper.GetSearchConfig()
|
||||||
|
metricType := m.GetMetricType(searchConfig.MetricType)
|
||||||
|
|
||||||
// Build filter expression
|
// Build filter expression
|
||||||
expr := ""
|
expr := ""
|
||||||
searchResults, err := m.client.Search(
|
searchResults, err := m.client.Search(
|
||||||
ctx,
|
ctx,
|
||||||
m.Collection,
|
m.collection,
|
||||||
[]string{}, // partition names
|
[]string{}, // partition names
|
||||||
expr, // filter expression
|
expr, // filter expression
|
||||||
[]string{"id", "content", "metadata", "created_at"}, // output fields
|
outputFields, // output fields
|
||||||
[]entity.Vector{entity.FloatVector(vector)},
|
[]entity.Vector{entity.FloatVector(vector)},
|
||||||
"vector", // anns_field
|
vectorField.RawName, // anns_field
|
||||||
entity.IP, // metric_type
|
metricType, // metric_type
|
||||||
options.TopK,
|
options.TopK,
|
||||||
sp,
|
sp,
|
||||||
)
|
)
|
||||||
@@ -341,9 +677,13 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
|
|||||||
// Get field data
|
// Get field data
|
||||||
var content string
|
var content string
|
||||||
var metadata map[string]interface{}
|
var metadata map[string]interface{}
|
||||||
|
|
||||||
for _, field := range result.Fields {
|
for _, field := range result.Fields {
|
||||||
switch field.Name() {
|
fieldMapping, err := m.mapper.GetField(field.Name())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fieldName := strings.ToLower(fieldMapping.StandardName)
|
||||||
|
switch fieldName {
|
||||||
case "content":
|
case "content":
|
||||||
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
|
if contentCol, ok := field.(*entity.ColumnVarChar); ok {
|
||||||
if contentVal, err := contentCol.Get(i); err == nil {
|
if contentVal, err := contentCol.Get(i); err == nil {
|
||||||
@@ -364,7 +704,6 @@ func (m *MilvusProvider) SearchDocs(ctx context.Context, vector []float32, optio
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
searchResult := schema.SearchResult{
|
searchResult := schema.SearchResult{
|
||||||
Document: schema.Document{
|
Document: schema.Document{
|
||||||
ID: fmt.Sprintf("%s", id),
|
ID: fmt.Sprintf("%s", id),
|
||||||
@@ -392,15 +731,17 @@ func (m *MilvusProvider) DeleteDocs(ctx context.Context, ids []string) error {
|
|||||||
for i, id := range ids {
|
for i, id := range ids {
|
||||||
quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
|
quotedIDs[i] = fmt.Sprintf("\"%s\"", id)
|
||||||
}
|
}
|
||||||
expr := fmt.Sprintf("id in [%s]", strings.Join(quotedIDs, ","))
|
|
||||||
|
idField, _ := m.mapper.GetIDField()
|
||||||
|
expr := fmt.Sprintf("%s in [%s]", idField.RawName, strings.Join(quotedIDs, ","))
|
||||||
|
|
||||||
// Delete data
|
// Delete data
|
||||||
err := m.client.Delete(ctx, m.Collection, "", expr)
|
err := m.client.Delete(ctx, m.collection, "", expr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete documents: %w", err)
|
return fmt.Errorf("failed to delete documents: %w", err)
|
||||||
}
|
}
|
||||||
// Flush data
|
// Flush data
|
||||||
err = m.client.Flush(ctx, m.Collection, false)
|
err = m.client.Flush(ctx, m.collection, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
return fmt.Errorf("failed to flush collection after delete: %w", err)
|
||||||
}
|
}
|
||||||
@@ -413,12 +754,13 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
|
|||||||
// Build query expression
|
// Build query expression
|
||||||
expr := ""
|
expr := ""
|
||||||
// Query all relevant documents
|
// Query all relevant documents
|
||||||
|
outputFields, _ := m.mapper.GetRawAllFieldNames()
|
||||||
queryResult, err := m.client.Query(
|
queryResult, err := m.client.Query(
|
||||||
ctx,
|
ctx,
|
||||||
m.Collection,
|
m.collection,
|
||||||
[]string{}, // partitions
|
[]string{}, // partitions
|
||||||
expr, // filter condition
|
expr, // filter condition
|
||||||
[]string{"id", "content", "metadata", "created_at"},
|
outputFields,
|
||||||
client.WithOffset(0), client.WithLimit(int64(limit)),
|
client.WithOffset(0), client.WithLimit(int64(limit)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -443,7 +785,12 @@ func (m *MilvusProvider) ListDocs(ctx context.Context, limit int) ([]schema.Docu
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, col := range queryResult {
|
for _, col := range queryResult {
|
||||||
switch col.Name() {
|
fieldMapping, err := m.mapper.GetField(col.Name())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fieldName := strings.ToLower(fieldMapping.StandardName)
|
||||||
|
switch fieldName {
|
||||||
case "id":
|
case "id":
|
||||||
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
if v, err := col.(*entity.ColumnVarChar).Get(i); err == nil {
|
||||||
id = v.(string)
|
id = v.(string)
|
||||||
@@ -488,8 +835,3 @@ func (m *MilvusProvider) Close() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// joinStrings joins a slice of strings with the given separator
|
|
||||||
func joinStrings(elems []string, sep string) string {
|
|
||||||
return strings.Join(elems, sep)
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user