mirror of
https://github.com/alibaba/higress.git
synced 2026-02-27 06:00:51 +08:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
020b5f3984 | ||
|
|
9a12f0b593 | ||
|
|
7e74eeb333 | ||
|
|
fff5903007 | ||
|
|
a00b810be5 | ||
|
|
3e0a5f02a7 | ||
|
|
44c33617fa | ||
|
|
b2ffeff7b8 | ||
|
|
c0ddbccbfe | ||
|
|
16a18c6609 | ||
|
|
72b98ab6cf | ||
|
|
df20472f7b | ||
|
|
9186b5505d | ||
|
|
eaea782693 | ||
|
|
890a802481 | ||
|
|
bb69a1d50b | ||
|
|
5a023512fa | ||
|
|
47f0478ef5 | ||
|
|
c9fa8d15db | ||
|
|
0f1afcdcca | ||
|
|
19d1548971 | ||
|
|
24dca0455e | ||
|
|
be603af461 | ||
|
|
8796c6040f | ||
|
|
15edc79fb3 | ||
|
|
5822868f87 | ||
|
|
995bcc2168 | ||
|
|
a3310f1a3b | ||
|
|
0bb934073a | ||
|
|
247de6a349 | ||
|
|
79b3b23aab | ||
|
|
b9d6343efa | ||
|
|
0af00bef6b | ||
|
|
953b95cf92 | ||
|
|
a76808171f | ||
|
|
f7813df1d7 | ||
|
|
33ce18df5a | ||
|
|
a1bf1ff009 | ||
|
|
b69e3a8f30 | ||
|
|
5ee878198c | ||
|
|
943fda0a9c | ||
|
|
abc31169a2 | ||
|
|
5f65b4f5b0 | ||
|
|
645646fe22 | ||
|
|
4acb65cc67 | ||
|
|
e63a2e0251 |
@@ -122,7 +122,7 @@ jobs:
|
||||
set -e
|
||||
cd /workspace/plugins/wasm-go/extensions/${PLUGIN_NAME}
|
||||
go mod tidy
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o plugin.wasm main.go
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o plugin.wasm .
|
||||
tar czvf plugin.tar.gz plugin.wasm
|
||||
echo ${{ secrets.REGISTRY_PASSWORD }} | oras login -u ${{ secrets.REGISTRY_USERNAME }} --password-stdin ${{ env.IMAGE_REGISTRY_SERVICE }}
|
||||
oras push ${target_image} ${push_command}
|
||||
|
||||
378
.github/workflows/wasm-plugin-unit-test.yml
vendored
Normal file
378
.github/workflows/wasm-plugin-unit-test.yml
vendored
Normal file
@@ -0,0 +1,378 @@
|
||||
name: Wasm Plugin Unit Tests(GO)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'plugins/wasm-go/extensions/**'
|
||||
- '.github/workflows/wasm-plugin-unit-test.yml'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
pull_request:
|
||||
branches: [ "*" ]
|
||||
paths:
|
||||
- 'plugins/wasm-go/extensions/**'
|
||||
- '.github/workflows/wasm-plugin-unit-test.yml'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
|
||||
env:
|
||||
GO111MODULE: on
|
||||
CGO_ENABLED: 0
|
||||
GOOS: linux
|
||||
GOARCH: amd64
|
||||
|
||||
jobs:
|
||||
detect-changed-plugins:
|
||||
name: Detect Changed Plugins
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed-plugins: ${{ steps.detect.outputs.plugins }}
|
||||
has-changes: ${{ steps.detect.outputs.has-changes }}
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # 获取完整历史用于比较
|
||||
|
||||
- name: Detect changed plugins
|
||||
id: detect
|
||||
run: |
|
||||
# 获取变更的文件列表
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
# PR模式:比较目标分支和源分支
|
||||
git fetch origin ${{ github.base_ref }}
|
||||
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}...HEAD)
|
||||
else
|
||||
# Push模式:比较当前提交和上一个提交
|
||||
CHANGED_FILES=$(git diff --name-only HEAD~1 HEAD)
|
||||
fi
|
||||
|
||||
echo "Changed files:"
|
||||
echo "$CHANGED_FILES"
|
||||
|
||||
# 提取变更的插件名称
|
||||
CHANGED_PLUGINS=""
|
||||
for file in $CHANGED_FILES; do
|
||||
if [[ $file =~ ^plugins/wasm-go/extensions/([^/]+)/ ]]; then
|
||||
PLUGIN_NAME="${BASH_REMATCH[1]}"
|
||||
if [[ ! " $CHANGED_PLUGINS " =~ " $PLUGIN_NAME " ]]; then
|
||||
# 修复:只在非空时添加空格
|
||||
if [ -z "$CHANGED_PLUGINS" ]; then
|
||||
CHANGED_PLUGINS="$PLUGIN_NAME"
|
||||
else
|
||||
CHANGED_PLUGINS="$CHANGED_PLUGINS $PLUGIN_NAME"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# 如果没有插件变更,不触发测试
|
||||
if [ -z "$CHANGED_PLUGINS" ]; then
|
||||
echo "No plugin changes detected, skipping tests"
|
||||
echo "has-changes=false" >> $GITHUB_OUTPUT
|
||||
echo "plugins=[]" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "Changed plugins: $CHANGED_PLUGINS"
|
||||
echo "has-changes=true" >> $GITHUB_OUTPUT
|
||||
# 将空格分隔转换为 JSON 数组格式
|
||||
PLUGINS_JSON=$(echo "$CHANGED_PLUGINS" | sed 's/ /","/g' | sed 's/^/["/' | sed 's/$/"]/')
|
||||
echo "PLUGINS_JSON: $PLUGINS_JSON"
|
||||
echo "plugins=$PLUGINS_JSON" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
test:
|
||||
name: Test Changed Plugins
|
||||
runs-on: ubuntu-latest
|
||||
needs: detect-changed-plugins
|
||||
if: needs.detect-changed-plugins.outputs.has-changes == 'true'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
plugin: ${{ fromJSON(needs.detect-changed-plugins.outputs.changed-plugins) }}
|
||||
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go 1.24
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.24
|
||||
cache: true
|
||||
|
||||
- name: Install test tools
|
||||
run: |
|
||||
go install gotest.tools/gotestsum@latest
|
||||
# 移除gocov工具,直接使用Codecov
|
||||
|
||||
- name: Build WASM for ${{ matrix.plugin }}
|
||||
working-directory: plugins/wasm-go/extensions/${{ matrix.plugin }}
|
||||
run: |
|
||||
echo "Building WASM for ${{ matrix.plugin }}..."
|
||||
|
||||
# 检查是否存在main.go文件
|
||||
|
||||
export GOOS=wasip1
|
||||
export GOARCH=wasm
|
||||
|
||||
# 构建WASM文件,失败时直接退出
|
||||
if ! go build -buildmode=c-shared -o main.wasm ./; then
|
||||
echo "❌ WASM build failed for ${{ matrix.plugin }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 验证WASM文件是否生成
|
||||
if [ ! -f "main.wasm" ]; then
|
||||
echo "❌ WASM file not generated for ${{ matrix.plugin }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ WASM build successful for ${{ matrix.plugin }}"
|
||||
|
||||
|
||||
- name: Set WASM_PATH environment variable
|
||||
run: |
|
||||
echo "WASM_PATH=$(pwd)/plugins/wasm-go/extensions/${{ matrix.plugin }}/main.wasm" >> $GITHUB_ENV
|
||||
|
||||
- name: Run tests with coverage for ${{ matrix.plugin }}
|
||||
working-directory: plugins/wasm-go/extensions/${{ matrix.plugin }}
|
||||
run: |
|
||||
# 检查是否存在main_test.go文件
|
||||
if [ -f "main_test.go" ]; then
|
||||
echo "Running tests for ${{ matrix.plugin }}..."
|
||||
|
||||
# 运行测试并生成覆盖率报告
|
||||
gotestsum --junitfile ../../../../test-results-${{ matrix.plugin }}.xml \
|
||||
--format standard-verbose \
|
||||
--jsonfile ../../../../test-output-${{ matrix.plugin }}.json \
|
||||
-- -coverprofile=coverage-${{ matrix.plugin }}.out -covermode=atomic -coverpkg=./... ./...
|
||||
|
||||
echo "✅ Tests completed for ${{ matrix.plugin }}"
|
||||
else
|
||||
echo "No tests found for ${{ matrix.plugin }}, skipping..."
|
||||
# 创建空的测试结果文件
|
||||
echo '<?xml version="1.0" encoding="UTF-8"?><testsuites><testsuite name="no-tests" tests="0" failures="0" errors="0" time="0"></testsuite></testsuites>' > ../../../../test-results-${{ matrix.plugin }}.xml
|
||||
fi
|
||||
|
||||
- name: Upload test results for ${{ matrix.plugin }}
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: test-results-${{ matrix.plugin }}
|
||||
path: |
|
||||
test-results-${{ matrix.plugin }}.xml
|
||||
test-output-${{ matrix.plugin }}.json
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload coverage report for ${{ matrix.plugin }}
|
||||
uses: actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
name: coverage-${{ matrix.plugin }}
|
||||
path: plugins/wasm-go/extensions/${{ matrix.plugin }}/coverage-${{ matrix.plugin }}.out
|
||||
retention-days: 30
|
||||
|
||||
test-summary:
|
||||
name: Test Summary & Coverage
|
||||
runs-on: ubuntu-latest
|
||||
needs: [detect-changed-plugins, test]
|
||||
if: always() && needs.detect-changed-plugins.outputs.has-changes == 'true'
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go 1.24
|
||||
uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: 1.24
|
||||
cache: true
|
||||
|
||||
- name: Install required tools
|
||||
run: |
|
||||
go install github.com/wadey/gocovmerge@latest
|
||||
|
||||
- name: Download all test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: test-results-*
|
||||
merge-multiple: true
|
||||
path: ${{ github.workspace }}
|
||||
|
||||
- name: Download all coverage files
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: coverage-*
|
||||
merge-multiple: true
|
||||
path: ${{ github.workspace }}
|
||||
|
||||
|
||||
|
||||
- name: Generate comprehensive test summary
|
||||
run: |
|
||||
echo "## 🧪 Go Plugin Test Results" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
total_plugins=0
|
||||
passed_plugins=0
|
||||
failed_plugins=0
|
||||
total_tests=0
|
||||
total_failures=0
|
||||
total_errors=0
|
||||
|
||||
echo "### 📊 Test Results by Plugin" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
for result_file in test-results-*.xml; do
|
||||
if [ -f "$result_file" ]; then
|
||||
plugin_name=$(echo "$result_file" | sed 's/test-results-\(.*\)\.xml/\1/')
|
||||
total_plugins=$((total_plugins + 1))
|
||||
|
||||
# 解析XML获取测试结果
|
||||
if grep -q '<testsuite' "$result_file"; then
|
||||
# 使用grep解析XML属性,更稳定可靠
|
||||
tests=$(grep -o 'tests="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
|
||||
failures=$(grep -o 'failures="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
|
||||
errors=$(grep -o 'errors="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
|
||||
time=$(grep -o 'time="[0-9.]*"' "$result_file" | head -1 | grep -o '[0-9.]*' || echo "0")
|
||||
|
||||
# 确保数值有效,避免bash算术运算错误
|
||||
tests=${tests:-0}
|
||||
failures=${failures:-0}
|
||||
errors=${errors:-0}
|
||||
|
||||
# 转换为整数进行算术运算
|
||||
total_tests=$((total_tests + tests))
|
||||
total_failures=$((total_failures + failures))
|
||||
total_errors=$((total_errors + errors))
|
||||
|
||||
if [ "$failures" = "0" ] && [ "$errors" = "0" ]; then
|
||||
echo "✅ **$plugin_name**: $tests tests passed in ${time}s" >> $GITHUB_STEP_SUMMARY
|
||||
passed_plugins=$((passed_plugins + 1))
|
||||
else
|
||||
echo "❌ **$plugin_name**: $tests tests, $failures failures, $errors errors in ${time}s" >> $GITHUB_STEP_SUMMARY
|
||||
failed_plugins=$((failed_plugins + 1))
|
||||
fi
|
||||
else
|
||||
echo "⚠️ **$plugin_name**: No tests found" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "### 📈 Coverage Report" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "📊 **Coverage reports are now available on Codecov**" >> $GITHUB_STEP_SUMMARY
|
||||
echo "🔗 **This Commit Coverage**: https://codecov.io/gh/${{ github.repository }}/commit/${{ github.sha }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
echo "### 🎯 Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- **Total plugins**: $total_plugins" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- **Passed**: $passed_plugins ✅" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- **Failed**: $failed_plugins ❌" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- **Total tests**: $total_tests" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- **Total failures**: $total_failures" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- **Total errors**: $total_errors" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# 如果有失败,显示详细信息
|
||||
if [ $total_failures -gt 0 ] || [ $total_errors -gt 0 ]; then
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "### ❌ Failed Tests Details" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Failed plugins**: $failed_plugins" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Total failures**: $total_failures" >> $GITHUB_STEP_SUMMARY
|
||||
echo "**Total errors**: $total_errors" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "📋 **View detailed logs**: [Click here](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# 显示每个失败插件的详细信息
|
||||
echo "#### 📊 Failed Plugin Details" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
for result_file in test-results-*.xml; do
|
||||
if [ -f "$result_file" ]; then
|
||||
plugin_name=$(echo "$result_file" | sed 's/test-results-\(.*\)\.xml/\1/')
|
||||
|
||||
# 检查是否有失败
|
||||
failures=$(grep -o 'failures="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
|
||||
errors=$(grep -o 'errors="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
|
||||
|
||||
# 确保数值有效
|
||||
failures=${failures:-0}
|
||||
errors=${errors:-0}
|
||||
|
||||
if [ "$failures" -gt 0 ] || [ "$errors" -gt 0 ]; then
|
||||
echo "**$plugin_name**:" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Failures: $failures" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Errors: $errors" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- [View plugin logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
# - name: Merge coverage reports
|
||||
# run: |
|
||||
# echo "Merging coverage reports..."
|
||||
#
|
||||
# # 使用绝对路径查找,更可靠
|
||||
# coverage_files=$(find ${{ github.workspace }} -name "coverage-*")
|
||||
#
|
||||
# if [ -n "$coverage_files" ]; then
|
||||
# echo "Found coverage files:"
|
||||
# echo "$coverage_files"
|
||||
#
|
||||
# # 使用gocovmerge顺序合并
|
||||
# echo "Merging Go coverage files using gocovmerge sequential method..."
|
||||
#
|
||||
# # 将文件列表转换为数组
|
||||
# readarray -t coverage_array <<< "$coverage_files"
|
||||
# file_count=${#coverage_array[@]}
|
||||
#
|
||||
# echo "Total files to merge: $file_count"
|
||||
#
|
||||
# # 复制第一个文件作为基础
|
||||
# cp "${coverage_array[0]}" ${{ github.workspace }}/merged_coverage.out
|
||||
# echo "Starting with: ${coverage_array[0]}"
|
||||
#
|
||||
# # 如果有多个文件,逐个合并其他文件到最终目标
|
||||
# if [ $file_count -gt 1 ]; then
|
||||
# echo "Multiple files, merging sequentially with gocovmerge..."
|
||||
#
|
||||
# for ((i=1; i<file_count; i++)); do
|
||||
# current_file="${coverage_array[i]}"
|
||||
#
|
||||
# echo "Merging file $((i+1))/$file_count: $current_file"
|
||||
#
|
||||
# # 使用gocovmerge合并到最终目标文件
|
||||
# gocovmerge "${{ github.workspace }}/merged_coverage.out" "$current_file" > "${{ github.workspace }}/temp_merge.out"
|
||||
# mv "${{ github.workspace }}/temp_merge.out" "${{ github.workspace }}/merged_coverage.out"
|
||||
#
|
||||
# echo "Successfully merged with $current_file"
|
||||
# done
|
||||
# fi
|
||||
#
|
||||
# echo "Coverage reports merged successfully using gocovmerge sequential method"
|
||||
# echo "Merged file size: $(wc -c < ${{ github.workspace }}/merged_coverage.out) bytes"
|
||||
# else
|
||||
# echo "No coverage files found"
|
||||
# # 创建空的覆盖率文件
|
||||
# echo "mode: atomic" > ${{ github.workspace }}/merged_coverage.out
|
||||
# fi
|
||||
|
||||
# - name: Upload merged coverage to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# if: always()
|
||||
# with:
|
||||
# file: ${{ github.workspace }}/merged_coverage.out
|
||||
# flags: wasm-go-plugins-tests
|
||||
# name: codecov-wasm-go-plugins
|
||||
# fail_ci_if_error: false
|
||||
# verbose: true
|
||||
@@ -3,7 +3,7 @@
|
||||
/istio @SpecialYang @johnlanni
|
||||
/pkg @SpecialYang @johnlanni @CH3CHO
|
||||
/plugins @johnlanni @CH3CHO @rinfx @erasernoob
|
||||
/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx @wydream
|
||||
/plugins/wasm-go/extensions/ai-proxy @rinfx @wydream @johnlanni
|
||||
/plugins/wasm-rust @007gzs @jizhuozhi
|
||||
/registry @Erica177 @2456868764 @johnlanni
|
||||
/test @Xunzhuo @2456868764 @CH3CHO
|
||||
|
||||
@@ -1 +1 @@
|
||||
higress-console: v2.1.6
|
||||
higress-console: v2.1.7
|
||||
@@ -137,6 +137,8 @@ endif
|
||||
# for now docker is limited to Linux compiles - why ?
|
||||
include docker/docker.mk
|
||||
|
||||
docker-build-amd64: docker.higress-amd64 ## Build and push amdd64 docker images to registry defined by $HUB and $TAG
|
||||
|
||||
docker-build: docker.higress ## Build and push docker images to registry defined by $HUB and $TAG
|
||||
|
||||
docker-buildx-push: clean-env docker.higress-buildx
|
||||
@@ -144,7 +146,7 @@ docker-buildx-push: clean-env docker.higress-buildx
|
||||
export PARENT_GIT_TAG:=$(shell cat VERSION)
|
||||
export PARENT_GIT_REVISION:=$(TAG)
|
||||
|
||||
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.8/envoy-symbol-ARCH.tar.gz
|
||||
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.9/envoy-symbol-ARCH.tar.gz
|
||||
|
||||
build-envoy: prebuild
|
||||
./tools/hack/build-envoy.sh
|
||||
@@ -192,7 +194,7 @@ install: pre-install
|
||||
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
|
||||
|
||||
HIGRESS_LATEST_IMAGE_TAG ?= latest
|
||||
ENVOY_LATEST_IMAGE_TAG ?= latest
|
||||
ENVOY_LATEST_IMAGE_TAG ?= 48da465cfd0dc5c9ac851bd2b9743780dc82dd8a
|
||||
ISTIO_LATEST_IMAGE_TAG ?= latest
|
||||
|
||||
install-dev: pre-install
|
||||
|
||||
@@ -247,6 +247,23 @@ spec:
|
||||
properties:
|
||||
spec:
|
||||
properties:
|
||||
proxies:
|
||||
items:
|
||||
properties:
|
||||
connectTimeout:
|
||||
type: integer
|
||||
listenerPort:
|
||||
type: integer
|
||||
name:
|
||||
type: string
|
||||
serverAddress:
|
||||
type: string
|
||||
serverPort:
|
||||
type: integer
|
||||
type:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
registries:
|
||||
items:
|
||||
properties:
|
||||
@@ -309,6 +326,8 @@ spec:
|
||||
type: integer
|
||||
protocol:
|
||||
type: string
|
||||
proxyName:
|
||||
type: string
|
||||
sni:
|
||||
type: string
|
||||
type:
|
||||
|
||||
@@ -65,6 +65,7 @@ type McpBridge struct {
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Registries []*RegistryConfig `protobuf:"bytes,1,rep,name=registries,proto3" json:"registries,omitempty"`
|
||||
Proxies []*ProxyConfig `protobuf:"bytes,2,rep,name=proxies,proto3" json:"proxies,omitempty"`
|
||||
}
|
||||
|
||||
func (x *McpBridge) Reset() {
|
||||
@@ -106,6 +107,13 @@ func (x *McpBridge) GetRegistries() []*RegistryConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *McpBridge) GetProxies() []*ProxyConfig {
|
||||
if x != nil {
|
||||
return x.Proxies
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RegistryConfig struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -136,6 +144,7 @@ type RegistryConfig struct {
|
||||
EnableScopeMcpServers *wrappers.BoolValue `protobuf:"bytes,23,opt,name=enableScopeMcpServers,proto3" json:"enableScopeMcpServers,omitempty"`
|
||||
AllowMcpServers []string `protobuf:"bytes,24,rep,name=allowMcpServers,proto3" json:"allowMcpServers,omitempty"`
|
||||
Metadata map[string]*InnerMap `protobuf:"bytes,25,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
ProxyName string `protobuf:"bytes,26,opt,name=proxyName,proto3" json:"proxyName,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RegistryConfig) Reset() {
|
||||
@@ -345,6 +354,100 @@ func (x *RegistryConfig) GetMetadata() map[string]*InnerMap {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RegistryConfig) GetProxyName() string {
|
||||
if x != nil {
|
||||
return x.ProxyName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ProxyConfig struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
|
||||
ServerAddress string `protobuf:"bytes,3,opt,name=serverAddress,proto3" json:"serverAddress,omitempty"`
|
||||
ServerPort uint32 `protobuf:"varint,4,opt,name=serverPort,proto3" json:"serverPort,omitempty"`
|
||||
ListenerPort uint32 `protobuf:"varint,5,opt,name=listenerPort,proto3" json:"listenerPort,omitempty"`
|
||||
ConnectTimeout uint32 `protobuf:"varint,6,opt,name=connectTimeout,proto3" json:"connectTimeout,omitempty"`
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) Reset() {
|
||||
*x = ProxyConfig{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ProxyConfig) ProtoMessage() {}
|
||||
|
||||
func (x *ProxyConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ProxyConfig.ProtoReflect.Descriptor instead.
|
||||
func (*ProxyConfig) Descriptor() ([]byte, []int) {
|
||||
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) GetType() string {
|
||||
if x != nil {
|
||||
return x.Type
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) GetName() string {
|
||||
if x != nil {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) GetServerAddress() string {
|
||||
if x != nil {
|
||||
return x.ServerAddress
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) GetServerPort() uint32 {
|
||||
if x != nil {
|
||||
return x.ServerPort
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) GetListenerPort() uint32 {
|
||||
if x != nil {
|
||||
return x.ListenerPort
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *ProxyConfig) GetConnectTimeout() uint32 {
|
||||
if x != nil {
|
||||
return x.ConnectTimeout
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type InnerMap struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -356,7 +459,7 @@ type InnerMap struct {
|
||||
func (x *InnerMap) Reset() {
|
||||
*x = InnerMap{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -369,7 +472,7 @@ func (x *InnerMap) String() string {
|
||||
func (*InnerMap) ProtoMessage() {}
|
||||
|
||||
func (x *InnerMap) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
|
||||
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[3]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -382,7 +485,7 @@ func (x *InnerMap) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InnerMap.ProtoReflect.Descriptor instead.
|
||||
func (*InnerMap) Descriptor() ([]byte, []int) {
|
||||
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{2}
|
||||
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *InnerMap) GetInnerMap() map[string]string {
|
||||
@@ -404,100 +507,119 @@ var file_networking_v1_mcp_bridge_proto_rawDesc = []byte{
|
||||
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x77, 0x72, 0x61, 0x70, 0x70, 0x65,
|
||||
0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
|
||||
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74,
|
||||
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x52, 0x0a, 0x09, 0x4d, 0x63, 0x70, 0x42, 0x72, 0x69,
|
||||
0x64, 0x67, 0x65, 0x12, 0x45, 0x0a, 0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65,
|
||||
0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73,
|
||||
0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e,
|
||||
0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a,
|
||||
0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x22, 0xa8, 0x09, 0x0a, 0x0e, 0x52,
|
||||
0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x17, 0x0a,
|
||||
0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02,
|
||||
0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x6f,
|
||||
0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52,
|
||||
0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x17, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18,
|
||||
0x04, 0x20, 0x01, 0x28, 0x0d, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74,
|
||||
0x12, 0x2e, 0x0a, 0x12, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73,
|
||||
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x6e, 0x61,
|
||||
0x63, 0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
|
||||
0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4b,
|
||||
0x65, 0x79, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41,
|
||||
0x63, 0x63, 0x65, 0x73, 0x73, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f,
|
||||
0x73, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79,
|
||||
0x12, 0x2a, 0x0a, 0x10, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61,
|
||||
0x63, 0x65, 0x49, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6e, 0x61, 0x63, 0x6f,
|
||||
0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x0e,
|
||||
0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x09,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73,
|
||||
0x70, 0x61, 0x63, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x47, 0x72, 0x6f,
|
||||
0x75, 0x70, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73,
|
||||
0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52,
|
||||
0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x0b,
|
||||
0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52, 0x65, 0x66, 0x72, 0x65,
|
||||
0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x28, 0x0a, 0x0f, 0x63, 0x6f,
|
||||
0x6e, 0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x0c, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x0f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73,
|
||||
0x70, 0x61, 0x63, 0x65, 0x12, 0x26, 0x0a, 0x0e, 0x7a, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
|
||||
0x65, 0x73, 0x50, 0x61, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x7a, 0x6b,
|
||||
0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x50, 0x61, 0x74, 0x68, 0x12, 0x2a, 0x0a, 0x10,
|
||||
0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72,
|
||||
0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61,
|
||||
0x74, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x73,
|
||||
0x75, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x54, 0x61, 0x67, 0x18, 0x0f, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
|
||||
0x65, 0x54, 0x61, 0x67, 0x12, 0x34, 0x0a, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65,
|
||||
0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x10, 0x20,
|
||||
0x01, 0x28, 0x03, 0x52, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65, 0x66, 0x72, 0x65,
|
||||
0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x26, 0x0a, 0x0e, 0x61, 0x75,
|
||||
0x74, 0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x11, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61,
|
||||
0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x12,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10,
|
||||
0x0a, 0x03, 0x73, 0x6e, 0x69, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x6e, 0x69,
|
||||
0x12, 0x36, 0x0a, 0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70,
|
||||
0x6f, 0x72, 0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x14, 0x20, 0x03, 0x28, 0x09,
|
||||
0x52, 0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70, 0x6f, 0x72,
|
||||
0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x6d, 0x63, 0x70, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73, 0x65, 0x55, 0x72, 0x6c, 0x18, 0x15, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x10, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73,
|
||||
0x65, 0x55, 0x72, 0x6c, 0x12, 0x44, 0x0a, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x4d, 0x43,
|
||||
0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e,
|
||||
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
|
||||
0x42, 0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c,
|
||||
0x65, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x50, 0x0a, 0x15, 0x65, 0x6e,
|
||||
0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76,
|
||||
0x65, 0x72, 0x73, 0x18, 0x17, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
|
||||
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c,
|
||||
0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x15, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f,
|
||||
0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x28, 0x0a, 0x0f,
|
||||
0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18,
|
||||
0x18, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x4f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61,
|
||||
0x74, 0x61, 0x18, 0x19, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65,
|
||||
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x90, 0x01, 0x0a, 0x09, 0x4d, 0x63, 0x70, 0x42, 0x72,
|
||||
0x69, 0x64, 0x67, 0x65, 0x12, 0x45, 0x0a, 0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69,
|
||||
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65,
|
||||
0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31,
|
||||
0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e,
|
||||
0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d,
|
||||
0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x5c, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64,
|
||||
0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x76, 0x61,
|
||||
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x68, 0x69, 0x67, 0x72,
|
||||
0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76,
|
||||
0x31, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75,
|
||||
0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x93, 0x01, 0x0a, 0x08, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d,
|
||||
0x61, 0x70, 0x12, 0x4a, 0x0a, 0x09, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x5f, 0x6d, 0x61, 0x70, 0x18,
|
||||
0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e,
|
||||
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e,
|
||||
0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45,
|
||||
0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x1a, 0x3b,
|
||||
0x0a, 0x0d, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
|
||||
0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
|
||||
0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2e, 0x5a, 0x2c, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62,
|
||||
0x61, 0x2f, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6e, 0x65,
|
||||
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x33,
|
||||
0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
|
||||
0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x3c, 0x0a, 0x07, 0x70,
|
||||
0x72, 0x6f, 0x78, 0x69, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x68,
|
||||
0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e,
|
||||
0x67, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
|
||||
0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x69, 0x65, 0x73, 0x22, 0xc6, 0x09, 0x0a, 0x0e, 0x52, 0x65,
|
||||
0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x17, 0x0a, 0x04,
|
||||
0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52,
|
||||
0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x6f, 0x6d,
|
||||
0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x06,
|
||||
0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x17, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x04,
|
||||
0x20, 0x01, 0x28, 0x0d, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12,
|
||||
0x2e, 0x0a, 0x12, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x6e, 0x61, 0x63,
|
||||
0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
|
||||
0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4b, 0x65,
|
||||
0x79, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x63,
|
||||
0x63, 0x65, 0x73, 0x73, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73,
|
||||
0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x12,
|
||||
0x2a, 0x0a, 0x10, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63,
|
||||
0x65, 0x49, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6e, 0x61, 0x63, 0x6f, 0x73,
|
||||
0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
|
||||
0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x09, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70,
|
||||
0x61, 0x63, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x47, 0x72, 0x6f, 0x75,
|
||||
0x70, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x47,
|
||||
0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52, 0x65,
|
||||
0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x0b, 0x20,
|
||||
0x01, 0x28, 0x03, 0x52, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73,
|
||||
0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x28, 0x0a, 0x0f, 0x63, 0x6f, 0x6e,
|
||||
0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x0c, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x0f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70,
|
||||
0x61, 0x63, 0x65, 0x12, 0x26, 0x0a, 0x0e, 0x7a, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
|
||||
0x73, 0x50, 0x61, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x7a, 0x6b, 0x53,
|
||||
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x50, 0x61, 0x74, 0x68, 0x12, 0x2a, 0x0a, 0x10, 0x63,
|
||||
0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x18,
|
||||
0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61, 0x74,
|
||||
0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75,
|
||||
0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x54, 0x61, 0x67, 0x18, 0x0f, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
|
||||
0x54, 0x61, 0x67, 0x12, 0x34, 0x0a, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65, 0x66,
|
||||
0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x10, 0x20, 0x01,
|
||||
0x28, 0x03, 0x52, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73,
|
||||
0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x26, 0x0a, 0x0e, 0x61, 0x75, 0x74,
|
||||
0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x11, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61, 0x6d,
|
||||
0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x12, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a,
|
||||
0x03, 0x73, 0x6e, 0x69, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x6e, 0x69, 0x12,
|
||||
0x36, 0x0a, 0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70, 0x6f,
|
||||
0x72, 0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x14, 0x20, 0x03, 0x28, 0x09, 0x52,
|
||||
0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70, 0x6f, 0x72, 0x74,
|
||||
0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x6d, 0x63, 0x70, 0x53, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73, 0x65, 0x55, 0x72, 0x6c, 0x18, 0x15, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x10, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73, 0x65,
|
||||
0x55, 0x72, 0x6c, 0x12, 0x44, 0x0a, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x4d, 0x43, 0x50,
|
||||
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67,
|
||||
0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42,
|
||||
0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65,
|
||||
0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x50, 0x0a, 0x15, 0x65, 0x6e, 0x61,
|
||||
0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65,
|
||||
0x72, 0x73, 0x18, 0x17, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
|
||||
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c, 0x56,
|
||||
0x61, 0x6c, 0x75, 0x65, 0x52, 0x15, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70,
|
||||
0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x61,
|
||||
0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x18,
|
||||
0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x4f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
|
||||
0x61, 0x18, 0x19, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73,
|
||||
0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e,
|
||||
0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x4d,
|
||||
0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65,
|
||||
0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x4e,
|
||||
0x61, 0x6d, 0x65, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x78, 0x79,
|
||||
0x4e, 0x61, 0x6d, 0x65, 0x1a, 0x5c, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61,
|
||||
0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73,
|
||||
0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49,
|
||||
0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02,
|
||||
0x38, 0x01, 0x22, 0xdb, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x6f, 0x6e, 0x66,
|
||||
0x69, 0x67, 0x12, 0x17, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
|
||||
0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x04, 0x6e,
|
||||
0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04,
|
||||
0x6e, 0x61, 0x6d, 0x65, 0x12, 0x29, 0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
|
||||
0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02,
|
||||
0x52, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12,
|
||||
0x23, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20,
|
||||
0x01, 0x28, 0x0d, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
|
||||
0x50, 0x6f, 0x72, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72,
|
||||
0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x6c, 0x69, 0x73, 0x74,
|
||||
0x65, 0x6e, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x6f, 0x6e, 0x6e,
|
||||
0x65, 0x63, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d,
|
||||
0x52, 0x0e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74,
|
||||
0x22, 0x93, 0x01, 0x0a, 0x08, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x12, 0x4a, 0x0a,
|
||||
0x09, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x5f, 0x6d, 0x61, 0x70, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b,
|
||||
0x32, 0x2d, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f,
|
||||
0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61,
|
||||
0x70, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52,
|
||||
0x08, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x1a, 0x3b, 0x0a, 0x0d, 0x49, 0x6e, 0x6e,
|
||||
0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65,
|
||||
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05,
|
||||
0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c,
|
||||
0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2e, 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62, 0x61, 0x2f, 0x68, 0x69, 0x67,
|
||||
0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
|
||||
0x69, 0x6e, 0x67, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -512,27 +634,29 @@ func file_networking_v1_mcp_bridge_proto_rawDescGZIP() []byte {
|
||||
return file_networking_v1_mcp_bridge_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_networking_v1_mcp_bridge_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_networking_v1_mcp_bridge_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||
var file_networking_v1_mcp_bridge_proto_goTypes = []interface{}{
|
||||
(*McpBridge)(nil), // 0: higress.networking.v1.McpBridge
|
||||
(*RegistryConfig)(nil), // 1: higress.networking.v1.RegistryConfig
|
||||
(*InnerMap)(nil), // 2: higress.networking.v1.InnerMap
|
||||
nil, // 3: higress.networking.v1.RegistryConfig.MetadataEntry
|
||||
nil, // 4: higress.networking.v1.InnerMap.InnerMapEntry
|
||||
(*wrappers.BoolValue)(nil), // 5: google.protobuf.BoolValue
|
||||
(*ProxyConfig)(nil), // 2: higress.networking.v1.ProxyConfig
|
||||
(*InnerMap)(nil), // 3: higress.networking.v1.InnerMap
|
||||
nil, // 4: higress.networking.v1.RegistryConfig.MetadataEntry
|
||||
nil, // 5: higress.networking.v1.InnerMap.InnerMapEntry
|
||||
(*wrappers.BoolValue)(nil), // 6: google.protobuf.BoolValue
|
||||
}
|
||||
var file_networking_v1_mcp_bridge_proto_depIdxs = []int32{
|
||||
1, // 0: higress.networking.v1.McpBridge.registries:type_name -> higress.networking.v1.RegistryConfig
|
||||
5, // 1: higress.networking.v1.RegistryConfig.enableMCPServer:type_name -> google.protobuf.BoolValue
|
||||
5, // 2: higress.networking.v1.RegistryConfig.enableScopeMcpServers:type_name -> google.protobuf.BoolValue
|
||||
3, // 3: higress.networking.v1.RegistryConfig.metadata:type_name -> higress.networking.v1.RegistryConfig.MetadataEntry
|
||||
4, // 4: higress.networking.v1.InnerMap.inner_map:type_name -> higress.networking.v1.InnerMap.InnerMapEntry
|
||||
2, // 5: higress.networking.v1.RegistryConfig.MetadataEntry.value:type_name -> higress.networking.v1.InnerMap
|
||||
6, // [6:6] is the sub-list for method output_type
|
||||
6, // [6:6] is the sub-list for method input_type
|
||||
6, // [6:6] is the sub-list for extension type_name
|
||||
6, // [6:6] is the sub-list for extension extendee
|
||||
0, // [0:6] is the sub-list for field type_name
|
||||
2, // 1: higress.networking.v1.McpBridge.proxies:type_name -> higress.networking.v1.ProxyConfig
|
||||
6, // 2: higress.networking.v1.RegistryConfig.enableMCPServer:type_name -> google.protobuf.BoolValue
|
||||
6, // 3: higress.networking.v1.RegistryConfig.enableScopeMcpServers:type_name -> google.protobuf.BoolValue
|
||||
4, // 4: higress.networking.v1.RegistryConfig.metadata:type_name -> higress.networking.v1.RegistryConfig.MetadataEntry
|
||||
5, // 5: higress.networking.v1.InnerMap.inner_map:type_name -> higress.networking.v1.InnerMap.InnerMapEntry
|
||||
3, // 6: higress.networking.v1.RegistryConfig.MetadataEntry.value:type_name -> higress.networking.v1.InnerMap
|
||||
7, // [7:7] is the sub-list for method output_type
|
||||
7, // [7:7] is the sub-list for method input_type
|
||||
7, // [7:7] is the sub-list for extension type_name
|
||||
7, // [7:7] is the sub-list for extension extendee
|
||||
0, // [0:7] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_networking_v1_mcp_bridge_proto_init() }
|
||||
@@ -566,6 +690,18 @@ func file_networking_v1_mcp_bridge_proto_init() {
|
||||
}
|
||||
}
|
||||
file_networking_v1_mcp_bridge_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ProxyConfig); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_networking_v1_mcp_bridge_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*InnerMap); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
@@ -584,7 +720,7 @@ func file_networking_v1_mcp_bridge_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_networking_v1_mcp_bridge_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 5,
|
||||
NumMessages: 6,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
|
||||
@@ -46,6 +46,7 @@ option go_package = "github.com/alibaba/higress/api/networking/v1";
|
||||
// -->
|
||||
message McpBridge {
|
||||
repeated RegistryConfig registries = 1;
|
||||
repeated ProxyConfig proxies = 2;
|
||||
}
|
||||
|
||||
message RegistryConfig {
|
||||
@@ -74,6 +75,16 @@ message RegistryConfig {
|
||||
google.protobuf.BoolValue enableScopeMcpServers = 23;
|
||||
repeated string allowMcpServers = 24;
|
||||
map<string, InnerMap> metadata = 25;
|
||||
string proxyName = 26;
|
||||
}
|
||||
|
||||
message ProxyConfig {
|
||||
string type = 1 [(google.api.field_behavior) = REQUIRED];
|
||||
string name = 2 [(google.api.field_behavior) = REQUIRED];
|
||||
string serverAddress = 3 [(google.api.field_behavior) = REQUIRED];
|
||||
uint32 serverPort = 4 [(google.api.field_behavior) = REQUIRED];
|
||||
uint32 listenerPort = 5;
|
||||
uint32 connectTimeout = 6;
|
||||
}
|
||||
|
||||
message InnerMap {
|
||||
|
||||
@@ -47,6 +47,27 @@ func (in *RegistryConfig) DeepCopyInterface() interface{} {
|
||||
return in.DeepCopy()
|
||||
}
|
||||
|
||||
// DeepCopyInto supports using ProxyConfig within kubernetes types, where deepcopy-gen is used.
|
||||
func (in *ProxyConfig) DeepCopyInto(out *ProxyConfig) {
|
||||
p := proto.Clone(in).(*ProxyConfig)
|
||||
*out = *p
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProxyConfig. Required by controller-gen.
|
||||
func (in *ProxyConfig) DeepCopy() *ProxyConfig {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(ProxyConfig)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInterface is an autogenerated deepcopy function, copying the receiver, creating a new ProxyConfig. Required by controller-gen.
|
||||
func (in *ProxyConfig) DeepCopyInterface() interface{} {
|
||||
return in.DeepCopy()
|
||||
}
|
||||
|
||||
// DeepCopyInto supports using InnerMap within kubernetes types, where deepcopy-gen is used.
|
||||
func (in *InnerMap) DeepCopyInto(out *InnerMap) {
|
||||
p := proto.Clone(in).(*InnerMap)
|
||||
|
||||
@@ -28,6 +28,17 @@ func (this *RegistryConfig) UnmarshalJSON(b []byte) error {
|
||||
return McpBridgeUnmarshaler.Unmarshal(bytes.NewReader(b), this)
|
||||
}
|
||||
|
||||
// MarshalJSON is a custom marshaler for ProxyConfig
|
||||
func (this *ProxyConfig) MarshalJSON() ([]byte, error) {
|
||||
str, err := McpBridgeMarshaler.MarshalToString(this)
|
||||
return []byte(str), err
|
||||
}
|
||||
|
||||
// UnmarshalJSON is a custom unmarshaler for ProxyConfig
|
||||
func (this *ProxyConfig) UnmarshalJSON(b []byte) error {
|
||||
return McpBridgeUnmarshaler.Unmarshal(bytes.NewReader(b), this)
|
||||
}
|
||||
|
||||
// MarshalJSON is a custom marshaler for InnerMap
|
||||
func (this *InnerMap) MarshalJSON() ([]byte, error) {
|
||||
str, err := McpBridgeMarshaler.MarshalToString(this)
|
||||
|
||||
@@ -95,6 +95,6 @@ generate-k8s-client:
|
||||
|
||||
|
||||
.PHONY: clean-k8s-client
|
||||
clean-k8s-cliennt:
|
||||
clean-k8s-client:
|
||||
# remove generated code
|
||||
@rm -rf pkg/
|
||||
|
||||
@@ -6,11 +6,11 @@ ARG BASE_VERSION=latest
|
||||
|
||||
ARG HUB
|
||||
|
||||
ARG TARGETARCH
|
||||
|
||||
# The following section is used as base image if BASE_DISTRIBUTION=debug
|
||||
# This base image is provided by istio, see: https://github.com/istio/istio/blob/master/docker/Dockerfile.base
|
||||
FROM ${HUB}/base:${BASE_VERSION}
|
||||
|
||||
ARG TARGETARCH
|
||||
FROM ${HUB}/base:${BASE_VERSION}-${TARGETARCH}
|
||||
|
||||
COPY ${TARGETARCH:-amd64}/higress /usr/local/bin/higress
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@ docker.higress: $(OUT_LINUX)/higress
|
||||
docker.higress: docker/Dockerfile.higress
|
||||
$(HIGRESS_DOCKER_RULE)
|
||||
|
||||
docker.higress-amd64: BUILD_ARGS=--build-arg BASE_VERSION=${HIGRESS_BASE_VERSION} --build-arg HUB=${HUB}
|
||||
docker.higress-amd64: $(AMD64_OUT_LINUX)/higress
|
||||
docker.higress-amd64: docker/Dockerfile.higress
|
||||
$(HIGRESS_DOCKER_AMD64_RULE)
|
||||
|
||||
docker.higress-buildx: BUILD_ARGS=--build-arg BASE_VERSION=${HIGRESS_BASE_VERSION} --build-arg HUB=${HUB}
|
||||
docker.higress-buildx: $(AMD64_OUT_LINUX)/higress
|
||||
docker.higress-buildx: $(ARM64_OUT_LINUX)/higress
|
||||
@@ -40,3 +45,4 @@ IMG_URL ?= $(HUB)/$(IMG):$(TAG)
|
||||
|
||||
HIGRESS_DOCKER_BUILDX_RULE ?= $(foreach VARIANT,$(DOCKER_BUILD_VARIANTS), time (mkdir -p $(HIGRESS_DOCKER_BUILD_TOP)/$@ && TARGET_ARCH=$(TARGET_ARCH) ./docker/docker-copy.sh $^ $(HIGRESS_DOCKER_BUILD_TOP)/$@ && cd $(HIGRESS_DOCKER_BUILD_TOP)/$@ $(BUILD_PRE) && docker buildx create --name higress --node higress0 --platform linux/amd64,linux/arm64 --use && docker buildx build --no-cache --platform linux/amd64,linux/arm64 $(BUILD_ARGS) --build-arg BASE_DISTRIBUTION=$(call normalize-tag,$(VARIANT)) -t $(IMG_URL)$(call variant-tag,$(VARIANT)) -f Dockerfile.higress . --push ); )
|
||||
HIGRESS_DOCKER_RULE ?= $(foreach VARIANT,$(DOCKER_BUILD_VARIANTS), time (mkdir -p $(HIGRESS_DOCKER_BUILD_TOP)/$@ && TARGET_ARCH=$(TARGET_ARCH) ./docker/docker-copy.sh $^ $(HIGRESS_DOCKER_BUILD_TOP)/$@ && cd $(HIGRESS_DOCKER_BUILD_TOP)/$@ $(BUILD_PRE) && docker build $(BUILD_ARGS) --build-arg BASE_DISTRIBUTION=$(call normalize-tag,$(VARIANT)) -t $(IMG_URL)$(call variant-tag,$(VARIANT)) -f Dockerfile.higress . ); )
|
||||
HIGRESS_DOCKER_AMD64_RULE ?= $(foreach VARIANT,$(DOCKER_BUILD_VARIANTS), time (mkdir -p $(HIGRESS_DOCKER_BUILD_TOP)/$@ && TARGET_ARCH=amd64 ./docker/docker-copy.sh $^ $(HIGRESS_DOCKER_BUILD_TOP)/$@ && cd $(HIGRESS_DOCKER_BUILD_TOP)/$@ $(BUILD_PRE) && docker build $(BUILD_ARGS) --build-arg BASE_DISTRIBUTION=$(call normalize-tag,$(VARIANT)) --build-arg TARGETARCH=amd64 -t $(IMG_URL)$(call variant-tag,$(VARIANT)) -f Dockerfile.higress . ); )
|
||||
|
||||
Submodule envoy/envoy updated: e2707255f1...7f18940fbc
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.1.6
|
||||
appVersion: 2.1.7
|
||||
description: Helm chart for deploying higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -15,4 +15,4 @@ dependencies:
|
||||
repository: "file://../redis"
|
||||
version: 0.0.1
|
||||
type: application
|
||||
version: 2.1.6
|
||||
version: 2.1.7
|
||||
|
||||
@@ -247,6 +247,23 @@ spec:
|
||||
properties:
|
||||
spec:
|
||||
properties:
|
||||
proxies:
|
||||
items:
|
||||
properties:
|
||||
connectTimeout:
|
||||
type: integer
|
||||
listenerPort:
|
||||
type: integer
|
||||
name:
|
||||
type: string
|
||||
serverAddress:
|
||||
type: string
|
||||
serverPort:
|
||||
type: integer
|
||||
type:
|
||||
type: string
|
||||
type: object
|
||||
type: array
|
||||
registries:
|
||||
items:
|
||||
properties:
|
||||
@@ -309,6 +326,8 @@ spec:
|
||||
type: integer
|
||||
protocol:
|
||||
type: string
|
||||
proxyName:
|
||||
type: string
|
||||
sni:
|
||||
type: string
|
||||
type:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: file://../core
|
||||
version: 2.1.6
|
||||
version: 2.1.7
|
||||
- name: higress-console
|
||||
repository: https://higress.io/helm-charts/
|
||||
version: 2.1.6
|
||||
digest: sha256:c5bebb3bd92bf799804443faf9ab69e88ed26815a709e58911859b504b3d04db
|
||||
generated: "2025-07-30T21:13:57.834398+08:00"
|
||||
version: 2.1.7
|
||||
digest: sha256:c5bc8ddcc56c66751217aee5c7a40da0a906bfa9fc5c671cc4ae6e456db6bc21
|
||||
generated: "2025-09-01T15:19:26.228634+08:00"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
apiVersion: v2
|
||||
appVersion: 2.1.6
|
||||
appVersion: 2.1.7
|
||||
description: Helm chart for deploying Higress gateways
|
||||
icon: https://higress.io/img/higress_logo_small.png
|
||||
home: http://higress.io/
|
||||
@@ -12,9 +12,9 @@ sources:
|
||||
dependencies:
|
||||
- name: higress-core
|
||||
repository: "file://../core"
|
||||
version: 2.1.6
|
||||
version: 2.1.7
|
||||
- name: higress-console
|
||||
repository: "https://higress.io/helm-charts/"
|
||||
version: 2.1.6
|
||||
version: 2.1.7
|
||||
type: application
|
||||
version: 2.1.6
|
||||
version: 2.1.7
|
||||
|
||||
@@ -704,9 +704,9 @@ func TestK8sObject_ResolveK8sConflict(t *testing.T) {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
newObj := tt.o1.ResolveK8sConflict()
|
||||
if !newObj.Equal(tt.o2) {
|
||||
newObjjson, _ := newObj.JSON()
|
||||
wantedObjjson, _ := tt.o2.JSON()
|
||||
t.Errorf("Got: %s, want: %s", string(newObjjson), string(wantedObjjson))
|
||||
newObjJson, _ := newObj.JSON()
|
||||
wantedObjJson, _ := tt.o2.JSON()
|
||||
t.Errorf("Got: %s, want: %s", string(newObjJson), string(wantedObjJson))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func (o *K8sInstaller) Install() error {
|
||||
return err1
|
||||
}
|
||||
fmt.Fprintf(o.writer, "\n✔️ Wrote Profile in kubernetes configmap: \"%s\" \n", profileName)
|
||||
fmt.Fprintf(o.writer, "\n Use bellow kubectl command to edit profile for upgrade. \n")
|
||||
fmt.Fprintf(o.writer, "\n Use below kubectl command to edit profile for upgrade. \n")
|
||||
fmt.Fprintf(o.writer, " ================================================================================== \n")
|
||||
names := strings.Split(profileName, "/")
|
||||
fmt.Fprintf(o.writer, " kubectl edit configmap %s -n %s \n", names[1], names[0])
|
||||
|
||||
Submodule istio/proxy updated: d411a4f019...ced6d8167a
@@ -93,6 +93,15 @@ func (p Protocol) IsUnsupported() bool {
|
||||
return p == Unsupported
|
||||
}
|
||||
|
||||
func (p Protocol) IsSupportedByProxy() bool {
|
||||
switch p {
|
||||
case HTTPS:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p Protocol) String() string {
|
||||
return string(p)
|
||||
}
|
||||
|
||||
59
pkg/common/proxy.go
Normal file
59
pkg/common/proxy.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ProxyType string
|
||||
|
||||
const (
|
||||
ProxyType_Unknown ProxyType = "Unknown"
|
||||
ProxyType_HTTP ProxyType = "HTTP"
|
||||
ProxyType_HTTPS ProxyType = "HTTPS"
|
||||
ProxyType_SOCKS4 ProxyType = "SOCKS4"
|
||||
ProxyType_SOCKS5 ProxyType = "SOCKS5"
|
||||
)
|
||||
|
||||
func ParseProxyType(s string) ProxyType {
|
||||
switch strings.ToLower(s) {
|
||||
case "http":
|
||||
return ProxyType_HTTP
|
||||
case "https":
|
||||
return ProxyType_HTTPS
|
||||
case "socks4":
|
||||
return ProxyType_SOCKS4
|
||||
case "socks5":
|
||||
return ProxyType_SOCKS5
|
||||
}
|
||||
return ProxyType_Unknown
|
||||
}
|
||||
|
||||
func (p ProxyType) GetTransportProtocol() Protocol {
|
||||
switch p {
|
||||
case ProxyType_HTTP:
|
||||
return HTTP
|
||||
case ProxyType_HTTPS:
|
||||
return HTTPS
|
||||
case ProxyType_SOCKS4, ProxyType_SOCKS5:
|
||||
return TCP
|
||||
}
|
||||
return Unsupported
|
||||
}
|
||||
|
||||
func (p ProxyType) String() string {
|
||||
return string(p)
|
||||
}
|
||||
@@ -69,7 +69,7 @@ import (
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/wasmplugin"
|
||||
. "github.com/alibaba/higress/pkg/ingress/log"
|
||||
"github.com/alibaba/higress/pkg/kube"
|
||||
"github.com/alibaba/higress/registry/memory"
|
||||
"github.com/alibaba/higress/registry"
|
||||
"github.com/alibaba/higress/registry/reconcile"
|
||||
)
|
||||
|
||||
@@ -340,10 +340,6 @@ func (m *IngressConfig) listFromIngressControllers(typ config.GroupVersionKind,
|
||||
}
|
||||
IngressLog.Infof("Append %d configmap EnvoyFilters", len(configmapEnvoyFilters))
|
||||
}
|
||||
if len(envoyFilters) == 0 {
|
||||
IngressLog.Infof("resource type %s, configs number %d", typ, len(m.cachedEnvoyFilters))
|
||||
return m.cachedEnvoyFilters
|
||||
}
|
||||
envoyFilters = append(envoyFilters, m.cachedEnvoyFilters...)
|
||||
IngressLog.Infof("resource type %s, configs number %d", typ, len(envoyFilters))
|
||||
return envoyFilters
|
||||
@@ -490,6 +486,22 @@ func (m *IngressConfig) convertVirtualService(configs []common.WrapperConfig) []
|
||||
VirtualServices: map[string]*common.WrapperVirtualService{},
|
||||
HTTPRoutes: map[string][]*common.WrapperHTTPRoute{},
|
||||
Route2Ingress: map[string]*common.WrapperConfigWithRuleKey{},
|
||||
ServiceWrappers: make(map[string]*common.ServiceWrapper),
|
||||
ProxyWrappers: make(map[string]*common.ProxyWrapper),
|
||||
}
|
||||
if m.RegistryReconciler != nil {
|
||||
for _, sew := range m.RegistryReconciler.GetAllServiceWrapper() {
|
||||
hosts := sew.ServiceEntry.Hosts
|
||||
if len(hosts) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, host := range hosts {
|
||||
convertOptions.ServiceWrappers[host] = sew
|
||||
}
|
||||
}
|
||||
for _, pw := range m.RegistryReconciler.GetAllProxyWrapper() {
|
||||
convertOptions.ProxyWrappers[pw.ProxyName] = pw
|
||||
}
|
||||
}
|
||||
|
||||
// convert http route
|
||||
@@ -616,6 +628,7 @@ func (m *IngressConfig) convertEnvoyFilter(convertOptions *common.ConvertOptions
|
||||
mappings := map[string]*common.Rule{}
|
||||
|
||||
initHttp2RpcGlobalConfig := true
|
||||
initMcpSseGlobalFilter := true
|
||||
for _, routes := range convertOptions.HTTPRoutes {
|
||||
for _, route := range routes {
|
||||
if strings.HasSuffix(route.HTTPRoute.Name, "app-root") {
|
||||
@@ -635,6 +648,19 @@ func (m *IngressConfig) convertEnvoyFilter(convertOptions *common.ConvertOptions
|
||||
}
|
||||
}
|
||||
|
||||
loadBalance := route.WrapperConfig.AnnotationsConfig.LoadBalance
|
||||
if loadBalance != nil && loadBalance.McpSseStateful {
|
||||
IngressLog.Infof("Found MCP SSE stateful session for route %s", route.HTTPRoute.Name)
|
||||
envoyFilter, err := m.constructMcpSseStatefulSessionEnvoyFilter(route, m.namespace, initMcpSseGlobalFilter)
|
||||
if err != nil {
|
||||
IngressLog.Errorf("Construct MCP SSE stateful session EnvoyFilter error %v", err)
|
||||
} else {
|
||||
IngressLog.Infof("Append MCP SSE stateful session EnvoyFilter for route %s", route.HTTPRoute.Name)
|
||||
envoyFilters = append(envoyFilters, *envoyFilter)
|
||||
initMcpSseGlobalFilter = false
|
||||
}
|
||||
}
|
||||
|
||||
auth := route.WrapperConfig.AnnotationsConfig.Auth
|
||||
if auth == nil {
|
||||
continue
|
||||
@@ -669,6 +695,12 @@ func (m *IngressConfig) convertEnvoyFilter(convertOptions *common.ConvertOptions
|
||||
}
|
||||
}
|
||||
|
||||
if proxyEnvoyFilters := constructProxyEnvoyFilters(convertOptions.ProxyWrappers, convertOptions.ServiceWrappers, m.namespace); len(proxyEnvoyFilters) != 0 {
|
||||
for _, ef := range proxyEnvoyFilters {
|
||||
envoyFilters = append(envoyFilters, *ef)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Support other envoy filters
|
||||
|
||||
IngressLog.Infof("Found %d number of envoyFilters", len(envoyFilters))
|
||||
@@ -1113,7 +1145,7 @@ func (m *IngressConfig) AddOrUpdateWasmPlugin(clusterNamespacedName util.Cluster
|
||||
Labels: map[string]string{constants.AlwaysPushLabel: "true"},
|
||||
}
|
||||
for _, f := range m.wasmPluginHandlers {
|
||||
IngressLog.Debug("WasmPlugin triggerd update")
|
||||
IngressLog.Debug("WasmPlugin triggered update")
|
||||
f(config.Config{Meta: metadata}, config.Config{Meta: metadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
istioWasmPlugin, err := m.convertIstioWasmPlugin(&wasmPlugin.Spec)
|
||||
@@ -1155,7 +1187,7 @@ func (m *IngressConfig) DeleteWasmPlugin(clusterNamespacedName util.ClusterNames
|
||||
Labels: map[string]string{constants.AlwaysPushLabel: "true"},
|
||||
}
|
||||
for _, f := range m.wasmPluginHandlers {
|
||||
IngressLog.Debug("WasmPlugin triggerd update")
|
||||
IngressLog.Debug("WasmPlugin triggered update")
|
||||
f(config.Config{Meta: metadata}, config.Config{Meta: metadata}, istiomodel.EventDelete)
|
||||
}
|
||||
}
|
||||
@@ -1211,23 +1243,23 @@ func (m *IngressConfig) AddOrUpdateMcpBridge(clusterNamespacedName util.ClusterN
|
||||
}
|
||||
|
||||
for _, f := range m.serviceEntryHandlers {
|
||||
IngressLog.Debug("McpBridge triggerd serviceEntry update")
|
||||
IngressLog.Debug("McpBridge triggered serviceEntry update")
|
||||
f(config.Config{Meta: seMetadata}, config.Config{Meta: seMetadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
for _, f := range m.destinationRuleHandlers {
|
||||
IngressLog.Debug("McpBridge triggerd destinationRule update")
|
||||
IngressLog.Debug("McpBridge triggered destinationRule update")
|
||||
f(config.Config{Meta: drMetadata}, config.Config{Meta: drMetadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
for _, f := range m.virtualServiceHandlers {
|
||||
IngressLog.Debug("McpBridge triggerd virtualservice update")
|
||||
IngressLog.Debug("McpBridge triggered virtualservice update")
|
||||
f(config.Config{Meta: vsMetadata}, config.Config{Meta: vsMetadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
for _, f := range m.wasmPluginHandlers {
|
||||
IngressLog.Debug("McpBridge triggerd wasmplugin update")
|
||||
IngressLog.Debug("McpBridge triggered wasmplugin update")
|
||||
f(config.Config{Meta: wasmMetadata}, config.Config{Meta: wasmMetadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
for _, f := range m.envoyFilterHandlers {
|
||||
IngressLog.Debug("McpBridge triggerd envoyfilter update")
|
||||
IngressLog.Debug("McpBridge triggered envoyfilter update")
|
||||
f(config.Config{Meta: efMetadata}, config.Config{Meta: efMetadata}, istiomodel.EventUpdate)
|
||||
}
|
||||
}, m.localKubeClient, m.namespace, m.clusterId.String())
|
||||
@@ -1295,7 +1327,7 @@ func (m *IngressConfig) DeleteHttp2Rpc(clusterNamespacedName util.ClusterNamespa
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
if hit {
|
||||
IngressLog.Infof("Http2Rpc triggerd deleted event executed %s", clusterNamespacedName.Name)
|
||||
IngressLog.Infof("Http2Rpc triggered deleted event executed %s", clusterNamespacedName.Name)
|
||||
push := func(gvk config.GroupVersionKind) {
|
||||
m.XDSUpdater.ConfigUpdate(&istiomodel.PushRequest{
|
||||
Full: true,
|
||||
@@ -1493,7 +1525,7 @@ func (m *IngressConfig) constructHttp2RpcEnvoyFilter(http2rpcConfig *annotations
|
||||
return &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, http2rpcConfig.Name),
|
||||
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "http2rpc", http2rpcConfig.Name, "route", common.ConvertToDNSLabelValid(httpRoute.Name)),
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: &networking.EnvoyFilter{
|
||||
@@ -1675,28 +1707,150 @@ func constructBasicAuthEnvoyFilter(rules *common.BasicAuthRules, namespace strin
|
||||
}, nil
|
||||
}
|
||||
|
||||
func QueryByName(serviceEntries []*memory.ServiceWrapper, serviceName string) (*memory.ServiceWrapper, error) {
|
||||
IngressLog.Infof("Found http2rpc serviceEntries %s", serviceEntries)
|
||||
for _, se := range serviceEntries {
|
||||
if se.ServiceName == serviceName {
|
||||
return se, nil
|
||||
func constructProxyEnvoyFilters(proxyWrappers map[string]*common.ProxyWrapper, serviceWrappers map[string]*common.ServiceWrapper, namespace string) []*config.Config {
|
||||
var envoyFilters []*config.Config
|
||||
for _, proxyWrapper := range proxyWrappers {
|
||||
envoyFilters = append(envoyFilters, &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "proxy", proxyWrapper.ProxyName),
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: proxyWrapper.EnvoyFilter,
|
||||
})
|
||||
}
|
||||
|
||||
// Create a cluster for each service that uses a proxy.
|
||||
var serviceProxyPatches []*networking.EnvoyFilter_EnvoyConfigObjectPatch
|
||||
for _, serviceWrapper := range serviceWrappers {
|
||||
proxyConfig := serviceWrapper.ProxyConfig
|
||||
if proxyConfig == nil || proxyConfig.ProxyName == "" {
|
||||
continue
|
||||
}
|
||||
IngressLog.Debugf("Found service %s using proxy %s", serviceWrapper.ServiceName, proxyConfig.ProxyName)
|
||||
if err := validateServiceWrapperForProxy(serviceWrapper); err != nil {
|
||||
IngressLog.Warnf("Service wrapper validation failed for proxy: %v", err)
|
||||
continue
|
||||
}
|
||||
proxyWrapper := proxyWrappers[proxyConfig.ProxyName]
|
||||
if proxyWrapper == nil {
|
||||
IngressLog.Warnf("Service %s has proxy config %s, but no corresponding proxy wrapper found", serviceWrapper.ServiceName, proxyConfig.ProxyName)
|
||||
continue
|
||||
}
|
||||
if !proxyConfig.UpstreamProtocol.IsSupportedByProxy() {
|
||||
IngressLog.Warnf("Proxy %s does not support upstream protocol %s, skipping EnvoyFilter construction for service %s")
|
||||
continue
|
||||
}
|
||||
if proxyWrapper.EnvoyFilter == nil {
|
||||
IngressLog.Warnf("Proxy %s has no EnvoyFilter generated, meaning not ready for use.", proxyConfig.ProxyName)
|
||||
continue
|
||||
}
|
||||
se := serviceWrapper.ServiceEntry
|
||||
if se == nil || len(se.Hosts) == 0 || len(se.Ports) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, host := range se.Hosts {
|
||||
IngressLog.Debugf("Constructing EnvoyFilter for service %s using proxy %s", host, proxyConfig.ProxyName)
|
||||
for _, port := range se.Ports {
|
||||
if port == nil || port.Number <= 0 {
|
||||
continue
|
||||
}
|
||||
clusterName := fmt.Sprintf("outbound|%d||%s", port.Number, host)
|
||||
|
||||
// We need to delete the original cluster and add a new one pointing to the local proxy listener.
|
||||
serviceProxyPatches = append(serviceProxyPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
ApplyTo: networking.EnvoyFilter_CLUSTER,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Cluster{
|
||||
Cluster: &networking.EnvoyFilter_ClusterMatch{
|
||||
Name: clusterName,
|
||||
},
|
||||
},
|
||||
},
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_REMOVE,
|
||||
},
|
||||
})
|
||||
|
||||
patchObj := map[string]interface{}{
|
||||
"name": clusterName,
|
||||
"type": "STATIC",
|
||||
"connect_timeout": "10s",
|
||||
"load_assignment": map[string]interface{}{
|
||||
"cluster_name": clusterName,
|
||||
"endpoints": []map[string]interface{}{
|
||||
{
|
||||
"lb_endpoints": []map[string]interface{}{
|
||||
{
|
||||
"endpoint": map[string]interface{}{
|
||||
"address": map[string]interface{}{
|
||||
"socket_address": map[string]interface{}{
|
||||
"address": "127.0.0.1",
|
||||
"port_value": proxyWrapper.ListenerPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if proxyConfig.UpstreamProtocol.IsHTTPS() {
|
||||
tlsTypedConfig := map[string]interface{}{
|
||||
"@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext",
|
||||
}
|
||||
if proxyConfig.UpstreamSni != "" {
|
||||
tlsTypedConfig["sni"] = proxyConfig.UpstreamSni
|
||||
}
|
||||
patchObj["transport_socket"] = map[string]interface{}{
|
||||
"name": "envoy.transport_sockets.tls",
|
||||
"typed_config": tlsTypedConfig,
|
||||
}
|
||||
}
|
||||
patchJson, _ := json.Marshal(patchObj)
|
||||
serviceProxyPatches = append(serviceProxyPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
ApplyTo: networking.EnvoyFilter_CLUSTER,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
},
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_ADD,
|
||||
Value: util.BuildPatchStruct(string(patchJson)),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("can't find ServiceEntry by serviceName:%v", serviceName)
|
||||
if len(serviceProxyPatches) != 0 {
|
||||
envoyFilters = append(envoyFilters, &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "service-proxy"),
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: &networking.EnvoyFilter{
|
||||
ConfigPatches: serviceProxyPatches,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return envoyFilters
|
||||
}
|
||||
|
||||
func QueryRpcServiceVersion(serviceEntry *memory.ServiceWrapper, serviceName string) (string, error) {
|
||||
IngressLog.Infof("Found http2rpc serviceEntry %s", serviceEntry)
|
||||
IngressLog.Infof("Found http2rpc ServiceEntry %s", serviceEntry.ServiceEntry)
|
||||
IngressLog.Infof("Found http2rpc WorkloadSelector %s", serviceEntry.ServiceEntry.WorkloadSelector)
|
||||
IngressLog.Infof("Found http2rpc Labels %s", serviceEntry.ServiceEntry.WorkloadSelector.Labels)
|
||||
labels := (*serviceEntry).ServiceEntry.WorkloadSelector.Labels
|
||||
for key, value := range labels {
|
||||
if key == "version" {
|
||||
return value, nil
|
||||
}
|
||||
func validateServiceWrapperForProxy(serviceWrapper *common.ServiceWrapper) error {
|
||||
registryType := registry.ServiceRegistryType(serviceWrapper.RegistryType)
|
||||
switch registryType {
|
||||
case registry.DNS:
|
||||
break
|
||||
default:
|
||||
return fmt.Errorf("service %s has proxy config %s, but registry type %s is not supported for proxying", serviceWrapper.ServiceName, serviceWrapper.ProxyConfig.ProxyName, registryType)
|
||||
}
|
||||
return "", fmt.Errorf("can't get RpcServiceVersion for serviceName:%v", serviceName)
|
||||
if len(serviceWrapper.ServiceEntry.Endpoints) > 1 {
|
||||
return fmt.Errorf("service %s has multiple endpoints, which is not supported for proxying with EnvoyFilter. Skipping EnvoyFilter construction", serviceWrapper.ServiceName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *IngressConfig) Run(stop <-chan struct{}) {
|
||||
@@ -1800,6 +1954,99 @@ func (m *IngressConfig) Delete(config.GroupVersionKind, string, string, *string)
|
||||
return common.ErrUnsupportedOp
|
||||
}
|
||||
|
||||
func (m *IngressConfig) constructMcpSseStatefulSessionEnvoyFilter(route *common.WrapperHTTPRoute, namespace string, initGlobalFilter bool) (*config.Config, error) {
|
||||
httpRoute := route.HTTPRoute
|
||||
|
||||
var configPatches []*networking.EnvoyFilter_EnvoyConfigObjectPatch
|
||||
|
||||
// Add global HTTP filter if this is the first route using MCP SSE stateful session
|
||||
if initGlobalFilter {
|
||||
configPatches = append(configPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
|
||||
Listener: &networking.EnvoyFilter_ListenerMatch{
|
||||
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
|
||||
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
|
||||
Name: "envoy.filters.network.http_connection_manager",
|
||||
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
|
||||
Name: "envoy.filters.http.router",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_INSERT_BEFORE,
|
||||
Value: buildPatchStruct(`{
|
||||
"name": "envoy.filters.http.mcp_sse_stateful_session",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.mcp_sse_stateful_session.v3alpha.McpSseStatefulSession"
|
||||
}
|
||||
}`),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add route-specific configuration
|
||||
configPatches = append(configPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
|
||||
ApplyTo: networking.EnvoyFilter_HTTP_ROUTE,
|
||||
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
|
||||
Context: networking.EnvoyFilter_GATEWAY,
|
||||
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_RouteConfiguration{
|
||||
RouteConfiguration: &networking.EnvoyFilter_RouteConfigurationMatch{
|
||||
Vhost: &networking.EnvoyFilter_RouteConfigurationMatch_VirtualHostMatch{
|
||||
Route: &networking.EnvoyFilter_RouteConfigurationMatch_RouteMatch{
|
||||
Name: httpRoute.Name,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Patch: &networking.EnvoyFilter_Patch{
|
||||
Operation: networking.EnvoyFilter_Patch_MERGE,
|
||||
Value: buildPatchStruct(`{
|
||||
"typed_per_filter_config": {
|
||||
"envoy.filters.http.mcp_sse_stateful_session": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.mcp_sse_stateful_session.v3alpha.McpSseStatefulSessionPerRoute",
|
||||
"value": {
|
||||
"mcp_sse_stateful_session": {
|
||||
"session_state": {
|
||||
"name": "envoy.http.mcp_sse_stateful_session.envelope",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.http.mcp_sse_stateful_session.envelope.v3alpha.EnvelopeSessionState",
|
||||
"value": {
|
||||
"param_name": "sessionId",
|
||||
"chunk_end_patterns": ["\r\n\r\n", "\n\n", "\r\r"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"strict": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`),
|
||||
},
|
||||
})
|
||||
|
||||
return &config.Config{
|
||||
Meta: config.Meta{
|
||||
GroupVersionKind: gvk.EnvoyFilter,
|
||||
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "mcp-lb-route", common.ConvertToDNSLabelValid(httpRoute.Name)),
|
||||
Namespace: namespace,
|
||||
},
|
||||
Spec: &networking.EnvoyFilter{
|
||||
ConfigPatches: configPatches,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *IngressConfig) notifyXDSFullUpdate(gvk config.GroupVersionKind, reason istiomodel.TriggerReason, updatedConfigName *util.ClusterNamespacedName) {
|
||||
var configsUpdated map[istiomodel.ConfigKey]struct{}
|
||||
if updatedConfigName != nil {
|
||||
|
||||
@@ -66,9 +66,10 @@ type consistentHashByCookie struct {
|
||||
}
|
||||
|
||||
type LoadBalanceConfig struct {
|
||||
simple networking.LoadBalancerSettings_SimpleLB
|
||||
other *consistentHashByOther
|
||||
cookie *consistentHashByCookie
|
||||
simple networking.LoadBalancerSettings_SimpleLB
|
||||
other *consistentHashByOther
|
||||
cookie *consistentHashByCookie
|
||||
McpSseStateful bool
|
||||
}
|
||||
|
||||
type loadBalance struct{}
|
||||
@@ -129,7 +130,11 @@ func (l loadBalance) Parse(annotations Annotations, config *Ingress, _ *GlobalCo
|
||||
} else {
|
||||
if lb, err := annotations.ParseStringASAP(loadBalanceAnnotation); err == nil {
|
||||
lb = strings.ToUpper(lb)
|
||||
loadBalanceConfig.simple = networking.LoadBalancerSettings_SimpleLB(networking.LoadBalancerSettings_SimpleLB_value[lb])
|
||||
if lb == "MCP-SSE" {
|
||||
loadBalanceConfig.McpSseStateful = true
|
||||
} else {
|
||||
loadBalanceConfig.simple = networking.LoadBalancerSettings_SimpleLB(networking.LoadBalancerSettings_SimpleLB_value[lb])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,9 +16,8 @@ package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/pkg/cert"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
|
||||
networking "istio.io/api/networking/v1alpha3"
|
||||
"istio.io/istio/pilot/pkg/model"
|
||||
"istio.io/istio/pkg/cluster"
|
||||
@@ -26,6 +25,10 @@ import (
|
||||
gatewaytool "istio.io/istio/pkg/config/gateway"
|
||||
listerv1 "k8s.io/client-go/listers/core/v1"
|
||||
"k8s.io/client-go/tools/cache"
|
||||
|
||||
"github.com/alibaba/higress/pkg/cert"
|
||||
"github.com/alibaba/higress/pkg/common"
|
||||
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
|
||||
)
|
||||
|
||||
type ServiceKey struct {
|
||||
@@ -120,6 +123,68 @@ type WrapperDestinationRule struct {
|
||||
ServiceKey ServiceKey
|
||||
}
|
||||
|
||||
type ServiceProxyConfig struct {
|
||||
ProxyName string
|
||||
UpstreamProtocol common.Protocol
|
||||
UpstreamSni string
|
||||
}
|
||||
|
||||
type ServiceWrapper struct {
|
||||
ServiceName string
|
||||
ServiceEntry *networking.ServiceEntry
|
||||
DestinationRuleWrapper *WrapperDestinationRule
|
||||
Suffix string
|
||||
RegistryType string
|
||||
RegistryName string
|
||||
ProxyConfig *ServiceProxyConfig
|
||||
createTime time.Time
|
||||
}
|
||||
|
||||
func (sew *ServiceWrapper) DeepCopy() *ServiceWrapper {
|
||||
res := &ServiceWrapper{}
|
||||
*res = *sew
|
||||
res.ServiceEntry = sew.ServiceEntry.DeepCopy()
|
||||
|
||||
if sew.DestinationRuleWrapper != nil {
|
||||
res.DestinationRuleWrapper = sew.DestinationRuleWrapper
|
||||
res.DestinationRuleWrapper.DestinationRule = sew.DestinationRuleWrapper.DestinationRule.DeepCopy()
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (sew *ServiceWrapper) SetCreateTime(createTime time.Time) {
|
||||
sew.createTime = createTime
|
||||
}
|
||||
|
||||
func (sew *ServiceWrapper) GetCreateTime() time.Time {
|
||||
return sew.createTime
|
||||
}
|
||||
|
||||
type ProxyWrapper struct {
|
||||
ProxyName string
|
||||
ListenerPort uint32
|
||||
EnvoyFilter *networking.EnvoyFilter
|
||||
createTime time.Time
|
||||
}
|
||||
|
||||
func (pw *ProxyWrapper) DeepCopy() *ProxyWrapper {
|
||||
res := &ProxyWrapper{}
|
||||
*res = *pw
|
||||
|
||||
if pw.EnvoyFilter != nil {
|
||||
res.EnvoyFilter = pw.EnvoyFilter.DeepCopy()
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (pw *ProxyWrapper) SetCreateTime(createTime time.Time) {
|
||||
pw.createTime = createTime
|
||||
}
|
||||
|
||||
func (pw *ProxyWrapper) GetCreateTime() time.Time {
|
||||
return pw.createTime
|
||||
}
|
||||
|
||||
type IngressController interface {
|
||||
// RegisterEventHandler adds a handler to receive config update events for a
|
||||
// configuration type
|
||||
|
||||
@@ -169,6 +169,10 @@ type ConvertOptions struct {
|
||||
|
||||
Service2TrafficPolicy map[ServiceKey]*WrapperTrafficPolicy
|
||||
|
||||
ServiceWrappers map[string]*ServiceWrapper
|
||||
|
||||
ProxyWrappers map[string]*ProxyWrapper
|
||||
|
||||
HasDefaultBackend bool
|
||||
}
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ func GetHost(annotations map[string]string) string {
|
||||
|
||||
// Istio requires that the name of the gateway must conform to the DNS label.
|
||||
// For details, you can view: https://github.com/istio/istio/blob/2d5c40ad5e9cceebe64106005aa38381097da2ba/pkg/config/validation/validation.go#L478
|
||||
func convertToDNSLabelValid(input string) string {
|
||||
func ConvertToDNSLabelValid(input string) string {
|
||||
hasher := md5.New()
|
||||
hasher.Write([]byte(input))
|
||||
hash := hasher.Sum(nil)
|
||||
@@ -156,7 +156,7 @@ func convertToDNSLabelValid(input string) string {
|
||||
|
||||
// CleanHost follow the format of mse-ops for host.
|
||||
func CleanHost(host string) string {
|
||||
return convertToDNSLabelValid(host)
|
||||
return ConvertToDNSLabelValid(host)
|
||||
}
|
||||
|
||||
func CreateConvertedName(items ...string) string {
|
||||
|
||||
@@ -158,7 +158,7 @@ func (c *ConfigmapMgr) AddOrUpdateHigressConfig(name util.ClusterNamespacedName)
|
||||
IngressLog.Infof("configmapMgr oldHigressConfig: %s", GetHigressConfigString(oldHigressConfig))
|
||||
IngressLog.Infof("configmapMgr newHigressConfig: %s", GetHigressConfigString(newHigressConfig))
|
||||
result, _ := c.CompareHigressConfig(oldHigressConfig, newHigressConfig)
|
||||
IngressLog.Infof("configmapMgr CompareHigressConfig reuslt is %d", result)
|
||||
IngressLog.Infof("configmapMgr CompareHigressConfig result is %d", result)
|
||||
|
||||
if result == ResultNothing {
|
||||
return
|
||||
@@ -177,7 +177,7 @@ func (c *ConfigmapMgr) AddOrUpdateHigressConfig(name util.ClusterNamespacedName)
|
||||
}
|
||||
}
|
||||
c.SetHigressConfig(newHigressConfig)
|
||||
IngressLog.Infof("configmapMgr higress config AddOrUpdate success, reuslt is %d", result)
|
||||
IngressLog.Infof("configmapMgr higress config AddOrUpdate success, result is %d", result)
|
||||
// Call updateConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -509,6 +509,11 @@ func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
|
||||
}
|
||||
|
||||
func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
|
||||
// if no servers, return empty string
|
||||
if mcp == nil || len(mcp.Servers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Build servers configuration
|
||||
servers := "[]"
|
||||
if len(mcp.Servers) > 0 {
|
||||
|
||||
@@ -566,7 +566,7 @@ func TestMcpServerController_ConstructEnvoyFilters(t *testing.T) {
|
||||
MatchList: []*MatchRule{},
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantConfigs: 2, // Both session and server filters
|
||||
wantConfigs: 1, // Only session filter when no servers configured
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
@@ -744,24 +744,7 @@ func TestMcpServerController_constructMcpServerStruct(t *testing.T) {
|
||||
mcp: &McpServer{
|
||||
Servers: []*SSEServer{},
|
||||
},
|
||||
wantJSON: `{
|
||||
"name": "envoy.filters.http.golang",
|
||||
"typed_config": {
|
||||
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
|
||||
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
|
||||
"value": {
|
||||
"library_id": "mcp-server",
|
||||
"library_path": "/var/lib/istio/envoy/golang-filter.so",
|
||||
"plugin_name": "mcp-server",
|
||||
"plugin_config": {
|
||||
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
|
||||
"value": {
|
||||
"servers": []
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantJSON: "", // Return empty string when no servers configured
|
||||
},
|
||||
{
|
||||
name: "with servers",
|
||||
|
||||
@@ -286,7 +286,7 @@ func testConvertHTTPRoute(t *testing.T, c common.KIngressController) {
|
||||
expectNoError: true,
|
||||
},
|
||||
{
|
||||
description: "valid httpRoute convention, vaild ingress",
|
||||
description: "valid httpRoute convention, valid ingress",
|
||||
input: struct {
|
||||
options *common.ConvertOptions
|
||||
wrapperConfig *common.WrapperConfig
|
||||
|
||||
@@ -57,12 +57,12 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
|
||||
}
|
||||
|
||||
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
if f.message {
|
||||
for _, server := range f.config.servers {
|
||||
if f.path == server.BaseServer.GetMessageEndpoint() {
|
||||
if !endStream {
|
||||
return api.StopAndBuffer
|
||||
}
|
||||
// Create a response recorder to capture the response
|
||||
recorder := httptest.NewRecorder()
|
||||
// Call the handleMessage method of SSEServer with complete body
|
||||
|
||||
@@ -77,7 +77,7 @@ func (n *NacosMcpRegistry) refreshToolsListForGroup(group string, serviceMatcher
|
||||
serviceList := services.Doms
|
||||
pattern, err := regexp.Compile(serviceMatcher)
|
||||
if err != nil {
|
||||
api.LogErrorf("Match service error for patter %s", serviceMatcher)
|
||||
api.LogErrorf("Match service error for pattern %s", serviceMatcher)
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -149,6 +149,9 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
|
||||
func (f *filter) processMcpRequestHeadersForSSEUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
|
||||
// We don't need to process the request body for SSE upstream.
|
||||
f.skipRequestBody = true
|
||||
// Remove Accept-Encoding header to avoid gzip encoding,
|
||||
// which our response body handling logic doesn't support.
|
||||
header.Del("Accept-Encoding")
|
||||
return api.Continue
|
||||
}
|
||||
|
||||
|
||||
@@ -14,5 +14,5 @@ export {SetCtx,
|
||||
ProcessRequestHeadersBy,
|
||||
ProcessResponseBodyBy,
|
||||
ProcessResponseHeadersBy,
|
||||
Logger, RegisteTickFunc} from "./plugin_wrapper"
|
||||
Logger, RegisterTickFunc} from "./plugin_wrapper"
|
||||
export {ParseResult} from "./rule_matcher"
|
||||
@@ -156,7 +156,7 @@ class TickFuncEntry {
|
||||
|
||||
var globalOnTickFuncs = new Array<TickFuncEntry>();
|
||||
|
||||
export function RegisteTickFunc(tickPeriod: i64, tickFunc: () => void): void {
|
||||
export function RegisterTickFunc(tickPeriod: i64, tickFunc: () => void): void {
|
||||
globalOnTickFuncs.push(new TickFuncEntry(0, tickPeriod, tickFunc));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
export * from "@higress/proxy-wasm-assemblyscript-sdk/assembly/proxy";
|
||||
import { SetCtx, HttpContext, ProcessRequestHeadersBy, Logger, ParseResult, ParseConfigBy, RegisteTickFunc, ProcessResponseHeadersBy } from "@higress/wasm-assemblyscript/assembly";
|
||||
import { SetCtx, HttpContext, ProcessRequestHeadersBy, Logger, ParseResult, ParseConfigBy, RegisterTickFunc, ProcessResponseHeadersBy } from "@higress/wasm-assemblyscript/assembly";
|
||||
import { FilterHeadersStatusValues, send_http_response, stream_context } from "@higress/proxy-wasm-assemblyscript-sdk/assembly"
|
||||
import { JSON } from "assemblyscript-json/assembly";
|
||||
class HelloWorldConfig {
|
||||
@@ -12,10 +12,10 @@ SetCtx<HelloWorldConfig>("hello-world",
|
||||
])
|
||||
|
||||
function parseConfig(json: JSON.Obj): ParseResult<HelloWorldConfig> {
|
||||
RegisteTickFunc(2000, () => {
|
||||
RegisterTickFunc(2000, () => {
|
||||
Logger.Debug("tick 2s");
|
||||
})
|
||||
RegisteTickFunc(5000, () => {
|
||||
RegisterTickFunc(5000, () => {
|
||||
Logger.Debug("tick 5s");
|
||||
})
|
||||
return new ParseResult<HelloWorldConfig>(new HelloWorldConfig(), true);
|
||||
|
||||
@@ -243,7 +243,7 @@ class RouteRuleMatcher {
|
||||
std::string route_name;
|
||||
getValue({"route_name"}, &route_name);
|
||||
std::string service_name;
|
||||
getValue({"service_name"}, &service_name);
|
||||
getValue({"cluster_name"}, &service_name);
|
||||
std::optional<std::reference_wrapper<PluginConfig>> match_config;
|
||||
std::optional<std::reference_wrapper<std::unordered_set<std::string>>>
|
||||
allow_set;
|
||||
|
||||
@@ -15,7 +15,7 @@ WORKDIR /workspace/extensions/$PLUGIN_NAME
|
||||
|
||||
RUN go mod tidy
|
||||
RUN \
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o /main.wasm ./
|
||||
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o /main.wasm .
|
||||
|
||||
FROM scratch AS output
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ builder:
|
||||
@echo "image: ${BUILDER}"
|
||||
|
||||
local-build:
|
||||
cd extensions/${PLUGIN_NAME};GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o ./main.wasm ./
|
||||
cd extensions/${PLUGIN_NAME};GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o ./main.wasm .
|
||||
|
||||
@echo ""
|
||||
@echo "wasm: extensions/${PLUGIN_NAME}/main.wasm"
|
||||
|
||||
@@ -148,6 +148,49 @@ spec:
|
||||
|
||||
所有规则会按上面配置的顺序一次执行匹配,当有一个规则匹配时,就停止匹配,并选择匹配的配置执行插件逻辑。
|
||||
|
||||
## 单元测试
|
||||
|
||||
在开发wasm插件时,建议同时编写单元测试来验证插件功能。详细的单元测试编写指南请参考 [wasm plugin unit test](https://github.com/higress-group/wasm-go/blob/main/pkg/test/README.md)。
|
||||
|
||||
### 单元测试样例
|
||||
|
||||
```go
|
||||
func TestMyPlugin(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 1. 创建测试主机
|
||||
config := json.RawMessage(`{"key": "value"}`)
|
||||
host, status := test.NewTestHost(config)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
defer host.Reset()
|
||||
|
||||
// 2. 设置请求头
|
||||
headers := [][2]string{
|
||||
{":method", "GET"},
|
||||
{":path", "/test"},
|
||||
{":authority", "test.com"},
|
||||
}
|
||||
|
||||
// 3. 调用插件请求头处理方法
|
||||
action := host.CallOnHttpRequestHeaders(headers)
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 4. 模拟外部调用响应(如果需要)
|
||||
|
||||
// host.CallOnRedisCall(0, test.CreateRedisRespString("OK"))
|
||||
|
||||
// host.CallOnHttpCall([][2]string{{":status", "200"}}, []byte(`{"result": "success"}`))
|
||||
|
||||
// 5. 完成请求
|
||||
host.CompleteHttp()
|
||||
|
||||
// 6. 验证结果(如果插件里返回了响应)
|
||||
localResponse := host.GetLocalResponse()
|
||||
require.NotNil(t, localResponse)
|
||||
assert.Equal(t, uint32(200), localResponse.StatusCode)
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
## E2E测试
|
||||
|
||||
当你完成一个GO语言的插件功能时, 可以同时创建关联的e2e test cases, 并在本地对插件功能完成测试验证。
|
||||
|
||||
@@ -139,6 +139,52 @@ spec:
|
||||
The rules will be matched in the order of configuration. If one match is found, it will stop, and the matching configuration will take effect.
|
||||
|
||||
|
||||
## Unit Testing
|
||||
|
||||
When developing wasm plugins, it's recommended to write unit tests to verify plugin functionality. For detailed unit testing guidelines, please refer to [wasm plugin unit test](https://github.com/higress-group/wasm-go/blob/main/pkg/test/README.md).
|
||||
|
||||
### Unit Test Structure Example
|
||||
|
||||
```go
|
||||
func TestMyPlugin(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 1. Create test host
|
||||
config := json.RawMessage(`{"key": "value"}`)
|
||||
host, status := test.NewTestHost(config)
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
defer host.Reset()
|
||||
|
||||
// 2. Set request headers
|
||||
headers := [][2]string{
|
||||
{":method", "GET"},
|
||||
{":path", "/test"},
|
||||
{":authority", "test.com"},
|
||||
}
|
||||
|
||||
// 3. Call plugin request header processing method
|
||||
action := host.CallOnHttpRequestHeaders(headers)
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 4. Simulate external call responses (if needed)
|
||||
|
||||
// host.CallOnRedisCall(0, test.CreateRedisRespString("OK"))
|
||||
|
||||
// host.CallOnHttpCall([][2]string{{":status", "200"}}, []byte(`{"result": "success"}`))
|
||||
|
||||
// 5. Complete request
|
||||
host.CompleteHttp()
|
||||
|
||||
// 6. Verify results (if the plugin returns a response)
|
||||
localResponse := host.GetLocalResponse()
|
||||
require.NotNil(t, localResponse)
|
||||
assert.Equal(t, uint32(200), localResponse.StatusCode)
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
This example shows the basic test structure including configuration parsing, request processing flow, and result verification.
|
||||
|
||||
|
||||
## E2E test
|
||||
|
||||
When you complete a GO plug-in function, you can create associated e2e test cases at the same time, and complete the test verification of the plug-in function locally.
|
||||
|
||||
@@ -5,15 +5,21 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v0.0.0-20250628101008-bea7da01a545
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v0.0.0-20250628101008-bea7da01a545 h1:qb/Rhhfm1gzr/stim/L0cKNo0MPatdo0Rd8iYOAPWE0=
|
||||
github.com/higress-group/wasm-go v0.0.0-20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -19,6 +22,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
|
||||
1835
plugins/wasm-go/extensions/ai-agent/main_test.go
Normal file
1835
plugins/wasm-go/extensions/ai-agent/main_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,14 +8,20 @@ toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/resp v0.1.1
|
||||
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
1195
plugins/wasm-go/extensions/ai-cache/main_test.go
Normal file
1195
plugins/wasm-go/extensions/ai-cache/main_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,14 +7,20 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/resp v0.1.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -1,10 +1,129 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基本Redis配置
|
||||
var basicRedisConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"redis": map[string]interface{}{
|
||||
"serviceName": "redis.static",
|
||||
"servicePort": 6379,
|
||||
"timeout": 1000,
|
||||
"database": 0,
|
||||
},
|
||||
"questionFrom": map[string]interface{}{
|
||||
"requestBody": "messages.@reverse.0.content",
|
||||
},
|
||||
"answerValueFrom": map[string]interface{}{
|
||||
"responseBody": "choices.0.message.content",
|
||||
},
|
||||
"answerStreamValueFrom": map[string]interface{}{
|
||||
"responseBody": "choices.0.delta.content",
|
||||
},
|
||||
"cacheKeyPrefix": "higress-ai-history:",
|
||||
"identityHeader": "Authorization",
|
||||
"fillHistoryCnt": 3,
|
||||
"cacheTTL": 3600,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:最小Redis配置(使用默认值)
|
||||
var minimalRedisConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"redis": map[string]interface{}{
|
||||
"serviceName": "redis.static",
|
||||
},
|
||||
"questionFrom": map[string]interface{}{
|
||||
"requestBody": "messages.@reverse.0.content",
|
||||
},
|
||||
"answerValueFrom": map[string]interface{}{
|
||||
"responseBody": "choices.0.message.content",
|
||||
},
|
||||
"answerStreamValueFrom": map[string]interface{}{
|
||||
"responseBody": "choices.0.delta.content",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:自定义Redis配置
|
||||
var customRedisConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"redis": map[string]interface{}{
|
||||
"serviceName": "custom-redis.dns",
|
||||
"servicePort": 6380,
|
||||
"username": "admin",
|
||||
"password": "password123",
|
||||
"timeout": 2000,
|
||||
"database": 1,
|
||||
},
|
||||
"questionFrom": map[string]interface{}{
|
||||
"requestBody": "query.text",
|
||||
},
|
||||
"answerValueFrom": map[string]interface{}{
|
||||
"responseBody": "response.content",
|
||||
},
|
||||
"answerStreamValueFrom": map[string]interface{}{
|
||||
"responseBody": "response.delta.content",
|
||||
},
|
||||
"cacheKeyPrefix": "custom-history:",
|
||||
"identityHeader": "X-User-ID",
|
||||
"fillHistoryCnt": 5,
|
||||
"cacheTTL": 7200,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:带认证的Redis配置
|
||||
var authRedisConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"redis": map[string]interface{}{
|
||||
"serviceName": "auth-redis.static",
|
||||
"servicePort": 6379,
|
||||
"username": "user",
|
||||
"password": "pass",
|
||||
"timeout": 1500,
|
||||
"database": 2,
|
||||
},
|
||||
"questionFrom": map[string]interface{}{
|
||||
"requestBody": "messages.@reverse.0.content",
|
||||
},
|
||||
"answerValueFrom": map[string]interface{}{
|
||||
"responseBody": "choices.0.message.content",
|
||||
},
|
||||
"answerStreamValueFrom": map[string]interface{}{
|
||||
"responseBody": "choices.0.delta.content",
|
||||
},
|
||||
"cacheKeyPrefix": "auth-history:",
|
||||
"identityHeader": "X-Auth-Token",
|
||||
"fillHistoryCnt": 4,
|
||||
"cacheTTL": 1800,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestDistinctChat(t *testing.T) {
|
||||
type args struct {
|
||||
chat []ChatHistory
|
||||
@@ -34,3 +153,627 @@ func TestDistinctChat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本Redis配置解析
|
||||
t.Run("basic redis config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// 类型断言
|
||||
pluginConfig, ok := config.(*PluginConfig)
|
||||
require.True(t, ok, "config should be *PluginConfig")
|
||||
|
||||
// 验证Redis配置字段
|
||||
require.Equal(t, "redis.static", pluginConfig.RedisInfo.ServiceName)
|
||||
require.Equal(t, 6379, pluginConfig.RedisInfo.ServicePort)
|
||||
require.Equal(t, 1000, pluginConfig.RedisInfo.Timeout)
|
||||
require.Equal(t, 0, pluginConfig.RedisInfo.Database)
|
||||
require.Equal(t, "", pluginConfig.RedisInfo.Username)
|
||||
require.Equal(t, "", pluginConfig.RedisInfo.Password)
|
||||
|
||||
// 验证问题提取配置
|
||||
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
|
||||
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
|
||||
|
||||
// 验证答案提取配置
|
||||
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
|
||||
|
||||
// 验证流式答案提取配置
|
||||
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
|
||||
|
||||
// 验证其他配置字段
|
||||
require.Equal(t, "higress-ai-history:", pluginConfig.CacheKeyPrefix)
|
||||
require.Equal(t, "Authorization", pluginConfig.IdentityHeader)
|
||||
require.Equal(t, 3, pluginConfig.FillHistoryCnt)
|
||||
require.Equal(t, 3600, pluginConfig.CacheTTL)
|
||||
})
|
||||
|
||||
// 测试最小Redis配置解析(使用默认值)
|
||||
t.Run("minimal redis config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimalRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// 类型断言
|
||||
pluginConfig, ok := config.(*PluginConfig)
|
||||
require.True(t, ok, "config should be *PluginConfig")
|
||||
|
||||
// 验证Redis配置字段(使用默认值)
|
||||
require.Equal(t, "redis.static", pluginConfig.RedisInfo.ServiceName)
|
||||
require.Equal(t, 80, pluginConfig.RedisInfo.ServicePort) // 对于.static服务,默认端口是80
|
||||
require.Equal(t, 1000, pluginConfig.RedisInfo.Timeout) // 默认超时
|
||||
require.Equal(t, 0, pluginConfig.RedisInfo.Database) // 默认数据库
|
||||
require.Equal(t, "", pluginConfig.RedisInfo.Username)
|
||||
require.Equal(t, "", pluginConfig.RedisInfo.Password)
|
||||
|
||||
// 验证问题提取配置(使用默认值)
|
||||
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
|
||||
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
|
||||
|
||||
// 验证答案提取配置(使用默认值)
|
||||
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
|
||||
|
||||
// 验证流式答案提取配置(使用默认值)
|
||||
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
|
||||
|
||||
// 验证其他配置字段(使用默认值)
|
||||
require.Equal(t, "higress-ai-history:", pluginConfig.CacheKeyPrefix)
|
||||
require.Equal(t, "Authorization", pluginConfig.IdentityHeader)
|
||||
require.Equal(t, 3, pluginConfig.FillHistoryCnt)
|
||||
require.Equal(t, 0, pluginConfig.CacheTTL) // 默认永不过期
|
||||
})
|
||||
|
||||
// 测试自定义Redis配置解析
|
||||
t.Run("custom redis config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// 类型断言
|
||||
pluginConfig, ok := config.(*PluginConfig)
|
||||
require.True(t, ok, "config should be *PluginConfig")
|
||||
|
||||
// 验证Redis配置字段
|
||||
require.Equal(t, "custom-redis.dns", pluginConfig.RedisInfo.ServiceName)
|
||||
require.Equal(t, 6380, pluginConfig.RedisInfo.ServicePort)
|
||||
require.Equal(t, 2000, pluginConfig.RedisInfo.Timeout)
|
||||
require.Equal(t, 1, pluginConfig.RedisInfo.Database)
|
||||
require.Equal(t, "admin", pluginConfig.RedisInfo.Username)
|
||||
require.Equal(t, "password123", pluginConfig.RedisInfo.Password)
|
||||
|
||||
// 验证问题提取配置(插件硬编码,不从配置读取)
|
||||
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
|
||||
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
|
||||
|
||||
// 验证答案提取配置(插件硬编码,不从配置读取)
|
||||
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
|
||||
|
||||
// 验证流式答案提取配置(插件硬编码,不从配置读取)
|
||||
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
|
||||
|
||||
// 验证其他配置字段
|
||||
require.Equal(t, "custom-history:", pluginConfig.CacheKeyPrefix)
|
||||
require.Equal(t, "X-User-ID", pluginConfig.IdentityHeader)
|
||||
require.Equal(t, 5, pluginConfig.FillHistoryCnt)
|
||||
require.Equal(t, 7200, pluginConfig.CacheTTL)
|
||||
})
|
||||
|
||||
// 测试带认证的Redis配置解析
|
||||
t.Run("auth redis config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(authRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// 类型断言
|
||||
pluginConfig, ok := config.(*PluginConfig)
|
||||
require.True(t, ok, "config should be *PluginConfig")
|
||||
|
||||
// 验证Redis配置字段
|
||||
require.Equal(t, "auth-redis.static", pluginConfig.RedisInfo.ServiceName)
|
||||
require.Equal(t, 6379, pluginConfig.RedisInfo.ServicePort)
|
||||
require.Equal(t, 1500, pluginConfig.RedisInfo.Timeout)
|
||||
require.Equal(t, 2, pluginConfig.RedisInfo.Database)
|
||||
require.Equal(t, "user", pluginConfig.RedisInfo.Username)
|
||||
require.Equal(t, "pass", pluginConfig.RedisInfo.Password)
|
||||
|
||||
// 验证问题提取配置
|
||||
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
|
||||
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
|
||||
|
||||
// 验证答案提取配置
|
||||
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
|
||||
|
||||
// 验证流式答案提取配置
|
||||
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
|
||||
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
|
||||
|
||||
// 验证其他配置字段
|
||||
require.Equal(t, "auth-history:", pluginConfig.CacheKeyPrefix)
|
||||
require.Equal(t, "X-Auth-Token", pluginConfig.IdentityHeader)
|
||||
require.Equal(t, 4, pluginConfig.FillHistoryCnt)
|
||||
require.Equal(t, 1800, pluginConfig.CacheTTL)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试JSON内容类型的请求头处理
|
||||
t.Run("JSON content type headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置JSON内容类型的请求头,包含身份标识
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要读取请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
})
|
||||
|
||||
// 测试非JSON内容类型的请求头处理
|
||||
t.Run("non-JSON content type headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置非JSON内容类型的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "text/plain"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue,但不会读取请求体
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试缺少身份标识的请求头处理
|
||||
t.Run("missing identity header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置缺少身份标识的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue,因为缺少身份标识
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试自定义身份标识头
|
||||
t.Run("custom identity header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置自定义身份标识头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"x-user-id", "user456"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试缓存命中的请求体处理
|
||||
t.Run("cache hit request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 构造请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,请介绍一下自己"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待Redis响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟Redis缓存命中响应
|
||||
cacheResponse := `[{"role":"user","content":"之前的问题"},{"role":"assistant","content":"之前的回答"}]`
|
||||
resp := test.CreateRedisRespString(cacheResponse)
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试缓存未命中的请求体处理
|
||||
t.Run("cache miss request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 构造请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "今天天气怎么样?"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待Redis响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟Redis缓存未命中响应
|
||||
resp := test.CreateRedisRespNull()
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试流式请求的请求体处理
|
||||
t.Run("streaming request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 构造流式请求体
|
||||
requestBody := `{
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请用流式方式回答"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待Redis响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟Redis缓存未命中响应
|
||||
resp := test.CreateRedisRespNull()
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试查询历史请求的请求体处理
|
||||
t.Run("query history request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/ai-history/query?cnt=2"},
|
||||
{":method", "GET"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 构造请求体(需要包含messages字段,因为插件会尝试提取问题)
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "查询历史"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待Redis响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟Redis缓存命中响应
|
||||
cacheResponse := `[{"role":"user","content":"问题1"},{"role":"assistant","content":"回答1"},{"role":"user","content":"问题2"},{"role":"assistant","content":"回答2"}]`
|
||||
resp := test.CreateRedisRespString(cacheResponse)
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpResponseHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试流式响应头处理
|
||||
t.Run("streaming response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 必须先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 设置流式响应头
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "text/event-stream"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试非流式响应头处理
|
||||
t.Run("non-streaming response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 必须先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 设置非流式响应头
|
||||
action := host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpStreamResponseBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试流式响应体处理 - 非流式模式
|
||||
t.Run("non-streaming mode", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "测试问题"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理,设置必要的上下文
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 模拟Redis缓存未命中,设置QuestionContextKey
|
||||
resp := test.CreateRedisRespNull()
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 测试非流式响应体处理
|
||||
chunk := []byte(`{"choices":[{"message":{"content":"测试回答"}}]}`)
|
||||
action := host.CallOnHttpStreamingResponseBody(chunk, true)
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试流式响应体处理 - 流式模式
|
||||
t.Run("streaming mode", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "测试流式问题"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理,设置必要的上下文
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 模拟Redis缓存未命中,设置QuestionContextKey
|
||||
resp := test.CreateRedisRespNull()
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 设置流式响应头
|
||||
host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "text/event-stream"},
|
||||
})
|
||||
|
||||
// 测试流式响应体处理 - 非最后一个chunk
|
||||
chunk1 := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")
|
||||
action1 := host.CallOnHttpStreamingResponseBody(chunk1, false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
|
||||
// 测试流式响应体处理 - 最后一个chunk
|
||||
chunk2 := []byte("data: {\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")
|
||||
action2 := host.CallOnHttpStreamingResponseBody(chunk2, true)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
})
|
||||
|
||||
// 测试查询历史路径的流式响应体处理
|
||||
t.Run("query history path", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/ai-history/query?cnt=2"},
|
||||
{":method", "GET"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "查询历史"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理,设置必要的上下文
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 模拟Redis缓存命中,设置QuestionContextKey
|
||||
cacheResponse := `[{"role":"user","content":"问题1"},{"role":"assistant","content":"回答1"}]`
|
||||
resp := test.CreateRedisRespString(cacheResponse)
|
||||
host.CallOnRedisCall(0, resp)
|
||||
|
||||
// 测试查询历史路径的响应体处理
|
||||
chunk := []byte("test chunk")
|
||||
action := host.CallOnHttpStreamingResponseBody(chunk, true)
|
||||
|
||||
// 应该直接返回chunk,不进行处理
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试没有QuestionContextKey的情况
|
||||
t.Run("no question context", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicRedisConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
{"authorization", "Bearer user123"},
|
||||
})
|
||||
|
||||
// 不调用请求体处理,所以没有QuestionContextKey
|
||||
|
||||
// 测试没有QuestionContextKey的响应体处理
|
||||
chunk := []byte("test chunk")
|
||||
action := host.CallOnHttpStreamingResponseBody(chunk, true)
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,13 +5,21 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
|
||||
@@ -2,14 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@@ -22,5 +24,7 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
616
plugins/wasm-go/extensions/ai-image-reader/main_test.go
Normal file
616
plugins/wasm-go/extensions/ai-image-reader/main_test.go
Normal file
@@ -0,0 +1,616 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基本DashScope OCR配置
|
||||
var basicDashScopeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "dashscope",
|
||||
"apiKey": "test-api-key-123",
|
||||
"serviceName": "ocr-service",
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"servicePort": 443,
|
||||
"timeout": 10000,
|
||||
"model": "qwen-vl-ocr",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:最小DashScope配置(使用默认值)
|
||||
var minimalDashScopeConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "dashscope",
|
||||
"apiKey": "minimal-api-key",
|
||||
"serviceName": "ocr-service",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:自定义端口和超时配置
|
||||
var customPortTimeoutConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "dashscope",
|
||||
"apiKey": "custom-api-key",
|
||||
"serviceName": "ocr-service",
|
||||
"serviceHost": "custom.dashscope.com",
|
||||
"servicePort": 8443,
|
||||
"timeout": 30000,
|
||||
"model": "qwen-vl-ocr",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:自定义模型配置
|
||||
var customModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "dashscope",
|
||||
"apiKey": "model-api-key",
|
||||
"serviceName": "ocr-service",
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"servicePort": 443,
|
||||
"timeout": 15000,
|
||||
"model": "custom-ocr-model",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本DashScope配置解析
|
||||
t.Run("basic dashscope config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试最小DashScope配置解析(使用默认值)
|
||||
t.Run("minimal dashscope config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimalDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试自定义端口和超时配置解析
|
||||
t.Run("custom port timeout config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customPortTimeoutConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试自定义模型配置解析
|
||||
t.Run("custom model config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试JSON内容类型的请求头处理
|
||||
t.Run("JSON content type headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置JSON内容类型的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue,因为禁用了重路由但允许继续处理
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试非JSON内容类型的请求头处理
|
||||
t.Run("non-JSON content type headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置非JSON内容类型的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "text/plain"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue,但不会读取请求体
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试缺少content-type的请求头处理
|
||||
t.Run("missing content type headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置缺少content-type的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试包含单张图片的请求体处理
|
||||
t.Run("single image request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造包含单张图片的请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "这张图片里有什么?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image1.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待OCR响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟OCR服务响应
|
||||
ocrResponse := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "图片中包含一些文字内容"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 模拟HTTP调用响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(ocrResponse))
|
||||
|
||||
modifiedBody := host.GetRequestBody()
|
||||
require.NotNil(t, modifiedBody)
|
||||
require.Contains(t, string(modifiedBody), "图片中包含一些文字内容")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试包含多张图片的请求体处理
|
||||
t.Run("multiple images request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造包含多张图片的请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "这些图片里有什么?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image1.jpg"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image2.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待OCR响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟第一张图片的OCR响应
|
||||
ocrResponse1 := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "第一张图片包含文字A"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 模拟第二张图片的OCR响应
|
||||
ocrResponse2 := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "第二张图片包含文字B"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 模拟第一个HTTP调用响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(ocrResponse1))
|
||||
|
||||
// 模拟第二个HTTP调用响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(ocrResponse2))
|
||||
|
||||
modifiedBody := host.GetRequestBody()
|
||||
require.NotNil(t, modifiedBody)
|
||||
require.Contains(t, string(modifiedBody), "第一张图片包含文字A")
|
||||
require.Contains(t, string(modifiedBody), "第二张图片包含文字B")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试不包含图片的请求体处理
|
||||
t.Run("no image request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造不包含图片的请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "你好,请介绍一下自己"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionContinue,因为没有图片需要处理
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置验证
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试缺少type配置
|
||||
t.Run("missing type", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"apiKey": "test-api-key",
|
||||
"serviceName": "ocr-service",
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"servicePort": 443,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的type
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
// 测试缺少apiKey配置
|
||||
t.Run("missing apiKey", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "dashscope",
|
||||
"serviceName": "ocr-service",
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"servicePort": 443,
|
||||
"timeout": 10000,
|
||||
"model": "qwen-vl-ocr",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的apiKey
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
// 测试缺少serviceName配置
|
||||
t.Run("missing serviceName", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "dashscope",
|
||||
"apiKey": "test-api-key",
|
||||
"serviceHost": "dashscope.aliyuncs.com",
|
||||
"servicePort": 443,
|
||||
"timeout": 10000,
|
||||
"model": "qwen-vl-ocr",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的serviceName
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
// 测试未知的provider类型
|
||||
t.Run("unknown provider type", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"type": "unknown-provider",
|
||||
"apiKey": "test-api-key",
|
||||
"serviceName": "ocr-service",
|
||||
"serviceHost": "example.com",
|
||||
"servicePort": 443,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为provider类型未知
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试边界情况
|
||||
func TestEdgeCases(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试空请求体
|
||||
t.Run("empty request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 调用请求体处理 - 空请求体
|
||||
action := host.CallOnHttpRequestBody([]byte{})
|
||||
|
||||
// 应该返回ActionContinue,因为没有图片需要处理
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试无效JSON请求体
|
||||
t.Run("invalid JSON request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 调用请求体处理 - 无效JSON
|
||||
invalidJSON := []byte(`{"messages": [{"role": "user", "content": "test"}`)
|
||||
action := host.CallOnHttpRequestBody(invalidJSON)
|
||||
|
||||
// 应该返回ActionContinue,因为JSON解析失败
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试OCR服务错误响应
|
||||
t.Run("OCR service error response", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造包含图片的请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "这张图片里有什么?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image1.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟OCR服务错误响应
|
||||
errorResponse := `{
|
||||
"error": "Service unavailable",
|
||||
"message": "OCR service is down"
|
||||
}`
|
||||
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "503"},
|
||||
}, []byte(errorResponse))
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试OCR服务返回空结果
|
||||
t.Run("OCR service empty response", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicDashScopeConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造包含图片的请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "这张图片里有什么?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image1.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟OCR服务返回空结果
|
||||
emptyResponse := `{
|
||||
"choices": []
|
||||
}`
|
||||
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(emptyResponse))
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -7,14 +7,20 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
531
plugins/wasm-go/extensions/ai-intent/main_test.go
Normal file
531
plugins/wasm-go/extensions/ai-intent/main_test.go
Normal file
@@ -0,0 +1,531 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基本意图识别配置
|
||||
var basicIntentConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "金融|电商|法律|Higress",
|
||||
"prompt": "你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:'%s',预设的类别为'%s',直接返回一种具体类别,如果没有找到就返回'NotFound'。",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
"proxyUrl": "http://ai.example.com/v1/chat/completions",
|
||||
"proxyModel": "qwen-long",
|
||||
"proxyPort": 80,
|
||||
"proxyDomain": "ai.example.com",
|
||||
"proxyTimeout": 10000,
|
||||
"proxyApiKey": "test-api-key",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{
|
||||
"requestBody": "messages.@reverse.0.content",
|
||||
"responseBody": "choices.0.message.content",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:自定义提示词配置
|
||||
var customPromptConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "技术|产品|运营|设计",
|
||||
"prompt": "请分析以下问题属于哪个技术领域:%s,可选领域:%s,请直接返回领域名称。",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
"proxyUrl": "https://ai.example.com/v1/chat/completions",
|
||||
"proxyModel": "gpt-3.5-turbo",
|
||||
"proxyPort": 443,
|
||||
"proxyDomain": "ai.example.com",
|
||||
"proxyTimeout": 15000,
|
||||
"proxyApiKey": "custom-api-key",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{
|
||||
"requestBody": "query",
|
||||
"responseBody": "result",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:最小配置(使用默认值)
|
||||
var minimalConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "A|B|C",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
"proxyUrl": "http://ai.example.com/v1/chat/completions",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:HTTPS配置
|
||||
var httpsConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "客服|销售|技术支持",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
"proxyUrl": "https://ai.example.com:8443/v1/chat/completions",
|
||||
"proxyModel": "claude-3",
|
||||
"proxyTimeout": 20000,
|
||||
"proxyApiKey": "https-api-key",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{
|
||||
"requestBody": "input.text",
|
||||
"responseBody": "output.classification",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本意图识别配置解析
|
||||
t.Run("basic intent config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试自定义提示词配置解析
|
||||
t.Run("custom prompt config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(customPromptConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试最小配置解析(使用默认值)
|
||||
t.Run("minimal config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(minimalConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试HTTPS配置解析
|
||||
t.Run("https config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(httpsConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试请求头处理
|
||||
t.Run("request headers processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为禁用了重路由
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试请求体处理 - 金融类问题
|
||||
t.Run("financial question processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造请求体 - 金融类问题
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天股市怎么样?"}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause,因为需要等待LLM响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟LLM响应 - 返回"金融"类别
|
||||
llmResponse := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "金融"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 模拟HTTP调用响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(llmResponse))
|
||||
|
||||
// 验证插件是否正确处理了LLM响应
|
||||
// 插件应该将"金融"类别设置到Property中
|
||||
// 通过host.GetProperty验证意图类别是否被正确设置
|
||||
intentCategory, err := host.GetProperty([]string{"intent_category"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "金融", string(intentCategory))
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试请求体处理 - 电商类问题
|
||||
t.Run("ecommerce question processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造请求体 - 电商类问题
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{"role": "user", "content": "这个商品什么时候发货?"}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟LLM响应 - 返回"电商"类别
|
||||
llmResponse := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "电商"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 模拟HTTP调用响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(llmResponse))
|
||||
|
||||
// 验证插件是否正确处理了LLM响应
|
||||
// 插件应该将"电商"类别设置到Property中
|
||||
// 通过host.GetProperty验证意图类别是否被正确设置
|
||||
intentCategory, err := host.GetProperty([]string{"intent_category"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "电商", string(intentCategory))
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试请求体处理 - 未找到类别
|
||||
t.Run("category not found processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造请求体 - 不相关的问题
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天天气怎么样?"}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟LLM响应 - 返回"NotFound"
|
||||
llmResponse := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "NotFound"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
// 模拟HTTP调用响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(llmResponse))
|
||||
|
||||
_, err := host.GetProperty([]string{"intent_category"})
|
||||
// 应该返回错误,因为没有设置该Property
|
||||
require.Error(t, err)
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试配置验证
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试缺少scene.category配置
|
||||
t.Run("missing scene.category", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"prompt": "test prompt",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
"proxyUrl": "http://ai.example.com/v1/chat/completions",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的scene.category
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
// 测试缺少llm.proxyServiceName配置
|
||||
t.Run("missing llm.proxyServiceName", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "A|B|C",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyUrl": "http://ai.example.com/v1/chat/completions",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的llm.proxyServiceName
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
// 测试缺少llm.proxyUrl配置
|
||||
t.Run("missing llm.proxyUrl", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "A|B|C",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
},
|
||||
"keyFrom": map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的llm.proxyUrl
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
|
||||
// 测试缺少必需字段的配置
|
||||
t.Run("missing required fields", func(t *testing.T) {
|
||||
invalidConfig := func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"scene": map[string]interface{}{
|
||||
"category": "A|B|C",
|
||||
},
|
||||
"llm": map[string]interface{}{
|
||||
"proxyServiceName": "ai-service",
|
||||
// 故意不设置proxyUrl,这是必需的
|
||||
},
|
||||
"keyFrom": map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
host, status := test.NewTestHost(invalidConfig)
|
||||
defer host.Reset()
|
||||
// 应该返回错误状态,因为缺少必需的proxyUrl
|
||||
require.NotEqual(t, types.OnPluginStartStatusOK, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// 测试边界情况
|
||||
func TestEdgeCases(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
|
||||
// 测试无效JSON请求体
|
||||
t.Run("invalid JSON request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 调用请求体处理 - 无效JSON
|
||||
invalidJSON := []byte(`{"messages": [{"role": "user", "content": "test"}`)
|
||||
action := host.CallOnHttpRequestBody(invalidJSON)
|
||||
|
||||
// 应该返回ActionPause,因为需要等待LLM响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟LLM响应
|
||||
llmResponse := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "NotFound"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "200"},
|
||||
}, []byte(llmResponse))
|
||||
|
||||
// 验证插件是否正确处理了LLM响应
|
||||
// 由于返回"NotFound",插件不会设置任何意图类别到Property中
|
||||
// 验证没有设置意图类别Property
|
||||
_, err := host.GetProperty([]string{"intent_category"})
|
||||
// 应该返回错误,因为没有设置该Property
|
||||
require.Error(t, err)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试LLM服务错误响应
|
||||
t.Run("LLM service error response", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicIntentConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/api/chat"},
|
||||
{":method", "POST"},
|
||||
{"content-type", "application/json"},
|
||||
})
|
||||
|
||||
// 构造请求体
|
||||
requestBody := `{
|
||||
"messages": [
|
||||
{"role": "user", "content": "今天股市怎么样?"}
|
||||
]
|
||||
}`
|
||||
|
||||
// 调用请求体处理
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 应该返回ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟LLM服务错误响应
|
||||
errorResponse := `{
|
||||
"error": "Service unavailable",
|
||||
"message": "LLM service is down"
|
||||
}`
|
||||
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{"content-type", "application/json"},
|
||||
{":status", "503"},
|
||||
}, []byte(errorResponse))
|
||||
|
||||
// 验证插件是否正确处理了LLM错误响应
|
||||
// 由于状态码不是200,插件不会设置任何意图类别到Property中
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -5,8 +5,17 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -2,16 +2,19 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
|
||||
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -21,5 +24,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
892
plugins/wasm-go/extensions/ai-json-resp/main_test.go
Normal file
892
plugins/wasm-go/extensions/ai-json-resp/main_test.go
Normal file
@@ -0,0 +1,892 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/santhosh-tekuri/jsonschema"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基础配置
|
||||
var basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "ai-service",
|
||||
"serviceDomain": "api.openai.com",
|
||||
"servicePort": 443,
|
||||
"servicePath": "/v1/chat/completions",
|
||||
"apiKey": "sk-test123",
|
||||
"serviceTimeout": 30000,
|
||||
"maxRetry": 3,
|
||||
"contentPath": "choices.0.message.content",
|
||||
"enableContentDisposition": true,
|
||||
// 添加一个简单的JSON Schema,避免编译失败
|
||||
"jsonSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:使用serviceUrl的配置
|
||||
var serviceUrlConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "ai-service",
|
||||
"serviceUrl": "https://api.openai.com/v1/chat/completions",
|
||||
"apiKey": "sk-test456",
|
||||
"serviceTimeout": 50000,
|
||||
"maxRetry": 5,
|
||||
"contentPath": "choices.0.message.content",
|
||||
"enableContentDisposition": false,
|
||||
// 添加一个简单的JSON Schema,避免编译失败
|
||||
"jsonSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:包含JSON Schema的配置
|
||||
var jsonSchemaConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "ai-service",
|
||||
"serviceDomain": "api.openai.com",
|
||||
"servicePort": 443,
|
||||
"apiKey": "sk-test789",
|
||||
"jsonSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"name": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
"age": map[string]interface{}{
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": []string{"name"},
|
||||
},
|
||||
"enableSwagger": true,
|
||||
"enableOas3": false,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:启用OAS3的配置
|
||||
var oas3Config = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "ai-service",
|
||||
"serviceDomain": "api.openai.com",
|
||||
"servicePort": 443,
|
||||
"apiKey": "sk-test101",
|
||||
"jsonSchema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"title": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
"content": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"enableSwagger": false,
|
||||
"enableOas3": true,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:无效的JSON Schema配置
|
||||
var invalidJsonSchemaConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"serviceName": "ai-service",
|
||||
"serviceDomain": "api.openai.com",
|
||||
"servicePort": 443,
|
||||
"apiKey": "sk-test303",
|
||||
"jsonSchema": "invalid-schema",
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:缺少必需字段的配置
|
||||
var missingRequiredConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"apiKey": "sk-test404",
|
||||
"serviceTimeout": 30000,
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基础配置解析
|
||||
t.Run("basic config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
pluginConfig := config.(*PluginConfig)
|
||||
require.Equal(t, "ai-service", pluginConfig.serviceName)
|
||||
require.Equal(t, "api.openai.com", pluginConfig.serviceDomain)
|
||||
require.Equal(t, 443, pluginConfig.servicePort)
|
||||
require.Equal(t, "/v1/chat/completions", pluginConfig.servicePath)
|
||||
require.Equal(t, "sk-test123", pluginConfig.apiKey)
|
||||
require.Equal(t, 30000, pluginConfig.serviceTimeout)
|
||||
require.Equal(t, 3, pluginConfig.maxRetry)
|
||||
require.Equal(t, "choices.0.message.content", pluginConfig.contentPath)
|
||||
require.True(t, pluginConfig.enableContentDisposition)
|
||||
require.NotNil(t, pluginConfig.jsonSchema)
|
||||
require.Equal(t, jsonschema.Draft7, pluginConfig.draft)
|
||||
require.True(t, pluginConfig.enableJsonSchemaValidation)
|
||||
})
|
||||
|
||||
// 测试使用serviceUrl的配置解析
|
||||
t.Run("serviceUrl config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(serviceUrlConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
pluginConfig := config.(*PluginConfig)
|
||||
require.Equal(t, "ai-service", pluginConfig.serviceName)
|
||||
require.Equal(t, "api.openai.com", pluginConfig.serviceDomain)
|
||||
require.Equal(t, 443, pluginConfig.servicePort)
|
||||
require.Equal(t, "/v1/chat/completions", pluginConfig.servicePath)
|
||||
require.Equal(t, "sk-test456", pluginConfig.apiKey)
|
||||
require.Equal(t, 50000, pluginConfig.serviceTimeout)
|
||||
require.Equal(t, 5, pluginConfig.maxRetry)
|
||||
require.False(t, pluginConfig.enableContentDisposition)
|
||||
})
|
||||
|
||||
// 测试包含JSON Schema的配置解析
|
||||
t.Run("jsonSchema config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(jsonSchemaConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
pluginConfig := config.(*PluginConfig)
|
||||
require.NotNil(t, pluginConfig.jsonSchema)
|
||||
require.Equal(t, jsonschema.Draft4, pluginConfig.draft)
|
||||
require.True(t, pluginConfig.enableJsonSchemaValidation)
|
||||
require.NotNil(t, pluginConfig.compile)
|
||||
})
|
||||
|
||||
// 测试启用OAS3的配置解析
|
||||
t.Run("oas3 config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(oas3Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
pluginConfig := config.(*PluginConfig)
|
||||
require.Equal(t, jsonschema.Draft7, pluginConfig.draft)
|
||||
require.True(t, pluginConfig.enableJsonSchemaValidation)
|
||||
})
|
||||
|
||||
// 测试无效的JSON Schema配置
|
||||
t.Run("invalid jsonSchema config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(invalidJsonSchemaConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
pluginConfig := config.(*PluginConfig)
|
||||
// 根据插件的实际行为,无效的JSON Schema会导致编译失败
|
||||
require.Equal(t, uint32(JSON_SCHEMA_COMPILE_FAILED_CODE), pluginConfig.rejectStruct.RejectCode)
|
||||
})
|
||||
|
||||
// 测试缺少必需字段的配置
|
||||
t.Run("missing required config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(missingRequiredConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, _ := host.GetMatchConfig()
|
||||
require.NotNil(t, config)
|
||||
|
||||
pluginConfig := config.(*PluginConfig)
|
||||
// 根据插件的实际行为,缺少serviceDomain会导致JSON Schema编译失败
|
||||
require.Equal(t, uint32(JSON_SCHEMA_COMPILE_FAILED_CODE), pluginConfig.rejectStruct.RejectCode)
|
||||
require.Contains(t, pluginConfig.rejectStruct.RejectMsg, "Json Schema compile failed")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试正常请求头处理
|
||||
t.Run("normal request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Authorization", "Bearer sk-user123"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"Content-Length", "100"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试来自插件的请求头处理
|
||||
t.Run("request from this plugin", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置来自插件的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{EXTEND_HEADER_KEY, "true"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试没有Authorization头的请求
|
||||
t.Run("no authorization header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置没有Authorization的请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"Content-Length", "100"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试配置错误的请求头处理
|
||||
t.Run("config error", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(missingRequiredConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回ActionPause
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试来自插件的请求(应该直接继续)
|
||||
t.Run("request from this plugin", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头,包含EXTEND_HEADER_KEY来标记请求来自插件
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
{EXTEND_HEADER_KEY, "true"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionContinue,因为请求来自插件
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试配置错误的请求体处理
|
||||
t.Run("config error", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(missingRequiredConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionContinue,因为配置有错误
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试正常请求体处理 - 成功响应
|
||||
t.Run("normal request with successful response", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionPause,等待外部服务响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟外部服务返回成功响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"definition\": \"AI is artificial intelligence\", \"examples\": [\"machine learning\", \"natural language processing\"]}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
|
||||
response := host.GetLocalResponse()
|
||||
require.NotNil(t, response)
|
||||
require.Contains(t, string(response.Data), "definition")
|
||||
require.Contains(t, string(response.Data), "examples")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试正常请求体处理 - 需要重试的响应
|
||||
t.Run("normal request with retry response", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionPause,等待外部服务响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟外部服务返回需要重试的响应(content字段不是有效JSON)
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "AI is artificial intelligence. It includes machine learning and natural language processing."
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
|
||||
// 由于content不是有效JSON,插件会进行重试
|
||||
// 模拟重试请求的响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`{
|
||||
"id": "chatcmpl-456",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"definition\": \"AI is artificial intelligence\", \"examples\": [\"machine learning\", \"natural language processing\"]}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
|
||||
// 验证最终响应体是提取的JSON内容
|
||||
response := host.GetLocalResponse()
|
||||
require.NotNil(t, response)
|
||||
require.Contains(t, string(response.Data), "definition")
|
||||
require.Contains(t, string(response.Data), "examples")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试外部服务返回无效响应体
|
||||
t.Run("external service returns invalid response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionPause,等待外部服务响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟外部服务返回无效的响应体
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`invalid json response`))
|
||||
|
||||
// 验证响应体包含错误信息
|
||||
response := host.GetLocalResponse()
|
||||
require.NotNil(t, response)
|
||||
require.Contains(t, string(response.Data), "invalid json response")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试外部服务返回缺少content字段的响应
|
||||
t.Run("external service returns response without content field", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionPause,等待外部服务响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟外部服务返回缺少content字段的响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
|
||||
// 验证响应体包含错误信息
|
||||
response := host.GetLocalResponse()
|
||||
require.NotNil(t, response)
|
||||
require.Contains(t, string(response.Data), "response body does not contain the content")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试使用自定义servicePath的请求
|
||||
t.Run("request with custom service path", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(serviceUrlConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/custom/chat"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionPause,等待外部服务响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟外部服务返回成功响应
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "{\"answer\": \"AI is artificial intelligence\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
|
||||
// 验证响应体是提取的JSON内容
|
||||
response := host.GetLocalResponse()
|
||||
require.NotNil(t, response)
|
||||
require.Contains(t, string(response.Data), "answer")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试达到最大重试次数的情况
|
||||
t.Run("max retry count exceeded", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is AI?"}
|
||||
]
|
||||
}`
|
||||
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
// 应该返回ActionPause,等待外部服务响应
|
||||
require.Equal(t, types.ActionPause, action)
|
||||
|
||||
// 模拟多次重试,每次都返回无效的content
|
||||
for i := 0; i < 4; i++ { // 超过最大重试次数3次
|
||||
host.CallOnHttpCall([][2]string{
|
||||
{":status", "200"},
|
||||
{"content-type", "application/json"},
|
||||
}, []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "AI is artificial intelligence"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`))
|
||||
}
|
||||
|
||||
// 验证最终响应体包含重试次数超限的错误信息
|
||||
response := host.GetLocalResponse()
|
||||
require.NotNil(t, response)
|
||||
require.Contains(t, string(response.Data), "retry count exceeds max retry count")
|
||||
|
||||
// 完成HTTP请求
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestRejectStruct(t *testing.T) {
|
||||
// 测试RejectStruct的GetBytes方法
|
||||
t.Run("GetBytes", func(t *testing.T) {
|
||||
reject := RejectStruct{
|
||||
RejectCode: 1001,
|
||||
RejectMsg: "Test error message",
|
||||
}
|
||||
|
||||
bytes := reject.GetBytes()
|
||||
require.NotNil(t, bytes)
|
||||
|
||||
// 验证JSON格式
|
||||
var result RejectStruct
|
||||
err := json.Unmarshal(bytes, &result)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, uint32(1001), result.RejectCode)
|
||||
require.Equal(t, "Test error message", result.RejectMsg)
|
||||
})
|
||||
|
||||
// 测试RejectStruct的GetShortMsg方法
|
||||
t.Run("GetShortMsg", func(t *testing.T) {
|
||||
reject := RejectStruct{
|
||||
RejectCode: 1001,
|
||||
RejectMsg: "Json Schema is not valid: invalid format",
|
||||
}
|
||||
|
||||
shortMsg := reject.GetShortMsg()
|
||||
require.Equal(t, "ai-json-resp.Json Schema is not valid", shortMsg)
|
||||
})
|
||||
|
||||
// 测试RejectStruct的GetShortMsg方法 - 没有冒号的情况
|
||||
t.Run("GetShortMsg no colon", func(t *testing.T) {
|
||||
reject := RejectStruct{
|
||||
RejectCode: 1001,
|
||||
RejectMsg: "Simple error message",
|
||||
}
|
||||
|
||||
shortMsg := reject.GetShortMsg()
|
||||
require.Equal(t, "ai-json-resp.Simple error message", shortMsg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateBody(t *testing.T) {
|
||||
// 创建测试配置
|
||||
config := &PluginConfig{
|
||||
contentPath: "choices.0.message.content",
|
||||
jsonSchema: nil, // 明确设置为nil,禁用JSON Schema验证
|
||||
enableJsonSchemaValidation: false, // 禁用JSON Schema验证
|
||||
}
|
||||
|
||||
// 测试有效的响应体
|
||||
t.Run("valid response body", func(t *testing.T) {
|
||||
validBody := []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, how can I help you?"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
err := config.ValidateBody(validBody)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
// 测试无效的JSON响应体
|
||||
t.Run("invalid JSON response body", func(t *testing.T) {
|
||||
invalidBody := []byte(`invalid json content`)
|
||||
|
||||
err := config.ValidateBody(invalidBody)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint32(SERVICE_UNAVAILABLE_CODE), config.rejectStruct.RejectCode)
|
||||
require.Contains(t, config.rejectStruct.RejectMsg, "service unavailable")
|
||||
})
|
||||
|
||||
// 测试缺少content字段的响应体
|
||||
t.Run("missing content field", func(t *testing.T) {
|
||||
missingContentBody := []byte(`{
|
||||
"id": "chatcmpl-123",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
err := config.ValidateBody(missingContentBody)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint32(SERVICE_UNAVAILABLE_CODE), config.rejectStruct.RejectCode)
|
||||
require.Contains(t, config.rejectStruct.RejectMsg, "response body does not contain the content")
|
||||
})
|
||||
|
||||
// 测试空的响应体
|
||||
t.Run("empty response body", func(t *testing.T) {
|
||||
emptyBody := []byte{}
|
||||
|
||||
err := config.ValidateBody(emptyBody)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, uint32(SERVICE_UNAVAILABLE_CODE), config.rejectStruct.RejectCode)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractJson(t *testing.T) {
|
||||
// 创建测试配置
|
||||
config := &PluginConfig{
|
||||
jsonSchema: nil, // 明确设置为nil,禁用JSON Schema验证
|
||||
enableJsonSchemaValidation: false, // 禁用JSON Schema验证
|
||||
}
|
||||
|
||||
// 测试提取有效的JSON
|
||||
t.Run("extract valid JSON", func(t *testing.T) {
|
||||
content := `Here is the response: {"name": "John", "age": 30} and some other text`
|
||||
|
||||
jsonStr, err := config.ExtractJson(content)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"name": "John", "age": 30}`, jsonStr)
|
||||
})
|
||||
|
||||
// 测试提取嵌套JSON
|
||||
t.Run("extract nested JSON", func(t *testing.T) {
|
||||
content := `Response: {"user": {"name": "John", "profile": {"age": 30, "city": "NYC"}}}`
|
||||
|
||||
jsonStr, err := config.ExtractJson(content)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, `{"user": {"name": "John", "profile": {"age": 30, "city": "NYC"}}}`, jsonStr)
|
||||
})
|
||||
|
||||
// 测试没有JSON的内容
|
||||
t.Run("no JSON in content", func(t *testing.T) {
|
||||
content := `This is just plain text without any JSON content`
|
||||
|
||||
_, err := config.ExtractJson(content)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot find json in the response body")
|
||||
})
|
||||
|
||||
// 测试只有开始括号的内容
|
||||
t.Run("only opening brace", func(t *testing.T) {
|
||||
content := `Here is the start: { but no closing brace`
|
||||
|
||||
_, err := config.ExtractJson(content)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot find json in the response body")
|
||||
})
|
||||
|
||||
// 测试只有结束括号的内容
|
||||
t.Run("only closing brace", func(t *testing.T) {
|
||||
content := `Here is the end: } but no opening brace`
|
||||
|
||||
_, err := config.ExtractJson(content)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot find json in the response body")
|
||||
})
|
||||
|
||||
// 测试无效的JSON格式
|
||||
t.Run("invalid JSON format", func(t *testing.T) {
|
||||
content := `Here is invalid JSON: {"name": "John", "age": 30,}`
|
||||
|
||||
_, err := config.ExtractJson(content)
|
||||
require.Error(t, err)
|
||||
// ExtractJson会提取到{"name": "John", "age": 30,},但json.Unmarshal会失败
|
||||
// 因为JSON格式无效(末尾有多余的逗号)
|
||||
require.Contains(t, err.Error(), "invalid character '}' looking for beginning of object key string")
|
||||
})
|
||||
|
||||
// 测试多个JSON对象(应该提取第一个完整的)
|
||||
t.Run("multiple JSON objects", func(t *testing.T) {
|
||||
content := `First: {"name": "John"} Second: {"age": 30}`
|
||||
|
||||
_, err := config.ExtractJson(content)
|
||||
require.Error(t, err)
|
||||
// ExtractJson会提取到{"name": "John"} Second: {"age": 30}
|
||||
// 这不是有效的JSON,因为"Second: {"age": 30}"不是有效的JSON语法
|
||||
require.Contains(t, err.Error(), "invalid character 'S' after top-level value")
|
||||
})
|
||||
}
|
||||
@@ -5,8 +5,8 @@ go 1.24.1
|
||||
toolchain go1.24.3
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
@@ -21,3 +21,5 @@ require (
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
)
|
||||
|
||||
require github.com/tidwall/sjson v1.2.5 // indirect
|
||||
|
||||
@@ -4,10 +4,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545 h1:zPXEonKCAeLvXI1IpwGpIeVSvLY5AZ9h9uTJnOuiA3Q=
|
||||
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
@@ -18,6 +18,7 @@ github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQP
|
||||
github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -27,6 +28,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=
|
||||
|
||||
@@ -5,11 +5,19 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
|
||||
@@ -2,14 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
@@ -22,5 +24,7 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -18,9 +18,9 @@ func main() {}
|
||||
func init() {
|
||||
wrapper.SetCtx(
|
||||
"ai-prompt-decorator",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
||||
wrapper.ParseConfig(parseConfig),
|
||||
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
||||
wrapper.ProcessRequestBody(onHttpRequestBody),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -34,11 +34,11 @@ type AIPromptDecoratorConfig struct {
|
||||
Append []Message `json:"append"`
|
||||
}
|
||||
|
||||
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig, log log.Log) error {
|
||||
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig) error {
|
||||
return json.Unmarshal([]byte(jsonConfig.Raw), config)
|
||||
}
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, log log.Log) types.Action {
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig) types.Action {
|
||||
ctx.DisableReroute()
|
||||
proxywasm.RemoveHttpRequestHeader("content-length")
|
||||
return types.ActionContinue
|
||||
@@ -70,7 +70,7 @@ func decorateGeographicPrompt(entry *Message) (*Message, error) {
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log log.Log) types.Action {
|
||||
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte) types.Action {
|
||||
messageJson := `{"messages":[]}`
|
||||
|
||||
for _, entry := range config.Prepend {
|
||||
|
||||
511
plugins/wasm-go/extensions/ai-prompt-decorator/main_test.go
Normal file
511
plugins/wasm-go/extensions/ai-prompt-decorator/main_test.go
Normal file
@@ -0,0 +1,511 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基础装饰器配置
|
||||
var basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"prepend": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant from ${geo-country}.",
|
||||
},
|
||||
},
|
||||
"append": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Please provide context about ${geo-city}.",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:只有前置消息的配置
|
||||
var prependOnlyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"prepend": []map[string]interface{}{
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are located in ${geo-province}, ${geo-country}.",
|
||||
},
|
||||
},
|
||||
"append": []map[string]interface{}{}, // 显式定义空的append字段
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:空配置
|
||||
var emptyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"prepend": []map[string]interface{}{},
|
||||
"append": []map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基础装饰器配置解析
|
||||
t.Run("basic decorator config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
decoratorConfig := config.(*AIPromptDecoratorConfig)
|
||||
require.NotNil(t, decoratorConfig.Prepend)
|
||||
require.NotNil(t, decoratorConfig.Append)
|
||||
require.Len(t, decoratorConfig.Prepend, 1)
|
||||
require.Len(t, decoratorConfig.Append, 1)
|
||||
require.Equal(t, "system", decoratorConfig.Prepend[0].Role)
|
||||
require.Equal(t, "You are a helpful assistant from ${geo-country}.", decoratorConfig.Prepend[0].Content)
|
||||
require.Equal(t, "system", decoratorConfig.Append[0].Role)
|
||||
require.Equal(t, "Please provide context about ${geo-city}.", decoratorConfig.Append[0].Content)
|
||||
})
|
||||
|
||||
// 测试只有前置消息的配置解析
|
||||
t.Run("prepend only config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(prependOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
decoratorConfig := config.(*AIPromptDecoratorConfig)
|
||||
require.NotNil(t, decoratorConfig.Prepend)
|
||||
require.NotNil(t, decoratorConfig.Append)
|
||||
require.Len(t, decoratorConfig.Prepend, 1)
|
||||
require.Len(t, decoratorConfig.Append, 0)
|
||||
require.Equal(t, "system", decoratorConfig.Prepend[0].Role)
|
||||
require.Equal(t, "You are located in ${geo-province}, ${geo-country}.", decoratorConfig.Prepend[0].Content)
|
||||
})
|
||||
|
||||
// 测试空配置解析
|
||||
t.Run("empty config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(emptyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
decoratorConfig := config.(*AIPromptDecoratorConfig)
|
||||
require.NotNil(t, decoratorConfig.Prepend)
|
||||
require.NotNil(t, decoratorConfig.Append)
|
||||
require.Len(t, decoratorConfig.Prepend, 0)
|
||||
require.Len(t, decoratorConfig.Append, 0)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试请求头处理
|
||||
t.Run("request headers processing", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-length", "100"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试基础消息装饰
|
||||
t.Run("basic message decoration", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 设置地理变量属性,供插件使用
|
||||
host.SetProperty([]string{"geo-country"}, []byte("China"))
|
||||
host.SetProperty([]string{"geo-province"}, []byte("Beijing"))
|
||||
host.SetProperty([]string{"geo-city"}, []byte("Beijing"))
|
||||
host.SetProperty([]string{"geo-isp"}, []byte("China Mobile"))
|
||||
|
||||
// 设置请求体,包含消息
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"}
|
||||
]
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证消息装饰是否成功
|
||||
modifiedBody := host.GetRequestBody()
|
||||
require.NotEmpty(t, modifiedBody)
|
||||
|
||||
// 解析修改后的请求体
|
||||
var modifiedRequest map[string]interface{}
|
||||
err := json.Unmarshal(modifiedBody, &modifiedRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证messages字段存在
|
||||
messages, exists := modifiedRequest["messages"].([]interface{})
|
||||
require.True(t, exists, "messages field should exist")
|
||||
require.NotNil(t, messages)
|
||||
|
||||
// 验证消息数量:前置消息(1) + 原始消息(1) + 后置消息(1) = 3
|
||||
require.Len(t, messages, 3, "should have 3 messages: prepend + original + append")
|
||||
|
||||
// 验证第一个消息是前置消息(地理变量已被替换)
|
||||
firstMessage := messages[0].(map[string]interface{})
|
||||
require.Equal(t, "system", firstMessage["role"])
|
||||
require.Equal(t, "You are a helpful assistant from China.", firstMessage["content"])
|
||||
|
||||
// 验证第二个消息是原始用户消息
|
||||
secondMessage := messages[1].(map[string]interface{})
|
||||
require.Equal(t, "user", secondMessage["role"])
|
||||
require.Equal(t, "Hello, how are you?", secondMessage["content"])
|
||||
|
||||
// 验证第三个消息是后置消息(地理变量已被替换)
|
||||
thirdMessage := messages[2].(map[string]interface{})
|
||||
require.Equal(t, "system", thirdMessage["role"])
|
||||
require.Equal(t, "Please provide context about Beijing.", thirdMessage["content"])
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试只有前置消息的装饰
|
||||
t.Run("prepend only decoration", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(prependOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 设置地理变量属性,供插件使用
|
||||
host.SetProperty([]string{"geo-country"}, []byte("China"))
|
||||
host.SetProperty([]string{"geo-province"}, []byte("Shanghai"))
|
||||
host.SetProperty([]string{"geo-city"}, []byte("Shanghai"))
|
||||
host.SetProperty([]string{"geo-isp"}, []byte("China Telecom"))
|
||||
|
||||
// 设置请求体,包含消息
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like?"}
|
||||
]
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证消息装饰是否成功
|
||||
modifiedBody := host.GetRequestBody()
|
||||
require.NotEmpty(t, modifiedBody)
|
||||
|
||||
// 解析修改后的请求体
|
||||
var modifiedRequest map[string]interface{}
|
||||
err := json.Unmarshal(modifiedBody, &modifiedRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证messages字段存在
|
||||
messages, exists := modifiedRequest["messages"].([]interface{})
|
||||
require.True(t, exists, "messages field should exist")
|
||||
require.NotNil(t, messages)
|
||||
|
||||
// 验证消息数量:前置消息(1) + 原始消息(1) = 2
|
||||
require.Len(t, messages, 2, "should have 2 messages: prepend + original")
|
||||
|
||||
// 验证第一个消息是前置消息(地理变量已被替换)
|
||||
firstMessage := messages[0].(map[string]interface{})
|
||||
require.Equal(t, "system", firstMessage["role"])
|
||||
require.Equal(t, "You are located in Shanghai, China.", firstMessage["content"])
|
||||
|
||||
// 验证第二个消息是原始用户消息
|
||||
secondMessage := messages[1].(map[string]interface{})
|
||||
require.Equal(t, "user", secondMessage["role"])
|
||||
require.Equal(t, "What's the weather like?", secondMessage["content"])
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试空消息的情况
|
||||
t.Run("empty messages", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 设置请求体,不包含messages字段
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo"
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试多个消息的装饰
|
||||
t.Run("multiple messages decoration", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 设置地理变量属性,供插件使用
|
||||
host.SetProperty([]string{"geo-country"}, []byte("USA"))
|
||||
host.SetProperty([]string{"geo-province"}, []byte("California"))
|
||||
host.SetProperty([]string{"geo-city"}, []byte("San Francisco"))
|
||||
host.SetProperty([]string{"geo-isp"}, []byte("Comcast"))
|
||||
|
||||
// 设置请求体,包含多个消息
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"}
|
||||
]
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证消息装饰是否成功
|
||||
modifiedBody := host.GetRequestBody()
|
||||
require.NotEmpty(t, modifiedBody)
|
||||
|
||||
// 解析修改后的请求体
|
||||
var modifiedRequest map[string]interface{}
|
||||
err := json.Unmarshal(modifiedBody, &modifiedRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证messages字段存在
|
||||
messages, exists := modifiedRequest["messages"].([]interface{})
|
||||
require.True(t, exists, "messages field should exist")
|
||||
require.NotNil(t, messages)
|
||||
|
||||
// 验证消息数量:前置消息(1) + 原始消息(3) + 后置消息(1) = 5
|
||||
require.Len(t, messages, 5, "should have 5 messages: prepend + original(3) + append")
|
||||
|
||||
// 验证第一个消息是前置消息(地理变量已被替换)
|
||||
firstMessage := messages[0].(map[string]interface{})
|
||||
require.Equal(t, "system", firstMessage["role"])
|
||||
require.Equal(t, "You are a helpful assistant from USA.", firstMessage["content"])
|
||||
|
||||
// 验证原始消息保持顺序
|
||||
originalMessages := messages[1:4]
|
||||
require.Equal(t, "system", originalMessages[0].(map[string]interface{})["role"])
|
||||
require.Equal(t, "You are a helpful assistant", originalMessages[0].(map[string]interface{})["content"])
|
||||
require.Equal(t, "user", originalMessages[1].(map[string]interface{})["role"])
|
||||
require.Equal(t, "Hello", originalMessages[1].(map[string]interface{})["content"])
|
||||
require.Equal(t, "assistant", originalMessages[2].(map[string]interface{})["role"])
|
||||
require.Equal(t, "Hi there!", originalMessages[2].(map[string]interface{})["content"])
|
||||
|
||||
// 验证最后一个消息是后置消息(地理变量已被替换)
|
||||
lastMessage := messages[4].(map[string]interface{})
|
||||
require.Equal(t, "system", lastMessage["role"])
|
||||
require.Equal(t, "Please provide context about San Francisco.", lastMessage["content"])
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructs(t *testing.T) {
|
||||
// 测试Message结构体
|
||||
t.Run("Message struct", func(t *testing.T) {
|
||||
message := Message{
|
||||
Role: "system",
|
||||
Content: "You are a helpful assistant from ${geo-country}.",
|
||||
}
|
||||
require.Equal(t, "system", message.Role)
|
||||
require.Equal(t, "You are a helpful assistant from ${geo-country}.", message.Content)
|
||||
})
|
||||
|
||||
// 测试AIPromptDecoratorConfig结构体
|
||||
t.Run("AIPromptDecoratorConfig struct", func(t *testing.T) {
|
||||
config := &AIPromptDecoratorConfig{
|
||||
Prepend: []Message{
|
||||
{Role: "system", Content: "Prepend message"},
|
||||
},
|
||||
Append: []Message{
|
||||
{Role: "system", Content: "Append message"},
|
||||
},
|
||||
}
|
||||
require.NotNil(t, config.Prepend)
|
||||
require.NotNil(t, config.Append)
|
||||
require.Len(t, config.Prepend, 1)
|
||||
require.Len(t, config.Append, 1)
|
||||
require.Equal(t, "Prepend message", config.Prepend[0].Content)
|
||||
require.Equal(t, "Append message", config.Append[0].Content)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeographicVariableReplacement(t *testing.T) {
|
||||
// 测试地理变量替换逻辑
|
||||
t.Run("geographic variable replacement", func(t *testing.T) {
|
||||
config := &AIPromptDecoratorConfig{
|
||||
Prepend: []Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "Location: ${geo-country}/${geo-province}/${geo-city}, ISP: ${geo-isp}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 验证地理变量在内容中的存在
|
||||
content := config.Prepend[0].Content
|
||||
require.Contains(t, content, "${geo-country}")
|
||||
require.Contains(t, content, "${geo-province}")
|
||||
require.Contains(t, content, "${geo-city}")
|
||||
require.Contains(t, content, "${geo-isp}")
|
||||
|
||||
// 测试变量替换逻辑
|
||||
geoVariables := []string{"geo-country", "geo-province", "geo-city", "geo-isp"}
|
||||
for _, geo := range geoVariables {
|
||||
require.Contains(t, content, fmt.Sprintf("${%s}", geo))
|
||||
}
|
||||
})
|
||||
|
||||
// 测试混合内容的地理变量
|
||||
t.Run("mixed content geographic variables", func(t *testing.T) {
|
||||
config := &AIPromptDecoratorConfig{
|
||||
Append: []Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "User from ${geo-country} with ISP ${geo-isp}. Context: ${geo-province}, ${geo-city}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
content := config.Append[0].Content
|
||||
require.Contains(t, content, "${geo-country}")
|
||||
require.Contains(t, content, "${geo-isp}")
|
||||
require.Contains(t, content, "${geo-province}")
|
||||
require.Contains(t, content, "${geo-city}")
|
||||
})
|
||||
}
|
||||
|
||||
func TestEdgeCases(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试空前置和后置消息
|
||||
t.Run("empty prepend and append", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(emptyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Test message"}
|
||||
]
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试无效JSON请求体
|
||||
t.Run("invalid JSON body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
})
|
||||
|
||||
// 设置无效的请求体
|
||||
body := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
// Missing closing brace
|
||||
`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -5,14 +5,20 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.0
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/resp v0.1.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
424
plugins/wasm-go/extensions/ai-prompt-template/main_test.go
Normal file
424
plugins/wasm-go/extensions/ai-prompt-template/main_test.go
Normal file
@@ -0,0 +1,424 @@
|
||||
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基础模板配置
|
||||
var basicConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"templates": []map[string]interface{}{
|
||||
{
|
||||
"name": "greeting",
|
||||
"template": "Hello {{name}}, welcome to {{company}}!",
|
||||
},
|
||||
{
|
||||
"name": "summary",
|
||||
"template": "Here is a summary of {{topic}}: {{content}}",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:单个模板配置
|
||||
var singleTemplateConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"templates": []map[string]interface{}{
|
||||
{
|
||||
"name": "simple",
|
||||
"template": "This is a {{adjective}} {{noun}}.",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:空模板配置
|
||||
var emptyTemplatesConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"templates": []map[string]interface{}{},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:复杂模板配置
|
||||
var complexTemplateConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"templates": []map[string]interface{}{
|
||||
{
|
||||
"name": "email",
|
||||
"template": "Dear {{recipient}},\n\n{{greeting}}\n\n{{body}}\n\nBest regards,\n{{sender}}",
|
||||
},
|
||||
{
|
||||
"name": "report",
|
||||
"template": "Report: {{title}}\nDate: {{date}}\nAuthor: {{author}}\n\n{{content}}\n\nConclusion: {{conclusion}}",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func TestParseConfig(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基础模板配置解析
|
||||
t.Run("basic templates config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
promptConfig := config.(*AIPromptTemplateConfig)
|
||||
require.NotNil(t, promptConfig.templates)
|
||||
require.Len(t, promptConfig.templates, 2)
|
||||
// 由于gjson.Get("template").Raw返回JSON原始值,包含引号
|
||||
require.Equal(t, "\"Hello {{name}}, welcome to {{company}}!\"", promptConfig.templates["greeting"])
|
||||
require.Equal(t, "\"Here is a summary of {{topic}}: {{content}}\"", promptConfig.templates["summary"])
|
||||
})
|
||||
|
||||
// 测试单个模板配置解析
|
||||
t.Run("single template config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(singleTemplateConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
promptConfig := config.(*AIPromptTemplateConfig)
|
||||
require.NotNil(t, promptConfig.templates)
|
||||
require.Len(t, promptConfig.templates, 1)
|
||||
// 由于gjson.Get("template").Raw返回JSON原始值,包含引号
|
||||
require.Equal(t, "\"This is a {{adjective}} {{noun}}.\"", promptConfig.templates["simple"])
|
||||
})
|
||||
|
||||
// 测试空模板配置解析
|
||||
t.Run("empty templates config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(emptyTemplatesConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
promptConfig := config.(*AIPromptTemplateConfig)
|
||||
require.NotNil(t, promptConfig.templates)
|
||||
require.Len(t, promptConfig.templates, 0)
|
||||
})
|
||||
|
||||
// 测试复杂模板配置解析
|
||||
t.Run("complex templates config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(complexTemplateConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
promptConfig := config.(*AIPromptTemplateConfig)
|
||||
require.NotNil(t, promptConfig.templates)
|
||||
require.Len(t, promptConfig.templates, 2)
|
||||
// 由于gjson.Get("template").Raw返回JSON原始值,包含引号和转义字符
|
||||
require.Equal(t, "\"Dear {{recipient}},\\n\\n{{greeting}}\\n\\n{{body}}\\n\\nBest regards,\\n{{sender}}\"", promptConfig.templates["email"])
|
||||
require.Equal(t, "\"Report: {{title}}\\nDate: {{date}}\\nAuthor: {{author}}\\n\\n{{content}}\\n\\nConclusion: {{conclusion}}\"", promptConfig.templates["report"])
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试启用模板的情况
|
||||
t.Run("template enabled", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头,启用模板
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "true"},
|
||||
{"content-length", "100"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试禁用模板的情况
|
||||
t.Run("template disabled", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头,禁用模板
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "false"},
|
||||
{"content-length", "100"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
|
||||
// 测试没有template-enable头的情况
|
||||
t.Run("no template-enable header", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头,不包含template-enable
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"content-length", "100"},
|
||||
})
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestOnHttpRequestBody(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试基础模板替换
|
||||
t.Run("basic template replacement", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "true"},
|
||||
})
|
||||
|
||||
// 设置请求体,包含模板和属性
|
||||
body := `{
|
||||
"template": "greeting",
|
||||
"properties": {
|
||||
"name": "Alice",
|
||||
"company": "TechCorp"
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试复杂模板替换
|
||||
t.Run("complex template replacement", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(complexTemplateConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "true"},
|
||||
})
|
||||
|
||||
// 设置请求体,包含复杂模板和属性
|
||||
body := `{
|
||||
"template": "email",
|
||||
"properties": {
|
||||
"recipient": "John Doe",
|
||||
"greeting": "I hope this email finds you well",
|
||||
"body": "Please find attached the quarterly report",
|
||||
"sender": "Jane Smith"
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试没有模板的情况
|
||||
t.Run("no template in body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "true"},
|
||||
})
|
||||
|
||||
// 设置请求体,不包含模板
|
||||
body := `{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试没有属性的情况
|
||||
t.Run("no properties in body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "true"},
|
||||
})
|
||||
|
||||
// 设置请求体,包含模板但不包含属性
|
||||
body := `{
|
||||
"template": "greeting"
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
|
||||
// 测试部分属性替换
|
||||
t.Run("partial properties replacement", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"template-enable", "true"},
|
||||
})
|
||||
|
||||
// 设置请求体,只包含部分属性
|
||||
body := `{
|
||||
"template": "greeting",
|
||||
"properties": {
|
||||
"name": "Bob"
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpRequestBody([]byte(body))
|
||||
|
||||
// 应该返回ActionContinue
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
host.CompleteHttp()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestStructs(t *testing.T) {
|
||||
// 测试AIPromptTemplateConfig结构体
|
||||
t.Run("AIPromptTemplateConfig struct", func(t *testing.T) {
|
||||
config := &AIPromptTemplateConfig{
|
||||
templates: map[string]string{
|
||||
"test": "This is a {{test}} template",
|
||||
},
|
||||
}
|
||||
require.NotNil(t, config.templates)
|
||||
require.Len(t, config.templates, 1)
|
||||
require.Equal(t, "This is a {{test}} template", config.templates["test"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestTemplateReplacementLogic(t *testing.T) {
|
||||
// 测试模板变量替换逻辑
|
||||
t.Run("template variable replacement", func(t *testing.T) {
|
||||
config := &AIPromptTemplateConfig{
|
||||
templates: map[string]string{
|
||||
"greeting": "Hello {{name}}, welcome to {{company}}!",
|
||||
},
|
||||
}
|
||||
|
||||
// 模拟模板替换逻辑
|
||||
template := config.templates["greeting"]
|
||||
require.Equal(t, "Hello {{name}}, welcome to {{company}}!", template)
|
||||
|
||||
// 测试变量替换
|
||||
properties := map[string]string{
|
||||
"name": "Alice",
|
||||
"company": "TechCorp",
|
||||
}
|
||||
|
||||
for key, value := range properties {
|
||||
template = strings.ReplaceAll(template, fmt.Sprintf("{{%s}}", key), value)
|
||||
}
|
||||
|
||||
require.Equal(t, "Hello Alice, welcome to TechCorp!", template)
|
||||
})
|
||||
|
||||
// 测试嵌套变量替换
|
||||
t.Run("nested variable replacement", func(t *testing.T) {
|
||||
config := &AIPromptTemplateConfig{
|
||||
templates: map[string]string{
|
||||
"nested": "{{greeting}} {{name}}, {{message}}",
|
||||
},
|
||||
}
|
||||
|
||||
template := config.templates["nested"]
|
||||
require.Equal(t, "{{greeting}} {{name}}, {{message}}", template)
|
||||
|
||||
// 测试嵌套替换
|
||||
properties := map[string]string{
|
||||
"greeting": "Hello",
|
||||
"name": "World",
|
||||
"message": "welcome!",
|
||||
}
|
||||
|
||||
for key, value := range properties {
|
||||
template = strings.ReplaceAll(template, fmt.Sprintf("{{%s}}", key), value)
|
||||
}
|
||||
|
||||
require.Equal(t, "Hello World, welcome!", template)
|
||||
})
|
||||
}
|
||||
@@ -16,4 +16,3 @@
|
||||
!*/
|
||||
|
||||
/out
|
||||
/test
|
||||
|
||||
@@ -9,10 +9,21 @@ description: AI 代理插件配置参考
|
||||
`AI 代理`插件实现了基于 OpenAI API 契约的 AI 代理功能。目前支持 OpenAI、Azure OpenAI、月之暗面(Moonshot)和通义千问等 AI
|
||||
服务提供商。
|
||||
|
||||
> **注意:**
|
||||
**🚀 自动协议兼容 (Auto Protocol Compatibility)**
|
||||
|
||||
插件现在支持**自动协议检测**,无需配置即可同时兼容 OpenAI 和 Claude 两种协议格式:
|
||||
|
||||
- **OpenAI 协议**: 请求路径 `/v1/chat/completions`,使用标准的 OpenAI Messages API 格式
|
||||
- **Claude 协议**: 请求路径 `/v1/messages`,使用 Anthropic Claude Messages API 格式
|
||||
- **智能转换**: 自动检测请求协议,如果目标供应商不原生支持该协议,则自动进行协议转换
|
||||
- **零配置**: 用户无需设置 `protocol` 字段,插件自动处理
|
||||
|
||||
> **协议支持说明:**
|
||||
|
||||
> 请求路径后缀匹配 `/v1/chat/completions` 时,对应文生文场景,会用 OpenAI 的文生文协议解析请求 Body,再转换为对应 LLM 厂商的文生文协议
|
||||
|
||||
> 请求路径后缀匹配 `/v1/messages` 时,对应 Claude 文生文场景,会自动检测供应商能力:如果支持原生 Claude 协议则直接转发,否则先转换为 OpenAI 协议再转发给供应商
|
||||
|
||||
> 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body,再转换为对应 LLM 厂商的文本向量协议
|
||||
|
||||
## 运行属性
|
||||
@@ -158,6 +169,14 @@ DeepSeek 所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
|
||||
|
||||
Groq 所对应的 `type` 为 `groq`。它并无特有的配置字段。
|
||||
|
||||
#### Grok
|
||||
|
||||
Grok 所对应的 `type` 为 `grok`。它并无特有的配置字段。
|
||||
|
||||
#### OpenRouter
|
||||
|
||||
OpenRouter 所对应的 `type` 为 `openrouter`。它并无特有的配置字段。
|
||||
|
||||
#### 文心一言(Baidu)
|
||||
|
||||
文心一言所对应的 `type` 为 `baidu`。它并无特有的配置字段。
|
||||
@@ -231,10 +250,11 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置
|
||||
|
||||
Gemini 所对应的 `type` 为 `gemini`。它特有的配置字段如下:
|
||||
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| --------------------- | ------------- | -------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1` 或 `v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 |
|
||||
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|
||||
| ---------------------- | ------------- | -------- | -------- | ------------------------------------------------------------ |
|
||||
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
|
||||
| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1` 或 `v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 |
|
||||
| `geminiThinkingBudget` | number | 非必填 | - | gemini2.5系列的参数,0是不开启思考,-1动态调整,具体参数指可参考官网 |
|
||||
|
||||
#### DeepL
|
||||
|
||||
@@ -862,19 +882,167 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 Claude 服务
|
||||
### 使用 OpenAI 协议代理 Grok 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: claude
|
||||
type: grok
|
||||
apiTokens:
|
||||
- 'YOUR_GROK_API_TOKEN'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can answer questions and help with tasks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 101*3?"
|
||||
}
|
||||
],
|
||||
"model": "grok-4"
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "a3d1008e-4544-40d4-d075-11527e794e4a",
|
||||
"object": "chat.completion",
|
||||
"created": 1752854522,
|
||||
"model": "grok-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "101 multiplied by 3 is 303.",
|
||||
"refusal": null
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 32,
|
||||
"completion_tokens": 9,
|
||||
"total_tokens": 135,
|
||||
"prompt_tokens_details": {
|
||||
"text_tokens": 32,
|
||||
"audio_tokens": 0,
|
||||
"image_tokens": 0,
|
||||
"cached_tokens": 6
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 94,
|
||||
"audio_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0
|
||||
},
|
||||
"num_sources_used": 0
|
||||
},
|
||||
"system_fingerprint": "fp_3a7881249c"
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理 OpenRouter 服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: openrouter
|
||||
apiTokens:
|
||||
- 'YOUR_OPENROUTER_API_TOKEN'
|
||||
modelMapping:
|
||||
'gpt-4': 'openai/gpt-4-turbo-preview'
|
||||
'gpt-3.5-turbo': 'openai/gpt-3.5-turbo'
|
||||
'claude-3': 'anthropic/claude-3-opus'
|
||||
'*': 'openai/gpt-3.5-turbo'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}
|
||||
```
|
||||
|
||||
**响应示例**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gen-1234567890abcdef",
|
||||
"object": "chat.completion",
|
||||
"created": 1699123456,
|
||||
"model": "openai/gpt-4-turbo-preview",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "你好!我是一个AI助手,通过OpenRouter平台提供服务。我可以帮助回答问题、协助创作、进行对话等。有什么我可以帮助你的吗?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 46,
|
||||
"total_tokens": 58
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用自动协议兼容功能
|
||||
|
||||
插件现在支持自动协议检测,可以同时处理 OpenAI 和 Claude 两种协议格式的请求。
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: claude # 原生支持 Claude 协议的供应商
|
||||
apiTokens:
|
||||
- 'YOUR_CLAUDE_API_TOKEN'
|
||||
version: '2023-06-01'
|
||||
```
|
||||
|
||||
**请求示例**
|
||||
**OpenAI 协议请求示例**
|
||||
|
||||
URL: `http://your-domain/v1/chat/completions`
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus-20240229",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Claude 协议请求示例**
|
||||
|
||||
URL: `http://your-domain/v1/messages`
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -891,6 +1059,8 @@ provider:
|
||||
|
||||
**响应示例**
|
||||
|
||||
两种协议格式的请求都会返回相应格式的响应:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg_01Jt3GzyjuzymnxmZERJguLK",
|
||||
@@ -915,6 +1085,39 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### 使用智能协议转换
|
||||
|
||||
当目标供应商不原生支持 Claude 协议时,插件会自动进行协议转换:
|
||||
|
||||
**配置信息**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: qwen # 不原生支持 Claude 协议,会自动转换
|
||||
apiTokens:
|
||||
- 'YOUR_QWEN_API_TOKEN'
|
||||
modelMapping:
|
||||
'claude-3-opus-20240229': 'qwen-max'
|
||||
'*': 'qwen-turbo'
|
||||
```
|
||||
|
||||
**Claude 协议请求**
|
||||
|
||||
URL: `http://your-domain/v1/messages` (自动转换为 OpenAI 协议调用供应商)
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus-20240229",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你是谁?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 使用 OpenAI 协议代理混元服务
|
||||
|
||||
**配置信息**
|
||||
|
||||
@@ -8,10 +8,21 @@ description: Reference for configuring the AI Proxy plugin
|
||||
|
||||
The `AI Proxy` plugin implements AI proxy functionality based on the OpenAI API contract. It currently supports AI service providers such as OpenAI, Azure OpenAI, Moonshot, and Qwen.
|
||||
|
||||
> **Note:**
|
||||
**🚀 Auto Protocol Compatibility**
|
||||
|
||||
The plugin now supports **automatic protocol detection**, allowing seamless compatibility with both OpenAI and Claude protocol formats without configuration:
|
||||
|
||||
- **OpenAI Protocol**: Request path `/v1/chat/completions`, using standard OpenAI Messages API format
|
||||
- **Claude Protocol**: Request path `/v1/messages`, using Anthropic Claude Messages API format
|
||||
- **Intelligent Conversion**: Automatically detects request protocol and performs conversion if the target provider doesn't natively support it
|
||||
- **Zero Configuration**: No need to set `protocol` field, the plugin handles everything automatically
|
||||
|
||||
> **Protocol Support:**
|
||||
|
||||
> When the request path suffix matches `/v1/chat/completions`, it corresponds to text-to-text scenarios. The request body will be parsed using OpenAI's text-to-text protocol and then converted to the corresponding LLM vendor's text-to-text protocol.
|
||||
|
||||
> When the request path suffix matches `/v1/messages`, it corresponds to Claude text-to-text scenarios. The plugin automatically detects provider capabilities: if native Claude protocol is supported, requests are forwarded directly; otherwise, they are converted to OpenAI protocol first.
|
||||
|
||||
> When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol.
|
||||
|
||||
## Execution Properties
|
||||
@@ -35,7 +46,7 @@ Plugin execution priority: `100`
|
||||
| `apiTokens` | array of string | Optional | - | Tokens used for authentication when accessing AI services. If multiple tokens are configured, the plugin randomly selects one for each request. Some service providers only support configuring a single token. |
|
||||
| `timeout` | number | Optional | - | Timeout for accessing AI services, in milliseconds. The default value is 120000, which equals 2 minutes. Only used when retrieving context data. Won't affect the request forwarded to the LLM upstream. |
|
||||
| `modelMapping` | map of string | Optional | - | Mapping table for AI models, used to map model names in requests to names supported by the service provider.<br/>1. Supports prefix matching. For example, "gpt-3-\*" matches all model names starting with “gpt-3-”;<br/>2. Supports using "\*" as a key for a general fallback mapping;<br/>3. If the mapped target name is an empty string "", the original model name is preserved. |
|
||||
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider) |
|
||||
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider). **Note: Auto protocol detection is now supported, no need to configure this field to support both OpenAI and Claude protocols** |
|
||||
| `context` | object | Optional | - | Configuration for AI conversation context information |
|
||||
| `customSettings` | array of customSetting | Optional | - | Specifies overrides or fills parameters for AI requests |
|
||||
| `subPath` | string | Optional | - | If subPath is configured, the prefix will be removed from the request path before further processing. |
|
||||
@@ -129,6 +140,14 @@ For DeepSeek, the corresponding `type` is `deepseek`. It has no unique configura
|
||||
|
||||
For Groq, the corresponding `type` is `groq`. It has no unique configuration fields.
|
||||
|
||||
#### Grok
|
||||
|
||||
For Grok, the corresponding `type` is `grok`. It has no unique configuration fields.
|
||||
|
||||
#### OpenRouter
|
||||
|
||||
For OpenRouter, the corresponding `type` is `openrouter`. It has no unique configuration fields.
|
||||
|
||||
#### ERNIE Bot
|
||||
|
||||
For ERNIE Bot, the corresponding `type` is `baidu`. It has no unique configuration fields.
|
||||
@@ -200,6 +219,8 @@ For Gemini, the corresponding `type` is `gemini`. Its unique configuration field
|
||||
| Name | Data Type | Filling Requirements | Default Value | Description |
|
||||
|---------------------|----------|----------------------|---------------|---------------------------------------------------------------------------------------------------------|
|
||||
| `geminiSafetySetting` | map of string | Optional | - | Gemini AI content filtering and safety level settings. Refer to [Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings). |
|
||||
| `apiVersion` | string | 非必填 | `v1beta` | To specify the version of the API, you can choose either 'v1' or 'v1beta'. Version differences refer to https://ai.google.dev/gemini-api/docs/api-versions |
|
||||
| `geminiThinkingBudget` | number | 非必填 | - | The parameters of the gemini2.5 series: 0 indicates no thinking mode, -1 represents dynamic adjustment. For specific parameter references, please refer to the official website |
|
||||
|
||||
### DeepL
|
||||
|
||||
@@ -807,19 +828,167 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Using OpenAI Protocol Proxy for Claude Service
|
||||
### Using OpenAI Protocol Proxy for Grok Service
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: claude
|
||||
type: grok
|
||||
apiTokens:
|
||||
- "YOUR_GROK_API_TOKEN"
|
||||
```
|
||||
|
||||
**Example Request**
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant that can answer questions and help with tasks."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is 101*3?"
|
||||
}
|
||||
],
|
||||
"model": "grok-4"
|
||||
}
|
||||
```
|
||||
|
||||
**Example Response**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "a3d1008e-4544-40d4-d075-11527e794e4a",
|
||||
"object": "chat.completion",
|
||||
"created": 1752854522,
|
||||
"model": "grok-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "101 multiplied by 3 is 303.",
|
||||
"refusal": null
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 32,
|
||||
"completion_tokens": 9,
|
||||
"total_tokens": 135,
|
||||
"prompt_tokens_details": {
|
||||
"text_tokens": 32,
|
||||
"audio_tokens": 0,
|
||||
"image_tokens": 0,
|
||||
"cached_tokens": 6
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 94,
|
||||
"audio_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0
|
||||
},
|
||||
"num_sources_used": 0
|
||||
},
|
||||
"system_fingerprint": "fp_3a7881249c"
|
||||
}
|
||||
```
|
||||
|
||||
### Using OpenAI Protocol Proxy for OpenRouter Service
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: openrouter
|
||||
apiTokens:
|
||||
- 'YOUR_OPENROUTER_API_TOKEN'
|
||||
modelMapping:
|
||||
'gpt-4': 'openai/gpt-4-turbo-preview'
|
||||
'gpt-3.5-turbo': 'openai/gpt-3.5-turbo'
|
||||
'claude-3': 'anthropic/claude-3-opus'
|
||||
'*': 'openai/gpt-3.5-turbo'
|
||||
```
|
||||
|
||||
**Example Request**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}
|
||||
```
|
||||
|
||||
**Example Response**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "gen-1234567890abcdef",
|
||||
"object": "chat.completion",
|
||||
"created": 1699123456,
|
||||
"model": "openai/gpt-4-turbo-preview",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I am an AI assistant powered by OpenRouter. I can help answer questions, assist with creative tasks, engage in conversations, and more. How can I assist you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 35,
|
||||
"total_tokens": 47
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Using Auto Protocol Compatibility
|
||||
|
||||
The plugin now supports automatic protocol detection, capable of handling both OpenAI and Claude protocol format requests simultaneously.
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: claude # Provider with native Claude protocol support
|
||||
apiTokens:
|
||||
- "YOUR_CLAUDE_API_TOKEN"
|
||||
version: "2023-06-01"
|
||||
```
|
||||
|
||||
**Example Request**
|
||||
**OpenAI Protocol Request Example**
|
||||
|
||||
URL: `http://your-domain/v1/chat/completions`
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus-20240229",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**Claude Protocol Request Example**
|
||||
|
||||
URL: `http://your-domain/v1/messages`
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -836,6 +1005,8 @@ provider:
|
||||
|
||||
**Example Response**
|
||||
|
||||
Both protocol formats will return responses in their respective formats:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "msg_01Jt3GzyjuzymnxmZERJguLK",
|
||||
@@ -860,6 +1031,39 @@ provider:
|
||||
}
|
||||
```
|
||||
|
||||
### Using Intelligent Protocol Conversion
|
||||
|
||||
When the target provider doesn't natively support Claude protocol, the plugin automatically performs protocol conversion:
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
```yaml
|
||||
provider:
|
||||
type: qwen # Doesn't natively support Claude protocol, auto-conversion applied
|
||||
apiTokens:
|
||||
- "YOUR_QWEN_API_TOKEN"
|
||||
modelMapping:
|
||||
'claude-3-opus-20240229': 'qwen-max'
|
||||
'*': 'qwen-turbo'
|
||||
```
|
||||
|
||||
**Claude Protocol Request**
|
||||
|
||||
URL: `http://your-domain/v1/messages` (automatically converted to OpenAI protocol for provider)
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "claude-3-opus-20240229",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, who are you?"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Using OpenAI Protocol Proxy for Hunyuan Service
|
||||
|
||||
**Configuration Information**
|
||||
|
||||
3734
plugins/wasm-go/extensions/ai-proxy/claude-message-api.yaml
Normal file
3734
plugins/wasm-go/extensions/ai-proxy/claude-message-api.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,12 +7,14 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.1
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
)
|
||||
|
||||
require github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
|
||||
@@ -2,16 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
|
||||
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
|
||||
github.com/higress-group/wasm-go v1.0.1 h1:T1m++qTEANp8+jwE0sxltwtaTKmrHCkLOp1m9N+YeqY=
|
||||
github.com/higress-group/wasm-go v1.0.1/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
|
||||
@@ -6,6 +6,7 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
|
||||
@@ -31,6 +32,11 @@ const (
|
||||
ctxOriginalAuth = "original_auth"
|
||||
)
|
||||
|
||||
type pair[K, V any] struct {
|
||||
key K
|
||||
value V
|
||||
}
|
||||
|
||||
var (
|
||||
headersCtxKeyMapping = map[string]string{
|
||||
util.HeaderAuthority: ctxOriginalHost,
|
||||
@@ -42,6 +48,44 @@ var (
|
||||
util.HeaderPath: util.HeaderOriginalPath,
|
||||
util.HeaderAuthorization: util.HeaderOriginalAuth,
|
||||
}
|
||||
pathSuffixToApiName = []pair[string, provider.ApiName]{
|
||||
// OpenAI style
|
||||
{provider.PathOpenAIChatCompletions, provider.ApiNameChatCompletion},
|
||||
{provider.PathOpenAICompletions, provider.ApiNameCompletion},
|
||||
{provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings},
|
||||
{provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech},
|
||||
{provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration},
|
||||
{provider.PathOpenAIImageVariation, provider.ApiNameImageVariation},
|
||||
{provider.PathOpenAIImageEdit, provider.ApiNameImageEdit},
|
||||
{provider.PathOpenAIBatches, provider.ApiNameBatches},
|
||||
{provider.PathOpenAIFiles, provider.ApiNameFiles},
|
||||
{provider.PathOpenAIModels, provider.ApiNameModels},
|
||||
{provider.PathOpenAIFineTuningJobs, provider.ApiNameFineTuningJobs},
|
||||
{provider.PathOpenAIResponses, provider.ApiNameResponses},
|
||||
// Anthropic style
|
||||
{provider.PathAnthropicMessages, provider.ApiNameAnthropicMessages},
|
||||
{provider.PathAnthropicComplete, provider.ApiNameAnthropicComplete},
|
||||
// Cohere style
|
||||
{provider.PathCohereV1Rerank, provider.ApiNameCohereV1Rerank},
|
||||
}
|
||||
pathPatternToApiName = []pair[*regexp.Regexp, provider.ApiName]{
|
||||
// OpenAI style
|
||||
{util.RegRetrieveBatchPath, provider.ApiNameRetrieveBatch},
|
||||
{util.RegCancelBatchPath, provider.ApiNameCancelBatch},
|
||||
{util.RegRetrieveFilePath, provider.ApiNameRetrieveFile},
|
||||
{util.RegRetrieveFileContentPath, provider.ApiNameRetrieveFileContent},
|
||||
{util.RegRetrieveFineTuningJobPath, provider.ApiNameRetrieveFineTuningJob},
|
||||
{util.RegRetrieveFineTuningJobEventsPath, provider.ApiNameFineTuningJobEvents},
|
||||
{util.RegRetrieveFineTuningJobCheckpointsPath, provider.ApiNameFineTuningJobCheckpoints},
|
||||
{util.RegCancelFineTuningJobPath, provider.ApiNameCancelFineTuningJob},
|
||||
{util.RegResumeFineTuningJobPath, provider.ApiNameResumeFineTuningJob},
|
||||
{util.RegPauseFineTuningJobPath, provider.ApiNamePauseFineTuningJob},
|
||||
{util.RegFineTuningCheckpointPermissionPath, provider.ApiNameFineTuningCheckpointPermissions},
|
||||
{util.RegDeleteFineTuningCheckpointPermissionPath, provider.ApiNameDeleteFineTuningCheckpointPermission},
|
||||
// Gemini style
|
||||
{util.RegGeminiGenerateContent, provider.ApiNameGeminiGenerateContent},
|
||||
{util.RegGeminiStreamGenerateContent, provider.ApiNameGeminiStreamGenerateContent},
|
||||
}
|
||||
)
|
||||
|
||||
func main() {}
|
||||
@@ -97,6 +141,9 @@ func initContext(ctx wrapper.HttpContext) {
|
||||
value, _ := proxywasm.GetHttpRequestHeader(header)
|
||||
ctx.SetContext(ctxKey, value)
|
||||
}
|
||||
for _, originHeader := range headerToOriginalHeaderMapping {
|
||||
proxywasm.RemoveHttpRequestHeader(originHeader)
|
||||
}
|
||||
}
|
||||
|
||||
func saveContextsToHeaders(ctx wrapper.HttpContext) {
|
||||
@@ -127,6 +174,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
|
||||
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
|
||||
|
||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||
ctx.DisableReroute()
|
||||
|
||||
initContext(ctx)
|
||||
|
||||
rawPath := ctx.Path()
|
||||
@@ -144,6 +194,23 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-detect protocol based on request path and handle conversion if needed
|
||||
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
|
||||
// convert to OpenAI format (/v1/chat/completions)
|
||||
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
|
||||
// Provider doesn't support Claude protocol natively, convert to OpenAI format
|
||||
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
|
||||
// Update apiName to match the new path
|
||||
apiName = provider.ApiNameChatCompletion
|
||||
// Mark that we need to convert response back to Claude format
|
||||
ctx.SetContext("needClaudeResponseConversion", true)
|
||||
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
|
||||
} else if apiName == provider.ApiNameAnthropicMessages {
|
||||
// Provider supports Claude protocol natively, no conversion needed
|
||||
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
|
||||
}
|
||||
|
||||
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
|
||||
ctx.DontReadRequestBody()
|
||||
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
|
||||
@@ -156,8 +223,6 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
|
||||
}
|
||||
|
||||
ctx.SetContext(provider.CtxKeyApiName, apiName)
|
||||
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
||||
ctx.DisableReroute()
|
||||
|
||||
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
|
||||
// allowing plugins to inspect or modify the response correctly
|
||||
@@ -275,17 +340,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
|
||||
}
|
||||
util.ReplaceResponseHeaders(headers)
|
||||
|
||||
checkStream(ctx)
|
||||
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
|
||||
var needHandleStreamingBody bool
|
||||
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
|
||||
if !needHandleStreamingBody {
|
||||
_, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler)
|
||||
}
|
||||
if !needHandleBody && !needHandleStreamingBody {
|
||||
|
||||
// Check if we need to read body for Claude response conversion
|
||||
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
|
||||
|
||||
if !needHandleBody && !needHandleStreamingBody && !needClaudeConversion {
|
||||
ctx.DontReadResponseBody()
|
||||
} else if !needHandleStreamingBody {
|
||||
ctx.BufferResponseBody()
|
||||
} else {
|
||||
checkStream(ctx)
|
||||
}
|
||||
|
||||
return types.ActionContinue
|
||||
@@ -306,7 +374,12 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
|
||||
if err == nil && modifiedChunk != nil {
|
||||
return modifiedChunk
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk)
|
||||
if convertErr != nil {
|
||||
return modifiedChunk
|
||||
}
|
||||
return claudeChunk
|
||||
}
|
||||
return chunk
|
||||
}
|
||||
@@ -315,8 +388,8 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk)
|
||||
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
|
||||
if len(events) == 0 {
|
||||
// No events are extracted, return the original chunk
|
||||
return chunk
|
||||
// No events are extracted, return empty bytes slice
|
||||
return []byte("")
|
||||
}
|
||||
var responseBuilder strings.Builder
|
||||
for _, event := range events {
|
||||
@@ -332,7 +405,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
|
||||
return chunk
|
||||
}
|
||||
if outputEvents == nil || len(outputEvents) == 0 {
|
||||
if len(outputEvents) == 0 {
|
||||
responseBuilder.WriteString(event.ToHttpString())
|
||||
} else {
|
||||
for _, outputEvent := range outputEvents {
|
||||
@@ -340,9 +413,40 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
|
||||
}
|
||||
}
|
||||
}
|
||||
return []byte(responseBuilder.String())
|
||||
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
||||
if convertErr != nil {
|
||||
return result
|
||||
}
|
||||
return claudeChunk
|
||||
}
|
||||
return chunk
|
||||
|
||||
// If provider doesn't implement any streaming handlers but we need Claude conversion
|
||||
// First extract complete events from the chunk
|
||||
events := provider.ExtractStreamingEvents(ctx, chunk)
|
||||
log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events))
|
||||
if len(events) == 0 {
|
||||
// No events are extracted, return empty bytes slice
|
||||
return []byte("")
|
||||
}
|
||||
|
||||
// Build response from extracted events (without handler processing)
|
||||
var responseBuilder strings.Builder
|
||||
for _, event := range events {
|
||||
responseBuilder.WriteString(event.ToHttpString())
|
||||
}
|
||||
|
||||
result := []byte(responseBuilder.String())
|
||||
|
||||
// Convert to Claude format if needed
|
||||
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
|
||||
if convertErr != nil {
|
||||
return result
|
||||
}
|
||||
return claudeChunk
|
||||
}
|
||||
|
||||
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
|
||||
@@ -355,20 +459,82 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
|
||||
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
|
||||
|
||||
var finalBody []byte
|
||||
|
||||
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
body, err := handler.TransformResponseBody(ctx, apiName, body)
|
||||
transformedBody, err := handler.TransformResponseBody(ctx, apiName, body)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
||||
return types.ActionContinue
|
||||
}
|
||||
if err = provider.ReplaceResponseBody(body); err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||
}
|
||||
finalBody = transformedBody
|
||||
} else {
|
||||
finalBody = body
|
||||
}
|
||||
|
||||
// Convert to Claude format if needed (applies to both branches)
|
||||
convertedBody, err := convertResponseBodyToClaude(ctx, finalBody)
|
||||
if err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.convert_resp_to_claude_failed", err)
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
if err = provider.ReplaceResponseBody(convertedBody); err != nil {
|
||||
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
||||
}
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
// Helper function to check if Claude response conversion is needed
|
||||
func needsClaudeResponseConversion(ctx wrapper.HttpContext) bool {
|
||||
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
|
||||
return needClaudeConversion
|
||||
}
|
||||
|
||||
// Helper function to convert OpenAI streaming response to Claude format
|
||||
func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]byte, error) {
|
||||
if !needsClaudeResponseConversion(ctx) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Get or create converter instance from context to maintain state
|
||||
const claudeConverterKey = "claudeConverter"
|
||||
var converter *provider.ClaudeToOpenAIConverter
|
||||
|
||||
if converterData := ctx.GetContext(claudeConverterKey); converterData != nil {
|
||||
if c, ok := converterData.(*provider.ClaudeToOpenAIConverter); ok {
|
||||
converter = c
|
||||
}
|
||||
}
|
||||
|
||||
if converter == nil {
|
||||
converter = &provider.ClaudeToOpenAIConverter{}
|
||||
ctx.SetContext(claudeConverterKey, converter)
|
||||
}
|
||||
|
||||
claudeChunk, err := converter.ConvertOpenAIStreamResponseToClaude(ctx, data)
|
||||
if err != nil {
|
||||
log.Errorf("failed to convert streaming response to claude format: %v", err)
|
||||
return data, err
|
||||
}
|
||||
return claudeChunk, nil
|
||||
}
|
||||
|
||||
// Helper function to convert OpenAI response body to Claude format
|
||||
func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
if !needsClaudeResponseConversion(ctx) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
converter := &provider.ClaudeToOpenAIConverter{}
|
||||
convertedBody, err := converter.ConvertOpenAIResponseToClaude(ctx, body)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to convert response to claude format: %v", err)
|
||||
}
|
||||
return convertedBody, nil
|
||||
}
|
||||
|
||||
func normalizeOpenAiRequestBody(body []byte) []byte {
|
||||
var err error
|
||||
// Default setting include_usage.
|
||||
@@ -393,99 +559,19 @@ func checkStream(ctx wrapper.HttpContext) {
|
||||
}
|
||||
|
||||
func getApiName(path string) provider.ApiName {
|
||||
// openai style
|
||||
if strings.HasSuffix(path, provider.PathOpenAIChatCompletions) {
|
||||
return provider.ApiNameChatCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAICompletions) {
|
||||
return provider.ApiNameCompletion
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIEmbeddings) {
|
||||
return provider.ApiNameEmbeddings
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIAudioSpeech) {
|
||||
return provider.ApiNameAudioSpeech
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIImageGeneration) {
|
||||
return provider.ApiNameImageGeneration
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIImageVariation) {
|
||||
return provider.ApiNameImageVariation
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIImageEdit) {
|
||||
return provider.ApiNameImageEdit
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIBatches) {
|
||||
return provider.ApiNameBatches
|
||||
}
|
||||
if util.RegRetrieveBatchPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveBatch
|
||||
}
|
||||
if util.RegCancelBatchPath.MatchString(path) {
|
||||
return provider.ApiNameCancelBatch
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIFiles) {
|
||||
return provider.ApiNameFiles
|
||||
}
|
||||
if util.RegRetrieveFilePath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFile
|
||||
}
|
||||
if util.RegRetrieveFileContentPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFileContent
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIModels) {
|
||||
return provider.ApiNameModels
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIFineTuningJobs) {
|
||||
return provider.ApiNameFineTuningJobs
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNameRetrieveFineTuningJob
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobEventsPath.MatchString(path) {
|
||||
return provider.ApiNameFineTuningJobEvents
|
||||
}
|
||||
if util.RegRetrieveFineTuningJobCheckpointsPath.MatchString(path) {
|
||||
return provider.ApiNameFineTuningJobCheckpoints
|
||||
}
|
||||
if util.RegCancelFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNameCancelFineTuningJob
|
||||
}
|
||||
if util.RegResumeFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNameResumeFineTuningJob
|
||||
}
|
||||
if util.RegPauseFineTuningJobPath.MatchString(path) {
|
||||
return provider.ApiNamePauseFineTuningJob
|
||||
}
|
||||
if util.RegFineTuningCheckpointPermissionPath.MatchString(path) {
|
||||
return provider.ApiNameFineTuningCheckpointPermissions
|
||||
}
|
||||
if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) {
|
||||
return provider.ApiNameDeleteFineTuningCheckpointPermission
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathOpenAIResponses) {
|
||||
return provider.ApiNameResponses
|
||||
// Check path suffix matches first
|
||||
for _, p := range pathSuffixToApiName {
|
||||
if strings.HasSuffix(path, p.key) {
|
||||
return p.value
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic
|
||||
if strings.HasSuffix(path, provider.PathAnthropicMessages) {
|
||||
return provider.ApiNameAnthropicMessages
|
||||
}
|
||||
if strings.HasSuffix(path, provider.PathAnthropicComplete) {
|
||||
return provider.ApiNameAnthropicComplete
|
||||
// Check path pattern matches
|
||||
for _, p := range pathPatternToApiName {
|
||||
if p.key.MatchString(path) {
|
||||
return p.value
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini
|
||||
if util.RegGeminiGenerateContent.MatchString(path) {
|
||||
return provider.ApiNameGeminiGenerateContent
|
||||
}
|
||||
if util.RegGeminiStreamGenerateContent.MatchString(path) {
|
||||
return provider.ApiNameGeminiStreamGenerateContent
|
||||
}
|
||||
|
||||
// cohere style
|
||||
if strings.HasSuffix(path, provider.PathCohereV1Rerank) {
|
||||
return provider.ApiNameCohereV1Rerank
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
105
plugins/wasm-go/extensions/ai-proxy/main_test.go
Normal file
105
plugins/wasm-go/extensions/ai-proxy/main_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test"
|
||||
)
|
||||
|
||||
func Test_getApiName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want provider.ApiName
|
||||
}{
|
||||
// OpenAI style
|
||||
{"openai chat completions", "/v1/chat/completions", provider.ApiNameChatCompletion},
|
||||
{"openai completions", "/v1/completions", provider.ApiNameCompletion},
|
||||
{"openai embeddings", "/v1/embeddings", provider.ApiNameEmbeddings},
|
||||
{"openai audio speech", "/v1/audio/speech", provider.ApiNameAudioSpeech},
|
||||
{"openai image generation", "/v1/images/generations", provider.ApiNameImageGeneration},
|
||||
{"openai image variation", "/v1/images/variations", provider.ApiNameImageVariation},
|
||||
{"openai image edit", "/v1/images/edits", provider.ApiNameImageEdit},
|
||||
{"openai batches", "/v1/batches", provider.ApiNameBatches},
|
||||
{"openai retrieve batch", "/v1/batches/batchid", provider.ApiNameRetrieveBatch},
|
||||
{"openai cancel batch", "/v1/batches/batchid/cancel", provider.ApiNameCancelBatch},
|
||||
{"openai files", "/v1/files", provider.ApiNameFiles},
|
||||
{"openai retrieve file", "/v1/files/fileid", provider.ApiNameRetrieveFile},
|
||||
{"openai retrieve file content", "/v1/files/fileid/content", provider.ApiNameRetrieveFileContent},
|
||||
{"openai models", "/v1/models", provider.ApiNameModels},
|
||||
{"openai fine tuning jobs", "/v1/fine_tuning/jobs", provider.ApiNameFineTuningJobs},
|
||||
{"openai retrieve fine tuning job", "/v1/fine_tuning/jobs/jobid", provider.ApiNameRetrieveFineTuningJob},
|
||||
{"openai fine tuning job events", "/v1/fine_tuning/jobs/jobid/events", provider.ApiNameFineTuningJobEvents},
|
||||
{"openai fine tuning job checkpoints", "/v1/fine_tuning/jobs/jobid/checkpoints", provider.ApiNameFineTuningJobCheckpoints},
|
||||
{"openai cancel fine tuning job", "/v1/fine_tuning/jobs/jobid/cancel", provider.ApiNameCancelFineTuningJob},
|
||||
{"openai resume fine tuning job", "/v1/fine_tuning/jobs/jobid/resume", provider.ApiNameResumeFineTuningJob},
|
||||
{"openai pause fine tuning job", "/v1/fine_tuning/jobs/jobid/pause", provider.ApiNamePauseFineTuningJob},
|
||||
{"openai fine tuning checkpoint permissions", "/v1/fine_tuning/checkpoints/checkpointid/permissions", provider.ApiNameFineTuningCheckpointPermissions},
|
||||
{"openai delete fine tuning checkpoint permission", "/v1/fine_tuning/checkpoints/checkpointid/permissions/permissionid", provider.ApiNameDeleteFineTuningCheckpointPermission},
|
||||
{"openai responses", "/v1/responses", provider.ApiNameResponses},
|
||||
// Anthropic
|
||||
{"anthropic messages", "/v1/messages", provider.ApiNameAnthropicMessages},
|
||||
{"anthropic complete", "/v1/complete", provider.ApiNameAnthropicComplete},
|
||||
// Gemini
|
||||
{"gemini generate content", "/v1beta/models/gemini-1.0-pro:generateContent", provider.ApiNameGeminiGenerateContent},
|
||||
{"gemini stream generate content", "/v1beta/models/gemini-1.0-pro:streamGenerateContent", provider.ApiNameGeminiStreamGenerateContent},
|
||||
// Cohere
|
||||
{"cohere rerank", "/v1/rerank", provider.ApiNameCohereV1Rerank},
|
||||
// Unknown
|
||||
{"unknown", "/v1/unknown", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := getApiName(tt.path)
|
||||
if got != tt.want {
|
||||
t.Errorf("getApiName(%q) = %v, want %v", tt.path, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAi360(t *testing.T) {
|
||||
test.RunAi360ParseConfigTests(t)
|
||||
test.RunAi360OnHttpRequestHeadersTests(t)
|
||||
test.RunAi360OnHttpRequestBodyTests(t)
|
||||
test.RunAi360OnHttpResponseHeadersTests(t)
|
||||
test.RunAi360OnHttpResponseBodyTests(t)
|
||||
test.RunAi360OnStreamingResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestOpenAI(t *testing.T) {
|
||||
test.RunOpenAIParseConfigTests(t)
|
||||
test.RunOpenAIOnHttpRequestHeadersTests(t)
|
||||
test.RunOpenAIOnHttpRequestBodyTests(t)
|
||||
test.RunOpenAIOnHttpResponseHeadersTests(t)
|
||||
test.RunOpenAIOnHttpResponseBodyTests(t)
|
||||
test.RunOpenAIOnStreamingResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestQwen(t *testing.T) {
|
||||
test.RunQwenParseConfigTests(t)
|
||||
test.RunQwenOnHttpRequestHeadersTests(t)
|
||||
test.RunQwenOnHttpRequestBodyTests(t)
|
||||
test.RunQwenOnHttpResponseHeadersTests(t)
|
||||
test.RunQwenOnHttpResponseBodyTests(t)
|
||||
test.RunQwenOnStreamingResponseBodyTests(t)
|
||||
}
|
||||
|
||||
func TestGemini(t *testing.T) {
|
||||
test.RunGeminiParseConfigTests(t)
|
||||
test.RunGeminiOnHttpRequestHeadersTests(t)
|
||||
test.RunGeminiOnHttpRequestBodyTests(t)
|
||||
test.RunGeminiOnHttpResponseHeadersTests(t)
|
||||
test.RunGeminiOnHttpResponseBodyTests(t)
|
||||
test.RunGeminiOnStreamingResponseBodyTests(t)
|
||||
test.RunGeminiGetImageURLTests(t)
|
||||
}
|
||||
|
||||
func TestAzure(t *testing.T) {
|
||||
test.RunAzureParseConfigTests(t)
|
||||
test.RunAzureOnHttpRequestHeadersTests(t)
|
||||
test.RunAzureOnHttpRequestBodyTests(t)
|
||||
test.RunAzureOnHttpResponseHeadersTests(t)
|
||||
test.RunAzureOnHttpResponseBodyTests(t)
|
||||
}
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
type azureServiceUrlType int
|
||||
|
||||
const (
|
||||
pathAzurePrefix = "/openai"
|
||||
pathAzureModelPlaceholder = "{model}"
|
||||
@@ -21,6 +23,12 @@ const (
|
||||
queryAzureApiVersion = "api-version"
|
||||
)
|
||||
|
||||
const (
|
||||
azureServiceUrlTypeFull azureServiceUrlType = iota
|
||||
azureServiceUrlTypeWithDeployment
|
||||
azureServiceUrlTypeDomainOnly
|
||||
)
|
||||
|
||||
var (
|
||||
azureModelIrrelevantApis = map[ApiName]bool{
|
||||
ApiNameModels: true,
|
||||
@@ -31,7 +39,7 @@ var (
|
||||
ApiNameRetrieveFile: true,
|
||||
ApiNameRetrieveFileContent: true,
|
||||
}
|
||||
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)")
|
||||
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)")
|
||||
)
|
||||
|
||||
// azureProvider is the provider for Azure OpenAI service.
|
||||
@@ -82,32 +90,44 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
|
||||
|
||||
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
|
||||
defaultModel := "placeholder"
|
||||
var serviceUrlType azureServiceUrlType
|
||||
if modelSubMatch != nil {
|
||||
defaultModel = modelSubMatch[1]
|
||||
if modelSubMatch[2] != "" {
|
||||
serviceUrlType = azureServiceUrlTypeFull
|
||||
} else {
|
||||
serviceUrlType = azureServiceUrlTypeWithDeployment
|
||||
}
|
||||
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
|
||||
} else {
|
||||
serviceUrlType = azureServiceUrlTypeDomainOnly
|
||||
log.Debugf("azureProvider: no default model found in serviceUrl")
|
||||
}
|
||||
log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType)
|
||||
|
||||
config.setDefaultCapabilities(m.DefaultCapabilities())
|
||||
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
|
||||
log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion)
|
||||
return &azureProvider{
|
||||
config: config,
|
||||
serviceUrl: serviceUrl,
|
||||
apiVersion: apiVersion,
|
||||
defaultModel: defaultModel,
|
||||
contextCache: createContextCache(&config),
|
||||
config: config,
|
||||
serviceUrl: serviceUrl,
|
||||
serviceUrlType: serviceUrlType,
|
||||
serviceUrlFullPath: serviceUrl.Path + "?" + serviceUrl.RawQuery,
|
||||
apiVersion: apiVersion,
|
||||
defaultModel: defaultModel,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type azureProvider struct {
|
||||
config ProviderConfig
|
||||
|
||||
contextCache *contextCache
|
||||
serviceUrl *url.URL
|
||||
apiVersion string
|
||||
defaultModel string
|
||||
contextCache *contextCache
|
||||
serviceUrl *url.URL
|
||||
serviceUrlFullPath string
|
||||
serviceUrlType azureServiceUrlType
|
||||
apiVersion string
|
||||
defaultModel string
|
||||
}
|
||||
|
||||
func (m *azureProvider) GetProviderType() string {
|
||||
@@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
|
||||
return originalPath
|
||||
}
|
||||
|
||||
if m.serviceUrlType == azureServiceUrlTypeFull {
|
||||
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
|
||||
return m.serviceUrlFullPath
|
||||
}
|
||||
|
||||
log.Debugf("azureProvider: original request path: %s", originalPath)
|
||||
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
|
||||
log.Debugf("azureProvider: path: %s", path)
|
||||
if strings.Contains(path, pathAzureModelPlaceholder) {
|
||||
log.Debugf("azureProvider: path contains placeholder: %s", path)
|
||||
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
||||
log.Debugf("azureProvider: model from context: %s", model)
|
||||
if model == "" {
|
||||
var model string
|
||||
if m.serviceUrlType == azureServiceUrlTypeWithDeployment {
|
||||
model = m.defaultModel
|
||||
log.Debugf("azureProvider: use default model: %s", model)
|
||||
} else {
|
||||
model = ctx.GetStringContext(ctxKeyFinalRequestModel, "")
|
||||
log.Debugf("azureProvider: model from context: %s", model)
|
||||
if model == "" {
|
||||
model = m.defaultModel
|
||||
log.Debugf("azureProvider: use default model: %s", model)
|
||||
}
|
||||
}
|
||||
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
|
||||
log.Debugf("azureProvider: model replaced path: %s", path)
|
||||
}
|
||||
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion)
|
||||
path = path + "?" + m.serviceUrl.RawQuery
|
||||
log.Debugf("azureProvider: final path: %s", path)
|
||||
|
||||
return path
|
||||
|
||||
@@ -19,12 +19,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -99,8 +96,31 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
|
||||
if bedrockEvent.Role != nil {
|
||||
chatChoice.Delta.Role = *bedrockEvent.Role
|
||||
}
|
||||
if bedrockEvent.Start != nil {
|
||||
chatChoice.Delta.Content = nil
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Id: bedrockEvent.Start.ToolUse.ToolUseID,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: bedrockEvent.Start.ToolUse.Name,
|
||||
Arguments: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
if bedrockEvent.Delta != nil {
|
||||
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
|
||||
if bedrockEvent.Delta.ToolUse != nil {
|
||||
chatChoice.Delta.ToolCalls = []toolCall{
|
||||
{
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Arguments: bedrockEvent.Delta.ToolUse.Input,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
if bedrockEvent.StopReason != nil {
|
||||
chatChoice.FinishReason = util.Ptr(stopReasonBedrock2OpenAI(*bedrockEvent.StopReason))
|
||||
@@ -591,29 +611,7 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
|
||||
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
|
||||
request := &bedrockTextGenRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
if len(request.System) > 0 {
|
||||
request.System = append(request.System, systemContentBlock{Text: content})
|
||||
} else {
|
||||
request.System = []systemContentBlock{{Text: content}}
|
||||
}
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, nil)
|
||||
return requestBytes, err
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if gjson.GetBytes(body, "model").Exists() {
|
||||
rawModel := gjson.GetBytes(body, "model").String()
|
||||
encodedModel := url.QueryEscape(rawModel)
|
||||
body, _ = sjson.SetBytes(body, "model", encodedModel)
|
||||
}
|
||||
switch apiName {
|
||||
case ApiNameChatCompletion:
|
||||
return b.onChatCompletionRequestBody(ctx, body, headers)
|
||||
@@ -651,7 +649,7 @@ func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext,
|
||||
return nil, err
|
||||
}
|
||||
headers.Set("Accept", "*/*")
|
||||
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model))
|
||||
b.overwriteRequestPathHeader(headers, bedrockInvokeModelPath, request.Model)
|
||||
return b.buildBedrockImageGenerationRequest(request, headers)
|
||||
}
|
||||
|
||||
@@ -675,7 +673,6 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
|
||||
Quality: origRequest.Quality,
|
||||
},
|
||||
}
|
||||
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model))
|
||||
requestBytes, err := json.Marshal(request)
|
||||
b.setAuthHeaders(requestBytes, headers)
|
||||
return requestBytes, err
|
||||
@@ -714,9 +711,9 @@ func (b *bedrockProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, b
|
||||
streaming := request.Stream
|
||||
headers.Set("Accept", "*/*")
|
||||
if streaming {
|
||||
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockStreamChatCompletionPath, request.Model))
|
||||
b.overwriteRequestPathHeader(headers, bedrockStreamChatCompletionPath, request.Model)
|
||||
} else {
|
||||
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockChatCompletionPath, request.Model))
|
||||
b.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, request.Model)
|
||||
}
|
||||
return b.buildBedrockTextGenerationRequest(request, headers)
|
||||
}
|
||||
@@ -726,9 +723,12 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
systemMessages := make([]systemContentBlock, 0)
|
||||
|
||||
for _, msg := range origRequest.Messages {
|
||||
if msg.Role == roleSystem {
|
||||
switch msg.Role {
|
||||
case roleSystem:
|
||||
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
|
||||
} else {
|
||||
case roleTool:
|
||||
messages = append(messages, chatToolMessage2BedrockMessage(msg))
|
||||
default:
|
||||
messages = append(messages, chatMessage2BedrockMessage(msg))
|
||||
}
|
||||
}
|
||||
@@ -747,6 +747,36 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
|
||||
},
|
||||
}
|
||||
|
||||
if origRequest.Tools != nil {
|
||||
request.ToolConfig = &bedrockToolConfig{}
|
||||
if origRequest.ToolChoice == nil {
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
} else if choice_type, ok := origRequest.ToolChoice.(string); ok {
|
||||
switch choice_type {
|
||||
case "required":
|
||||
request.ToolConfig.ToolChoice.Any = &struct{}{}
|
||||
case "auto":
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
case "none":
|
||||
request.ToolConfig.ToolChoice.Auto = &struct{}{}
|
||||
}
|
||||
} else if choice, ok := origRequest.ToolChoice.(toolChoice); ok {
|
||||
request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{
|
||||
Name: choice.Function.Name,
|
||||
}
|
||||
}
|
||||
request.ToolConfig.Tools = []bedrockTool{}
|
||||
for _, tool := range origRequest.Tools {
|
||||
request.ToolConfig.Tools = append(request.ToolConfig.Tools, bedrockTool{
|
||||
ToolSpec: bedrockToolSpecification{
|
||||
InputSchema: bedrockToolInputSchemaJson{Json: tool.Function.Parameters},
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for key, value := range b.config.bedrockAdditionalFields {
|
||||
request.AdditionalModelRequestFields[key] = value
|
||||
}
|
||||
@@ -761,16 +791,29 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
if len(bedrockResponse.Output.Message.Content) > 0 {
|
||||
outputContent = bedrockResponse.Output.Message.Content[0].Text
|
||||
}
|
||||
choices := []chatCompletionChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: &chatMessage{
|
||||
Role: bedrockResponse.Output.Message.Role,
|
||||
Content: outputContent,
|
||||
},
|
||||
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
|
||||
choice := chatCompletionChoice{
|
||||
Index: 0,
|
||||
Message: &chatMessage{
|
||||
Role: bedrockResponse.Output.Message.Role,
|
||||
Content: outputContent,
|
||||
},
|
||||
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
|
||||
}
|
||||
choice.Message.ToolCalls = []toolCall{}
|
||||
for _, content := range bedrockResponse.Output.Message.Content {
|
||||
if content.ToolUse != nil {
|
||||
args, _ := json.Marshal(content.ToolUse.Input)
|
||||
choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCall{
|
||||
Id: content.ToolUse.ToolUseId,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: content.ToolUse.Name,
|
||||
Arguments: string(args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
choices := []chatCompletionChoice{choice}
|
||||
requestId := ctx.GetStringContext(requestIdHeader, "")
|
||||
modelId, _ := url.QueryUnescape(ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
|
||||
return &chatCompletionResponse{
|
||||
@@ -788,6 +831,17 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) overwriteRequestPathHeader(headers http.Header, format, model string) {
|
||||
modelInPath := model
|
||||
// Just in case the model name has already been URL-escaped, we shouldn't escape it again.
|
||||
if !strings.ContainsRune(model, '%') {
|
||||
modelInPath = url.QueryEscape(model)
|
||||
}
|
||||
path := fmt.Sprintf(format, modelInPath)
|
||||
log.Debugf("overwriting bedrock request path: %s", path)
|
||||
util.OverwriteRequestPathHeader(headers, path)
|
||||
}
|
||||
|
||||
func stopReasonBedrock2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "end_turn":
|
||||
@@ -796,6 +850,8 @@ func stopReasonBedrock2OpenAI(reason string) string {
|
||||
return finishReasonStop
|
||||
case "max_tokens":
|
||||
return finishReasonLength
|
||||
case "tool_use":
|
||||
return finishReasonToolCall
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
@@ -807,20 +863,48 @@ type bedrockTextGenRequest struct {
|
||||
InferenceConfig bedrockInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
|
||||
PerformanceConfig PerformanceConfiguration `json:"performanceConfig,omitempty"`
|
||||
ToolConfig *bedrockToolConfig `json:"toolConfig,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolConfig struct {
|
||||
Tools []bedrockTool `json:"tools,omitempty"`
|
||||
ToolChoice bedrockToolChoice `json:"toolChoice,omitempty"`
|
||||
}
|
||||
|
||||
type PerformanceConfiguration struct {
|
||||
Latency string `json:"latency,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockTool struct {
|
||||
ToolSpec bedrockToolSpecification `json:"toolSpec,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolChoice struct {
|
||||
Any *struct{} `json:"any,omitempty"`
|
||||
Auto *struct{} `json:"auto,omitempty"`
|
||||
Tool *bedrockToolSpecification `json:"tool,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolSpecification struct {
|
||||
InputSchema bedrockToolInputSchemaJson `json:"inputSchema,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolInputSchemaJson struct {
|
||||
Json map[string]interface{} `json:"json,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []bedrockMessageContent `json:"content"`
|
||||
}
|
||||
|
||||
type bedrockMessageContent struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Image *imageBlock `json:"image,omitempty"`
|
||||
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
|
||||
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
|
||||
}
|
||||
|
||||
type systemContentBlock struct {
|
||||
@@ -836,6 +920,22 @@ type imageSource struct {
|
||||
Bytes string `json:"bytes,omitempty"`
|
||||
}
|
||||
|
||||
type toolResultBlock struct {
|
||||
ToolUseId string `json:"toolUseId"`
|
||||
Content []toolResultContentBlock `json:"content"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
type toolResultContentBlock struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type toolUseBlock struct {
|
||||
Input map[string]interface{} `json:"input"`
|
||||
Name string `json:"name"`
|
||||
ToolUseId string `json:"toolUseId"`
|
||||
}
|
||||
|
||||
type bedrockInferenceConfig struct {
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
MaxTokens int `json:"maxTokens,omitempty"`
|
||||
@@ -859,13 +959,19 @@ type converseOutputMemberMessage struct {
|
||||
}
|
||||
|
||||
type message struct {
|
||||
Content []contentBlockMemberText `json:"content"`
|
||||
|
||||
Role string `json:"role"`
|
||||
Content []contentBlock `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type contentBlockMemberText struct {
|
||||
Text string `json:"text"`
|
||||
type contentBlock struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
ToolUse *bedrockToolUse `json:"toolUse,omitempty"`
|
||||
}
|
||||
|
||||
type bedrockToolUse struct {
|
||||
Name string `json:"name"`
|
||||
ToolUseId string `json:"toolUseId"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
}
|
||||
|
||||
type tokenUsage struct {
|
||||
@@ -876,9 +982,53 @@ type tokenUsage struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
|
||||
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
toolResultContent := &toolResultBlock{}
|
||||
toolResultContent.ToolUseId = chatMessage.ToolCallId
|
||||
if text, ok := chatMessage.Content.(string); ok {
|
||||
toolResultContent.Content = []toolResultContentBlock{
|
||||
{
|
||||
Text: text,
|
||||
},
|
||||
}
|
||||
openaiContent := chatMessage.ParseContent()
|
||||
for _, part := range openaiContent {
|
||||
var content bedrockMessageContent
|
||||
if part.Type == contentTypeText {
|
||||
content.Text = part.Text
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Warnf("only text content is supported, current content is %v", chatMessage.Content)
|
||||
}
|
||||
return bedrockMessage{
|
||||
Role: roleUser,
|
||||
Content: []bedrockMessageContent{
|
||||
{
|
||||
ToolResult: toolResultContent,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
if chatMessage.IsStringContent() {
|
||||
return bedrockMessage{
|
||||
var result bedrockMessage
|
||||
if len(chatMessage.ToolCalls) > 0 {
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: []bedrockMessageContent{{}},
|
||||
}
|
||||
params := map[string]interface{}{}
|
||||
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), ¶ms)
|
||||
result.Content[0].ToolUse = &toolUseBlock{
|
||||
Input: params,
|
||||
Name: chatMessage.ToolCalls[0].Function.Name,
|
||||
ToolUseId: chatMessage.ToolCalls[0].Id,
|
||||
}
|
||||
} else if chatMessage.IsStringContent() {
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: []bedrockMessageContent{{Text: chatMessage.StringContent()}},
|
||||
}
|
||||
@@ -895,29 +1045,22 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
|
||||
}
|
||||
contents = append(contents, content)
|
||||
}
|
||||
return bedrockMessage{
|
||||
result = bedrockMessage{
|
||||
Role: chatMessage.Role,
|
||||
Content: contents,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
|
||||
t := time.Now().UTC()
|
||||
amzDate := t.Format("20060102T150405Z")
|
||||
dateStamp := t.Format("20060102")
|
||||
path, _ := proxywasm.GetHttpRequestHeader(":path")
|
||||
if headers != nil {
|
||||
path = headers.Get(":path")
|
||||
}
|
||||
path := headers.Get(":path")
|
||||
signature := b.generateSignature(path, amzDate, dateStamp, body)
|
||||
if headers != nil {
|
||||
headers.Set("X-Amz-Date", amzDate)
|
||||
headers.Set("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
||||
} else {
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("X-Amz-Date", amzDate)
|
||||
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
||||
}
|
||||
headers.Set("X-Amz-Date", amzDate)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
|
||||
}
|
||||
|
||||
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {
|
||||
|
||||
@@ -36,8 +36,18 @@ type claudeToolChoice struct {
|
||||
}
|
||||
|
||||
type claudeChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content claudeChatMessageContentWr `json:"content"`
|
||||
}
|
||||
|
||||
// claudeChatMessageContentWr wraps the content to handle both string and array formats
|
||||
type claudeChatMessageContentWr struct {
|
||||
// StringValue holds simple text content
|
||||
StringValue string
|
||||
// ArrayValue holds multi-modal content
|
||||
ArrayValue []claudeChatMessageContent
|
||||
// IsString indicates whether this is a simple string or array
|
||||
IsString bool
|
||||
}
|
||||
|
||||
type claudeChatMessageContentSource struct {
|
||||
@@ -49,23 +59,154 @@ type claudeChatMessageContentSource struct {
|
||||
}
|
||||
|
||||
type claudeChatMessageContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *claudeChatMessageContentSource `json:"source,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Source *claudeChatMessageContentSource `json:"source,omitempty"`
|
||||
CacheControl map[string]interface{} `json:"cache_control,omitempty"`
|
||||
// Tool use fields
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
// Tool result fields
|
||||
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
|
||||
Content string `json:"content,omitempty"` // For tool_result
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for claudeChatMessageContentWr
|
||||
func (ccw *claudeChatMessageContentWr) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as string first
|
||||
var stringValue string
|
||||
if err := json.Unmarshal(data, &stringValue); err == nil {
|
||||
ccw.StringValue = stringValue
|
||||
ccw.IsString = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as array of content blocks
|
||||
var arrayValue []claudeChatMessageContent
|
||||
if err := json.Unmarshal(data, &arrayValue); err == nil {
|
||||
ccw.ArrayValue = arrayValue
|
||||
ccw.IsString = false
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("content field must be either a string or an array of content blocks")
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for claudeChatMessageContentWr
|
||||
func (ccw claudeChatMessageContentWr) MarshalJSON() ([]byte, error) {
|
||||
if ccw.IsString {
|
||||
return json.Marshal(ccw.StringValue)
|
||||
}
|
||||
return json.Marshal(ccw.ArrayValue)
|
||||
}
|
||||
|
||||
// GetStringValue returns the string representation if it's a string, empty string otherwise
|
||||
func (ccw claudeChatMessageContentWr) GetStringValue() string {
|
||||
if ccw.IsString {
|
||||
return ccw.StringValue
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetArrayValue returns the array representation if it's an array, empty slice otherwise
|
||||
func (ccw claudeChatMessageContentWr) GetArrayValue() []claudeChatMessageContent {
|
||||
if !ccw.IsString {
|
||||
return ccw.ArrayValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewStringContent creates a new wrapper for string content
|
||||
func NewStringContent(content string) claudeChatMessageContentWr {
|
||||
return claudeChatMessageContentWr{
|
||||
StringValue: content,
|
||||
IsString: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewArrayContent creates a new wrapper for array content
|
||||
func NewArrayContent(content []claudeChatMessageContent) claudeChatMessageContentWr {
|
||||
return claudeChatMessageContentWr{
|
||||
ArrayValue: content,
|
||||
IsString: false,
|
||||
}
|
||||
}
|
||||
|
||||
// claudeSystemPrompt represents the system field which can be either a string or an array of text blocks
|
||||
type claudeSystemPrompt struct {
|
||||
// Will be set to the string value if system is a simple string
|
||||
StringValue string
|
||||
// Will be set to the array value if system is an array of text blocks
|
||||
ArrayValue []claudeTextGenContent
|
||||
// Indicates which type this represents
|
||||
IsArray bool
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements custom JSON unmarshaling for claudeSystemPrompt
|
||||
func (csp *claudeSystemPrompt) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as string first
|
||||
var stringValue string
|
||||
if err := json.Unmarshal(data, &stringValue); err == nil {
|
||||
csp.StringValue = stringValue
|
||||
csp.IsArray = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to unmarshal as array of text blocks
|
||||
var arrayValue []claudeTextGenContent
|
||||
if err := json.Unmarshal(data, &arrayValue); err == nil {
|
||||
csp.ArrayValue = arrayValue
|
||||
csp.IsArray = true
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("system field must be either a string or an array of text blocks")
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for claudeSystemPrompt
|
||||
func (csp claudeSystemPrompt) MarshalJSON() ([]byte, error) {
|
||||
if csp.IsArray {
|
||||
return json.Marshal(csp.ArrayValue)
|
||||
}
|
||||
return json.Marshal(csp.StringValue)
|
||||
}
|
||||
|
||||
// String returns the string representation of the system prompt
|
||||
func (csp claudeSystemPrompt) String() string {
|
||||
if csp.IsArray {
|
||||
// Concatenate all text blocks
|
||||
var parts []string
|
||||
for _, block := range csp.ArrayValue {
|
||||
if block.Text != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
return csp.StringValue
|
||||
}
|
||||
|
||||
// claudeThinkingConfig represents the thinking configuration for Claude
|
||||
type claudeThinkingConfig struct {
|
||||
Type string `json:"type"`
|
||||
BudgetTokens int `json:"budget_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []claudeChatMessage `json:"messages"`
|
||||
System claudeSystemPrompt `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type claudeTextGenResponse struct {
|
||||
@@ -81,8 +222,13 @@ type claudeTextGenResponse struct {
|
||||
}
|
||||
|
||||
type claudeTextGenContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Id string `json:"id,omitempty"` // For tool_use
|
||||
Name string `json:"name,omitempty"` // For tool_use
|
||||
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
|
||||
Signature string `json:"signature,omitempty"` // For thinking
|
||||
Thinking string `json:"thinking,omitempty"` // For thinking
|
||||
}
|
||||
|
||||
type claudeTextGenUsage struct {
|
||||
@@ -99,7 +245,7 @@ type claudeTextGenError struct {
|
||||
type claudeTextGenStreamResponse struct {
|
||||
Type string `json:"type"`
|
||||
Message *claudeTextGenResponse `json:"message,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
ContentBlock *claudeTextGenContent `json:"content_block,omitempty"`
|
||||
Delta *claudeTextGenDelta `json:"delta,omitempty"`
|
||||
Usage *claudeTextGenUsage `json:"usage,omitempty"`
|
||||
@@ -107,13 +253,13 @@ type claudeTextGenStreamResponse struct {
|
||||
|
||||
type claudeTextGenDelta struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
StopReason *string `json:"stop_reason"`
|
||||
StopSequence *string `json:"stop_sequence"`
|
||||
Text string `json:"text,omitempty"`
|
||||
StopReason *string `json:"stop_reason,omitempty"`
|
||||
StopSequence *string `json:"stop_sequence,omitempty"`
|
||||
}
|
||||
|
||||
func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
if len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
@@ -255,7 +401,10 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
|
||||
for _, message := range origRequest.Messages {
|
||||
if message.Role == roleSystem {
|
||||
claudeRequest.System = message.StringContent()
|
||||
claudeRequest.System = claudeSystemPrompt{
|
||||
StringValue: message.StringContent(),
|
||||
IsArray: false,
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -263,7 +412,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
Role: message.Role,
|
||||
}
|
||||
if message.IsStringContent() {
|
||||
claudeMessage.Content = message.StringContent()
|
||||
claudeMessage.Content = NewStringContent(message.StringContent())
|
||||
} else {
|
||||
chatMessageContents := make([]claudeChatMessageContent, 0)
|
||||
for _, messageContent := range message.ParseContent() {
|
||||
@@ -310,7 +459,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
|
||||
continue
|
||||
}
|
||||
}
|
||||
claudeMessage.Content = chatMessageContents
|
||||
claudeMessage.Content = NewArrayContent(chatMessageContents)
|
||||
}
|
||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||
}
|
||||
@@ -342,19 +491,25 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
|
||||
}
|
||||
|
||||
return &chatCompletionResponse{
|
||||
response := &chatCompletionResponse{
|
||||
Id: origResponse.Id,
|
||||
Created: time.Now().UnixMilli() / 1000,
|
||||
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
|
||||
SystemFingerprint: "",
|
||||
Object: objectChatCompletion,
|
||||
Choices: []chatCompletionChoice{choice},
|
||||
Usage: &usage{
|
||||
}
|
||||
|
||||
// Include usage information if available
|
||||
if origResponse.Usage.InputTokens > 0 || origResponse.Usage.OutputTokens > 0 {
|
||||
response.Usage = &usage{
|
||||
PromptTokens: origResponse.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Usage.OutputTokens,
|
||||
TotalTokens: origResponse.Usage.InputTokens + origResponse.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func stopReasonClaude2OpenAI(reason *string) string {
|
||||
@@ -376,31 +531,47 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
||||
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
|
||||
switch origResponse.Type {
|
||||
case "message_start":
|
||||
c.messageId = origResponse.Message.Id
|
||||
c.usage = usage{
|
||||
PromptTokens: origResponse.Message.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Message.Usage.OutputTokens,
|
||||
if origResponse.Message != nil {
|
||||
c.messageId = origResponse.Message.Id
|
||||
c.usage = usage{
|
||||
PromptTokens: origResponse.Message.Usage.InputTokens,
|
||||
CompletionTokens: origResponse.Message.Usage.OutputTokens,
|
||||
}
|
||||
c.serviceTier = origResponse.Message.Usage.ServiceTier
|
||||
}
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
c.serviceTier = origResponse.Message.Usage.ServiceTier
|
||||
choice := chatCompletionChoice{
|
||||
Index: origResponse.Index,
|
||||
Index: index,
|
||||
Delta: &chatMessage{Role: roleAssistant, Content: ""},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "content_block_delta":
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
choice := chatCompletionChoice{
|
||||
Index: origResponse.Index,
|
||||
Index: index,
|
||||
Delta: &chatMessage{Content: origResponse.Delta.Text},
|
||||
}
|
||||
return c.createChatCompletionResponse(ctx, origResponse, choice)
|
||||
|
||||
case "message_delta":
|
||||
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
|
||||
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
|
||||
if origResponse.Usage != nil {
|
||||
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
|
||||
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
|
||||
}
|
||||
|
||||
var index int
|
||||
if origResponse.Index != nil {
|
||||
index = *origResponse.Index
|
||||
}
|
||||
choice := chatCompletionChoice{
|
||||
Index: origResponse.Index,
|
||||
Index: index,
|
||||
Delta: &chatMessage{},
|
||||
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.Delta.StopReason)),
|
||||
}
|
||||
@@ -449,10 +620,17 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
if request.System == "" {
|
||||
request.System = content
|
||||
systemStr := request.System.String()
|
||||
if systemStr == "" {
|
||||
request.System = claudeSystemPrompt{
|
||||
StringValue: content,
|
||||
IsArray: false,
|
||||
}
|
||||
} else {
|
||||
request.System = content + "\n" + request.System
|
||||
request.System = claudeSystemPrompt{
|
||||
StringValue: content + "\n" + systemStr,
|
||||
IsArray: false,
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(request)
|
||||
|
||||
824
plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai.go
Normal file
824
plugins/wasm-go/extensions/ai-proxy/provider/claude_to_openai.go
Normal file
@@ -0,0 +1,824 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
// ClaudeToOpenAIConverter converts Claude protocol requests to OpenAI protocol
|
||||
type ClaudeToOpenAIConverter struct {
|
||||
// State tracking for streaming conversion
|
||||
messageStartSent bool
|
||||
messageStopSent bool
|
||||
messageId string
|
||||
// Cache stop_reason until we get usage info
|
||||
pendingStopReason *string
|
||||
// Content block tracking with dynamic index allocation
|
||||
nextContentIndex int
|
||||
thinkingBlockIndex int
|
||||
thinkingBlockStarted bool
|
||||
thinkingBlockStopped bool
|
||||
textBlockIndex int
|
||||
textBlockStarted bool
|
||||
textBlockStopped bool
|
||||
toolBlockIndex int
|
||||
toolBlockStarted bool
|
||||
toolBlockStopped bool
|
||||
// Tool call state tracking
|
||||
toolCallStates map[string]*toolCallState
|
||||
}
|
||||
|
||||
// contentConversionResult represents the result of converting Claude content to OpenAI format
|
||||
type contentConversionResult struct {
|
||||
textParts []string
|
||||
toolCalls []toolCall
|
||||
toolResults []claudeChatMessageContent
|
||||
openaiContents []chatMessageContent
|
||||
hasNonTextContent bool
|
||||
}
|
||||
|
||||
// toolCallState tracks the state of a tool call during streaming
|
||||
type toolCallState struct {
|
||||
id string
|
||||
name string
|
||||
argumentsBuffer string
|
||||
isComplete bool
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToOpenAI converts a Claude chat completion request to OpenAI format
|
||||
func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]byte, error) {
|
||||
log.Debugf("[Claude->OpenAI] Original Claude request body: %s", string(body))
|
||||
|
||||
var claudeRequest claudeTextGenRequest
|
||||
if err := json.Unmarshal(body, &claudeRequest); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal claude request: %v", err)
|
||||
}
|
||||
|
||||
// Convert Claude request to OpenAI format
|
||||
openaiRequest := chatCompletionRequest{
|
||||
Model: claudeRequest.Model,
|
||||
Stream: claudeRequest.Stream,
|
||||
Temperature: claudeRequest.Temperature,
|
||||
TopP: claudeRequest.TopP,
|
||||
MaxTokens: claudeRequest.MaxTokens,
|
||||
Stop: claudeRequest.StopSequences,
|
||||
}
|
||||
|
||||
// Convert messages from Claude format to OpenAI format
|
||||
for _, claudeMsg := range claudeRequest.Messages {
|
||||
// Handle different content types using the type-safe wrapper
|
||||
if claudeMsg.Content.IsString {
|
||||
// Simple text content
|
||||
openaiMsg := chatMessage{
|
||||
Role: claudeMsg.Role,
|
||||
Content: claudeMsg.Content.GetStringValue(),
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
|
||||
} else {
|
||||
// Multi-modal content - process with convertContentArray
|
||||
conversionResult := c.convertContentArray(claudeMsg.Content.GetArrayValue())
|
||||
|
||||
// Handle tool calls if present
|
||||
if len(conversionResult.toolCalls) > 0 {
|
||||
// Use tool_calls format (current OpenAI standard)
|
||||
openaiMsg := chatMessage{
|
||||
Role: claudeMsg.Role,
|
||||
ToolCalls: conversionResult.toolCalls,
|
||||
}
|
||||
|
||||
// Add text content if present, otherwise set to null
|
||||
if len(conversionResult.textParts) > 0 {
|
||||
openaiMsg.Content = strings.Join(conversionResult.textParts, "\n\n")
|
||||
} else {
|
||||
openaiMsg.Content = nil
|
||||
}
|
||||
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
|
||||
}
|
||||
|
||||
// Handle tool results if present
|
||||
if len(conversionResult.toolResults) > 0 {
|
||||
for _, toolResult := range conversionResult.toolResults {
|
||||
toolMsg := chatMessage{
|
||||
Role: "tool",
|
||||
Content: toolResult.Content,
|
||||
ToolCallId: toolResult.ToolUseId,
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, toolMsg)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle regular content if no tool calls or tool results
|
||||
if len(conversionResult.toolCalls) == 0 && len(conversionResult.toolResults) == 0 {
|
||||
var content interface{}
|
||||
if !conversionResult.hasNonTextContent && len(conversionResult.textParts) > 0 {
|
||||
// Simple text content
|
||||
content = strings.Join(conversionResult.textParts, "\n\n")
|
||||
} else {
|
||||
// Multi-modal content or empty content
|
||||
content = conversionResult.openaiContents
|
||||
}
|
||||
|
||||
openaiMsg := chatMessage{
|
||||
Role: claudeMsg.Role,
|
||||
Content: content,
|
||||
}
|
||||
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle system message - Claude has separate system field
|
||||
systemStr := claudeRequest.System.String()
|
||||
if systemStr != "" {
|
||||
systemMsg := chatMessage{
|
||||
Role: roleSystem,
|
||||
Content: systemStr,
|
||||
}
|
||||
// Insert system message at the beginning
|
||||
openaiRequest.Messages = append([]chatMessage{systemMsg}, openaiRequest.Messages...)
|
||||
}
|
||||
|
||||
// Convert tools if present
|
||||
for _, claudeTool := range claudeRequest.Tools {
|
||||
openaiTool := tool{
|
||||
Type: "function",
|
||||
Function: function{
|
||||
Name: claudeTool.Name,
|
||||
Description: claudeTool.Description,
|
||||
Parameters: claudeTool.InputSchema,
|
||||
},
|
||||
}
|
||||
openaiRequest.Tools = append(openaiRequest.Tools, openaiTool)
|
||||
}
|
||||
|
||||
// Convert tool choice if present
|
||||
if claudeRequest.ToolChoice != nil {
|
||||
if claudeRequest.ToolChoice.Type == "tool" && claudeRequest.ToolChoice.Name != "" {
|
||||
openaiRequest.ToolChoice = &toolChoice{
|
||||
Type: "function",
|
||||
Function: function{
|
||||
Name: claudeRequest.ToolChoice.Name,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// For other types like "auto", "none", etc.
|
||||
openaiRequest.ToolChoice = claudeRequest.ToolChoice.Type
|
||||
}
|
||||
|
||||
// Handle parallel tool calls
|
||||
openaiRequest.ParallelToolCalls = !claudeRequest.ToolChoice.DisableParallelToolUse
|
||||
}
|
||||
|
||||
// Convert thinking configuration if present
|
||||
if claudeRequest.Thinking != nil {
|
||||
log.Debugf("[Claude->OpenAI] Found thinking config: type=%s, budget_tokens=%d",
|
||||
claudeRequest.Thinking.Type, claudeRequest.Thinking.BudgetTokens)
|
||||
|
||||
if claudeRequest.Thinking.Type == "enabled" {
|
||||
openaiRequest.ReasoningMaxTokens = claudeRequest.Thinking.BudgetTokens
|
||||
|
||||
// Set ReasoningEffort based on budget_tokens
|
||||
// low: <4096, medium: >=4096 and <16384, high: >=16384
|
||||
if claudeRequest.Thinking.BudgetTokens < 4096 {
|
||||
openaiRequest.ReasoningEffort = "low"
|
||||
} else if claudeRequest.Thinking.BudgetTokens < 16384 {
|
||||
openaiRequest.ReasoningEffort = "medium"
|
||||
} else {
|
||||
openaiRequest.ReasoningEffort = "high"
|
||||
}
|
||||
|
||||
log.Debugf("[Claude->OpenAI] Converted thinking config: budget_tokens=%d, reasoning_effort=%s, reasoning_max_tokens=%d",
|
||||
claudeRequest.Thinking.BudgetTokens, openaiRequest.ReasoningEffort, openaiRequest.ReasoningMaxTokens)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("[Claude->OpenAI] No thinking config found")
|
||||
}
|
||||
|
||||
result, err := json.Marshal(openaiRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to marshal openai request: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("[Claude->OpenAI] Converted OpenAI request body: %s", string(result))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ConvertOpenAIResponseToClaude converts an OpenAI response back to Claude format
|
||||
func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
log.Debugf("[OpenAI->Claude] Original OpenAI response body: %s", string(body))
|
||||
|
||||
var openaiResponse chatCompletionResponse
|
||||
if err := json.Unmarshal(body, &openaiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal openai response: %v", err)
|
||||
}
|
||||
|
||||
// Convert OpenAI response to Claude format
|
||||
claudeResponse := claudeTextGenResponse{
|
||||
Id: openaiResponse.Id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openaiResponse.Model,
|
||||
}
|
||||
|
||||
// Only include usage if it's available
|
||||
if openaiResponse.Usage != nil {
|
||||
claudeResponse.Usage = claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
OutputTokens: openaiResponse.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the first choice content
|
||||
if len(openaiResponse.Choices) > 0 {
|
||||
choice := openaiResponse.Choices[0]
|
||||
if choice.Message != nil {
|
||||
var contents []claudeTextGenContent
|
||||
|
||||
// Add reasoning content (thinking) if present - check both reasoning and reasoning_content fields
|
||||
var reasoningText string
|
||||
if choice.Message.Reasoning != "" {
|
||||
reasoningText = choice.Message.Reasoning
|
||||
} else if choice.Message.ReasoningContent != "" {
|
||||
reasoningText = choice.Message.ReasoningContent
|
||||
}
|
||||
|
||||
if reasoningText != "" {
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "thinking",
|
||||
Signature: "", // OpenAI doesn't provide signature, use empty string
|
||||
Thinking: reasoningText,
|
||||
})
|
||||
log.Debugf("[OpenAI->Claude] Added thinking content: %s", reasoningText)
|
||||
}
|
||||
|
||||
// Add text content if present
|
||||
if choice.Message.StringContent() != "" {
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "text",
|
||||
Text: choice.Message.StringContent(),
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool calls if present
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
for _, toolCall := range choice.Message.ToolCalls {
|
||||
if !toolCall.Function.IsEmpty() {
|
||||
// Parse arguments from JSON string to map
|
||||
var input map[string]interface{}
|
||||
if toolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil {
|
||||
log.Errorf("Failed to parse tool call arguments: %v", err)
|
||||
input = map[string]interface{}{}
|
||||
}
|
||||
} else {
|
||||
input = map[string]interface{}{}
|
||||
}
|
||||
|
||||
contents = append(contents, claudeTextGenContent{
|
||||
Type: "tool_use",
|
||||
Id: toolCall.Id,
|
||||
Name: toolCall.Function.Name,
|
||||
Input: input,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
claudeResponse.Content = contents
|
||||
}
|
||||
|
||||
// Convert finish reason
|
||||
if choice.FinishReason != nil {
|
||||
claudeFinishReason := openAIFinishReasonToClaude(*choice.FinishReason)
|
||||
claudeResponse.StopReason = &claudeFinishReason
|
||||
}
|
||||
}
|
||||
|
||||
result, err := json.Marshal(claudeResponse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to marshal claude response: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Converted Claude response body: %s", string(result))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ConvertOpenAIStreamResponseToClaude converts OpenAI streaming response to Claude format
|
||||
func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrapper.HttpContext, chunk []byte) ([]byte, error) {
|
||||
log.Debugf("[OpenAI->Claude] Original OpenAI streaming chunk: %s", string(chunk))
|
||||
|
||||
// Initialize tool call states if needed
|
||||
if c.toolCallStates == nil {
|
||||
c.toolCallStates = make(map[string]*toolCallState)
|
||||
}
|
||||
|
||||
// For streaming responses, we need to handle the Server-Sent Events format
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
var result strings.Builder
|
||||
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
// Handle [DONE] messages
|
||||
if data == "[DONE]" {
|
||||
log.Debugf("[OpenAI->Claude] Processing [DONE] message, finalizing stream")
|
||||
|
||||
// Send final content_block_stop events for any active blocks
|
||||
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
|
||||
c.thinkingBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Sending final thinking content_block_stop event at index %d", c.thinkingBlockIndex)
|
||||
stopEvent := &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
}
|
||||
if c.textBlockStarted && !c.textBlockStopped {
|
||||
c.textBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Sending final text content_block_stop event at index %d", c.textBlockIndex)
|
||||
stopEvent := &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.textBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
}
|
||||
if c.toolBlockStarted && !c.toolBlockStopped {
|
||||
c.toolBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Sending final tool content_block_stop event at index %d", c.toolBlockIndex)
|
||||
stopEvent := &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.toolBlockIndex,
|
||||
}
|
||||
stopData, _ := json.Marshal(stopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
}
|
||||
|
||||
// If we have a pending stop_reason but no usage, send message_delta with just stop_reason
|
||||
if c.pendingStopReason != nil {
|
||||
log.Debugf("[OpenAI->Claude] Sending final message_delta with pending stop_reason: %s", *c.pendingStopReason)
|
||||
messageDelta := &claudeTextGenStreamResponse{
|
||||
Type: "message_delta",
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "message_delta",
|
||||
StopReason: c.pendingStopReason,
|
||||
},
|
||||
}
|
||||
stopData, _ := json.Marshal(messageDelta)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
c.pendingStopReason = nil
|
||||
}
|
||||
|
||||
if c.messageStartSent && !c.messageStopSent {
|
||||
c.messageStopSent = true
|
||||
log.Debugf("[OpenAI->Claude] Sending final message_stop event")
|
||||
messageStopEvent := &claudeTextGenStreamResponse{
|
||||
Type: "message_stop",
|
||||
}
|
||||
stopData, _ := json.Marshal(messageStopEvent)
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
|
||||
}
|
||||
|
||||
// Reset all state for next request
|
||||
c.messageStartSent = false
|
||||
c.messageStopSent = false
|
||||
c.messageId = ""
|
||||
c.pendingStopReason = nil
|
||||
c.nextContentIndex = 0
|
||||
c.thinkingBlockIndex = -1
|
||||
c.thinkingBlockStarted = false
|
||||
c.thinkingBlockStopped = false
|
||||
c.textBlockIndex = -1
|
||||
c.textBlockStarted = false
|
||||
c.textBlockStopped = false
|
||||
c.toolBlockIndex = -1
|
||||
c.toolBlockStarted = false
|
||||
c.toolBlockStopped = false
|
||||
c.toolCallStates = make(map[string]*toolCallState)
|
||||
log.Debugf("[OpenAI->Claude] Reset converter state for next request")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
var openaiStreamResponse chatCompletionResponse
|
||||
if err := json.Unmarshal([]byte(data), &openaiStreamResponse); err != nil {
|
||||
log.Debugf("unable to unmarshal openai stream response: %v, data: %s", err, data)
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to Claude streaming format
|
||||
claudeStreamResponses := c.buildClaudeStreamResponse(ctx, &openaiStreamResponse)
|
||||
log.Debugf("[OpenAI->Claude] Generated %d Claude stream events from OpenAI chunk", len(claudeStreamResponses))
|
||||
|
||||
for i, claudeStreamResponse := range claudeStreamResponses {
|
||||
responseData, err := json.Marshal(claudeStreamResponse)
|
||||
if err != nil {
|
||||
log.Errorf("unable to marshal claude stream response: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Debugf("[OpenAI->Claude] Stream event [%d/%d]: %s", i+1, len(claudeStreamResponses), string(responseData))
|
||||
result.WriteString(fmt.Sprintf("data: %s\n\n", responseData))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
claudeChunk := []byte(result.String())
|
||||
log.Debugf("[OpenAI->Claude] Converted Claude streaming chunk: %s", string(claudeChunk))
|
||||
return claudeChunk, nil
|
||||
}
|
||||
|
||||
// buildClaudeStreamResponse builds Claude streaming responses from OpenAI streaming response
|
||||
func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpContext, openaiResponse *chatCompletionResponse) []*claudeTextGenStreamResponse {
|
||||
if len(openaiResponse.Choices) == 0 {
|
||||
log.Debugf("[OpenAI->Claude] No choices in OpenAI response, skipping")
|
||||
return nil
|
||||
}
|
||||
|
||||
choice := openaiResponse.Choices[0]
|
||||
var responses []*claudeTextGenStreamResponse
|
||||
|
||||
// Log what we're processing
|
||||
hasRole := choice.Delta != nil && choice.Delta.Role != ""
|
||||
hasContent := choice.Delta != nil && choice.Delta.Content != ""
|
||||
hasFinishReason := choice.FinishReason != nil
|
||||
hasUsage := openaiResponse.Usage != nil
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Processing OpenAI chunk - Role: %v, Content: %v, FinishReason: %v, Usage: %v",
|
||||
hasRole, hasContent, hasFinishReason, hasUsage)
|
||||
|
||||
// Handle message start (only once)
|
||||
// Note: OpenRouter may send multiple messages with role but empty content at the start
|
||||
// We only send message_start for the first one
|
||||
if choice.Delta != nil && choice.Delta.Role != "" && !c.messageStartSent {
|
||||
c.messageId = openaiResponse.Id
|
||||
c.messageStartSent = true
|
||||
|
||||
message := &claudeTextGenResponse{
|
||||
Id: openaiResponse.Id,
|
||||
Type: "message",
|
||||
Role: "assistant",
|
||||
Model: openaiResponse.Model,
|
||||
Content: []claudeTextGenContent{},
|
||||
}
|
||||
|
||||
// Only include usage if it's available
|
||||
if openaiResponse.Usage != nil {
|
||||
message.Usage = claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
OutputTokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "message_start",
|
||||
Message: message,
|
||||
})
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Generated message_start event for id: %s", openaiResponse.Id)
|
||||
} else if choice.Delta != nil && choice.Delta.Role != "" && c.messageStartSent {
|
||||
// Skip duplicate role messages from OpenRouter
|
||||
log.Debugf("[OpenAI->Claude] Skipping duplicate role message for id: %s", openaiResponse.Id)
|
||||
}
|
||||
|
||||
// Handle reasoning content (thinking) first - check both reasoning and reasoning_content fields
|
||||
var reasoningText string
|
||||
if choice.Delta != nil {
|
||||
if choice.Delta.Reasoning != "" {
|
||||
reasoningText = choice.Delta.Reasoning
|
||||
} else if choice.Delta.ReasoningContent != "" {
|
||||
reasoningText = choice.Delta.ReasoningContent
|
||||
}
|
||||
}
|
||||
|
||||
if reasoningText != "" {
|
||||
log.Debugf("[OpenAI->Claude] Processing reasoning content delta: %s", reasoningText)
|
||||
|
||||
// Send content_block_start for thinking only once with dynamic index
|
||||
if !c.thinkingBlockStarted {
|
||||
c.thinkingBlockIndex = c.nextContentIndex
|
||||
c.nextContentIndex++
|
||||
c.thinkingBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for thinking at index %d", c.thinkingBlockIndex)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "thinking",
|
||||
Signature: "", // OpenAI doesn't provide signature
|
||||
Thinking: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Send content_block_delta for thinking
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_delta event with thinking: %s", reasoningText)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_delta",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "thinking_delta", // Use thinking_delta for reasoning content
|
||||
Text: reasoningText,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle content
|
||||
if choice.Delta != nil && choice.Delta.Content != nil && choice.Delta.Content != "" {
|
||||
deltaContent, ok := choice.Delta.Content.(string)
|
||||
if !ok {
|
||||
log.Debugf("[OpenAI->Claude] Content is not a string: %T", choice.Delta.Content)
|
||||
return responses
|
||||
}
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Processing content delta: %s", deltaContent)
|
||||
|
||||
// Close thinking content block if it's still open
|
||||
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
|
||||
c.thinkingBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Closing thinking content block before text")
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
})
|
||||
}
|
||||
|
||||
// Send content_block_start only once for text content with dynamic index
|
||||
if !c.textBlockStarted {
|
||||
c.textBlockIndex = c.nextContentIndex
|
||||
c.nextContentIndex++
|
||||
c.textBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for text at index %d", c.textBlockIndex)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.textBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "text",
|
||||
Text: "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Send content_block_delta
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_delta event with text: %s", deltaContent)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_delta",
|
||||
Index: &c.textBlockIndex,
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "text_delta",
|
||||
Text: deltaContent,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle tool calls in streaming response
|
||||
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
|
||||
for _, toolCall := range choice.Delta.ToolCalls {
|
||||
if !toolCall.Function.IsEmpty() {
|
||||
log.Debugf("[OpenAI->Claude] Processing tool call delta")
|
||||
|
||||
// Get or create tool call state
|
||||
state := c.toolCallStates[toolCall.Id]
|
||||
if state == nil {
|
||||
state = &toolCallState{
|
||||
id: toolCall.Id,
|
||||
name: toolCall.Function.Name,
|
||||
argumentsBuffer: "",
|
||||
isComplete: false,
|
||||
}
|
||||
c.toolCallStates[toolCall.Id] = state
|
||||
log.Debugf("[OpenAI->Claude] Created new tool call state for id: %s, name: %s", toolCall.Id, toolCall.Function.Name)
|
||||
}
|
||||
|
||||
// Accumulate arguments
|
||||
if toolCall.Function.Arguments != "" {
|
||||
state.argumentsBuffer += toolCall.Function.Arguments
|
||||
log.Debugf("[OpenAI->Claude] Accumulated tool arguments: %s", state.argumentsBuffer)
|
||||
}
|
||||
|
||||
// Try to parse accumulated arguments as JSON to check if complete
|
||||
var input map[string]interface{}
|
||||
if state.argumentsBuffer != "" {
|
||||
if err := json.Unmarshal([]byte(state.argumentsBuffer), &input); err == nil {
|
||||
// Successfully parsed - arguments are complete
|
||||
if !state.isComplete {
|
||||
state.isComplete = true
|
||||
log.Debugf("[OpenAI->Claude] Tool call arguments complete for %s: %s", state.name, state.argumentsBuffer)
|
||||
|
||||
// Close thinking content block if it's still open
|
||||
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
|
||||
c.thinkingBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Closing thinking content block before tool use")
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
})
|
||||
}
|
||||
|
||||
// Close text content block if it's still open
|
||||
if c.textBlockStarted && !c.textBlockStopped {
|
||||
c.textBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Closing text content block before tool use")
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.textBlockIndex,
|
||||
})
|
||||
}
|
||||
|
||||
// Send content_block_start for tool_use only when we have complete arguments with dynamic index
|
||||
if !c.toolBlockStarted {
|
||||
c.toolBlockIndex = c.nextContentIndex
|
||||
c.nextContentIndex++
|
||||
c.toolBlockStarted = true
|
||||
log.Debugf("[OpenAI->Claude] Generated content_block_start event for tool_use at index %d", c.toolBlockIndex)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_start",
|
||||
Index: &c.toolBlockIndex,
|
||||
ContentBlock: &claudeTextGenContent{
|
||||
Type: "tool_use",
|
||||
Id: toolCall.Id,
|
||||
Name: state.name,
|
||||
Input: input,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Still accumulating arguments
|
||||
log.Debugf("[OpenAI->Claude] Tool arguments not yet complete, continuing to accumulate: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle finish reason
|
||||
if choice.FinishReason != nil {
|
||||
claudeFinishReason := openAIFinishReasonToClaude(*choice.FinishReason)
|
||||
log.Debugf("[OpenAI->Claude] Processing finish_reason: %s -> %s", *choice.FinishReason, claudeFinishReason)
|
||||
|
||||
// Send content_block_stop for any active content blocks
|
||||
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
|
||||
c.thinkingBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Generated thinking content_block_stop event at index %d", c.thinkingBlockIndex)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.thinkingBlockIndex,
|
||||
})
|
||||
}
|
||||
if c.textBlockStarted && !c.textBlockStopped {
|
||||
c.textBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Generated text content_block_stop event at index %d", c.textBlockIndex)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.textBlockIndex,
|
||||
})
|
||||
}
|
||||
if c.toolBlockStarted && !c.toolBlockStopped {
|
||||
c.toolBlockStopped = true
|
||||
log.Debugf("[OpenAI->Claude] Generated tool content_block_stop event at index %d", c.toolBlockIndex)
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "content_block_stop",
|
||||
Index: &c.toolBlockIndex,
|
||||
})
|
||||
}
|
||||
|
||||
// Cache stop_reason until we get usage info (Claude protocol requires them together)
|
||||
c.pendingStopReason = &claudeFinishReason
|
||||
log.Debugf("[OpenAI->Claude] Cached stop_reason: %s, waiting for usage", claudeFinishReason)
|
||||
}
|
||||
|
||||
// Handle usage information
|
||||
if openaiResponse.Usage != nil && choice.FinishReason == nil {
|
||||
log.Debugf("[OpenAI->Claude] Processing usage info - input: %d, output: %d",
|
||||
openaiResponse.Usage.PromptTokens, openaiResponse.Usage.CompletionTokens)
|
||||
|
||||
// Send message_delta with both stop_reason and usage (Claude protocol requirement)
|
||||
messageDelta := &claudeTextGenStreamResponse{
|
||||
Type: "message_delta",
|
||||
Delta: &claudeTextGenDelta{
|
||||
Type: "message_delta",
|
||||
},
|
||||
Usage: &claudeTextGenUsage{
|
||||
InputTokens: openaiResponse.Usage.PromptTokens,
|
||||
OutputTokens: openaiResponse.Usage.CompletionTokens,
|
||||
},
|
||||
}
|
||||
|
||||
// Include cached stop_reason if available
|
||||
if c.pendingStopReason != nil {
|
||||
log.Debugf("[OpenAI->Claude] Combining cached stop_reason %s with usage", *c.pendingStopReason)
|
||||
messageDelta.Delta.StopReason = c.pendingStopReason
|
||||
c.pendingStopReason = nil // Clear cache
|
||||
}
|
||||
|
||||
log.Debugf("[OpenAI->Claude] Generated message_delta event with usage and stop_reason")
|
||||
responses = append(responses, messageDelta)
|
||||
|
||||
// Send message_stop after combined message_delta
|
||||
if !c.messageStopSent {
|
||||
c.messageStopSent = true
|
||||
log.Debugf("[OpenAI->Claude] Generated message_stop event")
|
||||
responses = append(responses, &claudeTextGenStreamResponse{
|
||||
Type: "message_stop",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return responses
|
||||
}
|
||||
|
||||
// openAIFinishReasonToClaude converts OpenAI finish reason to Claude format
|
||||
func openAIFinishReasonToClaude(reason string) string {
|
||||
switch reason {
|
||||
case finishReasonStop:
|
||||
return "end_turn"
|
||||
case finishReasonLength:
|
||||
return "max_tokens"
|
||||
case finishReasonToolCall:
|
||||
return "tool_use"
|
||||
default:
|
||||
return reason
|
||||
}
|
||||
}
|
||||
|
||||
// convertContentArray converts an array of Claude content to OpenAI content format
|
||||
func (c *ClaudeToOpenAIConverter) convertContentArray(claudeContents []claudeChatMessageContent) *contentConversionResult {
|
||||
result := &contentConversionResult{
|
||||
textParts: []string{},
|
||||
toolCalls: []toolCall{},
|
||||
toolResults: []claudeChatMessageContent{},
|
||||
openaiContents: []chatMessageContent{},
|
||||
hasNonTextContent: false,
|
||||
}
|
||||
|
||||
for _, claudeContent := range claudeContents {
|
||||
switch claudeContent.Type {
|
||||
case "text":
|
||||
if claudeContent.Text != "" {
|
||||
result.textParts = append(result.textParts, claudeContent.Text)
|
||||
result.openaiContents = append(result.openaiContents, chatMessageContent{
|
||||
Type: contentTypeText,
|
||||
Text: claudeContent.Text,
|
||||
})
|
||||
}
|
||||
case "image":
|
||||
result.hasNonTextContent = true
|
||||
if claudeContent.Source != nil {
|
||||
if claudeContent.Source.Type == "base64" {
|
||||
// Convert base64 image to OpenAI format
|
||||
dataUrl := fmt.Sprintf("data:%s;base64,%s", claudeContent.Source.MediaType, claudeContent.Source.Data)
|
||||
result.openaiContents = append(result.openaiContents, chatMessageContent{
|
||||
Type: contentTypeImageUrl,
|
||||
ImageUrl: &chatMessageContentImageUrl{
|
||||
Url: dataUrl,
|
||||
},
|
||||
})
|
||||
} else if claudeContent.Source.Type == "url" {
|
||||
result.openaiContents = append(result.openaiContents, chatMessageContent{
|
||||
Type: contentTypeImageUrl,
|
||||
ImageUrl: &chatMessageContentImageUrl{
|
||||
Url: claudeContent.Source.Url,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
result.hasNonTextContent = true
|
||||
// Convert Claude tool_use to OpenAI tool_calls format
|
||||
if claudeContent.Id != "" && claudeContent.Name != "" {
|
||||
// Convert input to JSON string for OpenAI format
|
||||
var argumentsStr string
|
||||
if claudeContent.Input != nil {
|
||||
if argBytes, err := json.Marshal(claudeContent.Input); err == nil {
|
||||
argumentsStr = string(argBytes)
|
||||
}
|
||||
}
|
||||
|
||||
toolCall := toolCall{
|
||||
Id: claudeContent.Id,
|
||||
Type: "function",
|
||||
Function: functionCall{
|
||||
Name: claudeContent.Name,
|
||||
Arguments: argumentsStr,
|
||||
},
|
||||
}
|
||||
result.toolCalls = append(result.toolCalls, toolCall)
|
||||
log.Debugf("[Claude->OpenAI] Converted tool_use to tool_call: %s", claudeContent.Name)
|
||||
}
|
||||
case "tool_result":
|
||||
result.hasNonTextContent = true
|
||||
// Store tool results for processing
|
||||
result.toolResults = append(result.toolResults, claudeContent)
|
||||
log.Debugf("[Claude->OpenAI] Found tool_result for tool_use_id: %s", claudeContent.ToolUseId)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,727 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Mock logger for testing
|
||||
type mockLogger struct{}
|
||||
|
||||
func (m *mockLogger) Trace(msg string) {}
|
||||
func (m *mockLogger) Tracef(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Debug(msg string) {}
|
||||
func (m *mockLogger) Debugf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Info(msg string) {}
|
||||
func (m *mockLogger) Infof(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Warn(msg string) {}
|
||||
func (m *mockLogger) Warnf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Error(msg string) {}
|
||||
func (m *mockLogger) Errorf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) Critical(msg string) {}
|
||||
func (m *mockLogger) Criticalf(format string, args ...interface{}) {}
|
||||
func (m *mockLogger) ResetID(pluginID string) {}
|
||||
|
||||
func init() {
|
||||
// Initialize mock logger for testing
|
||||
log.SetPluginLog(&mockLogger{})
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
t.Run("convert_multiple_text_content_blocks", func(t *testing.T) {
|
||||
// Test case for the bug fix: multiple text content blocks should be merged into a single string
|
||||
claudeRequest := `{
|
||||
"max_tokens": 32000,
|
||||
"messages": [{
|
||||
"content": [{
|
||||
"text": "<system-reminder>\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.</system-reminder>",
|
||||
"type": "text"
|
||||
}, {
|
||||
"text": "<system-reminder>\nyyy</system-reminder>",
|
||||
"type": "text"
|
||||
}, {
|
||||
"cache_control": {
|
||||
"type": "ephemeral"
|
||||
},
|
||||
"text": "你是谁",
|
||||
"type": "text"
|
||||
}],
|
||||
"role": "user"
|
||||
}],
|
||||
"metadata": {
|
||||
"user_id": "user_dd3c52c1d698a4486bdef490197846b7c1f7e553202dae5763f330c35aeb9823_account__session_b2e14122-0ac6-4959-9c5d-b49ae01ccb7c"
|
||||
},
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"stream": true,
|
||||
"system": [{
|
||||
"cache_control": {
|
||||
"type": "ephemeral"
|
||||
},
|
||||
"text": "xxx",
|
||||
"type": "text"
|
||||
}, {
|
||||
"cache_control": {
|
||||
"type": "ephemeral"
|
||||
},
|
||||
"text": "yyy",
|
||||
"type": "text"
|
||||
}],
|
||||
"temperature": 1,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the result to verify the conversion
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify basic fields are converted correctly
|
||||
assert.Equal(t, "anthropic/claude-sonnet-4", openaiRequest.Model)
|
||||
assert.Equal(t, true, openaiRequest.Stream)
|
||||
assert.Equal(t, 1.0, openaiRequest.Temperature)
|
||||
assert.Equal(t, 32000, openaiRequest.MaxTokens)
|
||||
|
||||
// Verify messages structure
|
||||
require.Len(t, openaiRequest.Messages, 2)
|
||||
|
||||
// First message should be system message (converted from Claude's system field)
|
||||
systemMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, roleSystem, systemMsg.Role)
|
||||
assert.Equal(t, "xxx\nyyy", systemMsg.Content) // Claude system uses single \n
|
||||
|
||||
// Second message should be user message with merged text content
|
||||
userMsg := openaiRequest.Messages[1]
|
||||
assert.Equal(t, "user", userMsg.Role)
|
||||
|
||||
// The key fix: multiple text blocks should be merged into a single string
|
||||
expectedContent := "<system-reminder>\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.</system-reminder>\n\n<system-reminder>\nyyy</system-reminder>\n\n你是谁"
|
||||
assert.Equal(t, expectedContent, userMsg.Content)
|
||||
})
|
||||
|
||||
t.Run("convert_mixed_content_with_image", func(t *testing.T) {
|
||||
// Test case with mixed text and image content (should remain as array)
|
||||
claudeRequest := `{
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "What's in this image?"
|
||||
}, {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
|
||||
}
|
||||
}]
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one user message
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
userMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "user", userMsg.Role)
|
||||
|
||||
// Content should be an array (mixed content) - after JSON marshaling/unmarshaling it becomes []interface{}
|
||||
contentArray, ok := userMsg.Content.([]interface{})
|
||||
require.True(t, ok, "Content should be an array for mixed content")
|
||||
require.Len(t, contentArray, 2)
|
||||
|
||||
// First element should be text
|
||||
firstElement, ok := contentArray[0].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, contentTypeText, firstElement["type"])
|
||||
assert.Equal(t, "What's in this image?", firstElement["text"])
|
||||
|
||||
// Second element should be image
|
||||
secondElement, ok := contentArray[1].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, contentTypeImageUrl, secondElement["type"])
|
||||
assert.NotNil(t, secondElement["image_url"])
|
||||
imageUrl, ok := secondElement["image_url"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
assert.Contains(t, imageUrl["url"], "data:image/jpeg;base64,")
|
||||
})
|
||||
|
||||
t.Run("convert_simple_string_content", func(t *testing.T) {
|
||||
// Test case with simple string content
|
||||
claudeRequest := `{
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
userMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "user", userMsg.Role)
|
||||
assert.Equal(t, "Hello, how are you?", userMsg.Content)
|
||||
})
|
||||
|
||||
t.Run("convert_empty_content_array", func(t *testing.T) {
|
||||
// Test case with empty content array
|
||||
claudeRequest := `{
|
||||
"model": "claude-3-sonnet-20240229",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": []
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
userMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "user", userMsg.Role)
|
||||
|
||||
// Empty array should result in empty array, not string - after JSON marshaling/unmarshaling becomes []interface{}
|
||||
if userMsg.Content != nil {
|
||||
contentArray, ok := userMsg.Content.([]interface{})
|
||||
require.True(t, ok, "Empty content should be an array")
|
||||
assert.Empty(t, contentArray)
|
||||
} else {
|
||||
// null is also acceptable for empty content
|
||||
assert.Nil(t, userMsg.Content)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("convert_tool_use_to_tool_calls", func(t *testing.T) {
|
||||
// Test Claude tool_use conversion to OpenAI tool_calls format
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "I'll help you search for information."
|
||||
}, {
|
||||
"type": "tool_use",
|
||||
"id": "toolu_01D7FLrfh4GYq7yT1ULFeyMV",
|
||||
"name": "web_search",
|
||||
"input": {
|
||||
"query": "Claude AI capabilities",
|
||||
"max_results": 5
|
||||
}
|
||||
}]
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one assistant message with tool_calls
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
assistantMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "assistant", assistantMsg.Role)
|
||||
assert.Equal(t, "I'll help you search for information.", assistantMsg.Content)
|
||||
|
||||
// Verify tool_calls format
|
||||
require.NotNil(t, assistantMsg.ToolCalls)
|
||||
require.Len(t, assistantMsg.ToolCalls, 1)
|
||||
|
||||
toolCall := assistantMsg.ToolCalls[0]
|
||||
assert.Equal(t, "toolu_01D7FLrfh4GYq7yT1ULFeyMV", toolCall.Id)
|
||||
assert.Equal(t, "function", toolCall.Type)
|
||||
assert.Equal(t, "web_search", toolCall.Function.Name)
|
||||
|
||||
// Verify arguments are properly JSON encoded
|
||||
var args map[string]interface{}
|
||||
err = json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Claude AI capabilities", args["query"])
|
||||
assert.Equal(t, float64(5), args["max_results"])
|
||||
})
|
||||
|
||||
t.Run("convert_tool_result_to_tool_message", func(t *testing.T) {
|
||||
// Test Claude tool_result conversion to OpenAI tool message format
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_01D7FLrfh4GYq7yT1ULFeyMV",
|
||||
"content": "Search results: Claude is an AI assistant created by Anthropic."
|
||||
}]
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one tool message
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
toolMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "tool", toolMsg.Role)
|
||||
assert.Equal(t, "Search results: Claude is an AI assistant created by Anthropic.", toolMsg.Content)
|
||||
assert.Equal(t, "toolu_01D7FLrfh4GYq7yT1ULFeyMV", toolMsg.ToolCallId)
|
||||
})
|
||||
|
||||
t.Run("convert_multiple_tool_calls", func(t *testing.T) {
|
||||
// Test multiple tool_use in single message
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_search",
|
||||
"name": "web_search",
|
||||
"input": {"query": "weather"}
|
||||
}, {
|
||||
"type": "tool_use",
|
||||
"id": "toolu_calc",
|
||||
"name": "calculate",
|
||||
"input": {"expression": "2+2"}
|
||||
}]
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one assistant message with multiple tool_calls
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
assistantMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "assistant", assistantMsg.Role)
|
||||
assert.Nil(t, assistantMsg.Content) // No text content, so should be null
|
||||
|
||||
// Verify multiple tool_calls
|
||||
require.NotNil(t, assistantMsg.ToolCalls)
|
||||
require.Len(t, assistantMsg.ToolCalls, 2)
|
||||
|
||||
// First tool call
|
||||
assert.Equal(t, "toolu_search", assistantMsg.ToolCalls[0].Id)
|
||||
assert.Equal(t, "web_search", assistantMsg.ToolCalls[0].Function.Name)
|
||||
|
||||
// Second tool call
|
||||
assert.Equal(t, "toolu_calc", assistantMsg.ToolCalls[1].Id)
|
||||
assert.Equal(t, "calculate", assistantMsg.ToolCalls[1].Function.Name)
|
||||
})
|
||||
|
||||
t.Run("convert_multiple_tool_results", func(t *testing.T) {
|
||||
// Test multiple tool_result messages
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_search",
|
||||
"content": "Weather: 25°C sunny"
|
||||
}, {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_calc",
|
||||
"content": "Result: 4"
|
||||
}]
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have two tool messages
|
||||
require.Len(t, openaiRequest.Messages, 2)
|
||||
|
||||
// First tool result
|
||||
toolMsg1 := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "tool", toolMsg1.Role)
|
||||
assert.Equal(t, "Weather: 25°C sunny", toolMsg1.Content)
|
||||
assert.Equal(t, "toolu_search", toolMsg1.ToolCallId)
|
||||
|
||||
// Second tool result
|
||||
toolMsg2 := openaiRequest.Messages[1]
|
||||
assert.Equal(t, "tool", toolMsg2.Role)
|
||||
assert.Equal(t, "Result: 4", toolMsg2.Content)
|
||||
assert.Equal(t, "toolu_calc", toolMsg2.ToolCallId)
|
||||
})
|
||||
|
||||
t.Run("convert_mixed_text_and_tool_use", func(t *testing.T) {
|
||||
// Test message with both text and tool_use
|
||||
claudeRequest := `{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Let me search for that information and do a calculation."
|
||||
}, {
|
||||
"type": "tool_use",
|
||||
"id": "toolu_search123",
|
||||
"name": "search_database",
|
||||
"input": {"table": "users", "limit": 10}
|
||||
}]
|
||||
}],
|
||||
"max_tokens": 1000
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
|
||||
require.NoError(t, err)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should have one assistant message with both content and tool_calls
|
||||
require.Len(t, openaiRequest.Messages, 1)
|
||||
assistantMsg := openaiRequest.Messages[0]
|
||||
assert.Equal(t, "assistant", assistantMsg.Role)
|
||||
assert.Equal(t, "Let me search for that information and do a calculation.", assistantMsg.Content)
|
||||
|
||||
// Should have tool_calls
|
||||
require.NotNil(t, assistantMsg.ToolCalls)
|
||||
require.Len(t, assistantMsg.ToolCalls, 1)
|
||||
assert.Equal(t, "toolu_search123", assistantMsg.ToolCalls[0].Id)
|
||||
assert.Equal(t, "search_database", assistantMsg.ToolCalls[0].Function.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_ConvertOpenAIResponseToClaude(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
t.Run("convert_tool_calls_response", func(t *testing.T) {
|
||||
// Test OpenAI response with tool calls conversion to Claude format
|
||||
openaiResponse := `{
|
||||
"id": "gen-1756214072-tVFkPBV6lxee00IqNAC5",
|
||||
"provider": "Google",
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"object": "chat.completion",
|
||||
"created": 1756214072,
|
||||
"choices": [{
|
||||
"logprobs": null,
|
||||
"finish_reason": "tool_calls",
|
||||
"native_finish_reason": "tool_calls",
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'll analyze the README file to understand this project's purpose.",
|
||||
"refusal": null,
|
||||
"reasoning": null,
|
||||
"tool_calls": [{
|
||||
"id": "toolu_vrtx_017ijjgx8hpigatPzzPW59Wq",
|
||||
"index": 0,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Read",
|
||||
"arguments": "{\"file_path\": \"/Users/zhangty/git/higress/README.md\"}"
|
||||
}
|
||||
}]
|
||||
}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 14923,
|
||||
"completion_tokens": 81,
|
||||
"total_tokens": 15004
|
||||
}
|
||||
}`
|
||||
|
||||
result, err := converter.ConvertOpenAIResponseToClaude(nil, []byte(openaiResponse))
|
||||
require.NoError(t, err)
|
||||
|
||||
var claudeResponse claudeTextGenResponse
|
||||
err = json.Unmarshal(result, &claudeResponse)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify basic response fields
|
||||
assert.Equal(t, "gen-1756214072-tVFkPBV6lxee00IqNAC5", claudeResponse.Id)
|
||||
assert.Equal(t, "message", claudeResponse.Type)
|
||||
assert.Equal(t, "assistant", claudeResponse.Role)
|
||||
assert.Equal(t, "anthropic/claude-sonnet-4", claudeResponse.Model)
|
||||
assert.Equal(t, "tool_use", *claudeResponse.StopReason)
|
||||
|
||||
// Verify usage
|
||||
assert.Equal(t, 14923, claudeResponse.Usage.InputTokens)
|
||||
assert.Equal(t, 81, claudeResponse.Usage.OutputTokens)
|
||||
|
||||
// Verify content array has both text and tool_use
|
||||
require.Len(t, claudeResponse.Content, 2)
|
||||
|
||||
// First content should be text
|
||||
textContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, "I'll analyze the README file to understand this project's purpose.", textContent.Text)
|
||||
|
||||
// Second content should be tool_use
|
||||
toolContent := claudeResponse.Content[1]
|
||||
assert.Equal(t, "tool_use", toolContent.Type)
|
||||
assert.Equal(t, "toolu_vrtx_017ijjgx8hpigatPzzPW59Wq", toolContent.Id)
|
||||
assert.Equal(t, "Read", toolContent.Name)
|
||||
|
||||
// Verify tool arguments
|
||||
require.NotNil(t, toolContent.Input)
|
||||
assert.Equal(t, "/Users/zhangty/git/higress/README.md", toolContent.Input["file_path"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claudeRequest string
|
||||
expectedMaxTokens int
|
||||
expectedEffort string
|
||||
expectThinkingConfig bool
|
||||
}{
|
||||
{
|
||||
name: "thinking_enabled_low",
|
||||
claudeRequest: `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 2048}
|
||||
}`,
|
||||
expectedMaxTokens: 2048,
|
||||
expectedEffort: "low",
|
||||
expectThinkingConfig: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_enabled_medium",
|
||||
claudeRequest: `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8192}
|
||||
}`,
|
||||
expectedMaxTokens: 8192,
|
||||
expectedEffort: "medium",
|
||||
expectThinkingConfig: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_enabled_high",
|
||||
claudeRequest: `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 20480}
|
||||
}`,
|
||||
expectedMaxTokens: 20480,
|
||||
expectedEffort: "high",
|
||||
expectThinkingConfig: true,
|
||||
},
|
||||
{
|
||||
name: "thinking_disabled",
|
||||
claudeRequest: `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"thinking": {"type": "disabled"}
|
||||
}`,
|
||||
expectedMaxTokens: 0,
|
||||
expectedEffort: "",
|
||||
expectThinkingConfig: false,
|
||||
},
|
||||
{
|
||||
name: "no_thinking",
|
||||
claudeRequest: `{
|
||||
"model": "claude-sonnet-4",
|
||||
"max_tokens": 1000,
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}`,
|
||||
expectedMaxTokens: 0,
|
||||
expectedEffort: "",
|
||||
expectThinkingConfig: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(tt.claudeRequest))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
|
||||
var openaiRequest chatCompletionRequest
|
||||
err = json.Unmarshal(result, &openaiRequest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tt.expectThinkingConfig {
|
||||
assert.Equal(t, tt.expectedMaxTokens, openaiRequest.ReasoningMaxTokens)
|
||||
assert.Equal(t, tt.expectedEffort, openaiRequest.ReasoningEffort)
|
||||
} else {
|
||||
assert.Equal(t, 0, openaiRequest.ReasoningMaxTokens)
|
||||
assert.Equal(t, "", openaiRequest.ReasoningEffort)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIConverter_ConvertReasoningResponseToClaude(t *testing.T) {
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
openaiResponse string
|
||||
expectThinking bool
|
||||
expectedText string
|
||||
}{
|
||||
{
|
||||
name: "response_with_reasoning_content",
|
||||
openaiResponse: `{
|
||||
"id": "chatcmpl-test123",
|
||||
"object": "chat.completion",
|
||||
"created": 1699999999,
|
||||
"model": "gpt-4o",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Based on my analysis, the answer is 42.",
|
||||
"reasoning_content": "Let me think about this step by step:\n1. The question asks about the meaning of life\n2. According to Douglas Adams, the answer is 42\n3. Therefore, 42 is the correct answer"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}`,
|
||||
expectThinking: true,
|
||||
expectedText: "Based on my analysis, the answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "response_with_reasoning_field",
|
||||
openaiResponse: `{
|
||||
"id": "chatcmpl-test789",
|
||||
"object": "chat.completion",
|
||||
"created": 1699999999,
|
||||
"model": "gpt-4o",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Based on my analysis, the answer is 42.",
|
||||
"reasoning": "Let me think about this step by step:\n1. The question asks about the meaning of life\n2. According to Douglas Adams, the answer is 42\n3. Therefore, 42 is the correct answer"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30
|
||||
}
|
||||
}`,
|
||||
expectThinking: true,
|
||||
expectedText: "Based on my analysis, the answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "response_without_reasoning_content",
|
||||
openaiResponse: `{
|
||||
"id": "chatcmpl-test456",
|
||||
"object": "chat.completion",
|
||||
"created": 1699999999,
|
||||
"model": "gpt-4o",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The answer is 42."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}`,
|
||||
expectThinking: false,
|
||||
expectedText: "The answer is 42.",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := converter.ConvertOpenAIResponseToClaude(nil, []byte(tt.openaiResponse))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
|
||||
var claudeResponse claudeTextGenResponse
|
||||
err = json.Unmarshal(result, &claudeResponse)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify response structure
|
||||
assert.Equal(t, "message", claudeResponse.Type)
|
||||
assert.Equal(t, "assistant", claudeResponse.Role)
|
||||
assert.NotEmpty(t, claudeResponse.Id) // ID should be present
|
||||
|
||||
if tt.expectThinking {
|
||||
// Should have both thinking and text content
|
||||
assert.Len(t, claudeResponse.Content, 2)
|
||||
|
||||
// First should be thinking
|
||||
thinkingContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "thinking", thinkingContent.Type)
|
||||
assert.Equal(t, "", thinkingContent.Signature) // OpenAI doesn't provide signature
|
||||
assert.Contains(t, thinkingContent.Thinking, "Let me think about this step by step")
|
||||
|
||||
// Second should be text
|
||||
textContent := claudeResponse.Content[1]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, tt.expectedText, textContent.Text)
|
||||
} else {
|
||||
// Should only have text content
|
||||
assert.Len(t, claudeResponse.Content, 1)
|
||||
|
||||
textContent := claudeResponse.Content[0]
|
||||
assert.Equal(t, "text", textContent.Type)
|
||||
assert.Equal(t, tt.expectedText, textContent.Text)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,24 +5,21 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
)
|
||||
|
||||
// deepseekProvider is the provider for deepseek Ai service.
|
||||
|
||||
const (
|
||||
deepseekDomain = "api.deepseek.com"
|
||||
// TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion
|
||||
// accourding to the docs, the path should be /chat/completions, need to be verified
|
||||
deepseekChatCompletionPath = "/v1/chat/completions"
|
||||
deepseekDomain = "api.deepseek.com"
|
||||
deepseekAnthropicMessagesPath = "/anthropic/v1/messages"
|
||||
)
|
||||
|
||||
type deepseekProviderInitializer struct {
|
||||
}
|
||||
type deepseekProviderInitializer struct{}
|
||||
|
||||
func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
if len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
@@ -30,7 +27,9 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err
|
||||
|
||||
func (m *deepseekProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): deepseekChatCompletionPath,
|
||||
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
|
||||
string(ApiNameModels): PathOpenAIModels,
|
||||
string(ApiNameAnthropicMessages): deepseekAnthropicMessagesPath,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
|
||||
return fmt.Errorf("failed to init apiTokens: %v", err)
|
||||
}
|
||||
|
||||
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
|
||||
wrapper.RegisterTickFunc(c.failover.healthCheckInterval, func() {
|
||||
// Only the Wasm VM that successfully acquires the lease will perform health check
|
||||
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID) {
|
||||
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
@@ -28,6 +34,12 @@ const (
|
||||
geminiImageGenerationPath = "predict"
|
||||
)
|
||||
|
||||
var geminiThinkingModels = map[string]bool{
|
||||
"gemini-2.5-pro": true,
|
||||
"gemini-2.5-flash": true,
|
||||
"gemini-2.5-flash-lite": true,
|
||||
}
|
||||
|
||||
type geminiProviderInitializer struct{}
|
||||
|
||||
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
@@ -53,12 +65,17 @@ func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
|
||||
return &geminiProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
client: wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: geminiDomain,
|
||||
}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type geminiProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func (g *geminiProvider) GetProviderType() string {
|
||||
@@ -77,11 +94,47 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "")
|
||||
}
|
||||
|
||||
// to support the multimodal for gemini, we can't reuse the config's handleRequestBody
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
|
||||
if g.config.firstByteTimeout != 0 && g.config.isStreamingAPI(apiName, body) {
|
||||
err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms",
|
||||
strconv.FormatUint(uint64(g.config.firstByteTimeout), 10))
|
||||
if err != nil {
|
||||
log.Errorf("failed to set timeout header: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if g.config.IsOriginal() {
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
headers := util.GetRequestHeaders()
|
||||
request, err := g.TransformRequestBodyHeaders(ctx, apiName, body, headers)
|
||||
if err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
util.ReplaceRequestHeaders(headers)
|
||||
|
||||
if apiName == ApiNameChatCompletion {
|
||||
if g.config.context != nil {
|
||||
err = g.contextCache.GetContextFromFile(ctx, g, body)
|
||||
if err == nil {
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
}
|
||||
|
||||
if action, err := g.processImageURL(ctx, request); err != nil {
|
||||
return action, err
|
||||
} else {
|
||||
return action, replaceRequestBody(request)
|
||||
}
|
||||
|
||||
}
|
||||
return types.ActionContinue, replaceRequestBody(request)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
@@ -365,6 +418,7 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
Threshold: threshold,
|
||||
})
|
||||
}
|
||||
|
||||
geminiRequest := geminiGenerationContentRequest{
|
||||
Contents: make([]geminiChatContent, 0, len(request.Messages)),
|
||||
SafetySettings: safetySettings,
|
||||
@@ -379,6 +433,13 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
},
|
||||
}
|
||||
|
||||
if geminiThinkingModels[request.Model] {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &geminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
ThinkingBudget: g.config.geminiThinkingBudget,
|
||||
}
|
||||
}
|
||||
|
||||
if request.Tools != nil {
|
||||
functions := make([]function, 0, len(request.Tools))
|
||||
for _, tool := range request.Tools {
|
||||
@@ -393,12 +454,21 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
// shouldAddDummyModelMessage := false
|
||||
for _, message := range request.Messages {
|
||||
content := geminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []geminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
Role: message.Role,
|
||||
Parts: []geminiPart{},
|
||||
}
|
||||
|
||||
for _, c := range message.ParseContent() {
|
||||
switch c.Type {
|
||||
case contentTypeText:
|
||||
content.Parts = append(content.Parts, geminiPart{
|
||||
Text: c.Text,
|
||||
})
|
||||
case contentTypeImageUrl:
|
||||
content.Parts = append(content.Parts, g.handleContentTypeImageUrl(c.ImageUrl))
|
||||
default:
|
||||
log.Debugf("currently gemini did not support this type: %s", c.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if role is not user or model
|
||||
@@ -417,6 +487,176 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
func (g *geminiProvider) countImageUrl(request *geminiGenerationContentRequest) int {
|
||||
totalImages := 0
|
||||
for _, c := range request.Contents {
|
||||
for _, p := range c.Parts {
|
||||
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
|
||||
totalImages += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
return totalImages
|
||||
}
|
||||
|
||||
func (g *geminiProvider) processImageURL(ctx wrapper.HttpContext, body []byte) (types.Action, error) {
|
||||
request := &geminiGenerationContentRequest{}
|
||||
err := json.Unmarshal(body, request)
|
||||
if err != nil {
|
||||
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal")
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
var totalImages int
|
||||
if totalImages = g.countImageUrl(request); totalImages == 0 {
|
||||
// there are no images return directly
|
||||
return types.ActionContinue, replaceRequestBody(body)
|
||||
}
|
||||
|
||||
if err := g.processImageURLWithCallback(ctx, body, totalImages, func(body []byte, err error) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed to get image while handle multi modal: %v", err)
|
||||
util.ErrorHandler("ai-proxy.gemini.fetch_image_failed", err)
|
||||
return
|
||||
}
|
||||
// replace the request
|
||||
if err := replaceRequestBody(body); err != nil {
|
||||
util.ErrorHandler("ai-proxy.gemini.replace_request_body_failed", err)
|
||||
}
|
||||
}); err != nil {
|
||||
return types.ActionContinue, err
|
||||
}
|
||||
|
||||
return types.ActionPause, nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) processImageURLWithCallback(ctx wrapper.HttpContext, body []byte, totalImages int, callback func([]byte, error)) error {
|
||||
request := &geminiGenerationContentRequest{}
|
||||
err := json.Unmarshal(body, request)
|
||||
if err != nil {
|
||||
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
pending := totalImages
|
||||
var callbackErr []error
|
||||
|
||||
for ci, c := range request.Contents {
|
||||
for pi := range c.Parts {
|
||||
p := &request.Contents[ci].Parts[pi]
|
||||
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
|
||||
g.getImageInlineDataWithCallback(p.InlineData.Data, func(gid *geminiInlineData, err error) {
|
||||
if err != nil {
|
||||
log.Errorf("image %s fetch failed: %v", p.InlineData.Data, err)
|
||||
callbackErr = append(callbackErr, err)
|
||||
} else {
|
||||
*p.InlineData = *gid
|
||||
}
|
||||
|
||||
pending -= 1
|
||||
if pending == 0 {
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
log.Errorf("failed to marshal request while processImageURL: %v", err)
|
||||
callbackErr = append(callbackErr, err)
|
||||
}
|
||||
callback(body, errors.Join(callbackErr...))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) handleContentTypeImageUrl(c *chatMessageContentImageUrl) (part geminiPart) {
|
||||
if g.isUrl(c.Url) {
|
||||
part.InlineData = &geminiInlineData{
|
||||
Data: c.Url,
|
||||
}
|
||||
return
|
||||
}
|
||||
part.InlineData = g.baseStr2InlineData(c.Url)
|
||||
return
|
||||
}
|
||||
|
||||
func (g *geminiProvider) isUrl(raw string) bool {
|
||||
u, err := url.Parse(raw)
|
||||
return err == nil && (u.Scheme == "http" || u.Scheme == "https")
|
||||
}
|
||||
|
||||
func (g *geminiProvider) baseStr2InlineData(baseStr string) *geminiInlineData {
|
||||
if strings.HasPrefix(baseStr, "data:") {
|
||||
p := strings.SplitN(baseStr, ";", 2)
|
||||
if len(p) != 2 {
|
||||
log.Errorf("invalid base64 string: %s", p)
|
||||
return nil
|
||||
}
|
||||
|
||||
mime := strings.TrimPrefix(p[0], "data:")
|
||||
baseData := strings.TrimPrefix(p[1], "base64,")
|
||||
return &geminiInlineData{
|
||||
MimeType: mime,
|
||||
Data: baseData,
|
||||
}
|
||||
}
|
||||
log.Errorf("invalid base64 string: %s", baseStr)
|
||||
return &geminiInlineData{
|
||||
MimeType: "",
|
||||
Data: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) getImageInlineDataWithCallback(raw string, callback func(*geminiInlineData, error)) {
|
||||
|
||||
responseCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
callback(nil, fmt.Errorf("get %s failed, status: %v", raw, statusCode))
|
||||
return
|
||||
}
|
||||
resReader := bytes.NewReader(responseBody)
|
||||
const maxSize = 100 << 20
|
||||
data, err := io.ReadAll(io.LimitReader(resReader, maxSize+1))
|
||||
if err != nil {
|
||||
callback(nil, fmt.Errorf("read %v response data failed: %v", raw, err))
|
||||
return
|
||||
}
|
||||
if len(data) > maxSize {
|
||||
callback(nil, fmt.Errorf("%v exceed max image size 100MB", raw))
|
||||
return
|
||||
}
|
||||
|
||||
mimeType := http.DetectContentType(data)
|
||||
base64Data := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
callback(&geminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
}, nil)
|
||||
}
|
||||
|
||||
timeout := (time.Second * 30).Milliseconds()
|
||||
|
||||
headers := [][2]string{
|
||||
{"Accept", "image/*"},
|
||||
{"User-Agent", "Mozilla/5.0 (compatible; AI-Proxy/1.0)"},
|
||||
{"Referer", "https://www.google.com/"},
|
||||
}
|
||||
if g.client == nil {
|
||||
log.Error("client is nil")
|
||||
return
|
||||
}
|
||||
err := g.client.Get(raw, headers, responseCallback, uint32(timeout))
|
||||
if err != nil {
|
||||
log.Errorf("failed to get image %s data", raw)
|
||||
callback(nil, fmt.Errorf("failed to get image %s", raw))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
|
||||
systemContents := []geminiChatContent{{
|
||||
Role: roleUser,
|
||||
|
||||
75
plugins/wasm-go/extensions/ai-proxy/provider/grok.go
Normal file
75
plugins/wasm-go/extensions/ai-proxy/provider/grok.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
// grokProvider is the provider for Grok service.
|
||||
const (
|
||||
grokDomain = "api.x.ai"
|
||||
grokChatCompletionPath = "/v1/chat/completions"
|
||||
)
|
||||
|
||||
type grokProviderInitializer struct{}
|
||||
|
||||
func (g *grokProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if config.apiTokens == nil || len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *grokProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): grokChatCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *grokProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(g.DefaultCapabilities())
|
||||
return &grokProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type grokProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (g *grokProvider) GetProviderType() string {
|
||||
return providerTypeGrok
|
||||
}
|
||||
|
||||
func (g *grokProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *grokProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (g *grokProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, grokDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (g *grokProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, grokChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -8,10 +8,10 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/log"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,12 @@ const (
|
||||
reasoningEndTag = "</think>"
|
||||
)
|
||||
|
||||
type NonOpenAIStyleOptions struct {
|
||||
ReasoningMaxTokens int `json:"reasoning_max_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionRequest struct {
|
||||
NonOpenAIStyleOptions
|
||||
Messages []chatMessage `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
Store bool `json:"store,omitempty"`
|
||||
@@ -169,8 +174,11 @@ type chatMessage struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content any `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"` // For streaming responses
|
||||
ToolCalls []toolCall `json:"tool_calls,omitempty"`
|
||||
FunctionCall *functionCall `json:"function_call,omitempty"` // For legacy OpenAI format
|
||||
Refusal string `json:"refusal,omitempty"`
|
||||
ToolCallId string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
|
||||
@@ -377,14 +385,14 @@ func (m *chatMessage) ParseContent() []chatMessageContent {
|
||||
}
|
||||
|
||||
type toolCall struct {
|
||||
Index int `json:"index"`
|
||||
Id string `json:"id"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Function functionCall `json:"function"`
|
||||
}
|
||||
|
||||
type functionCall struct {
|
||||
Id string `json:"id"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
117
plugins/wasm-go/extensions/ai-proxy/provider/openrouter.go
Normal file
117
plugins/wasm-go/extensions/ai-proxy/provider/openrouter.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/wrapper"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// openrouterProvider is the provider for OpenRouter service.
|
||||
const (
|
||||
openrouterDomain = "openrouter.ai"
|
||||
openrouterChatCompletionPath = "/api/v1/chat/completions"
|
||||
openrouterCompletionPath = "/api/v1/completions"
|
||||
)
|
||||
|
||||
type openrouterProviderInitializer struct{}
|
||||
|
||||
func (o *openrouterProviderInitializer) ValidateConfig(config *ProviderConfig) error {
|
||||
if len(config.apiTokens) == 0 {
|
||||
return errors.New("no apiToken found in provider config")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *openrouterProviderInitializer) DefaultCapabilities() map[string]string {
|
||||
return map[string]string{
|
||||
string(ApiNameChatCompletion): openrouterChatCompletionPath,
|
||||
string(ApiNameCompletion): openrouterCompletionPath,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *openrouterProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
|
||||
config.setDefaultCapabilities(o.DefaultCapabilities())
|
||||
return &openrouterProvider{
|
||||
config: config,
|
||||
contextCache: createContextCache(&config),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type openrouterProvider struct {
|
||||
config ProviderConfig
|
||||
contextCache *contextCache
|
||||
}
|
||||
|
||||
func (o *openrouterProvider) GetProviderType() string {
|
||||
return providerTypeOpenRouter
|
||||
}
|
||||
|
||||
func (o *openrouterProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
o.config.handleRequestHeaders(o, ctx, apiName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *openrouterProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !o.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return o.config.handleRequestBody(o, o.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (o *openrouterProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), o.config.capabilities)
|
||||
util.OverwriteRequestHostHeader(headers, openrouterDomain)
|
||||
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+o.config.GetApiTokenInUse(ctx))
|
||||
headers.Del("Content-Length")
|
||||
}
|
||||
|
||||
func (o *openrouterProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName != ApiNameChatCompletion {
|
||||
return o.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
// Check if ReasoningMaxTokens exists in the request body
|
||||
reasoningMaxTokens := gjson.GetBytes(body, "reasoning_max_tokens")
|
||||
if !reasoningMaxTokens.Exists() || reasoningMaxTokens.Int() == 0 {
|
||||
// No reasoning_max_tokens, use default transformation
|
||||
return o.config.defaultTransformRequestBody(ctx, apiName, body)
|
||||
}
|
||||
|
||||
// Clear reasoning_effort field if it exists
|
||||
modifiedBody, err := sjson.DeleteBytes(body, "reasoning_effort")
|
||||
if err != nil {
|
||||
// If delete fails, continue with original body
|
||||
modifiedBody = body
|
||||
}
|
||||
|
||||
// Set reasoning.max_tokens to the value of reasoning_max_tokens
|
||||
modifiedBody, err = sjson.SetBytes(modifiedBody, "reasoning.max_tokens", reasoningMaxTokens.Int())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remove the original reasoning_max_tokens field
|
||||
modifiedBody, err = sjson.DeleteBytes(modifiedBody, "reasoning_max_tokens")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply default model mapping
|
||||
return o.config.defaultTransformRequestBody(ctx, apiName, modifiedBody)
|
||||
}
|
||||
|
||||
func (o *openrouterProvider) GetApiName(path string) ApiName {
|
||||
if strings.Contains(path, openrouterChatCompletionPath) {
|
||||
return ApiNameChatCompletion
|
||||
}
|
||||
if strings.Contains(path, openrouterCompletionPath) {
|
||||
return ApiNameCompletion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package provider
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"path"
|
||||
@@ -107,6 +108,7 @@ const (
|
||||
providerTypeQwen = "qwen"
|
||||
providerTypeOpenAI = "openai"
|
||||
providerTypeGroq = "groq"
|
||||
providerTypeGrok = "grok"
|
||||
providerTypeBaichuan = "baichuan"
|
||||
providerTypeYi = "yi"
|
||||
providerTypeDeepSeek = "deepseek"
|
||||
@@ -129,6 +131,7 @@ const (
|
||||
providerTypeDify = "dify"
|
||||
providerTypeBedrock = "bedrock"
|
||||
providerTypeVertex = "vertex"
|
||||
providerTypeOpenRouter = "openrouter"
|
||||
|
||||
protocolOpenAI = "openai"
|
||||
protocolOriginal = "original"
|
||||
@@ -136,9 +139,11 @@ const (
|
||||
roleSystem = "system"
|
||||
roleAssistant = "assistant"
|
||||
roleUser = "user"
|
||||
roleTool = "tool"
|
||||
|
||||
finishReasonStop = "stop"
|
||||
finishReasonLength = "length"
|
||||
finishReasonStop = "stop"
|
||||
finishReasonLength = "length"
|
||||
finishReasonToolCall = "tool_calls"
|
||||
|
||||
ctxKeyIncrementalStreaming = "incrementalStreaming"
|
||||
ctxKeyApiKey = "apiKey"
|
||||
@@ -182,6 +187,7 @@ var (
|
||||
providerTypeQwen: &qwenProviderInitializer{},
|
||||
providerTypeOpenAI: &openaiProviderInitializer{},
|
||||
providerTypeGroq: &groqProviderInitializer{},
|
||||
providerTypeGrok: &grokProviderInitializer{},
|
||||
providerTypeBaichuan: &baichuanProviderInitializer{},
|
||||
providerTypeYi: &yiProviderInitializer{},
|
||||
providerTypeDeepSeek: &deepseekProviderInitializer{},
|
||||
@@ -204,6 +210,7 @@ var (
|
||||
providerTypeDify: &difyProviderInitializer{},
|
||||
providerTypeBedrock: &bedrockProviderInitializer{},
|
||||
providerTypeVertex: &vertexProviderInitializer{},
|
||||
providerTypeOpenRouter: &openrouterProviderInitializer{},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -344,6 +351,9 @@ type ProviderConfig struct {
|
||||
// @Title zh-CN Gemini AI内容过滤和安全级别设定
|
||||
// @Description zh-CN 仅适用于 Gemini AI 服务。参考:https://ai.google.dev/gemini-api/docs/safety-settings
|
||||
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
|
||||
// @Title zh-CN Gemini Thinking Budget 配置
|
||||
// @Description zh-CN 仅适用于 Gemini AI 服务,用于控制思考预算
|
||||
geminiThinkingBudget int64 `required:"false" yaml:"geminiThinkingBudget" json:"geminiThinkingBudget"`
|
||||
// @Title zh-CN Vertex AI访问区域
|
||||
// @Description zh-CN 仅适用于Vertex AI服务。如需查看支持的区域的完整列表,请参阅https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
|
||||
vertexRegion string `required:"false" yaml:"vertexRegion" json:"vertexRegion"`
|
||||
@@ -472,6 +482,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.geminiSafetySetting[k] = v.String()
|
||||
}
|
||||
}
|
||||
c.geminiThinkingBudget = json.Get("geminiThinkingBudget").Int()
|
||||
c.vertexRegion = json.Get("vertexRegion").String()
|
||||
c.vertexProjectId = json.Get("vertexProjectId").String()
|
||||
c.vertexAuthKey = json.Get("vertexAuthKey").String()
|
||||
@@ -514,10 +525,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.reasoningContentMode = strings.ToLower(c.reasoningContentMode)
|
||||
switch c.reasoningContentMode {
|
||||
case reasoningBehaviorPassThrough, reasoningBehaviorIgnore, reasoningBehaviorConcat:
|
||||
break
|
||||
// valid values, no action needed
|
||||
default:
|
||||
c.reasoningContentMode = reasoningBehaviorPassThrough
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@@ -824,6 +834,10 @@ func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
|
||||
return exist
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsSupportedAPI(apiName ApiName) bool {
|
||||
return c.isSupportedAPI(apiName)
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) {
|
||||
for capability, path := range capabilities {
|
||||
c.capabilities[capability] = path
|
||||
@@ -847,8 +861,22 @@ func (c *ProviderConfig) handleRequestBody(
|
||||
return types.ActionContinue, nil
|
||||
}
|
||||
|
||||
// use openai protocol
|
||||
var err error
|
||||
|
||||
// handle claude protocol input - auto-detect based on conversion marker
|
||||
// If main.go detected a Claude request that needs conversion, convert the body
|
||||
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
|
||||
if needClaudeConversion {
|
||||
// Convert Claude protocol to OpenAI protocol
|
||||
converter := &ClaudeToOpenAIConverter{}
|
||||
body, err = converter.ConvertClaudeRequestToOpenAI(body)
|
||||
if err != nil {
|
||||
return types.ActionContinue, fmt.Errorf("failed to convert claude request to openai: %v", err)
|
||||
}
|
||||
log.Debugf("[Auto Protocol] converted Claude request body to OpenAI format")
|
||||
}
|
||||
|
||||
// use openai protocol (either original openai or converted from claude)
|
||||
if handler, ok := provider.(TransformRequestBodyHandler); ok {
|
||||
body, err = handler.TransformRequestBody(ctx, apiName, body)
|
||||
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
|
||||
|
||||
718
plugins/wasm-go/extensions/ai-proxy/test/ai360.go
Normal file
718
plugins/wasm-go/extensions/ai-proxy/test/ai360.go
Normal file
@@ -0,0 +1,718 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基本ai360配置
|
||||
var basicAi360Config = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "ai360",
|
||||
"apiTokens": []string{"sk-ai360-test123456789"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "360GPT_S2_V9",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:ai360多模型配置
|
||||
var ai360MultiModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "ai360",
|
||||
"apiTokens": []string{"sk-ai360-multi-model"},
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "360GPT_S2_V9",
|
||||
"gpt-4": "360GPT_S2_V9",
|
||||
"text-embedding-ada-002": "360Embedding_Text_V1",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:无效ai360配置(缺少apiToken)
|
||||
var invalidAi360Config = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "ai360",
|
||||
// 缺少apiTokens
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:ai360自定义域名配置
|
||||
var ai360CustomDomainConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "ai360",
|
||||
"apiTokens": []string{"sk-ai360-custom-domain"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "360GPT_S2_V9",
|
||||
},
|
||||
"openaiCustomUrl": "https://custom.ai360.cn/v1",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:ai360完整配置(包含failover等字段)
|
||||
var completeAi360Config = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "ai360",
|
||||
"apiTokens": []string{"sk-ai360-complete"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "360GPT_S2_V9",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
"retryOnFailure": map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunAi360ParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本ai360配置解析
|
||||
t.Run("basic ai360 config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试ai360多模型配置解析
|
||||
t.Run("ai360 multi model config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(ai360MultiModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试无效ai360配置(缺少apiToken)
|
||||
t.Run("invalid ai360 config - missing api token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(invalidAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 测试ai360自定义域名配置解析
|
||||
t.Run("ai360 custom domain config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(ai360CustomDomainConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试ai360完整配置解析
|
||||
t.Run("ai360 complete config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(completeAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAi360OnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试ai360请求头处理(聊天完成接口)
|
||||
t.Run("ai360 chat completion request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要处理请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为ai360域名
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Equal(t, "api.360.cn", hostValue, "Host should be changed to ai360 domain")
|
||||
|
||||
// 验证Authorization是否被设置
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "sk-ai360-test123456789", "Authorization should contain ai360 API token")
|
||||
|
||||
// 验证Path是否被正确处理
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
// ai360应该支持聊天完成接口,路径可能被转换
|
||||
require.Contains(t, pathValue, "/v1/chat/completions", "Path should contain chat completions endpoint")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasAi360Logs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "ai360") {
|
||||
hasAi360Logs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasAi360Logs, "Should have ai360 processing logs")
|
||||
})
|
||||
|
||||
// 测试ai360请求头处理(嵌入接口)
|
||||
t.Run("ai360 embeddings request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证嵌入接口的请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host转换
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "api.360.cn", hostValue)
|
||||
|
||||
// 验证Path转换
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/embeddings", "Path should contain embeddings endpoint")
|
||||
|
||||
// 验证Authorization设置
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist for embeddings")
|
||||
require.Contains(t, authValue, "sk-ai360-test123456789", "Authorization should contain ai360 API token")
|
||||
})
|
||||
|
||||
// 测试ai360请求头处理(不支持的接口)
|
||||
t.Run("ai360 unsupported api request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证不支持的接口处理
|
||||
// 即使是不支持的接口,基本的请求头转换仍然应该执行
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// Host仍然应该被转换
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "api.360.cn", hostValue)
|
||||
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAi360OnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试ai360请求体处理(聊天完成接口)
|
||||
t.Run("ai360 chat completion request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证模型名称是否被正确映射
|
||||
// ai360 provider会将模型名称从gpt-3.5-turbo映射为360GPT_S2_V9
|
||||
require.Contains(t, string(processedBody), "360GPT_S2_V9", "Model name should be mapped to ai360 format")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
infoLogs := host.GetInfoLogs()
|
||||
|
||||
// 验证是否有ai360相关的处理日志
|
||||
hasAi360Logs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "ai360") {
|
||||
hasAi360Logs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, log := range infoLogs {
|
||||
if strings.Contains(log, "ai360") {
|
||||
hasAi360Logs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasAi360Logs, "Should have ai360 processing logs")
|
||||
})
|
||||
|
||||
// 测试ai360请求体处理(嵌入接口)
|
||||
t.Run("ai360 embeddings request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证嵌入接口的请求体处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证模型名称映射
|
||||
// ai360 provider会将模型名称从text-embedding-ada-002映射为360GPT_S2_V9
|
||||
require.Contains(t, string(processedBody), "360GPT_S2_V9", "Model name should be mapped to ai360 format")
|
||||
|
||||
// 检查处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasEmbeddingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "embeddings") || strings.Contains(log, "ai360") {
|
||||
hasEmbeddingLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasEmbeddingLogs, "Should have embedding processing logs")
|
||||
})
|
||||
|
||||
// 测试ai360请求体处理(不支持的接口)
|
||||
t.Run("ai360 unsupported api request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"dall-e-3","prompt":"test image"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证不支持的接口处理
|
||||
|
||||
// 验证请求体没有被意外修改
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
require.Contains(t, string(processedBody), "dall-e-3", "Request body should not be modified for unsupported APIs")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAi360OnHttpResponseHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试ai360响应头处理(聊天完成接口)
|
||||
t.Run("ai360 chat completion response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"X-Request-Id", "req-123"},
|
||||
}
|
||||
action := host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应头是否被正确处理
|
||||
processedResponseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, processedResponseHeaders)
|
||||
|
||||
// 验证状态码
|
||||
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "200", statusValue, "Status should be 200")
|
||||
|
||||
// 验证Content-Type
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(processedResponseHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, "application/json", contentTypeValue, "Content-Type should be application/json")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasResponseLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response") || strings.Contains(log, "ai360") {
|
||||
hasResponseLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasResponseLogs, "Should have response processing logs")
|
||||
})
|
||||
|
||||
// 测试ai360响应头处理(嵌入接口)
|
||||
t.Run("ai360 embeddings response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"X-Embedding-Model", "360Embedding_Text_V1"},
|
||||
}
|
||||
action := host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应头处理
|
||||
processedResponseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, processedResponseHeaders)
|
||||
|
||||
// 验证嵌入模型信息
|
||||
modelValue, hasModel := test.GetHeaderValue(processedResponseHeaders, "X-Embedding-Model")
|
||||
require.True(t, hasModel, "Embedding model header should exist")
|
||||
require.Equal(t, "360Embedding_Text_V1", modelValue, "Embedding model should match configuration")
|
||||
})
|
||||
|
||||
// 测试ai360响应头处理(错误响应)
|
||||
t.Run("ai360 error response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置错误响应头
|
||||
errorResponseHeaders := [][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"Retry-After", "60"},
|
||||
}
|
||||
action := host.CallOnHttpResponseHeaders(errorResponseHeaders)
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证错误响应头处理
|
||||
processedResponseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, processedResponseHeaders)
|
||||
|
||||
// 验证错误状态码
|
||||
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "429", statusValue, "Status should be 429 (Too Many Requests)")
|
||||
|
||||
// 验证重试信息
|
||||
retryValue, hasRetry := test.GetHeaderValue(processedResponseHeaders, "Retry-After")
|
||||
require.True(t, hasRetry, "Retry-After header should exist")
|
||||
require.Equal(t, "60", retryValue, "Retry-After should be 60 seconds")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAi360OnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试ai360响应体处理(聊天完成接口)
|
||||
t.Run("ai360 chat completion response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体
|
||||
responseBody := `{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证响应体内容
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "chat.completion", "Response should contain chat completion object")
|
||||
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage information")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasResponseBodyLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "ai360") {
|
||||
hasResponseBodyLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
|
||||
})
|
||||
|
||||
// 测试ai360响应体处理(嵌入接口)
|
||||
t.Run("ai360 embeddings response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体
|
||||
responseBody := `{
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": 0
|
||||
}],
|
||||
"model": "text-embedding-ada-002",
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"total_tokens": 5
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证嵌入响应内容
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "embedding", "Response should contain embedding object")
|
||||
require.Contains(t, responseStr, "0.1", "Response should contain embedding vector")
|
||||
require.Contains(t, responseStr, "text-embedding-ada-002", "Response should contain model name")
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func RunAi360OnStreamingResponseBodyTests(t *testing.T) {
|
||||
// 测试ai360响应体处理(流式响应)
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("ai360 streaming response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAi360Config)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟流式响应体
|
||||
chunk1 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"role":"assistant"},"index":0}]}
|
||||
|
||||
`
|
||||
chunk2 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"},"index":0}]}
|
||||
|
||||
`
|
||||
chunk3 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"},"index":0}]}
|
||||
|
||||
`
|
||||
chunk4 := `data: [DONE]
|
||||
|
||||
`
|
||||
|
||||
// 处理流式响应体
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), false)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
action3 := host.CallOnHttpStreamingResponseBody([]byte(chunk3), false)
|
||||
require.Equal(t, types.ActionContinue, action3)
|
||||
|
||||
action4 := host.CallOnHttpStreamingResponseBody([]byte(chunk4), true)
|
||||
require.Equal(t, types.ActionContinue, action4)
|
||||
|
||||
// 验证流式响应处理
|
||||
// 注意:流式响应可能不会在GetResponseBody中累积,需要检查日志或其他方式验证
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasStreamingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "ai360") {
|
||||
hasStreamingLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
600
plugins/wasm-go/extensions/ai-proxy/test/azure.go
Normal file
600
plugins/wasm-go/extensions/ai-proxy/test/azure.go
Normal file
@@ -0,0 +1,600 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基本Azure OpenAI配置
|
||||
var basicAzureConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-test123456789",
|
||||
},
|
||||
"azureServiceUrl": "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI完整路径配置
|
||||
var azureFullPathConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-fullpath",
|
||||
},
|
||||
"azureServiceUrl": "https://fullpath-resource.openai.azure.com/openai/deployments/fullpath-deployment/chat/completions?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo",
|
||||
"gpt-4": "gpt-4",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI仅部署配置
|
||||
var azureDeploymentOnlyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-deployment",
|
||||
},
|
||||
"azureServiceUrl": "https://deployment-resource.openai.azure.com/openai/deployments/deployment-only?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI仅域名配置
|
||||
var azureDomainOnlyConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-domain",
|
||||
},
|
||||
"azureServiceUrl": "https://domain-resource.openai.azure.com?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI多模型配置
|
||||
var azureMultiModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-multi",
|
||||
},
|
||||
"azureServiceUrl": "https://multi-resource.openai.azure.com/openai/deployments/multi-deployment?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo",
|
||||
"gpt-4": "gpt-4",
|
||||
"text-embedding-ada-002": "text-embedding-ada-002",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI无效配置(缺少azureServiceUrl)
|
||||
var azureInvalidConfigMissingUrl = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-invalid",
|
||||
},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI无效配置(缺少api-version)
|
||||
var azureInvalidConfigMissingApiVersion = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"apiTokens": []string{
|
||||
"sk-azure-invalid",
|
||||
},
|
||||
"azureServiceUrl": "https://invalid-resource.openai.azure.com/openai/deployments/invalid-deployment/chat/completions",
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:Azure OpenAI无效配置(缺少apiToken)
|
||||
var azureInvalidConfigMissingToken = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "azure",
|
||||
"azureServiceUrl": "https://invalid-resource.openai.azure.com/openai/deployments/invalid-deployment/chat/completions?api-version=2024-02-15-preview",
|
||||
"modelMapping": map[string]interface{}{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunAzureParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本Azure OpenAI配置解析
|
||||
t.Run("basic azure config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAzureConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI完整路径配置解析
|
||||
t.Run("azure full path config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureFullPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI仅部署配置解析
|
||||
t.Run("azure deployment only config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDeploymentOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI仅域名配置解析
|
||||
t.Run("azure domain only config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI多模型配置解析
|
||||
t.Run("azure multi model config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureMultiModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI无效配置(缺少azureServiceUrl)
|
||||
t.Run("azure invalid config missing url", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureInvalidConfigMissingUrl)
|
||||
defer host.Reset()
|
||||
// 应该失败,因为缺少azureServiceUrl
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI无效配置(缺少api-version)
|
||||
t.Run("azure invalid config missing api version", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureInvalidConfigMissingApiVersion)
|
||||
defer host.Reset()
|
||||
// 应该失败,因为缺少api-version
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI无效配置(缺少apiToken)
|
||||
t.Run("azure invalid config missing token", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureInvalidConfigMissingToken)
|
||||
defer host.Reset()
|
||||
// 应该失败,因为缺少apiToken
|
||||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAzureOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试Azure OpenAI请求头处理(聊天完成接口)
|
||||
t.Run("azure chat completion request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAzureConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要处理请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为Azure服务域名
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Equal(t, "test-resource.openai.azure.com", hostValue, "Host should be changed to Azure service domain")
|
||||
|
||||
// 验证api-key是否被设置
|
||||
apiKeyValue, hasApiKey := test.GetHeaderValue(requestHeaders, "api-key")
|
||||
require.True(t, hasApiKey, "api-key header should exist")
|
||||
require.Equal(t, "sk-azure-test123456789", apiKeyValue, "api-key should contain Azure API token")
|
||||
|
||||
// 验证Path是否被正确处理
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/openai/deployments/test-deployment/chat/completions", "Path should contain Azure deployment path")
|
||||
|
||||
// 验证Content-Length是否被删除
|
||||
_, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length")
|
||||
require.False(t, hasContentLength, "Content-Length header should be deleted")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasAzureLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "azureProvider") {
|
||||
hasAzureLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasAzureLogs, "Should have Azure provider debug logs")
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI请求头处理(完整路径配置)
|
||||
t.Run("azure full path request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureFullPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为Azure服务域名
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Equal(t, "fullpath-resource.openai.azure.com", hostValue, "Host should be changed to Azure service domain")
|
||||
|
||||
// 验证api-key是否被设置
|
||||
apiKeyValue, hasApiKey := test.GetHeaderValue(requestHeaders, "api-key")
|
||||
require.True(t, hasApiKey, "api-key header should exist")
|
||||
require.Equal(t, "sk-azure-fullpath", apiKeyValue, "api-key should contain Azure API token")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAzureOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试Azure OpenAI请求体处理(聊天完成接口)
|
||||
t.Run("azure chat completion request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAzureConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, how are you?"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
// 验证模型映射是否生效
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, exists := bodyMap["model"]
|
||||
require.True(t, exists, "Model should exist in request body")
|
||||
require.Equal(t, "gpt-3.5-turbo", model, "Model should be mapped correctly")
|
||||
|
||||
// 验证请求路径是否被正确转换
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/openai/deployments/test-deployment/chat/completions", "Path should contain Azure deployment path")
|
||||
require.Contains(t, pathValue, "api-version=2024-02-15-preview", "Path should contain API version")
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI请求体处理(不同模型)
|
||||
t.Run("azure different model request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureMultiModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"model": "gpt-4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Explain quantum computing"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
transformedBody := host.GetRequestBody()
|
||||
require.NotNil(t, transformedBody)
|
||||
|
||||
var bodyMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedBody, &bodyMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
model, exists := bodyMap["model"]
|
||||
require.True(t, exists, "Model should exist in request body")
|
||||
require.Equal(t, "gpt-4", model, "Model should be mapped correctly")
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI请求体处理(仅部署配置)
|
||||
t.Run("azure deployment only request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDeploymentOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test message"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求路径是否使用默认部署
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/openai/deployments/deployment-only/chat/completions", "Path should use default deployment")
|
||||
})
|
||||
|
||||
// 测试Azure OpenAI请求体处理(仅域名配置)
|
||||
t.Run("azure domain only request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(azureDomainOnlyConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test message"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求路径是否使用模型占位符
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/openai/deployments/gpt-3.5-turbo/chat/completions", "Path should use model from request body")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAzureOnHttpResponseHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试Azure OpenAI响应头处理
|
||||
t.Run("azure response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAzureConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 处理响应头
|
||||
action = host.CallOnHttpResponseHeaders([][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应头是否被正确处理
|
||||
responseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, responseHeaders)
|
||||
|
||||
// 验证状态码
|
||||
statusValue, hasStatus := test.GetHeaderValue(responseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "200", statusValue, "Status should be 200")
|
||||
|
||||
// 验证Content-Type
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(responseHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, "application/json", contentTypeValue, "Content-Type should be application/json")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunAzureOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试Azure OpenAI响应体处理
|
||||
t.Run("azure response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicAzureConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello"
|
||||
}
|
||||
]
|
||||
}`
|
||||
action = host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 处理响应体
|
||||
responseBody := `{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "Hello! How can I help you?"
|
||||
}
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
action = host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
transformedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, transformedResponseBody)
|
||||
|
||||
// 验证响应体内容
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal(transformedResponseBody, &responseMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
choices, exists := responseMap["choices"]
|
||||
require.True(t, exists, "Choices should exist in response body")
|
||||
require.NotNil(t, choices, "Choices should not be nil")
|
||||
})
|
||||
})
|
||||
}
|
||||
1335
plugins/wasm-go/extensions/ai-proxy/test/gemini.go
Normal file
1335
plugins/wasm-go/extensions/ai-proxy/test/gemini.go
Normal file
File diff suppressed because it is too large
Load Diff
866
plugins/wasm-go/extensions/ai-proxy/test/openai.go
Normal file
866
plugins/wasm-go/extensions/ai-proxy/test/openai.go
Normal file
@@ -0,0 +1,866 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/higress-group/wasm-go/pkg/test"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// 测试配置:基本OpenAI配置
|
||||
var basicOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-test123456789"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI多模型配置
|
||||
var openAIMultiModelConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-multi-model"},
|
||||
"modelMapping": map[string]string{
|
||||
"gpt-3.5-turbo": "gpt-3.5-turbo",
|
||||
"gpt-4": "gpt-4",
|
||||
"text-embedding-ada-002": "text-embedding-ada-002",
|
||||
"dall-e-3": "dall-e-3",
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI自定义域名配置(直接路径)
|
||||
var openAICustomDomainDirectPathConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-custom-domain"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"openaiCustomUrl": "https://custom.openai.com/v1",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI自定义域名配置(间接路径)
|
||||
var openAICustomDomainIndirectPathConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-custom-domain"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"openaiCustomUrl": "https://custom.openai.com/api",
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
// 测试配置:OpenAI完整配置(包含responseJsonSchema等字段)
|
||||
var completeOpenAIConfig = func() json.RawMessage {
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"provider": map[string]interface{}{
|
||||
"type": "openai",
|
||||
"apiTokens": []string{"sk-openai-complete"},
|
||||
"modelMapping": map[string]string{
|
||||
"*": "gpt-3.5-turbo",
|
||||
},
|
||||
"responseJsonSchema": map[string]interface{}{
|
||||
"type": "json_object",
|
||||
},
|
||||
"failover": map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
"retryOnFailure": map[string]interface{}{
|
||||
"enabled": false,
|
||||
},
|
||||
},
|
||||
})
|
||||
return data
|
||||
}()
|
||||
|
||||
func RunOpenAIParseConfigTests(t *testing.T) {
|
||||
test.RunGoTest(t, func(t *testing.T) {
|
||||
// 测试基本OpenAI配置解析
|
||||
t.Run("basic openai config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试OpenAI多模型配置解析
|
||||
t.Run("openai multi model config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAIMultiModelConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名配置(直接路径)
|
||||
t.Run("openai custom domain direct path config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainDirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名配置(间接路径)
|
||||
t.Run("openai custom domain indirect path config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainIndirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
|
||||
// 测试OpenAI完整配置解析
|
||||
t.Run("openai complete config", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(completeOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
config, err := host.GetMatchConfig()
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIOnHttpRequestHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试OpenAI请求头处理(聊天完成接口)
|
||||
t.Run("openai chat completion request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 应该返回HeaderStopIteration,因为需要处理请求体
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证请求头是否被正确处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为OpenAI默认域名
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost, "Host header should exist")
|
||||
require.Equal(t, "api.openai.com", hostValue, "Host should be changed to OpenAI default domain")
|
||||
|
||||
// 验证Authorization是否被设置
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist")
|
||||
require.Contains(t, authValue, "sk-openai-test123456789", "Authorization should contain OpenAI API token")
|
||||
|
||||
// 验证Path是否被正确处理
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath, "Path header should exist")
|
||||
require.Contains(t, pathValue, "/v1/chat/completions", "Path should contain chat completions endpoint")
|
||||
|
||||
// 验证Content-Length是否被删除
|
||||
_, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length")
|
||||
require.False(t, hasContentLength, "Content-Length header should be deleted")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasOpenAILogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "openai") {
|
||||
hasOpenAILogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasOpenAILogs, "Should have OpenAI processing logs")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(嵌入接口)
|
||||
t.Run("openai embeddings request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证嵌入接口的请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host转换
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "api.openai.com", hostValue)
|
||||
|
||||
// 验证Path转换
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/embeddings", "Path should contain embeddings endpoint")
|
||||
|
||||
// 验证Authorization设置
|
||||
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
|
||||
require.True(t, hasAuth, "Authorization header should exist for embeddings")
|
||||
require.Contains(t, authValue, "sk-openai-test123456789", "Authorization should contain OpenAI API token")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求头处理(图像生成接口)
|
||||
t.Run("openai image generation request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证图像生成接口的请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host转换
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "api.openai.com", hostValue)
|
||||
|
||||
// 验证Path转换
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
require.Contains(t, pathValue, "/v1/images/generations", "Path should contain image generations endpoint")
|
||||
})
|
||||
|
||||
// 测试OpenAI自定义域名请求头处理
|
||||
t.Run("openai custom domain request headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(openAICustomDomainDirectPathConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 设置请求头
|
||||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
require.Equal(t, types.HeaderStopIteration, action)
|
||||
|
||||
// 验证自定义域名的请求头处理
|
||||
requestHeaders := host.GetRequestHeaders()
|
||||
require.NotNil(t, requestHeaders)
|
||||
|
||||
// 验证Host是否被改为自定义域名
|
||||
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
|
||||
require.True(t, hasHost)
|
||||
require.Equal(t, "custom.openai.com", hostValue, "Host should be changed to custom domain")
|
||||
|
||||
// 验证Path是否被正确处理
|
||||
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
|
||||
require.True(t, hasPath)
|
||||
// 对于直接路径,应该保持原有路径
|
||||
require.Contains(t, pathValue, "/v1/chat/completions", "Path should be preserved for direct custom path")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIOnHttpRequestBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试OpenAI请求体处理(聊天完成接口)
|
||||
t.Run("openai chat completion request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证模型名称是否被正确映射
|
||||
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Original model name should be preserved or mapped")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
infoLogs := host.GetInfoLogs()
|
||||
|
||||
// 验证是否有OpenAI相关的处理日志
|
||||
hasOpenAILogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "openai") {
|
||||
hasOpenAILogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
for _, log := range infoLogs {
|
||||
if strings.Contains(log, "openai") {
|
||||
hasOpenAILogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasOpenAILogs, "Should have OpenAI processing logs")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求体处理(嵌入接口)
|
||||
t.Run("openai embeddings request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证嵌入接口的请求体处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证模型名称映射
|
||||
// 由于使用了通配符映射 "*": "gpt-3.5-turbo",text-embedding-ada-002 会被映射为 gpt-3.5-turbo
|
||||
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Model name should be mapped via wildcard")
|
||||
|
||||
// 检查处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasEmbeddingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "embeddings") || strings.Contains(log, "openai") {
|
||||
hasEmbeddingLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasEmbeddingLogs, "Should have embedding processing logs")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求体处理(图像生成接口)
|
||||
t.Run("openai image generation request body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"dall-e-3","prompt":"test image"}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证图像生成接口的请求体处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证模型名称映射
|
||||
// 由于使用了通配符映射 "*": "gpt-3.5-turbo",dall-e-3 会被映射为 gpt-3.5-turbo
|
||||
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Model name should be mapped via wildcard")
|
||||
|
||||
// 检查处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasImageLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "image") || strings.Contains(log, "openai") {
|
||||
hasImageLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasImageLogs, "Should have image generation processing logs")
|
||||
})
|
||||
|
||||
// 测试OpenAI请求体处理(带responseJsonSchema配置)
|
||||
t.Run("openai request body with responseJsonSchema", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(completeOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
action := host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证请求体是否被正确处理
|
||||
processedBody := host.GetRequestBody()
|
||||
require.NotNil(t, processedBody)
|
||||
|
||||
// 验证responseJsonSchema是否被应用
|
||||
// 注意:由于test框架的限制,我们可能需要检查日志或其他方式来验证处理结果
|
||||
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Model name should be preserved")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasSchemaLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response format") || strings.Contains(log, "openai") {
|
||||
hasSchemaLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasSchemaLogs, "Should have response format processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIOnHttpResponseHeadersTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试OpenAI响应头处理(聊天完成接口)
|
||||
t.Run("openai chat completion response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"X-Request-Id", "req-123"},
|
||||
}
|
||||
action := host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应头是否被正确处理
|
||||
processedResponseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, processedResponseHeaders)
|
||||
|
||||
// 验证状态码
|
||||
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "200", statusValue, "Status should be 200")
|
||||
|
||||
// 验证Content-Type
|
||||
contentTypeValue, hasContentType := test.GetHeaderValue(processedResponseHeaders, "Content-Type")
|
||||
require.True(t, hasContentType, "Content-Type header should exist")
|
||||
require.Equal(t, "application/json", contentTypeValue, "Content-Type should be application/json")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasResponseLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response") || strings.Contains(log, "openai") {
|
||||
hasResponseLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasResponseLogs, "Should have response processing logs")
|
||||
})
|
||||
|
||||
// 测试OpenAI响应头处理(嵌入接口)
|
||||
t.Run("openai embeddings response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"X-Embedding-Model", "text-embedding-ada-002"},
|
||||
}
|
||||
action := host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应头处理
|
||||
processedResponseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, processedResponseHeaders)
|
||||
|
||||
// 验证嵌入模型信息
|
||||
modelValue, hasModel := test.GetHeaderValue(processedResponseHeaders, "X-Embedding-Model")
|
||||
require.True(t, hasModel, "Embedding model header should exist")
|
||||
require.Equal(t, "text-embedding-ada-002", modelValue, "Embedding model should match configuration")
|
||||
})
|
||||
|
||||
// 测试OpenAI响应头处理(错误响应)
|
||||
t.Run("openai error response headers", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置错误响应头
|
||||
errorResponseHeaders := [][2]string{
|
||||
{":status", "429"},
|
||||
{"Content-Type", "application/json"},
|
||||
{"Retry-After", "60"},
|
||||
}
|
||||
action := host.CallOnHttpResponseHeaders(errorResponseHeaders)
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证错误响应头处理
|
||||
processedResponseHeaders := host.GetResponseHeaders()
|
||||
require.NotNil(t, processedResponseHeaders)
|
||||
|
||||
// 验证错误状态码
|
||||
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
|
||||
require.True(t, hasStatus, "Status header should exist")
|
||||
require.Equal(t, "429", statusValue, "Status should be 429 (Too Many Requests)")
|
||||
|
||||
// 验证重试信息
|
||||
retryValue, hasRetry := test.GetHeaderValue(processedResponseHeaders, "Retry-After")
|
||||
require.True(t, hasRetry, "Retry-After header should exist")
|
||||
require.Equal(t, "60", retryValue, "Retry-After should be 60 seconds")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIOnHttpResponseBodyTests(t *testing.T) {
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
// 测试OpenAI响应体处理(聊天完成接口)
|
||||
t.Run("openai chat completion response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体
|
||||
responseBody := `{
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1677652288,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 9,
|
||||
"completion_tokens": 12,
|
||||
"total_tokens": 21
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体是否被正确处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证响应体内容
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "chat.completion", "Response should contain chat completion object")
|
||||
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
|
||||
require.Contains(t, responseStr, "usage", "Response should contain usage information")
|
||||
|
||||
// 检查是否有相关的处理日志
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasResponseBodyLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "openai") {
|
||||
hasResponseBodyLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
|
||||
})
|
||||
|
||||
// 测试OpenAI响应体处理(嵌入接口)
|
||||
t.Run("openai embeddings response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/embeddings"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体
|
||||
responseBody := `{
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"object": "embedding",
|
||||
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": 0
|
||||
}],
|
||||
"model": "text-embedding-ada-002",
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"total_tokens": 5
|
||||
}
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证嵌入响应内容
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "embedding", "Response should contain embedding object")
|
||||
require.Contains(t, responseStr, "0.1", "Response should contain embedding vector")
|
||||
require.Contains(t, responseStr, "text-embedding-ada-002", "Response should contain model name")
|
||||
})
|
||||
|
||||
// 测试OpenAI响应体处理(图像生成接口)
|
||||
t.Run("openai image generation response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/images/generations"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置请求体
|
||||
requestBody := `{"model":"dall-e-3","prompt":"test image"}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "application/json"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 设置响应体
|
||||
responseBody := `{
|
||||
"created": 1677652288,
|
||||
"data": [{
|
||||
"url": "https://example.com/image1.png",
|
||||
"revised_prompt": "test image"
|
||||
}]
|
||||
}`
|
||||
action := host.CallOnHttpResponseBody([]byte(responseBody))
|
||||
|
||||
require.Equal(t, types.ActionContinue, action)
|
||||
|
||||
// 验证响应体处理
|
||||
processedResponseBody := host.GetResponseBody()
|
||||
require.NotNil(t, processedResponseBody)
|
||||
|
||||
// 验证图像生成响应内容
|
||||
responseStr := string(processedResponseBody)
|
||||
require.Contains(t, responseStr, "data", "Response should contain data array")
|
||||
require.Contains(t, responseStr, "url", "Response should contain image URL")
|
||||
require.Contains(t, responseStr, "revised_prompt", "Response should contain revised prompt")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) {
|
||||
// 测试OpenAI响应体处理(流式响应)
|
||||
test.RunTest(t, func(t *testing.T) {
|
||||
t.Run("openai streaming response body", func(t *testing.T) {
|
||||
host, status := test.NewTestHost(basicOpenAIConfig)
|
||||
defer host.Reset()
|
||||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||||
|
||||
// 先设置请求头
|
||||
host.CallOnHttpRequestHeaders([][2]string{
|
||||
{":authority", "example.com"},
|
||||
{":path", "/v1/chat/completions"},
|
||||
{":method", "POST"},
|
||||
{"Content-Type", "application/json"},
|
||||
})
|
||||
|
||||
// 设置流式请求体
|
||||
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}],"stream":true}`
|
||||
host.CallOnHttpRequestBody([]byte(requestBody))
|
||||
|
||||
// 设置流式响应头
|
||||
responseHeaders := [][2]string{
|
||||
{":status", "200"},
|
||||
{"Content-Type", "text/event-stream"},
|
||||
}
|
||||
host.CallOnHttpResponseHeaders(responseHeaders)
|
||||
|
||||
// 模拟流式响应体
|
||||
chunk1 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"role":"assistant"},"index":0}]}
|
||||
|
||||
`
|
||||
chunk2 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"},"index":0}]}
|
||||
|
||||
`
|
||||
chunk3 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"},"index":0}]}
|
||||
|
||||
`
|
||||
chunk4 := `data: [DONE]
|
||||
|
||||
`
|
||||
|
||||
// 处理流式响应体
|
||||
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
|
||||
require.Equal(t, types.ActionContinue, action1)
|
||||
|
||||
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), false)
|
||||
require.Equal(t, types.ActionContinue, action2)
|
||||
|
||||
action3 := host.CallOnHttpStreamingResponseBody([]byte(chunk3), false)
|
||||
require.Equal(t, types.ActionContinue, action3)
|
||||
|
||||
action4 := host.CallOnHttpStreamingResponseBody([]byte(chunk4), true)
|
||||
require.Equal(t, types.ActionContinue, action4)
|
||||
|
||||
// 验证流式响应处理
|
||||
// 注意:流式响应可能不会在GetResponseBody中累积,需要检查日志或其他方式验证
|
||||
debugLogs := host.GetDebugLogs()
|
||||
hasStreamingLogs := false
|
||||
for _, log := range debugLogs {
|
||||
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "openai") {
|
||||
hasStreamingLogs = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
|
||||
})
|
||||
})
|
||||
}
|
||||
1213
plugins/wasm-go/extensions/ai-proxy/test/qwen.go
Normal file
1213
plugins/wasm-go/extensions/ai-proxy/test/qwen.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -59,18 +59,10 @@ func OverwriteRequestPath(path string) error {
|
||||
}
|
||||
|
||||
func OverwriteRequestAuthorization(credential string) error {
|
||||
if exist, _ := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth); exist == "" {
|
||||
if originAuth, err := proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil {
|
||||
_ = proxywasm.AddHttpRequestHeader(HeaderOriginalPath, originAuth)
|
||||
}
|
||||
}
|
||||
return proxywasm.ReplaceHttpRequestHeader(HeaderAuthorization, credential)
|
||||
}
|
||||
|
||||
func OverwriteRequestHostHeader(headers http.Header, host string) {
|
||||
if originHost, err := proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil {
|
||||
headers.Set(HeaderOriginalHost, originHost)
|
||||
}
|
||||
headers.Set(HeaderAuthority, host)
|
||||
}
|
||||
|
||||
@@ -175,11 +167,6 @@ func SetOriginalRequestAuth(auth string) {
|
||||
}
|
||||
|
||||
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
|
||||
if exist := headers.Get(HeaderOriginalAuth); exist == "" {
|
||||
if originAuth := headers.Get(HeaderAuthorization); originAuth != "" {
|
||||
headers.Set(HeaderOriginalAuth, originAuth)
|
||||
}
|
||||
}
|
||||
headers.Set(HeaderAuthorization, credential)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,14 +5,20 @@ go 1.24.1
|
||||
toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
|
||||
github.com/higress-group/wasm-go v1.0.1
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/resp v0.1.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/tetratelabs/wazero v1.7.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.1 h1:T1m++qTEANp8+jwE0sxltwtaTKmrHCkLOp1m9N+YeqY=
|
||||
github.com/higress-group/wasm-go v1.0.1/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
|
||||
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
|
||||
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
|
||||
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
|
||||
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user