mirror of
https://github.com/alibaba/higress.git
synced 2026-02-27 22:20:57 +08:00
Compare commits
91 Commits
cr7258-pat
...
v2.1.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36bcb595d6 | ||
|
|
783a8db512 | ||
|
|
44566f5259 | ||
|
|
73ba9238bd | ||
|
|
41a1455874 | ||
|
|
9d68ccbf35 | ||
|
|
db7dbb24a2 | ||
|
|
9a0cf9b762 | ||
|
|
bb786c9618 | ||
|
|
ef49d2f5f6 | ||
|
|
864bf5af39 | ||
|
|
527e922d50 | ||
|
|
1fe5eb6e13 | ||
|
|
87185baff2 | ||
|
|
76ada0b844 | ||
|
|
f4d3fec228 | ||
|
|
e94ac43dd1 | ||
|
|
dd29267fd7 | ||
|
|
01a9161153 | ||
|
|
ceb8b557dc | ||
|
|
753022e093 | ||
|
|
04cbbfc7e8 | ||
|
|
db66df39c4 | ||
|
|
dad6278a6d | ||
|
|
272d693df3 | ||
|
|
69bc800198 | ||
|
|
1daaa4b880 | ||
|
|
6e31a7b67c | ||
|
|
91f070906a | ||
|
|
e3aeddcc24 | ||
|
|
926913f0e7 | ||
|
|
c471bb2003 | ||
|
|
0b9256617e | ||
|
|
2670ecbf8e | ||
|
|
7040e4bd34 | ||
|
|
de8a4d0b03 | ||
|
|
b33a3a4d2e | ||
|
|
087cb48fc5 | ||
|
|
95f32002d2 | ||
|
|
fb8dd819e9 | ||
|
|
86934b3203 | ||
|
|
38068ee43d | ||
|
|
d81573e0d2 | ||
|
|
312b80f91d | ||
|
|
e42e6eeee6 | ||
|
|
9f5067d22f | ||
|
|
6af9587372 | ||
|
|
5812c1e734 | ||
|
|
bafbe7972d | ||
|
|
f3fbf7d6c8 | ||
|
|
1666dfb01c | ||
|
|
d2f09fe8c5 | ||
|
|
69d877c116 | ||
|
|
5bc0058779 | ||
|
|
d4e114b152 | ||
|
|
e674c780c6 | ||
|
|
26cd6837d5 | ||
|
|
5674d91a10 | ||
|
|
c78b4aaba3 | ||
|
|
0e4e8da9c1 | ||
|
|
c9ec8a12bb | ||
|
|
7484bcea62 | ||
|
|
896780b60e | ||
|
|
7b1ae49cd4 | ||
|
|
ee26baf054 | ||
|
|
33fc47cefb | ||
|
|
19946d46ca | ||
|
|
52d0212698 | ||
|
|
a73c33f1da | ||
|
|
69b755a10d | ||
|
|
52464c0e06 | ||
|
|
d7d5d1c571 | ||
|
|
ea948ee818 | ||
|
|
767f51adce | ||
|
|
168cb04c61 | ||
|
|
323aabf72b | ||
|
|
b8d75598ed | ||
|
|
b37649a62f | ||
|
|
76f76a70ab | ||
|
|
647c961f51 | ||
|
|
5a5a72a9f8 | ||
|
|
ffcf5df28a | ||
|
|
ec83623614 | ||
|
|
bf5be07d74 | ||
|
|
f6bb5d7729 | ||
|
|
031ae21caa | ||
|
|
fa3c5ea0fc | ||
|
|
93436db13c | ||
|
|
be2c6f8a4a | ||
|
|
c768973e47 | ||
|
|
8ec65ed377 |
123
.github/workflows/helm-docs.yaml
vendored
123
.github/workflows/helm-docs.yaml
vendored
@@ -39,126 +39,3 @@ jobs:
|
||||
fi
|
||||
git diff --exit-code
|
||||
rm -f ./helm-docs
|
||||
|
||||
translate-readme:
|
||||
needs: helm
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y jq
|
||||
|
||||
- name: Compare README.md
|
||||
id: compare_readme
|
||||
run: |
|
||||
cd ./helm/higress
|
||||
|
||||
BASE_BRANCH=${GITHUB_BASE_REF:-main}
|
||||
git fetch origin $BASE_BRANCH
|
||||
|
||||
if git diff --quiet origin/$BASE_BRANCH -- README.md; then
|
||||
echo "README.md has no local changes compared to $BASE_BRANCH. Skipping translation."
|
||||
echo "skip_translation=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "README.md has local changes compared to $BASE_BRANCH. Proceeding with translation."
|
||||
echo "skip_translation=false" >> $GITHUB_ENV
|
||||
echo "--------- diff ---------"
|
||||
git diff origin/$BASE_BRANCH -- README.md
|
||||
echo "------------------------"
|
||||
fi
|
||||
|
||||
- name: Translate README.md to Chinese
|
||||
if: env.skip_translation == 'false'
|
||||
env:
|
||||
API_URL: ${{ secrets.HIGRESS_OPENAI_API_URL }}
|
||||
API_KEY: ${{ secrets.HIGRESS_OPENAI_API_KEY }}
|
||||
API_MODEL: ${{ secrets.HIGRESS_OPENAI_API_MODEL }}
|
||||
run: |
|
||||
cat << 'EOF' > translate_readme.py
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
API_URL = os.environ["API_URL"]
|
||||
API_KEY = os.environ["API_KEY"]
|
||||
API_MODEL = os.environ["API_MODEL"]
|
||||
README_PATH = "./helm/higress/README.md"
|
||||
OUTPUT_PATH = "./helm/higress/README.zh.md"
|
||||
|
||||
def stream_translation(api_url, api_key, payload):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json=payload, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(OUTPUT_PATH, "w", encoding="utf-8") as out_file:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line.strip() == "" or not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
if content:
|
||||
out_file.write(content)
|
||||
except Exception as e:
|
||||
print("Error parsing chunk:", e)
|
||||
|
||||
def main():
|
||||
if not os.path.exists(README_PATH):
|
||||
print("README.md not found!")
|
||||
return
|
||||
|
||||
with open(README_PATH, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
payload = {
|
||||
"model": API_MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a translation assistant that translates English Markdown text to Chinese. Preserve original Markdown formatting and line breaks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
print("Streaming translation started...")
|
||||
stream_translation(API_URL, API_KEY, payload)
|
||||
print(f"Translation completed and saved to {OUTPUT_PATH}.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
EOF
|
||||
|
||||
python3 translate_readme.py
|
||||
rm -rf translate_readme.py
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.skip_translation == 'false'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: "Update helm translated README.zh.md"
|
||||
branch: update-helm-readme-zh
|
||||
title: "Update helm translated README.zh.md"
|
||||
body: |
|
||||
This PR updates the translated README.zh.md file.
|
||||
|
||||
- Automatically generated by GitHub Actions
|
||||
labels: translation, automated
|
||||
base: main
|
||||
2
.github/workflows/release-crd.yaml
vendored
2
.github/workflows/release-crd.yaml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
cat helm/core/crds/customresourcedefinitions.gen.yaml helm/core/crds/istio-envoyfilter.yaml > crd.yaml
|
||||
|
||||
- name: Upload hgctl packages to the GitHub release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
|
||||
6
.github/workflows/release-hgctl.yaml
vendored
6
.github/workflows/release-hgctl.yaml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
zip -q -r hgctl_${{ env.HGCTL_VERSION }}_windows_arm64.zip out/windows_arm64/
|
||||
|
||||
- name: Upload hgctl packages to the GitHub release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
tar -zcvf hgctl_${{ env.HGCTL_VERSION }}_darwin_arm64.tar.gz out/darwin_arm64/
|
||||
|
||||
- name: Upload hgctl packages to the GitHub release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
tar -zcvf hgctl_${{ env.HGCTL_VERSION }}_darwin_amd64.tar.gz out/darwin_amd64/
|
||||
|
||||
- name: Upload hgctl packages to the GitHub release
|
||||
uses: softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
with:
|
||||
files: |
|
||||
|
||||
36
.github/workflows/sync-crds.yaml
vendored
Normal file
36
.github/workflows/sync-crds.yaml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: "Sync CRDs to Helm Chart"
|
||||
|
||||
on:
|
||||
workflow_dispatch: ~
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'api/kubernetes/customresourcedefinitions.gen.yaml'
|
||||
|
||||
jobs:
|
||||
sync-crds:
|
||||
name: Sync CRDs
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Copy the CRD YAML File to Helm Folder
|
||||
run: |
|
||||
cp api/kubernetes/customresourcedefinitions.gen.yaml helm/core/crds/customresourcedefinitions.gen.yaml
|
||||
|
||||
- name: Create Pull Request
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: "Update CRD file in the helm folder"
|
||||
branch: sync-crds
|
||||
title: "Update CRD file in the helm folder"
|
||||
body: |
|
||||
This PR updates CRD file in the helm folder.
|
||||
|
||||
- Automatically copied by GitHub Actions
|
||||
labels: crds, automated
|
||||
base: main
|
||||
131
.github/workflows/translate-readme.yaml
vendored
Normal file
131
.github/workflows/translate-readme.yaml
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
name: "Helm Docs"
|
||||
|
||||
on:
|
||||
workflow_dispatch: ~
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'helm/higress/README.md'
|
||||
|
||||
jobs:
|
||||
translate-readme:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y jq
|
||||
|
||||
- name: Compare README.md
|
||||
id: compare_readme
|
||||
run: |
|
||||
cd ./helm/higress
|
||||
|
||||
BASE_BRANCH=${GITHUB_BASE_REF:-main}
|
||||
git fetch origin $BASE_BRANCH
|
||||
|
||||
if git diff --quiet origin/$BASE_BRANCH -- README.md; then
|
||||
echo "README.md has no local changes compared to $BASE_BRANCH. Skipping translation."
|
||||
echo "skip_translation=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "README.md has local changes compared to $BASE_BRANCH. Proceeding with translation."
|
||||
echo "skip_translation=false" >> $GITHUB_ENV
|
||||
echo "--------- diff ---------"
|
||||
git diff origin/$BASE_BRANCH -- README.md
|
||||
echo "------------------------"
|
||||
fi
|
||||
|
||||
- name: Translate README.md to Chinese
|
||||
if: env.skip_translation == 'false'
|
||||
env:
|
||||
API_URL: ${{ secrets.HIGRESS_OPENAI_API_URL }}
|
||||
API_KEY: ${{ secrets.HIGRESS_OPENAI_API_KEY }}
|
||||
API_MODEL: ${{ secrets.HIGRESS_OPENAI_API_MODEL }}
|
||||
run: |
|
||||
cat << 'EOF' > translate_readme.py
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
API_URL = os.environ["API_URL"]
|
||||
API_KEY = os.environ["API_KEY"]
|
||||
API_MODEL = os.environ["API_MODEL"]
|
||||
README_PATH = "./helm/higress/README.md"
|
||||
OUTPUT_PATH = "./helm/higress/README.zh.md"
|
||||
|
||||
def stream_translation(api_url, api_key, payload):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json=payload, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(OUTPUT_PATH, "w", encoding="utf-8") as out_file:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line.strip() == "" or not line.startswith("data: "):
|
||||
continue
|
||||
data = line[6:]
|
||||
if data.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(data)
|
||||
content = chunk["choices"][0]["delta"].get("content", "")
|
||||
if content:
|
||||
out_file.write(content)
|
||||
except Exception as e:
|
||||
print("Error parsing chunk:", e)
|
||||
|
||||
def main():
|
||||
if not os.path.exists(README_PATH):
|
||||
print("README.md not found!")
|
||||
return
|
||||
|
||||
with open(README_PATH, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
payload = {
|
||||
"model": API_MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a translation assistant that translates English Markdown text to Chinese. Preserve original Markdown formatting and line breaks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
print("Streaming translation started...")
|
||||
stream_translation(API_URL, API_KEY, payload)
|
||||
print(f"Translation completed and saved to {OUTPUT_PATH}.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
EOF
|
||||
|
||||
python3 translate_readme.py
|
||||
rm -rf translate_readme.py
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.skip_translation == 'false'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
commit-message: "Update helm translated README.zh.md"
|
||||
branch: update-helm-readme-zh
|
||||
title: "Update helm translated README.zh.md"
|
||||
body: |
|
||||
This PR updates the translated README.zh.md file.
|
||||
|
||||
- Automatically generated by GitHub Actions
|
||||
labels: translation, automated
|
||||
base: main
|
||||
@@ -33,6 +33,7 @@ header:
|
||||
- 'hgctl/cmd/hgctl/config/testdata/config'
|
||||
- 'hgctl/pkg/manifests'
|
||||
- 'pkg/ingress/kube/gateway/istio/testdata'
|
||||
- 'release-notes/**'
|
||||
|
||||
comment: on-failure
|
||||
dependency:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
/istio @SpecialYang @johnlanni
|
||||
/pkg @SpecialYang @johnlanni @CH3CHO
|
||||
/plugins @johnlanni @CH3CHO @rinfx
|
||||
/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx
|
||||
/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx @wydream
|
||||
/plugins/wasm-rust @007gzs @jizhuozhi
|
||||
/registry @NameHaibinZhang @2456868764 @johnlanni
|
||||
/test @Xunzhuo @2456868764 @CH3CHO
|
||||
|
||||
@@ -144,7 +144,7 @@ docker-buildx-push: clean-env docker.higress-buildx
|
||||
export PARENT_GIT_TAG:=$(shell cat VERSION)
|
||||
export PARENT_GIT_REVISION:=$(TAG)
|
||||
|
||||
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.5/envoy-symbol-ARCH.tar.gz
|
||||
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.7/envoy-symbol-ARCH.tar.gz
|
||||
|
||||
build-envoy: prebuild
|
||||
./tools/hack/build-envoy.sh
|
||||
@@ -191,8 +191,9 @@ install: pre-install
|
||||
cd helm/higress; helm dependency build
|
||||
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
|
||||
|
||||
ENVOY_LATEST_IMAGE_TAG ?= 958467a353d411ae3f06e03b096bfd342cddb2c6
|
||||
ISTIO_LATEST_IMAGE_TAG ?= d9c728d3b01f64855e012b08d136e306f1160397
|
||||
HIGRESS_LATEST_IMAGE_TAG ?= latest
|
||||
ENVOY_LATEST_IMAGE_TAG ?= latest
|
||||
ISTIO_LATEST_IMAGE_TAG ?= latest
|
||||
|
||||
install-dev: pre-install
|
||||
helm install higress helm/core -n higress-system --create-namespace --set 'controller.tag=$(TAG)' --set 'gateway.replicas=1' --set 'pilot.tag=$(ISTIO_LATEST_IMAGE_TAG)' --set 'gateway.tag=$(ENVOY_LATEST_IMAGE_TAG)' --set 'global.local=true'
|
||||
@@ -268,10 +269,26 @@ higress-conformance-test-clean: $(tools/kind) delete-cluster
|
||||
.PHONY: higress-wasmplugin-test-prepare
|
||||
higress-wasmplugin-test-prepare: $(tools/kind) delete-cluster create-cluster docker-build kube-load-image install-dev-wasmplugin
|
||||
|
||||
# higress-wasmplugin-test-prepare-skip-docker-build prepares the environment for higress wasmplugin tests without build higress docker image.
|
||||
.PHONY: higress-wasmplugin-test-prepare-skip-docker-build
|
||||
higress-wasmplugin-test-prepare-skip-docker-build: $(tools/kind) delete-cluster create-cluster prebuild
|
||||
@export TAG="$(HIGRESS_LATEST_IMAGE_TAG)" && \
|
||||
$(MAKE) kube-load-image && \
|
||||
$(MAKE) install-dev-wasmplugin
|
||||
|
||||
# higress-wasmplugin-test runs ingress wasmplugin tests.
|
||||
.PHONY: higress-wasmplugin-test
|
||||
higress-wasmplugin-test: $(tools/kind) delete-cluster create-cluster docker-build kube-load-image install-dev-wasmplugin run-higress-e2e-test-wasmplugin delete-cluster
|
||||
|
||||
# higress-wasmplugin-test-skip-docker-build runs ingress wasmplugin tests without build higress docker image
|
||||
.PHONY: higress-wasmplugin-test-skip-docker-build
|
||||
higress-wasmplugin-test-skip-docker-build: $(tools/kind) delete-cluster create-cluster prebuild
|
||||
@export TAG="$(HIGRESS_LATEST_IMAGE_TAG)" && \
|
||||
$(MAKE) kube-load-image && \
|
||||
$(MAKE) install-dev-wasmplugin && \
|
||||
$(MAKE) run-higress-e2e-test-wasmplugin && \
|
||||
$(MAKE) delete-cluster
|
||||
|
||||
# higress-wasmplugin-test-clean cleans the environment for higress wasmplugin tests.
|
||||
.PHONY: higress-wasmplugin-test-clean
|
||||
higress-wasmplugin-test-clean: $(tools/kind) delete-cluster
|
||||
@@ -290,8 +307,12 @@ delete-cluster: $(tools/kind) ## Delete kind cluster.
|
||||
# dubbo-provider-demo和nacos-standlone-rc3的镜像已经上传到阿里云镜像库,第一次需要先拉到本地
|
||||
# docker pull registry.cn-hangzhou.aliyuncs.com/hinsteny/dubbo-provider-demo:0.0.1
|
||||
# docker pull registry.cn-hangzhou.aliyuncs.com/hinsteny/nacos-standlone-rc3:1.0.0-RC3
|
||||
# If TAG is HIGRESS_LATEST_IMAGE_TAG, means we skip building higress docker image, so we need to pull the image first.
|
||||
.PHONY: kube-load-image
|
||||
kube-load-image: $(tools/kind) ## Install the Higress image to a kind cluster using the provided $IMAGE and $TAG.
|
||||
@if [ "$(TAG)" = "$(HIGRESS_LATEST_IMAGE_TAG)" ]; then \
|
||||
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/higress $(TAG); \
|
||||
fi
|
||||
tools/hack/kind-load-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/higress $(TAG)
|
||||
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/pilot $(ISTIO_LATEST_IMAGE_TAG)
|
||||
tools/hack/docker-pull-image.sh higress-registry.cn-hangzhou.cr.aliyuncs.com/higress/gateway $(ENVOY_LATEST_IMAGE_TAG)
|
||||
|
||||
10
README.md
10
README.md
@@ -10,7 +10,7 @@
|
||||
|
||||
[](https://github.com/alibaba/higress/actions)
|
||||
[](https://www.apache.org/licenses/LICENSE-2.0.html)
|
||||
[](https://discord.gg/reymxYM5)
|
||||
[](https://discord.gg/tSbww9VDaM)
|
||||
|
||||
<a href="https://trendshift.io/repositories/10918" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10918" alt="alibaba%2Fhigress | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://www.producthunt.com/posts/higress?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-higress" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=951287&theme=light&t=1745492822283" alt="Higress - Global APIs as MCP powered by AI Gateway | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
English | <a href="README_ZH.md">中文<a/> | <a href="README_JP.md">日本語<a/>
|
||||
</p>
|
||||
|
||||
## Test What is Higress?
|
||||
## What is Higress?
|
||||
|
||||
Higress is a cloud-native API gateway based on Istio and Envoy, which can be extended with Wasm plugins written in Go/Rust/JS. It provides dozens of ready-to-use general-purpose plugins and an out-of-the-box console (try the [demo here](http://demo.higress.io/)).
|
||||
|
||||
@@ -69,6 +69,10 @@ Port descriptions:
|
||||
|
||||
> All Higress Docker images use Higress's own image repository and are not affected by Docker Hub rate limits.
|
||||
> In addition, the submission and updates of the images are protected by a security scanning mechanism (powered by Alibaba Cloud ACR), making them very secure for use in production environments.
|
||||
>
|
||||
> If you experience a timeout when pulling image from `higress-registry.cn-hangzhou.cr.aliyuncs.com`, you can try replacing it with the following docker registry mirror source:
|
||||
>
|
||||
> **Southeast Asia**: `higress-registry.ap-southeast-7.cr.aliyuncs.com`
|
||||
|
||||
For other installation methods such as Helm deployment under K8s, please refer to the official [Quick Start documentation](https://higress.io/en-us/docs/user/quickstart).
|
||||
|
||||
@@ -143,7 +147,7 @@ For other installation methods such as Helm deployment under K8s, please refer t
|
||||
|
||||
Join our Discord community! This is where you can connect with developers and other enthusiastic users of Higress.
|
||||
|
||||
[](https://discord.gg/reymxYM5)
|
||||
[](https://discord.gg/tSbww9VDaM)
|
||||
|
||||
|
||||
### Thanks
|
||||
|
||||
@@ -250,6 +250,10 @@ spec:
|
||||
registries:
|
||||
items:
|
||||
properties:
|
||||
allowMcpServers:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
authSecretName:
|
||||
type: string
|
||||
consulDatacenter:
|
||||
@@ -265,12 +269,23 @@ spec:
|
||||
type: string
|
||||
enableMCPServer:
|
||||
type: boolean
|
||||
enableScopeMcpServers:
|
||||
type: boolean
|
||||
mcpServerBaseUrl:
|
||||
type: string
|
||||
mcpServerExportDomains:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
metadata:
|
||||
additionalProperties:
|
||||
properties:
|
||||
innerMap:
|
||||
additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
type: object
|
||||
type: object
|
||||
nacosAccessKey:
|
||||
type: string
|
||||
nacosAddressServer:
|
||||
|
||||
@@ -111,28 +111,31 @@ type RegistryConfig struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
|
||||
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
|
||||
Port uint32 `protobuf:"varint,4,opt,name=port,proto3" json:"port,omitempty"`
|
||||
NacosAddressServer string `protobuf:"bytes,5,opt,name=nacosAddressServer,proto3" json:"nacosAddressServer,omitempty"`
|
||||
NacosAccessKey string `protobuf:"bytes,6,opt,name=nacosAccessKey,proto3" json:"nacosAccessKey,omitempty"`
|
||||
NacosSecretKey string `protobuf:"bytes,7,opt,name=nacosSecretKey,proto3" json:"nacosSecretKey,omitempty"`
|
||||
NacosNamespaceId string `protobuf:"bytes,8,opt,name=nacosNamespaceId,proto3" json:"nacosNamespaceId,omitempty"`
|
||||
NacosNamespace string `protobuf:"bytes,9,opt,name=nacosNamespace,proto3" json:"nacosNamespace,omitempty"`
|
||||
NacosGroups []string `protobuf:"bytes,10,rep,name=nacosGroups,proto3" json:"nacosGroups,omitempty"`
|
||||
NacosRefreshInterval int64 `protobuf:"varint,11,opt,name=nacosRefreshInterval,proto3" json:"nacosRefreshInterval,omitempty"`
|
||||
ConsulNamespace string `protobuf:"bytes,12,opt,name=consulNamespace,proto3" json:"consulNamespace,omitempty"`
|
||||
ZkServicesPath []string `protobuf:"bytes,13,rep,name=zkServicesPath,proto3" json:"zkServicesPath,omitempty"`
|
||||
ConsulDatacenter string `protobuf:"bytes,14,opt,name=consulDatacenter,proto3" json:"consulDatacenter,omitempty"`
|
||||
ConsulServiceTag string `protobuf:"bytes,15,opt,name=consulServiceTag,proto3" json:"consulServiceTag,omitempty"`
|
||||
ConsulRefreshInterval int64 `protobuf:"varint,16,opt,name=consulRefreshInterval,proto3" json:"consulRefreshInterval,omitempty"`
|
||||
AuthSecretName string `protobuf:"bytes,17,opt,name=authSecretName,proto3" json:"authSecretName,omitempty"`
|
||||
Protocol string `protobuf:"bytes,18,opt,name=protocol,proto3" json:"protocol,omitempty"`
|
||||
Sni string `protobuf:"bytes,19,opt,name=sni,proto3" json:"sni,omitempty"`
|
||||
McpServerExportDomains []string `protobuf:"bytes,20,rep,name=mcpServerExportDomains,proto3" json:"mcpServerExportDomains,omitempty"`
|
||||
McpServerBaseUrl string `protobuf:"bytes,21,opt,name=mcpServerBaseUrl,proto3" json:"mcpServerBaseUrl,omitempty"`
|
||||
EnableMCPServer *wrappers.BoolValue `protobuf:"bytes,22,opt,name=enableMCPServer,proto3" json:"enableMCPServer,omitempty"`
|
||||
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
|
||||
Domain string `protobuf:"bytes,3,opt,name=domain,proto3" json:"domain,omitempty"`
|
||||
Port uint32 `protobuf:"varint,4,opt,name=port,proto3" json:"port,omitempty"`
|
||||
NacosAddressServer string `protobuf:"bytes,5,opt,name=nacosAddressServer,proto3" json:"nacosAddressServer,omitempty"`
|
||||
NacosAccessKey string `protobuf:"bytes,6,opt,name=nacosAccessKey,proto3" json:"nacosAccessKey,omitempty"`
|
||||
NacosSecretKey string `protobuf:"bytes,7,opt,name=nacosSecretKey,proto3" json:"nacosSecretKey,omitempty"`
|
||||
NacosNamespaceId string `protobuf:"bytes,8,opt,name=nacosNamespaceId,proto3" json:"nacosNamespaceId,omitempty"`
|
||||
NacosNamespace string `protobuf:"bytes,9,opt,name=nacosNamespace,proto3" json:"nacosNamespace,omitempty"`
|
||||
NacosGroups []string `protobuf:"bytes,10,rep,name=nacosGroups,proto3" json:"nacosGroups,omitempty"`
|
||||
NacosRefreshInterval int64 `protobuf:"varint,11,opt,name=nacosRefreshInterval,proto3" json:"nacosRefreshInterval,omitempty"`
|
||||
ConsulNamespace string `protobuf:"bytes,12,opt,name=consulNamespace,proto3" json:"consulNamespace,omitempty"`
|
||||
ZkServicesPath []string `protobuf:"bytes,13,rep,name=zkServicesPath,proto3" json:"zkServicesPath,omitempty"`
|
||||
ConsulDatacenter string `protobuf:"bytes,14,opt,name=consulDatacenter,proto3" json:"consulDatacenter,omitempty"`
|
||||
ConsulServiceTag string `protobuf:"bytes,15,opt,name=consulServiceTag,proto3" json:"consulServiceTag,omitempty"`
|
||||
ConsulRefreshInterval int64 `protobuf:"varint,16,opt,name=consulRefreshInterval,proto3" json:"consulRefreshInterval,omitempty"`
|
||||
AuthSecretName string `protobuf:"bytes,17,opt,name=authSecretName,proto3" json:"authSecretName,omitempty"`
|
||||
Protocol string `protobuf:"bytes,18,opt,name=protocol,proto3" json:"protocol,omitempty"`
|
||||
Sni string `protobuf:"bytes,19,opt,name=sni,proto3" json:"sni,omitempty"`
|
||||
McpServerExportDomains []string `protobuf:"bytes,20,rep,name=mcpServerExportDomains,proto3" json:"mcpServerExportDomains,omitempty"`
|
||||
McpServerBaseUrl string `protobuf:"bytes,21,opt,name=mcpServerBaseUrl,proto3" json:"mcpServerBaseUrl,omitempty"`
|
||||
EnableMCPServer *wrappers.BoolValue `protobuf:"bytes,22,opt,name=enableMCPServer,proto3" json:"enableMCPServer,omitempty"`
|
||||
EnableScopeMcpServers *wrappers.BoolValue `protobuf:"bytes,23,opt,name=enableScopeMcpServers,proto3" json:"enableScopeMcpServers,omitempty"`
|
||||
AllowMcpServers []string `protobuf:"bytes,24,rep,name=allowMcpServers,proto3" json:"allowMcpServers,omitempty"`
|
||||
Metadata map[string]*InnerMap `protobuf:"bytes,25,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
}
|
||||
|
||||
func (x *RegistryConfig) Reset() {
|
||||
@@ -321,6 +324,74 @@ func (x *RegistryConfig) GetEnableMCPServer() *wrappers.BoolValue {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RegistryConfig) GetEnableScopeMcpServers() *wrappers.BoolValue {
|
||||
if x != nil {
|
||||
return x.EnableScopeMcpServers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RegistryConfig) GetAllowMcpServers() []string {
|
||||
if x != nil {
|
||||
return x.AllowMcpServers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RegistryConfig) GetMetadata() map[string]*InnerMap {
|
||||
if x != nil {
|
||||
return x.Metadata
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type InnerMap struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
InnerMap map[string]string `protobuf:"bytes,1,rep,name=inner_map,json=innerMap,proto3" json:"inner_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
}
|
||||
|
||||
func (x *InnerMap) Reset() {
|
||||
*x = InnerMap{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *InnerMap) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*InnerMap) ProtoMessage() {}
|
||||
|
||||
func (x *InnerMap) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use InnerMap.ProtoReflect.Descriptor instead.
|
||||
func (*InnerMap) Descriptor() ([]byte, []int) {
|
||||
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *InnerMap) GetInnerMap() map[string]string {
|
||||
if x != nil {
|
||||
return x.InnerMap
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_networking_v1_mcp_bridge_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_networking_v1_mcp_bridge_proto_rawDesc = []byte{
|
||||
@@ -338,7 +409,7 @@ var file_networking_v1_mcp_bridge_proto_rawDesc = []byte{
|
||||
0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73,
|
||||
0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e,
|
||||
0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a,
|
||||
0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x22, 0xfd, 0x06, 0x0a, 0x0e, 0x52,
|
||||
0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x22, 0xa8, 0x09, 0x0a, 0x0e, 0x52,
|
||||
0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x17, 0x0a,
|
||||
0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02,
|
||||
0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02,
|
||||
@@ -394,11 +465,39 @@ var file_networking_v1_mcp_bridge_proto_rawDesc = []byte{
|
||||
0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e,
|
||||
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
|
||||
0x42, 0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c,
|
||||
0x65, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x2e, 0x5a, 0x2c, 0x67, 0x69,
|
||||
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62, 0x61,
|
||||
0x2f, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6e, 0x65, 0x74,
|
||||
0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
0x65, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x50, 0x0a, 0x15, 0x65, 0x6e,
|
||||
0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76,
|
||||
0x65, 0x72, 0x73, 0x18, 0x17, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
|
||||
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c,
|
||||
0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x15, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f,
|
||||
0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x28, 0x0a, 0x0f,
|
||||
0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18,
|
||||
0x18, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x4f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61,
|
||||
0x74, 0x61, 0x18, 0x19, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65,
|
||||
0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31,
|
||||
0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e,
|
||||
0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d,
|
||||
0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x5c, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64,
|
||||
0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x76, 0x61,
|
||||
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x68, 0x69, 0x67, 0x72,
|
||||
0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76,
|
||||
0x31, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75,
|
||||
0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x93, 0x01, 0x0a, 0x08, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d,
|
||||
0x61, 0x70, 0x12, 0x4a, 0x0a, 0x09, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x5f, 0x6d, 0x61, 0x70, 0x18,
|
||||
0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e,
|
||||
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e,
|
||||
0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45,
|
||||
0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x1a, 0x3b,
|
||||
0x0a, 0x0d, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
|
||||
0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
|
||||
0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2e, 0x5a, 0x2c, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62,
|
||||
0x61, 0x2f, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6e, 0x65,
|
||||
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -413,20 +512,27 @@ func file_networking_v1_mcp_bridge_proto_rawDescGZIP() []byte {
|
||||
return file_networking_v1_mcp_bridge_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_networking_v1_mcp_bridge_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_networking_v1_mcp_bridge_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_networking_v1_mcp_bridge_proto_goTypes = []interface{}{
|
||||
(*McpBridge)(nil), // 0: higress.networking.v1.McpBridge
|
||||
(*RegistryConfig)(nil), // 1: higress.networking.v1.RegistryConfig
|
||||
(*wrappers.BoolValue)(nil), // 2: google.protobuf.BoolValue
|
||||
(*InnerMap)(nil), // 2: higress.networking.v1.InnerMap
|
||||
nil, // 3: higress.networking.v1.RegistryConfig.MetadataEntry
|
||||
nil, // 4: higress.networking.v1.InnerMap.InnerMapEntry
|
||||
(*wrappers.BoolValue)(nil), // 5: google.protobuf.BoolValue
|
||||
}
|
||||
var file_networking_v1_mcp_bridge_proto_depIdxs = []int32{
|
||||
1, // 0: higress.networking.v1.McpBridge.registries:type_name -> higress.networking.v1.RegistryConfig
|
||||
2, // 1: higress.networking.v1.RegistryConfig.enableMCPServer:type_name -> google.protobuf.BoolValue
|
||||
2, // [2:2] is the sub-list for method output_type
|
||||
2, // [2:2] is the sub-list for method input_type
|
||||
2, // [2:2] is the sub-list for extension type_name
|
||||
2, // [2:2] is the sub-list for extension extendee
|
||||
0, // [0:2] is the sub-list for field type_name
|
||||
5, // 1: higress.networking.v1.RegistryConfig.enableMCPServer:type_name -> google.protobuf.BoolValue
|
||||
5, // 2: higress.networking.v1.RegistryConfig.enableScopeMcpServers:type_name -> google.protobuf.BoolValue
|
||||
3, // 3: higress.networking.v1.RegistryConfig.metadata:type_name -> higress.networking.v1.RegistryConfig.MetadataEntry
|
||||
4, // 4: higress.networking.v1.InnerMap.inner_map:type_name -> higress.networking.v1.InnerMap.InnerMapEntry
|
||||
2, // 5: higress.networking.v1.RegistryConfig.MetadataEntry.value:type_name -> higress.networking.v1.InnerMap
|
||||
6, // [6:6] is the sub-list for method output_type
|
||||
6, // [6:6] is the sub-list for method input_type
|
||||
6, // [6:6] is the sub-list for extension type_name
|
||||
6, // [6:6] is the sub-list for extension extendee
|
||||
0, // [0:6] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_networking_v1_mcp_bridge_proto_init() }
|
||||
@@ -459,6 +565,18 @@ func file_networking_v1_mcp_bridge_proto_init() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_networking_v1_mcp_bridge_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*InnerMap); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
@@ -466,7 +584,7 @@ func file_networking_v1_mcp_bridge_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_networking_v1_mcp_bridge_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumMessages: 5,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
|
||||
@@ -71,4 +71,11 @@ message RegistryConfig {
|
||||
repeated string mcpServerExportDomains = 20;
|
||||
string mcpServerBaseUrl = 21;
|
||||
google.protobuf.BoolValue enableMCPServer = 22;
|
||||
google.protobuf.BoolValue enableScopeMcpServers = 23;
|
||||
repeated string allowMcpServers = 24;
|
||||
map<string, InnerMap> metadata = 25;
|
||||
}
|
||||
|
||||
message InnerMap {
|
||||
map<string, string> inner_map = 1;
|
||||
}
|
||||
@@ -46,3 +46,24 @@ func (in *RegistryConfig) DeepCopy() *RegistryConfig {
|
||||
func (in *RegistryConfig) DeepCopyInterface() interface{} {
|
||||
return in.DeepCopy()
|
||||
}
|
||||
|
||||
// DeepCopyInto supports using InnerMap within kubernetes types, where deepcopy-gen is used.
|
||||
func (in *InnerMap) DeepCopyInto(out *InnerMap) {
|
||||
p := proto.Clone(in).(*InnerMap)
|
||||
*out = *p
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new InnerMap. Required by controller-gen.
|
||||
func (in *InnerMap) DeepCopy() *InnerMap {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(InnerMap)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInterface is an autogenerated deepcopy function, copying the receiver, creating a new InnerMap. Required by controller-gen.
|
||||
func (in *InnerMap) DeepCopyInterface() interface{} {
|
||||
return in.DeepCopy()
|
||||
}
|
||||
|
||||
@@ -28,6 +28,17 @@ func (this *RegistryConfig) UnmarshalJSON(b []byte) error {
|
||||
return McpBridgeUnmarshaler.Unmarshal(bytes.NewReader(b), this)
|
||||
}
|
||||
|
||||
// MarshalJSON is a custom marshaler for InnerMap
|
||||
func (this *InnerMap) MarshalJSON() ([]byte, error) {
|
||||
str, err := McpBridgeMarshaler.MarshalToString(this)
|
||||
return []byte(str), err
|
||||
}
|
||||
|
||||
// UnmarshalJSON is a custom unmarshaler for InnerMap
|
||||
func (this *InnerMap) UnmarshalJSON(b []byte) error {
|
||||
return McpBridgeUnmarshaler.Unmarshal(bytes.NewReader(b), this)
|
||||
}
|
||||
|
||||
var (
|
||||
McpBridgeMarshaler = &jsonpb.Marshaler{}
|
||||
McpBridgeUnmarshaler = &jsonpb.Unmarshaler{AllowUnknownFields: true}
|
||||
|
||||
Submodule envoy/envoy updated: 17cf01d9f6...583feb54ce
4
go.mod
4
go.mod
@@ -31,7 +31,7 @@ require (
|
||||
github.com/hudl/fargo v1.4.0
|
||||
github.com/mholt/acmez v1.2.0
|
||||
github.com/nacos-group/nacos-sdk-go v1.0.8
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.1.2
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.2
|
||||
github.com/onsi/gomega v1.27.10
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
@@ -202,6 +202,7 @@ require (
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/cast v1.5.1 // indirect
|
||||
github.com/stoewer/go-strcase v1.3.0 // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.3 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
@@ -274,6 +275,5 @@ replace github.com/caddyserver/certmagic => github.com/2456868764/certmagic v1.0
|
||||
|
||||
replace (
|
||||
github.com/dubbogo/gost => github.com/johnlanni/gost v1.11.23-0.20220713132522-0967a24036c6
|
||||
github.com/nacos-group/nacos-sdk-go/v2 => github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-60
|
||||
golang.org/x/exp => golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1434,8 +1434,6 @@ github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de h1:9TO3cAIGXtEhn
|
||||
github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de/go.mod h1:zAbeS9B/r2mtpb6U+EI2rYA5OAXxsYw6wTamcNW+zcE=
|
||||
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
|
||||
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
|
||||
github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-60 h1:FA/azfz2nSkMc1XR8LeqhcAiA/2/sOMcyBGYCTUc+Cs=
|
||||
github.com/luoxiner/nacos-sdk-go/v2 v2.2.9-60/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
|
||||
github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA=
|
||||
github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WVYF7/opLeUgcQs/o=
|
||||
github.com/lyft/protoc-gen-star/v2 v2.0.3/go.mod h1:amey7yeodaJhXSbf/TlLvWiqQfLOSpEk//mLlc+axEk=
|
||||
@@ -1525,6 +1523,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
|
||||
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
|
||||
github.com/nacos-group/nacos-sdk-go v1.0.8 h1:8pEm05Cdav9sQgJSv5kyvlgfz0SzFUUGI3pWX6SiSnM=
|
||||
github.com/nacos-group/nacos-sdk-go v1.0.8/go.mod h1:hlAPn3UdzlxIlSILAyOXKxjFSvDJ9oLzTJ9hLAK1KzA=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.2 h1:9QB2nCJzT5wkTVlxNYl3XL/7+G6p2USMi2gQh/ouQQo=
|
||||
github.com/nacos-group/nacos-sdk-go/v2 v2.3.2/go.mod h1:9FKXl6FqOiVmm72i8kADtbeK71egyG9y3uRDBg41tpQ=
|
||||
github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg=
|
||||
github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU=
|
||||
github.com/nats-io/nats-server/v2 v2.1.2/go.mod h1:Afk+wRZqkMQs/p45uXdrVLuab3gwv3Z8C4HTBu8GD/k=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.1.3
|
||||
appVersion: 2.1.5
|
||||
description: Helm chart for deploying higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -15,4 +15,4 @@ dependencies:
|
||||
repository: "file://../redis"
|
||||
version: 0.0.1
|
||||
type: application
|
||||
version: 2.1.3
|
||||
version: 2.1.5
|
||||
|
||||
@@ -250,6 +250,10 @@ spec:
|
||||
registries:
|
||||
items:
|
||||
properties:
|
||||
allowMcpServers:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
authSecretName:
|
||||
type: string
|
||||
consulDatacenter:
|
||||
@@ -263,6 +267,25 @@ spec:
|
||||
type: string
|
||||
domain:
|
||||
type: string
|
||||
enableMCPServer:
|
||||
type: boolean
|
||||
enableScopeMcpServers:
|
||||
type: boolean
|
||||
mcpServerBaseUrl:
|
||||
type: string
|
||||
mcpServerExportDomains:
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
metadata:
|
||||
additionalProperties:
|
||||
properties:
|
||||
innerMap:
|
||||
additionalProperties:
|
||||
type: string
|
||||
type: object
|
||||
type: object
|
||||
type: object
|
||||
nacosAccessKey:
|
||||
type: string
|
||||
nacosAddressServer:
|
||||
|
||||
@@ -113,3 +113,36 @@ kind: VMPodScrape
|
||||
{{- fail "unexpected gateway.metrics.provider" -}}
|
||||
{{- end -}}
|
||||
{{- end -}}
|
||||
|
||||
{{- define "pluginServer.name" -}}
|
||||
{{- .Values.pluginServer.name | default "higress-plugin-server" -}}
|
||||
{{- end }}
|
||||
|
||||
{{- define "pluginServer.chart" -}}
|
||||
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
|
||||
{{- end }}
|
||||
|
||||
{{- define "pluginServer.labels" -}}
|
||||
helm.sh/chart: {{ include "pluginServer.chart" . }}
|
||||
{{ include "pluginServer.selectorLabels" . }}
|
||||
{{- if .Chart.AppVersion }}
|
||||
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
|
||||
{{- end }}
|
||||
app.kubernetes.io/managed-by: {{ .Release.Service }}
|
||||
app.kubernetes.io/name: {{ include "pluginServer.name" . }}
|
||||
{{- end }}
|
||||
|
||||
{{- define "pluginServer.selectorLabels" -}}
|
||||
{{- if hasKey .Values.pluginServer.labels "app" }}
|
||||
{{- with .Values.pluginServer.labels.app }}app: {{.|quote}}
|
||||
{{- end}}
|
||||
{{- else }}app: {{ include "pluginServer.name" . }}
|
||||
{{- end }}
|
||||
{{- if hasKey .Values.pluginServer.labels "higress" }}
|
||||
{{- with .Values.pluginServer.labels.higress }}
|
||||
higress: {{.|quote}}
|
||||
{{- end}}
|
||||
{{- else }}
|
||||
higress: {{ include "pluginServer.name" . }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
@@ -9,9 +9,7 @@
|
||||
accessLogFile: "/dev/stdout"
|
||||
{{- end }}
|
||||
ingressControllerMode: "OFF"
|
||||
accessLogFormat: '{"ai_log":"%FILTER_STATE(wasm.ai_log:PLAIN)%","authority":"%REQ(X-ENVOY-ORIGINAL-HOST?:AUTHORITY)%","bytes_received":"%BYTES_RECEIVED%","bytes_sent":"%BYTES_SENT%","downstream_local_address":"%DOWNSTREAM_LOCAL_ADDRESS%","downstream_remote_address":"%DOWNSTREAM_REMOTE_ADDRESS%","duration":"%DURATION%","istio_policy_status":"%DYNAMIC_METADATA(istio.mixer:status)%","method":"%REQ(:METHOD)%","path":"%REQ(X-ENVOY-ORIGINAL-PATH?:PATH)%","protocol":"%PROTOCOL%","request_id":"%REQ(X-REQUEST-ID)%","requested_server_name":"%REQUESTED_SERVER_NAME%","response_code":"%RESPONSE_CODE%","response_flags":"%RESPONSE_FLAGS%","route_name":"%ROUTE_NAME%","start_time":"%START_TIME%","trace_id":"%REQ(X-B3-TRACEID)%","upstream_cluster":"%UPSTREAM_CLUSTER%","upstream_host":"%UPSTREAM_HOST%","upstream_local_address":"%UPSTREAM_LOCAL_ADDRESS%","upstream_service_time":"%RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)%","upstream_transport_failure_reason":"%UPSTREAM_TRANSPORT_FAILURE_REASON%","user_agent":"%REQ(USER-AGENT)%","x_forwarded_for":"%REQ(X-FORWARDED-FOR)%","response_code_details":"%RESPONSE_CODE_DETAILS%"}
|
||||
|
||||
'
|
||||
accessLogFormat: '{"ai_log":"%FILTER_STATE(wasm.ai_log:PLAIN)%","authority":"%REQ(X-ENVOY-ORIGINAL-HOST?:AUTHORITY)%","bytes_received":"%BYTES_RECEIVED%","bytes_sent":"%BYTES_SENT%","downstream_local_address":"%DOWNSTREAM_LOCAL_ADDRESS%","downstream_remote_address":"%DOWNSTREAM_REMOTE_ADDRESS%","duration":"%DURATION%","istio_policy_status":"%DYNAMIC_METADATA(istio.mixer:status)%","method":"%REQ(:METHOD)%","path":"%REQ(X-ENVOY-ORIGINAL-PATH?:PATH)%","protocol":"%PROTOCOL%","request_id":"%REQ(X-REQUEST-ID)%","requested_server_name":"%REQUESTED_SERVER_NAME%","response_code":"%RESPONSE_CODE%","response_flags":"%RESPONSE_FLAGS%","route_name":"%ROUTE_NAME%","start_time":"%START_TIME%","trace_id":"%REQ(X-B3-TRACEID)%","upstream_cluster":"%UPSTREAM_CLUSTER%","upstream_host":"%UPSTREAM_HOST%","upstream_local_address":"%UPSTREAM_LOCAL_ADDRESS%","upstream_service_time":"%RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)%","upstream_transport_failure_reason":"%UPSTREAM_TRANSPORT_FAILURE_REASON%","user_agent":"%REQ(USER-AGENT)%","x_forwarded_for":"%REQ(X-FORWARDED-FOR)%","response_code_details":"%RESPONSE_CODE_DETAILS%"}'
|
||||
dnsRefreshRate: 200s
|
||||
enableAutoMtls: false
|
||||
enablePrometheusMerge: false
|
||||
@@ -99,7 +97,7 @@ metadata:
|
||||
name: higress-config
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
{{- include "gateway.labels" . | nindent 4 }}
|
||||
{{- include "gateway.labels" . | nindent 4 }}
|
||||
data:
|
||||
higress: |-
|
||||
{{- $existingConfig := lookup "v1" "ConfigMap" .Release.Namespace "higress-config" }}
|
||||
@@ -126,7 +124,7 @@ data:
|
||||
{{- else }}
|
||||
networks: {}
|
||||
{{- end }}
|
||||
|
||||
|
||||
mesh: |-
|
||||
{{- if .Values.meshConfig }}
|
||||
{{ $mesh | toYaml | indent 4 }}
|
||||
|
||||
@@ -6,4 +6,8 @@ metadata:
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
{{- include "controller.labels" . | nindent 4 }}
|
||||
{{- with .Values.controller.serviceAccount.annotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
39
helm/core/templates/plugin-server-deployment.yaml
Normal file
39
helm/core/templates/plugin-server-deployment.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
{{- if .Values.global.enablePluginServer }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "pluginServer.name" . }}
|
||||
namespace: {{ .Release.Namespace }}
|
||||
spec:
|
||||
replicas: {{ .Values.pluginServer.replicas }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "pluginServer.selectorLabels" . | nindent 6 }}
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
{{- with .Values.pluginServer.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- include "pluginServer.selectorLabels" . | nindent 8 }}
|
||||
spec:
|
||||
{{- with .Values.pluginServer.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: {{ .Chart.Name }}
|
||||
image: {{ .Values.pluginServer.hub | default .Values.global.hub }}/{{ .Values.pluginServer.image | default "plugin-server" }}:{{ .Values.pluginServer.tag | default "1.0.0" }}
|
||||
{{- if .Values.global.imagePullPolicy }}
|
||||
imagePullPolicy: {{ .Values.global.imagePullPolicy }}
|
||||
{{- end }}
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
resources:
|
||||
requests:
|
||||
cpu: {{ .Values.pluginServer.resources.requests.cpu }}
|
||||
memory: {{ .Values.pluginServer.resources.requests.memory }}
|
||||
limits:
|
||||
cpu: {{ .Values.pluginServer.resources.limits.cpu }}
|
||||
memory: {{ .Values.pluginServer.resources.limits.memory }}
|
||||
{{- end }}
|
||||
16
helm/core/templates/plugin-server-service.yaml
Normal file
16
helm/core/templates/plugin-server-service.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
{{- if .Values.global.enablePluginServer }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "pluginServer.name" . }}
|
||||
namespace: {{ .Release.Namespace }}
|
||||
labels:
|
||||
{{- include "pluginServer.labels" . | nindent 4 }}
|
||||
spec:
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: {{ .Values.pluginServer.service.port }}
|
||||
targetPort: 8080
|
||||
selector:
|
||||
{{- include "pluginServer.selectorLabels" . | nindent 4 }}
|
||||
{{- end }}
|
||||
@@ -11,6 +11,7 @@ global:
|
||||
enableSRDS: true
|
||||
# -- Whether to enable Redis(redis-stack-server) for Higress, default is false.
|
||||
enableRedis: false
|
||||
enablePluginServer: false
|
||||
onDemandRDS: false
|
||||
hostRDSMergeSubset: false
|
||||
onlyPushRouteCluster: true
|
||||
@@ -580,8 +581,7 @@ controller:
|
||||
# -- Labels to apply to the pod
|
||||
podLabels: {}
|
||||
|
||||
podSecurityContext:
|
||||
{}
|
||||
podSecurityContext: {}
|
||||
# fsGroup: 2000
|
||||
|
||||
ports:
|
||||
@@ -708,13 +708,13 @@ tracing:
|
||||
enable: false
|
||||
sampling: 100
|
||||
timeout: 500
|
||||
skywalking:
|
||||
# access_token: ""
|
||||
service: ""
|
||||
port: 11800
|
||||
# skywalking:
|
||||
# access_token: ""
|
||||
# service: ""
|
||||
# port: 11800
|
||||
# zipkin:
|
||||
# service: ""
|
||||
# port: 9411
|
||||
# service: ""
|
||||
# port: 9411
|
||||
|
||||
# -- Downstream config settings
|
||||
downstream:
|
||||
@@ -767,4 +767,31 @@ redis:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
# -- Persistent Volume size
|
||||
size: 1Gi
|
||||
size: 1Gi
|
||||
|
||||
pluginServer:
|
||||
name: "higress-plugin-server"
|
||||
# -- Number of Higress Plugin Server pods, 2 recommended for high availability
|
||||
replicas: 2
|
||||
image: plugin-server
|
||||
|
||||
hub: higress-registry.cn-hangzhou.cr.aliyuncs.com/higress
|
||||
tag: ""
|
||||
|
||||
imagePullSecrets: []
|
||||
|
||||
labels: {}
|
||||
# -- Labels to apply to the pod
|
||||
podLabels: {}
|
||||
|
||||
# Plugin-server Service configuration
|
||||
service:
|
||||
port: 80 # Container target port (usually fixed)
|
||||
|
||||
resources:
|
||||
requests:
|
||||
cpu: 200m
|
||||
memory: 128Mi
|
||||
limits:
|
||||
cpu: 500m
|
||||
memory: 256Mi
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: file://../core
|
||||
version: 2.1.3
|
||||
version: 2.1.5
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 2.1.3
|
||||
digest: sha256:c7307d5398c3c1178758c5372bd1aa4cb8dee7beeab3832d3e9ce0a04d1adc23
|
||||
generated: "2025-05-09T15:29:50.616179+08:00"
|
||||
version: 2.1.5
|
||||
digest: sha256:1c7c8003686b2df2c67427054006aef21c92ab1ff86d2e5f5587daf02ebc7d61
|
||||
generated: "2025-07-02T17:38:10.089494+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.1.3
|
||||
appVersion: 2.1.5
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -12,9 +12,9 @@ sources:
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: "file://../core"
|
||||
version: 2.1.3
|
||||
version: 2.1.5
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 2.1.3
|
||||
version: 2.1.5
|
||||
type: application
|
||||
version: 2.1.3
|
||||
version: 2.1.5
|
||||
|
||||
@@ -165,6 +165,7 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| global.enableIPv6 | bool | `false` | |
|
||||
| global.enableIstioAPI | bool | `true` | If true, Higress Controller will monitor istio resources as well |
|
||||
| global.enableLDSCache | bool | `false` | |
|
||||
| global.enablePluginServer | bool | `false` | |
|
||||
| global.enableProxyProtocol | bool | `false` | |
|
||||
| global.enablePushAllMCPClusters | bool | `true` | |
|
||||
| global.enableRedis | bool | `false` | Whether to enable Redis(redis-stack-server) for Higress, default is false. |
|
||||
@@ -273,6 +274,19 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| pilot.serviceAnnotations | object | `{}` | |
|
||||
| pilot.tag | string | `""` | |
|
||||
| pilot.traceSampling | float | `1` | |
|
||||
| pluginServer.hub | string | `"higress-registry.cn-hangzhou.cr.aliyuncs.com/higress"` | |
|
||||
| pluginServer.image | string | `"plugin-server"` | |
|
||||
| pluginServer.imagePullSecrets | list | `[]` | |
|
||||
| pluginServer.labels | object | `{}` | |
|
||||
| pluginServer.name | string | `"higress-plugin-server"` | |
|
||||
| pluginServer.podLabels | object | `{}` | Labels to apply to the pod |
|
||||
| pluginServer.replicas | int | `2` | Number of Higress Plugin Server pods, 2 recommended for high availability |
|
||||
| pluginServer.resources.limits.cpu | string | `"500m"` | |
|
||||
| pluginServer.resources.limits.memory | string | `"256Mi"` | |
|
||||
| pluginServer.resources.requests.cpu | string | `"200m"` | |
|
||||
| pluginServer.resources.requests.memory | string | `"128Mi"` | |
|
||||
| pluginServer.service.port | int | `80` | |
|
||||
| pluginServer.tag | string | `""` | |
|
||||
| redis.redis.affinity | object | `{}` | Affinity for Redis |
|
||||
| redis.redis.image | string | `"redis-stack-server"` | Specify the image |
|
||||
| redis.redis.name | string | `"redis-stack-server"` | |
|
||||
@@ -292,7 +306,5 @@ The command removes all the Kubernetes components associated with the chart and
|
||||
| revision | string | `""` | |
|
||||
| tracing.enable | bool | `false` | |
|
||||
| tracing.sampling | int | `100` | |
|
||||
| tracing.skywalking.port | int | `11800` | |
|
||||
| tracing.skywalking.service | string | `""` | |
|
||||
| tracing.timeout | int | `500` | |
|
||||
| upstream | object | `{"connectionBufferLimits":10485760,"idleTimeout":10}` | Upstream config settings |
|
||||
Submodule istio/istio updated: 1492505b14...1c41de65e9
@@ -63,6 +63,7 @@ import (
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/ingress"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/ingressv1"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpbridge"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpserver"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/secret"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/wasmplugin"
|
||||
@@ -158,6 +159,8 @@ type IngressConfig struct {
|
||||
|
||||
// secretConfigMgr manages secret dependencies
|
||||
secretConfigMgr *SecretConfigMgr
|
||||
|
||||
mcpServerCache mcpserver.McpServerCache
|
||||
}
|
||||
|
||||
// getSecretValue implements the getValue function for secret references
|
||||
@@ -224,6 +227,7 @@ func NewIngressConfig(localKubeClient kube.Client, xdsUpdater istiomodel.XDSUpda
|
||||
|
||||
higressConfigController := configmap.NewController(localKubeClient, clusterId, namespace)
|
||||
config.configmapMgr = configmap.NewConfigmapMgr(xdsUpdater, namespace, higressConfigController, higressConfigController.Lister())
|
||||
config.configmapMgr.RegisterMcpServerProvider(&config.mcpServerCache)
|
||||
|
||||
httpsConfigMgr, _ := cert.NewConfigMgr(namespace, localKubeClient.Kube())
|
||||
config.httpsConfigMgr = httpsConfigMgr
|
||||
@@ -421,6 +425,10 @@ func (m *IngressConfig) createWrapperConfigs(configs []config.Config) []common.W
|
||||
m.watchedSecretSet = globalContext.WatchedSecrets
|
||||
m.mutex.Unlock()
|
||||
|
||||
if m.mcpServerCache.SetMcpServers(globalContext.McpServers) {
|
||||
m.notifyXDSFullUpdate(mcpserver.GvkMcpServer, "mcp-server-annotation-change", nil)
|
||||
}
|
||||
|
||||
return wrapperConfigs
|
||||
}
|
||||
|
||||
@@ -590,7 +598,7 @@ func (m *IngressConfig) convertVirtualService(configs []common.WrapperConfig) []
|
||||
Spec: vs,
|
||||
})
|
||||
}
|
||||
// add vs from naco3 for mcp server
|
||||
// add vs from nacos3 for mcp server
|
||||
if m.RegistryReconciler != nil {
|
||||
allConfigsFromMcp := m.RegistryReconciler.GetAllConfigs(gvk.VirtualService)
|
||||
for _, cfg := range allConfigsFromMcp {
|
||||
@@ -794,23 +802,38 @@ func (m *IngressConfig) convertDestinationRule(configs []common.WrapperConfig) [
|
||||
if !exist {
|
||||
destinationRules[serviceName] = destinationRuleWrapper
|
||||
} else if dr.DestinationRule.TrafficPolicy != nil {
|
||||
if dr.DestinationRule.TrafficPolicy.LoadBalancer == nil &&
|
||||
destinationRuleWrapper.DestinationRule.TrafficPolicy.LoadBalancer != nil {
|
||||
dr.DestinationRule.TrafficPolicy.LoadBalancer = destinationRuleWrapper.DestinationRule.TrafficPolicy.LoadBalancer
|
||||
}
|
||||
portTrafficPolicy := destinationRuleWrapper.DestinationRule.TrafficPolicy.PortLevelSettings[0]
|
||||
portUpdated := false
|
||||
for _, policy := range dr.DestinationRule.TrafficPolicy.PortLevelSettings {
|
||||
if policy.Port.Number == portTrafficPolicy.Port.Number {
|
||||
policy.Tls = portTrafficPolicy.Tls
|
||||
portUpdated = true
|
||||
break
|
||||
// if the service is referenced by an sse type mcp server, an source ip based consistent hashing policy needs to be configured
|
||||
// consistent hashing policy will be generated by mcp server watcher, then if service do not have LoadBalancer settings, it will be merged
|
||||
if destinationRuleWrapper.DestinationRule.TrafficPolicy != nil && destinationRuleWrapper.DestinationRule.TrafficPolicy.LoadBalancer != nil {
|
||||
if dr.DestinationRule.TrafficPolicy.LoadBalancer == nil {
|
||||
dr.DestinationRule.TrafficPolicy.LoadBalancer = destinationRuleWrapper.DestinationRule.TrafficPolicy.LoadBalancer
|
||||
} else if dr.DestinationRule.TrafficPolicy.LoadBalancer.LbPolicy == nil {
|
||||
dr.DestinationRule.TrafficPolicy.LoadBalancer.LbPolicy = destinationRuleWrapper.DestinationRule.TrafficPolicy.LoadBalancer.LbPolicy
|
||||
}
|
||||
}
|
||||
if portUpdated {
|
||||
continue
|
||||
// if the service is referenced by an https type mcp server, an client side simple mode tls policy needs to be configured
|
||||
// simple mode tls policy will be generated by mcp server watcher, then if service do not have tls settings, it will be merged
|
||||
if dr.DestinationRule.TrafficPolicy.Tls == nil && destinationRuleWrapper.DestinationRule.TrafficPolicy != nil &&
|
||||
destinationRuleWrapper.DestinationRule.TrafficPolicy.Tls != nil {
|
||||
dr.DestinationRule.TrafficPolicy.Tls = destinationRuleWrapper.DestinationRule.TrafficPolicy.Tls
|
||||
}
|
||||
// Directly inherit or override the port policy (if it exists)
|
||||
if len(destinationRuleWrapper.DestinationRule.TrafficPolicy.PortLevelSettings) > 0 {
|
||||
portTrafficPolicy := destinationRuleWrapper.DestinationRule.TrafficPolicy.PortLevelSettings[0]
|
||||
portUpdated := false
|
||||
for _, policy := range dr.DestinationRule.TrafficPolicy.PortLevelSettings {
|
||||
if policy.Port.Number == portTrafficPolicy.Port.Number {
|
||||
policy.Tls = portTrafficPolicy.Tls
|
||||
policy.LoadBalancer = portTrafficPolicy.LoadBalancer
|
||||
portUpdated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if portUpdated {
|
||||
continue
|
||||
}
|
||||
dr.DestinationRule.TrafficPolicy.PortLevelSettings = append(dr.DestinationRule.TrafficPolicy.PortLevelSettings, portTrafficPolicy)
|
||||
}
|
||||
dr.DestinationRule.TrafficPolicy.PortLevelSettings = append(dr.DestinationRule.TrafficPolicy.PortLevelSettings, portTrafficPolicy)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1208,9 +1231,9 @@ func (m *IngressConfig) AddOrUpdateMcpBridge(clusterNamespacedName util.ClusterN
|
||||
f(config.Config{Meta: efMetadata}, config.Config{Meta: efMetadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
}, m.localKubeClient, m.namespace, m.clusterId.String())
|
||||
m.configmapMgr.RegisterMcpServerProvider(m.RegistryReconciler)
|
||||
}
|
||||
reconciler := m.RegistryReconciler
|
||||
m.configmapMgr.SetMcpReconciler(m.RegistryReconciler)
|
||||
err = reconciler.Reconcile(mcpbridge)
|
||||
if err != nil {
|
||||
IngressLog.Errorf("Mcpbridge reconcile failed, err:%v", err)
|
||||
@@ -1776,3 +1799,19 @@ func (m *IngressConfig) Patch(config.Config, config.PatchFunc) (string, error) {
|
||||
func (m *IngressConfig) Delete(config.GroupVersionKind, string, string, *string) error {
|
||||
return common.ErrUnsupportedOp
|
||||
}
|
||||
|
||||
func (m *IngressConfig) notifyXDSFullUpdate(gvk config.GroupVersionKind, reason istiomodel.TriggerReason, updatedConfigName *util.ClusterNamespacedName) {
|
||||
var configsUpdated map[istiomodel.ConfigKey]struct{}
|
||||
if updatedConfigName != nil {
|
||||
configsUpdated = map[istiomodel.ConfigKey]struct{}{{
|
||||
Kind: kind.MustFromGVK(gvk),
|
||||
Name: updatedConfigName.Name,
|
||||
Namespace: updatedConfigName.Namespace,
|
||||
}: {}}
|
||||
}
|
||||
m.XDSUpdater.ConfigUpdate(&istiomodel.PushRequest{
|
||||
Full: true,
|
||||
ConfigsUpdated: configsUpdated,
|
||||
Reason: istiomodel.NewReasonStats(reason),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"istio.io/istio/pkg/cluster"
|
||||
"istio.io/istio/pkg/util/sets"
|
||||
listersv1 "k8s.io/client-go/listers/core/v1"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpserver"
|
||||
)
|
||||
|
||||
type GlobalContext struct {
|
||||
@@ -30,6 +32,8 @@ type GlobalContext struct {
|
||||
ClusterSecretLister map[cluster.ID]listersv1.SecretLister
|
||||
|
||||
ClusterServiceList map[cluster.ID]listersv1.ServiceLister
|
||||
|
||||
McpServers []*mcpserver.McpServer
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
@@ -169,6 +173,7 @@ func NewAnnotationHandlerManager() AnnotationHandler {
|
||||
match{},
|
||||
headerControl{},
|
||||
http2rpc{},
|
||||
mcpServer{},
|
||||
},
|
||||
gatewayHandlers: []GatewayHandler{
|
||||
downstreamTLS{},
|
||||
|
||||
94
pkg/ingress/kube/annotations/mcpserver.go
Normal file
94
pkg/ingress/kube/annotations/mcpserver.go
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) 2023 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package annotations
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpserver"
|
||||
"github.com/alibaba/higress/pkg/ingress/log"
|
||||
)
|
||||
|
||||
const (
|
||||
enableMcpServer = "mcp-server"
|
||||
mcpServerMatchRuleDomains = "mcp-server-match-rule-domains"
|
||||
mcpServerMatchRuleType = "mcp-server-match-rule-type"
|
||||
mcpServerMatchRuleValue = "mcp-server-match-rule-value"
|
||||
mcpServerUpstreamType = "mcp-server-upstream-type"
|
||||
mcpServerEnablePathRewrite = "mcp-server-enable-path-rewrite"
|
||||
mcpServerPathRewritePrefix = "mcp-server-path-rewrite-prefix"
|
||||
)
|
||||
|
||||
// help to conform mcpServer implements method of Parse
|
||||
var (
|
||||
_ Parser = &mcpServer{}
|
||||
)
|
||||
|
||||
type mcpServer struct{}
|
||||
|
||||
func (a mcpServer) Parse(annotations Annotations, config *Ingress, globalContext *GlobalContext) error {
|
||||
if globalContext == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ingressKey := config.Namespace + "/" + config.Name
|
||||
|
||||
enabled, _ := annotations.ParseBoolASAP(enableMcpServer)
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
var matchRuleDomains []string
|
||||
rawMatchRuleDomains, _ := annotations.ParseStringASAP(mcpServerMatchRuleDomains)
|
||||
if rawMatchRuleDomains == "" || rawMatchRuleDomains == "*" {
|
||||
// Match all domains. Leave an empty slice.
|
||||
} else if strings.Contains(rawMatchRuleDomains, ",") {
|
||||
matchRuleDomains = strings.Split(rawMatchRuleDomains, ",")
|
||||
} else {
|
||||
matchRuleDomains = []string{rawMatchRuleDomains}
|
||||
}
|
||||
|
||||
matchRuleType, _ := annotations.ParseStringASAP(mcpServerMatchRuleType)
|
||||
if matchRuleType == "" {
|
||||
log.IngressLog.Errorf("ingress %s: mcp-server-match-rule-path-type is empty", ingressKey)
|
||||
return nil
|
||||
} else if !mcpserver.ValidPathMatchTypes[matchRuleType] {
|
||||
log.IngressLog.Errorf("ingress %s: mcp-server-match-rule-path-type %s is not supported", ingressKey, matchRuleType)
|
||||
return nil
|
||||
}
|
||||
|
||||
matchRuleValue, _ := annotations.ParseStringASAP(mcpServerMatchRuleValue)
|
||||
|
||||
upstreamType, _ := annotations.ParseStringASAP(mcpServerUpstreamType)
|
||||
if upstreamType != "" && !mcpserver.ValidUpstreamTypes[upstreamType] {
|
||||
log.IngressLog.Errorf("mcp-server-upstream-type %s is not supported", upstreamType)
|
||||
return nil
|
||||
}
|
||||
|
||||
enablePathRewrite, _ := annotations.ParseBoolASAP(mcpServerEnablePathRewrite)
|
||||
pathRewritePrefix, _ := annotations.ParseStringASAP(mcpServerPathRewritePrefix)
|
||||
|
||||
globalContext.McpServers = append(globalContext.McpServers, &mcpserver.McpServer{
|
||||
Name: ingressKey,
|
||||
Domains: matchRuleDomains,
|
||||
PathMatchType: matchRuleType,
|
||||
PathMatchValue: matchRuleValue,
|
||||
UpstreamType: upstreamType,
|
||||
EnablePathRewrite: enablePathRewrite,
|
||||
PathRewritePrefix: pathRewritePrefix,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
257
pkg/ingress/kube/annotations/mcpserver_test.go
Normal file
257
pkg/ingress/kube/annotations/mcpserver_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright (c) 2025 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package annotations
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpserver"
|
||||
)
|
||||
|
||||
func TestMCPServer_Parse(t *testing.T) {
|
||||
parser := mcpServer{}
|
||||
testCases := []struct {
|
||||
skip bool
|
||||
input Annotations
|
||||
expect *mcpserver.McpServer
|
||||
}{
|
||||
{
|
||||
// No annotation
|
||||
input: Annotations{},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
// Not enabled
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "false",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "",
|
||||
},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
// Enabled but no match rule type
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "",
|
||||
},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
// Enabled but empty match rule type
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "",
|
||||
},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
// Enabled but bad match rule type
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "bad-type",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "",
|
||||
},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
// Enabled but bad upstream type
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "bad-type",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "",
|
||||
},
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
// Enabled and rewrite not enabled
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "false",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "/",
|
||||
},
|
||||
expect: &mcpserver.McpServer{
|
||||
Name: "default/route",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: "prefix",
|
||||
PathMatchValue: "/mcp",
|
||||
UpstreamType: "rest",
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Enabled and rewrite not enabled and empty domain
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "false",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "/",
|
||||
},
|
||||
expect: &mcpserver.McpServer{
|
||||
Name: "default/route",
|
||||
Domains: nil,
|
||||
PathMatchType: "prefix",
|
||||
PathMatchValue: "/mcp",
|
||||
UpstreamType: "rest",
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Enabled and rewrite not enabled and wildcard domain
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "*",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "false",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "/",
|
||||
},
|
||||
expect: &mcpserver.McpServer{
|
||||
Name: "default/route",
|
||||
Domains: nil,
|
||||
PathMatchType: "prefix",
|
||||
PathMatchValue: "/mcp",
|
||||
UpstreamType: "rest",
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Enabled and rewrite enabled with root
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "true",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "/",
|
||||
},
|
||||
expect: &mcpserver.McpServer{
|
||||
Name: "default/route",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: "prefix",
|
||||
PathMatchValue: "/mcp",
|
||||
UpstreamType: "rest",
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Enabled and rewrite enabled with root
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "prefix",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "rest",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "true",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "/mcp-api",
|
||||
},
|
||||
expect: &mcpserver.McpServer{
|
||||
Name: "default/route",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: "prefix",
|
||||
PathMatchValue: "/mcp",
|
||||
UpstreamType: "rest",
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/mcp-api",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Enabled and multiple domains
|
||||
input: Annotations{
|
||||
buildHigressAnnotationKey(enableMcpServer): "true",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleDomains): "www.foo.com,www.bar.com",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleType): "exact",
|
||||
buildHigressAnnotationKey(mcpServerMatchRuleValue): "/mcp",
|
||||
buildHigressAnnotationKey(mcpServerUpstreamType): "sse",
|
||||
buildHigressAnnotationKey(mcpServerEnablePathRewrite): "true",
|
||||
buildHigressAnnotationKey(mcpServerPathRewritePrefix): "/",
|
||||
},
|
||||
expect: &mcpserver.McpServer{
|
||||
Name: "default/route",
|
||||
Domains: []string{"www.foo.com", "www.bar.com"},
|
||||
PathMatchType: "exact",
|
||||
PathMatchValue: "/mcp",
|
||||
UpstreamType: "sse",
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
if tt.skip {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("", func(t *testing.T) {
|
||||
config := &Ingress{Meta: Meta{
|
||||
Namespace: "default",
|
||||
Name: "route",
|
||||
}}
|
||||
globalContext := &GlobalContext{}
|
||||
_ = parser.Parse(tt.input, config, globalContext)
|
||||
if tt.expect == nil {
|
||||
if len(globalContext.McpServers) != 0 {
|
||||
t.Fatalf("globalContext.McpServers is not empty: %v", globalContext.McpServers)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(globalContext.McpServers) != 1 {
|
||||
t.Fatalf("globalContext.McpServers length is not 1: %v", globalContext.McpServers)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expect, globalContext.McpServers[0]); diff != "" {
|
||||
t.Fatalf("TestMCPServer_Parse() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/alibaba/higress/registry/reconcile"
|
||||
"istio.io/istio/pilot/pkg/model"
|
||||
"istio.io/istio/pkg/cluster"
|
||||
"istio.io/istio/pkg/config"
|
||||
@@ -33,6 +32,7 @@ import (
|
||||
"sigs.k8s.io/yaml"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/controller"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpserver"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
)
|
||||
@@ -59,7 +59,6 @@ type ItemController interface {
|
||||
ValidHigressConfig(higressConfig *HigressConfig) error
|
||||
ConstructEnvoyFilters() ([]*config.Config, error)
|
||||
RegisterItemEventHandler(eventHandler ItemEventHandler)
|
||||
RegisterMcpReconciler(reconciler *reconcile.Reconciler)
|
||||
}
|
||||
|
||||
type ConfigmapMgr struct {
|
||||
@@ -113,9 +112,11 @@ func (c *ConfigmapMgr) GetHigressConfig() *HigressConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ConfigmapMgr) SetMcpReconciler(reconciler *reconcile.Reconciler) {
|
||||
func (c *ConfigmapMgr) RegisterMcpServerProvider(provider mcpserver.McpServerProvider) {
|
||||
for _, itemController := range c.ItemControllers {
|
||||
itemController.RegisterMcpReconciler(reconciler)
|
||||
if mcpRouteProviderAware, ok := itemController.(mcpserver.McpRouteProviderAware); ok {
|
||||
mcpRouteProviderAware.RegisterMcpServerProvider(provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
"github.com/alibaba/higress/registry/reconcile"
|
||||
networking "istio.io/api/networking/v1alpha3"
|
||||
"istio.io/istio/pkg/config"
|
||||
"istio.io/istio/pkg/config/schema/gvk"
|
||||
@@ -377,9 +376,6 @@ func (g *GlobalOptionController) RegisterItemEventHandler(eventHandler ItemEvent
|
||||
g.eventHandler = eventHandler
|
||||
}
|
||||
|
||||
func (g *GlobalOptionController) RegisterMcpReconciler(reconciler *reconcile.Reconciler) {
|
||||
}
|
||||
|
||||
// generateDownstreamEnvoyFilter generates the downstream envoy filter.
|
||||
func (g *GlobalOptionController) generateDownstreamEnvoyFilter(downstreamValueStruct string, bufferLimitStruct string, routeTimeoutStruct string, namespace string) []*networking.EnvoyFilter_EnvoyConfigObjectPatch {
|
||||
var downstreamConfig []*networking.EnvoyFilter_EnvoyConfigObjectPatch
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
"github.com/alibaba/higress/registry/reconcile"
|
||||
networking "istio.io/api/networking/v1alpha3"
|
||||
"istio.io/istio/pkg/config"
|
||||
"istio.io/istio/pkg/config/schema/gvk"
|
||||
@@ -292,9 +291,6 @@ func (g *GzipController) RegisterItemEventHandler(eventHandler ItemEventHandler)
|
||||
g.eventHandler = eventHandler
|
||||
}
|
||||
|
||||
func (g *GzipController) RegisterMcpReconciler(reconciler *reconcile.Reconciler) {
|
||||
}
|
||||
|
||||
func (g *GzipController) constructGzipStruct(gzip *Gzip, namespace string) string {
|
||||
gzipConfig := ""
|
||||
contentType := ""
|
||||
|
||||
@@ -22,12 +22,13 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
"github.com/alibaba/higress/registry/reconcile"
|
||||
networking "istio.io/api/networking/v1alpha3"
|
||||
"istio.io/istio/pkg/config"
|
||||
"istio.io/istio/pkg/config/schema/gvk"
|
||||
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/mcpserver"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/util"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
)
|
||||
|
||||
// RedisConfig defines the configuration for Redis connection
|
||||
@@ -232,18 +233,19 @@ func deepCopyMcpServer(mcp *McpServer) (*McpServer, error) {
|
||||
}
|
||||
|
||||
type McpServerController struct {
|
||||
Namespace string
|
||||
mcpServer atomic.Value
|
||||
Name string
|
||||
eventHandler ItemEventHandler
|
||||
reconciler *reconcile.Reconciler
|
||||
Namespace string
|
||||
mcpServer atomic.Value
|
||||
Name string
|
||||
eventHandler ItemEventHandler
|
||||
mcpServerProviders map[mcpserver.McpServerProvider]bool
|
||||
}
|
||||
|
||||
func NewMcpServerController(namespace string) *McpServerController {
|
||||
mcpController := &McpServerController{
|
||||
Namespace: namespace,
|
||||
mcpServer: atomic.Value{},
|
||||
Name: "mcpServer",
|
||||
Namespace: namespace,
|
||||
Name: "mcpServer",
|
||||
mcpServer: atomic.Value{},
|
||||
mcpServerProviders: make(map[mcpserver.McpServerProvider]bool),
|
||||
}
|
||||
mcpController.SetMcpServer(NewDefaultMcpServer())
|
||||
return mcpController
|
||||
@@ -310,8 +312,11 @@ func (m *McpServerController) RegisterItemEventHandler(eventHandler ItemEventHan
|
||||
m.eventHandler = eventHandler
|
||||
}
|
||||
|
||||
func (m *McpServerController) RegisterMcpReconciler(reconciler *reconcile.Reconciler) {
|
||||
m.reconciler = reconciler
|
||||
func (m *McpServerController) RegisterMcpServerProvider(provider mcpserver.McpServerProvider) {
|
||||
if m.mcpServerProviders == nil {
|
||||
m.mcpServerProviders = make(map[mcpserver.McpServerProvider]bool)
|
||||
}
|
||||
m.mcpServerProviders[provider] = true
|
||||
}
|
||||
|
||||
func (m *McpServerController) ConstructEnvoyFilters() ([]*config.Config, error) {
|
||||
@@ -406,10 +411,36 @@ func (m *McpServerController) ConstructEnvoyFilters() ([]*config.Config, error)
|
||||
|
||||
func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
|
||||
// Build match_list configuration
|
||||
matchList := "[]"
|
||||
var matchConfigs []string
|
||||
if len(mcp.MatchList) > 0 {
|
||||
for _, rule := range mcp.MatchList {
|
||||
var matchList []*MatchRule
|
||||
matchList = append(matchList, mcp.MatchList...)
|
||||
for provider, _ := range m.mcpServerProviders {
|
||||
servers := provider.GetMcpServers()
|
||||
if len(servers) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, server := range servers {
|
||||
matchRuleDomain := ""
|
||||
if len(server.Domains) != 0 {
|
||||
if len(server.Domains) > 1 {
|
||||
matchRuleDomain = fmt.Sprintf("(%s)", strings.Join(server.Domains, "|"))
|
||||
} else {
|
||||
matchRuleDomain = server.Domains[0]
|
||||
}
|
||||
}
|
||||
matchList = append(matchList, &MatchRule{
|
||||
MatchRuleDomain: matchRuleDomain,
|
||||
MatchRuleType: server.PathMatchType,
|
||||
MatchRulePath: server.PathMatchValue,
|
||||
UpstreamType: server.UpstreamType,
|
||||
EnablePathRewrite: server.EnablePathRewrite,
|
||||
PathRewritePrefix: server.PathRewritePrefix,
|
||||
})
|
||||
}
|
||||
}
|
||||
matchListConfig := "[]"
|
||||
if len(matchList) > 0 {
|
||||
matchConfigs := make([]string, 0, len(matchList))
|
||||
for _, rule := range matchList {
|
||||
matchConfigs = append(matchConfigs, fmt.Sprintf(`{
|
||||
"match_rule_domain": "%s",
|
||||
"match_rule_path": "%s",
|
||||
@@ -419,28 +450,9 @@ func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
|
||||
"path_rewrite_prefix": "%s"
|
||||
}`, rule.MatchRuleDomain, rule.MatchRulePath, rule.MatchRuleType, rule.UpstreamType, rule.EnablePathRewrite, rule.PathRewritePrefix))
|
||||
}
|
||||
matchListConfig = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
|
||||
}
|
||||
|
||||
if m.reconciler != nil {
|
||||
vsFromMcp := m.reconciler.GetAllConfigs(gvk.VirtualService)
|
||||
for _, c := range vsFromMcp {
|
||||
vs := c.Spec.(*networking.VirtualService)
|
||||
var host string
|
||||
if len(vs.Hosts) > 1 {
|
||||
host = fmt.Sprintf("(%s)", strings.Join(vs.Hosts, "|"))
|
||||
} else {
|
||||
host = vs.Hosts[0]
|
||||
}
|
||||
path := vs.Http[0].Match[0].Uri.GetPrefix()
|
||||
matchConfigs = append(matchConfigs, fmt.Sprintf(`{
|
||||
"match_rule_domain": "%s",
|
||||
"match_rule_path": "%s",
|
||||
"match_rule_type": "prefix"
|
||||
}`, host, path))
|
||||
}
|
||||
}
|
||||
matchList = fmt.Sprintf("[%s]", strings.Join(matchConfigs, ","))
|
||||
|
||||
// Build redis configuration
|
||||
redisConfig := "null"
|
||||
if mcp.Redis != nil {
|
||||
@@ -492,7 +504,7 @@ func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
|
||||
redisConfig,
|
||||
rateLimitConfig,
|
||||
mcp.SSEPathSuffix,
|
||||
matchList,
|
||||
matchListConfig,
|
||||
mcp.EnableUserLevelServer)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/alibaba/higress/registry/reconcile"
|
||||
"istio.io/istio/pkg/config"
|
||||
"istio.io/istio/pkg/config/schema/gvk"
|
||||
|
||||
@@ -238,9 +237,6 @@ func (t *TracingController) RegisterItemEventHandler(eventHandler ItemEventHandl
|
||||
t.eventHandler = eventHandler
|
||||
}
|
||||
|
||||
func (t *TracingController) RegisterMcpReconciler(reconciler *reconcile.Reconciler) {
|
||||
}
|
||||
|
||||
func (t *TracingController) ConstructEnvoyFilters() ([]*config.Config, error) {
|
||||
configs := make([]*config.Config, 0)
|
||||
tracing := t.GetTracing()
|
||||
|
||||
60
pkg/ingress/kube/mcpserver/model.go
Normal file
60
pkg/ingress/kube/mcpserver/model.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright (c) 2025 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"istio.io/istio/pkg/config"
|
||||
)
|
||||
|
||||
var (
|
||||
GvkMcpServer = config.GroupVersionKind{Group: "networking.higress.io", Version: "v1alpha1", Kind: "McpServer"}
|
||||
)
|
||||
|
||||
const (
|
||||
UpstreamTypeRest string = "rest"
|
||||
UpstreamTypeSSE string = "sse"
|
||||
UpstreamTypeStreamable string = "streamable"
|
||||
|
||||
ExactMatchType string = "exact"
|
||||
PrefixMatchType string = "prefix"
|
||||
SuffixMatchType string = "suffix"
|
||||
ContainsMatchType string = "contains"
|
||||
RegexMatchType string = "regex"
|
||||
)
|
||||
|
||||
var (
|
||||
ValidUpstreamTypes = map[string]bool{
|
||||
UpstreamTypeRest: true,
|
||||
UpstreamTypeSSE: true,
|
||||
UpstreamTypeStreamable: true,
|
||||
}
|
||||
ValidPathMatchTypes = map[string]bool{
|
||||
ExactMatchType: true,
|
||||
PrefixMatchType: true,
|
||||
SuffixMatchType: true,
|
||||
ContainsMatchType: true,
|
||||
RegexMatchType: true,
|
||||
}
|
||||
)
|
||||
|
||||
type McpServer struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Domains []string `json:"domains,omitempty"`
|
||||
PathMatchType string `json:"path_match_type,omitempty"`
|
||||
PathMatchValue string `json:"path_match_value,omitempty"`
|
||||
UpstreamType string `json:"upstream_type,omitempty"`
|
||||
EnablePathRewrite bool `json:"enable_path_rewrite,omitempty"`
|
||||
PathRewritePrefix string `json:"path_rewrite_prefix,omitempty"`
|
||||
}
|
||||
70
pkg/ingress/kube/mcpserver/provider.go
Normal file
70
pkg/ingress/kube/mcpserver/provider.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Copyright (c) 2025 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type McpServerProvider interface {
|
||||
GetMcpServers() []*McpServer
|
||||
}
|
||||
|
||||
type McpRouteProviderAware interface {
|
||||
RegisterMcpServerProvider(provider McpServerProvider)
|
||||
}
|
||||
|
||||
type McpServerCache struct {
|
||||
mcpServers []*McpServer
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (c *McpServerCache) GetMcpServers() []*McpServer {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
return c.mcpServers
|
||||
}
|
||||
|
||||
// SetMcpServers sets the mcp servers and returns true if the cached list is changed
|
||||
func (c *McpServerCache) SetMcpServers(mcpServers []*McpServer) bool {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
sortedMcpServers := make([]*McpServer, 0, len(mcpServers))
|
||||
sortedMcpServers = append(sortedMcpServers, mcpServers...)
|
||||
// Sort the mcp servers by PathMatchValue in descending order
|
||||
slices.SortFunc(sortedMcpServers, func(a, b *McpServer) int {
|
||||
return strings.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
if len(c.mcpServers) == len(sortedMcpServers) {
|
||||
changed := false
|
||||
for i := range c.mcpServers {
|
||||
if !reflect.DeepEqual(c.mcpServers[i], sortedMcpServers[i]) {
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
c.mcpServers = sortedMcpServers
|
||||
return true
|
||||
}
|
||||
654
pkg/ingress/kube/mcpserver/provider_test.go
Normal file
654
pkg/ingress/kube/mcpserver/provider_test.go
Normal file
@@ -0,0 +1,654 @@
|
||||
// Copyright (c) 2025 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mcpserver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestMcpServerCache_GetSet(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
skip bool
|
||||
init []*McpServer
|
||||
input []*McpServer
|
||||
expect []*McpServer
|
||||
changed bool
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
init: nil,
|
||||
input: nil,
|
||||
changed: false,
|
||||
expect: nil,
|
||||
},
|
||||
{
|
||||
name: "nil to non-nil",
|
||||
init: nil,
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
changed: true,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil to non-nil (length increase)",
|
||||
init: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
},
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
changed: true,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil to non-nil (length decrease)",
|
||||
init: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
},
|
||||
changed: true,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil to non-nil (length unchanged + name field changed)",
|
||||
init: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3-1",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
changed: true,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3-1",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil to non-nil (length unchanged + non-name field changed)",
|
||||
init: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar-2.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test4",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
changed: true,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar-2.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test4",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil to non-nil (content unchanged + order unchanged)",
|
||||
init: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
changed: false,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-nil to non-nil (content unchanged + order changed)",
|
||||
init: []*McpServer{
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
},
|
||||
input: []*McpServer{
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
},
|
||||
changed: false,
|
||||
expect: []*McpServer{
|
||||
{
|
||||
Name: "test1",
|
||||
Domains: nil,
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test1",
|
||||
UpstreamType: UpstreamTypeRest,
|
||||
EnablePathRewrite: false,
|
||||
PathRewritePrefix: "",
|
||||
},
|
||||
{
|
||||
Name: "test2",
|
||||
Domains: []string{"www.foo.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test2",
|
||||
UpstreamType: UpstreamTypeSSE,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/test",
|
||||
},
|
||||
{
|
||||
Name: "test3",
|
||||
Domains: []string{"www.bar.com"},
|
||||
PathMatchType: ExactMatchType,
|
||||
PathMatchValue: "/mcp/test3",
|
||||
UpstreamType: UpstreamTypeStreamable,
|
||||
EnablePathRewrite: true,
|
||||
PathRewritePrefix: "/",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
if tt.skip {
|
||||
continue
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := &McpServerCache{}
|
||||
|
||||
if provider.GetMcpServers() != nil {
|
||||
t.Fatalf("GetMcpServers doesn't return nil before testing.")
|
||||
}
|
||||
|
||||
_ = provider.SetMcpServers(tt.init)
|
||||
|
||||
changed := provider.SetMcpServers(tt.input)
|
||||
if changed != tt.changed {
|
||||
t.Fatalf("actual changed %t != expect changed %t", changed, tt.changed)
|
||||
return
|
||||
}
|
||||
|
||||
actual := provider.GetMcpServers()
|
||||
|
||||
if len(actual) != len(tt.expect) {
|
||||
t.Fatalf("actual length %d != expect length %d", len(actual), len(tt.expect))
|
||||
}
|
||||
for i := range actual {
|
||||
if diff := cmp.Diff(tt.expect[i], actual[i]); diff != "" {
|
||||
t.Fatalf("TestMcpServerCache_GetSet() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,14 @@ type DBClient struct {
|
||||
panicCount int32 // Add panic counter
|
||||
}
|
||||
|
||||
// supports database types
|
||||
const (
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
CLICKHOUSE = "clickhouse"
|
||||
SQLITE = "sqlite"
|
||||
)
|
||||
|
||||
// NewDBClient creates a new DBClient instance and establishes a connection to the database
|
||||
func NewDBClient(dsn string, dbType string, stop chan struct{}) *DBClient {
|
||||
client := &DBClient{
|
||||
@@ -53,13 +61,13 @@ func (c *DBClient) connect() error {
|
||||
}
|
||||
|
||||
switch c.dbType {
|
||||
case "postgres":
|
||||
case POSTGRES:
|
||||
db, err = gorm.Open(postgres.Open(c.dsn), &gormConfig)
|
||||
case "clickhouse":
|
||||
case CLICKHOUSE:
|
||||
db, err = gorm.Open(clickhouse.Open(c.dsn), &gormConfig)
|
||||
case "mysql":
|
||||
case MYSQL:
|
||||
db, err = gorm.Open(mysql.Open(c.dsn), &gormConfig)
|
||||
case "sqlite":
|
||||
case SQLITE:
|
||||
db, err = gorm.Open(sqlite.Open(c.dsn), &gormConfig)
|
||||
default:
|
||||
return fmt.Errorf("unsupported database type %s", c.dbType)
|
||||
@@ -125,25 +133,166 @@ func (c *DBClient) reconnectLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteSQL executes a raw SQL query and returns the result as a slice of maps
|
||||
func (c *DBClient) ExecuteSQL(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
func (c *DBClient) reconnectIfDbEmpty() error {
|
||||
if c.db == nil {
|
||||
// Trigger reconnection
|
||||
select {
|
||||
case c.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("database is not connected, attempting to reconnect")
|
||||
return fmt.Errorf("database is not connected, attempting to reconnect")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(query, args...).Rows()
|
||||
func (c *DBClient) handleSQLError(err error) error {
|
||||
if err != nil {
|
||||
// If execution fails, connection might be lost, trigger reconnection
|
||||
select {
|
||||
case c.reconnect <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
return fmt.Errorf("failed to execute SQL: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DescribeTable Get the structure of a specific table.
|
||||
func (c *DBClient) DescribeTable(table string) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []string
|
||||
switch c.dbType {
|
||||
case MYSQL:
|
||||
sql = `
|
||||
select
|
||||
column_name,
|
||||
column_type,
|
||||
is_nullable,
|
||||
column_key,
|
||||
column_default,
|
||||
extra,
|
||||
column_comment
|
||||
from information_schema.columns
|
||||
where table_schema = database() and table_name = ?
|
||||
`
|
||||
args = []string{table}
|
||||
|
||||
case POSTGRES:
|
||||
sql = `
|
||||
select
|
||||
column_name,
|
||||
data_type as column_type,
|
||||
is_nullable,
|
||||
case
|
||||
when column_default like 'nextval%%' then 'auto_increment'
|
||||
when column_default is not null then 'default'
|
||||
else ''
|
||||
end as column_key,
|
||||
column_default,
|
||||
case
|
||||
when column_default like 'nextval%%' then 'auto_increment'
|
||||
else ''
|
||||
end as extra,
|
||||
col_description((select oid from pg_class where relname = ?), ordinal_position) as column_comment
|
||||
from information_schema.columns
|
||||
where table_name = ?
|
||||
`
|
||||
args = []string{table, table}
|
||||
|
||||
case CLICKHOUSE:
|
||||
sql = `
|
||||
select
|
||||
name as column_name,
|
||||
type as column_type,
|
||||
if(is_nullable, 'YES', 'NO') as is_nullable,
|
||||
default_kind as column_key,
|
||||
default_expression as column_default,
|
||||
default_kind as extra,
|
||||
comment as column_comment
|
||||
from system.columns
|
||||
where database = currentDatabase() and table = ?
|
||||
`
|
||||
args = []string{table}
|
||||
|
||||
case SQLITE:
|
||||
sql = `
|
||||
select
|
||||
name as column_name,
|
||||
type as column_type,
|
||||
not (notnull = 1) as is_nullable,
|
||||
pk as column_key,
|
||||
dflt_value as column_default,
|
||||
'' as extra,
|
||||
'' as column_comment
|
||||
from pragma_table_info(?)
|
||||
`
|
||||
args = []string{table}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", c.dbType)
|
||||
}
|
||||
|
||||
return c.Query(sql, args)
|
||||
}
|
||||
|
||||
// ListTables List all tables in the connected database.
|
||||
func (c *DBClient) ListTables() ([]string, error) {
|
||||
var sql string
|
||||
switch c.dbType {
|
||||
case MYSQL:
|
||||
sql = "show tables"
|
||||
case POSTGRES:
|
||||
sql = "select tablename from pg_tables where schemaname = 'public'"
|
||||
case CLICKHOUSE:
|
||||
sql = "select name from system.tables where database = currentDatabase()"
|
||||
case SQLITE:
|
||||
sql = "select name from sqlite_master where type='table'"
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database type: %s", c.dbType)
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(sql).Rows()
|
||||
if err := c.handleSQLError(err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var table string
|
||||
if err := rows.Scan(&table); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan table name: %w", err)
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// Execute executes an INSERT, UPDATE, or DELETE raw SQL and returns the rows affected
|
||||
func (c *DBClient) Execute(sql string, args ...interface{}) (int64, error) {
|
||||
if err := c.reconnectIfDbEmpty(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tx := c.db.Exec(sql, args...)
|
||||
if err := c.handleSQLError(tx.Error); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer tx.Commit()
|
||||
|
||||
return tx.RowsAffected, nil
|
||||
}
|
||||
|
||||
// Query executes a raw SQL query and returns the result as a slice of maps
|
||||
func (c *DBClient) Query(sql string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
if err := c.reconnectIfDbEmpty(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := c.db.Raw(sql, args...).Rows()
|
||||
if err := c.handleSQLError(err); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
||||
@@ -49,11 +49,24 @@ func (c *DBConfig) NewServer(serverName string) (*common.MCPServer, error) {
|
||||
)
|
||||
|
||||
dbClient := NewDBClient(c.dsn, c.dbType, mcpServer.GetDestoryChannel())
|
||||
descriptionSuffix := fmt.Sprintf("in database %s. Database description: %s", c.dbType, c.description)
|
||||
// Add query tool
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query in database %s. Database description: %s", c.dbType, c.description), GetQueryToolSchema()),
|
||||
mcp.NewToolWithRawSchema("query", fmt.Sprintf("Run a read-only SQL query %s", descriptionSuffix), GetQueryToolSchema()),
|
||||
HandleQueryTool(dbClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("execute", fmt.Sprintf("Execute an insert, update, or delete SQL %s", descriptionSuffix), GetExecuteToolSchema()),
|
||||
HandleExecuteTool(dbClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("list tables", fmt.Sprintf("List all tables %s", descriptionSuffix), GetListTablesToolSchema()),
|
||||
HandleListTablesTool(dbClient),
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewToolWithRawSchema("describe table", fmt.Sprintf("Get the structure of a specific table %s", descriptionSuffix), GetDescribeTableToolSchema()),
|
||||
HandleDescribeTableTool(dbClient),
|
||||
)
|
||||
|
||||
return mcpServer, nil
|
||||
}
|
||||
|
||||
@@ -18,27 +18,80 @@ func HandleQueryTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return nil, fmt.Errorf("invalid message argument")
|
||||
}
|
||||
|
||||
results, err := dbClient.ExecuteSQL(message)
|
||||
results, err := dbClient.Query(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleExecuteTool handles SQL INSERT, UPDATE, or DELETE execution
|
||||
func HandleExecuteTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
message, ok := arguments["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message argument")
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
results, err := dbClient.Execute(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleListTablesTool handles list all tables
|
||||
func HandleListTablesTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
results, err := dbClient.ListTables()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDescribeTableTool handles describe table
|
||||
func HandleDescribeTableTool(dbClient *DBClient) common.ToolHandlerFunc {
|
||||
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
arguments := request.Params.Arguments
|
||||
message, ok := arguments["table"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid message argument")
|
||||
}
|
||||
|
||||
results, err := dbClient.DescribeTable(message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute SQL query: %w", err)
|
||||
}
|
||||
|
||||
return buildCallToolResult(results)
|
||||
}
|
||||
}
|
||||
|
||||
// buildCallToolResult builds the call tool result
|
||||
func buildCallToolResult(results any) (*mcp.CallToolResult, error) {
|
||||
jsonData, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal SQL results: %w", err)
|
||||
}
|
||||
|
||||
return &mcp.CallToolResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: string(jsonData),
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetQueryToolSchema returns the schema for query tool
|
||||
func GetQueryToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
@@ -53,3 +106,44 @@ func GetQueryToolSchema() json.RawMessage {
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetExecuteToolSchema returns the schema for execute tool
|
||||
func GetExecuteToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {
|
||||
"type": "string",
|
||||
"description": "The sql to execute"
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetDescribeTableToolSchema returns the schema for DescribeTable tool
|
||||
func GetDescribeTableToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table": {
|
||||
"type": "string",
|
||||
"description": "table name"
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetListTablesToolSchema returns the schema for ListTables tool
|
||||
func GetListTablesToolSchema() json.RawMessage {
|
||||
return json.RawMessage(`
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
}
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -94,13 +95,15 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
defer s.sessions.Delete(sessionID)
|
||||
|
||||
channel := GetSSEChannelName(sessionID)
|
||||
u, err := url.Parse(s.baseURL + s.messageEndpoint)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to parse base URL: %v", err)
|
||||
}
|
||||
|
||||
messageEndpoint := fmt.Sprintf(
|
||||
"%s%s?sessionId=%s",
|
||||
s.baseURL,
|
||||
s.messageEndpoint,
|
||||
sessionID,
|
||||
)
|
||||
q := u.Query()
|
||||
q.Set("sessionId", sessionID)
|
||||
u.RawQuery = q.Encode()
|
||||
messageEndpoint := u.String()
|
||||
|
||||
// go func() {
|
||||
// for {
|
||||
@@ -126,7 +129,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
// }
|
||||
// }()
|
||||
|
||||
err := s.redisClient.Subscribe(channel, stopChan, func(message string) {
|
||||
err = s.redisClient.Subscribe(channel, stopChan, func(message string) {
|
||||
defer cb.EncoderFilterCallbacks().RecoverPanic()
|
||||
api.LogDebugf("SSE Send message: %s", message)
|
||||
cb.EncoderFilterCallbacks().InjectData([]byte(message))
|
||||
@@ -136,7 +139,7 @@ func (s *SSEServer) HandleSSE(cb api.FilterCallbackHandler, stopChan chan struct
|
||||
}
|
||||
|
||||
// Send the initial endpoint event
|
||||
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\r\n\r\n", messageEndpoint)
|
||||
initialEvent := fmt.Sprintf("event: endpoint\ndata: %s\n\n", messageEndpoint)
|
||||
err = s.redisClient.Publish(channel, initialEvent)
|
||||
if err != nil {
|
||||
api.LogErrorf("Failed to send initial event: %v", err)
|
||||
@@ -210,7 +213,7 @@ func (s *SSEServer) HandleMessage(w http.ResponseWriter, r *http.Request, body j
|
||||
var status int
|
||||
// Only send response if there is one (not for notifications)
|
||||
if response != nil {
|
||||
if sessionID != ""{
|
||||
if sessionID != "" {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
status = http.StatusAccepted
|
||||
} else {
|
||||
|
||||
@@ -129,9 +129,15 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
|
||||
if method != http.MethodGet {
|
||||
f.callbacks.DecoderFilterCallbacks().SendLocalReply(http.StatusMethodNotAllowed, "Method not allowed", nil, 0, "")
|
||||
} else {
|
||||
// to support the query param in Message Endpoint
|
||||
trimmed := strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)
|
||||
if rq := requestUrl.RawQuery; rq != "" {
|
||||
trimmed += "?" + rq
|
||||
}
|
||||
|
||||
f.config.defaultServer = common.NewSSEServer(common.NewMCPServer(DefaultServerName, Version),
|
||||
common.WithSSEEndpoint(GlobalSSEPathSuffix),
|
||||
common.WithMessageEndpoint(strings.TrimSuffix(requestUrl.Path, GlobalSSEPathSuffix)),
|
||||
common.WithMessageEndpoint(trimmed),
|
||||
common.WithRedisClient(f.config.redisClient))
|
||||
f.serverName = f.config.defaultServer.GetServerName()
|
||||
body := "SSE connection create"
|
||||
@@ -143,50 +149,9 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
|
||||
func (f *filter) processMcpRequestHeadersForSSEUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||
// We don't need to process the request body for SSE upstream.
|
||||
f.skipRequestBody = true
|
||||
f.rewritePathForSSEUpstream(header)
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
func (f *filter) rewritePathForSSEUpstream(header api.RequestHeaderMap) {
|
||||
matchedRule := f.matchedRule
|
||||
if !matchedRule.EnablePathRewrite || matchedRule.MatchRuleType != common.PrefixMatch {
|
||||
// No rewrite required, so we don't need to process the response body, either.
|
||||
f.skipResponseBody = true
|
||||
return
|
||||
}
|
||||
|
||||
path := f.req.URL.Path
|
||||
if !strings.HasPrefix(path, matchedRule.MatchRulePath) {
|
||||
api.LogWarnf("Unexpected: Path %s does not match the configured prefix %s", path, matchedRule.MatchRulePath)
|
||||
return
|
||||
}
|
||||
|
||||
rewrittenPath := path[len(matchedRule.MatchRulePath):]
|
||||
|
||||
if rewrittenPath == "" {
|
||||
rewrittenPath = matchedRule.PathRewritePrefix
|
||||
} else {
|
||||
rewritePrefixHasTrailingSlash := strings.HasSuffix(matchedRule.PathRewritePrefix, "/")
|
||||
pathSuffixHasLeadingSlash := strings.HasPrefix(rewrittenPath, "/")
|
||||
if rewritePrefixHasTrailingSlash != pathSuffixHasLeadingSlash {
|
||||
// One has, the other doesn't have.
|
||||
rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath
|
||||
} else if pathSuffixHasLeadingSlash {
|
||||
// Both have.
|
||||
rewrittenPath = matchedRule.PathRewritePrefix + rewrittenPath[1:]
|
||||
} else {
|
||||
// Neither have.
|
||||
rewrittenPath = matchedRule.PathRewritePrefix + "/" + rewrittenPath
|
||||
}
|
||||
}
|
||||
|
||||
if f.req.URL.RawQuery != "" {
|
||||
rewrittenPath = rewrittenPath + "?" + f.req.URL.RawQuery
|
||||
}
|
||||
|
||||
header.SetPath(rewrittenPath)
|
||||
}
|
||||
|
||||
// DecodeData might be called multiple times during handling the request body.
|
||||
// The endStream is true when handling the last piece of the body.
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
@@ -322,20 +287,7 @@ func (f *filter) encodeDataFromSSEUpstream(buffer api.BufferInstance, endStream
|
||||
bufferBytes := buffer.Bytes()
|
||||
bufferData := string(bufferBytes)
|
||||
|
||||
err, lineBreak := f.findSSELineBreak(bufferData)
|
||||
if err != nil {
|
||||
api.LogWarnf("Failed to find line break in SSE data: %v", err)
|
||||
f.needProcess = false
|
||||
return api.Continue
|
||||
}
|
||||
if lineBreak == "" {
|
||||
// Have not found any line break. Need to buffer and check again.
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
|
||||
api.LogDebugf("Line break sequence: %v", []byte(lineBreak))
|
||||
|
||||
err, endpointUrl := f.findEndpointUrl(bufferData, lineBreak)
|
||||
err, endpointUrl := f.findEndpointUrl(bufferData)
|
||||
if err != nil {
|
||||
api.LogWarnf("Failed to find endpoint URL in SSE data: %v", err)
|
||||
f.needProcess = false
|
||||
@@ -412,7 +364,7 @@ func (f *filter) rewriteEndpointUrl(endpointUrl string) (bool, string) {
|
||||
return true, endpointUrl
|
||||
}
|
||||
|
||||
func (f *filter) findSSELineBreak(bufferData string) (error, string) {
|
||||
func (f *filter) findNextLineBreak(bufferData string) (error, string) {
|
||||
// See https://html.spec.whatwg.org/multipage/server-sent-events.html
|
||||
crIndex := strings.IndexAny(bufferData, "\r")
|
||||
lfIndex := strings.IndexAny(bufferData, "\n")
|
||||
@@ -422,11 +374,20 @@ func (f *filter) findSSELineBreak(bufferData string) (error, string) {
|
||||
}
|
||||
lineBreak := ""
|
||||
if crIndex != -1 && lfIndex != -1 {
|
||||
if crIndex+1 != lfIndex {
|
||||
// Found both line breaks, but they are not adjacent. Skip body processing.
|
||||
return errors.New("found non-adjacent CR and LF"), ""
|
||||
if crIndex < lfIndex {
|
||||
if crIndex+1 == lfIndex {
|
||||
lineBreak = "\r\n"
|
||||
} else {
|
||||
lineBreak = "\r"
|
||||
}
|
||||
} else {
|
||||
if crIndex == lfIndex+1 {
|
||||
// Found unexpected "\n\r". Skip body processing.
|
||||
return errors.New("found unexpected LF+CR"), ""
|
||||
} else {
|
||||
lineBreak = "\n"
|
||||
}
|
||||
}
|
||||
lineBreak = "\r\n"
|
||||
} else if crIndex != -1 {
|
||||
lineBreak = "\r"
|
||||
} else {
|
||||
@@ -435,12 +396,21 @@ func (f *filter) findSSELineBreak(bufferData string) (error, string) {
|
||||
return nil, lineBreak
|
||||
}
|
||||
|
||||
func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) {
|
||||
func (f *filter) findEndpointUrl(bufferData string) (error, string) {
|
||||
eventIndex := strings.Index(bufferData, "event:")
|
||||
if eventIndex == -1 {
|
||||
return nil, ""
|
||||
}
|
||||
bufferData = bufferData[eventIndex:]
|
||||
err, lineBreak := f.findNextLineBreak(bufferData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
|
||||
}
|
||||
if lineBreak == "" {
|
||||
// No line break found, which means the data is not enough.
|
||||
return nil, ""
|
||||
}
|
||||
api.LogDebugf("event line break sequence: %v", []byte(lineBreak))
|
||||
eventEndIndex := strings.Index(bufferData, lineBreak)
|
||||
if eventEndIndex == -1 {
|
||||
return nil, ""
|
||||
@@ -450,6 +420,15 @@ func (f *filter) findEndpointUrl(bufferData, lineBreak string) (error, string) {
|
||||
return fmt.Errorf("the initial event [%s] is not an endpoint event. Skip processing", eventName), ""
|
||||
}
|
||||
bufferData = bufferData[eventEndIndex+len(lineBreak):]
|
||||
err, lineBreak = f.findNextLineBreak(bufferData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find endpoint URL in SSE data: %v", err), ""
|
||||
}
|
||||
if lineBreak == "" {
|
||||
// No line break found, which means the data is not enough.
|
||||
return nil, ""
|
||||
}
|
||||
api.LogDebugf("data line break sequence: %v", []byte(lineBreak))
|
||||
dataEndIndex := strings.Index(bufferData, lineBreak)
|
||||
if dataEndIndex == -1 {
|
||||
// Data received not enough.
|
||||
|
||||
@@ -4,4 +4,6 @@ build:gcc --cxxopt=-std=c++17
|
||||
build:clang --action_env=CC=clang --action_env=CXX=clang++
|
||||
build:clang --action_env=BAZEL_COMPILER=clang
|
||||
build:clang --linkopt=-fuse-ld=lld
|
||||
build:clang --cxxopt=-std=c++17
|
||||
build:clang --cxxopt=-std=c++17
|
||||
|
||||
build --incompatible_use_platforms_repo_for_constraints=false
|
||||
@@ -1 +1 @@
|
||||
5.4.0
|
||||
6.0.0
|
||||
@@ -1,6 +1,13 @@
|
||||
workspace(name = "istio_ecosystem_wasm_extensions")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
http_archive(
|
||||
name = "platforms",
|
||||
url = "https://github.com/bazelbuild/platforms/releases/download/0.0.9/platforms-0.0.9.tar.gz",
|
||||
sha256 = "5eda539c841265031c2f82d8ae7a3a6490bd62176e0c038fc469eabf91f6149b",
|
||||
)
|
||||
|
||||
load("//bazel:third_party.bzl", "wasm_extension_dependency")
|
||||
|
||||
wasm_extension_dependency()
|
||||
@@ -16,9 +23,9 @@ load("@io_bazel_rules_docker//repositories:deps.bzl", container_deps = "deps")
|
||||
|
||||
container_deps()
|
||||
|
||||
PROXY_WASM_CPP_SDK_SHA = "eaec483b5b3c7bcb89fd208b5a1fa5d79d626f61"
|
||||
PROXY_WASM_CPP_SDK_SHA = "0ceca8c81dddc4c9875cf0cb997454764905658c"
|
||||
|
||||
PROXY_WASM_CPP_SDK_SHA256 = "1140bc8114d75db56a6ca6b18423d4df50d988d40b4cec929a1eb246cf5a4a3d"
|
||||
PROXY_WASM_CPP_SDK_SHA256 = "cb010b242d49fb02b39124421b6acb69bd4ece64fb6299ba3f98f3b36eef7004"
|
||||
|
||||
http_archive(
|
||||
name = "proxy_wasm_cpp_sdk",
|
||||
|
||||
@@ -202,7 +202,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
|
||||
}
|
||||
item = consumer.find("keys");
|
||||
if (item == consumer.end()) {
|
||||
LOG_WARN("not found keys configuration for consumer " + c.name + ", will use global configuration to extract keys");
|
||||
LOG_DEBUG("not found keys configuration for consumer " + c.name + ", will use global configuration to extract keys");
|
||||
need_global_keys = true;
|
||||
} else {
|
||||
c.keys = std::vector<std::string>{OriginalAuthKey};
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
|
||||
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
|
||||
| `modelMapping` | map of string | 选填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/v1/chat/completions"] | 只对这些特定路径后缀的请求生效 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis"] | 只对这些特定路径后缀的请求生效|
|
||||
|
||||
|
||||
## 效果说明
|
||||
|
||||
@@ -7,7 +7,7 @@ The `model-mapper` plugin implements the functionality of routing based on the m
|
||||
| ----------- | --------------- | ----------------------- | ------ | ------------------------------------------- |
|
||||
| `modelKey` | string | Optional | model | The location of the model parameter in the request body. |
|
||||
| `modelMapping` | map of string | Optional | - | AI model mapping table, used to map the model names in the request to the model names supported by the service provider.<br/>1. Supports prefix matching. For example, use "gpt-3-*" to match all models whose names start with “gpt-3-”;<br/>2. Supports using "*" as the key to configure a generic fallback mapping relationship;<br/>3. If the target name in the mapping is an empty string "", it means to keep the original model name. |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only applies to requests with these specific path suffixes. |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis"] | Only applies to requests with these specific path suffixes. |
|
||||
|
||||
## Runtime Properties
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ struct ModelMapperConfigRule {
|
||||
std::string default_model_mapping_;
|
||||
std::vector<std::string> enable_on_path_suffix_ = {
|
||||
"/completions", "/embeddings", "/images/generations",
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations"};
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||
"/image-synthesis", "/video-synthesis"};
|
||||
};
|
||||
|
||||
// PluginRootContext is the root context for all streams processed by the
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
| `modelKey` | string | 选填 | model | 请求body中model参数的位置 |
|
||||
| `addProviderHeader` | string | 选填 | - | 从model参数中解析出的provider名字放到哪个请求header中 |
|
||||
| `modelToHeader` | string | 选填 | - | 直接将model参数放到哪个请求header中 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 |
|
||||
| `enableOnPathSuffix` | array of string | 选填 | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis"] | 只对这些特定路径后缀的请求生效,可以配置为 "*" 以匹配所有路径 |
|
||||
|
||||
## 运行属性
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ The `model-router` plugin implements routing functionality based on the model pa
|
||||
| `modelKey` | string | Optional | model | Location of the model parameter in the request body |
|
||||
| `addProviderHeader` | string | Optional | - | Which request header to add the provider name parsed from the model parameter |
|
||||
| `modelToHeader` | string | Optional | - | Which request header to directly add the model parameter to |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/v1/chat/completions"] | Only effective for requests with these specific path suffixes, can be configured as "*" to match all paths |
|
||||
| `enableOnPathSuffix` | array of string | Optional | ["/completions","/embeddings","/images/generations","/audio/speech","/fine_tuning/jobs","/moderations","/image-synthesis","/video-synthesis"] | Only effective for requests with these specific path suffixes, can be configured as "*" to match all paths |
|
||||
|
||||
## Runtime Properties
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ struct ModelRouterConfigRule {
|
||||
std::string model_to_header_;
|
||||
std::vector<std::string> enable_on_path_suffix_ = {
|
||||
"/completions", "/embeddings", "/images/generations",
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations"};
|
||||
"/audio/speech", "/fine_tuning/jobs", "/moderations",
|
||||
"/image-synthesis", "/video-synthesis"};
|
||||
};
|
||||
|
||||
class PluginContext;
|
||||
|
||||
@@ -90,6 +90,8 @@ func (c *PluginConfig) FromJson(json gjson.Result, log wrapper.Log) {
|
||||
|
||||
if json.Get("enableSemanticCache").Exists() {
|
||||
c.EnableSemanticCache = json.Get("enableSemanticCache").Bool()
|
||||
} else if c.GetVectorProvider() == nil {
|
||||
c.EnableSemanticCache = false // set value to false when no vector provider
|
||||
} else {
|
||||
c.EnableSemanticCache = true // set default value to true
|
||||
}
|
||||
|
||||
98
plugins/wasm-go/extensions/ai-image-reader/README.md
Normal file
98
plugins/wasm-go/extensions/ai-image-reader/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
---
|
||||
title: AI IMAGE READER
|
||||
keywords: [ AI网关, AI IMAGE READER ]
|
||||
description: AI IMAGE READER 插件配置参考
|
||||
|
||||
---
|
||||
|
||||
## 功能说明
|
||||
|
||||
通过对接OCR服务实现AI-IMAGE-READER,目前支持阿里云模型服务灵积(dashscope)的qwen-vl-ocr模型提供OCR服务,流程如图所示:
|
||||
|
||||
<img src=".\ai-image-reader.png">
|
||||
|
||||
## 运行属性
|
||||
|
||||
插件执行阶段:`默认阶段`
|
||||
插件执行优先级:`400`
|
||||
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------- | -------- | -------- | ------ | -------------------------------------- |
|
||||
| `apiKey` | string | 必填 | - | 用于在访问OCR服务时进行认证的令牌。 |
|
||||
| `type` | string | 必填 | - | 后端OCR服务提供商类型(例如dashscope) |
|
||||
| `serviceHost` | string | 必填 | - | 后端OCR服务域名 |
|
||||
| `serviceName` | string | 必填 | - | 后端OCR服务名 |
|
||||
| `servicePort` | int | 必填 | - | 后端OCR服务端口 |
|
||||
| `model` | string | 必填 | - | 后端OCR服务模型名称(例如qwen-vl-ocr) |
|
||||
| `timeout` | int | 选填 | 10000 | API调用超时时间(毫秒) |
|
||||
|
||||
## 示例
|
||||
|
||||
```yaml
|
||||
"apiKey": "YOUR_API_KEY",
|
||||
"type": "dashscope",
|
||||
"model": "qwen-vl-ocr",
|
||||
"timeout": 10000,
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"serviceName": "dashscope",
|
||||
"servicePort": "443"
|
||||
```
|
||||
|
||||
请求遵循openai api协议规范:
|
||||
|
||||
URL传递图片:
|
||||
|
||||
```
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/ctdzex/biaozhun.jpg",
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
```
|
||||
|
||||
Base64编码传递图片:
|
||||
|
||||
```
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "what's in this image?" },
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
```
|
||||
|
||||
以下为使用ai-image-reader进行增强的例子,原始请求为:
|
||||
|
||||
```
|
||||
图片内容是什么?
|
||||
```
|
||||
|
||||
未经过ai-image-reader插件处理LLM返回的结果为:
|
||||
|
||||
```
|
||||
对不起,作为一个文本AI助手,我无法查看图片内容。您可以描述一下图片的内容,我可以尽力帮助您识别。
|
||||
```
|
||||
|
||||
经过ai-image-reader插件处理后LLM返回的结果为:
|
||||
|
||||
```
|
||||
非常感谢您分享的图片内容!根据您提供的文字信息,学习编写shell脚本对Linux系统管理员来说是非常有益的。通过自动化系统管理任务,可以提高效率并减少手动操作的时间。对于家用Linux爱好者来说,了解如何在命令行下操作也是很重要的,因为在某些情况下,命令行操作可能更为便捷和高效。在本书中,您将学习如何运用shell脚本处理系统管理任务,以及如何在Linux命令行下进行操作。希望这本书能够帮助您更好地理解和应用Linux系统管理和操作的知识!如果您有任何其他问题或需要进一步帮助,请随时告诉我。
|
||||
```
|
||||
94
plugins/wasm-go/extensions/ai-image-reader/README_EN.md
Normal file
94
plugins/wasm-go/extensions/ai-image-reader/README_EN.md
Normal file
@@ -0,0 +1,94 @@
|
||||
---
|
||||
title: AI IMAGE READER
|
||||
keywords: [ AI GATEWAY, AI IMAGE READER ]
|
||||
description: AI IMAGE READER Plugin Configuration Reference
|
||||
---
|
||||
|
||||
## Function Description
|
||||
|
||||
By integrating with OCR services to implement AI-IMAGE-READER, currently, it supports Alibaba Cloud's qwen-vl-ocr model under Dashscope for OCR services, and the process is shown in the figure below:<img src=".\ai-image-reader-en.png">
|
||||
|
||||
## Running Attributes
|
||||
|
||||
Plugin execution phase:`Default Phase`
|
||||
Plugin execution priority:`400`
|
||||
|
||||
## Configuration Description
|
||||
|
||||
| Name | Data Type | Requirement | Default Value | Description |
|
||||
| ------------- | --------- | ----------- | ------------- | ------------------------------------------------------------ |
|
||||
| `apiKey` | string | Required | - | Token for authenticating access to OCR services. |
|
||||
| `type` | string | Required | - | Provider type of the backend OCR service type(e.g. dashscope). |
|
||||
| `serviceHost` | string | Required | - | Host of the backend OCR service. |
|
||||
| `serviceName` | string | Required | - | Name of the backend OCR service. |
|
||||
| `servicePort` | int | Required | - | Port of the backend OCR service. |
|
||||
| `model` | string | Required | - | Model name of the backend OCR service (e.g., qwen-vl-ocr). |
|
||||
| `timeout` | int | Required | 10000 | API call timeout duration (milliseconds). |
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
"apiKey": "YOUR_API_KEY",
|
||||
"type": "dashscope",
|
||||
"model": "qwen-vl-ocr",
|
||||
"timeout": 10000,
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"serviceName": "dashscope",
|
||||
"servicePort": "443"
|
||||
```
|
||||
|
||||
Request to follow the OpenAI API protocol specifications:
|
||||
|
||||
Pass images via URL:
|
||||
|
||||
```
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://help-static-aliyun-doc.aliyuncs.com/file-manage-files/zh-CN/20241108/ctdzex/biaozhun.jpg",
|
||||
},
|
||||
},
|
||||
],
|
||||
}],
|
||||
```
|
||||
|
||||
Pass images via Base64:
|
||||
|
||||
```
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "what's in this image?" },
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
```
|
||||
|
||||
The following is an example of using ai-image-reader for enhancement. The original request was:
|
||||
|
||||
```
|
||||
What is the content of the image?
|
||||
```
|
||||
|
||||
The result returned by the LLM without processing from the ai-image-reader plugin is:
|
||||
|
||||
```
|
||||
Sorry, as a text-based AI assistant, I cannot view image content. You can describe the content of the image, and I will do my best to help you identify it.
|
||||
```
|
||||
|
||||
The result returned by the LLM after processing by the ai-image-reader plugin is:
|
||||
|
||||
```
|
||||
Thank you for sharing the image! Mastering shell scripting is highly beneficial for Linux system administrators as it automates tasks, boosts efficiency, and cuts down manual work. For home Linux users, command-line skills are equally important for quick and efficient operations. This book will teach you to handle system management tasks with shell scripts and operate in the Linux command line. Hope it aids your Linux system management learning! Feel free to ask if you have more questions.
|
||||
```
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 30 KiB |
BIN
plugins/wasm-go/extensions/ai-image-reader/ai-image-reader.png
Normal file
BIN
plugins/wasm-go/extensions/ai-image-reader/ai-image-reader.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
177
plugins/wasm-go/extensions/ai-image-reader/dashscope.go
Normal file
177
plugins/wasm-go/extensions/ai-image-reader/dashscope.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
DashscopeDomain = "dashscope.aliyuncs.com"
|
||||
DashscopePort = 443
|
||||
DashscopeDefaultModelName = "qwen-vl-ocr"
|
||||
DashscopeEndpoint = "/compatible-mode/v1/chat/completions"
|
||||
MinPixels = 3136
|
||||
MaxPixels = 1003520
|
||||
)
|
||||
|
||||
type OcrReq struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []chatMessage `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type OcrResp struct {
|
||||
Choices []chatCompletionChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type chatCompletionChoice struct {
|
||||
Message *chatMessageContent `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessageContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []content `json:"content"`
|
||||
}
|
||||
|
||||
type imageURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type content struct {
|
||||
Type string `json:"type"`
|
||||
ImageUrl imageURL `json:"image_url,omitempty"`
|
||||
MinPixels int `json:"min_pixels,omitempty"`
|
||||
MaxPixels int `json:"max_pixels,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
var dashScopeConfig dashScopeProviderConfig
|
||||
|
||||
type dashScopeProviderInitializer struct {
|
||||
}
|
||||
|
||||
func (d *dashScopeProviderInitializer) InitConfig(json gjson.Result) {
|
||||
dashScopeConfig.apiKey = json.Get("apiKey").String()
|
||||
}
|
||||
|
||||
func (d *dashScopeProviderInitializer) ValidateConfig() error {
|
||||
if dashScopeConfig.apiKey == "" {
|
||||
return errors.New("[DashScope] apiKey is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *dashScopeProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) {
|
||||
if c.servicePort == 0 {
|
||||
c.servicePort = DashscopePort
|
||||
}
|
||||
if c.serviceHost == "" {
|
||||
c.serviceHost = DashscopeDomain
|
||||
}
|
||||
return &DSProvider{
|
||||
config: c,
|
||||
client: wrapper.NewClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: c.serviceName,
|
||||
Host: c.serviceHost,
|
||||
Port: int64(c.servicePort),
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dashScopeProviderConfig struct {
|
||||
// @Title zh-CN 文字识别服务 API Key
|
||||
// @Description zh-CN 文字识别服务 API Key
|
||||
apiKey string
|
||||
}
|
||||
|
||||
type DSProvider struct {
|
||||
config ProviderConfig
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (d *DSProvider) GetProviderType() string {
|
||||
return ProviderTypeDashscope
|
||||
}
|
||||
|
||||
func (d *DSProvider) CallArgs(imageUrl string) CallArgs {
|
||||
model := d.config.model
|
||||
if model == "" {
|
||||
model = DashscopeDefaultModelName
|
||||
}
|
||||
reqBody := OcrReq{
|
||||
Model: model,
|
||||
Messages: []chatMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []content{
|
||||
{
|
||||
Type: "image_url",
|
||||
ImageUrl: imageURL{
|
||||
URL: imageUrl,
|
||||
},
|
||||
MinPixels: MinPixels,
|
||||
MaxPixels: MaxPixels,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
return CallArgs{
|
||||
Method: http.MethodPost,
|
||||
Url: DashscopeEndpoint,
|
||||
Headers: [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
{"Authorization", fmt.Sprintf("Bearer %s", dashScopeConfig.apiKey)},
|
||||
},
|
||||
Body: body,
|
||||
TimeoutMillisecond: d.config.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DSProvider) parseOcrResponse(responseBody []byte) (*OcrResp, error) {
|
||||
var resp OcrResp
|
||||
err := json.Unmarshal(responseBody, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (d *DSProvider) DoOCR(
|
||||
imageUrl string,
|
||||
callback func(imageContent string, err error)) error {
|
||||
args := d.CallArgs(imageUrl)
|
||||
err := d.client.Call(args.Method, args.Url, args.Headers, args.Body,
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
err := errors.New("failed to do ocr due to status code: " + strconv.Itoa(statusCode))
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
log.Debugf("do ocr response: %d, %s", statusCode, responseBody)
|
||||
resp, err := d.parseOcrResponse(responseBody)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to parse response: %v", err)
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
err = errors.New("no ocr response found")
|
||||
callback("", err)
|
||||
return
|
||||
}
|
||||
callback(resp.Choices[0].Message.Content, nil)
|
||||
}, args.TimeoutMillisecond)
|
||||
return err
|
||||
}
|
||||
19
plugins/wasm-go/extensions/ai-image-reader/go.mod
Normal file
19
plugins/wasm-go/extensions/ai-image-reader/go.mod
Normal file
@@ -0,0 +1,19 @@
|
||||
module ai-image-reader
|
||||
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.4-0.20250621002302-e94ac43dd15c
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.1 // indirect
|
||||
github.com/magefile/mage v1.14.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
)
|
||||
25
plugins/wasm-go/extensions/ai-image-reader/go.sum
Normal file
25
plugins/wasm-go/extensions/ai-image-reader/go.sum
Normal file
@@ -0,0 +1,25 @@
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.4-0.20250621002302-e94ac43dd15c h1:YGKECMrlahN6dyEaM/S5NEU4IJoFzWKsHQyawov6ep8=
|
||||
github.com/alibaba/higress/plugins/wasm-go v1.4.4-0.20250621002302-e94ac43dd15c/go.mod h1:E2xVWrIovU3rZi4HGlMfcYf+c/UVh3aCtpcJlNjpxYc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
|
||||
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.1 h1:f9X4I5Y6jK3GrdsWn/lCTI1z5Lu5GOMazqQohAC3Vzk=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v1.0.1/go.mod h1:iiSyFbo+rAtbtGt/bsefv8GU57h9CCLYGJA74/tF5/0=
|
||||
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
|
||||
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
139
plugins/wasm-go/extensions/ai-image-reader/main.go
Normal file
139
plugins/wasm-go/extensions/ai-image-reader/main.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMaxBodyBytes uint32 = 100 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
promptTemplate string
|
||||
ocrProvider Provider
|
||||
ocrProviderConfig *ProviderConfig
|
||||
}
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"ai-image-reader",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.promptTemplate = `# 用户发送的图片解析得到的文字内容如下:
|
||||
{image_content}
|
||||
在回答时,请注意以下几点:
|
||||
- 请你回答问题时结合用户图片的文字内容回答。
|
||||
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
|
||||
|
||||
# 用户消息为:
|
||||
{question}`
|
||||
config.ocrProviderConfig = &ProviderConfig{}
|
||||
config.ocrProviderConfig.FromJson(json)
|
||||
if err := config.ocrProviderConfig.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
var err error
|
||||
config.ocrProvider, err = CreateProvider(*config.ocrProviderConfig)
|
||||
if err != nil {
|
||||
return errors.New("create ocr provider failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
contentType, _ := proxywasm.GetHttpRequestHeader("content-type")
|
||||
if contentType == "" {
|
||||
return types.ActionContinue
|
||||
}
|
||||
if !strings.Contains(contentType, "application/json") {
|
||||
log.Warnf("content is not json, can't process: %s", contentType)
|
||||
ctx.DontReadRequestBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
|
||||
_ = proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
var queryIndex int
|
||||
var query string
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
var imageUrls []string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Get("role").String() == "user" {
|
||||
queryIndex = i
|
||||
content := messages[i].Get("content").Array()
|
||||
for j := len(content) - 1; j >= 0; j-- {
|
||||
contentType := content[j].Get("type").String()
|
||||
if contentType == "image_url" {
|
||||
imageUrls = append(imageUrls, content[j].Get("image_url.url").String())
|
||||
} else if contentType == "text" {
|
||||
query = content[j].Get("text").String()
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(imageUrls) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
return executeReadImage(imageUrls, config, query, queryIndex, body)
|
||||
}
|
||||
|
||||
func executeReadImage(imageUrls []string, config Config, query string, queryIndex int, body []byte) types.Action {
|
||||
var imageContents []string
|
||||
var totalImages int
|
||||
var finished int
|
||||
for _, imageUrl := range imageUrls {
|
||||
err := config.ocrProvider.DoOCR(imageUrl, func(imageContent string, err error) {
|
||||
defer func() {
|
||||
finished++
|
||||
if totalImages == finished {
|
||||
var processedContents []string
|
||||
for idx := len(imageContents) - 1; idx >= 0; idx-- {
|
||||
processedContents = append(processedContents, fmt.Sprintf("第%d张图片内容为 %s", totalImages-idx, imageContents[idx]))
|
||||
}
|
||||
imageSummary := fmt.Sprintf("总共有 %d 张图片。\n", totalImages)
|
||||
prompt := strings.Replace(config.promptTemplate, "{image_content}", imageSummary+strings.Join(processedContents, "\n"), 1)
|
||||
prompt = strings.Replace(prompt, "{question}", query, 1)
|
||||
modifiedBody, err := sjson.SetBytes(body, fmt.Sprintf("messages.%d.content", queryIndex), prompt)
|
||||
if err != nil {
|
||||
log.Errorf("modify request message content failed, err:%v, body:%s", err, body)
|
||||
} else {
|
||||
log.Debugf("modified body:%s", modifiedBody)
|
||||
proxywasm.ReplaceHttpRequestBody(modifiedBody)
|
||||
}
|
||||
proxywasm.ResumeHttpRequest()
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
log.Errorf("do ocr failed, err:%v", err)
|
||||
return
|
||||
}
|
||||
imageContents = append(imageContents, imageContent)
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("ocr call failed, err:%v", err)
|
||||
continue
|
||||
}
|
||||
totalImages++
|
||||
}
|
||||
if totalImages > 0 {
|
||||
return types.ActionPause
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
109
plugins/wasm-go/extensions/ai-image-reader/provider.go
Normal file
109
plugins/wasm-go/extensions/ai-image-reader/provider.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderTypeDashscope = "dashscope"
|
||||
)
|
||||
|
||||
type providerInitializer interface {
|
||||
InitConfig(json gjson.Result)
|
||||
ValidateConfig() error
|
||||
CreateProvider(ProviderConfig) (Provider, error)
|
||||
}
|
||||
|
||||
var (
|
||||
providerInitializers = map[string]providerInitializer{
|
||||
ProviderTypeDashscope: &dashScopeProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
type ProviderConfig struct {
|
||||
// @Title zh-CN 文字识别服务提供者类型
|
||||
// @Description zh-CN 文字识别服务提供者类型,例如 DashScope
|
||||
typ string
|
||||
// @Title zh-CN DashScope 文字识别服务名称
|
||||
// @Description zh-CN 文字识别服务名称
|
||||
serviceName string
|
||||
// @Title zh-CN 文字识别服务域名
|
||||
// @Description zh-CN 文字识别服务域名
|
||||
serviceHost string
|
||||
// @Title zh-CN 文字识别服务端口
|
||||
// @Description zh-CN 文字识别服务端口
|
||||
servicePort int64
|
||||
// @Title zh-CN 文字识别服务超时时间
|
||||
// @Description zh-CN 文字识别服务超时时间
|
||||
timeout uint32
|
||||
// @Title zh-CN 文字识别服务使用的模型
|
||||
// @Description zh-CN 用于文字识别的模型名称, 在 DashScope 中默认为 "qwen-vl-ocr"
|
||||
model string
|
||||
|
||||
initializer providerInitializer
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.typ = json.Get("type").String()
|
||||
i, has := providerInitializers[c.typ]
|
||||
if has {
|
||||
i.InitConfig(json)
|
||||
c.initializer = i
|
||||
}
|
||||
c.serviceName = json.Get("serviceName").String()
|
||||
c.serviceHost = json.Get("serviceHost").String()
|
||||
c.servicePort = json.Get("servicePort").Int()
|
||||
c.timeout = uint32(json.Get("timeout").Int())
|
||||
c.model = json.Get("model").String()
|
||||
if c.timeout == 0 {
|
||||
c.timeout = 10000
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
if c.typ == "" {
|
||||
return errors.New("ocr service provider type is required")
|
||||
}
|
||||
if c.serviceName == "" {
|
||||
return errors.New("ocr service name is required")
|
||||
}
|
||||
if c.typ == "" {
|
||||
return errors.New("ocr service type is required")
|
||||
}
|
||||
if c.initializer == nil {
|
||||
return errors.New("unknown ocr service provider type: " + c.typ)
|
||||
}
|
||||
if err := c.initializer.ValidateConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetProviderType() string {
|
||||
return c.typ
|
||||
}
|
||||
|
||||
func CreateProvider(pc ProviderConfig) (Provider, error) {
|
||||
initializer, has := providerInitializers[pc.typ]
|
||||
if !has {
|
||||
return nil, errors.New("unknown provider type: " + pc.typ)
|
||||
}
|
||||
return initializer.CreateProvider(pc)
|
||||
}
|
||||
|
||||
type CallArgs struct {
|
||||
Method string
|
||||
Url string
|
||||
Headers [][2]string
|
||||
Body []byte
|
||||
TimeoutMillisecond uint32
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
GetProviderType() string
|
||||
CallArgs(imageUrl string) CallArgs
|
||||
DoOCR(
|
||||
imageUrl string,
|
||||
callback func(imageContent string, err error)) error
|
||||
}
|
||||
1
plugins/wasm-go/extensions/ai-load-balancer/.gitignore
vendored
Normal file
1
plugins/wasm-go/extensions/ai-load-balancer/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
test/
|
||||
174
plugins/wasm-go/extensions/ai-load-balancer/README.md
Normal file
174
plugins/wasm-go/extensions/ai-load-balancer/README.md
Normal file
@@ -0,0 +1,174 @@
|
||||
---
|
||||
title: AI负载均衡
|
||||
keywords: [higress, llm, load balance]
|
||||
description: 针对LLM服务的负载均衡策略
|
||||
---
|
||||
|
||||
# 功能说明
|
||||
|
||||
**注意**:
|
||||
- Higress网关版本需要>=v2.1.5
|
||||
|
||||
对LLM服务提供热插拔的负载均衡策略,如果关闭插件,负载均衡策略会退化为服务本身的负载均衡策略(轮训、本地最小请求数、随机、一致性hash等)。
|
||||
|
||||
配置如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_policy` | string | 必填 | | 负载均衡策略类型 |
|
||||
| `lb_config` | object | 必填 | | 当前负载均衡策略类型的配置 |
|
||||
|
||||
目前支持的负载均衡策略包括:
|
||||
- `global_least_request`: 基于redis实现的全局最小请求数负载均衡
|
||||
- `prefix_cache`: 基于 prompt 前缀匹配选择后端节点,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
|
||||
- `least_busy`: [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
|
||||
|
||||
# 全局最小请求数
|
||||
## 功能说明
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant R as Redis
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
C ->> H: 发起请求
|
||||
H ->> R: 获取 host ongoing 请求数
|
||||
R ->> H: 返回结果
|
||||
H ->> R: 根据结果选择当前请求数最小的host,计数+1
|
||||
R ->> H: 返回结果
|
||||
H ->> H1: 绕过service原本的负载均衡策略,转发请求到对应host
|
||||
H1 ->> H: 返回响应
|
||||
H ->> R: host计数-1
|
||||
H ->> C: 返回响应
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `serviceFQDN` | string | 必填 | | redis 服务的FQDN,例如: `redis.dns` |
|
||||
| `servicePort` | int | 必填 | | redis 服务的port |
|
||||
| `username` | string | 必填 | | redis 用户名 |
|
||||
| `password` | string | 选填 | 空 | redis 密码 |
|
||||
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
|
||||
| `database` | int | 选填 | 0 | redis 数据库序号 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# 前缀匹配
|
||||
## 功能说明
|
||||
根据 prompt 前缀匹配选择 pod,以复用 KV Cache,如果通过前缀匹配无法匹配到节点,则通过全局最小请求数进行服务节点的选择
|
||||
|
||||
例如以下请求被路由到了pod 1
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
那么后续具有相同前缀的请求也会被路由到 pod 1
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi! How can I assist you today? 😊"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write a short story aboud 100 words"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `serviceFQDN` | string | 必填 | | redis 服务的FQDN,例如: `redis.dns` |
|
||||
| `servicePort` | int | 必填 | | redis 服务的port |
|
||||
| `username` | string | 必填 | | redis 用户名 |
|
||||
| `password` | string | 选填 | 空 | redis 密码 |
|
||||
| `timeout` | int | 选填 | 3000ms | redis 请求超时时间 |
|
||||
| `database` | int | 选填 | 0 | redis 数据库序号 |
|
||||
| `redisKeyTTL` | int | 选填 | 1800ms | prompt 前缀对应的key的ttl |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# 最小负载
|
||||
## 功能说明
|
||||
[gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md) 的 wasm 实现
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
loop 定期拉取metrics
|
||||
H ->> H1: /metrics
|
||||
H1 ->> H: vllm metrics
|
||||
H ->> H2: /metrics
|
||||
H2 ->> H: vllm metrics
|
||||
end
|
||||
|
||||
C ->> H: 发起请求
|
||||
H ->> H1: 根据vllm metrics选择合适的pod,绕过服务原始的lb policy直接转发
|
||||
H1 ->> H: 返回响应
|
||||
H ->> C: 返回响应
|
||||
```
|
||||
|
||||
<!-- pod选取流程图如下:
|
||||
|
||||
 -->
|
||||
|
||||
## 配置说明
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | 选填 | | critical的模型列表 |
|
||||
|
||||
## 配置示例
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
```
|
||||
177
plugins/wasm-go/extensions/ai-load-balancer/README_EN.md
Normal file
177
plugins/wasm-go/extensions/ai-load-balancer/README_EN.md
Normal file
@@ -0,0 +1,177 @@
|
||||
---
|
||||
title: AI Load Balance
|
||||
keywords: [higress, llm, load balance]
|
||||
description: LLM-oriented load balance policies
|
||||
---
|
||||
|
||||
# Introduction
|
||||
|
||||
**Attention**:
|
||||
- Version of Higress should >= v2.1.5
|
||||
|
||||
This plug-in provides the llm-oriented load balancing capability in a hot-swappable manner. If the plugin is closed, the load balancing strategy will degenerate into the load balancing strategy of the service itself (round robin, local minimum request number, random, consistent hash, etc.).
|
||||
|
||||
The configuration is:
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `lb_policy` | string | required | | load balance type |
|
||||
| `lb_config` | object | required | | configuration for the current load balance type |
|
||||
|
||||
Current supported load balance policies are:
|
||||
|
||||
- `global_least_request`: global least request based on redis
|
||||
- `prefix_cache`: Select the backend node based on the prompt prefix match. If the node cannot be matched by prefix matching, the service node is selected based on the global minimum number of requests.
|
||||
- `least_busy`: implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
|
||||
|
||||
# Global Least Request
|
||||
## Introduction
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant R as Redis
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
C ->> H: Send request
|
||||
H ->> R: Get host ongoing request number
|
||||
R ->> H: Return result
|
||||
H ->> R: According to the result, select the host with the smallest number of current requests, host rq count +1.
|
||||
R ->> H: Return result
|
||||
H ->> H1: Bypass the service's original load balancing strategy and forward the request to the corresponding host
|
||||
H1 ->> H: Return result
|
||||
H ->> R: host rq count -1
|
||||
H ->> C: Receive response
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Name | Type | required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `serviceFQDN` | string | required | | redis FQDN, e.g. `redis.dns` |
|
||||
| `servicePort` | int | required | | redis port |
|
||||
| `username` | string | required | | redis username |
|
||||
| `password` | string | optional | `` | redis password |
|
||||
| `timeout` | int | optional | 3000ms | redis request timeout |
|
||||
| `database` | int | optional | 0 | redis database number |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: global_least_request
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# Prefix Cache
|
||||
## Introduction
|
||||
Select pods based on the prompt prefix match to reuse KV Cache. If no node can be matched by prefix match, select the service node based on the global minimum number of requests.
|
||||
|
||||
For example, the following request is routed to pod 1:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Then subsequent requests with the same prefix will also be routed to pod 1:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi! How can I assist you today? 😊"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "write a short story aboud 100 words"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
| Name | Type | required | default | description |
|
||||
|--------------------|-----------------|-----------------------|-------------|---------------------------------|
|
||||
| `serviceFQDN` | string | required | | redis FQDN, e.g. `redis.dns` |
|
||||
| `servicePort` | int | required | | redis port |
|
||||
| `username` | string | required | | redis username |
|
||||
| `password` | string | optional | `` | redis password |
|
||||
| `timeout` | int | optional | 3000ms | redis request timeout |
|
||||
| `database` | int | optional | 0 | redis database number |
|
||||
| `redisKeyTTL` | int | optional | 1800ms | prompt prefix key's ttl |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: prefix_cache
|
||||
lb_config:
|
||||
serviceFQDN: redis.static
|
||||
servicePort: 6379
|
||||
username: default
|
||||
password: '123456'
|
||||
```
|
||||
|
||||
# Least Busy
|
||||
## Introduction
|
||||
|
||||
wasm implementation for [gateway-api-inference-extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/README.md)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant C as Client
|
||||
participant H as Higress
|
||||
participant H1 as Host1
|
||||
participant H2 as Host2
|
||||
|
||||
loop fetch metrics periodically
|
||||
H ->> H1: /metrics
|
||||
H1 ->> H: vllm metrics
|
||||
H ->> H2: /metrics
|
||||
H2 ->> H: vllm metrics
|
||||
end
|
||||
|
||||
C ->> H: request
|
||||
H ->> H1: select pod according to vllm metrics, bypassing original service load balance policy
|
||||
H1 ->> H: response
|
||||
H ->> C: response
|
||||
```
|
||||
|
||||
<!-- flowchart for pod selection:
|
||||
|
||||
 -->
|
||||
|
||||
## Configuration
|
||||
|
||||
| Name | Type | Required | default | description |
|
||||
|--------------------|-----------------|------------------|-------------|-------------------------------------|
|
||||
| `criticalModels` | []string | required | | critical model names |
|
||||
|
||||
## Configuration Example
|
||||
|
||||
```yaml
|
||||
lb_policy: least_busy
|
||||
lb_config:
|
||||
criticalModels:
|
||||
- meta-llama/Llama-2-7b-hf
|
||||
- sql-lora
|
||||
```
|
||||
@@ -0,0 +1,178 @@
|
||||
package global_least_request
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `local seed = KEYS[1]
|
||||
local hset_key = KEYS[2]
|
||||
local current_target = KEYS[3]
|
||||
local current_count = 0
|
||||
|
||||
math.randomseed(seed)
|
||||
|
||||
local function randomBool()
|
||||
return math.random() >= 0.5
|
||||
end
|
||||
|
||||
local function is_healthy(addr)
|
||||
for i = 4, #KEYS do
|
||||
if addr == KEYS[i] then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
if redis.call('HEXISTS', hset_key, current_target) ~= 0 then
|
||||
current_count = redis.call('HGET', hset_key, current_target)
|
||||
local hash = redis.call('HGETALL', hset_key)
|
||||
for i = 1, #hash, 2 do
|
||||
local addr = hash[i]
|
||||
local count = hash[i+1]
|
||||
if is_healthy(addr) then
|
||||
if count < current_count then
|
||||
current_target = addr
|
||||
current_count = count
|
||||
elseif count == current_count and randomBool() then
|
||||
current_target = addr
|
||||
current_count = count
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
redis.call("HINCRBY", hset_key, current_target, 1)
|
||||
|
||||
return current_target`
|
||||
)
|
||||
|
||||
type GlobalLeastRequestLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
}
|
||||
|
||||
func NewGlobalLeastRequestLoadBalancer(json gjson.Result) (GlobalLeastRequestLoadBalancer, error) {
|
||||
lb := GlobalLeastRequestLoadBalancer{}
|
||||
serviceFQDN := json.Get("serviceFQDN").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
if serviceFQDN == "" || servicePort == 0 {
|
||||
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
|
||||
return lb, errors.New("invalid redis service config")
|
||||
}
|
||||
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceFQDN,
|
||||
Port: servicePort,
|
||||
})
|
||||
username := json.Get("username").String()
|
||||
password := json.Get("password").String()
|
||||
timeout := json.Get("timeout").Int()
|
||||
if timeout == 0 {
|
||||
timeout = 3000
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
routeName, err := utils.GetRouteName()
|
||||
if err != nil || routeName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("routeName", routeName)
|
||||
}
|
||||
clusterName, err := utils.GetClusterName()
|
||||
if err != nil || clusterName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("clusterName", clusterName)
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
// Only healthy host can be selected
|
||||
healthyHostArray := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHostArray = append(healthyHostArray, hostInfo[0])
|
||||
}
|
||||
}
|
||||
if len(healthyHostArray) == 0 {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
randomIndex := rand.Intn(len(healthyHostArray))
|
||||
hostSelected := healthyHostArray[randomIndex]
|
||||
keys := []interface{}{time.Now().Unix(), fmt.Sprintf(RedisKeyFormat, routeName, clusterName), hostSelected}
|
||||
for _, v := range healthyHostArray {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, []interface{}{}, func(response resp.Value) {
|
||||
if err := response.Error(); err != nil {
|
||||
log.Errorf("HGetAll failed: %+v", err)
|
||||
ctx.SetContext("error", true)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
return
|
||||
}
|
||||
hostSelected = response.String()
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
}
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
proxywasm.ResumeHttpRequest()
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
if endOfStream {
|
||||
isErr, _ := ctx.GetContext("error").(bool)
|
||||
if !isErr {
|
||||
routeName, _ := ctx.GetContext("routeName").(string)
|
||||
clusterName, _ := ctx.GetContext("clusterName").(string)
|
||||
host_selected, _ := ctx.GetContext("host_selected").(string)
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb GlobalLeastRequestLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
23
plugins/wasm-go/extensions/ai-load-balancer/go.mod
Normal file
23
plugins/wasm-go/extensions/ai-load-balancer/go.mod
Normal file
@@ -0,0 +1,23 @@
|
||||
module github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer
|
||||
|
||||
go 1.24.1
|
||||
|
||||
toolchain go1.24.3
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/resp v0.1.1
|
||||
go.uber.org/multierr v1.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/prometheus/common v0.64.0
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
)
|
||||
35
plugins/wasm-go/extensions/ai-load-balancer/go.sum
Normal file
35
plugins/wasm-go/extensions/ai-load-balancer/go.sum
Normal file
@@ -0,0 +1,35 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545 h1:zPXEonKCAeLvXI1IpwGpIeVSvLY5AZ9h9uTJnOuiA3Q=
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQPGO4=
|
||||
github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -0,0 +1,68 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package backend
|
||||
|
||||
import "fmt"
|
||||
|
||||
type PodSet map[Pod]bool
|
||||
|
||||
type Pod struct {
|
||||
Name string
|
||||
Address string
|
||||
}
|
||||
|
||||
func (p Pod) String() string {
|
||||
return p.Name + ":" + p.Address
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
// ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU.
|
||||
ActiveModels map[string]int
|
||||
// MaxActiveModels is the maximum number of models that can be loaded to GPU.
|
||||
MaxActiveModels int
|
||||
RunningQueueSize int
|
||||
WaitingQueueSize int
|
||||
KVCacheUsagePercent float64
|
||||
KvCacheMaxTokenCapacity int
|
||||
}
|
||||
|
||||
type PodMetrics struct {
|
||||
Pod
|
||||
Metrics
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) String() string {
|
||||
return fmt.Sprintf("Pod: %+v; Metrics: %+v", pm.Pod, pm.Metrics)
|
||||
}
|
||||
|
||||
func (pm *PodMetrics) Clone() *PodMetrics {
|
||||
cm := make(map[string]int, len(pm.ActiveModels))
|
||||
for k, v := range pm.ActiveModels {
|
||||
cm[k] = v
|
||||
}
|
||||
clone := &PodMetrics{
|
||||
Pod: pm.Pod,
|
||||
Metrics: Metrics{
|
||||
ActiveModels: cm,
|
||||
RunningQueueSize: pm.RunningQueueSize,
|
||||
WaitingQueueSize: pm.WaitingQueueSize,
|
||||
KVCacheUsagePercent: pm.KVCacheUsagePercent,
|
||||
KvCacheMaxTokenCapacity: pm.KvCacheMaxTokenCapacity,
|
||||
},
|
||||
}
|
||||
return clone
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package vllm provides vllm specific pod metrics implementation.
|
||||
package vllm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
const (
|
||||
LoraRequestInfoMetricName = "vllm:lora_requests_info"
|
||||
LoraRequestInfoRunningAdaptersMetricName = "running_lora_adapters"
|
||||
LoraRequestInfoMaxAdaptersMetricName = "max_lora"
|
||||
// TODO: Replace these with the num_tokens_running/waiting below once we add those to the fork.
|
||||
RunningQueueSizeMetricName = "vllm:num_requests_running"
|
||||
WaitingQueueSizeMetricName = "vllm:num_requests_waiting"
|
||||
/* TODO: Uncomment this once the following are added to the fork.
|
||||
RunningQueueSizeMetricName = "vllm:num_tokens_running"
|
||||
WaitingQueueSizeMetricName = "vllm:num_tokens_waiting"
|
||||
*/
|
||||
KVCacheUsagePercentMetricName = "vllm:gpu_cache_usage_perc"
|
||||
KvCacheMaxTokenCapacityMetricName = "vllm:gpu_cache_max_token_capacity"
|
||||
)
|
||||
|
||||
// promToPodMetrics updates internal pod metrics with scraped prometheus metrics.
|
||||
// A combined error is returned if errors occur in one or more metric processing.
|
||||
// it returns a new PodMetrics pointer which can be used to atomically update the pod metrics map.
|
||||
func PromToPodMetrics(
|
||||
metricFamilies map[string]*dto.MetricFamily,
|
||||
existing *backend.PodMetrics,
|
||||
) (*backend.PodMetrics, error) {
|
||||
var errs error
|
||||
updated := existing.Clone()
|
||||
runningQueueSize, err := getLatestMetric(metricFamilies, RunningQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.RunningQueueSize = int(runningQueueSize.GetGauge().GetValue())
|
||||
}
|
||||
waitingQueueSize, err := getLatestMetric(metricFamilies, WaitingQueueSizeMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.WaitingQueueSize = int(waitingQueueSize.GetGauge().GetValue())
|
||||
}
|
||||
cachePercent, err := getLatestMetric(metricFamilies, KVCacheUsagePercentMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err == nil {
|
||||
updated.KVCacheUsagePercent = cachePercent.GetGauge().GetValue()
|
||||
}
|
||||
|
||||
loraMetrics, _, err := getLatestLoraMetric(metricFamilies)
|
||||
errs = multierr.Append(errs, err)
|
||||
/* TODO: uncomment once this is available in vllm.
|
||||
kvCap, _, err := getGaugeLatestValue(metricFamilies, KvCacheMaxTokenCapacityMetricName)
|
||||
errs = multierr.Append(errs, err)
|
||||
if err != nil {
|
||||
updated.KvCacheMaxTokenCapacity = int(kvCap)
|
||||
}
|
||||
*/
|
||||
|
||||
if loraMetrics != nil {
|
||||
updated.ActiveModels = make(map[string]int)
|
||||
for _, label := range loraMetrics.GetLabel() {
|
||||
if label.GetName() == LoraRequestInfoRunningAdaptersMetricName {
|
||||
if label.GetValue() != "" {
|
||||
adapterList := strings.Split(label.GetValue(), ",")
|
||||
for _, adapter := range adapterList {
|
||||
updated.ActiveModels[adapter] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
if label.GetName() == LoraRequestInfoMaxAdaptersMetricName {
|
||||
if label.GetValue() != "" {
|
||||
updated.MaxActiveModels, err = strconv.Atoi(label.GetValue())
|
||||
if err != nil {
|
||||
errs = multierr.Append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return updated, errs
|
||||
}
|
||||
|
||||
// getLatestLoraMetric gets latest lora metric series in gauge metric family `vllm:lora_requests_info`
|
||||
// reason its specially fetched is because each label key value pair permutation generates new series
|
||||
// and only most recent is useful. The value of each series is the creation timestamp so we can
|
||||
// retrieve the latest by sorting the value.
|
||||
func getLatestLoraMetric(metricFamilies map[string]*dto.MetricFamily) (*dto.Metric, time.Time, error) {
|
||||
loraRequests, ok := metricFamilies[LoraRequestInfoMetricName]
|
||||
if !ok {
|
||||
// klog.Warningf("metric family %q not found", LoraRequestInfoMetricName)
|
||||
return nil, time.Time{}, fmt.Errorf("metric family %q not found", LoraRequestInfoMetricName)
|
||||
}
|
||||
var latestTs float64
|
||||
var latest *dto.Metric
|
||||
for _, m := range loraRequests.GetMetric() {
|
||||
if m.GetGauge().GetValue() > latestTs {
|
||||
latestTs = m.GetGauge().GetValue()
|
||||
latest = m
|
||||
}
|
||||
}
|
||||
return latest, time.Unix(0, int64(latestTs*1000)), nil
|
||||
}
|
||||
|
||||
// getLatestMetric gets the latest metric of a family. This should be used to get the latest Gauge metric.
|
||||
// Since vllm doesn't set the timestamp in metric, this metric essentially gets the first metric.
|
||||
func getLatestMetric(metricFamilies map[string]*dto.MetricFamily, metricName string) (*dto.Metric, error) {
|
||||
mf, ok := metricFamilies[metricName]
|
||||
if !ok {
|
||||
// klog.Warningf("metric family %q not found", metricName)
|
||||
return nil, fmt.Errorf("metric family %q not found", metricName)
|
||||
}
|
||||
if len(mf.GetMetric()) == 0 {
|
||||
return nil, fmt.Errorf("no metrics available for %q", metricName)
|
||||
}
|
||||
var latestTs int64
|
||||
var latest *dto.Metric
|
||||
for _, m := range mf.GetMetric() {
|
||||
if m.GetTimestampMs() >= latestTs {
|
||||
latestTs = m.GetTimestampMs()
|
||||
latest = m
|
||||
}
|
||||
}
|
||||
// klog.V(logutil.TRACE).Infof("Got metric value %+v for metric %v", latest, metricName)
|
||||
return latest, nil
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package least_busy
|
||||
|
||||
import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/scheduling"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type LeastBusyLoadBalancer struct {
|
||||
criticalModels map[string]struct{}
|
||||
}
|
||||
|
||||
func NewLeastBusyLoadBalancer(json gjson.Result) (LeastBusyLoadBalancer, error) {
|
||||
lb := LeastBusyLoadBalancer{}
|
||||
lb.criticalModels = make(map[string]struct{})
|
||||
for _, model := range json.Get("criticalModels").Array() {
|
||||
lb.criticalModels[model.String()] = struct{}{}
|
||||
}
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Callbacks which are called in request path
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
requestModel := gjson.GetBytes(body, "model")
|
||||
if !requestModel.Exists() {
|
||||
return types.ActionContinue
|
||||
}
|
||||
_, isCritical := lb.criticalModels[requestModel.String()]
|
||||
llmReq := &scheduling.LLMRequest{
|
||||
Model: requestModel.String(),
|
||||
Critical: isCritical,
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
return types.ActionContinue
|
||||
}
|
||||
hostMetrics := make(map[string]string)
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
hostMetrics[hostInfo[0]] = gjson.Get(hostInfo[1], "metrics").String()
|
||||
}
|
||||
}
|
||||
scheduler, err := scheduling.GetScheduler(hostMetrics)
|
||||
if err != nil {
|
||||
log.Debugf("initial scheduler failed: %v", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
targetPod, err := scheduler.Schedule(llmReq)
|
||||
log.Debugf("targetPod: %+v", targetPod.Address)
|
||||
if err != nil {
|
||||
log.Debugf("pod select failed: %v", err)
|
||||
proxywasm.SendHttpResponseWithDetail(429, "limited resources", nil, []byte("limited resources"), 0)
|
||||
} else {
|
||||
proxywasm.SetUpstreamOverrideHost([]byte(targetPod.Address))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
ctx.DontReadResponseBody()
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb LeastBusyLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
)
|
||||
|
||||
type Filter interface {
|
||||
Name() string
|
||||
Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
|
||||
}
|
||||
|
||||
// filter applies current filterFunc, and then recursively applies next filters depending success or
|
||||
// failure of the current filterFunc.
|
||||
// It can be used to construct a flow chart algorithm.
|
||||
type filter struct {
|
||||
name string
|
||||
filter filterFunc
|
||||
// nextOnSuccess filter will be applied after successfully applying the current filter.
|
||||
// The filtered results will be passed to the next filter.
|
||||
nextOnSuccess *filter
|
||||
// nextOnFailure filter will be applied if current filter fails.
|
||||
// The original input will be passed to the next filter.
|
||||
nextOnFailure *filter
|
||||
// nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the
|
||||
// success or failure of the current filter.
|
||||
// NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil.
|
||||
// However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of
|
||||
// nextOnSuccessOrFailure, in the success and failure scenarios, respectively.
|
||||
nextOnSuccessOrFailure *filter
|
||||
|
||||
// callbacks api.FilterCallbackHandler
|
||||
}
|
||||
|
||||
func (f *filter) Name() string {
|
||||
if f == nil {
|
||||
return "nil"
|
||||
}
|
||||
return f.name
|
||||
}
|
||||
|
||||
func (f *filter) Filter(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
proxywasm.LogDebugf("Running filter %q on request %v with %v pods", f.name, req, len(pods))
|
||||
filtered, err := f.filter(req, pods)
|
||||
|
||||
next := f.nextOnSuccessOrFailure
|
||||
if err == nil && len(filtered) > 0 {
|
||||
if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil {
|
||||
// No succeeding filters to run, return.
|
||||
return filtered, err
|
||||
}
|
||||
if f.nextOnSuccess != nil {
|
||||
next = f.nextOnSuccess
|
||||
}
|
||||
// On success, pass the filtered result to the next filter.
|
||||
return next.Filter(req, filtered)
|
||||
} else {
|
||||
if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil {
|
||||
// No succeeding filters to run, return.
|
||||
return filtered, err
|
||||
}
|
||||
if f.nextOnFailure != nil {
|
||||
next = f.nextOnFailure
|
||||
}
|
||||
// On failure, pass the initial set of pods to the next filter.
|
||||
return next.Filter(req, pods)
|
||||
}
|
||||
}
|
||||
|
||||
// filterFunc filters a set of input pods to a subset.
|
||||
type filterFunc func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error)
|
||||
|
||||
// toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc.
|
||||
func toFilterFunc(pp podPredicate) filterFunc {
|
||||
return func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
filtered := []*backend.PodMetrics{}
|
||||
for _, pod := range pods {
|
||||
pass := pp(req, pod)
|
||||
if pass {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return nil, errors.New("no pods left")
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
}
|
||||
|
||||
// leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range
|
||||
// (max-min) by the number of pods, and finds the pods that fall into the first range.
|
||||
// The intuition is that if there are multiple pods that share similar queue size in the low range,
|
||||
// we should consider them all instead of the absolute minimum one. This worked better than picking
|
||||
// the least one as it gives more choices for the next filter, which on aggregate gave better
|
||||
// results.
|
||||
// TODO: Compare this strategy with other strategies such as top K.
|
||||
func leastQueuingFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxInt
|
||||
max := 0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.WaitingQueueSize <= min {
|
||||
min = pod.WaitingQueueSize
|
||||
}
|
||||
if pod.WaitingQueueSize >= max {
|
||||
max = pod.WaitingQueueSize
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.WaitingQueueSize >= min && pod.WaitingQueueSize <= min+(max-min)/len(pods) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func lowQueueingPodPredicate(_ *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return pod.WaitingQueueSize < queueingThresholdLoRA
|
||||
}
|
||||
|
||||
// leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range
|
||||
// (max-min) by the number of pods, and finds the pods that fall into the first range.
|
||||
// The intuition is that if there are multiple pods that share similar KV cache in the low range, we
|
||||
// should consider them all instead of the absolute minimum one. This worked better than picking the
|
||||
// least one as it gives more choices for the next filter, which on aggregate gave better results.
|
||||
// TODO: Compare this strategy with other strategies such as top K.
|
||||
func leastKVCacheFilterFunc(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
min := math.MaxFloat64
|
||||
var max float64 = 0
|
||||
filtered := []*backend.PodMetrics{}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.KVCacheUsagePercent <= min {
|
||||
min = pod.KVCacheUsagePercent
|
||||
}
|
||||
if pod.KVCacheUsagePercent >= max {
|
||||
max = pod.KVCacheUsagePercent
|
||||
}
|
||||
}
|
||||
|
||||
for _, pod := range pods {
|
||||
if pod.KVCacheUsagePercent >= min && pod.KVCacheUsagePercent <= min+(max-min)/float64(len(pods)) {
|
||||
filtered = append(filtered, pod)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// podPredicate is a filter function to check whether a pod is desired.
|
||||
type podPredicate func(req *LLMRequest, pod *backend.PodMetrics) bool
|
||||
|
||||
// We consider serving an adapter low cost it the adapter is active in the model server, or the
|
||||
// model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by
|
||||
// spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to
|
||||
// a single pod. This gave good performance in our initial benchmarking results in the scenario
|
||||
// where # of lora slots > # of lora adapters.
|
||||
func lowLoRACostPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
_, ok := pod.ActiveModels[req.Model]
|
||||
return ok || len(pod.ActiveModels) < pod.MaxActiveModels
|
||||
}
|
||||
|
||||
// loRAAffinityPredicate is a filter function to check whether a pod has affinity to the lora requested.
|
||||
func loRAAffinityPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
_, ok := pod.ActiveModels[req.Model]
|
||||
return ok
|
||||
}
|
||||
|
||||
// canAcceptNewLoraPredicate is a filter function to check whether a pod has room to load the adapter.
|
||||
func canAcceptNewLoraPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return len(pod.ActiveModels) < pod.MaxActiveModels
|
||||
}
|
||||
|
||||
func criticalRequestPredicate(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return req.Critical
|
||||
}
|
||||
|
||||
func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate {
|
||||
return func(req *LLMRequest, pod *backend.PodMetrics) bool {
|
||||
return pod.WaitingQueueSize <= queueThreshold && pod.KVCacheUsagePercent <= kvCacheThreshold
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
/*
|
||||
Copyright 2025 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package scheduling implements request scheduling algorithms.
|
||||
package scheduling
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy/backend/vllm"
|
||||
|
||||
"github.com/prometheus/common/expfmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
kvCacheThreshold = 0.8
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
queueThresholdCritical = 5
|
||||
// TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16) Make this configurable.
|
||||
// the threshold for queued requests to be considered low below which we can prioritize LoRA affinity.
|
||||
// The value of 50 is arrived heuristicically based on experiments.
|
||||
queueingThresholdLoRA = 50
|
||||
)
|
||||
|
||||
var (
|
||||
defaultFilter = &filter{
|
||||
name: "critical request",
|
||||
filter: toFilterFunc(criticalRequestPredicate),
|
||||
nextOnSuccess: lowLatencyFilter,
|
||||
nextOnFailure: sheddableRequestFilter,
|
||||
}
|
||||
|
||||
// queueLoRAAndKVCacheFilter applied least queue -> low cost lora -> least KV Cache filter
|
||||
queueLoRAAndKVCacheFilter = &filter{
|
||||
name: "least queuing",
|
||||
filter: leastQueuingFilterFunc,
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "low cost LoRA",
|
||||
filter: toFilterFunc(lowLoRACostPredicate),
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "least KV cache percent",
|
||||
filter: leastKVCacheFilterFunc,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// queueAndKVCacheFilter applies least queue followed by least KV Cache filter
|
||||
queueAndKVCacheFilter = &filter{
|
||||
name: "least queuing",
|
||||
filter: leastQueuingFilterFunc,
|
||||
nextOnSuccessOrFailure: &filter{
|
||||
name: "least KV cache percent",
|
||||
filter: leastKVCacheFilterFunc,
|
||||
},
|
||||
}
|
||||
|
||||
lowLatencyFilter = &filter{
|
||||
name: "low queueing filter",
|
||||
filter: toFilterFunc((lowQueueingPodPredicate)),
|
||||
nextOnSuccess: &filter{
|
||||
name: "affinity LoRA",
|
||||
filter: toFilterFunc(loRAAffinityPredicate),
|
||||
nextOnSuccess: queueAndKVCacheFilter,
|
||||
nextOnFailure: &filter{
|
||||
name: "can accept LoRA Adapter",
|
||||
filter: toFilterFunc(canAcceptNewLoraPredicate),
|
||||
nextOnSuccessOrFailure: queueAndKVCacheFilter,
|
||||
},
|
||||
},
|
||||
nextOnFailure: queueLoRAAndKVCacheFilter,
|
||||
}
|
||||
|
||||
sheddableRequestFilter = &filter{
|
||||
// When there is at least one model server that's not queuing requests, and still has KV
|
||||
// cache below a certain threshold, we consider this model server has capacity to handle
|
||||
// a sheddable request without impacting critical requests.
|
||||
name: "has capacity for sheddable requests",
|
||||
filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(queueThresholdCritical, kvCacheThreshold)),
|
||||
nextOnSuccess: queueLoRAAndKVCacheFilter,
|
||||
// If all pods are queuing or running above the KVCache threshold, we drop the sheddable
|
||||
// request to make room for critical requests.
|
||||
nextOnFailure: &filter{
|
||||
name: "drop request",
|
||||
filter: func(req *LLMRequest, pods []*backend.PodMetrics) ([]*backend.PodMetrics, error) {
|
||||
// api.LogDebugf("Dropping request %v", req)
|
||||
return []*backend.PodMetrics{}, errors.New("dropping request due to limited backend resources")
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func NewScheduler(pm []*backend.PodMetrics) *Scheduler {
|
||||
|
||||
return &Scheduler{
|
||||
podMetrics: pm,
|
||||
filter: defaultFilter,
|
||||
}
|
||||
}
|
||||
|
||||
type Scheduler struct {
|
||||
podMetrics []*backend.PodMetrics
|
||||
filter Filter
|
||||
}
|
||||
|
||||
// Schedule finds the target pod based on metrics and the requested lora adapter.
|
||||
func (s *Scheduler) Schedule(req *LLMRequest) (targetPod backend.Pod, err error) {
|
||||
pods, err := s.filter.Filter(req, s.podMetrics)
|
||||
if err != nil || len(pods) == 0 {
|
||||
return backend.Pod{}, fmt.Errorf("failed to apply filter, resulted %v pods: %w", len(pods), err)
|
||||
}
|
||||
i := rand.Intn(len(pods))
|
||||
return pods[i].Pod, nil
|
||||
}
|
||||
|
||||
func GetScheduler(hostMetrics map[string]string) (*Scheduler, error) {
|
||||
if len(hostMetrics) == 0 {
|
||||
return nil, errors.New("backend is not support llm scheduling")
|
||||
}
|
||||
var pms []*backend.PodMetrics
|
||||
for addr, metric := range hostMetrics {
|
||||
parser := expfmt.TextParser{}
|
||||
metricFamilies, err := parser.TextToMetricFamilies(strings.NewReader(metric))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pm := &backend.PodMetrics{
|
||||
Pod: backend.Pod{
|
||||
Name: addr,
|
||||
Address: addr,
|
||||
},
|
||||
Metrics: backend.Metrics{},
|
||||
}
|
||||
pm, err = vllm.PromToPodMetrics(metricFamilies, pm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pms = append(pms, pm)
|
||||
}
|
||||
return NewScheduler(pms), nil
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package scheduling
|
||||
|
||||
// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body.
|
||||
type LLMRequest struct {
|
||||
Model string
|
||||
Critical bool
|
||||
}
|
||||
82
plugins/wasm-go/extensions/ai-load-balancer/main.go
Normal file
82
plugins/wasm-go/extensions/ai-load-balancer/main.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
global_least_request "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/global_least_request"
|
||||
least_busy "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/least_busy"
|
||||
prefix_cache "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/prefix_cache"
|
||||
)
|
||||
|
||||
func main() {}
|
||||
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"ai-load-balancer",
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
wrapper.ProcessResponseHeaders(onHttpResponseHeaders),
|
||||
wrapper.ProcessStreamingResponseBody(onHttpStreamingResponseBody),
|
||||
wrapper.ProcessResponseBody(onHttpResponseBody),
|
||||
)
|
||||
}
|
||||
|
||||
type LoadBalancer interface {
|
||||
HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action
|
||||
HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action
|
||||
HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action
|
||||
HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte
|
||||
HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
policy string
|
||||
lb LoadBalancer
|
||||
}
|
||||
|
||||
const (
|
||||
LeastBusyLoadBalancerPolicy = "least_busy"
|
||||
GlobalLeastRequestLoadBalancerPolicy = "global_least_request"
|
||||
PrefixCache = "prefix_cache"
|
||||
)
|
||||
|
||||
func parseConfig(json gjson.Result, config *Config) error {
|
||||
config.policy = json.Get("lb_policy").String()
|
||||
var err error
|
||||
switch config.policy {
|
||||
case LeastBusyLoadBalancerPolicy:
|
||||
config.lb, err = least_busy.NewLeastBusyLoadBalancer(json.Get("lb_config"))
|
||||
case GlobalLeastRequestLoadBalancerPolicy:
|
||||
config.lb, err = global_least_request.NewGlobalLeastRequestLoadBalancer(json.Get("lb_config"))
|
||||
case PrefixCache:
|
||||
config.lb, err = prefix_cache.NewPrefixCacheLoadBalancer(json.Get("lb_config"))
|
||||
default:
|
||||
err = fmt.Errorf("lb_policy %s is not supported", config.policy)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
return config.lb.HandleHttpRequestHeaders(ctx)
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
return config.lb.HandleHttpRequestBody(ctx, body)
|
||||
}
|
||||
|
||||
func onHttpResponseHeaders(ctx wrapper.HttpContext, config Config) types.Action {
|
||||
return config.lb.HandleHttpResponseHeaders(ctx)
|
||||
}
|
||||
|
||||
func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config Config, data []byte, endOfStream bool) []byte {
|
||||
return config.lb.HandleHttpStreamingResponseBody(ctx, data, endOfStream)
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, config Config, body []byte) types.Action {
|
||||
return config.lb.HandleHttpResponseBody(ctx, body)
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package prefix_cache
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-load-balancer/utils"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/resp"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisKeyFormat = "higress:global_least_request_table:%s:%s"
|
||||
RedisLua = `-- hex string => bytes
|
||||
local function hex_to_bytes(hex)
|
||||
local bytes = {}
|
||||
for i = 1, #hex, 2 do
|
||||
local byte_str = hex:sub(i, i+1)
|
||||
local byte_val = tonumber(byte_str, 16)
|
||||
table.insert(bytes, byte_val)
|
||||
end
|
||||
return bytes
|
||||
end
|
||||
|
||||
-- bytes => hex string
|
||||
local function bytes_to_hex(bytes)
|
||||
local result = ""
|
||||
for _, byte in ipairs(bytes) do
|
||||
result = result .. string.format("%02X", byte)
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- byte XOR
|
||||
local function byte_xor(a, b)
|
||||
local result = 0
|
||||
for i = 0, 7 do
|
||||
local bit_val = 2^i
|
||||
if ((a % (bit_val * 2)) >= bit_val) ~= ((b % (bit_val * 2)) >= bit_val) then
|
||||
result = result + bit_val
|
||||
end
|
||||
end
|
||||
return result
|
||||
end
|
||||
|
||||
-- hex string XOR
|
||||
local function hex_xor(a, b)
|
||||
if #a ~= #b then
|
||||
error("Hex strings must be of equal length, first is " .. a .. " second is " .. b)
|
||||
end
|
||||
|
||||
local a_bytes = hex_to_bytes(a)
|
||||
local b_bytes = hex_to_bytes(b)
|
||||
|
||||
local result_bytes = {}
|
||||
for i = 1, #a_bytes do
|
||||
table.insert(result_bytes, byte_xor(a_bytes[i], b_bytes[i]))
|
||||
end
|
||||
|
||||
return bytes_to_hex(result_bytes)
|
||||
end
|
||||
|
||||
-- check host whether healthy
|
||||
local function is_healthy(addr)
|
||||
for i = 4, #KEYS do
|
||||
if addr == KEYS[i] then
|
||||
return true
|
||||
end
|
||||
end
|
||||
return false
|
||||
end
|
||||
|
||||
local target = ""
|
||||
local key = ""
|
||||
local current_key = ""
|
||||
local count = #ARGV
|
||||
local ttl = KEYS[1]
|
||||
local hset_key = KEYS[2]
|
||||
local default_target = KEYS[3]
|
||||
|
||||
if count == 0 then
|
||||
return target
|
||||
end
|
||||
|
||||
-- find longest prefix
|
||||
local index = 1
|
||||
while index <= count do
|
||||
if current_key == "" then
|
||||
current_key = ARGV[index]
|
||||
else
|
||||
current_key = hex_xor(current_key, ARGV[index])
|
||||
end
|
||||
if redis.call("EXISTS", current_key) == 1 then
|
||||
key = current_key
|
||||
local tmp_target = redis.call("GET", key)
|
||||
if not is_healthy(tmp_target) then
|
||||
break
|
||||
end
|
||||
target = tmp_target
|
||||
-- update ttl for exist keys
|
||||
redis.call("EXPIRE", key, ttl)
|
||||
index = index + 1
|
||||
else
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
-- global least request
|
||||
if target == "" then
|
||||
index = 1
|
||||
local current_count = 0
|
||||
target = default_target
|
||||
if redis.call('HEXISTS', hset_key, target) ~= 0 then
|
||||
current_count = redis.call('HGET', hset_key, target)
|
||||
local hash = redis.call('HGETALL', hset_key)
|
||||
for i = 1, #hash, 2 do
|
||||
local addr = hash[i]
|
||||
local count = hash[i+1]
|
||||
if count < current_count and is_healthy(addr) then
|
||||
target = addr
|
||||
current_count = count
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- update request count
|
||||
redis.call("HINCRBY", hset_key, target, 1)
|
||||
|
||||
-- add tree-path
|
||||
while index <= count do
|
||||
if key == "" then
|
||||
key = ARGV[index]
|
||||
else
|
||||
key = hex_xor(key, ARGV[index])
|
||||
end
|
||||
redis.call("SET", key, target)
|
||||
redis.call("EXPIRE", key, ttl)
|
||||
index = index + 1
|
||||
end
|
||||
|
||||
return target`
|
||||
)
|
||||
|
||||
type PrefixCacheLoadBalancer struct {
|
||||
redisClient wrapper.RedisClient
|
||||
redisKeyTTL int
|
||||
}
|
||||
|
||||
func NewPrefixCacheLoadBalancer(json gjson.Result) (PrefixCacheLoadBalancer, error) {
|
||||
lb := PrefixCacheLoadBalancer{}
|
||||
serviceFQDN := json.Get("serviceFQDN").String()
|
||||
servicePort := json.Get("servicePort").Int()
|
||||
if serviceFQDN == "" || servicePort == 0 {
|
||||
log.Errorf("invalid redis service, serviceFQDN: %s, servicePort: %d", serviceFQDN, servicePort)
|
||||
return lb, errors.New("invalid redis service config")
|
||||
}
|
||||
lb.redisClient = wrapper.NewRedisClusterClient(wrapper.FQDNCluster{
|
||||
FQDN: serviceFQDN,
|
||||
Port: servicePort,
|
||||
})
|
||||
username := json.Get("username").String()
|
||||
password := json.Get("password").String()
|
||||
timeout := json.Get("timeout").Int()
|
||||
if timeout == 0 {
|
||||
timeout = 3000
|
||||
}
|
||||
// database default is 0
|
||||
database := json.Get("database").Int()
|
||||
if json.Get("redisKeyTTL").Int() == 0 {
|
||||
lb.redisKeyTTL = int(json.Get("redisKeyTTL").Int())
|
||||
} else {
|
||||
lb.redisKeyTTL = 1800
|
||||
}
|
||||
return lb, lb.redisClient.Init(username, password, int64(timeout), wrapper.WithDataBase(int(database)))
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpRequestHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
// If return types.ActionContinue, SetUpstreamOverrideHost will not take effect
|
||||
return types.HeaderStopIteration
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpRequestBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
var err error
|
||||
routeName, err := utils.GetRouteName()
|
||||
if err != nil || routeName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("routeName", routeName)
|
||||
}
|
||||
clusterName, err := utils.GetClusterName()
|
||||
if err != nil || clusterName == "" {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
} else {
|
||||
ctx.SetContext("clusterName", clusterName)
|
||||
}
|
||||
hostInfos, err := proxywasm.GetUpstreamHosts()
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Error("get upstream cluster endpoints failed")
|
||||
return types.ActionContinue
|
||||
}
|
||||
healthyHosts := []string{}
|
||||
for _, hostInfo := range hostInfos {
|
||||
if gjson.Get(hostInfo[1], "health_status").String() == "Healthy" {
|
||||
healthyHosts = append(healthyHosts, hostInfo[0])
|
||||
}
|
||||
}
|
||||
if len(healthyHosts) == 0 {
|
||||
log.Info("upstream cluster has no healthy endpoints")
|
||||
return types.ActionContinue
|
||||
}
|
||||
defaultHost := healthyHosts[rand.Intn(len(healthyHosts))]
|
||||
params := []interface{}{}
|
||||
rawStr := ""
|
||||
messages := gjson.GetBytes(body, "messages").Array()
|
||||
for index, obj := range messages {
|
||||
if !obj.Get("role").Exists() || !obj.Get("content").Exists() {
|
||||
ctx.SetContext("error", true)
|
||||
log.Info("cannot extract role or content from request body, skip llm load balancing")
|
||||
return types.ActionContinue
|
||||
}
|
||||
role := obj.Get("role").String()
|
||||
content := obj.Get("content").String()
|
||||
rawStr += role + ":" + content
|
||||
if role == "user" || index == len(messages)-1 {
|
||||
sha1Str := computeSHA1(rawStr)
|
||||
params = append(params, sha1Str)
|
||||
rawStr = ""
|
||||
}
|
||||
}
|
||||
if len(params) == 0 {
|
||||
return types.ActionContinue
|
||||
}
|
||||
keys := []interface{}{lb.redisKeyTTL, fmt.Sprintf(RedisKeyFormat, routeName, clusterName), defaultHost}
|
||||
for _, v := range healthyHosts {
|
||||
keys = append(keys, v)
|
||||
}
|
||||
err = lb.redisClient.Eval(RedisLua, len(keys), keys, params, func(response resp.Value) {
|
||||
defer proxywasm.ResumeHttpRequest()
|
||||
if err := response.Error(); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("Redis eval failed: %+v", err)
|
||||
return
|
||||
}
|
||||
hostSelected := response.String()
|
||||
if err := proxywasm.SetUpstreamOverrideHost([]byte(hostSelected)); err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
log.Errorf("override upstream host failed, fallback to default lb policy, error informations: %+v", err)
|
||||
}
|
||||
log.Debugf("host_selected: %s", hostSelected)
|
||||
ctx.SetContext("host_selected", hostSelected)
|
||||
})
|
||||
if err != nil {
|
||||
ctx.SetContext("error", true)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpResponseHeaders(ctx wrapper.HttpContext) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpStreamingResponseBody(ctx wrapper.HttpContext, data []byte, endOfStream bool) []byte {
|
||||
if endOfStream {
|
||||
isErr, _ := ctx.GetContext("error").(bool)
|
||||
if !isErr {
|
||||
routeName, _ := ctx.GetContext("routeName").(string)
|
||||
clusterName, _ := ctx.GetContext("clusterName").(string)
|
||||
host_selected, _ := ctx.GetContext("host_selected").(string)
|
||||
if host_selected == "" {
|
||||
log.Errorf("get host_selected failed")
|
||||
} else {
|
||||
lb.redisClient.HIncrBy(fmt.Sprintf(RedisKeyFormat, routeName, clusterName), host_selected, -1, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (lb PrefixCacheLoadBalancer) HandleHttpResponseBody(ctx wrapper.HttpContext, body []byte) types.Action {
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func computeSHA1(data string) string {
|
||||
hasher := sha1.New()
|
||||
hasher.Write([]byte(data))
|
||||
return strings.ToUpper(hex.EncodeToString(hasher.Sum(nil)))
|
||||
}
|
||||
19
plugins/wasm-go/extensions/ai-load-balancer/utils/utils.go
Normal file
19
plugins/wasm-go/extensions/ai-load-balancer/utils/utils.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import "github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
|
||||
func GetRouteName() (string, error) {
|
||||
if raw, err := proxywasm.GetProperty([]string{"route_name"}); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return string(raw), nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetClusterName() (string, error) {
|
||||
if raw, err := proxywasm.GetProperty([]string{"cluster_name"}); err != nil {
|
||||
return "", err
|
||||
} else {
|
||||
return string(raw), nil
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
---
|
||||
title: AI 代理
|
||||
keywords: [ AI网关, AI代理 ]
|
||||
keywords: [AI网关, AI代理]
|
||||
description: AI 代理插件配置参考
|
||||
---
|
||||
|
||||
@@ -20,53 +20,50 @@ description: AI 代理插件配置参考
|
||||
插件执行阶段:`默认阶段`
|
||||
插件执行优先级:`100`
|
||||
|
||||
|
||||
## 配置字段
|
||||
|
||||
### 基本配置
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------|--------|------|-----|------------------|
|
||||
| `provider` | object | 必填 | - | 配置目标 AI 服务提供商的信息 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------- | -------- | -------- | ------ | ---------------------------- |
|
||||
| `provider` | object | 必填 | - | 配置目标 AI 服务提供商的信息 |
|
||||
|
||||
`provider`的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------| --------------- | -------- | ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `type` | string | 必填 | - | AI 服务提供商名称 |
|
||||
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
|
||||
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟。此项配置目前仅用于获取上下文信息,并不影响实际转发大模型请求。 |
|
||||
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-\*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "\*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。 |
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为AI请求指定覆盖或者填充参数 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
|
||||
| `reasoningContentMode` | string | 非必填 | - | 如何处理大模型服务返回的推理内容。目前支持以下取值:passthrough(正常输出推理内容)、ignore(不输出推理内容)、concat(将推理内容拼接在常规输出内容之前)。默认为 passthrough。仅支持通义千问服务。 |
|
||||
| `capabilities` | map of string | 非必填 | - | 部分provider的部分ai能力原生兼容openai/v1格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key表示的是采用的厂商协议能力,values表示的真实的厂商该能力的api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------------- | ---------------------- | -------- | ------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `type` | string | 必填 | - | AI 服务提供商名称 |
|
||||
| `apiTokens` | array of string | 非必填 | - | 用于在访问 AI 服务时进行认证的令牌。如果配置了多个 token,插件会在请求时随机进行选择。部分服务提供商只支持配置一个 token。 |
|
||||
| `timeout` | number | 非必填 | - | 访问 AI 服务的超时时间。单位为毫秒。默认值为 120000,即 2 分钟。此项配置目前仅用于获取上下文信息,并不影响实际转发大模型请求。 |
|
||||
| `modelMapping` | map of string | 非必填 | - | AI 模型映射表,用于将请求中的模型名称映射为服务提供商支持模型名称。<br/>1. 支持前缀匹配。例如用 "gpt-3-\*" 匹配所有名称以“gpt-3-”开头的模型;<br/>2. 支持使用 "\*" 为键来配置通用兜底映射关系;<br/>3. 如果映射的目标名称为空字符串 "",则表示保留原模型名称。<br/>4. 支持以 `~` 前缀使用正则匹配。例如用 "~gpt(.\*)" 匹配所有以 "gpt" 开头的模型并支持在目标模型中使用 capture group 引用匹配到的内容。示例: "~gpt(.\*): openai/gpt\$1" |
|
||||
| `protocol` | string | 非必填 | - | 插件对外提供的 API 接口契约。目前支持以下取值:openai(默认值,使用 OpenAI 的接口契约)、original(使用目标服务提供商的原始接口契约) |
|
||||
| `context` | object | 非必填 | - | 配置 AI 对话上下文信息 |
|
||||
| `customSettings` | array of customSetting | 非必填 | - | 为 AI 请求指定覆盖或者填充参数 |
|
||||
| `failover` | object | 非必填 | - | 配置 apiToken 的 failover 策略,当 apiToken 不可用时,将其移出 apiToken 列表,待健康检测通过后重新添加回 apiToken 列表 |
|
||||
| `retryOnFailure` | object | 非必填 | - | 当请求失败时立即进行重试 |
|
||||
| `reasoningContentMode` | string | 非必填 | - | 如何处理大模型服务返回的推理内容。目前支持以下取值:passthrough(正常输出推理内容)、ignore(不输出推理内容)、concat(将推理内容拼接在常规输出内容之前)。默认为 passthrough。仅支持通义千问服务。 |
|
||||
| `capabilities` | map of string | 非必填 | - | 部分 provider 的部分 ai 能力原生兼容 openai/v1 格式,不需要重写,可以直接转发,通过此配置项指定来开启转发, key 表示的是采用的厂商协议能力,values 表示的真实的厂商该能力的 api path, 厂商协议能力当前支持: openai/v1/chatcompletions, openai/v1/embeddings, openai/v1/imagegeneration, openai/v1/audiospeech, cohere/v1/rerank |
|
||||
| `subPath` | string | 非必填 | - | 如果配置了subPath,将会先移除请求path中该前缀,再进行后续处理 |
|
||||
|
||||
`context`的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------|--------|------|-----|----------------------------------|
|
||||
| `fileUrl` | string | 必填 | - | 保存 AI 对话上下文的文件 URL。仅支持纯文本类型的文件内容 |
|
||||
| `serviceName` | string | 必填 | - | URL 所对应的 Higress 后端服务完整名称 |
|
||||
| `servicePort` | number | 必填 | - | URL 所对应的 Higress 后端服务访问端口 |
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------- | -------- | -------- | ------ | -------------------------------------------------------- |
|
||||
| `fileUrl` | string | 必填 | - | 保存 AI 对话上下文的文件 URL。仅支持纯文本类型的文件内容 |
|
||||
| `serviceName` | string | 必填 | - | URL 所对应的 Higress 后端服务完整名称 |
|
||||
| `servicePort` | number | 必填 | - | URL 所对应的 Higress 后端服务访问端口 |
|
||||
|
||||
`customSettings`的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ----------- | --------------------- | -------- | ------ | ---------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `name` | string | 必填 | - | 想要设置的参数的名称,例如`max_tokens` |
|
||||
| `value` | string/int/float/bool | 必填 | - | 想要设置的参数的值,例如0 |
|
||||
| `value` | string/int/float/bool | 必填 | - | 想要设置的参数的值,例如 0 |
|
||||
| `mode` | string | 非必填 | "auto" | 参数设置的模式,可以设置为"auto"或者"raw",如果为"auto"则会自动根据协议对参数名做改写,如果为"raw"则不会有任何改写和限制检查 |
|
||||
| `overwrite` | bool | 非必填 | true | 如果为false则只在用户没有设置这个参数时填充参数,否则会直接覆盖用户原有的参数设置 |
|
||||
|
||||
|
||||
custom-setting会遵循如下表格,根据`name`和协议来替换对应的字段,用户需要填写表格中`settingName`列中存在的值。例如用户将`name`设置为`max_tokens`,在openai协议中会替换`max_tokens`,在gemini中会替换`maxOutputTokens`。
|
||||
`none`表示该协议不支持此参数。如果`name`不在此表格中或者对应协议不支持此参数,同时没有设置raw模式,则配置不会生效。
|
||||
| `overwrite` | bool | 非必填 | true | 如果为 false 则只在用户没有设置这个参数时填充参数,否则会直接覆盖用户原有的参数设置 |
|
||||
|
||||
custom-setting 会遵循如下表格,根据`name`和协议来替换对应的字段,用户需要填写表格中`settingName`列中存在的值。例如用户将`name`设置为`max_tokens`,在 openai 协议中会替换`max_tokens`,在 gemini 中会替换`maxOutputTokens`。
|
||||
`none`表示该协议不支持此参数。如果`name`不在此表格中或者对应协议不支持此参数,同时没有设置 raw 模式,则配置不会生效。
|
||||
|
||||
| settingName | openai | baidu | spark | qwen | gemini | hunyuan | claude | minimax |
|
||||
| ----------- | ----------- | ----------------- | ----------- | ----------- | --------------- | ----------- | ----------- | ------------------ |
|
||||
@@ -76,32 +73,31 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
|
||||
| top_k | none | none | top_k | none | topK | none | top_k | none |
|
||||
| seed | seed | none | none | seed | none | none | none | none |
|
||||
|
||||
如果启用了raw模式,custom-setting会直接用输入的`name`和`value`去更改请求中的json内容,而不对参数名称做任何限制和修改。
|
||||
对于大多数协议,custom-setting都会在json内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy会在json的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
|
||||
如果启用了 raw 模式,custom-setting 会直接用输入的`name`和`value`去更改请求中的 json 内容,而不对参数名称做任何限制和修改。
|
||||
对于大多数协议,custom-setting 都会在 json 内容的根路径修改或者填充参数。对于`qwen`协议,ai-proxy 会在 json 的`parameters`子路径下做配置。对于`gemini`协议,则会在`generation_config`子路径下做配置。
|
||||
|
||||
`failover` 的配置字段说明如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|-----------------|-------|-----------------------------------|
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
|
||||
| failoverOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------------- | --------------- | -------------------- | -------------- | -------------------------------------------------------- |
|
||||
| enabled | bool | 非必填 | false | 是否启用 apiToken 的 failover 机制 |
|
||||
| failureThreshold | int | 非必填 | 3 | 触发 failover 连续请求失败的阈值(次数) |
|
||||
| successThreshold | int | 非必填 | 1 | 健康检测的成功阈值(次数) |
|
||||
| healthCheckInterval | int | 非必填 | 5000 | 健康检测的间隔时间,单位毫秒 |
|
||||
| healthCheckTimeout | int | 非必填 | 5000 | 健康检测的超时时间,单位毫秒 |
|
||||
| healthCheckModel | string | 启用 failover 时必填 | | 健康检测使用的模型 |
|
||||
| failoverOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行 failover 的原始请求的状态码,支持正则表达式匹配 |
|
||||
|
||||
`retryOnFailure` 的配置字段说明如下:
|
||||
|
||||
目前仅支持对非流式请求进行重试。
|
||||
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|--------|-------|---------------------------|
|
||||
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
|
||||
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
|
||||
| retryTimeout | int | 非必填 | 30000 | 重试超时时间,单位毫秒 |
|
||||
| retryOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行重试的原始请求的状态码,支持正则表达式匹配 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------- | --------------- | -------- | -------------- | -------------------------------------------------- |
|
||||
| enabled | bool | 非必填 | false | 是否启用失败请求重试 |
|
||||
| maxRetries | int | 非必填 | 1 | 最大重试次数 |
|
||||
| retryTimeout | int | 非必填 | 30000 | 重试超时时间,单位毫秒 |
|
||||
| retryOnStatus | array of string | 非必填 | ["4.*", "5.*"] | 需要进行重试的原始请求的状态码,支持正则表达式匹配 |
|
||||
|
||||
### 提供商特有配置
|
||||
|
||||
@@ -109,19 +105,18 @@ custom-setting会遵循如下表格,根据`name`和协议来替换对应的字
|
||||
|
||||
OpenAI 所对应的 `type` 为 `openai`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|----------|----------|--------|-------------------------------------------------------------------------------|
|
||||
| `openaiCustomUrl` | string | 非必填 | - | 基于OpenAI协议的自定义后端URL,例如: www.example.com/myai/v1/chat/completions |
|
||||
| `responseJsonSchema` | object | 非必填 | - | 预先定义OpenAI响应需满足的Json Schema, 注意目前仅特定的几种模型支持该用法|
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| -------------------- | -------- | -------- | ------ | ---------------------------------------------------------------------------------- |
|
||||
| `openaiCustomUrl` | string | 非必填 | - | 基于 OpenAI 协议的自定义后端 URL,例如: <www.example.com/myai/v1/chat/completions> |
|
||||
| `responseJsonSchema` | object | 非必填 | - | 预先定义 OpenAI 响应需满足的 Json Schema, 注意目前仅特定的几种模型支持该用法 |
|
||||
|
||||
#### Azure OpenAI
|
||||
|
||||
Azure OpenAI 所对应的 `type` 为 `azure`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------|
|
||||
| `azureServiceUrl` | string | 必填 | - | Azure OpenAI 服务的 URL,须包含 `api-version` 查询参数。 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ----------------- | -------- | -------- | ------ | -------------------------------------------------------- |
|
||||
| `azureServiceUrl` | string | 必填 | - | Azure OpenAI 服务的 URL,须包含 `api-version` 查询参数。 |
|
||||
|
||||
**注意:** Azure OpenAI 只支持配置一个 API Token。
|
||||
|
||||
@@ -129,19 +124,19 @@ Azure OpenAI 所对应的 `type` 为 `azure`。它特有的配置字段如下:
|
||||
|
||||
月之暗面所对应的 `type` 为 `moonshot`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|------------------|--------|------|-----|-------------------------------------------------------------|
|
||||
| `moonshotFileId` | string | 非必填 | - | 通过文件接口上传至月之暗面的文件 ID,其内容将被用做 AI 对话的上下文。不可与 `context` 字段同时配置。 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | -------- | ------ | ---------------------------------------------------------------------------------------------------- |
|
||||
| `moonshotFileId` | string | 非必填 | - | 通过文件接口上传至月之暗面的文件 ID,其内容将被用做 AI 对话的上下文。不可与 `context` 字段同时配置。 |
|
||||
|
||||
#### 通义千问(Qwen)
|
||||
|
||||
通义千问所对应的 `type` 为 `qwen`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------------- | --------------- | -------- | ------ | ------------------------------------------------------------ |
|
||||
| `qwenEnableSearch` | boolean | 非必填 | - | 是否启用通义千问内置的互联网搜索功能。 |
|
||||
| `qwenFileIds` | array of string | 非必填 | - | 通过文件接口上传至Dashscope的文件 ID,其内容将被用做 AI 对话的上下文。不可与 `context` 字段同时配置。 |
|
||||
| `qwenEnableCompatible` | boolean | 非必填 | false | 开启通义千问兼容模式。启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------- |
|
||||
| `qwenEnableSearch` | boolean | 非必填 | - | 是否启用通义千问内置的互联网搜索功能。 |
|
||||
| `qwenFileIds` | array of string | 非必填 | - | 通过文件接口上传至 Dashscope 的文件 ID,其内容将被用做 AI 对话的上下文。不可与 `context` 字段同时配置。 |
|
||||
| `qwenEnableCompatible` | boolean | 非必填 | false | 开启通义千问兼容模式。启用通义千问兼容模式后,将调用千问的兼容模式接口,同时对请求/响应不做修改。 |
|
||||
|
||||
#### 百川智能 (Baichuan AI)
|
||||
|
||||
@@ -151,13 +146,13 @@ Azure OpenAI 所对应的 `type` 为 `azure`。它特有的配置字段如下:
|
||||
|
||||
零一万物所对应的 `type` 为 `yi`。它并无特有的配置字段。
|
||||
|
||||
#### 智谱AI(Zhipu AI)
|
||||
#### 智谱 AI(Zhipu AI)
|
||||
|
||||
智谱AI所对应的 `type` 为 `zhipuai`。它并无特有的配置字段。
|
||||
智谱 AI 所对应的 `type` 为 `zhipuai`。它并无特有的配置字段。
|
||||
|
||||
#### DeepSeek(DeepSeek)
|
||||
|
||||
DeepSeek所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
|
||||
DeepSeek 所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
|
||||
|
||||
#### Groq
|
||||
|
||||
@@ -167,13 +162,13 @@ Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
|
||||
|
||||
文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。
|
||||
|
||||
#### 360智脑
|
||||
#### 360 智脑
|
||||
|
||||
360智脑所对应的 `type` 为 `ai360`。它并无特有的配置字段。
|
||||
360 智脑所对应的 `type` 为 `ai360`。它并无特有的配置字段。
|
||||
|
||||
#### GitHub模型
|
||||
#### GitHub 模型
|
||||
|
||||
GitHub模型所对应的 `type` 为 `github`。它并无特有的配置字段。
|
||||
GitHub 模型所对应的 `type` 为 `github`。它并无特有的配置字段。
|
||||
|
||||
#### Mistral
|
||||
|
||||
@@ -181,38 +176,38 @@ Mistral 所对应的 `type` 为 `mistral`。它并无特有的配置字段。
|
||||
|
||||
#### MiniMax
|
||||
|
||||
MiniMax所对应的 `type` 为 `minimax`。它特有的配置字段如下:
|
||||
MiniMax 所对应的 `type` 为 `minimax`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | ------------------------------ | ------ |----------------------------------------------------------------|
|
||||
| `minimaxApiType` | string | v2 和 pro 中选填一项 | v2 | v2 代表 ChatCompletion v2 API,pro 代表 ChatCompletion Pro API |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | ------------------------------ | ------ | ----------------------------------------------------------------------- |
|
||||
| `minimaxApiType` | string | v2 和 pro 中选填一项 | v2 | v2 代表 ChatCompletion v2 API,pro 代表 ChatCompletion Pro API |
|
||||
| `minimaxGroupId` | string | `minimaxApiType` 为 pro 时必填 | - | `minimaxApiType` 为 pro 时使用 ChatCompletion Pro API,需要设置 groupID |
|
||||
|
||||
#### Anthropic Claude
|
||||
|
||||
Anthropic Claude 所对应的 `type` 为 `claude`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------|--------|------|-----|----------------------------------|
|
||||
| `claudeVersion` | string | 可选 | - | Claude 服务的 API 版本,默认为 2023-06-01 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| --------------- | -------- | -------- | ------ | ----------------------------------------- |
|
||||
| `claudeVersion` | string | 可选 | - | Claude 服务的 API 版本,默认为 2023-06-01 |
|
||||
|
||||
#### Ollama
|
||||
|
||||
Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------|
|
||||
| `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 |
|
||||
| `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为11434 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ------------------ | -------- | -------- | ------ | ----------------------------------- |
|
||||
| `ollamaServerHost` | string | 必填 | - | Ollama 服务器的主机地址 |
|
||||
| `ollamaServerPort` | number | 必填 | - | Ollama 服务器的端口号,默认为 11434 |
|
||||
|
||||
#### 混元
|
||||
|
||||
混元所对应的 `type` 为 `hunyuan`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------|
|
||||
| `hunyuanAuthId` | string | 必填 | - | 混元用于v3版本认证的id |
|
||||
| `hunyuanAuthKey` | string | 必填 | - | 混元用于v3版本认证的key |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | -------- | ------ | -------------------------- |
|
||||
| `hunyuanAuthId` | string | 必填 | - | 混元用于 v3 版本认证的 id |
|
||||
| `hunyuanAuthKey` | string | 必填 | - | 混元用于 v3 版本认证的 key |
|
||||
|
||||
#### 阶跃星辰 (Stepfun)
|
||||
|
||||
@@ -222,23 +217,24 @@ Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:
|
||||
|
||||
Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-------------------|--------|------|-----|----------------------------------------------------------------------------------------------------------------------------|
|
||||
| `cloudflareAccountId` | string | 必填 | - | [Cloudflare Account ID](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-api-token-and-account-id) |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| --------------------- | -------- | -------- | ------ | -------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `cloudflareAccountId` | string | 必填 | - | [Cloudflare Account ID](https://developers.cloudflare.com/workers-ai/get-started/rest-api/#1-get-api-token-and-account-id) |
|
||||
|
||||
#### 星火 (Spark)
|
||||
|
||||
星火所对应的 `type` 为 `spark`。它并无特有的配置字段。
|
||||
|
||||
讯飞星火认知大模型的`apiTokens`字段值为`APIKey:APISecret`。即填入自己的APIKey与APISecret,并以`:`分隔。
|
||||
讯飞星火认知大模型的`apiTokens`字段值为`APIKey:APISecret`。即填入自己的 APIKey 与 APISecret,并以`:`分隔。
|
||||
|
||||
#### Gemini
|
||||
|
||||
Gemini 所对应的 `type` 为 `gemini`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| --------------------- | -------- | -------- |-----|-------------------------------------------------------------------------------------------------|
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| --------------------- | ------------- | -------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1` 或 `v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 |
|
||||
|
||||
#### DeepL
|
||||
|
||||
@@ -253,18 +249,43 @@ DeepL 所对应的 `type` 为 `deepl`。它特有的配置字段如下:
|
||||
Cohere 所对应的 `type` 为 `cohere`。它并无特有的配置字段。
|
||||
|
||||
#### Together-AI
|
||||
|
||||
Together-AI 所对应的 `type` 为 `together-ai`。它并无特有的配置字段。
|
||||
|
||||
#### Dify
|
||||
|
||||
Dify 所对应的 `type` 为 `dify`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| -- | -------- |------| ------ | ---------------------------- |
|
||||
| `difyApiUrl` | string | 非必填 | - | dify私有化部署的url |
|
||||
| `botType` | string | 非必填 | - | dify的应用类型,Chat/Completion/Agent/Workflow |
|
||||
| `inputVariable` | string | 非必填 | - | dify中应用类型为workflow时需要设置输入变量,当botType为workflow时一起使用 |
|
||||
| `outputVariable` | string | 非必填 | - | dify中应用类型为workflow时需要设置输出变量,当botType为workflow时一起使用 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------- | -------- | -------- | ------ | -------------------------------------------------------------------------------- |
|
||||
| `difyApiUrl` | string | 非必填 | - | dify 私有化部署的 url |
|
||||
| `botType` | string | 非必填 | - | dify 的应用类型,Chat/Completion/Agent/Workflow |
|
||||
| `inputVariable` | string | 非必填 | - | dify 中应用类型为 workflow 时需要设置输入变量,当 botType 为 workflow 时一起使用 |
|
||||
| `outputVariable` | string | 非必填 | - | dify 中应用类型为 workflow 时需要设置输出变量,当 botType 为 workflow 时一起使用 |
|
||||
|
||||
#### Google Vertex AI
|
||||
|
||||
Google Vertex AI 所对应的 type 为 vertex。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|-----------------------------|---------------|--------|--------|-------------------------------------------------------------------------------|
|
||||
| `vertexAuthKey` | string | 必填 | - | 用于认证的 Google Service Account JSON Key,格式为 PEM 编码的 PKCS#8 私钥和 client_email 等信息 |
|
||||
| `vertexRegion` | string | 必填 | - | Google Cloud 区域(如 us-central1, europe-west4 等),用于构建 Vertex API 地址 |
|
||||
| `vertexProjectId` | string | 必填 | - | Google Cloud 项目 ID,用于标识目标 GCP 项目 |
|
||||
| `vertexAuthServiceName` | string | 必填 | - | 用于 OAuth2 认证的服务名称,该服务为了访问oauth2.googleapis.com |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `vertexTokenRefreshAhead` | number | 非必填 | - | Vertex access token刷新提前时间(单位秒) |
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
AWS Bedrock 所对应的 type 为 bedrock。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
|---------------------------|--------|------|-----|------------------------------|
|
||||
| `awsAccessKey` | string | 必填 | - | AWS Access Key,用于身份认证 |
|
||||
| `awsSecretKey` | string | 必填 | - | AWS Secret Access Key,用于身份认证 |
|
||||
| `awsRegion` | string | 必填 | - | AWS 区域,例如:us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | 非必填 | - | Bedrock 额外模型请求参数 |
|
||||
|
||||
## 用法示例
|
||||
|
||||
@@ -376,20 +397,20 @@ provider:
|
||||
provider:
|
||||
type: qwen
|
||||
apiTokens:
|
||||
- "YOUR_QWEN_API_TOKEN"
|
||||
- 'YOUR_QWEN_API_TOKEN'
|
||||
modelMapping:
|
||||
'gpt-3': "qwen-turbo"
|
||||
'gpt-35-turbo': "qwen-plus"
|
||||
'gpt-4-turbo': "qwen-max"
|
||||
'gpt-4-*': "qwen-max"
|
||||
'gpt-4o': "qwen-vl-plus"
|
||||
'gpt-3': 'qwen-turbo'
|
||||
'gpt-35-turbo': 'qwen-plus'
|
||||
'gpt-4-turbo': 'qwen-max'
|
||||
'gpt-4-*': 'qwen-max'
|
||||
'gpt-4o': 'qwen-vl-plus'
|
||||
'text-embedding-v1': 'text-embedding-v1'
|
||||
'*': "qwen-turbo"
|
||||
'*': 'qwen-turbo'
|
||||
```
|
||||
|
||||
**AI 对话请求示例**
|
||||
|
||||
URL: http://your-domain/v1/chat/completions
|
||||
URL: <http://your-domain/v1/chat/completions>
|
||||
|
||||
请求示例:
|
||||
|
||||
@@ -434,7 +455,7 @@ URL: http://your-domain/v1/chat/completions
|
||||
|
||||
**多模态模型 API 请求示例(适用于 `qwen-vl-plus` 和 `qwen-vl-max` 模型)**
|
||||
|
||||
URL: http://your-domain/v1/chat/completions
|
||||
URL: <http://your-domain/v1/chat/completions>
|
||||
|
||||
请求示例:
|
||||
|
||||
@@ -493,7 +514,7 @@ URL: http://your-domain/v1/chat/completions
|
||||
|
||||
**文本向量请求示例**
|
||||
|
||||
URL: http://your-domain/v1/embeddings
|
||||
URL: <http://your-domain/v1/embeddings>
|
||||
|
||||
请求示例:
|
||||
|
||||
@@ -606,12 +627,12 @@ provider:
|
||||
provider:
|
||||
type: qwen
|
||||
apiTokens:
|
||||
- "YOUR_QWEN_API_TOKEN"
|
||||
- 'YOUR_QWEN_API_TOKEN'
|
||||
modelMapping:
|
||||
"*": "qwen-long" # 通义千问的文件上下文只能在 qwen-long 模型下使用
|
||||
'*': 'qwen-long' # 通义千问的文件上下文只能在 qwen-long 模型下使用
|
||||
qwenFileIds:
|
||||
- "file-fe-xxx"
|
||||
- "file-fe-yyy"
|
||||
- 'file-fe-xxx'
|
||||
- 'file-fe-yyy'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -653,7 +674,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用original协议代理百炼智能体应用
|
||||
### 使用 original 协议代理百炼智能体应用
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -661,17 +682,18 @@ provider:
|
||||
provider:
|
||||
type: qwen
|
||||
apiTokens:
|
||||
- "YOUR_DASHSCOPE_API_TOKEN"
|
||||
- 'YOUR_DASHSCOPE_API_TOKEN'
|
||||
protocol: original
|
||||
```
|
||||
|
||||
**请求实例**
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"prompt": "介绍一下Dubbo"
|
||||
},
|
||||
"parameters": {},
|
||||
"parameters": {},
|
||||
"debug": {}
|
||||
}
|
||||
```
|
||||
@@ -789,7 +811,7 @@ provider:
|
||||
provider:
|
||||
type: groq
|
||||
apiTokens:
|
||||
- "YOUR_GROQ_API_TOKEN"
|
||||
- 'YOUR_GROQ_API_TOKEN'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -848,8 +870,8 @@ provider:
|
||||
provider:
|
||||
type: claude
|
||||
apiTokens:
|
||||
- "YOUR_CLAUDE_API_TOKEN"
|
||||
version: "2023-06-01"
|
||||
- 'YOUR_CLAUDE_API_TOKEN'
|
||||
version: '2023-06-01'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -899,14 +921,14 @@ provider:
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: "hunyuan"
|
||||
hunyuanAuthKey: "<YOUR AUTH KEY>"
|
||||
type: 'hunyuan'
|
||||
hunyuanAuthKey: '<YOUR AUTH KEY>'
|
||||
apiTokens:
|
||||
- ""
|
||||
hunyuanAuthId: "<YOUR AUTH ID>"
|
||||
- ''
|
||||
hunyuanAuthId: '<YOUR AUTH ID>'
|
||||
timeout: 1200000
|
||||
modelMapping:
|
||||
"*": "hunyuan-lite"
|
||||
'*': 'hunyuan-lite'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -967,10 +989,10 @@ curl --location 'http://<your higress domain>/v1/chat/completions' \
|
||||
provider:
|
||||
type: baidu
|
||||
apiTokens:
|
||||
- "YOUR_BAIDU_API_TOKEN"
|
||||
- 'YOUR_BAIDU_API_TOKEN'
|
||||
modelMapping:
|
||||
'gpt-3': "ERNIE-4.0"
|
||||
'*': "ERNIE-4.0"
|
||||
'gpt-3': 'ERNIE-4.0'
|
||||
'*': 'ERNIE-4.0'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1014,7 +1036,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理MiniMax服务
|
||||
### 使用 OpenAI 协议代理 MiniMax 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -1022,11 +1044,11 @@ provider:
|
||||
provider:
|
||||
type: minimax
|
||||
apiTokens:
|
||||
- "YOUR_MINIMAX_API_TOKEN"
|
||||
- 'YOUR_MINIMAX_API_TOKEN'
|
||||
modelMapping:
|
||||
"gpt-3": "abab6.5s-chat"
|
||||
"gpt-4": "abab6.5g-chat"
|
||||
"*": "abab6.5t-chat"
|
||||
'gpt-3': 'abab6.5s-chat'
|
||||
'gpt-4': 'abab6.5g-chat'
|
||||
'*': 'abab6.5t-chat'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1090,12 +1112,12 @@ provider:
|
||||
provider:
|
||||
type: github
|
||||
apiTokens:
|
||||
- "YOUR_GITHUB_ACCESS_TOKEN"
|
||||
- 'YOUR_GITHUB_ACCESS_TOKEN'
|
||||
modelMapping:
|
||||
"gpt-4o": "gpt-4o"
|
||||
"gpt-4": "Phi-3.5-MoE-instruct"
|
||||
"gpt-3.5": "cohere-command-r-08-2024"
|
||||
"text-embedding-3-large": "text-embedding-3-large"
|
||||
'gpt-4o': 'gpt-4o'
|
||||
'gpt-4': 'Phi-3.5-MoE-instruct'
|
||||
'gpt-3.5': 'cohere-command-r-08-2024'
|
||||
'text-embedding-3-large': 'text-embedding-3-large'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1121,6 +1143,7 @@ provider:
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"choices": [
|
||||
@@ -1183,7 +1206,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理360智脑服务
|
||||
### 使用 OpenAI 协议代理 360 智脑服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -1191,13 +1214,13 @@ provider:
|
||||
provider:
|
||||
type: ai360
|
||||
apiTokens:
|
||||
- "YOUR_360_API_TOKEN"
|
||||
- 'YOUR_360_API_TOKEN'
|
||||
modelMapping:
|
||||
"gpt-4o": "360gpt-turbo-responsibility-8k"
|
||||
"gpt-4": "360gpt2-pro"
|
||||
"gpt-3.5": "360gpt-turbo"
|
||||
"text-embedding-3-small": "embedding_s1_v1.2"
|
||||
"*": "360gpt-pro"
|
||||
'gpt-4o': '360gpt-turbo-responsibility-8k'
|
||||
'gpt-4': '360gpt2-pro'
|
||||
'gpt-3.5': '360gpt-turbo'
|
||||
'text-embedding-3-small': 'embedding_s1_v1.2'
|
||||
'*': '360gpt-pro'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1257,14 +1280,14 @@ provider:
|
||||
|
||||
**文本向量请求示例**
|
||||
|
||||
URL: http://your-domain/v1/embeddings
|
||||
URL: <http://your-domain/v1/embeddings>
|
||||
|
||||
请求示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"input":["你好"],
|
||||
"model":"text-embedding-3-small"
|
||||
"input": ["你好"],
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1305,10 +1328,10 @@ URL: http://your-domain/v1/embeddings
|
||||
provider:
|
||||
type: cloudflare
|
||||
apiTokens:
|
||||
- "YOUR_WORKERS_AI_API_TOKEN"
|
||||
cloudflareAccountId: "YOUR_CLOUDFLARE_ACCOUNT_ID"
|
||||
- 'YOUR_WORKERS_AI_API_TOKEN'
|
||||
cloudflareAccountId: 'YOUR_CLOUDFLARE_ACCOUNT_ID'
|
||||
modelMapping:
|
||||
"*": "@cf/meta/llama-3-8b-instruct"
|
||||
'*': '@cf/meta/llama-3-8b-instruct'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1348,7 +1371,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理Spark服务
|
||||
### 使用 OpenAI 协议代理 Spark 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -1356,11 +1379,11 @@ provider:
|
||||
provider:
|
||||
type: spark
|
||||
apiTokens:
|
||||
- "APIKey:APISecret"
|
||||
- 'APIKey:APISecret'
|
||||
modelMapping:
|
||||
"gpt-4o": "generalv3.5"
|
||||
"gpt-4": "generalv3"
|
||||
"*": "general"
|
||||
'gpt-4o': 'generalv3.5'
|
||||
'gpt-4': 'generalv3'
|
||||
'*': 'general'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1407,7 +1430,7 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 gemini 服务
|
||||
### 使用 OpenAI 协议代理 Gemini 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -1474,8 +1497,8 @@ provider:
|
||||
provider:
|
||||
type: deepl
|
||||
apiTokens:
|
||||
- "YOUR_DEEPL_API_TOKEN"
|
||||
targetLang: "ZH"
|
||||
- 'YOUR_DEEPL_API_TOKEN'
|
||||
targetLang: 'ZH'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
@@ -1500,6 +1523,7 @@ provider:
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"choices": [
|
||||
@@ -1522,16 +1546,18 @@ provider:
|
||||
### 使用 OpenAI 协议代理 Together-AI 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: together-ai
|
||||
apiTokens:
|
||||
- "YOUR_TOGETHER_AI_API_TOKEN"
|
||||
- 'YOUR_TOGETHER_AI_API_TOKEN'
|
||||
modelMapping:
|
||||
"*": "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
'*': 'Qwen/Qwen2.5-72B-Instruct-Turbo'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "Qwen/Qwen2.5-72B-Instruct-Turbo",
|
||||
@@ -1545,6 +1571,7 @@ provider:
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "8f5809d54b73efac",
|
||||
@@ -1576,16 +1603,18 @@ provider:
|
||||
### 使用 OpenAI 协议代理 Dify 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: dify
|
||||
apiTokens:
|
||||
- "YOUR_DIFY_API_TOKEN"
|
||||
- 'YOUR_DIFY_API_TOKEN'
|
||||
modelMapping:
|
||||
"*": "dify"
|
||||
'*': 'dify'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4-turbo",
|
||||
@@ -1600,6 +1629,7 @@ provider:
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "e33fc636-f9e8-4fae-8d5e-fbd0acb09401",
|
||||
@@ -1624,6 +1654,123 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Google Vertex 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是 Vertex AI 提供的 Gemini 模型,由 Google 开发的人工智能助手。我可以回答问题、提供信息和帮助完成各种任务。有什么我可以帮您的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 43,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 AWS Bedrock 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
|
||||
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
awsRegion: "YOUR_AWS_REGION"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "dc5812e2-6a62-49d6-829e-5c327b15e4e2",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是Claude,一个由Anthropic开发的AI助手。很高兴认识你!我的目标是以诚实、有益且有意义的方式与人类交流。我会尽力提供准确和有帮助的信息,同时保持诚实和正直。请问我今天能为你做些什么呢?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1749657608,
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 16,
|
||||
"completion_tokens": 101,
|
||||
"total_tokens": 117
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 完整配置示例
|
||||
|
||||
@@ -1643,7 +1790,7 @@ spec:
|
||||
provider:
|
||||
type: groq
|
||||
apiTokens:
|
||||
- "YOUR_API_TOKEN"
|
||||
- 'YOUR_API_TOKEN'
|
||||
ingress:
|
||||
- groq
|
||||
url: oci://higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/ai-proxy:1.0.0
|
||||
@@ -1655,7 +1802,7 @@ metadata:
|
||||
higress.io/backend-protocol: HTTPS
|
||||
higress.io/destination: groq.dns
|
||||
higress.io/proxy-ssl-name: api.groq.com
|
||||
higress.io/proxy-ssl-server-name: "on"
|
||||
higress.io/proxy-ssl-server-name: 'on'
|
||||
labels:
|
||||
higress.io/resource-definer: higress
|
||||
name: groq
|
||||
@@ -1716,7 +1863,7 @@ services:
|
||||
networks:
|
||||
- higress-net
|
||||
ports:
|
||||
- "10000:10000"
|
||||
- '10000:10000'
|
||||
volumes:
|
||||
- ./envoy.yaml:/etc/envoy/envoy.yaml
|
||||
- ./plugin.wasm:/etc/envoy/plugin.wasm
|
||||
@@ -1745,7 +1892,7 @@ static_resources:
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
'@type': type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
scheme_header_transformation:
|
||||
scheme_to_overwrite: https
|
||||
stat_prefix: ingress_http
|
||||
@@ -1753,23 +1900,23 @@ static_resources:
|
||||
access_log:
|
||||
- name: envoy.access_loggers.stdout
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StdoutAccessLog
|
||||
'@type': type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StdoutAccessLog
|
||||
# Modify as required
|
||||
route_config:
|
||||
name: local_route
|
||||
virtual_hosts:
|
||||
- name: local_service
|
||||
domains: [ "*" ]
|
||||
domains: ['*']
|
||||
routes:
|
||||
- match:
|
||||
prefix: "/"
|
||||
prefix: '/'
|
||||
route:
|
||||
cluster: claude
|
||||
timeout: 300s
|
||||
http_filters:
|
||||
- name: claude
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
'@type': type.googleapis.com/udpa.type.v1.TypedStruct
|
||||
type_url: type.googleapis.com/envoy.extensions.filters.http.wasm.v3.Wasm
|
||||
value:
|
||||
config:
|
||||
@@ -1780,7 +1927,7 @@ static_resources:
|
||||
local:
|
||||
filename: /etc/envoy/plugin.wasm
|
||||
configuration:
|
||||
"@type": "type.googleapis.com/google.protobuf.StringValue"
|
||||
'@type': 'type.googleapis.com/google.protobuf.StringValue'
|
||||
value: | # 插件配置
|
||||
{
|
||||
"provider": {
|
||||
@@ -1809,8 +1956,8 @@ static_resources:
|
||||
transport_socket:
|
||||
name: envoy.transport_sockets.tls
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
"sni": "api.anthropic.com"
|
||||
'@type': type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
|
||||
'sni': 'api.anthropic.com'
|
||||
```
|
||||
|
||||
访问示例:
|
||||
|
||||
@@ -29,15 +29,16 @@ Plugin execution priority: `100`
|
||||
|
||||
**Details for the `provider` configuration fields:**
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `type` | string | Required | - | Name of the AI service provider |
|
||||
| `apiTokens` | array of string | Optional | - | Tokens used for authentication when accessing AI services. If multiple tokens are configured, the plugin randomly selects one for each request. Some service providers only support configuring a single token. |
|
||||
| `timeout` | number | Optional | - | Timeout for accessing AI services, in milliseconds. The default value is 120000, which equals 2 minutes. Only used when retrieving context data. Won't affect the request forwarded to the LLM upstream. |
|
||||
| `modelMapping` | map of string | Optional | - | Mapping table for AI models, used to map model names in requests to names supported by the service provider.<br/>1. Supports prefix matching. For example, "gpt-3-\*" matches all model names starting with “gpt-3-”;<br/>2. Supports using "\*" as a key for a general fallback mapping;<br/>3. If the mapped target name is an empty string "", the original model name is preserved. |
|
||||
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider) |
|
||||
| `context` | object | Optional | - | Configuration for AI conversation context information |
|
||||
| `customSettings` | array of customSetting | Optional | - | Specifies overrides or fills parameters for AI requests |
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
| -------------- | --------------- | -------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `type` | string | Required | - | Name of the AI service provider |
|
||||
| `apiTokens` | array of string | Optional | - | Tokens used for authentication when accessing AI services. If multiple tokens are configured, the plugin randomly selects one for each request. Some service providers only support configuring a single token. |
|
||||
| `timeout` | number | Optional | - | Timeout for accessing AI services, in milliseconds. The default value is 120000, which equals 2 minutes. Only used when retrieving context data. Won't affect the request forwarded to the LLM upstream. |
|
||||
| `modelMapping` | map of string | Optional | - | Mapping table for AI models, used to map model names in requests to names supported by the service provider.<br/>1. Supports prefix matching. For example, "gpt-3-\*" matches all model names starting with “gpt-3-”;<br/>2. Supports using "\*" as a key for a general fallback mapping;<br/>3. If the mapped target name is an empty string "", the original model name is preserved. |
|
||||
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider) |
|
||||
| `context` | object | Optional | - | Configuration for AI conversation context information |
|
||||
| `customSettings` | array of customSetting | Optional | - | Specifies overrides or fills parameters for AI requests |
|
||||
| `subPath` | string | Optional | - | If subPath is configured, the prefix will be removed from the request path before further processing. |
|
||||
|
||||
**Details for the `context` configuration fields:**
|
||||
|
||||
@@ -208,6 +209,29 @@ For DeepL, the corresponding `type` is `deepl`. Its unique configuration field i
|
||||
| ------------ | --------- | ----------- | ------- | ------------------------------------ |
|
||||
| `targetLang` | string | Required | - | The target language required by the DeepL translation service |
|
||||
|
||||
#### Google Vertex AI
|
||||
For Vertex, the corresponding `type` is `vertex`. Its unique configuration field is:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|-----------------------------|---------------|---------------| ------ |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `vertexAuthKey` | string | Required | - | Google Service Account JSON Key used for authentication. The format should be PEM encoded PKCS#8 private key along with client_email and other information |
|
||||
| `vertexRegion` | string | Required | - | Google Cloud region (e.g., us-central1, europe-west4) used to build the Vertex API address |
|
||||
| `vertexProjectId` | string | Required | - | Google Cloud Project ID, used to identify the target GCP project |
|
||||
| `vertexAuthServiceName` | string | Required | - | Service name for OAuth2 authentication, used to access oauth2.googleapis.com |
|
||||
| `vertexGeminiSafetySetting` | map of string | Optional | - | Gemini model content safety filtering settings. |
|
||||
| `vertexTokenRefreshAhead` | number | Optional | - | Vertex access token refresh ahead time in seconds |
|
||||
|
||||
#### AWS Bedrock
|
||||
|
||||
For AWS Bedrock, the corresponding `type` is `bedrock`. Its unique configuration field is:
|
||||
|
||||
| Name | Data Type | Requirement | Default | Description |
|
||||
|---------------------------|-----------|-------------|---------|---------------------------------------------------------|
|
||||
| `awsAccessKey` | string | Required | - | AWS Access Key used for authentication |
|
||||
| `awsSecretKey` | string | Required | - | AWS Secret Access Key used for authentication |
|
||||
| `awsRegion` | string | Required | - | AWS region, e.g., us-east-1 |
|
||||
| `bedrockAdditionalFields` | map | Optional | - | Additional inference parameters that the model supports |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Using OpenAI Protocol Proxy for Azure OpenAI Service
|
||||
@@ -1411,6 +1435,115 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for Google Vertex Services
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: vertex
|
||||
vertexAuthKey: |
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "your-project-id",
|
||||
"private_key_id": "your-private-key-id",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\n...\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "your-service-account@your-project.iam.gserviceaccount.com",
|
||||
"token_uri": "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
vertexRegion: us-central1
|
||||
vertexProjectId: your-project-id
|
||||
vertexAuthServiceName: your-auth-service-name
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-0000000000000",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am the Gemini model provided by Vertex AI, developed by Google. I can answer questions, provide information, and assist in completing various tasks. How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1729986750,
|
||||
"model": "gemini-2.0-flash-001",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 15,
|
||||
"completion_tokens": 43,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Utilizing OpenAI Protocol Proxy for AWS Bedrock Services
|
||||
**Configuration Information**
|
||||
```yaml
|
||||
provider:
|
||||
type: bedrock
|
||||
awsAccessKey: "YOUR_AWS_ACCESS_KEY_ID"
|
||||
awsSecretKey: "YOUR_AWS_SECRET_ACCESS_KEY"
|
||||
awsRegion: "YOUR_AWS_REGION"
|
||||
bedrockAdditionalFields:
|
||||
top_k: 200
|
||||
```
|
||||
|
||||
**Request Example**
|
||||
```json
|
||||
{
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "who are you"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example**
|
||||
```json
|
||||
{
|
||||
"id": "d52da49d-daf3-49d9-a105-0b527481fe14",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'm Claude, an AI created by Anthropic. I aim to be helpful, honest, and harmless. I won't pretend to be human, and I'll always try to be direct and truthful about what I am and what I can do."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"created": 1749659050,
|
||||
"model": "arn:aws:bedrock:us-west-2::foundation-model/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||
"object": "chat.completion",
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 57,
|
||||
"total_tokens": 67
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Full Configuration Example
|
||||
|
||||
### Kubernetes Example
|
||||
|
||||
@@ -161,7 +161,8 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
if settingErr != nil {
|
||||
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
||||
}
|
||||
if providerConfig.IsOpenAIProtocol() {
|
||||
// 仅 /v1/chat/completions 和 /v1/completions 接口支持 stream_options 参数
|
||||
if providerConfig.IsOpenAIProtocol() && (apiName == provider.ApiNameChatCompletion || apiName == provider.ApiNameCompletion) {
|
||||
newBody = normalizeOpenAiRequestBody(newBody)
|
||||
}
|
||||
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
||||
@@ -315,7 +316,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
func normalizeOpenAiRequestBody(body []byte) []byte {
|
||||
var err error
|
||||
// Default setting include_usage.
|
||||
if gjson.GetBytes(body, "stream").Bool() {
|
||||
if gjson.GetBytes(body, "stream").Bool() && (!gjson.GetBytes(body, "stream_options").Exists() || !gjson.GetBytes(body, "stream_options.include_usage").Exists()) {
|
||||
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
log.Errorf("set include_usage failed, err:%s", err)
|
||||
@@ -352,15 +353,63 @@ func getApiName(path string) provider.ApiName {
|
||||
if strings.HasSuffix(path, "/v1/images/generations") {
|
||||
return provider.ApiNameImageGeneration
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/variations") {
|
||||
return provider.ApiNameImageVariation
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/images/edits") {
|
||||
return provider.ApiNameImageEdit
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/batches") {
|
||||
return provider.ApiNameBatches
|
||||
}
|
||||
if util.RegRetrieveBatchPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveBatch
|
||||
}
|
||||
if util.RegCancelBatchPath.MatchString(path) {
|
||||
return provider.ApiNameCancelBatch
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/files") {
|
||||
return provider.ApiNameFiles
|
||||
}
|
||||
if util.RegRetrieveFilePath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFile
|
||||
}
|
||||
if util.RegRetrieveFileContentPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFileContent
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/models") {
|
||||
return provider.ApiNameModels
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/fine_tuning/jobs") {
|
||||
return provider.ApiNameFineTuningJobs
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFineTuningJob
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobEventsPath.MatchString(path) {
|
||||
return provider.ApiNameFineTuningJobEvents
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobCheckpointsPath.MatchString(path) {
|
||||
return provider.ApiNameFineTuningJobCheckpoints
|
||||
}
|
||||
if util.RegCancelFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNameCancelFineTuningJob
|
||||
}
|
||||
if util.RegResumeFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNameResumeFineTuningJob
|
||||
}
|
||||
if util.RegPauseFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNamePauseFineTuningJob
|
||||
}
|
||||
if util.RegFineTuningCheckpointPermissionPath.MatchString(path) {
|
||||
return provider.ApiNameFineTuningCheckpointPermissions
|
||||
}
|
||||
if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) {
|
||||
return provider.ApiNameDeleteFineTuningCheckpointPermission
|
||||
}
|
||||
if strings.HasSuffix(path, "/v1/responses") {
|
||||
return provider.ApiNameResponses
|
||||
}
|
||||
// cohere style
|
||||
if strings.HasSuffix(path, "/v1/rerank") {
|
||||
return provider.ApiNameCohereV1Rerank
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -22,6 +23,8 @@ import (
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,8 +42,7 @@ const (
|
||||
requestIdHeader = "X-Amzn-Requestid"
|
||||
)
|
||||
|
||||
type bedrockProviderInitializer struct {
|
||||
}
|
||||
type bedrockProviderInitializer struct{}
|
||||
|
||||
func (b *bedrockProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if len(config.awsAccessKey) == 0 || len(config.awsSecretKey) == 0 {
|
||||
@@ -101,7 +103,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
|
||||
}
|
||||
if bedrockEvent.StopReason != nil {
|
||||
chatChoice.FinishReason = stopReasonBedrock2OpenAI(*bedrockEvent.StopReason)
|
||||
chatChoice.FinishReason = util.Ptr(stopReasonBedrock2OpenAI(*bedrockEvent.StopReason))
|
||||
}
|
||||
choices = append(choices, chatChoice)
|
||||
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||
@@ -115,7 +117,7 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
}
|
||||
if bedrockEvent.Usage != nil {
|
||||
openAIFormattedChunk.Choices = choices[:0]
|
||||
openAIFormattedChunk.Usage = usage{
|
||||
openAIFormattedChunk.Usage = &usage{
|
||||
CompletionTokens: bedrockEvent.Usage.OutputTokens,
|
||||
PromptTokens: bedrockEvent.Usage.InputTokens,
|
||||
TotalTokens: bedrockEvent.Usage.TotalTokens,
|
||||
@@ -607,6 +609,11 @@ func (b *bedrockProvider) insertHttpContextMessage(body []byte, content string,
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if gjson.GetBytes(body, "model").Exists() {
|
||||
rawModel := gjson.GetBytes(body, "model").String()
|
||||
encodedModel := url.QueryEscape(rawModel)
|
||||
body, _ = sjson.SetBytes(body, "model", encodedModel)
|
||||
}
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
@@ -716,21 +723,34 @@ func (b *bedrockProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, b
|
||||
|
||||
func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCompletionRequest, headers http.Header) ([]byte, error) {
|
||||
messages := make([]bedrockMessage, 0, len(origRequest.Messages))
|
||||
for i := range origRequest.Messages {
|
||||
messages = append(messages, chatMessage2BedrockMessage(origRequest.Messages[i]))
|
||||
systemMessages := make([]systemContentBlock, 0)
|
||||
|
||||
for _, msg := range origRequest.Messages {
|
||||
if msg.Role == roleSystem {
|
||||
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
|
||||
} else {
|
||||
messages = append(messages, chatMessage2BedrockMessage(msg))
|
||||
}
|
||||
}
|
||||
|
||||
request := &bedrockTextGenRequest{
|
||||
System: systemMessages,
|
||||
Messages: messages,
|
||||
InferenceConfig: bedrockInferenceConfig{
|
||||
MaxTokens: origRequest.MaxTokens,
|
||||
Temperature: origRequest.Temperature,
|
||||
TopP: origRequest.TopP,
|
||||
},
|
||||
AdditionalModelRequestFields: map[string]interface{}{},
|
||||
AdditionalModelRequestFields: make(map[string]interface{}),
|
||||
PerformanceConfig: PerformanceConfiguration{
|
||||
Latency: "standard",
|
||||
},
|
||||
}
|
||||
|
||||
for key, value := range b.config.bedrockAdditionalFields {
|
||||
request.AdditionalModelRequestFields[key] = value
|
||||
}
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
@@ -748,18 +768,19 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
Role: bedrockResponse.Output.Message.Role,
|
||||
Content: outputContent,
|
||||
},
|
||||
FinishReason: stopReasonBedrock2OpenAI(bedrockResponse.StopReason),
|
||||
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
|
||||
},
|
||||
}
|
||||
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||
modelId, _ := url.QueryUnescape(ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
return &chatCompletionResponse{
|
||||
Id: requestId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Model: modelId,
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: bedrockResponse.Usage.InputTokens,
|
||||
CompletionTokens: bedrockResponse.Usage.OutputTokens,
|
||||
TotalTokens: bedrockResponse.Usage.TotalTokens,
|
||||
@@ -900,8 +921,8 @@ func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
||||
path = encodeSigV4Path(path)
|
||||
hashedPayload := sha256Hex(body)
|
||||
path = urlEncoding(path)
|
||||
|
||||
endpoint := fmt.Sprintf(bedrockDefaultDomain, b.config.awsRegion)
|
||||
canonicalHeaders := fmt.Sprintf("host:%s\nx-amz-date:%s\n", endpoint, amzDate)
|
||||
@@ -918,14 +939,15 @@ func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, bod
|
||||
return signature
|
||||
}
|
||||
|
||||
func urlEncoding(rawStr string) string {
|
||||
encodedStr := strings.ReplaceAll(rawStr, ":", "%3A")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "+", "%2B")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "=", "%3D")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "&", "%26")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "$", "%24")
|
||||
encodedStr = strings.ReplaceAll(encodedStr, "@", "%40")
|
||||
return encodedStr
|
||||
func encodeSigV4Path(path string) string {
|
||||
segments := strings.Split(path, "/")
|
||||
for i, seg := range segments {
|
||||
if seg == "" {
|
||||
continue
|
||||
}
|
||||
segments[i] = url.PathEscape(seg)
|
||||
}
|
||||
return strings.Join(segments, "/")
|
||||
}
|
||||
|
||||
func getSignatureKey(key, dateStamp, region, service string) []byte {
|
||||
|
||||
@@ -19,22 +19,55 @@ const (
|
||||
claudeDomain = "api.anthropic.com"
|
||||
claudeChatCompletionPath = "/v1/messages"
|
||||
claudeCompletionPath = "/v1/complete"
|
||||
defaultVersion = "2023-06-01"
|
||||
defaultMaxTokens = 4096
|
||||
claudeDefaultVersion = "2023-06-01"
|
||||
claudeDefaultMaxTokens = 4096
|
||||
)
|
||||
|
||||
type claudeProviderInitializer struct{}
|
||||
|
||||
type claudeTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"input_schema,omitempty"`
|
||||
}
|
||||
|
||||
type claudeToolChoice struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name,omitempty"`
|
||||
DisableParallelToolUse bool `json:"disable_parallel_tool_use,omitempty"`
|
||||
}
|
||||
|
||||
type claudeChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type claudeChatMessageContentSource struct {
|
||||
Type string `json:"type"`
|
||||
MediaType string `json:"media_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
FileId string `json:"file_id,omitempty"`
|
||||
}
|
||||
|
||||
type claudeChatMessageContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *claudeChatMessageContentSource `json:"source,omitempty"`
|
||||
}
|
||||
type claudeTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenResponse struct {
|
||||
@@ -50,13 +83,14 @@ type claudeTextGenResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenContent struct {
|
||||
Type string `json:"type"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokens int `json:"input_tokens,omitempty"`
|
||||
OutputTokens int `json:"output_tokens,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenError struct {
|
||||
@@ -65,12 +99,12 @@ type claudeTextGenError struct {
|
||||
}
|
||||
|
||||
type claudeTextGenStreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Message claudeTextGenResponse `json:"message"`
|
||||
Index int `json:"index"`
|
||||
ContentBlock *claudeTextGenContent `json:"content_block"`
|
||||
Delta *claudeTextGenDelta `json:"delta"`
|
||||
Usage claudeTextGenUsage `json:"usage"`
|
||||
Type string `json:"type"`
|
||||
Message *claudeTextGenResponse `json:"message,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ContentBlock *claudeTextGenContent `json:"content_block,omitempty"`
|
||||
Delta *claudeTextGenDelta `json:"delta,omitempty"`
|
||||
Usage *claudeTextGenUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenDelta struct {
|
||||
@@ -93,6 +127,7 @@ func (c *claudeProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
string(ApiNameCompletion): claudeCompletionPath,
|
||||
// docs: https://docs.anthropic.com/en/docs/build-with-claude/embeddings#voyage-http-api
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,6 +142,10 @@ func (c *claudeProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
type claudeProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
|
||||
messageId string
|
||||
usage usage
|
||||
serviceTier string
|
||||
}
|
||||
|
||||
func (c *claudeProvider) GetProviderType() string {
|
||||
@@ -124,16 +163,16 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
|
||||
headers.Set("x-api-key", c.config.GetApiTokenInUse(ctx))
|
||||
|
||||
if c.config.claudeVersion == "" {
|
||||
c.config.claudeVersion = defaultVersion
|
||||
if c.config.apiVersion == "" {
|
||||
c.config.apiVersion = claudeDefaultVersion
|
||||
}
|
||||
|
||||
headers.Set("anthropic-version", c.config.claudeVersion)
|
||||
headers.Set("anthropic-version", c.config.apiVersion)
|
||||
}
|
||||
|
||||
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !c.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body)
|
||||
}
|
||||
@@ -205,14 +244,15 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRequest) *claudeTextGenRequest {
|
||||
claudeRequest := claudeTextGenRequest{
|
||||
Model: origRequest.Model,
|
||||
MaxTokens: origRequest.MaxTokens,
|
||||
MaxTokens: origRequest.getMaxTokens(),
|
||||
StopSequences: origRequest.Stop,
|
||||
Stream: origRequest.Stream,
|
||||
Temperature: origRequest.Temperature,
|
||||
TopP: origRequest.TopP,
|
||||
// ServiceTier: origRequest.ServiceTier,
|
||||
}
|
||||
if claudeRequest.MaxTokens == 0 {
|
||||
claudeRequest.MaxTokens = defaultMaxTokens
|
||||
claudeRequest.MaxTokens = claudeDefaultMaxTokens
|
||||
}
|
||||
|
||||
for _, message := range origRequest.Messages {
|
||||
@@ -220,12 +260,80 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
claudeRequest.System = message.StringContent()
|
||||
continue
|
||||
}
|
||||
claudeMessage := chatMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
|
||||
claudeMessage := claudeChatMessage{
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
} else {
|
||||
chatMessageContents := make([]claudeChatMessageContent, 0)
|
||||
for _, messageContent := range message.ParseContent() {
|
||||
switch messageContent.Type {
|
||||
case contentTypeText:
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: messageContent.Text,
|
||||
})
|
||||
case contentTypeImageUrl:
|
||||
if strings.HasPrefix(messageContent.ImageUrl.Url, "data:") {
|
||||
parts := strings.SplitN(messageContent.ImageUrl.Url, ";", 2)
|
||||
if len(parts) != 2 {
|
||||
log.Errorf("invalid image url format: %s", messageContent.ImageUrl.Url)
|
||||
continue
|
||||
}
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: "image",
|
||||
Source: &claudeChatMessageContentSource{
|
||||
Type: "base64",
|
||||
MediaType: strings.TrimPrefix(parts[0], "data:"),
|
||||
Data: strings.TrimPrefix(parts[1], "base64,"),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: "image",
|
||||
Source: &claudeChatMessageContentSource{
|
||||
Type: "url",
|
||||
Url: messageContent.ImageUrl.Url,
|
||||
},
|
||||
})
|
||||
}
|
||||
case contentTypeFile:
|
||||
chatMessageContents = append(chatMessageContents, claudeChatMessageContent{
|
||||
Type: "file",
|
||||
Source: &claudeChatMessageContentSource{
|
||||
Type: "url",
|
||||
FileId: messageContent.File.FileId,
|
||||
},
|
||||
})
|
||||
default:
|
||||
log.Errorf("Unsupported content type: %s", messageContent.Type)
|
||||
continue
|
||||
}
|
||||
}
|
||||
claudeMessage.Content = chatMessageContents
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
}
|
||||
|
||||
for _, tool := range origRequest.Tools {
|
||||
claudeTool := claudeTool{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
InputSchema: tool.Function.Parameters,
|
||||
}
|
||||
claudeRequest.Tools = append(claudeRequest.Tools, claudeTool)
|
||||
}
|
||||
|
||||
if tc := origRequest.getToolChoiceObject(); tc != nil {
|
||||
claudeRequest.ToolChoice = &claudeToolChoice{
|
||||
Name: tc.Function.Name,
|
||||
Type: tc.Type,
|
||||
DisableParallelToolUse: !origRequest.ParallelToolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
@@ -233,7 +341,7 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: origResponse.Content[0].Text},
|
||||
FinishReason: stopReasonClaude2OpenAI(origResponse.StopReason),
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
|
||||
}
|
||||
|
||||
return &chatCompletionResponse{
|
||||
@@ -243,7 +351,7 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: origResponse.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Usage.OutputTokens,
|
||||
TotalTokens: origResponse.Usage.InputTokens + origResponse.Usage.OutputTokens,
|
||||
@@ -270,27 +378,50 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
||||
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
|
||||
switch origResponse.Type {
|
||||
case "message_start":
|
||||
c.messageId = origResponse.Message.Id
|
||||
c.usage = usage{
|
||||
PromptTokens: origResponse.Message.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Message.Usage.OutputTokens,
|
||||
}
|
||||
c.serviceTier = origResponse.Message.Usage.ServiceTier
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Index: origResponse.Index,
|
||||
Delta: &chatMessage{Role: roleAssistant, Content: ""},
|
||||
}
|
||||
return createChatCompletionResponse(ctx, origResponse, choice)
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "content_block_delta":
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Index: origResponse.Index,
|
||||
Delta: &chatMessage{Content: origResponse.Delta.Text},
|
||||
}
|
||||
return createChatCompletionResponse(ctx, origResponse, choice)
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "message_delta":
|
||||
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
|
||||
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
|
||||
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Index: origResponse.Index,
|
||||
Delta: &chatMessage{},
|
||||
FinishReason: stopReasonClaude2OpenAI(origResponse.Delta.StopReason),
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.Delta.StopReason)),
|
||||
}
|
||||
return createChatCompletionResponse(ctx, origResponse, choice)
|
||||
case "content_block_stop", "message_stop":
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
case "message_stop":
|
||||
return &chatCompletionResponse{
|
||||
Id: c.messageId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{},
|
||||
ServiceTier: c.serviceTier,
|
||||
Usage: &usage{
|
||||
PromptTokens: c.usage.PromptTokens,
|
||||
CompletionTokens: c.usage.CompletionTokens,
|
||||
TotalTokens: c.usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
case "content_block_stop", "ping", "content_block_start":
|
||||
log.Debugf("skip processing response type: %s", origResponse.Type)
|
||||
return nil
|
||||
default:
|
||||
@@ -299,13 +430,14 @@ func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, or
|
||||
}
|
||||
}
|
||||
|
||||
func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse {
|
||||
func (c *claudeProvider) createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextGenStreamResponse, choice chatCompletionChoice) *chatCompletionResponse {
|
||||
return &chatCompletionResponse{
|
||||
Id: response.Message.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Id: c.messageId,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Object: objectChatCompletionChunk,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
ServiceTier: c.serviceTier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,5 +464,14 @@ func (c *claudeProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, claudeChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, claudeCompletionPath) {
|
||||
return ApiNameCompletion
|
||||
}
|
||||
if strings.Contains(path, PathOpenAIModels) {
|
||||
return ApiNameModels
|
||||
}
|
||||
if strings.Contains(path, PathOpenAIEmbeddings) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -116,34 +116,34 @@ func (d *difyProvider) responseDify2OpenAI(ctx wrapper.HttpContext, response *Di
|
||||
choice = chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: response.Answer},
|
||||
FinishReason: finishReasonStop,
|
||||
FinishReason: util.Ptr(finishReasonStop),
|
||||
}
|
||||
//response header中增加conversationId字段
|
||||
// response header中增加conversationId字段
|
||||
_ = proxywasm.ReplaceHttpResponseHeader("ConversationId", response.ConversationId)
|
||||
id = response.ConversationId
|
||||
case BotTypeCompletion:
|
||||
choice = chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: response.Answer},
|
||||
FinishReason: finishReasonStop,
|
||||
FinishReason: util.Ptr(finishReasonStop),
|
||||
}
|
||||
id = response.MessageId
|
||||
case BotTypeWorkflow:
|
||||
choice = chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{Role: roleAssistant, Content: response.Data.Outputs[d.config.outputVariable]},
|
||||
FinishReason: finishReasonStop,
|
||||
FinishReason: util.Ptr(finishReasonStop),
|
||||
}
|
||||
id = response.Data.WorkflowId
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
Id: id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Created: response.CreatedAt,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: response.MetaData.Usage,
|
||||
Usage: &response.MetaData.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api
|
||||
func (d *difyProvider) streamResponseDify2OpenAI(ctx wrapper.HttpContext, response *DifyChunkChatResponse) *chatCompletionResponse {
|
||||
var choice chatCompletionChoice
|
||||
var id string
|
||||
var responseUsage usage
|
||||
var responseUsage *usage
|
||||
switch d.config.botType {
|
||||
case BotTypeChat, BotTypeAgent:
|
||||
choice = chatCompletionChoice{
|
||||
@@ -211,9 +211,9 @@ func (d *difyProvider) streamResponseDify2OpenAI(ctx wrapper.HttpContext, respon
|
||||
id = response.Data.WorkflowId
|
||||
}
|
||||
if response.Event == "message_end" || response.Event == "workflow_finished" {
|
||||
choice.FinishReason = finishReasonStop
|
||||
choice.FinishReason = util.Ptr(finishReasonStop)
|
||||
if response.Event == "message_end" {
|
||||
responseUsage = usage{
|
||||
responseUsage = &usage{
|
||||
PromptTokens: response.MetaData.Usage.PromptTokens,
|
||||
CompletionTokens: response.MetaData.Usage.CompletionTokens,
|
||||
TotalTokens: response.MetaData.Usage.TotalTokens,
|
||||
@@ -222,7 +222,7 @@ func (d *difyProvider) streamResponseDify2OpenAI(ctx wrapper.HttpContext, respon
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
Id: id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Created: response.CreatedAt,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletionChunk,
|
||||
@@ -309,7 +309,7 @@ type DifyChatResponse struct {
|
||||
ConversationId string `json:"conversation_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
Answer string `json:"answer"`
|
||||
CreateAt int64 `json:"create_at"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Data DifyData `json:"data"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
||||
@@ -319,6 +319,7 @@ type DifyChunkChatResponse struct {
|
||||
ConversationId string `json:"conversation_id"`
|
||||
MessageId string `json:"message_id"`
|
||||
Answer string `json:"answer"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Data DifyData `json:"data"`
|
||||
MetaData DifyMetaData `json:"metadata"`
|
||||
}
|
||||
|
||||
@@ -6,14 +6,19 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
doubaoDomain = "ark.cn-beijing.volces.com"
|
||||
doubaoChatCompletionPath = "/api/v3/chat/completions"
|
||||
doubaoEmbeddingsPath = "/api/v3/embeddings"
|
||||
doubaoDomain = "ark.cn-beijing.volces.com"
|
||||
doubaoChatCompletionPath = "/api/v3/chat/completions"
|
||||
doubaoEmbeddingsPath = "/api/v3/embeddings"
|
||||
doubaoImageGenerationPath = "/api/v3/images/generations"
|
||||
doubaoResponsesPath = "/api/v3/responses"
|
||||
)
|
||||
|
||||
type doubaoProviderInitializer struct{}
|
||||
@@ -27,8 +32,10 @@ func (m *doubaoProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (m *doubaoProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): doubaoChatCompletionPath,
|
||||
string(ApiNameEmbeddings): doubaoEmbeddingsPath,
|
||||
string(ApiNameChatCompletion): doubaoChatCompletionPath,
|
||||
string(ApiNameEmbeddings): doubaoEmbeddingsPath,
|
||||
string(ApiNameImageGeneration): doubaoImageGenerationPath,
|
||||
string(ApiNameResponses): doubaoResponsesPath,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,6 +75,32 @@ func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
var err error
|
||||
switch apiName {
|
||||
case ApiNameResponses:
|
||||
// 移除火山 responses 接口暂时不支持的参数
|
||||
// 参考: https://www.volcengine.com/docs/82379/1569618
|
||||
// TODO: 这里应该用 DTO 处理
|
||||
for _, param := range []string{"parallel_tool_calls", "tool_choice"} {
|
||||
body, err = sjson.DeleteBytes(body, param)
|
||||
if err != nil {
|
||||
log.Warnf("[doubao] failed to delete %s in request body, err: %v", param, err)
|
||||
}
|
||||
}
|
||||
case ApiNameImageGeneration:
|
||||
// 火山生图接口默认会带上水印,但 OpenAI 接口不支持此参数
|
||||
// 参考: https://www.volcengine.com/docs/82379/1541523
|
||||
if res := gjson.GetBytes(body, "watermark"); !res.Exists() {
|
||||
body, err = sjson.SetBytes(body, "watermark", false)
|
||||
if err != nil {
|
||||
log.Warnf("[doubao] failed to set watermark in request body, err: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (m *doubaoProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, doubaoChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
@@ -75,5 +108,11 @@ func (m *doubaoProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, doubaoEmbeddingsPath) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
if strings.Contains(path, doubaoImageGenerationPath) {
|
||||
return ApiNameImageGeneration
|
||||
}
|
||||
if strings.Contains(path, doubaoResponsesPath) {
|
||||
return ApiNameResponses
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -19,14 +19,16 @@ import (
|
||||
|
||||
const (
|
||||
geminiApiKeyHeader = "x-goog-api-key"
|
||||
geminiDefaultApiVersion = "v1beta" // 可选: v1, v1beta
|
||||
geminiDomain = "generativelanguage.googleapis.com"
|
||||
geminiChatCompletionPath = "generateContent"
|
||||
geminiChatCompletionStreamPath = "streamGenerateContent?alt=sse"
|
||||
geminiEmbeddingPath = "batchEmbedContents"
|
||||
geminiModelsPath = "models"
|
||||
geminiImageGenerationPath = "predict"
|
||||
)
|
||||
|
||||
type geminiProviderInitializer struct {
|
||||
}
|
||||
type geminiProviderInitializer struct{}
|
||||
|
||||
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
@@ -37,8 +39,10 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): "",
|
||||
string(ApiNameEmbeddings): "",
|
||||
string(ApiNameChatCompletion): "",
|
||||
string(ApiNameEmbeddings): "",
|
||||
string(ApiNameModels): "",
|
||||
string(ApiNameImageGeneration): "",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,11 +83,38 @@ func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return g.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return g.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
case ApiNameImageGeneration:
|
||||
return g.onImageGenerationRequestBody(ctx, body, headers)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &imageGenerationRequest{}
|
||||
if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := g.getRequestPath(ApiNameImageGeneration, request.Model, false)
|
||||
log.Debugf("request path:%s", path)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
geminiRequest := g.buildGeminiImageGenerationRequest(request)
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) buildGeminiImageGenerationRequest(request *imageGenerationRequest) *geminiImageGenerationRequest {
|
||||
geminiRequest := &geminiImageGenerationRequest{
|
||||
Instances: []geminiImageGenerationInstance{{Prompt: request.Prompt}},
|
||||
Parameters: &geminiImageGenerationParameters{
|
||||
SampleCount: request.N,
|
||||
},
|
||||
}
|
||||
|
||||
return geminiRequest
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
@@ -112,7 +143,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
log.Infof("chunk body:%s", string(chunk))
|
||||
log.Debugf("chunk body:%s", string(chunk))
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -148,14 +179,43 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return g.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
case ApiNameEmbeddings:
|
||||
return g.onEmbeddingsResponseBody(ctx, body)
|
||||
case ApiNameImageGeneration:
|
||||
return g.onImageGenerationResponseBody(ctx, body)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onImageGenerationResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
geminiResponse := &geminiImageGenerationResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini image generation response: %v", err)
|
||||
}
|
||||
response := g.buildImageGenerationResponse(ctx, geminiResponse)
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) buildImageGenerationResponse(ctx wrapper.HttpContext, geminiResponse *geminiImageGenerationResponse) *imageGenerationResponse {
|
||||
data := make([]imageGenerationData, len(geminiResponse.Predictions))
|
||||
for i, prediction := range geminiResponse.Predictions {
|
||||
data[i] = imageGenerationData{
|
||||
B64: prediction.BytesBase64Encoded,
|
||||
}
|
||||
}
|
||||
response := &imageGenerationResponse{
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Data: data,
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
log.Debugf("chat completion response body:%s", string(body))
|
||||
geminiResponse := &geminiChatResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
||||
@@ -181,26 +241,37 @@ func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) getRequestPath(apiName ApiName, geminiModel string, stream bool) string {
|
||||
func (g *geminiProvider) getRequestPath(apiName ApiName, model string, stream bool) string {
|
||||
action := ""
|
||||
if apiName == ApiNameEmbeddings {
|
||||
action = geminiEmbeddingPath
|
||||
} else if stream {
|
||||
action = geminiChatCompletionStreamPath
|
||||
} else {
|
||||
action = geminiChatCompletionPath
|
||||
if g.config.apiVersion == "" {
|
||||
g.config.apiVersion = geminiDefaultApiVersion
|
||||
}
|
||||
return fmt.Sprintf("/v1/models/%s:%s", geminiModel, action)
|
||||
switch apiName {
|
||||
case ApiNameModels:
|
||||
return fmt.Sprintf("/%s/%s", g.config.apiVersion, geminiModelsPath)
|
||||
case ApiNameEmbeddings:
|
||||
action = geminiEmbeddingPath
|
||||
case ApiNameChatCompletion:
|
||||
if stream {
|
||||
action = geminiChatCompletionStreamPath
|
||||
} else {
|
||||
action = geminiChatCompletionPath
|
||||
}
|
||||
case ApiNameImageGeneration:
|
||||
action = geminiImageGenerationPath
|
||||
}
|
||||
return fmt.Sprintf("/%s/models/%s:%s", g.config.apiVersion, model, action)
|
||||
}
|
||||
|
||||
type geminiChatRequest struct {
|
||||
type geminiGenerationContentRequest struct {
|
||||
// Model and Stream are only used when using the gemini original protocol
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Contents []geminiChatContent `json:"contents"`
|
||||
SafetySettings []geminiChatSafetySetting `json:"safety_settings,omitempty"`
|
||||
GenerationConfig geminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []geminiChatTools `json:"tools,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Contents []geminiChatContent `json:"contents"`
|
||||
SystemInstruction *geminiChatContent `json:"system_instruction,omitempty"`
|
||||
SafetySettings []geminiChatSafetySetting `json:"safetySettings,omitempty"`
|
||||
GenerationConfig geminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
Tools []geminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type geminiChatContent struct {
|
||||
@@ -213,13 +284,26 @@ type geminiChatSafetySetting struct {
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type geminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget int64 `json:"thinkingBudget,omitempty"`
|
||||
}
|
||||
|
||||
type geminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int64 `json:"topK,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Logprobs bool `json:"logprobs,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
PresencePenalty int64 `json:"presencePenalty,omitempty"`
|
||||
FrequencyPenalty int64 `json:"frequencyPenalty,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
NegativePrompt string `json:"negativePrompt,omitempty"`
|
||||
ThinkingConfig *geminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
MediaResolution string `json:"mediaResolution,omitempty"`
|
||||
}
|
||||
|
||||
type geminiChatTools struct {
|
||||
@@ -242,25 +326,52 @@ type geminiFunctionCall struct {
|
||||
Arguments any `json:"args"`
|
||||
}
|
||||
|
||||
func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiChatRequest {
|
||||
// geminiImageGenerationRequest is the request body for generate image using Imagen 3
|
||||
type geminiImageGenerationRequest struct {
|
||||
Instances []geminiImageGenerationInstance `json:"instances"`
|
||||
Parameters *geminiImageGenerationParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type geminiImageGenerationInstance struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type geminiImageGenerationParameters struct {
|
||||
SampleCount int `json:"sampleCount,omitempty"`
|
||||
AspectRatio string `json:"aspectRatio,omitempty"`
|
||||
}
|
||||
|
||||
type geminiImageGenerationPrediction struct {
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
MimeType string `json:"mimeType"`
|
||||
}
|
||||
|
||||
type geminiImageGenerationResponse struct {
|
||||
Predictions []geminiImageGenerationPrediction `json:"predictions"`
|
||||
}
|
||||
|
||||
func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest) *geminiGenerationContentRequest {
|
||||
var safetySettings []geminiChatSafetySetting
|
||||
{
|
||||
}
|
||||
for category, threshold := range g.config.geminiSafetySetting {
|
||||
safetySettings = append(safetySettings, geminiChatSafetySetting{
|
||||
Category: category,
|
||||
Threshold: threshold,
|
||||
})
|
||||
}
|
||||
geminiRequest := geminiChatRequest{
|
||||
geminiRequest := geminiGenerationContentRequest{
|
||||
Contents: make([]geminiChatContent, 0, len(request.Messages)),
|
||||
SafetySettings: safetySettings,
|
||||
GenerationConfig: geminiChatGenerationConfig{
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
MaxOutputTokens: request.MaxTokens,
|
||||
PresencePenalty: int64(request.PresencePenalty),
|
||||
FrequencyPenalty: int64(request.FrequencyPenalty),
|
||||
Logprobs: request.Logprobs,
|
||||
ResponseModalities: request.Modalities,
|
||||
},
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
functions := make([]function, 0, len(request.Tools))
|
||||
for _, tool := range request.Tools {
|
||||
@@ -272,7 +383,7 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
// shouldAddDummyModelMessage := false
|
||||
for _, message := range request.Messages {
|
||||
content := geminiChatContent{
|
||||
Role: message.Role,
|
||||
@@ -284,32 +395,22 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
}
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if role is not user or model
|
||||
if content.Role == roleAssistant {
|
||||
switch content.Role {
|
||||
case roleSystem:
|
||||
content.Role = ""
|
||||
geminiRequest.SystemInstruction = &content
|
||||
continue
|
||||
case roleAssistant:
|
||||
content.Role = "model"
|
||||
} else if content.Role == roleSystem { // converting system prompt to prompt from user for the same reason
|
||||
content.Role = roleUser
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
|
||||
// if a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, geminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []geminiPart{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func (g *geminiProvider) setSystemContent(request *geminiChatRequest, content string) {
|
||||
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
|
||||
systemContents := []geminiChatContent{{
|
||||
Role: roleUser,
|
||||
Parts: []geminiPart{
|
||||
@@ -399,32 +500,34 @@ func (g *geminiProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, re
|
||||
Object: objectChatCompletion,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Choices: make([]chatCompletionChoice, 0, len(response.Candidates)),
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: response.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: response.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: response.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := chatCompletionChoice{
|
||||
Index: i,
|
||||
Message: &chatMessage{
|
||||
Role: roleAssistant,
|
||||
},
|
||||
FinishReason: finishReasonStop,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
if candidate.Content.Parts[0].FunctionCall != nil {
|
||||
choice.Message.ToolCalls = g.buildToolCalls(&candidate)
|
||||
} else {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
choiceIndex := 0
|
||||
for _, candidate := range response.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
choice := chatCompletionChoice{
|
||||
Index: choiceIndex,
|
||||
Message: &chatMessage{
|
||||
Role: roleAssistant,
|
||||
},
|
||||
FinishReason: util.Ptr(finishReasonStop),
|
||||
}
|
||||
} else {
|
||||
choice.Message.Content = ""
|
||||
choice.FinishReason = candidate.FinishReason
|
||||
if part.FunctionCall != nil {
|
||||
choice.Message.ToolCalls = g.buildToolCalls(&candidate)
|
||||
} else if part.InlineData != nil {
|
||||
choice.Message.Content = part.InlineData.Data
|
||||
} else {
|
||||
choice.Message.Content = part.Text
|
||||
}
|
||||
|
||||
choice.FinishReason = util.Ptr(strings.ToLower(candidate.FinishReason))
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
choiceIndex += 1
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
@@ -457,6 +560,9 @@ func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
||||
var choice chatCompletionChoice
|
||||
if len(geminiResp.Candidates) > 0 && len(geminiResp.Candidates[0].Content.Parts) > 0 {
|
||||
choice.Delta = &chatMessage{Content: geminiResp.Candidates[0].Content.Parts[0].Text}
|
||||
if geminiResp.Candidates[0].FinishReason != "" {
|
||||
choice.FinishReason = util.Ptr(strings.ToLower(geminiResp.Candidates[0].FinishReason))
|
||||
}
|
||||
}
|
||||
streamResponse := chatCompletionResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", uuid.New().String()),
|
||||
@@ -464,7 +570,7 @@ func (g *geminiProvider) buildChatCompletionStreamResponse(ctx wrapper.HttpConte
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||
@@ -512,5 +618,8 @@ func (g *geminiProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, geminiEmbeddingPath) {
|
||||
return ApiNameEmbeddings
|
||||
}
|
||||
if strings.Contains(path, geminiImageGenerationPath) {
|
||||
return ApiNameImageGeneration
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -42,13 +42,10 @@ const (
|
||||
hunyuanAuthIdLen = 36
|
||||
|
||||
// docs: https://cloud.tencent.com/document/product/1729/111007
|
||||
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
|
||||
hunyuanOpenAiRequestPath = "/v1/chat/completions"
|
||||
hunyuanOpenAiEmbeddings = "/v1/embeddings"
|
||||
hunyuanOpenAiDomain = "api.hunyuan.cloud.tencent.com"
|
||||
)
|
||||
|
||||
type hunyuanProviderInitializer struct {
|
||||
}
|
||||
type hunyuanProviderInitializer struct{}
|
||||
|
||||
// ref: https://console.cloud.tencent.com/api/explorer?Product=hunyuan&Version=2023-09-01&Action=ChatCompletions
|
||||
type hunyuanTextGenRequest struct {
|
||||
@@ -105,8 +102,8 @@ func (m *hunyuanProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
|
||||
func (m *hunyuanProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): hunyuanOpenAiRequestPath,
|
||||
string(ApiNameEmbeddings): hunyuanOpenAiEmbeddings,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,7 +321,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
}
|
||||
|
||||
// hunyuan的流式返回:
|
||||
//data: {"Note":"以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
|
||||
// data: {"Note":"以上内容为AI生成,不代表开发者立场,请勿删除或修改本标记","Choices":[{"Delta":{"Role":"assistant","Content":"有助于"},"FinishReason":""}],"Created":1716359713,"Id":"086b6b19-8b2c-4def-a65c-db6a7bc86acd","Usage":{"PromptTokens":7,"CompletionTokens":145,"TotalTokens":152}}
|
||||
|
||||
// openai的流式返回
|
||||
// data: {"id": "chatcmpl-7QyqpwdfhqwajicIEznoc6Q47XAyW", "object": "chat.completion.chunk", "created": 1677664795, "model": "gpt-3.5-turbo-0613", "choices": [{"delta": {"content": "The "}, "index": 0, "finish_reason": null}]}
|
||||
@@ -338,7 +335,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
|
||||
}
|
||||
|
||||
// 初始化处理下标,以及将要返回的处理过的chunks
|
||||
var newEventPivot = -1
|
||||
newEventPivot := -1
|
||||
var outputBuffer []byte
|
||||
|
||||
// 从buffer区取出若干完整的chunk,将其转为openAI格式后返回
|
||||
@@ -390,7 +387,7 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletionChunk,
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: hunyuanFormattedChunk.Usage.PromptTokens,
|
||||
CompletionTokens: hunyuanFormattedChunk.Usage.CompletionTokens,
|
||||
TotalTokens: hunyuanFormattedChunk.Usage.TotalTokens,
|
||||
@@ -403,7 +400,7 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
|
||||
if hunyuanFormattedChunk.Choices[0].FinishReason == hunyuanStreamEndMark {
|
||||
// log.Debugf("@@@ --- 最后chunk: ")
|
||||
openAIFormattedChunk.Choices = append(openAIFormattedChunk.Choices, chatCompletionChoice{
|
||||
FinishReason: hunyuanFormattedChunk.Choices[0].FinishReason,
|
||||
FinishReason: util.Ptr(hunyuanFormattedChunk.Choices[0].FinishReason),
|
||||
})
|
||||
} else {
|
||||
deltaMsg := chatMessage{
|
||||
@@ -451,7 +448,6 @@ func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
|
||||
}
|
||||
|
||||
func (m *hunyuanProvider) insertContextMessageIntoHunyuanRequest(request *hunyuanTextGenRequest, content string) {
|
||||
|
||||
fileMessage := hunyuanChatMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
@@ -499,7 +495,7 @@ func (m *hunyuanProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, h
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: nil,
|
||||
},
|
||||
FinishReason: choice.FinishReason,
|
||||
FinishReason: util.Ptr(choice.FinishReason),
|
||||
})
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
@@ -509,7 +505,7 @@ func (m *hunyuanProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, h
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: hunyuanResponse.Response.Usage.PromptTokens,
|
||||
CompletionTokens: hunyuanResponse.Response.Usage.CompletionTokens,
|
||||
TotalTokens: hunyuanResponse.Response.Usage.TotalTokens,
|
||||
|
||||
@@ -36,8 +36,7 @@ const (
|
||||
defaultSenderName string = "小明"
|
||||
)
|
||||
|
||||
type minimaxProviderInitializer struct {
|
||||
}
|
||||
type minimaxProviderInitializer struct{}
|
||||
|
||||
func (m *minimaxProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
// If using the chat completion Pro API, a group ID must be set.
|
||||
@@ -368,7 +367,7 @@ func (m *minimaxProvider) responseProToOpenAI(response *minimaxChatCompletionPro
|
||||
Content: message.Text,
|
||||
}
|
||||
choices = append(choices, chatCompletionChoice{
|
||||
FinishReason: choice.FinishReason,
|
||||
FinishReason: util.Ptr(choice.FinishReason),
|
||||
Index: messageIndex,
|
||||
Message: message,
|
||||
})
|
||||
@@ -381,7 +380,7 @@ func (m *minimaxProvider) responseProToOpenAI(response *minimaxChatCompletionPro
|
||||
Created: response.Created,
|
||||
Model: response.Model,
|
||||
Choices: choices,
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
TotalTokens: int(response.Usage.TotalTokens),
|
||||
PromptTokens: int(response.Usage.PromptTokens),
|
||||
CompletionTokens: int(response.Usage.CompletionTokens),
|
||||
|
||||
@@ -20,8 +20,10 @@ const (
|
||||
|
||||
httpStatus200 = "200"
|
||||
|
||||
contentTypeText = "text"
|
||||
contentTypeImageUrl = "image_url"
|
||||
contentTypeText = "text"
|
||||
contentTypeImageUrl = "image_url"
|
||||
contentTypeInputAudio = "input_audio"
|
||||
contentTypeFile = "file"
|
||||
|
||||
reasoningStartTag = "<think>"
|
||||
reasoningEndTag = "</think>"
|
||||
@@ -53,11 +55,40 @@ type chatCompletionRequest struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Tools []tool `json:"tools,omitempty"`
|
||||
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getMaxTokens() int {
|
||||
if c.MaxCompletionTokens > 0 {
|
||||
return c.MaxCompletionTokens
|
||||
}
|
||||
return c.MaxTokens
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getToolChoiceString() string {
|
||||
if c.ToolChoice == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if tc, ok := c.ToolChoice.(string); ok {
|
||||
return tc
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *chatCompletionRequest) getToolChoiceObject() *toolChoice {
|
||||
if c.ToolChoice == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tc, ok := c.ToolChoice.(*toolChoice); ok {
|
||||
return tc
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
@@ -107,15 +138,15 @@ type chatCompletionResponse struct {
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
SystemFingerprint string `json:"system_fingerprint,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Usage usage `json:"usage,omitempty"`
|
||||
Usage *usage `json:"usage"`
|
||||
}
|
||||
|
||||
type chatCompletionChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message *chatMessage `json:"message,omitempty"`
|
||||
Delta *chatMessage `json:"delta,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Logprobs map[string]interface{} `json:"logprobs,omitempty"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
Logprobs map[string]interface{} `json:"logprobs"`
|
||||
}
|
||||
|
||||
type usage struct {
|
||||
@@ -200,13 +231,26 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r
|
||||
}
|
||||
}
|
||||
|
||||
type messageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl *imageUrl `json:"image_url,omitempty"`
|
||||
type chatMessageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text"`
|
||||
ImageUrl *chatMessageContentImageUrl `json:"image_url,omitempty"`
|
||||
File *chatMessageContentFile `json:"file,omitempty"`
|
||||
InputAudio *chatMessageContentAudio `json:"input_audio,omitempty"`
|
||||
}
|
||||
|
||||
type imageUrl struct {
|
||||
type chatMessageContentAudio struct {
|
||||
Data string `json:"data"`
|
||||
Format string `json:"format"`
|
||||
}
|
||||
|
||||
type chatMessageContentFile struct {
|
||||
FileData string `json:"file_data,omitempty"`
|
||||
FileId string `json:"file_id,omitempty"`
|
||||
FileName string `json:"file_name,omitempty"`
|
||||
}
|
||||
|
||||
type chatMessageContentImageUrl struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
@@ -266,11 +310,11 @@ func (m *chatMessage) StringContent() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *chatMessage) ParseContent() []messageContent {
|
||||
var contentList []messageContent
|
||||
func (m *chatMessage) ParseContent() []chatMessageContent {
|
||||
var contentList []chatMessageContent
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
contentList = append(contentList, messageContent{
|
||||
contentList = append(contentList, chatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: content,
|
||||
})
|
||||
@@ -286,18 +330,43 @@ func (m *chatMessage) ParseContent() []messageContent {
|
||||
switch contentMap["type"] {
|
||||
case contentTypeText:
|
||||
if subStr, ok := contentMap[contentTypeText].(string); ok {
|
||||
contentList = append(contentList, messageContent{
|
||||
contentList = append(contentList, chatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case contentTypeImageUrl:
|
||||
if subObj, ok := contentMap[contentTypeImageUrl].(map[string]any); ok {
|
||||
contentList = append(contentList, messageContent{
|
||||
msg := chatMessageContent{
|
||||
Type: contentTypeImageUrl,
|
||||
ImageUrl: &imageUrl{
|
||||
ImageUrl: &chatMessageContentImageUrl{
|
||||
Url: subObj["url"].(string),
|
||||
},
|
||||
}
|
||||
if detail, ok := subObj["detail"].(string); ok {
|
||||
msg.ImageUrl.Detail = detail
|
||||
}
|
||||
contentList = append(contentList, msg)
|
||||
}
|
||||
case contentTypeInputAudio:
|
||||
if subObj, ok := contentMap[contentTypeInputAudio].(map[string]any); ok {
|
||||
contentList = append(contentList, chatMessageContent{
|
||||
Type: contentTypeInputAudio,
|
||||
InputAudio: &chatMessageContentAudio{
|
||||
Data: subObj["data"].(string),
|
||||
Format: subObj["format"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
case contentTypeFile:
|
||||
if subObj, ok := contentMap[contentTypeFile].(map[string]any); ok {
|
||||
contentList = append(contentList, chatMessageContent{
|
||||
Type: contentTypeFile,
|
||||
File: &chatMessageContentFile{
|
||||
FileId: subObj["file_id"].(string),
|
||||
// FileName: subObj["file_name"].(string),
|
||||
// FileData: subObj["file_data"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,13 +18,10 @@ import (
|
||||
// moonshotProvider is the provider for Moonshot AI service.
|
||||
|
||||
const (
|
||||
moonshotDomain = "api.moonshot.cn"
|
||||
moonshotChatCompletionPath = "/v1/chat/completions"
|
||||
moonshotModelsPath = "/v1/models"
|
||||
moonshotDomain = "api.moonshot.cn"
|
||||
)
|
||||
|
||||
type moonshotProviderInitializer struct {
|
||||
}
|
||||
type moonshotProviderInitializer struct{}
|
||||
|
||||
func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.moonshotFileId != "" && config.context != nil {
|
||||
@@ -38,8 +35,8 @@ func (m *moonshotProviderInitializer) ValidateConfig(config *ProviderConfig) err
|
||||
|
||||
func (m *moonshotProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): moonshotChatCompletionPath,
|
||||
string(ApiNameModels): moonshotModelsPath,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,17 +15,10 @@ import (
|
||||
// openaiProvider is the provider for OpenAI service.
|
||||
|
||||
const (
|
||||
defaultOpenaiDomain = "api.openai.com"
|
||||
defaultOpenaiChatCompletionPath = "/v1/chat/completions"
|
||||
defaultOpenaiCompletionPath = "/v1/completions"
|
||||
defaultOpenaiEmbeddingsPath = "/v1/embeddings"
|
||||
defaultOpenaiAudioSpeech = "/v1/audio/speech"
|
||||
defaultOpenaiImageGeneration = "/v1/images/generations"
|
||||
defaultOpenaiModels = "/v1/models"
|
||||
defaultOpenaiDomain = "api.openai.com"
|
||||
)
|
||||
|
||||
type openaiProviderInitializer struct {
|
||||
}
|
||||
type openaiProviderInitializer struct{}
|
||||
|
||||
func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
return nil
|
||||
@@ -33,20 +26,45 @@ func (m *openaiProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (m *openaiProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameCompletion): defaultOpenaiCompletionPath,
|
||||
string(ApiNameChatCompletion): defaultOpenaiChatCompletionPath,
|
||||
string(ApiNameEmbeddings): defaultOpenaiEmbeddingsPath,
|
||||
string(ApiNameImageGeneration): defaultOpenaiImageGeneration,
|
||||
string(ApiNameAudioSpeech): defaultOpenaiAudioSpeech,
|
||||
string(ApiNameModels): defaultOpenaiModels,
|
||||
string(ApiNameCompletion): PathOpenAICompletions,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
|
||||
string(ApiNameImageGeneration): PathOpenAIImageGeneration,
|
||||
string(ApiNameImageEdit): PathOpenAIImageEdit,
|
||||
string(ApiNameImageVariation): PathOpenAIImageVariation,
|
||||
string(ApiNameAudioSpeech): PathOpenAIAudioSpeech,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
string(ApiNameFiles): PathOpenAIFiles,
|
||||
string(ApiNameRetrieveFile): PathOpenAIRetrieveFile,
|
||||
string(ApiNameRetrieveFileContent): PathOpenAIRetrieveFileContent,
|
||||
string(ApiNameBatches): PathOpenAIBatches,
|
||||
string(ApiNameRetrieveBatch): PathOpenAIRetrieveBatch,
|
||||
string(ApiNameCancelBatch): PathOpenAICancelBatch,
|
||||
string(ApiNameResponses): PathOpenAIResponses,
|
||||
string(ApiNameFineTuningJobs): PathOpenAIFineTuningJobs,
|
||||
string(ApiNameRetrieveFineTuningJob): PathOpenAIRetrieveFineTuningJob,
|
||||
string(ApiNameFineTuningJobEvents): PathOpenAIFineTuningJobEvents,
|
||||
string(ApiNameFineTuningJobCheckpoints): PathOpenAIFineTuningJobCheckpoints,
|
||||
string(ApiNameCancelFineTuningJob): PathOpenAICancelFineTuningJob,
|
||||
string(ApiNameResumeFineTuningJob): PathOpenAIResumeFineTuningJob,
|
||||
string(ApiNamePauseFineTuningJob): PathOpenAIPauseFineTuningJob,
|
||||
string(ApiNameFineTuningCheckpointPermissions): PathOpenAIFineTuningCheckpointPermissions,
|
||||
string(ApiNameDeleteFineTuningCheckpointPermission): PathOpenAIFineDeleteTuningCheckpointPermission,
|
||||
}
|
||||
}
|
||||
|
||||
// isDirectPath checks if the path is a known standard OpenAI interface path.
|
||||
func isDirectPath(path string) bool {
|
||||
return strings.HasSuffix(path, "/completions") ||
|
||||
strings.HasSuffix(path, "/embeddings") ||
|
||||
strings.HasSuffix(path, "/audio/speech") ||
|
||||
strings.HasSuffix(path, "/images/generations")
|
||||
strings.HasSuffix(path, "/images/generations") ||
|
||||
strings.HasSuffix(path, "/images/variations") ||
|
||||
strings.HasSuffix(path, "/images/edits") ||
|
||||
strings.HasSuffix(path, "/models") ||
|
||||
strings.HasSuffix(path, "/responses") ||
|
||||
strings.HasSuffix(path, "/fine_tuning/jobs") ||
|
||||
strings.HasSuffix(path, "/fine_tuning/checkpoints")
|
||||
}
|
||||
|
||||
func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
@@ -100,15 +118,12 @@ func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiNa
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
if m.customPath != "" {
|
||||
if m.isDirectCustomPath || apiName == "" {
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
} else {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
}
|
||||
} else {
|
||||
if m.isDirectCustomPath {
|
||||
util.OverwriteRequestPathHeader(headers, m.customPath)
|
||||
} else if apiName != "" {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
|
||||
}
|
||||
|
||||
if m.customDomain != "" {
|
||||
util.OverwriteRequestHostHeader(headers, m.customDomain)
|
||||
} else {
|
||||
@@ -121,7 +136,7 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
}
|
||||
|
||||
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
if !m.config.needToProcessRequestBody(apiName) {
|
||||
// We don't need to process the request body for other APIs.
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
@@ -16,32 +17,70 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type ApiName string
|
||||
type Pointcut string
|
||||
type (
|
||||
ApiName string
|
||||
Pointcut string
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// ApiName 格式 {vendor}/{version}/{apitype}
|
||||
// 表示遵循 厂商/版本/接口类型 的格式
|
||||
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
|
||||
ApiNameCompletion ApiName = "openai/v1/completions"
|
||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
ApiNameFiles ApiName = "openai/v1/files"
|
||||
ApiNameBatches ApiName = "openai/v1/batches"
|
||||
ApiNameModels ApiName = "openai/v1/models"
|
||||
ApiNameCompletion ApiName = "openai/v1/completions"
|
||||
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
|
||||
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
|
||||
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
|
||||
ApiNameImageEdit ApiName = "openai/v1/imageedit"
|
||||
ApiNameImageVariation ApiName = "openai/v1/imagevariation"
|
||||
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
|
||||
ApiNameFiles ApiName = "openai/v1/files"
|
||||
ApiNameRetrieveFile ApiName = "openai/v1/retrievefile"
|
||||
ApiNameRetrieveFileContent ApiName = "openai/v1/retrievefilecontent"
|
||||
ApiNameBatches ApiName = "openai/v1/batches"
|
||||
ApiNameRetrieveBatch ApiName = "openai/v1/retrievebatch"
|
||||
ApiNameCancelBatch ApiName = "openai/v1/cancelbatch"
|
||||
ApiNameModels ApiName = "openai/v1/models"
|
||||
ApiNameResponses ApiName = "openai/v1/responses"
|
||||
ApiNameFineTuningJobs ApiName = "openai/v1/fine-tuningjobs"
|
||||
ApiNameRetrieveFineTuningJob ApiName = "openai/v1/retrievefine-tuningjob"
|
||||
ApiNameFineTuningJobEvents ApiName = "openai/v1/fine-tuningjobsevents"
|
||||
ApiNameFineTuningJobCheckpoints ApiName = "openai/v1/fine-tuningjobcheckpoints"
|
||||
ApiNameCancelFineTuningJob ApiName = "openai/v1/cancelfine-tuningjob"
|
||||
ApiNameResumeFineTuningJob ApiName = "openai/v1/resumefine-tuningjob"
|
||||
ApiNamePauseFineTuningJob ApiName = "openai/v1/pausefine-tuningjob"
|
||||
ApiNameFineTuningCheckpointPermissions ApiName = "openai/v1/fine-tuningjobcheckpointpermissions"
|
||||
ApiNameDeleteFineTuningCheckpointPermission ApiName = "openai/v1/deletefine-tuningjobcheckpointpermission"
|
||||
|
||||
PathOpenAICompletions = "/v1/completions"
|
||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||
PathOpenAIFiles = "/v1/files"
|
||||
PathOpenAIBatches = "/v1/batches"
|
||||
PathOpenAIModels = "/v1/models"
|
||||
PathOpenAICompletions = "/v1/completions"
|
||||
PathOpenAIChatCompletions = "/v1/chat/completions"
|
||||
PathOpenAIEmbeddings = "/v1/embeddings"
|
||||
PathOpenAIFiles = "/v1/files"
|
||||
PathOpenAIRetrieveFile = "/v1/files/{file_id}"
|
||||
PathOpenAIRetrieveFileContent = "/v1/files/{file_id}/content"
|
||||
PathOpenAIBatches = "/v1/batches"
|
||||
PathOpenAIRetrieveBatch = "/v1/batches/{batch_id}"
|
||||
PathOpenAICancelBatch = "/v1/batches/{batch_id}/cancel"
|
||||
PathOpenAIModels = "/v1/models"
|
||||
PathOpenAIImageGeneration = "/v1/images/generations"
|
||||
PathOpenAIImageEdit = "/v1/images/edits"
|
||||
PathOpenAIImageVariation = "/v1/images/variations"
|
||||
PathOpenAIAudioSpeech = "/v1/audio/speech"
|
||||
PathOpenAIResponses = "/v1/responses"
|
||||
PathOpenAIFineTuningJobs = "/v1/fine_tuning/jobs"
|
||||
PathOpenAIRetrieveFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}"
|
||||
PathOpenAIFineTuningJobEvents = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/events"
|
||||
PathOpenAIFineTuningJobCheckpoints = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/checkpoints"
|
||||
PathOpenAICancelFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/cancel"
|
||||
PathOpenAIResumeFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/resume"
|
||||
PathOpenAIPauseFineTuningJob = "/v1/fine_tuning/jobs/{fine_tuning_job_id}/pause"
|
||||
PathOpenAIFineTuningCheckpointPermissions = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions"
|
||||
PathOpenAIFineDeleteTuningCheckpointPermission = "/v1/fine_tuning/checkpoints/{fine_tuned_model_checkpoint}/permissions/{permission_id}"
|
||||
|
||||
// TODO: 以下是一些非标准的API名称,需要进一步确认是否支持
|
||||
ApiNameCohereV1Rerank ApiName = "cohere/v1/rerank"
|
||||
ApiNameQwenAsyncAIGC ApiName = "api/v1/services/aigc"
|
||||
ApiNameQwenAsyncTask ApiName = "api/v1/tasks/"
|
||||
|
||||
providerTypeMoonshot = "moonshot"
|
||||
providerTypeAzure = "azure"
|
||||
@@ -71,6 +110,7 @@ const (
|
||||
providerTypeTogetherAI = "together-ai"
|
||||
providerTypeDify = "dify"
|
||||
providerTypeBedrock = "bedrock"
|
||||
providerTypeVertex = "vertex"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
@@ -142,6 +182,7 @@ var (
|
||||
providerTypeTogetherAI: &togetherAIProviderInitializer{},
|
||||
providerTypeDify: &difyProviderInitializer{},
|
||||
providerTypeBedrock: &bedrockProviderInitializer{},
|
||||
providerTypeVertex: &vertexProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -255,6 +296,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Amazon Bedrock Region
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务访问
|
||||
awsRegion string `required:"false" yaml:"awsRegion" json:"awsRegion"`
|
||||
// @Title zh-CN Amazon Bedrock 额外模型请求参数
|
||||
// @Description zh-CN 仅适用于Amazon Bedrock服务,用于设置模型特定的推理参数
|
||||
bedrockAdditionalFields map[string]interface{} `required:"false" yaml:"bedrockAdditionalFields" json:"bedrockAdditionalFields"`
|
||||
// @Title zh-CN minimax API type
|
||||
// @Description zh-CN 仅适用于 minimax 服务。minimax API 类型,v2 和 pro 中选填一项,默认值为 v2
|
||||
minimaxApiType string `required:"false" yaml:"minimaxApiType" json:"minimaxApiType"`
|
||||
@@ -271,14 +315,29 @@ type ProviderConfig struct {
|
||||
// @Description zh-CN 配置一个外部获取对话上下文的文件来源,用于在AI请求中补充对话上下文
|
||||
context *ContextConfig `required:"false" yaml:"context" json:"context"`
|
||||
// @Title zh-CN 版本
|
||||
// @Description zh-CN 请求AI服务的版本,目前仅适用于Claude AI服务
|
||||
claudeVersion string `required:"false" yaml:"version" json:"version"`
|
||||
// @Description zh-CN 请求AI服务的版本,目前仅适用于 Gemini 和 Claude AI服务
|
||||
apiVersion string `required:"false" yaml:"apiVersion" json:"apiVersion"`
|
||||
// @Title zh-CN Cloudflare Account ID
|
||||
// @Description zh-CN 仅适用于 Cloudflare Workers AI 服务。参考:https://developers.cloudflare.com/workers-ai/get-started/rest-api/#2-run-a-model-via-api
|
||||
cloudflareAccountId string `required:"false" yaml:"cloudflareAccountId" json:"cloudflareAccountId"`
|
||||
// @Title zh-CN Gemini AI内容过滤和安全级别设定
|
||||
// @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings
|
||||
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
|
||||
// @Title zh-CN Vertex AI访问区域
|
||||
// @Description zh-CN 仅适用于Vertex AI服务。如需查看支持的区域的完整列表,请参阅https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
|
||||
vertexRegion string `required:"false" yaml:"vertexRegion" json:"vertexRegion"`
|
||||
// @Title zh-CN Vertex AI项目Id
|
||||
// @Description zh-CN 仅适用于Vertex AI服务。创建和管理项目请参阅https://cloud.google.com/resource-manager/docs/creating-managing-projects?hl=zh-cn#identifiers
|
||||
vertexProjectId string `required:"false" yaml:"vertexProjectId" json:"vertexProjectId"`
|
||||
// @Title zh-CN Vertex 认证秘钥
|
||||
// @Description zh-CN 用于Google服务账号认证的完整JSON密钥文件内容,获取可参考https://cloud.google.com/iam/docs/keys-create-delete?hl=zh-cn#iam-service-account-keys-create-console
|
||||
vertexAuthKey string `required:"false" yaml:"vertexAuthKey" json:"vertexAuthKey"`
|
||||
// @Title zh-CN Vertex 认证服务名
|
||||
// @Description zh-CN 用于Google服务账号认证的服务,DNS类型的服务名
|
||||
vertexAuthServiceName string `required:"false" yaml:"vertexAuthServiceName" json:"vertexAuthServiceName"`
|
||||
// @Title zh-CN Vertex token刷新提前时间
|
||||
// @Description zh-CN 用于Google服务账号认证,access token过期时间判定提前刷新,单位为秒,默认值为60秒
|
||||
vertexTokenRefreshAhead int64 `required:"false" yaml:"vertexTokenRefreshAhead" json:"vertexTokenRefreshAhead"`
|
||||
// @Title zh-CN 翻译服务需指定的目标语种
|
||||
// @Description zh-CN 翻译结果的语种,目前仅适用于DeepL服务。
|
||||
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
|
||||
@@ -299,6 +358,8 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN 额外支持的ai能力
|
||||
// @Description zh-CN 开放的ai能力和urlpath映射,例如: {"openai/v1/chatcompletions": "/v1/chat/completions"}
|
||||
capabilities map[string]string
|
||||
// @Title zh-CN 如果配置了subPath,将会先移除请求path中该前缀,再进行后续处理
|
||||
subPath string `required:"false" yaml:"subPath" json:"subPath"`
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) GetId() string {
|
||||
@@ -356,21 +417,41 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.context = &ContextConfig{}
|
||||
c.context.FromJson(contextJson)
|
||||
}
|
||||
c.claudeVersion = json.Get("claudeVersion").String()
|
||||
|
||||
// 这里获取 claudeVersion 字段,与结构体中定义 yaml/json 的 tag 不一致
|
||||
c.apiVersion = json.Get("claudeVersion").String()
|
||||
if c.apiVersion == "" {
|
||||
// 增加获取 version 字段,用于适配其他模型的配置,并保持与结构体中定义的 tag 一致
|
||||
c.apiVersion = json.Get("apiVersion").String()
|
||||
}
|
||||
c.hunyuanAuthId = json.Get("hunyuanAuthId").String()
|
||||
c.hunyuanAuthKey = json.Get("hunyuanAuthKey").String()
|
||||
c.awsAccessKey = json.Get("awsAccessKey").String()
|
||||
c.awsSecretKey = json.Get("awsSecretKey").String()
|
||||
c.awsRegion = json.Get("awsRegion").String()
|
||||
if c.typ == providerTypeBedrock {
|
||||
c.bedrockAdditionalFields = make(map[string]interface{})
|
||||
for k, v := range json.Get("bedrockAdditionalFields").Map() {
|
||||
c.bedrockAdditionalFields[k] = v.Value()
|
||||
}
|
||||
}
|
||||
c.minimaxApiType = json.Get("minimaxApiType").String()
|
||||
c.minimaxGroupId = json.Get("minimaxGroupId").String()
|
||||
c.cloudflareAccountId = json.Get("cloudflareAccountId").String()
|
||||
if c.typ == providerTypeGemini {
|
||||
if c.typ == providerTypeGemini || c.typ == providerTypeVertex {
|
||||
c.geminiSafetySetting = make(map[string]string)
|
||||
for k, v := range json.Get("geminiSafetySetting").Map() {
|
||||
c.geminiSafetySetting[k] = v.String()
|
||||
}
|
||||
}
|
||||
c.vertexRegion = json.Get("vertexRegion").String()
|
||||
c.vertexProjectId = json.Get("vertexProjectId").String()
|
||||
c.vertexAuthKey = json.Get("vertexAuthKey").String()
|
||||
c.vertexAuthServiceName = json.Get("vertexAuthServiceName").String()
|
||||
c.vertexTokenRefreshAhead = json.Get("vertexTokenRefreshAhead").Int()
|
||||
if c.vertexTokenRefreshAhead == 0 {
|
||||
c.vertexTokenRefreshAhead = 60
|
||||
}
|
||||
c.targetLang = json.Get("targetLang").String()
|
||||
|
||||
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
|
||||
@@ -439,11 +520,14 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
case string(ApiNameChatCompletion),
|
||||
string(ApiNameEmbeddings),
|
||||
string(ApiNameImageGeneration),
|
||||
string(ApiNameImageVariation),
|
||||
string(ApiNameImageEdit),
|
||||
string(ApiNameAudioSpeech),
|
||||
string(ApiNameCohereV1Rerank):
|
||||
c.capabilities[capability] = pathJson.String()
|
||||
}
|
||||
}
|
||||
c.subPath = json.Get("subPath").String()
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) Validate() error {
|
||||
@@ -596,13 +680,25 @@ func doGetMappedModel(model string, modelMapping map[string]string) string {
|
||||
}
|
||||
|
||||
for k, v := range modelMapping {
|
||||
if k == wildcard || !strings.HasSuffix(k, wildcard) {
|
||||
if k == wildcard {
|
||||
continue
|
||||
}
|
||||
k = strings.TrimSuffix(k, wildcard)
|
||||
if strings.HasPrefix(model, k) {
|
||||
log.Debugf("model [%s] is mapped to [%s] via prefix [%s]", model, v, k)
|
||||
return v
|
||||
if strings.HasSuffix(k, wildcard) {
|
||||
k = strings.TrimSuffix(k, wildcard)
|
||||
if strings.HasPrefix(model, k) {
|
||||
log.Debugf("model [%s] is mapped to [%s] via prefix [%s]", model, v, k)
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, "~") {
|
||||
k = strings.TrimPrefix(k, "~")
|
||||
re := regexp.MustCompile(k)
|
||||
if re.MatchString(model) {
|
||||
v = re.ReplaceAllString(model, v)
|
||||
log.Debugf("model [%s] is mapped to [%s] via regex [%s]", model, v, k)
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -703,7 +799,8 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) handleRequestBody(
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte,
|
||||
) (types.Action, error) {
|
||||
// use original protocol
|
||||
if c.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
@@ -741,10 +838,17 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
|
||||
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
|
||||
headers := util.GetOriginalRequestHeaders()
|
||||
originPath := headers.Get(":path")
|
||||
if c.subPath != "" {
|
||||
headers.Set(":path", strings.TrimPrefix(originPath, c.subPath))
|
||||
}
|
||||
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
|
||||
handler.TransformRequestHeaders(ctx, apiName, headers)
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
if headers.Get(":path") != originPath {
|
||||
headers.Set("X-ENVOY-ORIGINAL-PATH", originPath)
|
||||
}
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
}
|
||||
|
||||
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
|
||||
@@ -771,3 +875,19 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool {
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion,
|
||||
ApiNameCompletion,
|
||||
ApiNameEmbeddings,
|
||||
ApiNameImageGeneration,
|
||||
ApiNameImageEdit,
|
||||
ApiNameImageVariation,
|
||||
ApiNameAudioSpeech,
|
||||
ApiNameFineTuningJobs,
|
||||
ApiNameResponses:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -23,13 +23,22 @@ import (
|
||||
const (
|
||||
qwenResultFormatMessage = "message"
|
||||
|
||||
qwenDefaultDomain = "dashscope.aliyuncs.com"
|
||||
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
|
||||
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
qwenChatCompatiblePath = "/compatible-mode/v1/chat/completions"
|
||||
qwenTextEmbeddingCompatiblePath = "/compatible-mode/v1/embeddings"
|
||||
qwenBailianPath = "/api/v1/apps"
|
||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
qwenDefaultDomain = "dashscope.aliyuncs.com"
|
||||
qwenChatCompletionPath = "/api/v1/services/aigc/text-generation/generation"
|
||||
qwenTextEmbeddingPath = "/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
qwenCompatibleChatCompletionPath = "/compatible-mode/v1/chat/completions"
|
||||
qwenCompatibleCompletionsPath = "/compatible-mode/v1/completions"
|
||||
qwenCompatibleTextEmbeddingPath = "/compatible-mode/v1/embeddings"
|
||||
qwenCompatibleFilesPath = "/compatible-mode/v1/files"
|
||||
qwenCompatibleRetrieveFilePath = "/compatible-mode/v1/files/{file_id}"
|
||||
qwenCompatibleRetrieveFileContentPath = "/compatible-mode/v1/files/{file_id}/content"
|
||||
qwenCompatibleBatchesPath = "/compatible-mode/v1/batches"
|
||||
qwenCompatibleRetrieveBatchPath = "/compatible-mode/v1/batches/{batch_id}"
|
||||
qwenBailianPath = "/api/v1/apps"
|
||||
qwenMultimodalGenerationPath = "/api/v1/services/aigc/multimodal-generation/generation"
|
||||
|
||||
qwenAsyncAIGCPath = "/api/v1/services/aigc/"
|
||||
qwenAsyncTaskPath = "/api/v1/tasks/"
|
||||
|
||||
qwenTopPMin = 0.000001
|
||||
qwenTopPMax = 0.999999
|
||||
@@ -40,8 +49,7 @@ const (
|
||||
qwenVlModelPrefixName = "qwen-vl"
|
||||
)
|
||||
|
||||
type qwenProviderInitializer struct {
|
||||
}
|
||||
type qwenProviderInitializer struct{}
|
||||
|
||||
func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if len(config.qwenFileIds) != 0 && config.context != nil {
|
||||
@@ -56,13 +64,21 @@ func (m *qwenProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
func (m *qwenProviderInitializer) DefaultCapabilities(qwenEnableCompatible bool) map[string]string {
|
||||
if qwenEnableCompatible {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): qwenChatCompatiblePath,
|
||||
string(ApiNameEmbeddings): qwenTextEmbeddingCompatiblePath,
|
||||
string(ApiNameChatCompletion): qwenCompatibleChatCompletionPath,
|
||||
string(ApiNameEmbeddings): qwenCompatibleTextEmbeddingPath,
|
||||
string(ApiNameCompletion): qwenCompatibleCompletionsPath,
|
||||
string(ApiNameFiles): qwenCompatibleFilesPath,
|
||||
string(ApiNameRetrieveFile): qwenCompatibleRetrieveFilePath,
|
||||
string(ApiNameRetrieveFileContent): qwenCompatibleRetrieveFileContentPath,
|
||||
string(ApiNameBatches): qwenCompatibleBatchesPath,
|
||||
string(ApiNameRetrieveBatch): qwenCompatibleRetrieveBatchPath,
|
||||
}
|
||||
} else {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): qwenChatCompletionPath,
|
||||
string(ApiNameEmbeddings): qwenTextEmbeddingPath,
|
||||
string(ApiNameQwenAsyncAIGC): qwenAsyncAIGCPath,
|
||||
string(ApiNameQwenAsyncTask): qwenAsyncTaskPath,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -291,7 +307,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
|
||||
message := qwenMessageToChatMessage(qwenChoice.Message, m.config.reasoningContentMode)
|
||||
choices = append(choices, chatCompletionChoice{
|
||||
Message: &message,
|
||||
FinishReason: qwenChoice.FinishReason,
|
||||
FinishReason: util.Ptr(qwenChoice.FinishReason),
|
||||
})
|
||||
}
|
||||
return &chatCompletionResponse{
|
||||
@@ -301,7 +317,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: usage{
|
||||
Usage: &usage{
|
||||
PromptTokens: qwenResponse.Usage.InputTokens,
|
||||
CompletionTokens: qwenResponse.Usage.OutputTokens,
|
||||
TotalTokens: qwenResponse.Usage.TotalTokens,
|
||||
@@ -402,11 +418,11 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont
|
||||
|
||||
if finished {
|
||||
finishResponse := *&baseMessage
|
||||
finishResponse.Choices = append(finishResponse.Choices, chatCompletionChoice{Delta: &chatMessage{}, FinishReason: qwenChoice.FinishReason})
|
||||
finishResponse.Choices = append(finishResponse.Choices, chatCompletionChoice{Delta: &chatMessage{}, FinishReason: util.Ptr(qwenChoice.FinishReason)})
|
||||
|
||||
usageResponse := *&baseMessage
|
||||
usageResponse.Choices = []chatCompletionChoice{{Delta: &chatMessage{}}}
|
||||
usageResponse.Usage = usage{
|
||||
usageResponse.Usage = &usage{
|
||||
PromptTokens: qwenResponse.Usage.InputTokens,
|
||||
CompletionTokens: qwenResponse.Usage.OutputTokens,
|
||||
TotalTokens: qwenResponse.Usage.TotalTokens,
|
||||
@@ -673,11 +689,15 @@ func (m *qwenProvider) GetApiName(path string) ApiName {
|
||||
case strings.Contains(path, qwenChatCompletionPath),
|
||||
strings.Contains(path, qwenMultimodalGenerationPath),
|
||||
strings.Contains(path, qwenBailianPath),
|
||||
strings.Contains(path, qwenChatCompatiblePath):
|
||||
strings.Contains(path, qwenCompatibleChatCompletionPath):
|
||||
return ApiNameChatCompletion
|
||||
case strings.Contains(path, qwenTextEmbeddingPath),
|
||||
strings.Contains(path, qwenTextEmbeddingCompatiblePath):
|
||||
strings.Contains(path, qwenCompatibleTextEmbeddingPath):
|
||||
return ApiNameEmbeddings
|
||||
case strings.Contains(path, qwenAsyncAIGCPath):
|
||||
return ApiNameQwenAsyncAIGC
|
||||
case strings.Contains(path, qwenAsyncTaskPath):
|
||||
return ApiNameQwenAsyncTask
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -15,12 +15,10 @@ import (
|
||||
|
||||
// sparkProvider is the provider for SparkLLM AI service.
|
||||
const (
|
||||
sparkHost = "spark-api-open.xf-yun.com"
|
||||
sparkChatCompletionPath = "/v1/chat/completions"
|
||||
sparkHost = "spark-api-open.xf-yun.com"
|
||||
)
|
||||
|
||||
type sparkProviderInitializer struct {
|
||||
}
|
||||
type sparkProviderInitializer struct{}
|
||||
|
||||
type sparkProvider struct {
|
||||
config ProviderConfig
|
||||
@@ -58,7 +56,7 @@ func (i *sparkProviderInitializer) ValidateConfig(config *ProviderConfig) error
|
||||
|
||||
func (i *sparkProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): sparkChatCompletionPath,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,7 +150,7 @@ func (p *sparkProvider) responseSpark2OpenAI(ctx wrapper.HttpContext, response *
|
||||
Object: objectChatCompletion,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Choices: choices,
|
||||
Usage: response.Usage,
|
||||
Usage: &response.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +168,7 @@ func (p *sparkProvider) streamResponseSpark2OpenAI(ctx wrapper.HttpContext, resp
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
Object: objectChatCompletion,
|
||||
Choices: choices,
|
||||
Usage: response.Usage,
|
||||
Usage: &response.Usage,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,12 +10,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
stepfunDomain = "api.stepfun.com"
|
||||
stepfunChatCompletionPath = "/v1/chat/completions"
|
||||
stepfunDomain = "api.stepfun.com"
|
||||
)
|
||||
|
||||
type stepfunProviderInitializer struct {
|
||||
}
|
||||
type stepfunProviderInitializer struct{}
|
||||
|
||||
func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
@@ -27,7 +25,7 @@ func (m *stepfunProviderInitializer) ValidateConfig(config *ProviderConfig) erro
|
||||
func (m *stepfunProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
// stepfun的chat接口path和OpenAI的chat接口一样
|
||||
string(ApiNameChatCompletion): stepfunChatCompletionPath,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
togetherAIDomain = "api.together.xyz"
|
||||
togetherAICompletionPath = "/v1/chat/completions"
|
||||
togetherAIDomain = "api.together.xyz"
|
||||
)
|
||||
|
||||
type togetherAIProviderInitializer struct{}
|
||||
@@ -26,7 +25,7 @@ func (m *togetherAIProviderInitializer) ValidateConfig(config *ProviderConfig) e
|
||||
|
||||
func (m *togetherAIProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): togetherAICompletionPath,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +66,7 @@ func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, ap
|
||||
}
|
||||
|
||||
func (m *togetherAIProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, togetherAICompletionPath) {
|
||||
if strings.Contains(path, PathOpenAIChatCompletions) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user