Compare commits

..

46 Commits

Author SHA1 Message Date
EricaLiu
020b5f3984 Add security schema for nacos mcp (#2847) 2025-09-01 15:23:04 +08:00
澄潭
9a12f0b593 rel: Release v2.1.7 (#2849) 2025-09-01 15:22:17 +08:00
韩贤涛
7e74eeb333 feat(wasm-plugin): add tests and docs for hmac-auth-apisix (#2842) 2025-09-01 14:07:57 +08:00
澄潭
fff5903007 feat: add MCP SSE stateful session load balancer support (#2818) 2025-08-28 20:52:07 +08:00
Jingze
a00b810be5 feat(wasm-go): add wasm go plugin unit test and ci workflow (#2809) 2025-08-28 20:02:03 +08:00
澄潭
3e0a5f02a7 feat(wasm-plugin): add jsonrpc-converter plugin (#2805) 2025-08-28 19:28:37 +08:00
澄潭
44c33617fa feat(ai-proxy): add OpenRouter provider support (#2823) 2025-08-28 19:26:21 +08:00
韩贤涛
b2ffeff7b8 feat(wasm-plugin): add hmac-auth-apisix plugin (#2815) 2025-08-26 21:10:43 +08:00
Asnowww
c0ddbccbfe chore: fix typos (#2816) 2025-08-26 14:41:05 +08:00
澄潭
16a18c6609 feat(ai-proxy): add auto protocol compatibility for OpenAI and Claude APIs (#2810) 2025-08-25 14:13:51 +08:00
Xijun Dai
72b98ab6cf feat(ai-proxy): add anthropic/v1/messages and openai/v1/models support for DeepSeek (#2808)
Signed-off-by: Xijun Dai <daixijun1990@gmail.com>
2025-08-21 17:40:13 +08:00
xingpiaoliang
df20472f7b fix(wasm-go-build): correct the build command (#2799) 2025-08-21 09:47:30 +08:00
co63oc
9186b5505d chore: fix typos (#2770) 2025-08-20 10:38:03 +08:00
co63oc
eaea782693 fix RegisteTickFunc (#2787) 2025-08-19 20:53:53 +08:00
rinfx
890a802481 feat(ai-proxy): bedrock support tool use (#2730) 2025-08-19 16:54:50 +08:00
github-actions[bot]
bb69a1d50b Update CRD file in the helm folder (#2769)
Co-authored-by: CH3CHO <2909796+CH3CHO@users.noreply.github.com>
2025-08-19 16:54:27 +08:00
zat366
5a023512fa feat(mcp-server): update the dependency github.com/higress-group/wasm-go to support MCP response images (#2788) 2025-08-19 16:52:14 +08:00
Kent Dong
47f0478ef5 fix: Remove "accept-encoding" header for mcp-sse upstreams (#2786) 2025-08-19 15:53:49 +08:00
Kent Dong
c9fa8d15db chore: Restructure the path-to-api-name mapping logic in ai-proxy (#2773) 2025-08-18 19:05:23 +08:00
Kent Dong
0f1afcdcca fix(ai-proxy): Do not change the configured components of Azure URL (#2782) 2025-08-18 16:27:25 +08:00
StarryNight
19d1548971 update ai-prompt-decorator to new plugin wrapper api (#2777) 2025-08-18 11:01:35 +08:00
Kent Dong
24dca0455e fix: Fix bugs in the bedrock model name escaping logic (#2663)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2025-08-15 17:40:13 +08:00
澄潭
be603af461 Update CODEOWNERS 2025-08-15 16:53:31 +08:00
澄潭
8796c6040f Update CODEOWNERS 2025-08-15 16:52:45 +08:00
Kent Dong
15edc79fb3 fix: Fix the malfunction of _match_service_ rules in C++ Wasm plugins (#2723) 2025-08-15 16:47:30 +08:00
Kent Dong
5822868f87 feat: Support adding a proxy server in between when forwarding requests to upstream (#2710)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2025-08-15 16:15:42 +08:00
澄潭
995bcc2168 feat(transformer): Add split and retain strategy for dedupe (#2761) 2025-08-15 15:21:13 +08:00
Kent Dong
a3310f1a3b fix: Allow duplicated items in the IP list of ip-restriction config (#2755)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2025-08-15 13:39:26 +08:00
Jingze
0bb934073a fix(golang-filter): fix mcp server contruct envoy filter unit test (#2757) 2025-08-13 17:43:15 +08:00
Jingze
247de6a349 fix(golang-filter): fix bug of stop and buffer in decode data (#2754) 2025-08-13 11:37:01 +08:00
co63oc
79b3b23aab Fix typos (#2628)
Signed-off-by: co63oc <co63oc@users.noreply.github.com>
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
Co-authored-by: xingpiaoliang <erasernoobx@outlook.com>
2025-08-12 15:58:32 +08:00
Jingze
b9d6343efa fix(ip-restriction): fix bug of set ip_source_type (#2743) 2025-08-11 18:47:48 +08:00
xingpiaoliang
0af00bef6b feat(ai-proxy): gemini model multimodal support (#2698) 2025-08-11 15:54:34 +08:00
Kent Dong
953b95cf92 chore: Update some log levels in ai-statistics (#2740) 2025-08-11 13:59:01 +08:00
aias00
a76808171f feat: improve ai statistic plugin (#2671) 2025-08-11 13:43:00 +08:00
rinfx
f7813df1d7 add value length limit for ai statistics, truncate when over limit (#2729) 2025-08-11 11:12:03 +08:00
WeixinX
33ce18df5a feat(wasm-go): add field reroute to disable route reselection (#2739) 2025-08-11 09:37:35 +08:00
aias00
a1bf1ff009 feat(provider): add support for Grok provider in AI proxy (#2713)
Co-authored-by: 韩贤涛 <601803023@qq.com>
2025-08-07 17:22:47 +08:00
澄潭
b69e3a8f30 Deprecate the use of slash as a concatenation character for mcp server and tool, to avoid non-compliance with function naming conventions. (#2711) 2025-08-07 15:18:31 +08:00
NOBODY
5ee878198c feat(ai-proxy): gemini model thinking support (#2712) 2025-08-06 06:50:46 +08:00
rinfx
943fda0a9c AI security streaming (#2696) 2025-08-04 20:47:18 +08:00
WeixinX
abc31169a2 fix(wasm-go): transformer performs an add op when the replace key does not exist (#2706) 2025-08-04 20:38:29 +08:00
韩贤涛
5f65b4f5b0 feat: Rust WASM supports Redis database configuration option (#2704) 2025-08-03 12:56:27 +08:00
澄潭
645646fe22 Fix the issue where AI route fallback does not work when using Bedrock. (#2653) 2025-07-31 20:16:16 +08:00
澄潭
4acb65cc67 Update README_ZH.md 2025-07-31 14:26:07 +08:00
github-actions[bot]
e63a2e0251 Add release notes (#2693)
Co-authored-by: 澄潭 <zty98751@alibaba-inc.com>
2025-07-31 14:23:54 +08:00
267 changed files with 41745 additions and 1238 deletions

View File

@@ -122,7 +122,7 @@ jobs:
set -e
cd /workspace/plugins/wasm-go/extensions/${PLUGIN_NAME}
go mod tidy
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o plugin.wasm main.go
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o plugin.wasm .
tar czvf plugin.tar.gz plugin.wasm
echo ${{ secrets.REGISTRY_PASSWORD }} | oras login -u ${{ secrets.REGISTRY_USERNAME }} --password-stdin ${{ env.IMAGE_REGISTRY_SERVICE }}
oras push ${target_image} ${push_command}

View File

@@ -0,0 +1,378 @@
name: Wasm Plugin Unit Tests(GO)
on:
push:
branches: [ main ]
paths:
- 'plugins/wasm-go/extensions/**'
- '.github/workflows/wasm-plugin-unit-test.yml'
- 'go.mod'
- 'go.sum'
pull_request:
branches: [ "*" ]
paths:
- 'plugins/wasm-go/extensions/**'
- '.github/workflows/wasm-plugin-unit-test.yml'
- 'go.mod'
- 'go.sum'
env:
GO111MODULE: on
CGO_ENABLED: 0
GOOS: linux
GOARCH: amd64
jobs:
detect-changed-plugins:
name: Detect Changed Plugins
runs-on: ubuntu-latest
outputs:
changed-plugins: ${{ steps.detect.outputs.plugins }}
has-changes: ${{ steps.detect.outputs.has-changes }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0 # 获取完整历史用于比较
- name: Detect changed plugins
id: detect
run: |
# 获取变更的文件列表
if [ "${{ github.event_name }}" = "pull_request" ]; then
# PR模式比较目标分支和源分支
git fetch origin ${{ github.base_ref }}
CHANGED_FILES=$(git diff --name-only origin/${{ github.base_ref }}...HEAD)
else
# Push模式比较当前提交和上一个提交
CHANGED_FILES=$(git diff --name-only HEAD~1 HEAD)
fi
echo "Changed files:"
echo "$CHANGED_FILES"
# 提取变更的插件名称
CHANGED_PLUGINS=""
for file in $CHANGED_FILES; do
if [[ $file =~ ^plugins/wasm-go/extensions/([^/]+)/ ]]; then
PLUGIN_NAME="${BASH_REMATCH[1]}"
if [[ ! " $CHANGED_PLUGINS " =~ " $PLUGIN_NAME " ]]; then
# 修复:只在非空时添加空格
if [ -z "$CHANGED_PLUGINS" ]; then
CHANGED_PLUGINS="$PLUGIN_NAME"
else
CHANGED_PLUGINS="$CHANGED_PLUGINS $PLUGIN_NAME"
fi
fi
fi
done
# 如果没有插件变更,不触发测试
if [ -z "$CHANGED_PLUGINS" ]; then
echo "No plugin changes detected, skipping tests"
echo "has-changes=false" >> $GITHUB_OUTPUT
echo "plugins=[]" >> $GITHUB_OUTPUT
else
echo "Changed plugins: $CHANGED_PLUGINS"
echo "has-changes=true" >> $GITHUB_OUTPUT
# 将空格分隔转换为 JSON 数组格式
PLUGINS_JSON=$(echo "$CHANGED_PLUGINS" | sed 's/ /","/g' | sed 's/^/["/' | sed 's/$/"]/')
echo "PLUGINS_JSON: $PLUGINS_JSON"
echo "plugins=$PLUGINS_JSON" >> $GITHUB_OUTPUT
fi
test:
name: Test Changed Plugins
runs-on: ubuntu-latest
needs: detect-changed-plugins
if: needs.detect-changed-plugins.outputs.has-changes == 'true'
strategy:
fail-fast: false
matrix:
plugin: ${{ fromJSON(needs.detect-changed-plugins.outputs.changed-plugins) }}
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Go 1.24
uses: actions/setup-go@v4
with:
go-version: 1.24
cache: true
- name: Install test tools
run: |
go install gotest.tools/gotestsum@latest
# 移除gocov工具直接使用Codecov
- name: Build WASM for ${{ matrix.plugin }}
working-directory: plugins/wasm-go/extensions/${{ matrix.plugin }}
run: |
echo "Building WASM for ${{ matrix.plugin }}..."
# 检查是否存在main.go文件
export GOOS=wasip1
export GOARCH=wasm
# 构建WASM文件失败时直接退出
if ! go build -buildmode=c-shared -o main.wasm ./; then
echo "❌ WASM build failed for ${{ matrix.plugin }}"
exit 1
fi
# 验证WASM文件是否生成
if [ ! -f "main.wasm" ]; then
echo "❌ WASM file not generated for ${{ matrix.plugin }}"
exit 1
fi
echo "✅ WASM build successful for ${{ matrix.plugin }}"
- name: Set WASM_PATH environment variable
run: |
echo "WASM_PATH=$(pwd)/plugins/wasm-go/extensions/${{ matrix.plugin }}/main.wasm" >> $GITHUB_ENV
- name: Run tests with coverage for ${{ matrix.plugin }}
working-directory: plugins/wasm-go/extensions/${{ matrix.plugin }}
run: |
# 检查是否存在main_test.go文件
if [ -f "main_test.go" ]; then
echo "Running tests for ${{ matrix.plugin }}..."
# 运行测试并生成覆盖率报告
gotestsum --junitfile ../../../../test-results-${{ matrix.plugin }}.xml \
--format standard-verbose \
--jsonfile ../../../../test-output-${{ matrix.plugin }}.json \
-- -coverprofile=coverage-${{ matrix.plugin }}.out -covermode=atomic -coverpkg=./... ./...
echo "✅ Tests completed for ${{ matrix.plugin }}"
else
echo "No tests found for ${{ matrix.plugin }}, skipping..."
# 创建空的测试结果文件
echo '<?xml version="1.0" encoding="UTF-8"?><testsuites><testsuite name="no-tests" tests="0" failures="0" errors="0" time="0"></testsuite></testsuites>' > ../../../../test-results-${{ matrix.plugin }}.xml
fi
- name: Upload test results for ${{ matrix.plugin }}
uses: actions/upload-artifact@v4
if: always()
with:
name: test-results-${{ matrix.plugin }}
path: |
test-results-${{ matrix.plugin }}.xml
test-output-${{ matrix.plugin }}.json
retention-days: 30
- name: Upload coverage report for ${{ matrix.plugin }}
uses: actions/upload-artifact@v4
if: always()
with:
name: coverage-${{ matrix.plugin }}
path: plugins/wasm-go/extensions/${{ matrix.plugin }}/coverage-${{ matrix.plugin }}.out
retention-days: 30
test-summary:
name: Test Summary & Coverage
runs-on: ubuntu-latest
needs: [detect-changed-plugins, test]
if: always() && needs.detect-changed-plugins.outputs.has-changes == 'true'
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go 1.24
uses: actions/setup-go@v4
with:
go-version: 1.24
cache: true
- name: Install required tools
run: |
go install github.com/wadey/gocovmerge@latest
- name: Download all test results
uses: actions/download-artifact@v4
with:
pattern: test-results-*
merge-multiple: true
path: ${{ github.workspace }}
- name: Download all coverage files
uses: actions/download-artifact@v4
with:
pattern: coverage-*
merge-multiple: true
path: ${{ github.workspace }}
- name: Generate comprehensive test summary
run: |
echo "## 🧪 Go Plugin Test Results" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
total_plugins=0
passed_plugins=0
failed_plugins=0
total_tests=0
total_failures=0
total_errors=0
echo "### 📊 Test Results by Plugin" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
for result_file in test-results-*.xml; do
if [ -f "$result_file" ]; then
plugin_name=$(echo "$result_file" | sed 's/test-results-\(.*\)\.xml/\1/')
total_plugins=$((total_plugins + 1))
# 解析XML获取测试结果
if grep -q '<testsuite' "$result_file"; then
# 使用grep解析XML属性更稳定可靠
tests=$(grep -o 'tests="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
failures=$(grep -o 'failures="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
errors=$(grep -o 'errors="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
time=$(grep -o 'time="[0-9.]*"' "$result_file" | head -1 | grep -o '[0-9.]*' || echo "0")
# 确保数值有效避免bash算术运算错误
tests=${tests:-0}
failures=${failures:-0}
errors=${errors:-0}
# 转换为整数进行算术运算
total_tests=$((total_tests + tests))
total_failures=$((total_failures + failures))
total_errors=$((total_errors + errors))
if [ "$failures" = "0" ] && [ "$errors" = "0" ]; then
echo "✅ **$plugin_name**: $tests tests passed in ${time}s" >> $GITHUB_STEP_SUMMARY
passed_plugins=$((passed_plugins + 1))
else
echo "❌ **$plugin_name**: $tests tests, $failures failures, $errors errors in ${time}s" >> $GITHUB_STEP_SUMMARY
failed_plugins=$((failed_plugins + 1))
fi
else
echo "⚠️ **$plugin_name**: No tests found" >> $GITHUB_STEP_SUMMARY
fi
fi
done
echo "" >> $GITHUB_STEP_SUMMARY
echo "### 📈 Coverage Report" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "📊 **Coverage reports are now available on Codecov**" >> $GITHUB_STEP_SUMMARY
echo "🔗 **This Commit Coverage**: https://codecov.io/gh/${{ github.repository }}/commit/${{ github.sha }}" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### 🎯 Summary" >> $GITHUB_STEP_SUMMARY
echo "- **Total plugins**: $total_plugins" >> $GITHUB_STEP_SUMMARY
echo "- **Passed**: $passed_plugins ✅" >> $GITHUB_STEP_SUMMARY
echo "- **Failed**: $failed_plugins ❌" >> $GITHUB_STEP_SUMMARY
echo "- **Total tests**: $total_tests" >> $GITHUB_STEP_SUMMARY
echo "- **Total failures**: $total_failures" >> $GITHUB_STEP_SUMMARY
echo "- **Total errors**: $total_errors" >> $GITHUB_STEP_SUMMARY
# 如果有失败,显示详细信息
if [ $total_failures -gt 0 ] || [ $total_errors -gt 0 ]; then
echo "" >> $GITHUB_STEP_SUMMARY
echo "### ❌ Failed Tests Details" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Failed plugins**: $failed_plugins" >> $GITHUB_STEP_SUMMARY
echo "**Total failures**: $total_failures" >> $GITHUB_STEP_SUMMARY
echo "**Total errors**: $total_errors" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "📋 **View detailed logs**: [Click here](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
# 显示每个失败插件的详细信息
echo "#### 📊 Failed Plugin Details" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
for result_file in test-results-*.xml; do
if [ -f "$result_file" ]; then
plugin_name=$(echo "$result_file" | sed 's/test-results-\(.*\)\.xml/\1/')
# 检查是否有失败
failures=$(grep -o 'failures="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
errors=$(grep -o 'errors="[0-9]*"' "$result_file" | head -1 | grep -o '[0-9]*' || echo "0")
# 确保数值有效
failures=${failures:-0}
errors=${errors:-0}
if [ "$failures" -gt 0 ] || [ "$errors" -gt 0 ]; then
echo "**$plugin_name**:" >> $GITHUB_STEP_SUMMARY
echo "- Failures: $failures" >> $GITHUB_STEP_SUMMARY
echo "- Errors: $errors" >> $GITHUB_STEP_SUMMARY
echo "- [View plugin logs](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
fi
fi
done
fi
# - name: Merge coverage reports
# run: |
# echo "Merging coverage reports..."
#
# # 使用绝对路径查找,更可靠
# coverage_files=$(find ${{ github.workspace }} -name "coverage-*")
#
# if [ -n "$coverage_files" ]; then
# echo "Found coverage files:"
# echo "$coverage_files"
#
# # 使用gocovmerge顺序合并
# echo "Merging Go coverage files using gocovmerge sequential method..."
#
# # 将文件列表转换为数组
# readarray -t coverage_array <<< "$coverage_files"
# file_count=${#coverage_array[@]}
#
# echo "Total files to merge: $file_count"
#
# # 复制第一个文件作为基础
# cp "${coverage_array[0]}" ${{ github.workspace }}/merged_coverage.out
# echo "Starting with: ${coverage_array[0]}"
#
# # 如果有多个文件,逐个合并其他文件到最终目标
# if [ $file_count -gt 1 ]; then
# echo "Multiple files, merging sequentially with gocovmerge..."
#
# for ((i=1; i<file_count; i++)); do
# current_file="${coverage_array[i]}"
#
# echo "Merging file $((i+1))/$file_count: $current_file"
#
# # 使用gocovmerge合并到最终目标文件
# gocovmerge "${{ github.workspace }}/merged_coverage.out" "$current_file" > "${{ github.workspace }}/temp_merge.out"
# mv "${{ github.workspace }}/temp_merge.out" "${{ github.workspace }}/merged_coverage.out"
#
# echo "Successfully merged with $current_file"
# done
# fi
#
# echo "Coverage reports merged successfully using gocovmerge sequential method"
# echo "Merged file size: $(wc -c < ${{ github.workspace }}/merged_coverage.out) bytes"
# else
# echo "No coverage files found"
# # 创建空的覆盖率文件
# echo "mode: atomic" > ${{ github.workspace }}/merged_coverage.out
# fi
# - name: Upload merged coverage to Codecov
# uses: codecov/codecov-action@v4
# if: always()
# with:
# file: ${{ github.workspace }}/merged_coverage.out
# flags: wasm-go-plugins-tests
# name: codecov-wasm-go-plugins
# fail_ci_if_error: false
# verbose: true

View File

@@ -3,7 +3,7 @@
/istio @SpecialYang @johnlanni
/pkg @SpecialYang @johnlanni @CH3CHO
/plugins @johnlanni @CH3CHO @rinfx @erasernoob
/plugins/wasm-go/extensions/ai-proxy @cr7258 @CH3CHO @rinfx @wydream
/plugins/wasm-go/extensions/ai-proxy @rinfx @wydream @johnlanni
/plugins/wasm-rust @007gzs @jizhuozhi
/registry @Erica177 @2456868764 @johnlanni
/test @Xunzhuo @2456868764 @CH3CHO

View File

@@ -1 +1 @@
higress-console: v2.1.6
higress-console: v2.1.7

View File

@@ -137,6 +137,8 @@ endif
# for now docker is limited to Linux compiles - why ?
include docker/docker.mk
docker-build-amd64: docker.higress-amd64 ## Build and push amdd64 docker images to registry defined by $HUB and $TAG
docker-build: docker.higress ## Build and push docker images to registry defined by $HUB and $TAG
docker-buildx-push: clean-env docker.higress-buildx
@@ -144,7 +146,7 @@ docker-buildx-push: clean-env docker.higress-buildx
export PARENT_GIT_TAG:=$(shell cat VERSION)
export PARENT_GIT_REVISION:=$(TAG)
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.8/envoy-symbol-ARCH.tar.gz
export ENVOY_PACKAGE_URL_PATTERN?=https://github.com/higress-group/proxy/releases/download/v2.1.9/envoy-symbol-ARCH.tar.gz
build-envoy: prebuild
./tools/hack/build-envoy.sh
@@ -192,7 +194,7 @@ install: pre-install
helm install higress helm/higress -n higress-system --create-namespace --set 'global.local=true'
HIGRESS_LATEST_IMAGE_TAG ?= latest
ENVOY_LATEST_IMAGE_TAG ?= latest
ENVOY_LATEST_IMAGE_TAG ?= 48da465cfd0dc5c9ac851bd2b9743780dc82dd8a
ISTIO_LATEST_IMAGE_TAG ?= latest
install-dev: pre-install

View File

@@ -1 +1 @@
v2.1.6
v2.1.7

View File

@@ -247,6 +247,23 @@ spec:
properties:
spec:
properties:
proxies:
items:
properties:
connectTimeout:
type: integer
listenerPort:
type: integer
name:
type: string
serverAddress:
type: string
serverPort:
type: integer
type:
type: string
type: object
type: array
registries:
items:
properties:
@@ -309,6 +326,8 @@ spec:
type: integer
protocol:
type: string
proxyName:
type: string
sni:
type: string
type:

View File

@@ -65,6 +65,7 @@ type McpBridge struct {
unknownFields protoimpl.UnknownFields
Registries []*RegistryConfig `protobuf:"bytes,1,rep,name=registries,proto3" json:"registries,omitempty"`
Proxies []*ProxyConfig `protobuf:"bytes,2,rep,name=proxies,proto3" json:"proxies,omitempty"`
}
func (x *McpBridge) Reset() {
@@ -106,6 +107,13 @@ func (x *McpBridge) GetRegistries() []*RegistryConfig {
return nil
}
func (x *McpBridge) GetProxies() []*ProxyConfig {
if x != nil {
return x.Proxies
}
return nil
}
type RegistryConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -136,6 +144,7 @@ type RegistryConfig struct {
EnableScopeMcpServers *wrappers.BoolValue `protobuf:"bytes,23,opt,name=enableScopeMcpServers,proto3" json:"enableScopeMcpServers,omitempty"`
AllowMcpServers []string `protobuf:"bytes,24,rep,name=allowMcpServers,proto3" json:"allowMcpServers,omitempty"`
Metadata map[string]*InnerMap `protobuf:"bytes,25,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
ProxyName string `protobuf:"bytes,26,opt,name=proxyName,proto3" json:"proxyName,omitempty"`
}
func (x *RegistryConfig) Reset() {
@@ -345,6 +354,100 @@ func (x *RegistryConfig) GetMetadata() map[string]*InnerMap {
return nil
}
func (x *RegistryConfig) GetProxyName() string {
if x != nil {
return x.ProxyName
}
return ""
}
type ProxyConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type string `protobuf:"bytes,1,opt,name=type,proto3" json:"type,omitempty"`
Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"`
ServerAddress string `protobuf:"bytes,3,opt,name=serverAddress,proto3" json:"serverAddress,omitempty"`
ServerPort uint32 `protobuf:"varint,4,opt,name=serverPort,proto3" json:"serverPort,omitempty"`
ListenerPort uint32 `protobuf:"varint,5,opt,name=listenerPort,proto3" json:"listenerPort,omitempty"`
ConnectTimeout uint32 `protobuf:"varint,6,opt,name=connectTimeout,proto3" json:"connectTimeout,omitempty"`
}
func (x *ProxyConfig) Reset() {
*x = ProxyConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ProxyConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ProxyConfig) ProtoMessage() {}
func (x *ProxyConfig) ProtoReflect() protoreflect.Message {
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ProxyConfig.ProtoReflect.Descriptor instead.
func (*ProxyConfig) Descriptor() ([]byte, []int) {
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{2}
}
func (x *ProxyConfig) GetType() string {
if x != nil {
return x.Type
}
return ""
}
func (x *ProxyConfig) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *ProxyConfig) GetServerAddress() string {
if x != nil {
return x.ServerAddress
}
return ""
}
func (x *ProxyConfig) GetServerPort() uint32 {
if x != nil {
return x.ServerPort
}
return 0
}
func (x *ProxyConfig) GetListenerPort() uint32 {
if x != nil {
return x.ListenerPort
}
return 0
}
func (x *ProxyConfig) GetConnectTimeout() uint32 {
if x != nil {
return x.ConnectTimeout
}
return 0
}
type InnerMap struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -356,7 +459,7 @@ type InnerMap struct {
func (x *InnerMap) Reset() {
*x = InnerMap{}
if protoimpl.UnsafeEnabled {
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -369,7 +472,7 @@ func (x *InnerMap) String() string {
func (*InnerMap) ProtoMessage() {}
func (x *InnerMap) ProtoReflect() protoreflect.Message {
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[2]
mi := &file_networking_v1_mcp_bridge_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -382,7 +485,7 @@ func (x *InnerMap) ProtoReflect() protoreflect.Message {
// Deprecated: Use InnerMap.ProtoReflect.Descriptor instead.
func (*InnerMap) Descriptor() ([]byte, []int) {
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{2}
return file_networking_v1_mcp_bridge_proto_rawDescGZIP(), []int{3}
}
func (x *InnerMap) GetInnerMap() map[string]string {
@@ -404,100 +507,119 @@ var file_networking_v1_mcp_bridge_proto_rawDesc = []byte{
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x77, 0x72, 0x61, 0x70, 0x70, 0x65,
0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65,
0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x52, 0x0a, 0x09, 0x4d, 0x63, 0x70, 0x42, 0x72, 0x69,
0x64, 0x67, 0x65, 0x12, 0x45, 0x0a, 0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65,
0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73,
0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e,
0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a,
0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x22, 0xa8, 0x09, 0x0a, 0x0e, 0x52,
0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x17, 0x0a,
0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02,
0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52,
0x06, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x17, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18,
0x04, 0x20, 0x01, 0x28, 0x0d, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74,
0x12, 0x2e, 0x0a, 0x12, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x6e, 0x61,
0x63, 0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4b,
0x65, 0x79, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41,
0x63, 0x63, 0x65, 0x73, 0x73, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f,
0x73, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79,
0x12, 0x2a, 0x0a, 0x10, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61,
0x63, 0x65, 0x49, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6e, 0x61, 0x63, 0x6f,
0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x0e,
0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x09,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73,
0x70, 0x61, 0x63, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x47, 0x72, 0x6f,
0x75, 0x70, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73,
0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52,
0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x0b,
0x20, 0x01, 0x28, 0x03, 0x52, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52, 0x65, 0x66, 0x72, 0x65,
0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x28, 0x0a, 0x0f, 0x63, 0x6f,
0x6e, 0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x0c, 0x20,
0x01, 0x28, 0x09, 0x52, 0x0f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73,
0x70, 0x61, 0x63, 0x65, 0x12, 0x26, 0x0a, 0x0e, 0x7a, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
0x65, 0x73, 0x50, 0x61, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x7a, 0x6b,
0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x50, 0x61, 0x74, 0x68, 0x12, 0x2a, 0x0a, 0x10,
0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72,
0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61,
0x74, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x73,
0x75, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x54, 0x61, 0x67, 0x18, 0x0f, 0x20, 0x01,
0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
0x65, 0x54, 0x61, 0x67, 0x12, 0x34, 0x0a, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65,
0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x10, 0x20,
0x01, 0x28, 0x03, 0x52, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65, 0x66, 0x72, 0x65,
0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x26, 0x0a, 0x0e, 0x61, 0x75,
0x74, 0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x11, 0x20, 0x01,
0x28, 0x09, 0x52, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61,
0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x12,
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10,
0x0a, 0x03, 0x73, 0x6e, 0x69, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x6e, 0x69,
0x12, 0x36, 0x0a, 0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70,
0x6f, 0x72, 0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x14, 0x20, 0x03, 0x28, 0x09,
0x52, 0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70, 0x6f, 0x72,
0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x6d, 0x63, 0x70, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73, 0x65, 0x55, 0x72, 0x6c, 0x18, 0x15, 0x20, 0x01,
0x28, 0x09, 0x52, 0x10, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73,
0x65, 0x55, 0x72, 0x6c, 0x12, 0x44, 0x0a, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x4d, 0x43,
0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e,
0x42, 0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c,
0x65, 0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x50, 0x0a, 0x15, 0x65, 0x6e,
0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x73, 0x18, 0x17, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c,
0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x15, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f,
0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x28, 0x0a, 0x0f,
0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18,
0x18, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x4f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61,
0x74, 0x61, 0x18, 0x19, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x90, 0x01, 0x0a, 0x09, 0x4d, 0x63, 0x70, 0x42, 0x72,
0x69, 0x64, 0x67, 0x65, 0x12, 0x45, 0x0a, 0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69,
0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65,
0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31,
0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e,
0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d,
0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x5c, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64,
0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x76, 0x61,
0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x68, 0x69, 0x67, 0x72,
0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76,
0x31, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75,
0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x93, 0x01, 0x0a, 0x08, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d,
0x61, 0x70, 0x12, 0x4a, 0x0a, 0x09, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x5f, 0x6d, 0x61, 0x70, 0x18,
0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e,
0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e,
0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45,
0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x1a, 0x3b,
0x0a, 0x0d, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12,
0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65,
0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2e, 0x5a, 0x2c, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62,
0x61, 0x2f, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6e, 0x65,
0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
0x2e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52,
0x0a, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x69, 0x65, 0x73, 0x12, 0x3c, 0x0a, 0x07, 0x70,
0x72, 0x6f, 0x78, 0x69, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x22, 0x2e, 0x68,
0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e,
0x67, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67,
0x52, 0x07, 0x70, 0x72, 0x6f, 0x78, 0x69, 0x65, 0x73, 0x22, 0xc6, 0x09, 0x0a, 0x0e, 0x52, 0x65,
0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x17, 0x0a, 0x04,
0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52,
0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x6f, 0x6d,
0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x06,
0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x17, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x04,
0x20, 0x01, 0x28, 0x0d, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12,
0x2e, 0x0a, 0x12, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x6e, 0x61, 0x63,
0x6f, 0x73, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x4b, 0x65,
0x79, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x41, 0x63,
0x63, 0x65, 0x73, 0x73, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73,
0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52,
0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4b, 0x65, 0x79, 0x12,
0x2a, 0x0a, 0x10, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63,
0x65, 0x49, 0x64, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6e, 0x61, 0x63, 0x6f, 0x73,
0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x49, 0x64, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x09, 0x20,
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70,
0x61, 0x63, 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x47, 0x72, 0x6f, 0x75,
0x70, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x47,
0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52, 0x65,
0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x0b, 0x20,
0x01, 0x28, 0x03, 0x52, 0x14, 0x6e, 0x61, 0x63, 0x6f, 0x73, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73,
0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x28, 0x0a, 0x0f, 0x63, 0x6f, 0x6e,
0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x0c, 0x20, 0x01,
0x28, 0x09, 0x52, 0x0f, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70,
0x61, 0x63, 0x65, 0x12, 0x26, 0x0a, 0x0e, 0x7a, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x73, 0x50, 0x61, 0x74, 0x68, 0x18, 0x0d, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0e, 0x7a, 0x6b, 0x53,
0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x50, 0x61, 0x74, 0x68, 0x12, 0x2a, 0x0a, 0x10, 0x63,
0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x18,
0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x44, 0x61, 0x74,
0x61, 0x63, 0x65, 0x6e, 0x74, 0x65, 0x72, 0x12, 0x2a, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75,
0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x54, 0x61, 0x67, 0x18, 0x0f, 0x20, 0x01, 0x28,
0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x54, 0x61, 0x67, 0x12, 0x34, 0x0a, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65, 0x66,
0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x10, 0x20, 0x01,
0x28, 0x03, 0x52, 0x15, 0x63, 0x6f, 0x6e, 0x73, 0x75, 0x6c, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73,
0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x26, 0x0a, 0x0e, 0x61, 0x75, 0x74,
0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x11, 0x20, 0x01, 0x28,
0x09, 0x52, 0x0e, 0x61, 0x75, 0x74, 0x68, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x4e, 0x61, 0x6d,
0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x12, 0x20,
0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x10, 0x0a,
0x03, 0x73, 0x6e, 0x69, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x6e, 0x69, 0x12,
0x36, 0x0a, 0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70, 0x6f,
0x72, 0x74, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x14, 0x20, 0x03, 0x28, 0x09, 0x52,
0x16, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x45, 0x78, 0x70, 0x6f, 0x72, 0x74,
0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x2a, 0x0a, 0x10, 0x6d, 0x63, 0x70, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73, 0x65, 0x55, 0x72, 0x6c, 0x18, 0x15, 0x20, 0x01, 0x28,
0x09, 0x52, 0x10, 0x6d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x61, 0x73, 0x65,
0x55, 0x72, 0x6c, 0x12, 0x44, 0x0a, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x4d, 0x43, 0x50,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67,
0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42,
0x6f, 0x6f, 0x6c, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x0f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65,
0x4d, 0x43, 0x50, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x50, 0x0a, 0x15, 0x65, 0x6e, 0x61,
0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x73, 0x18, 0x17, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c,
0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x42, 0x6f, 0x6f, 0x6c, 0x56,
0x61, 0x6c, 0x75, 0x65, 0x52, 0x15, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x53, 0x63, 0x6f, 0x70,
0x65, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x28, 0x0a, 0x0f, 0x61,
0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x18,
0x20, 0x03, 0x28, 0x09, 0x52, 0x0f, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x4d, 0x63, 0x70, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x4f, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74,
0x61, 0x18, 0x19, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73,
0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e,
0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x72, 0x79, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x4d,
0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65,
0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x4e,
0x61, 0x6d, 0x65, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x6f, 0x78, 0x79,
0x4e, 0x61, 0x6d, 0x65, 0x1a, 0x5c, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61,
0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65,
0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73,
0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49,
0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02,
0x38, 0x01, 0x22, 0xdb, 0x01, 0x0a, 0x0b, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x43, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x12, 0x17, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x04,
0x6e, 0x61, 0x6d, 0x65, 0x12, 0x29, 0x0a, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64,
0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x03, 0xe0, 0x41, 0x02,
0x52, 0x0d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12,
0x23, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20,
0x01, 0x28, 0x0d, 0x42, 0x03, 0xe0, 0x41, 0x02, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x50, 0x6f, 0x72, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x6c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x65, 0x72,
0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x6c, 0x69, 0x73, 0x74,
0x65, 0x6e, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x6f, 0x6e, 0x6e,
0x65, 0x63, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d,
0x52, 0x0e, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74,
0x22, 0x93, 0x01, 0x0a, 0x08, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x12, 0x4a, 0x0a,
0x09, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x5f, 0x6d, 0x61, 0x70, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b,
0x32, 0x2d, 0x2e, 0x68, 0x69, 0x67, 0x72, 0x65, 0x73, 0x73, 0x2e, 0x6e, 0x65, 0x74, 0x77, 0x6f,
0x72, 0x6b, 0x69, 0x6e, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61,
0x70, 0x2e, 0x49, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52,
0x08, 0x69, 0x6e, 0x6e, 0x65, 0x72, 0x4d, 0x61, 0x70, 0x1a, 0x3b, 0x0a, 0x0d, 0x49, 0x6e, 0x6e,
0x65, 0x72, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05,
0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c,
0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x2e, 0x5a, 0x2c, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x69, 0x62, 0x61, 0x62, 0x61, 0x2f, 0x68, 0x69, 0x67,
0x72, 0x65, 0x73, 0x73, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b,
0x69, 0x6e, 0x67, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -512,27 +634,29 @@ func file_networking_v1_mcp_bridge_proto_rawDescGZIP() []byte {
return file_networking_v1_mcp_bridge_proto_rawDescData
}
var file_networking_v1_mcp_bridge_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_networking_v1_mcp_bridge_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_networking_v1_mcp_bridge_proto_goTypes = []interface{}{
(*McpBridge)(nil), // 0: higress.networking.v1.McpBridge
(*RegistryConfig)(nil), // 1: higress.networking.v1.RegistryConfig
(*InnerMap)(nil), // 2: higress.networking.v1.InnerMap
nil, // 3: higress.networking.v1.RegistryConfig.MetadataEntry
nil, // 4: higress.networking.v1.InnerMap.InnerMapEntry
(*wrappers.BoolValue)(nil), // 5: google.protobuf.BoolValue
(*ProxyConfig)(nil), // 2: higress.networking.v1.ProxyConfig
(*InnerMap)(nil), // 3: higress.networking.v1.InnerMap
nil, // 4: higress.networking.v1.RegistryConfig.MetadataEntry
nil, // 5: higress.networking.v1.InnerMap.InnerMapEntry
(*wrappers.BoolValue)(nil), // 6: google.protobuf.BoolValue
}
var file_networking_v1_mcp_bridge_proto_depIdxs = []int32{
1, // 0: higress.networking.v1.McpBridge.registries:type_name -> higress.networking.v1.RegistryConfig
5, // 1: higress.networking.v1.RegistryConfig.enableMCPServer:type_name -> google.protobuf.BoolValue
5, // 2: higress.networking.v1.RegistryConfig.enableScopeMcpServers:type_name -> google.protobuf.BoolValue
3, // 3: higress.networking.v1.RegistryConfig.metadata:type_name -> higress.networking.v1.RegistryConfig.MetadataEntry
4, // 4: higress.networking.v1.InnerMap.inner_map:type_name -> higress.networking.v1.InnerMap.InnerMapEntry
2, // 5: higress.networking.v1.RegistryConfig.MetadataEntry.value:type_name -> higress.networking.v1.InnerMap
6, // [6:6] is the sub-list for method output_type
6, // [6:6] is the sub-list for method input_type
6, // [6:6] is the sub-list for extension type_name
6, // [6:6] is the sub-list for extension extendee
0, // [0:6] is the sub-list for field type_name
2, // 1: higress.networking.v1.McpBridge.proxies:type_name -> higress.networking.v1.ProxyConfig
6, // 2: higress.networking.v1.RegistryConfig.enableMCPServer:type_name -> google.protobuf.BoolValue
6, // 3: higress.networking.v1.RegistryConfig.enableScopeMcpServers:type_name -> google.protobuf.BoolValue
4, // 4: higress.networking.v1.RegistryConfig.metadata:type_name -> higress.networking.v1.RegistryConfig.MetadataEntry
5, // 5: higress.networking.v1.InnerMap.inner_map:type_name -> higress.networking.v1.InnerMap.InnerMapEntry
3, // 6: higress.networking.v1.RegistryConfig.MetadataEntry.value:type_name -> higress.networking.v1.InnerMap
7, // [7:7] is the sub-list for method output_type
7, // [7:7] is the sub-list for method input_type
7, // [7:7] is the sub-list for extension type_name
7, // [7:7] is the sub-list for extension extendee
0, // [0:7] is the sub-list for field type_name
}
func init() { file_networking_v1_mcp_bridge_proto_init() }
@@ -566,6 +690,18 @@ func file_networking_v1_mcp_bridge_proto_init() {
}
}
file_networking_v1_mcp_bridge_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ProxyConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_networking_v1_mcp_bridge_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*InnerMap); i {
case 0:
return &v.state
@@ -584,7 +720,7 @@ func file_networking_v1_mcp_bridge_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_networking_v1_mcp_bridge_proto_rawDesc,
NumEnums: 0,
NumMessages: 5,
NumMessages: 6,
NumExtensions: 0,
NumServices: 0,
},

View File

@@ -46,6 +46,7 @@ option go_package = "github.com/alibaba/higress/api/networking/v1";
// -->
message McpBridge {
repeated RegistryConfig registries = 1;
repeated ProxyConfig proxies = 2;
}
message RegistryConfig {
@@ -74,6 +75,16 @@ message RegistryConfig {
google.protobuf.BoolValue enableScopeMcpServers = 23;
repeated string allowMcpServers = 24;
map<string, InnerMap> metadata = 25;
string proxyName = 26;
}
message ProxyConfig {
string type = 1 [(google.api.field_behavior) = REQUIRED];
string name = 2 [(google.api.field_behavior) = REQUIRED];
string serverAddress = 3 [(google.api.field_behavior) = REQUIRED];
uint32 serverPort = 4 [(google.api.field_behavior) = REQUIRED];
uint32 listenerPort = 5;
uint32 connectTimeout = 6;
}
message InnerMap {

View File

@@ -47,6 +47,27 @@ func (in *RegistryConfig) DeepCopyInterface() interface{} {
return in.DeepCopy()
}
// DeepCopyInto supports using ProxyConfig within kubernetes types, where deepcopy-gen is used.
func (in *ProxyConfig) DeepCopyInto(out *ProxyConfig) {
p := proto.Clone(in).(*ProxyConfig)
*out = *p
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProxyConfig. Required by controller-gen.
func (in *ProxyConfig) DeepCopy() *ProxyConfig {
if in == nil {
return nil
}
out := new(ProxyConfig)
in.DeepCopyInto(out)
return out
}
// DeepCopyInterface is an autogenerated deepcopy function, copying the receiver, creating a new ProxyConfig. Required by controller-gen.
func (in *ProxyConfig) DeepCopyInterface() interface{} {
return in.DeepCopy()
}
// DeepCopyInto supports using InnerMap within kubernetes types, where deepcopy-gen is used.
func (in *InnerMap) DeepCopyInto(out *InnerMap) {
p := proto.Clone(in).(*InnerMap)

View File

@@ -28,6 +28,17 @@ func (this *RegistryConfig) UnmarshalJSON(b []byte) error {
return McpBridgeUnmarshaler.Unmarshal(bytes.NewReader(b), this)
}
// MarshalJSON is a custom marshaler for ProxyConfig
func (this *ProxyConfig) MarshalJSON() ([]byte, error) {
str, err := McpBridgeMarshaler.MarshalToString(this)
return []byte(str), err
}
// UnmarshalJSON is a custom unmarshaler for ProxyConfig
func (this *ProxyConfig) UnmarshalJSON(b []byte) error {
return McpBridgeUnmarshaler.Unmarshal(bytes.NewReader(b), this)
}
// MarshalJSON is a custom marshaler for InnerMap
func (this *InnerMap) MarshalJSON() ([]byte, error) {
str, err := McpBridgeMarshaler.MarshalToString(this)

View File

@@ -95,6 +95,6 @@ generate-k8s-client:
.PHONY: clean-k8s-client
clean-k8s-cliennt:
clean-k8s-client:
# remove generated code
@rm -rf pkg/

View File

@@ -6,11 +6,11 @@ ARG BASE_VERSION=latest
ARG HUB
ARG TARGETARCH
# The following section is used as base image if BASE_DISTRIBUTION=debug
# This base image is provided by istio, see: https://github.com/istio/istio/blob/master/docker/Dockerfile.base
FROM ${HUB}/base:${BASE_VERSION}
ARG TARGETARCH
FROM ${HUB}/base:${BASE_VERSION}-${TARGETARCH}
COPY ${TARGETARCH:-amd64}/higress /usr/local/bin/higress

View File

@@ -17,6 +17,11 @@ docker.higress: $(OUT_LINUX)/higress
docker.higress: docker/Dockerfile.higress
$(HIGRESS_DOCKER_RULE)
docker.higress-amd64: BUILD_ARGS=--build-arg BASE_VERSION=${HIGRESS_BASE_VERSION} --build-arg HUB=${HUB}
docker.higress-amd64: $(AMD64_OUT_LINUX)/higress
docker.higress-amd64: docker/Dockerfile.higress
$(HIGRESS_DOCKER_AMD64_RULE)
docker.higress-buildx: BUILD_ARGS=--build-arg BASE_VERSION=${HIGRESS_BASE_VERSION} --build-arg HUB=${HUB}
docker.higress-buildx: $(AMD64_OUT_LINUX)/higress
docker.higress-buildx: $(ARM64_OUT_LINUX)/higress
@@ -40,3 +45,4 @@ IMG_URL ?= $(HUB)/$(IMG):$(TAG)
HIGRESS_DOCKER_BUILDX_RULE ?= $(foreach VARIANT,$(DOCKER_BUILD_VARIANTS), time (mkdir -p $(HIGRESS_DOCKER_BUILD_TOP)/$@ && TARGET_ARCH=$(TARGET_ARCH) ./docker/docker-copy.sh $^ $(HIGRESS_DOCKER_BUILD_TOP)/$@ && cd $(HIGRESS_DOCKER_BUILD_TOP)/$@ $(BUILD_PRE) && docker buildx create --name higress --node higress0 --platform linux/amd64,linux/arm64 --use && docker buildx build --no-cache --platform linux/amd64,linux/arm64 $(BUILD_ARGS) --build-arg BASE_DISTRIBUTION=$(call normalize-tag,$(VARIANT)) -t $(IMG_URL)$(call variant-tag,$(VARIANT)) -f Dockerfile.higress . --push ); )
HIGRESS_DOCKER_RULE ?= $(foreach VARIANT,$(DOCKER_BUILD_VARIANTS), time (mkdir -p $(HIGRESS_DOCKER_BUILD_TOP)/$@ && TARGET_ARCH=$(TARGET_ARCH) ./docker/docker-copy.sh $^ $(HIGRESS_DOCKER_BUILD_TOP)/$@ && cd $(HIGRESS_DOCKER_BUILD_TOP)/$@ $(BUILD_PRE) && docker build $(BUILD_ARGS) --build-arg BASE_DISTRIBUTION=$(call normalize-tag,$(VARIANT)) -t $(IMG_URL)$(call variant-tag,$(VARIANT)) -f Dockerfile.higress . ); )
HIGRESS_DOCKER_AMD64_RULE ?= $(foreach VARIANT,$(DOCKER_BUILD_VARIANTS), time (mkdir -p $(HIGRESS_DOCKER_BUILD_TOP)/$@ && TARGET_ARCH=amd64 ./docker/docker-copy.sh $^ $(HIGRESS_DOCKER_BUILD_TOP)/$@ && cd $(HIGRESS_DOCKER_BUILD_TOP)/$@ $(BUILD_PRE) && docker build $(BUILD_ARGS) --build-arg BASE_DISTRIBUTION=$(call normalize-tag,$(VARIANT)) --build-arg TARGETARCH=amd64 -t $(IMG_URL)$(call variant-tag,$(VARIANT)) -f Dockerfile.higress . ); )

View File

@@ -1,5 +1,5 @@
apiVersion: v2
appVersion: 2.1.6
appVersion: 2.1.7
description: Helm chart for deploying higress gateways
icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/
@@ -15,4 +15,4 @@ dependencies:
repository: "file://../redis"
version: 0.0.1
type: application
version: 2.1.6
version: 2.1.7

View File

@@ -247,6 +247,23 @@ spec:
properties:
spec:
properties:
proxies:
items:
properties:
connectTimeout:
type: integer
listenerPort:
type: integer
name:
type: string
serverAddress:
type: string
serverPort:
type: integer
type:
type: string
type: object
type: array
registries:
items:
properties:
@@ -309,6 +326,8 @@ spec:
type: integer
protocol:
type: string
proxyName:
type: string
sni:
type: string
type:

View File

@@ -1,9 +1,9 @@
dependencies:
- name: higress-core
repository: file://../core
version: 2.1.6
version: 2.1.7
- name: higress-console
repository: https://higress.io/helm-charts/
version: 2.1.6
digest: sha256:c5bebb3bd92bf799804443faf9ab69e88ed26815a709e58911859b504b3d04db
generated: "2025-07-30T21:13:57.834398+08:00"
version: 2.1.7
digest: sha256:c5bc8ddcc56c66751217aee5c7a40da0a906bfa9fc5c671cc4ae6e456db6bc21
generated: "2025-09-01T15:19:26.228634+08:00"

View File

@@ -1,5 +1,5 @@
apiVersion: v2
appVersion: 2.1.6
appVersion: 2.1.7
description: Helm chart for deploying Higress gateways
icon: https://higress.io/img/higress_logo_small.png
home: http://higress.io/
@@ -12,9 +12,9 @@ sources:
dependencies:
- name: higress-core
repository: "file://../core"
version: 2.1.6
version: 2.1.7
- name: higress-console
repository: "https://higress.io/helm-charts/"
version: 2.1.6
version: 2.1.7
type: application
version: 2.1.6
version: 2.1.7

View File

@@ -704,9 +704,9 @@ func TestK8sObject_ResolveK8sConflict(t *testing.T) {
t.Run(tt.desc, func(t *testing.T) {
newObj := tt.o1.ResolveK8sConflict()
if !newObj.Equal(tt.o2) {
newObjjson, _ := newObj.JSON()
wantedObjjson, _ := tt.o2.JSON()
t.Errorf("Got: %s, want: %s", string(newObjjson), string(wantedObjjson))
newObjJson, _ := newObj.JSON()
wantedObjJson, _ := tt.o2.JSON()
t.Errorf("Got: %s, want: %s", string(newObjJson), string(wantedObjJson))
}
})
}

View File

@@ -65,7 +65,7 @@ func (o *K8sInstaller) Install() error {
return err1
}
fmt.Fprintf(o.writer, "\n✔ Wrote Profile in kubernetes configmap: \"%s\" \n", profileName)
fmt.Fprintf(o.writer, "\n Use bellow kubectl command to edit profile for upgrade. \n")
fmt.Fprintf(o.writer, "\n Use below kubectl command to edit profile for upgrade. \n")
fmt.Fprintf(o.writer, " ================================================================================== \n")
names := strings.Split(profileName, "/")
fmt.Fprintf(o.writer, " kubectl edit configmap %s -n %s \n", names[1], names[0])

View File

@@ -93,6 +93,15 @@ func (p Protocol) IsUnsupported() bool {
return p == Unsupported
}
func (p Protocol) IsSupportedByProxy() bool {
switch p {
case HTTPS:
return true
default:
return false
}
}
func (p Protocol) String() string {
return string(p)
}

59
pkg/common/proxy.go Normal file
View File

@@ -0,0 +1,59 @@
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"strings"
)
type ProxyType string
const (
ProxyType_Unknown ProxyType = "Unknown"
ProxyType_HTTP ProxyType = "HTTP"
ProxyType_HTTPS ProxyType = "HTTPS"
ProxyType_SOCKS4 ProxyType = "SOCKS4"
ProxyType_SOCKS5 ProxyType = "SOCKS5"
)
func ParseProxyType(s string) ProxyType {
switch strings.ToLower(s) {
case "http":
return ProxyType_HTTP
case "https":
return ProxyType_HTTPS
case "socks4":
return ProxyType_SOCKS4
case "socks5":
return ProxyType_SOCKS5
}
return ProxyType_Unknown
}
func (p ProxyType) GetTransportProtocol() Protocol {
switch p {
case ProxyType_HTTP:
return HTTP
case ProxyType_HTTPS:
return HTTPS
case ProxyType_SOCKS4, ProxyType_SOCKS5:
return TCP
}
return Unsupported
}
func (p ProxyType) String() string {
return string(p)
}

View File

@@ -69,7 +69,7 @@ import (
"github.com/alibaba/higress/pkg/ingress/kube/wasmplugin"
. "github.com/alibaba/higress/pkg/ingress/log"
"github.com/alibaba/higress/pkg/kube"
"github.com/alibaba/higress/registry/memory"
"github.com/alibaba/higress/registry"
"github.com/alibaba/higress/registry/reconcile"
)
@@ -340,10 +340,6 @@ func (m *IngressConfig) listFromIngressControllers(typ config.GroupVersionKind,
}
IngressLog.Infof("Append %d configmap EnvoyFilters", len(configmapEnvoyFilters))
}
if len(envoyFilters) == 0 {
IngressLog.Infof("resource type %s, configs number %d", typ, len(m.cachedEnvoyFilters))
return m.cachedEnvoyFilters
}
envoyFilters = append(envoyFilters, m.cachedEnvoyFilters...)
IngressLog.Infof("resource type %s, configs number %d", typ, len(envoyFilters))
return envoyFilters
@@ -490,6 +486,22 @@ func (m *IngressConfig) convertVirtualService(configs []common.WrapperConfig) []
VirtualServices: map[string]*common.WrapperVirtualService{},
HTTPRoutes: map[string][]*common.WrapperHTTPRoute{},
Route2Ingress: map[string]*common.WrapperConfigWithRuleKey{},
ServiceWrappers: make(map[string]*common.ServiceWrapper),
ProxyWrappers: make(map[string]*common.ProxyWrapper),
}
if m.RegistryReconciler != nil {
for _, sew := range m.RegistryReconciler.GetAllServiceWrapper() {
hosts := sew.ServiceEntry.Hosts
if len(hosts) == 0 {
continue
}
for _, host := range hosts {
convertOptions.ServiceWrappers[host] = sew
}
}
for _, pw := range m.RegistryReconciler.GetAllProxyWrapper() {
convertOptions.ProxyWrappers[pw.ProxyName] = pw
}
}
// convert http route
@@ -616,6 +628,7 @@ func (m *IngressConfig) convertEnvoyFilter(convertOptions *common.ConvertOptions
mappings := map[string]*common.Rule{}
initHttp2RpcGlobalConfig := true
initMcpSseGlobalFilter := true
for _, routes := range convertOptions.HTTPRoutes {
for _, route := range routes {
if strings.HasSuffix(route.HTTPRoute.Name, "app-root") {
@@ -635,6 +648,19 @@ func (m *IngressConfig) convertEnvoyFilter(convertOptions *common.ConvertOptions
}
}
loadBalance := route.WrapperConfig.AnnotationsConfig.LoadBalance
if loadBalance != nil && loadBalance.McpSseStateful {
IngressLog.Infof("Found MCP SSE stateful session for route %s", route.HTTPRoute.Name)
envoyFilter, err := m.constructMcpSseStatefulSessionEnvoyFilter(route, m.namespace, initMcpSseGlobalFilter)
if err != nil {
IngressLog.Errorf("Construct MCP SSE stateful session EnvoyFilter error %v", err)
} else {
IngressLog.Infof("Append MCP SSE stateful session EnvoyFilter for route %s", route.HTTPRoute.Name)
envoyFilters = append(envoyFilters, *envoyFilter)
initMcpSseGlobalFilter = false
}
}
auth := route.WrapperConfig.AnnotationsConfig.Auth
if auth == nil {
continue
@@ -669,6 +695,12 @@ func (m *IngressConfig) convertEnvoyFilter(convertOptions *common.ConvertOptions
}
}
if proxyEnvoyFilters := constructProxyEnvoyFilters(convertOptions.ProxyWrappers, convertOptions.ServiceWrappers, m.namespace); len(proxyEnvoyFilters) != 0 {
for _, ef := range proxyEnvoyFilters {
envoyFilters = append(envoyFilters, *ef)
}
}
// TODO Support other envoy filters
IngressLog.Infof("Found %d number of envoyFilters", len(envoyFilters))
@@ -1113,7 +1145,7 @@ func (m *IngressConfig) AddOrUpdateWasmPlugin(clusterNamespacedName util.Cluster
Labels: map[string]string{constants.AlwaysPushLabel: "true"},
}
for _, f := range m.wasmPluginHandlers {
IngressLog.Debug("WasmPlugin triggerd update")
IngressLog.Debug("WasmPlugin triggered update")
f(config.Config{Meta: metadata}, config.Config{Meta: metadata}, istiomodel.EventUpdate)
}
istioWasmPlugin, err := m.convertIstioWasmPlugin(&wasmPlugin.Spec)
@@ -1155,7 +1187,7 @@ func (m *IngressConfig) DeleteWasmPlugin(clusterNamespacedName util.ClusterNames
Labels: map[string]string{constants.AlwaysPushLabel: "true"},
}
for _, f := range m.wasmPluginHandlers {
IngressLog.Debug("WasmPlugin triggerd update")
IngressLog.Debug("WasmPlugin triggered update")
f(config.Config{Meta: metadata}, config.Config{Meta: metadata}, istiomodel.EventDelete)
}
}
@@ -1211,23 +1243,23 @@ func (m *IngressConfig) AddOrUpdateMcpBridge(clusterNamespacedName util.ClusterN
}
for _, f := range m.serviceEntryHandlers {
IngressLog.Debug("McpBridge triggerd serviceEntry update")
IngressLog.Debug("McpBridge triggered serviceEntry update")
f(config.Config{Meta: seMetadata}, config.Config{Meta: seMetadata}, istiomodel.EventUpdate)
}
for _, f := range m.destinationRuleHandlers {
IngressLog.Debug("McpBridge triggerd destinationRule update")
IngressLog.Debug("McpBridge triggered destinationRule update")
f(config.Config{Meta: drMetadata}, config.Config{Meta: drMetadata}, istiomodel.EventUpdate)
}
for _, f := range m.virtualServiceHandlers {
IngressLog.Debug("McpBridge triggerd virtualservice update")
IngressLog.Debug("McpBridge triggered virtualservice update")
f(config.Config{Meta: vsMetadata}, config.Config{Meta: vsMetadata}, istiomodel.EventUpdate)
}
for _, f := range m.wasmPluginHandlers {
IngressLog.Debug("McpBridge triggerd wasmplugin update")
IngressLog.Debug("McpBridge triggered wasmplugin update")
f(config.Config{Meta: wasmMetadata}, config.Config{Meta: wasmMetadata}, istiomodel.EventUpdate)
}
for _, f := range m.envoyFilterHandlers {
IngressLog.Debug("McpBridge triggerd envoyfilter update")
IngressLog.Debug("McpBridge triggered envoyfilter update")
f(config.Config{Meta: efMetadata}, config.Config{Meta: efMetadata}, istiomodel.EventUpdate)
}
}, m.localKubeClient, m.namespace, m.clusterId.String())
@@ -1295,7 +1327,7 @@ func (m *IngressConfig) DeleteHttp2Rpc(clusterNamespacedName util.ClusterNamespa
}
m.mutex.Unlock()
if hit {
IngressLog.Infof("Http2Rpc triggerd deleted event executed %s", clusterNamespacedName.Name)
IngressLog.Infof("Http2Rpc triggered deleted event executed %s", clusterNamespacedName.Name)
push := func(gvk config.GroupVersionKind) {
m.XDSUpdater.ConfigUpdate(&istiomodel.PushRequest{
Full: true,
@@ -1493,7 +1525,7 @@ func (m *IngressConfig) constructHttp2RpcEnvoyFilter(http2rpcConfig *annotations
return &config.Config{
Meta: config.Meta{
GroupVersionKind: gvk.EnvoyFilter,
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, http2rpcConfig.Name),
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "http2rpc", http2rpcConfig.Name, "route", common.ConvertToDNSLabelValid(httpRoute.Name)),
Namespace: namespace,
},
Spec: &networking.EnvoyFilter{
@@ -1675,28 +1707,150 @@ func constructBasicAuthEnvoyFilter(rules *common.BasicAuthRules, namespace strin
}, nil
}
func QueryByName(serviceEntries []*memory.ServiceWrapper, serviceName string) (*memory.ServiceWrapper, error) {
IngressLog.Infof("Found http2rpc serviceEntries %s", serviceEntries)
for _, se := range serviceEntries {
if se.ServiceName == serviceName {
return se, nil
func constructProxyEnvoyFilters(proxyWrappers map[string]*common.ProxyWrapper, serviceWrappers map[string]*common.ServiceWrapper, namespace string) []*config.Config {
var envoyFilters []*config.Config
for _, proxyWrapper := range proxyWrappers {
envoyFilters = append(envoyFilters, &config.Config{
Meta: config.Meta{
GroupVersionKind: gvk.EnvoyFilter,
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "proxy", proxyWrapper.ProxyName),
Namespace: namespace,
},
Spec: proxyWrapper.EnvoyFilter,
})
}
// Create a cluster for each service that uses a proxy.
var serviceProxyPatches []*networking.EnvoyFilter_EnvoyConfigObjectPatch
for _, serviceWrapper := range serviceWrappers {
proxyConfig := serviceWrapper.ProxyConfig
if proxyConfig == nil || proxyConfig.ProxyName == "" {
continue
}
IngressLog.Debugf("Found service %s using proxy %s", serviceWrapper.ServiceName, proxyConfig.ProxyName)
if err := validateServiceWrapperForProxy(serviceWrapper); err != nil {
IngressLog.Warnf("Service wrapper validation failed for proxy: %v", err)
continue
}
proxyWrapper := proxyWrappers[proxyConfig.ProxyName]
if proxyWrapper == nil {
IngressLog.Warnf("Service %s has proxy config %s, but no corresponding proxy wrapper found", serviceWrapper.ServiceName, proxyConfig.ProxyName)
continue
}
if !proxyConfig.UpstreamProtocol.IsSupportedByProxy() {
IngressLog.Warnf("Proxy %s does not support upstream protocol %s, skipping EnvoyFilter construction for service %s")
continue
}
if proxyWrapper.EnvoyFilter == nil {
IngressLog.Warnf("Proxy %s has no EnvoyFilter generated, meaning not ready for use.", proxyConfig.ProxyName)
continue
}
se := serviceWrapper.ServiceEntry
if se == nil || len(se.Hosts) == 0 || len(se.Ports) == 0 {
continue
}
for _, host := range se.Hosts {
IngressLog.Debugf("Constructing EnvoyFilter for service %s using proxy %s", host, proxyConfig.ProxyName)
for _, port := range se.Ports {
if port == nil || port.Number <= 0 {
continue
}
clusterName := fmt.Sprintf("outbound|%d||%s", port.Number, host)
// We need to delete the original cluster and add a new one pointing to the local proxy listener.
serviceProxyPatches = append(serviceProxyPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
ApplyTo: networking.EnvoyFilter_CLUSTER,
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
Context: networking.EnvoyFilter_GATEWAY,
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Cluster{
Cluster: &networking.EnvoyFilter_ClusterMatch{
Name: clusterName,
},
},
},
Patch: &networking.EnvoyFilter_Patch{
Operation: networking.EnvoyFilter_Patch_REMOVE,
},
})
patchObj := map[string]interface{}{
"name": clusterName,
"type": "STATIC",
"connect_timeout": "10s",
"load_assignment": map[string]interface{}{
"cluster_name": clusterName,
"endpoints": []map[string]interface{}{
{
"lb_endpoints": []map[string]interface{}{
{
"endpoint": map[string]interface{}{
"address": map[string]interface{}{
"socket_address": map[string]interface{}{
"address": "127.0.0.1",
"port_value": proxyWrapper.ListenerPort,
},
},
},
},
},
},
},
},
}
if proxyConfig.UpstreamProtocol.IsHTTPS() {
tlsTypedConfig := map[string]interface{}{
"@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext",
}
if proxyConfig.UpstreamSni != "" {
tlsTypedConfig["sni"] = proxyConfig.UpstreamSni
}
patchObj["transport_socket"] = map[string]interface{}{
"name": "envoy.transport_sockets.tls",
"typed_config": tlsTypedConfig,
}
}
patchJson, _ := json.Marshal(patchObj)
serviceProxyPatches = append(serviceProxyPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
ApplyTo: networking.EnvoyFilter_CLUSTER,
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
Context: networking.EnvoyFilter_GATEWAY,
},
Patch: &networking.EnvoyFilter_Patch{
Operation: networking.EnvoyFilter_Patch_ADD,
Value: util.BuildPatchStruct(string(patchJson)),
},
})
}
}
}
return nil, fmt.Errorf("can't find ServiceEntry by serviceName:%v", serviceName)
if len(serviceProxyPatches) != 0 {
envoyFilters = append(envoyFilters, &config.Config{
Meta: config.Meta{
GroupVersionKind: gvk.EnvoyFilter,
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "service-proxy"),
Namespace: namespace,
},
Spec: &networking.EnvoyFilter{
ConfigPatches: serviceProxyPatches,
},
})
}
return envoyFilters
}
func QueryRpcServiceVersion(serviceEntry *memory.ServiceWrapper, serviceName string) (string, error) {
IngressLog.Infof("Found http2rpc serviceEntry %s", serviceEntry)
IngressLog.Infof("Found http2rpc ServiceEntry %s", serviceEntry.ServiceEntry)
IngressLog.Infof("Found http2rpc WorkloadSelector %s", serviceEntry.ServiceEntry.WorkloadSelector)
IngressLog.Infof("Found http2rpc Labels %s", serviceEntry.ServiceEntry.WorkloadSelector.Labels)
labels := (*serviceEntry).ServiceEntry.WorkloadSelector.Labels
for key, value := range labels {
if key == "version" {
return value, nil
}
func validateServiceWrapperForProxy(serviceWrapper *common.ServiceWrapper) error {
registryType := registry.ServiceRegistryType(serviceWrapper.RegistryType)
switch registryType {
case registry.DNS:
break
default:
return fmt.Errorf("service %s has proxy config %s, but registry type %s is not supported for proxying", serviceWrapper.ServiceName, serviceWrapper.ProxyConfig.ProxyName, registryType)
}
return "", fmt.Errorf("can't get RpcServiceVersion for serviceName:%v", serviceName)
if len(serviceWrapper.ServiceEntry.Endpoints) > 1 {
return fmt.Errorf("service %s has multiple endpoints, which is not supported for proxying with EnvoyFilter. Skipping EnvoyFilter construction", serviceWrapper.ServiceName)
}
return nil
}
func (m *IngressConfig) Run(stop <-chan struct{}) {
@@ -1800,6 +1954,99 @@ func (m *IngressConfig) Delete(config.GroupVersionKind, string, string, *string)
return common.ErrUnsupportedOp
}
func (m *IngressConfig) constructMcpSseStatefulSessionEnvoyFilter(route *common.WrapperHTTPRoute, namespace string, initGlobalFilter bool) (*config.Config, error) {
httpRoute := route.HTTPRoute
var configPatches []*networking.EnvoyFilter_EnvoyConfigObjectPatch
// Add global HTTP filter if this is the first route using MCP SSE stateful session
if initGlobalFilter {
configPatches = append(configPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
ApplyTo: networking.EnvoyFilter_HTTP_FILTER,
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
Context: networking.EnvoyFilter_GATEWAY,
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_Listener{
Listener: &networking.EnvoyFilter_ListenerMatch{
FilterChain: &networking.EnvoyFilter_ListenerMatch_FilterChainMatch{
Filter: &networking.EnvoyFilter_ListenerMatch_FilterMatch{
Name: "envoy.filters.network.http_connection_manager",
SubFilter: &networking.EnvoyFilter_ListenerMatch_SubFilterMatch{
Name: "envoy.filters.http.router",
},
},
},
},
},
},
Patch: &networking.EnvoyFilter_Patch{
Operation: networking.EnvoyFilter_Patch_INSERT_BEFORE,
Value: buildPatchStruct(`{
"name": "envoy.filters.http.mcp_sse_stateful_session",
"typed_config": {
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
"type_url": "type.googleapis.com/envoy.extensions.filters.http.mcp_sse_stateful_session.v3alpha.McpSseStatefulSession"
}
}`),
},
})
}
// Add route-specific configuration
configPatches = append(configPatches, &networking.EnvoyFilter_EnvoyConfigObjectPatch{
ApplyTo: networking.EnvoyFilter_HTTP_ROUTE,
Match: &networking.EnvoyFilter_EnvoyConfigObjectMatch{
Context: networking.EnvoyFilter_GATEWAY,
ObjectTypes: &networking.EnvoyFilter_EnvoyConfigObjectMatch_RouteConfiguration{
RouteConfiguration: &networking.EnvoyFilter_RouteConfigurationMatch{
Vhost: &networking.EnvoyFilter_RouteConfigurationMatch_VirtualHostMatch{
Route: &networking.EnvoyFilter_RouteConfigurationMatch_RouteMatch{
Name: httpRoute.Name,
},
},
},
},
},
Patch: &networking.EnvoyFilter_Patch{
Operation: networking.EnvoyFilter_Patch_MERGE,
Value: buildPatchStruct(`{
"typed_per_filter_config": {
"envoy.filters.http.mcp_sse_stateful_session": {
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
"type_url": "type.googleapis.com/envoy.extensions.filters.http.mcp_sse_stateful_session.v3alpha.McpSseStatefulSessionPerRoute",
"value": {
"mcp_sse_stateful_session": {
"session_state": {
"name": "envoy.http.mcp_sse_stateful_session.envelope",
"typed_config": {
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
"type_url": "type.googleapis.com/envoy.extensions.http.mcp_sse_stateful_session.envelope.v3alpha.EnvelopeSessionState",
"value": {
"param_name": "sessionId",
"chunk_end_patterns": ["\r\n\r\n", "\n\n", "\r\r"]
}
}
},
"strict": true
}
}
}
}
}`),
},
})
return &config.Config{
Meta: config.Meta{
GroupVersionKind: gvk.EnvoyFilter,
Name: common.CreateConvertedName(constants.IstioIngressGatewayName, "mcp-lb-route", common.ConvertToDNSLabelValid(httpRoute.Name)),
Namespace: namespace,
},
Spec: &networking.EnvoyFilter{
ConfigPatches: configPatches,
},
}, nil
}
func (m *IngressConfig) notifyXDSFullUpdate(gvk config.GroupVersionKind, reason istiomodel.TriggerReason, updatedConfigName *util.ClusterNamespacedName) {
var configsUpdated map[istiomodel.ConfigKey]struct{}
if updatedConfigName != nil {

View File

@@ -66,9 +66,10 @@ type consistentHashByCookie struct {
}
type LoadBalanceConfig struct {
simple networking.LoadBalancerSettings_SimpleLB
other *consistentHashByOther
cookie *consistentHashByCookie
simple networking.LoadBalancerSettings_SimpleLB
other *consistentHashByOther
cookie *consistentHashByCookie
McpSseStateful bool
}
type loadBalance struct{}
@@ -129,7 +130,11 @@ func (l loadBalance) Parse(annotations Annotations, config *Ingress, _ *GlobalCo
} else {
if lb, err := annotations.ParseStringASAP(loadBalanceAnnotation); err == nil {
lb = strings.ToUpper(lb)
loadBalanceConfig.simple = networking.LoadBalancerSettings_SimpleLB(networking.LoadBalancerSettings_SimpleLB_value[lb])
if lb == "MCP-SSE" {
loadBalanceConfig.McpSseStateful = true
} else {
loadBalanceConfig.simple = networking.LoadBalancerSettings_SimpleLB(networking.LoadBalancerSettings_SimpleLB_value[lb])
}
}
}

View File

@@ -16,9 +16,8 @@ package common
import (
"strings"
"time"
"github.com/alibaba/higress/pkg/cert"
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
networking "istio.io/api/networking/v1alpha3"
"istio.io/istio/pilot/pkg/model"
"istio.io/istio/pkg/cluster"
@@ -26,6 +25,10 @@ import (
gatewaytool "istio.io/istio/pkg/config/gateway"
listerv1 "k8s.io/client-go/listers/core/v1"
"k8s.io/client-go/tools/cache"
"github.com/alibaba/higress/pkg/cert"
"github.com/alibaba/higress/pkg/common"
"github.com/alibaba/higress/pkg/ingress/kube/annotations"
)
type ServiceKey struct {
@@ -120,6 +123,68 @@ type WrapperDestinationRule struct {
ServiceKey ServiceKey
}
type ServiceProxyConfig struct {
ProxyName string
UpstreamProtocol common.Protocol
UpstreamSni string
}
type ServiceWrapper struct {
ServiceName string
ServiceEntry *networking.ServiceEntry
DestinationRuleWrapper *WrapperDestinationRule
Suffix string
RegistryType string
RegistryName string
ProxyConfig *ServiceProxyConfig
createTime time.Time
}
func (sew *ServiceWrapper) DeepCopy() *ServiceWrapper {
res := &ServiceWrapper{}
*res = *sew
res.ServiceEntry = sew.ServiceEntry.DeepCopy()
if sew.DestinationRuleWrapper != nil {
res.DestinationRuleWrapper = sew.DestinationRuleWrapper
res.DestinationRuleWrapper.DestinationRule = sew.DestinationRuleWrapper.DestinationRule.DeepCopy()
}
return res
}
func (sew *ServiceWrapper) SetCreateTime(createTime time.Time) {
sew.createTime = createTime
}
func (sew *ServiceWrapper) GetCreateTime() time.Time {
return sew.createTime
}
type ProxyWrapper struct {
ProxyName string
ListenerPort uint32
EnvoyFilter *networking.EnvoyFilter
createTime time.Time
}
func (pw *ProxyWrapper) DeepCopy() *ProxyWrapper {
res := &ProxyWrapper{}
*res = *pw
if pw.EnvoyFilter != nil {
res.EnvoyFilter = pw.EnvoyFilter.DeepCopy()
}
return res
}
func (pw *ProxyWrapper) SetCreateTime(createTime time.Time) {
pw.createTime = createTime
}
func (pw *ProxyWrapper) GetCreateTime() time.Time {
return pw.createTime
}
type IngressController interface {
// RegisterEventHandler adds a handler to receive config update events for a
// configuration type

View File

@@ -169,6 +169,10 @@ type ConvertOptions struct {
Service2TrafficPolicy map[ServiceKey]*WrapperTrafficPolicy
ServiceWrappers map[string]*ServiceWrapper
ProxyWrappers map[string]*ProxyWrapper
HasDefaultBackend bool
}

View File

@@ -146,7 +146,7 @@ func GetHost(annotations map[string]string) string {
// Istio requires that the name of the gateway must conform to the DNS label.
// For details, you can view: https://github.com/istio/istio/blob/2d5c40ad5e9cceebe64106005aa38381097da2ba/pkg/config/validation/validation.go#L478
func convertToDNSLabelValid(input string) string {
func ConvertToDNSLabelValid(input string) string {
hasher := md5.New()
hasher.Write([]byte(input))
hash := hasher.Sum(nil)
@@ -156,7 +156,7 @@ func convertToDNSLabelValid(input string) string {
// CleanHost follow the format of mse-ops for host.
func CleanHost(host string) string {
return convertToDNSLabelValid(host)
return ConvertToDNSLabelValid(host)
}
func CreateConvertedName(items ...string) string {

View File

@@ -158,7 +158,7 @@ func (c *ConfigmapMgr) AddOrUpdateHigressConfig(name util.ClusterNamespacedName)
IngressLog.Infof("configmapMgr oldHigressConfig: %s", GetHigressConfigString(oldHigressConfig))
IngressLog.Infof("configmapMgr newHigressConfig: %s", GetHigressConfigString(newHigressConfig))
result, _ := c.CompareHigressConfig(oldHigressConfig, newHigressConfig)
IngressLog.Infof("configmapMgr CompareHigressConfig reuslt is %d", result)
IngressLog.Infof("configmapMgr CompareHigressConfig result is %d", result)
if result == ResultNothing {
return
@@ -177,7 +177,7 @@ func (c *ConfigmapMgr) AddOrUpdateHigressConfig(name util.ClusterNamespacedName)
}
}
c.SetHigressConfig(newHigressConfig)
IngressLog.Infof("configmapMgr higress config AddOrUpdate success, reuslt is %d", result)
IngressLog.Infof("configmapMgr higress config AddOrUpdate success, result is %d", result)
// Call updateConfig
}

View File

@@ -509,6 +509,11 @@ func (m *McpServerController) constructMcpSessionStruct(mcp *McpServer) string {
}
func (m *McpServerController) constructMcpServerStruct(mcp *McpServer) string {
// if no servers, return empty string
if mcp == nil || len(mcp.Servers) == 0 {
return ""
}
// Build servers configuration
servers := "[]"
if len(mcp.Servers) > 0 {

View File

@@ -566,7 +566,7 @@ func TestMcpServerController_ConstructEnvoyFilters(t *testing.T) {
MatchList: []*MatchRule{},
Servers: []*SSEServer{},
},
wantConfigs: 2, // Both session and server filters
wantConfigs: 1, // Only session filter when no servers configured
wantErr: nil,
},
}
@@ -744,24 +744,7 @@ func TestMcpServerController_constructMcpServerStruct(t *testing.T) {
mcp: &McpServer{
Servers: []*SSEServer{},
},
wantJSON: `{
"name": "envoy.filters.http.golang",
"typed_config": {
"@type": "type.googleapis.com/udpa.type.v1.TypedStruct",
"type_url": "type.googleapis.com/envoy.extensions.filters.http.golang.v3alpha.Config",
"value": {
"library_id": "mcp-server",
"library_path": "/var/lib/istio/envoy/golang-filter.so",
"plugin_name": "mcp-server",
"plugin_config": {
"@type": "type.googleapis.com/xds.type.v3.TypedStruct",
"value": {
"servers": []
}
}
}
}
}`,
wantJSON: "", // Return empty string when no servers configured
},
{
name: "with servers",

View File

@@ -286,7 +286,7 @@ func testConvertHTTPRoute(t *testing.T, c common.KIngressController) {
expectNoError: true,
},
{
description: "valid httpRoute convention, vaild ingress",
description: "valid httpRoute convention, valid ingress",
input: struct {
options *common.ConvertOptions
wrapperConfig *common.WrapperConfig

View File

@@ -57,12 +57,12 @@ func (f *filter) DecodeHeaders(header api.RequestHeaderMap, endStream bool) api.
}
func (f *filter) DecodeData(buffer api.BufferInstance, endStream bool) api.StatusType {
if !endStream {
return api.StopAndBuffer
}
if f.message {
for _, server := range f.config.servers {
if f.path == server.BaseServer.GetMessageEndpoint() {
if !endStream {
return api.StopAndBuffer
}
// Create a response recorder to capture the response
recorder := httptest.NewRecorder()
// Call the handleMessage method of SSEServer with complete body

View File

@@ -77,7 +77,7 @@ func (n *NacosMcpRegistry) refreshToolsListForGroup(group string, serviceMatcher
serviceList := services.Doms
pattern, err := regexp.Compile(serviceMatcher)
if err != nil {
api.LogErrorf("Match service error for patter %s", serviceMatcher)
api.LogErrorf("Match service error for pattern %s", serviceMatcher)
return false
}

View File

@@ -149,6 +149,9 @@ func (f *filter) processMcpRequestHeadersForRestUpstream(header api.RequestHeade
func (f *filter) processMcpRequestHeadersForSSEUpstream(header api.RequestHeaderMap, endStream bool) api.StatusType {
// We don't need to process the request body for SSE upstream.
f.skipRequestBody = true
// Remove Accept-Encoding header to avoid gzip encoding,
// which our response body handling logic doesn't support.
header.Del("Accept-Encoding")
return api.Continue
}

View File

@@ -14,5 +14,5 @@ export {SetCtx,
ProcessRequestHeadersBy,
ProcessResponseBodyBy,
ProcessResponseHeadersBy,
Logger, RegisteTickFunc} from "./plugin_wrapper"
Logger, RegisterTickFunc} from "./plugin_wrapper"
export {ParseResult} from "./rule_matcher"

View File

@@ -156,7 +156,7 @@ class TickFuncEntry {
var globalOnTickFuncs = new Array<TickFuncEntry>();
export function RegisteTickFunc(tickPeriod: i64, tickFunc: () => void): void {
export function RegisterTickFunc(tickPeriod: i64, tickFunc: () => void): void {
globalOnTickFuncs.push(new TickFuncEntry(0, tickPeriod, tickFunc));
}

View File

@@ -1,5 +1,5 @@
export * from "@higress/proxy-wasm-assemblyscript-sdk/assembly/proxy";
import { SetCtx, HttpContext, ProcessRequestHeadersBy, Logger, ParseResult, ParseConfigBy, RegisteTickFunc, ProcessResponseHeadersBy } from "@higress/wasm-assemblyscript/assembly";
import { SetCtx, HttpContext, ProcessRequestHeadersBy, Logger, ParseResult, ParseConfigBy, RegisterTickFunc, ProcessResponseHeadersBy } from "@higress/wasm-assemblyscript/assembly";
import { FilterHeadersStatusValues, send_http_response, stream_context } from "@higress/proxy-wasm-assemblyscript-sdk/assembly"
import { JSON } from "assemblyscript-json/assembly";
class HelloWorldConfig {
@@ -12,10 +12,10 @@ SetCtx<HelloWorldConfig>("hello-world",
])
function parseConfig(json: JSON.Obj): ParseResult<HelloWorldConfig> {
RegisteTickFunc(2000, () => {
RegisterTickFunc(2000, () => {
Logger.Debug("tick 2s");
})
RegisteTickFunc(5000, () => {
RegisterTickFunc(5000, () => {
Logger.Debug("tick 5s");
})
return new ParseResult<HelloWorldConfig>(new HelloWorldConfig(), true);

View File

@@ -243,7 +243,7 @@ class RouteRuleMatcher {
std::string route_name;
getValue({"route_name"}, &route_name);
std::string service_name;
getValue({"service_name"}, &service_name);
getValue({"cluster_name"}, &service_name);
std::optional<std::reference_wrapper<PluginConfig>> match_config;
std::optional<std::reference_wrapper<std::unordered_set<std::string>>>
allow_set;

View File

@@ -15,7 +15,7 @@ WORKDIR /workspace/extensions/$PLUGIN_NAME
RUN go mod tidy
RUN \
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o /main.wasm ./
GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o /main.wasm .
FROM scratch AS output

View File

@@ -60,7 +60,7 @@ builder:
@echo "image: ${BUILDER}"
local-build:
cd extensions/${PLUGIN_NAME};GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o ./main.wasm ./
cd extensions/${PLUGIN_NAME};GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o ./main.wasm .
@echo ""
@echo "wasm: extensions/${PLUGIN_NAME}/main.wasm"

View File

@@ -148,6 +148,49 @@ spec:
所有规则会按上面配置的顺序一次执行匹配,当有一个规则匹配时,就停止匹配,并选择匹配的配置执行插件逻辑。
## 单元测试
在开发wasm插件时建议同时编写单元测试来验证插件功能。详细的单元测试编写指南请参考 [wasm plugin unit test](https://github.com/higress-group/wasm-go/blob/main/pkg/test/README.md)。
### 单元测试样例
```go
func TestMyPlugin(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 1. 创建测试主机
config := json.RawMessage(`{"key": "value"}`)
host, status := test.NewTestHost(config)
require.Equal(t, types.OnPluginStartStatusOK, status)
defer host.Reset()
// 2. 设置请求头
headers := [][2]string{
{":method", "GET"},
{":path", "/test"},
{":authority", "test.com"},
}
// 3. 调用插件请求头处理方法
action := host.CallOnHttpRequestHeaders(headers)
require.Equal(t, types.ActionPause, action)
// 4. 模拟外部调用响应(如果需要)
// host.CallOnRedisCall(0, test.CreateRedisRespString("OK"))
// host.CallOnHttpCall([][2]string{{":status", "200"}}, []byte(`{"result": "success"}`))
// 5. 完成请求
host.CompleteHttp()
// 6. 验证结果(如果插件里返回了响应)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
assert.Equal(t, uint32(200), localResponse.StatusCode)
})
}
```
## E2E测试
当你完成一个GO语言的插件功能时, 可以同时创建关联的e2e test cases, 并在本地对插件功能完成测试验证。

View File

@@ -139,6 +139,52 @@ spec:
The rules will be matched in the order of configuration. If one match is found, it will stop, and the matching configuration will take effect.
## Unit Testing
When developing wasm plugins, it's recommended to write unit tests to verify plugin functionality. For detailed unit testing guidelines, please refer to [wasm plugin unit test](https://github.com/higress-group/wasm-go/blob/main/pkg/test/README.md).
### Unit Test Structure Example
```go
func TestMyPlugin(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 1. Create test host
config := json.RawMessage(`{"key": "value"}`)
host, status := test.NewTestHost(config)
require.Equal(t, types.OnPluginStartStatusOK, status)
defer host.Reset()
// 2. Set request headers
headers := [][2]string{
{":method", "GET"},
{":path", "/test"},
{":authority", "test.com"},
}
// 3. Call plugin request header processing method
action := host.CallOnHttpRequestHeaders(headers)
require.Equal(t, types.ActionPause, action)
// 4. Simulate external call responses (if needed)
// host.CallOnRedisCall(0, test.CreateRedisRespString("OK"))
// host.CallOnHttpCall([][2]string{{":status", "200"}}, []byte(`{"result": "success"}`))
// 5. Complete request
host.CompleteHttp()
// 6. Verify results (if the plugin returns a response)
localResponse := host.GetLocalResponse()
require.NotNil(t, localResponse)
assert.Equal(t, uint32(200), localResponse.StatusCode)
})
}
```
This example shows the basic test structure including configuration parsing, request processing flow, and result verification.
## E2E test
When you complete a GO plug-in function, you can create associated e2e test cases at the same time, and complete the test verification of the plug-in function locally.

View File

@@ -5,15 +5,21 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v0.0.0-20250628101008-bea7da01a545
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
gopkg.in/yaml.v2 v2.4.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v0.0.0-20250628101008-bea7da01a545 h1:qb/Rhhfm1gzr/stim/L0cKNo0MPatdo0Rd8iYOAPWE0=
github.com/higress-group/wasm-go v0.0.0-20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -19,6 +22,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=

View File

File diff suppressed because it is too large Load Diff

View File

@@ -8,14 +8,20 @@ toolchain go1.24.4
require (
github.com/google/uuid v1.6.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
// github.com/weaviate/weaviate-go-client/v4 v4.15.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

File diff suppressed because it is too large Load Diff

View File

@@ -7,14 +7,20 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,10 +1,129 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"reflect"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置基本Redis配置
var basicRedisConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"redis": map[string]interface{}{
"serviceName": "redis.static",
"servicePort": 6379,
"timeout": 1000,
"database": 0,
},
"questionFrom": map[string]interface{}{
"requestBody": "messages.@reverse.0.content",
},
"answerValueFrom": map[string]interface{}{
"responseBody": "choices.0.message.content",
},
"answerStreamValueFrom": map[string]interface{}{
"responseBody": "choices.0.delta.content",
},
"cacheKeyPrefix": "higress-ai-history:",
"identityHeader": "Authorization",
"fillHistoryCnt": 3,
"cacheTTL": 3600,
})
return data
}()
// 测试配置最小Redis配置使用默认值
var minimalRedisConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"redis": map[string]interface{}{
"serviceName": "redis.static",
},
"questionFrom": map[string]interface{}{
"requestBody": "messages.@reverse.0.content",
},
"answerValueFrom": map[string]interface{}{
"responseBody": "choices.0.message.content",
},
"answerStreamValueFrom": map[string]interface{}{
"responseBody": "choices.0.delta.content",
},
})
return data
}()
// 测试配置自定义Redis配置
var customRedisConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"redis": map[string]interface{}{
"serviceName": "custom-redis.dns",
"servicePort": 6380,
"username": "admin",
"password": "password123",
"timeout": 2000,
"database": 1,
},
"questionFrom": map[string]interface{}{
"requestBody": "query.text",
},
"answerValueFrom": map[string]interface{}{
"responseBody": "response.content",
},
"answerStreamValueFrom": map[string]interface{}{
"responseBody": "response.delta.content",
},
"cacheKeyPrefix": "custom-history:",
"identityHeader": "X-User-ID",
"fillHistoryCnt": 5,
"cacheTTL": 7200,
})
return data
}()
// 测试配置带认证的Redis配置
var authRedisConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"redis": map[string]interface{}{
"serviceName": "auth-redis.static",
"servicePort": 6379,
"username": "user",
"password": "pass",
"timeout": 1500,
"database": 2,
},
"questionFrom": map[string]interface{}{
"requestBody": "messages.@reverse.0.content",
},
"answerValueFrom": map[string]interface{}{
"responseBody": "choices.0.message.content",
},
"answerStreamValueFrom": map[string]interface{}{
"responseBody": "choices.0.delta.content",
},
"cacheKeyPrefix": "auth-history:",
"identityHeader": "X-Auth-Token",
"fillHistoryCnt": 4,
"cacheTTL": 1800,
})
return data
}()
func TestDistinctChat(t *testing.T) {
type args struct {
chat []ChatHistory
@@ -34,3 +153,627 @@ func TestDistinctChat(t *testing.T) {
})
}
}
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本Redis配置解析
t.Run("basic redis config", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
// 类型断言
pluginConfig, ok := config.(*PluginConfig)
require.True(t, ok, "config should be *PluginConfig")
// 验证Redis配置字段
require.Equal(t, "redis.static", pluginConfig.RedisInfo.ServiceName)
require.Equal(t, 6379, pluginConfig.RedisInfo.ServicePort)
require.Equal(t, 1000, pluginConfig.RedisInfo.Timeout)
require.Equal(t, 0, pluginConfig.RedisInfo.Database)
require.Equal(t, "", pluginConfig.RedisInfo.Username)
require.Equal(t, "", pluginConfig.RedisInfo.Password)
// 验证问题提取配置
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
// 验证答案提取配置
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
// 验证流式答案提取配置
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
// 验证其他配置字段
require.Equal(t, "higress-ai-history:", pluginConfig.CacheKeyPrefix)
require.Equal(t, "Authorization", pluginConfig.IdentityHeader)
require.Equal(t, 3, pluginConfig.FillHistoryCnt)
require.Equal(t, 3600, pluginConfig.CacheTTL)
})
// 测试最小Redis配置解析使用默认值
t.Run("minimal redis config", func(t *testing.T) {
host, status := test.NewTestHost(minimalRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
// 类型断言
pluginConfig, ok := config.(*PluginConfig)
require.True(t, ok, "config should be *PluginConfig")
// 验证Redis配置字段使用默认值
require.Equal(t, "redis.static", pluginConfig.RedisInfo.ServiceName)
require.Equal(t, 80, pluginConfig.RedisInfo.ServicePort) // 对于.static服务默认端口是80
require.Equal(t, 1000, pluginConfig.RedisInfo.Timeout) // 默认超时
require.Equal(t, 0, pluginConfig.RedisInfo.Database) // 默认数据库
require.Equal(t, "", pluginConfig.RedisInfo.Username)
require.Equal(t, "", pluginConfig.RedisInfo.Password)
// 验证问题提取配置(使用默认值)
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
// 验证答案提取配置(使用默认值)
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
// 验证流式答案提取配置(使用默认值)
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
// 验证其他配置字段(使用默认值)
require.Equal(t, "higress-ai-history:", pluginConfig.CacheKeyPrefix)
require.Equal(t, "Authorization", pluginConfig.IdentityHeader)
require.Equal(t, 3, pluginConfig.FillHistoryCnt)
require.Equal(t, 0, pluginConfig.CacheTTL) // 默认永不过期
})
// 测试自定义Redis配置解析
t.Run("custom redis config", func(t *testing.T) {
host, status := test.NewTestHost(customRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
// 类型断言
pluginConfig, ok := config.(*PluginConfig)
require.True(t, ok, "config should be *PluginConfig")
// 验证Redis配置字段
require.Equal(t, "custom-redis.dns", pluginConfig.RedisInfo.ServiceName)
require.Equal(t, 6380, pluginConfig.RedisInfo.ServicePort)
require.Equal(t, 2000, pluginConfig.RedisInfo.Timeout)
require.Equal(t, 1, pluginConfig.RedisInfo.Database)
require.Equal(t, "admin", pluginConfig.RedisInfo.Username)
require.Equal(t, "password123", pluginConfig.RedisInfo.Password)
// 验证问题提取配置(插件硬编码,不从配置读取)
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
// 验证答案提取配置(插件硬编码,不从配置读取)
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
// 验证流式答案提取配置(插件硬编码,不从配置读取)
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
// 验证其他配置字段
require.Equal(t, "custom-history:", pluginConfig.CacheKeyPrefix)
require.Equal(t, "X-User-ID", pluginConfig.IdentityHeader)
require.Equal(t, 5, pluginConfig.FillHistoryCnt)
require.Equal(t, 7200, pluginConfig.CacheTTL)
})
// 测试带认证的Redis配置解析
t.Run("auth redis config", func(t *testing.T) {
host, status := test.NewTestHost(authRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
// 类型断言
pluginConfig, ok := config.(*PluginConfig)
require.True(t, ok, "config should be *PluginConfig")
// 验证Redis配置字段
require.Equal(t, "auth-redis.static", pluginConfig.RedisInfo.ServiceName)
require.Equal(t, 6379, pluginConfig.RedisInfo.ServicePort)
require.Equal(t, 1500, pluginConfig.RedisInfo.Timeout)
require.Equal(t, 2, pluginConfig.RedisInfo.Database)
require.Equal(t, "user", pluginConfig.RedisInfo.Username)
require.Equal(t, "pass", pluginConfig.RedisInfo.Password)
// 验证问题提取配置
require.Equal(t, "messages.@reverse.0.content", pluginConfig.QuestionFrom.RequestBody)
require.Equal(t, "", pluginConfig.QuestionFrom.ResponseBody)
// 验证答案提取配置
require.Equal(t, "", pluginConfig.AnswerValueFrom.RequestBody)
require.Equal(t, "choices.0.message.content", pluginConfig.AnswerValueFrom.ResponseBody)
// 验证流式答案提取配置
require.Equal(t, "", pluginConfig.AnswerStreamValueFrom.RequestBody)
require.Equal(t, "choices.0.delta.content", pluginConfig.AnswerStreamValueFrom.ResponseBody)
// 验证其他配置字段
require.Equal(t, "auth-history:", pluginConfig.CacheKeyPrefix)
require.Equal(t, "X-Auth-Token", pluginConfig.IdentityHeader)
require.Equal(t, 4, pluginConfig.FillHistoryCnt)
require.Equal(t, 1800, pluginConfig.CacheTTL)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试JSON内容类型的请求头处理
t.Run("JSON content type headers", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置JSON内容类型的请求头包含身份标识
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 应该返回HeaderStopIteration因为需要读取请求体
require.Equal(t, types.HeaderStopIteration, action)
})
// 测试非JSON内容类型的请求头处理
t.Run("non-JSON content type headers", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置非JSON内容类型的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "text/plain"},
{"authorization", "Bearer user123"},
})
// 应该返回ActionContinue但不会读取请求体
require.Equal(t, types.ActionContinue, action)
})
// 测试缺少身份标识的请求头处理
t.Run("missing identity header", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置缺少身份标识的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue因为缺少身份标识
require.Equal(t, types.ActionContinue, action)
})
// 测试自定义身份标识头
t.Run("custom identity header", func(t *testing.T) {
host, status := test.NewTestHost(customRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置自定义身份标识头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"x-user-id", "user456"},
})
// 应该返回HeaderStopIteration
require.Equal(t, types.HeaderStopIteration, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试缓存命中的请求体处理
t.Run("cache hit request body", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 构造请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": "你好,请介绍一下自己"
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待Redis响应
require.Equal(t, types.ActionPause, action)
// 模拟Redis缓存命中响应
cacheResponse := `[{"role":"user","content":"之前的问题"},{"role":"assistant","content":"之前的回答"}]`
resp := test.CreateRedisRespString(cacheResponse)
host.CallOnRedisCall(0, resp)
// 完成HTTP请求
host.CompleteHttp()
})
// 测试缓存未命中的请求体处理
t.Run("cache miss request body", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 构造请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": "今天天气怎么样?"
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待Redis响应
require.Equal(t, types.ActionPause, action)
// 模拟Redis缓存未命中响应
resp := test.CreateRedisRespNull()
host.CallOnRedisCall(0, resp)
// 完成HTTP请求
host.CompleteHttp()
})
// 测试流式请求的请求体处理
t.Run("streaming request body", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 构造流式请求体
requestBody := `{
"stream": true,
"messages": [
{
"role": "user",
"content": "请用流式方式回答"
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待Redis响应
require.Equal(t, types.ActionPause, action)
// 模拟Redis缓存未命中响应
resp := test.CreateRedisRespNull()
host.CallOnRedisCall(0, resp)
// 完成HTTP请求
host.CompleteHttp()
})
// 测试查询历史请求的请求体处理
t.Run("query history request body", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/ai-history/query?cnt=2"},
{":method", "GET"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 构造请求体需要包含messages字段因为插件会尝试提取问题
requestBody := `{
"messages": [
{
"role": "user",
"content": "查询历史"
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待Redis响应
require.Equal(t, types.ActionPause, action)
// 模拟Redis缓存命中响应
cacheResponse := `[{"role":"user","content":"问题1"},{"role":"assistant","content":"回答1"},{"role":"user","content":"问题2"},{"role":"assistant","content":"回答2"}]`
resp := test.CreateRedisRespString(cacheResponse)
host.CallOnRedisCall(0, resp)
// 完成HTTP请求
host.CompleteHttp()
})
})
}
func TestOnHttpResponseHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试流式响应头处理
t.Run("streaming response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 必须先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 设置流式响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试非流式响应头处理
t.Run("non-streaming response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 必须先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 设置非流式响应头
action := host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpStreamResponseBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试流式响应体处理 - 非流式模式
t.Run("non-streaming mode", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 设置请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": "测试问题"
}
]
}`
// 调用请求体处理,设置必要的上下文
host.CallOnHttpRequestBody([]byte(requestBody))
// 模拟Redis缓存未命中设置QuestionContextKey
resp := test.CreateRedisRespNull()
host.CallOnRedisCall(0, resp)
// 测试非流式响应体处理
chunk := []byte(`{"choices":[{"message":{"content":"测试回答"}}]}`)
action := host.CallOnHttpStreamingResponseBody(chunk, true)
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试流式响应体处理 - 流式模式
t.Run("streaming mode", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 设置流式请求体
requestBody := `{
"stream": true,
"messages": [
{
"role": "user",
"content": "测试流式问题"
}
]
}`
// 调用请求体处理,设置必要的上下文
host.CallOnHttpRequestBody([]byte(requestBody))
// 模拟Redis缓存未命中设置QuestionContextKey
resp := test.CreateRedisRespNull()
host.CallOnRedisCall(0, resp)
// 设置流式响应头
host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"content-type", "text/event-stream"},
})
// 测试流式响应体处理 - 非最后一个chunk
chunk1 := []byte("data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n")
action1 := host.CallOnHttpStreamingResponseBody(chunk1, false)
require.Equal(t, types.ActionContinue, action1)
// 测试流式响应体处理 - 最后一个chunk
chunk2 := []byte("data: {\"choices\":[{\"delta\":{\"content\":\" World\"}}]}\n\n")
action2 := host.CallOnHttpStreamingResponseBody(chunk2, true)
require.Equal(t, types.ActionContinue, action2)
})
// 测试查询历史路径的流式响应体处理
t.Run("query history path", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/ai-history/query?cnt=2"},
{":method", "GET"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 设置请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": "查询历史"
}
]
}`
// 调用请求体处理,设置必要的上下文
host.CallOnHttpRequestBody([]byte(requestBody))
// 模拟Redis缓存命中设置QuestionContextKey
cacheResponse := `[{"role":"user","content":"问题1"},{"role":"assistant","content":"回答1"}]`
resp := test.CreateRedisRespString(cacheResponse)
host.CallOnRedisCall(0, resp)
// 测试查询历史路径的响应体处理
chunk := []byte("test chunk")
action := host.CallOnHttpStreamingResponseBody(chunk, true)
// 应该直接返回chunk不进行处理
require.Equal(t, types.ActionContinue, action)
})
// 测试没有QuestionContextKey的情况
t.Run("no question context", func(t *testing.T) {
host, status := test.NewTestHost(basicRedisConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
{"authorization", "Bearer user123"},
})
// 不调用请求体处理所以没有QuestionContextKey
// 测试没有QuestionContextKey的响应体处理
chunk := []byte("test chunk")
action := host.CallOnHttpStreamingResponseBody(chunk, true)
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}

View File

@@ -5,13 +5,21 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (
github.com/google/uuid v1.6.0 // indirect
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect

View File

@@ -2,14 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -22,5 +24,7 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,616 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置基本DashScope OCR配置
var basicDashScopeConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "dashscope",
"apiKey": "test-api-key-123",
"serviceName": "ocr-service",
"serviceHost": "dashscope.aliyuncs.com",
"servicePort": 443,
"timeout": 10000,
"model": "qwen-vl-ocr",
})
return data
}()
// 测试配置最小DashScope配置使用默认值
var minimalDashScopeConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "dashscope",
"apiKey": "minimal-api-key",
"serviceName": "ocr-service",
})
return data
}()
// 测试配置:自定义端口和超时配置
var customPortTimeoutConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "dashscope",
"apiKey": "custom-api-key",
"serviceName": "ocr-service",
"serviceHost": "custom.dashscope.com",
"servicePort": 8443,
"timeout": 30000,
"model": "qwen-vl-ocr",
})
return data
}()
// 测试配置:自定义模型配置
var customModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "dashscope",
"apiKey": "model-api-key",
"serviceName": "ocr-service",
"serviceHost": "dashscope.aliyuncs.com",
"servicePort": 443,
"timeout": 15000,
"model": "custom-ocr-model",
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本DashScope配置解析
t.Run("basic dashscope config", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试最小DashScope配置解析使用默认值
t.Run("minimal dashscope config", func(t *testing.T) {
host, status := test.NewTestHost(minimalDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义端口和超时配置解析
t.Run("custom port timeout config", func(t *testing.T) {
host, status := test.NewTestHost(customPortTimeoutConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义模型配置解析
t.Run("custom model config", func(t *testing.T) {
host, status := test.NewTestHost(customModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试JSON内容类型的请求头处理
t.Run("JSON content type headers", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置JSON内容类型的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 应该返回ActionContinue因为禁用了重路由但允许继续处理
require.Equal(t, types.ActionContinue, action)
})
// 测试非JSON内容类型的请求头处理
t.Run("non-JSON content type headers", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置非JSON内容类型的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "text/plain"},
})
// 应该返回ActionContinue但不会读取请求体
require.Equal(t, types.ActionContinue, action)
})
// 测试缺少content-type的请求头处理
t.Run("missing content type headers", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置缺少content-type的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试包含单张图片的请求体处理
t.Run("single image request body", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造包含单张图片的请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "这张图片里有什么?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image1.jpg"
}
}
]
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待OCR响应
require.Equal(t, types.ActionPause, action)
// 模拟OCR服务响应
ocrResponse := `{
"choices": [
{
"message": {
"content": "图片中包含一些文字内容"
}
}
]
}`
// 模拟HTTP调用响应
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(ocrResponse))
modifiedBody := host.GetRequestBody()
require.NotNil(t, modifiedBody)
require.Contains(t, string(modifiedBody), "图片中包含一些文字内容")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试包含多张图片的请求体处理
t.Run("multiple images request body", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造包含多张图片的请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "这些图片里有什么?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image1.jpg"
}
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image2.jpg"
}
}
]
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待OCR响应
require.Equal(t, types.ActionPause, action)
// 模拟第一张图片的OCR响应
ocrResponse1 := `{
"choices": [
{
"message": {
"content": "第一张图片包含文字A"
}
}
]
}`
// 模拟第二张图片的OCR响应
ocrResponse2 := `{
"choices": [
{
"message": {
"content": "第二张图片包含文字B"
}
}
]
}`
// 模拟第一个HTTP调用响应
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(ocrResponse1))
// 模拟第二个HTTP调用响应
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(ocrResponse2))
modifiedBody := host.GetRequestBody()
require.NotNil(t, modifiedBody)
require.Contains(t, string(modifiedBody), "第一张图片包含文字A")
require.Contains(t, string(modifiedBody), "第二张图片包含文字B")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试不包含图片的请求体处理
t.Run("no image request body", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造不包含图片的请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "你好,请介绍一下自己"
}
]
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionContinue因为没有图片需要处理
require.Equal(t, types.ActionContinue, action)
})
})
}
// 测试配置验证
func TestConfigValidation(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试缺少type配置
t.Run("missing type", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "test-api-key",
"serviceName": "ocr-service",
"serviceHost": "dashscope.aliyuncs.com",
"servicePort": 443,
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的type
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
// 测试缺少apiKey配置
t.Run("missing apiKey", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "dashscope",
"serviceName": "ocr-service",
"serviceHost": "dashscope.aliyuncs.com",
"servicePort": 443,
"timeout": 10000,
"model": "qwen-vl-ocr",
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的apiKey
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
// 测试缺少serviceName配置
t.Run("missing serviceName", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "dashscope",
"apiKey": "test-api-key",
"serviceHost": "dashscope.aliyuncs.com",
"servicePort": 443,
"timeout": 10000,
"model": "qwen-vl-ocr",
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的serviceName
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
// 测试未知的provider类型
t.Run("unknown provider type", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"type": "unknown-provider",
"apiKey": "test-api-key",
"serviceName": "ocr-service",
"serviceHost": "example.com",
"servicePort": 443,
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为provider类型未知
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
})
}
// 测试边界情况
func TestEdgeCases(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试空请求体
t.Run("empty request body", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 调用请求体处理 - 空请求体
action := host.CallOnHttpRequestBody([]byte{})
// 应该返回ActionContinue因为没有图片需要处理
require.Equal(t, types.ActionContinue, action)
})
// 测试无效JSON请求体
t.Run("invalid JSON request body", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 调用请求体处理 - 无效JSON
invalidJSON := []byte(`{"messages": [{"role": "user", "content": "test"}`)
action := host.CallOnHttpRequestBody(invalidJSON)
// 应该返回ActionContinue因为JSON解析失败
require.Equal(t, types.ActionContinue, action)
})
// 测试OCR服务错误响应
t.Run("OCR service error response", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造包含图片的请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "这张图片里有什么?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image1.jpg"
}
}
]
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟OCR服务错误响应
errorResponse := `{
"error": "Service unavailable",
"message": "OCR service is down"
}`
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "503"},
}, []byte(errorResponse))
host.CompleteHttp()
})
// 测试OCR服务返回空结果
t.Run("OCR service empty response", func(t *testing.T) {
host, status := test.NewTestHost(basicDashScopeConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造包含图片的请求体
requestBody := `{
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "这张图片里有什么?"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image1.jpg"
}
}
]
}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟OCR服务返回空结果
emptyResponse := `{
"choices": []
}`
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(emptyResponse))
host.CompleteHttp()
})
})
}

View File

@@ -7,14 +7,20 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,531 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基本意图识别配置
var basicIntentConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "金融|电商|法律|Higress",
"prompt": "你是一个智能类别识别助手,负责根据用户提出的问题和预设的类别,确定问题属于哪个预设的类别,并给出相应的类别。用户提出的问题为:'%s',预设的类别为'%s',直接返回一种具体类别,如果没有找到就返回'NotFound'。",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
"proxyUrl": "http://ai.example.com/v1/chat/completions",
"proxyModel": "qwen-long",
"proxyPort": 80,
"proxyDomain": "ai.example.com",
"proxyTimeout": 10000,
"proxyApiKey": "test-api-key",
},
"keyFrom": map[string]interface{}{
"requestBody": "messages.@reverse.0.content",
"responseBody": "choices.0.message.content",
},
})
return data
}()
// 测试配置:自定义提示词配置
var customPromptConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "技术|产品|运营|设计",
"prompt": "请分析以下问题属于哪个技术领域:%s可选领域%s请直接返回领域名称。",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
"proxyUrl": "https://ai.example.com/v1/chat/completions",
"proxyModel": "gpt-3.5-turbo",
"proxyPort": 443,
"proxyDomain": "ai.example.com",
"proxyTimeout": 15000,
"proxyApiKey": "custom-api-key",
},
"keyFrom": map[string]interface{}{
"requestBody": "query",
"responseBody": "result",
},
})
return data
}()
// 测试配置:最小配置(使用默认值)
var minimalConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "A|B|C",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
"proxyUrl": "http://ai.example.com/v1/chat/completions",
},
"keyFrom": map[string]interface{}{},
})
return data
}()
// 测试配置HTTPS配置
var httpsConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "客服|销售|技术支持",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
"proxyUrl": "https://ai.example.com:8443/v1/chat/completions",
"proxyModel": "claude-3",
"proxyTimeout": 20000,
"proxyApiKey": "https-api-key",
},
"keyFrom": map[string]interface{}{
"requestBody": "input.text",
"responseBody": "output.classification",
},
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本意图识别配置解析
t.Run("basic intent config", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义提示词配置解析
t.Run("custom prompt config", func(t *testing.T) {
host, status := test.NewTestHost(customPromptConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试最小配置解析(使用默认值)
t.Run("minimal config", func(t *testing.T) {
host, status := test.NewTestHost(minimalConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试HTTPS配置解析
t.Run("https config", func(t *testing.T) {
host, status := test.NewTestHost(httpsConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试请求头处理
t.Run("request headers processing", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 应该返回HeaderStopIteration因为禁用了重路由
require.Equal(t, types.HeaderStopIteration, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试请求体处理 - 金融类问题
t.Run("financial question processing", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造请求体 - 金融类问题
requestBody := `{
"messages": [
{"role": "user", "content": "今天股市怎么样?"}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause因为需要等待LLM响应
require.Equal(t, types.ActionPause, action)
// 模拟LLM响应 - 返回"金融"类别
llmResponse := `{
"choices": [
{
"message": {
"content": "金融"
}
}
]
}`
// 模拟HTTP调用响应
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(llmResponse))
// 验证插件是否正确处理了LLM响应
// 插件应该将"金融"类别设置到Property中
// 通过host.GetProperty验证意图类别是否被正确设置
intentCategory, err := host.GetProperty([]string{"intent_category"})
require.NoError(t, err)
require.Equal(t, "金融", string(intentCategory))
// 完成HTTP请求
host.CompleteHttp()
})
// 测试请求体处理 - 电商类问题
t.Run("ecommerce question processing", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造请求体 - 电商类问题
requestBody := `{
"messages": [
{"role": "user", "content": "这个商品什么时候发货?"}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟LLM响应 - 返回"电商"类别
llmResponse := `{
"choices": [
{
"message": {
"content": "电商"
}
}
]
}`
// 模拟HTTP调用响应
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(llmResponse))
// 验证插件是否正确处理了LLM响应
// 插件应该将"电商"类别设置到Property中
// 通过host.GetProperty验证意图类别是否被正确设置
intentCategory, err := host.GetProperty([]string{"intent_category"})
require.NoError(t, err)
require.Equal(t, "电商", string(intentCategory))
// 完成HTTP请求
host.CompleteHttp()
})
// 测试请求体处理 - 未找到类别
t.Run("category not found processing", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造请求体 - 不相关的问题
requestBody := `{
"messages": [
{"role": "user", "content": "今天天气怎么样?"}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟LLM响应 - 返回"NotFound"
llmResponse := `{
"choices": [
{
"message": {
"content": "NotFound"
}
}
]
}`
// 模拟HTTP调用响应
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(llmResponse))
_, err := host.GetProperty([]string{"intent_category"})
// 应该返回错误因为没有设置该Property
require.Error(t, err)
// 完成HTTP请求
host.CompleteHttp()
})
})
}
// 测试配置验证
func TestConfigValidation(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试缺少scene.category配置
t.Run("missing scene.category", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"prompt": "test prompt",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
"proxyUrl": "http://ai.example.com/v1/chat/completions",
},
"keyFrom": map[string]interface{}{},
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的scene.category
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
// 测试缺少llm.proxyServiceName配置
t.Run("missing llm.proxyServiceName", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "A|B|C",
},
"llm": map[string]interface{}{
"proxyUrl": "http://ai.example.com/v1/chat/completions",
},
"keyFrom": map[string]interface{}{},
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的llm.proxyServiceName
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
// 测试缺少llm.proxyUrl配置
t.Run("missing llm.proxyUrl", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "A|B|C",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
},
"keyFrom": map[string]interface{}{},
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的llm.proxyUrl
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
// 测试缺少必需字段的配置
t.Run("missing required fields", func(t *testing.T) {
invalidConfig := func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"scene": map[string]interface{}{
"category": "A|B|C",
},
"llm": map[string]interface{}{
"proxyServiceName": "ai-service",
// 故意不设置proxyUrl这是必需的
},
"keyFrom": map[string]interface{}{},
})
return data
}()
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
// 应该返回错误状态因为缺少必需的proxyUrl
require.NotEqual(t, types.OnPluginStartStatusOK, status)
})
})
}
// 测试边界情况
func TestEdgeCases(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试无效JSON请求体
t.Run("invalid JSON request body", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 调用请求体处理 - 无效JSON
invalidJSON := []byte(`{"messages": [{"role": "user", "content": "test"}`)
action := host.CallOnHttpRequestBody(invalidJSON)
// 应该返回ActionPause因为需要等待LLM响应
require.Equal(t, types.ActionPause, action)
// 模拟LLM响应
llmResponse := `{
"choices": [
{
"message": {
"content": "NotFound"
}
}
]
}`
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "200"},
}, []byte(llmResponse))
// 验证插件是否正确处理了LLM响应
// 由于返回"NotFound"插件不会设置任何意图类别到Property中
// 验证没有设置意图类别Property
_, err := host.GetProperty([]string{"intent_category"})
// 应该返回错误因为没有设置该Property
require.Error(t, err)
host.CompleteHttp()
})
// 测试LLM服务错误响应
t.Run("LLM service error response", func(t *testing.T) {
host, status := test.NewTestHost(basicIntentConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "POST"},
{"content-type", "application/json"},
})
// 构造请求体
requestBody := `{
"messages": [
{"role": "user", "content": "今天股市怎么样?"}
]
}`
// 调用请求体处理
action := host.CallOnHttpRequestBody([]byte(requestBody))
// 应该返回ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟LLM服务错误响应
errorResponse := `{
"error": "Service unavailable",
"message": "LLM service is down"
}`
host.CallOnHttpCall([][2]string{
{"content-type", "application/json"},
{":status", "503"},
}, []byte(errorResponse))
// 验证插件是否正确处理了LLM错误响应
// 由于状态码不是200插件不会设置任何意图类别到Property中
host.CompleteHttp()
})
})
}

View File

@@ -5,8 +5,17 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (

View File

@@ -2,16 +2,19 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/santhosh-tekuri/jsonschema v1.2.4 h1:hNhW8e7t+H1vgY+1QeEQpveR6D4+OwKPXCfD2aieJis=
github.com/santhosh-tekuri/jsonschema v1.2.4/go.mod h1:TEAUOeZSmIxTTuHatJzrvARHiuO9LYd+cIxzgEHCQI4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -21,5 +24,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,892 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/santhosh-tekuri/jsonschema"
"github.com/stretchr/testify/require"
)
// 测试配置:基础配置
var basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "ai-service",
"serviceDomain": "api.openai.com",
"servicePort": 443,
"servicePath": "/v1/chat/completions",
"apiKey": "sk-test123",
"serviceTimeout": 30000,
"maxRetry": 3,
"contentPath": "choices.0.message.content",
"enableContentDisposition": true,
// 添加一个简单的JSON Schema避免编译失败
"jsonSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"content": map[string]interface{}{
"type": "string",
},
},
},
})
return data
}()
// 测试配置使用serviceUrl的配置
var serviceUrlConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "ai-service",
"serviceUrl": "https://api.openai.com/v1/chat/completions",
"apiKey": "sk-test456",
"serviceTimeout": 50000,
"maxRetry": 5,
"contentPath": "choices.0.message.content",
"enableContentDisposition": false,
// 添加一个简单的JSON Schema避免编译失败
"jsonSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"content": map[string]interface{}{
"type": "string",
},
},
},
})
return data
}()
// 测试配置包含JSON Schema的配置
var jsonSchemaConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "ai-service",
"serviceDomain": "api.openai.com",
"servicePort": 443,
"apiKey": "sk-test789",
"jsonSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"name": map[string]interface{}{
"type": "string",
},
"age": map[string]interface{}{
"type": "integer",
},
},
"required": []string{"name"},
},
"enableSwagger": true,
"enableOas3": false,
})
return data
}()
// 测试配置启用OAS3的配置
var oas3Config = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "ai-service",
"serviceDomain": "api.openai.com",
"servicePort": 443,
"apiKey": "sk-test101",
"jsonSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"title": map[string]interface{}{
"type": "string",
},
"content": map[string]interface{}{
"type": "string",
},
},
},
"enableSwagger": false,
"enableOas3": true,
})
return data
}()
// 测试配置无效的JSON Schema配置
var invalidJsonSchemaConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"serviceName": "ai-service",
"serviceDomain": "api.openai.com",
"servicePort": 443,
"apiKey": "sk-test303",
"jsonSchema": "invalid-schema",
})
return data
}()
// 测试配置:缺少必需字段的配置
var missingRequiredConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test404",
"serviceTimeout": 30000,
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础配置解析
t.Run("basic config", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
pluginConfig := config.(*PluginConfig)
require.Equal(t, "ai-service", pluginConfig.serviceName)
require.Equal(t, "api.openai.com", pluginConfig.serviceDomain)
require.Equal(t, 443, pluginConfig.servicePort)
require.Equal(t, "/v1/chat/completions", pluginConfig.servicePath)
require.Equal(t, "sk-test123", pluginConfig.apiKey)
require.Equal(t, 30000, pluginConfig.serviceTimeout)
require.Equal(t, 3, pluginConfig.maxRetry)
require.Equal(t, "choices.0.message.content", pluginConfig.contentPath)
require.True(t, pluginConfig.enableContentDisposition)
require.NotNil(t, pluginConfig.jsonSchema)
require.Equal(t, jsonschema.Draft7, pluginConfig.draft)
require.True(t, pluginConfig.enableJsonSchemaValidation)
})
// 测试使用serviceUrl的配置解析
t.Run("serviceUrl config", func(t *testing.T) {
host, status := test.NewTestHost(serviceUrlConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
pluginConfig := config.(*PluginConfig)
require.Equal(t, "ai-service", pluginConfig.serviceName)
require.Equal(t, "api.openai.com", pluginConfig.serviceDomain)
require.Equal(t, 443, pluginConfig.servicePort)
require.Equal(t, "/v1/chat/completions", pluginConfig.servicePath)
require.Equal(t, "sk-test456", pluginConfig.apiKey)
require.Equal(t, 50000, pluginConfig.serviceTimeout)
require.Equal(t, 5, pluginConfig.maxRetry)
require.False(t, pluginConfig.enableContentDisposition)
})
// 测试包含JSON Schema的配置解析
t.Run("jsonSchema config", func(t *testing.T) {
host, status := test.NewTestHost(jsonSchemaConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
pluginConfig := config.(*PluginConfig)
require.NotNil(t, pluginConfig.jsonSchema)
require.Equal(t, jsonschema.Draft4, pluginConfig.draft)
require.True(t, pluginConfig.enableJsonSchemaValidation)
require.NotNil(t, pluginConfig.compile)
})
// 测试启用OAS3的配置解析
t.Run("oas3 config", func(t *testing.T) {
host, status := test.NewTestHost(oas3Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
pluginConfig := config.(*PluginConfig)
require.Equal(t, jsonschema.Draft7, pluginConfig.draft)
require.True(t, pluginConfig.enableJsonSchemaValidation)
})
// 测试无效的JSON Schema配置
t.Run("invalid jsonSchema config", func(t *testing.T) {
host, status := test.NewTestHost(invalidJsonSchemaConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
pluginConfig := config.(*PluginConfig)
// 根据插件的实际行为无效的JSON Schema会导致编译失败
require.Equal(t, uint32(JSON_SCHEMA_COMPILE_FAILED_CODE), pluginConfig.rejectStruct.RejectCode)
})
// 测试缺少必需字段的配置
t.Run("missing required config", func(t *testing.T) {
host, status := test.NewTestHost(missingRequiredConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, _ := host.GetMatchConfig()
require.NotNil(t, config)
pluginConfig := config.(*PluginConfig)
// 根据插件的实际行为缺少serviceDomain会导致JSON Schema编译失败
require.Equal(t, uint32(JSON_SCHEMA_COMPILE_FAILED_CODE), pluginConfig.rejectStruct.RejectCode)
require.Contains(t, pluginConfig.rejectStruct.RejectMsg, "Json Schema compile failed")
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试正常请求头处理
t.Run("normal request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Authorization", "Bearer sk-user123"},
{"Content-Type", "application/json"},
{"Content-Length", "100"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试来自插件的请求头处理
t.Run("request from this plugin", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置来自插件的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{EXTEND_HEADER_KEY, "true"},
{"Content-Type", "application/json"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试没有Authorization头的请求
t.Run("no authorization header", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置没有Authorization的请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
{"Content-Length", "100"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试配置错误的请求头处理
t.Run("config error", func(t *testing.T) {
host, status := test.NewTestHost(missingRequiredConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回ActionPause
require.Equal(t, types.ActionPause, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试来自插件的请求(应该直接继续)
t.Run("request from this plugin", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头包含EXTEND_HEADER_KEY来标记请求来自插件
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
{EXTEND_HEADER_KEY, "true"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue因为请求来自插件
require.Equal(t, types.ActionContinue, action)
})
// 测试配置错误的请求体处理
t.Run("config error", func(t *testing.T) {
host, status := test.NewTestHost(missingRequiredConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue因为配置有错误
require.Equal(t, types.ActionContinue, action)
})
// 测试正常请求体处理 - 成功响应
t.Run("normal request with successful response", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is AI?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待外部服务响应
require.Equal(t, types.ActionPause, action)
// 模拟外部服务返回成功响应
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"definition\": \"AI is artificial intelligence\", \"examples\": [\"machine learning\", \"natural language processing\"]}"
}
}
]
}`))
response := host.GetLocalResponse()
require.NotNil(t, response)
require.Contains(t, string(response.Data), "definition")
require.Contains(t, string(response.Data), "examples")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试正常请求体处理 - 需要重试的响应
t.Run("normal request with retry response", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is AI?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待外部服务响应
require.Equal(t, types.ActionPause, action)
// 模拟外部服务返回需要重试的响应content字段不是有效JSON
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "AI is artificial intelligence. It includes machine learning and natural language processing."
}
}
]
}`))
// 由于content不是有效JSON插件会进行重试
// 模拟重试请求的响应
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{
"id": "chatcmpl-456",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"definition\": \"AI is artificial intelligence\", \"examples\": [\"machine learning\", \"natural language processing\"]}"
}
}
]
}`))
// 验证最终响应体是提取的JSON内容
response := host.GetLocalResponse()
require.NotNil(t, response)
require.Contains(t, string(response.Data), "definition")
require.Contains(t, string(response.Data), "examples")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试外部服务返回无效响应体
t.Run("external service returns invalid response body", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is AI?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待外部服务响应
require.Equal(t, types.ActionPause, action)
// 模拟外部服务返回无效的响应体
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`invalid json response`))
// 验证响应体包含错误信息
response := host.GetLocalResponse()
require.NotNil(t, response)
require.Contains(t, string(response.Data), "invalid json response")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试外部服务返回缺少content字段的响应
t.Run("external service returns response without content field", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is AI?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待外部服务响应
require.Equal(t, types.ActionPause, action)
// 模拟外部服务返回缺少content字段的响应
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant"
}
}
]
}`))
// 验证响应体包含错误信息
response := host.GetLocalResponse()
require.NotNil(t, response)
require.Contains(t, string(response.Data), "response body does not contain the content")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试使用自定义servicePath的请求
t.Run("request with custom service path", func(t *testing.T) {
host, status := test.NewTestHost(serviceUrlConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/custom/chat"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is AI?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待外部服务响应
require.Equal(t, types.ActionPause, action)
// 模拟外部服务返回成功响应
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\"answer\": \"AI is artificial intelligence\"}"
}
}
]
}`))
// 验证响应体是提取的JSON内容
response := host.GetLocalResponse()
require.NotNil(t, response)
require.Contains(t, string(response.Data), "answer")
// 完成HTTP请求
host.CompleteHttp()
})
// 测试达到最大重试次数的情况
t.Run("max retry count exceeded", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What is AI?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionPause等待外部服务响应
require.Equal(t, types.ActionPause, action)
// 模拟多次重试每次都返回无效的content
for i := 0; i < 4; i++ { // 超过最大重试次数3次
host.CallOnHttpCall([][2]string{
{":status", "200"},
{"content-type", "application/json"},
}, []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "AI is artificial intelligence"
}
}
]
}`))
}
// 验证最终响应体包含重试次数超限的错误信息
response := host.GetLocalResponse()
require.NotNil(t, response)
require.Contains(t, string(response.Data), "retry count exceeds max retry count")
// 完成HTTP请求
host.CompleteHttp()
})
})
}
func TestRejectStruct(t *testing.T) {
// 测试RejectStruct的GetBytes方法
t.Run("GetBytes", func(t *testing.T) {
reject := RejectStruct{
RejectCode: 1001,
RejectMsg: "Test error message",
}
bytes := reject.GetBytes()
require.NotNil(t, bytes)
// 验证JSON格式
var result RejectStruct
err := json.Unmarshal(bytes, &result)
require.NoError(t, err)
require.Equal(t, uint32(1001), result.RejectCode)
require.Equal(t, "Test error message", result.RejectMsg)
})
// 测试RejectStruct的GetShortMsg方法
t.Run("GetShortMsg", func(t *testing.T) {
reject := RejectStruct{
RejectCode: 1001,
RejectMsg: "Json Schema is not valid: invalid format",
}
shortMsg := reject.GetShortMsg()
require.Equal(t, "ai-json-resp.Json Schema is not valid", shortMsg)
})
// 测试RejectStruct的GetShortMsg方法 - 没有冒号的情况
t.Run("GetShortMsg no colon", func(t *testing.T) {
reject := RejectStruct{
RejectCode: 1001,
RejectMsg: "Simple error message",
}
shortMsg := reject.GetShortMsg()
require.Equal(t, "ai-json-resp.Simple error message", shortMsg)
})
}
func TestValidateBody(t *testing.T) {
// 创建测试配置
config := &PluginConfig{
contentPath: "choices.0.message.content",
jsonSchema: nil, // 明确设置为nil禁用JSON Schema验证
enableJsonSchemaValidation: false, // 禁用JSON Schema验证
}
// 测试有效的响应体
t.Run("valid response body", func(t *testing.T) {
validBody := []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, how can I help you?"
}
}
]
}`)
err := config.ValidateBody(validBody)
require.NoError(t, err)
})
// 测试无效的JSON响应体
t.Run("invalid JSON response body", func(t *testing.T) {
invalidBody := []byte(`invalid json content`)
err := config.ValidateBody(invalidBody)
require.Error(t, err)
require.Equal(t, uint32(SERVICE_UNAVAILABLE_CODE), config.rejectStruct.RejectCode)
require.Contains(t, config.rejectStruct.RejectMsg, "service unavailable")
})
// 测试缺少content字段的响应体
t.Run("missing content field", func(t *testing.T) {
missingContentBody := []byte(`{
"id": "chatcmpl-123",
"choices": [
{
"index": 0,
"message": {
"role": "assistant"
}
}
]
}`)
err := config.ValidateBody(missingContentBody)
require.Error(t, err)
require.Equal(t, uint32(SERVICE_UNAVAILABLE_CODE), config.rejectStruct.RejectCode)
require.Contains(t, config.rejectStruct.RejectMsg, "response body does not contain the content")
})
// 测试空的响应体
t.Run("empty response body", func(t *testing.T) {
emptyBody := []byte{}
err := config.ValidateBody(emptyBody)
require.Error(t, err)
require.Equal(t, uint32(SERVICE_UNAVAILABLE_CODE), config.rejectStruct.RejectCode)
})
}
func TestExtractJson(t *testing.T) {
// 创建测试配置
config := &PluginConfig{
jsonSchema: nil, // 明确设置为nil禁用JSON Schema验证
enableJsonSchemaValidation: false, // 禁用JSON Schema验证
}
// 测试提取有效的JSON
t.Run("extract valid JSON", func(t *testing.T) {
content := `Here is the response: {"name": "John", "age": 30} and some other text`
jsonStr, err := config.ExtractJson(content)
require.NoError(t, err)
require.Equal(t, `{"name": "John", "age": 30}`, jsonStr)
})
// 测试提取嵌套JSON
t.Run("extract nested JSON", func(t *testing.T) {
content := `Response: {"user": {"name": "John", "profile": {"age": 30, "city": "NYC"}}}`
jsonStr, err := config.ExtractJson(content)
require.NoError(t, err)
require.Equal(t, `{"user": {"name": "John", "profile": {"age": 30, "city": "NYC"}}}`, jsonStr)
})
// 测试没有JSON的内容
t.Run("no JSON in content", func(t *testing.T) {
content := `This is just plain text without any JSON content`
_, err := config.ExtractJson(content)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot find json in the response body")
})
// 测试只有开始括号的内容
t.Run("only opening brace", func(t *testing.T) {
content := `Here is the start: { but no closing brace`
_, err := config.ExtractJson(content)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot find json in the response body")
})
// 测试只有结束括号的内容
t.Run("only closing brace", func(t *testing.T) {
content := `Here is the end: } but no opening brace`
_, err := config.ExtractJson(content)
require.Error(t, err)
require.Contains(t, err.Error(), "cannot find json in the response body")
})
// 测试无效的JSON格式
t.Run("invalid JSON format", func(t *testing.T) {
content := `Here is invalid JSON: {"name": "John", "age": 30,}`
_, err := config.ExtractJson(content)
require.Error(t, err)
// ExtractJson会提取到{"name": "John", "age": 30,}但json.Unmarshal会失败
// 因为JSON格式无效末尾有多余的逗号
require.Contains(t, err.Error(), "invalid character '}' looking for beginning of object key string")
})
// 测试多个JSON对象应该提取第一个完整的
t.Run("multiple JSON objects", func(t *testing.T) {
content := `First: {"name": "John"} Second: {"age": 30}`
_, err := config.ExtractJson(content)
require.Error(t, err)
// ExtractJson会提取到{"name": "John"} Second: {"age": 30}
// 这不是有效的JSON因为"Second: {"age": 30}"不是有效的JSON语法
require.Contains(t, err.Error(), "invalid character 'S' after top-level value")
})
}

View File

@@ -5,8 +5,8 @@ go 1.24.1
toolchain go1.24.3
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.2
github.com/tidwall/gjson v1.18.0
@@ -21,3 +21,5 @@ require (
github.com/tidwall/pretty v1.2.1 // indirect
google.golang.org/protobuf v1.36.6 // indirect
)
require github.com/tidwall/sjson v1.2.5 // indirect

View File

@@ -4,10 +4,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545 h1:zPXEonKCAeLvXI1IpwGpIeVSvLY5AZ9h9uTJnOuiA3Q=
github.com/higress-group/wasm-go v1.0.1-0.20250628101008-bea7da01a545/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
@@ -18,6 +18,7 @@ github.com/prometheus/common v0.64.0 h1:pdZeA+g617P7oGv1CzdTzyeShxAGrTBsolKNOLQP
github.com/prometheus/common v0.64.0/go.mod h1:0gZns+BLRQ3V6NdaerOhMbwwRbNh9hkGINtQAsP5GS8=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -27,6 +28,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY=

View File

@@ -5,11 +5,19 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (
github.com/google/uuid v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect

View File

@@ -2,14 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -22,5 +24,7 @@ github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -18,9 +18,9 @@ func main() {}
func init() {
wrapper.SetCtx(
"ai-prompt-decorator",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
wrapper.ParseConfig(parseConfig),
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
wrapper.ProcessRequestBody(onHttpRequestBody),
)
}
@@ -34,11 +34,11 @@ type AIPromptDecoratorConfig struct {
Append []Message `json:"append"`
}
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig, log log.Log) error {
func parseConfig(jsonConfig gjson.Result, config *AIPromptDecoratorConfig) error {
return json.Unmarshal([]byte(jsonConfig.Raw), config)
}
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, log log.Log) types.Action {
func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIPromptDecoratorConfig) types.Action {
ctx.DisableReroute()
proxywasm.RemoveHttpRequestHeader("content-length")
return types.ActionContinue
@@ -70,7 +70,7 @@ func decorateGeographicPrompt(entry *Message) (*Message, error) {
return entry, nil
}
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte, log log.Log) types.Action {
func onHttpRequestBody(ctx wrapper.HttpContext, config AIPromptDecoratorConfig, body []byte) types.Action {
messageJson := `{"messages":[]}`
for _, entry := range config.Prepend {

View File

@@ -0,0 +1,511 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"fmt"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基础装饰器配置
var basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"prepend": []map[string]interface{}{
{
"role": "system",
"content": "You are a helpful assistant from ${geo-country}.",
},
},
"append": []map[string]interface{}{
{
"role": "system",
"content": "Please provide context about ${geo-city}.",
},
},
})
return data
}()
// 测试配置:只有前置消息的配置
var prependOnlyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"prepend": []map[string]interface{}{
{
"role": "system",
"content": "You are located in ${geo-province}, ${geo-country}.",
},
},
"append": []map[string]interface{}{}, // 显式定义空的append字段
})
return data
}()
// 测试配置:空配置
var emptyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"prepend": []map[string]interface{}{},
"append": []map[string]interface{}{},
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础装饰器配置解析
t.Run("basic decorator config", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
decoratorConfig := config.(*AIPromptDecoratorConfig)
require.NotNil(t, decoratorConfig.Prepend)
require.NotNil(t, decoratorConfig.Append)
require.Len(t, decoratorConfig.Prepend, 1)
require.Len(t, decoratorConfig.Append, 1)
require.Equal(t, "system", decoratorConfig.Prepend[0].Role)
require.Equal(t, "You are a helpful assistant from ${geo-country}.", decoratorConfig.Prepend[0].Content)
require.Equal(t, "system", decoratorConfig.Append[0].Role)
require.Equal(t, "Please provide context about ${geo-city}.", decoratorConfig.Append[0].Content)
})
// 测试只有前置消息的配置解析
t.Run("prepend only config", func(t *testing.T) {
host, status := test.NewTestHost(prependOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
decoratorConfig := config.(*AIPromptDecoratorConfig)
require.NotNil(t, decoratorConfig.Prepend)
require.NotNil(t, decoratorConfig.Append)
require.Len(t, decoratorConfig.Prepend, 1)
require.Len(t, decoratorConfig.Append, 0)
require.Equal(t, "system", decoratorConfig.Prepend[0].Role)
require.Equal(t, "You are located in ${geo-province}, ${geo-country}.", decoratorConfig.Prepend[0].Content)
})
// 测试空配置解析
t.Run("empty config", func(t *testing.T) {
host, status := test.NewTestHost(emptyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
decoratorConfig := config.(*AIPromptDecoratorConfig)
require.NotNil(t, decoratorConfig.Prepend)
require.NotNil(t, decoratorConfig.Append)
require.Len(t, decoratorConfig.Prepend, 0)
require.Len(t, decoratorConfig.Append, 0)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试请求头处理
t.Run("request headers processing", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-length", "100"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试基础消息装饰
t.Run("basic message decoration", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置地理变量属性,供插件使用
host.SetProperty([]string{"geo-country"}, []byte("China"))
host.SetProperty([]string{"geo-province"}, []byte("Beijing"))
host.SetProperty([]string{"geo-city"}, []byte("Beijing"))
host.SetProperty([]string{"geo-isp"}, []byte("China Mobile"))
// 设置请求体,包含消息
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello, how are you?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
// 验证消息装饰是否成功
modifiedBody := host.GetRequestBody()
require.NotEmpty(t, modifiedBody)
// 解析修改后的请求体
var modifiedRequest map[string]interface{}
err := json.Unmarshal(modifiedBody, &modifiedRequest)
require.NoError(t, err)
// 验证messages字段存在
messages, exists := modifiedRequest["messages"].([]interface{})
require.True(t, exists, "messages field should exist")
require.NotNil(t, messages)
// 验证消息数量:前置消息(1) + 原始消息(1) + 后置消息(1) = 3
require.Len(t, messages, 3, "should have 3 messages: prepend + original + append")
// 验证第一个消息是前置消息(地理变量已被替换)
firstMessage := messages[0].(map[string]interface{})
require.Equal(t, "system", firstMessage["role"])
require.Equal(t, "You are a helpful assistant from China.", firstMessage["content"])
// 验证第二个消息是原始用户消息
secondMessage := messages[1].(map[string]interface{})
require.Equal(t, "user", secondMessage["role"])
require.Equal(t, "Hello, how are you?", secondMessage["content"])
// 验证第三个消息是后置消息(地理变量已被替换)
thirdMessage := messages[2].(map[string]interface{})
require.Equal(t, "system", thirdMessage["role"])
require.Equal(t, "Please provide context about Beijing.", thirdMessage["content"])
host.CompleteHttp()
})
// 测试只有前置消息的装饰
t.Run("prepend only decoration", func(t *testing.T) {
host, status := test.NewTestHost(prependOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置地理变量属性,供插件使用
host.SetProperty([]string{"geo-country"}, []byte("China"))
host.SetProperty([]string{"geo-province"}, []byte("Shanghai"))
host.SetProperty([]string{"geo-city"}, []byte("Shanghai"))
host.SetProperty([]string{"geo-isp"}, []byte("China Telecom"))
// 设置请求体,包含消息
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "What's the weather like?"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
// 验证消息装饰是否成功
modifiedBody := host.GetRequestBody()
require.NotEmpty(t, modifiedBody)
// 解析修改后的请求体
var modifiedRequest map[string]interface{}
err := json.Unmarshal(modifiedBody, &modifiedRequest)
require.NoError(t, err)
// 验证messages字段存在
messages, exists := modifiedRequest["messages"].([]interface{})
require.True(t, exists, "messages field should exist")
require.NotNil(t, messages)
// 验证消息数量:前置消息(1) + 原始消息(1) = 2
require.Len(t, messages, 2, "should have 2 messages: prepend + original")
// 验证第一个消息是前置消息(地理变量已被替换)
firstMessage := messages[0].(map[string]interface{})
require.Equal(t, "system", firstMessage["role"])
require.Equal(t, "You are located in Shanghai, China.", firstMessage["content"])
// 验证第二个消息是原始用户消息
secondMessage := messages[1].(map[string]interface{})
require.Equal(t, "user", secondMessage["role"])
require.Equal(t, "What's the weather like?", secondMessage["content"])
host.CompleteHttp()
})
// 测试空消息的情况
t.Run("empty messages", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置请求体不包含messages字段
body := `{
"model": "gpt-3.5-turbo"
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试多个消息的装饰
t.Run("multiple messages decoration", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置地理变量属性,供插件使用
host.SetProperty([]string{"geo-country"}, []byte("USA"))
host.SetProperty([]string{"geo-province"}, []byte("California"))
host.SetProperty([]string{"geo-city"}, []byte("San Francisco"))
host.SetProperty([]string{"geo-isp"}, []byte("Comcast"))
// 设置请求体,包含多个消息
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
// 验证消息装饰是否成功
modifiedBody := host.GetRequestBody()
require.NotEmpty(t, modifiedBody)
// 解析修改后的请求体
var modifiedRequest map[string]interface{}
err := json.Unmarshal(modifiedBody, &modifiedRequest)
require.NoError(t, err)
// 验证messages字段存在
messages, exists := modifiedRequest["messages"].([]interface{})
require.True(t, exists, "messages field should exist")
require.NotNil(t, messages)
// 验证消息数量:前置消息(1) + 原始消息(3) + 后置消息(1) = 5
require.Len(t, messages, 5, "should have 5 messages: prepend + original(3) + append")
// 验证第一个消息是前置消息(地理变量已被替换)
firstMessage := messages[0].(map[string]interface{})
require.Equal(t, "system", firstMessage["role"])
require.Equal(t, "You are a helpful assistant from USA.", firstMessage["content"])
// 验证原始消息保持顺序
originalMessages := messages[1:4]
require.Equal(t, "system", originalMessages[0].(map[string]interface{})["role"])
require.Equal(t, "You are a helpful assistant", originalMessages[0].(map[string]interface{})["content"])
require.Equal(t, "user", originalMessages[1].(map[string]interface{})["role"])
require.Equal(t, "Hello", originalMessages[1].(map[string]interface{})["content"])
require.Equal(t, "assistant", originalMessages[2].(map[string]interface{})["role"])
require.Equal(t, "Hi there!", originalMessages[2].(map[string]interface{})["content"])
// 验证最后一个消息是后置消息(地理变量已被替换)
lastMessage := messages[4].(map[string]interface{})
require.Equal(t, "system", lastMessage["role"])
require.Equal(t, "Please provide context about San Francisco.", lastMessage["content"])
host.CompleteHttp()
})
})
}
func TestStructs(t *testing.T) {
// 测试Message结构体
t.Run("Message struct", func(t *testing.T) {
message := Message{
Role: "system",
Content: "You are a helpful assistant from ${geo-country}.",
}
require.Equal(t, "system", message.Role)
require.Equal(t, "You are a helpful assistant from ${geo-country}.", message.Content)
})
// 测试AIPromptDecoratorConfig结构体
t.Run("AIPromptDecoratorConfig struct", func(t *testing.T) {
config := &AIPromptDecoratorConfig{
Prepend: []Message{
{Role: "system", Content: "Prepend message"},
},
Append: []Message{
{Role: "system", Content: "Append message"},
},
}
require.NotNil(t, config.Prepend)
require.NotNil(t, config.Append)
require.Len(t, config.Prepend, 1)
require.Len(t, config.Append, 1)
require.Equal(t, "Prepend message", config.Prepend[0].Content)
require.Equal(t, "Append message", config.Append[0].Content)
})
}
func TestGeographicVariableReplacement(t *testing.T) {
// 测试地理变量替换逻辑
t.Run("geographic variable replacement", func(t *testing.T) {
config := &AIPromptDecoratorConfig{
Prepend: []Message{
{
Role: "system",
Content: "Location: ${geo-country}/${geo-province}/${geo-city}, ISP: ${geo-isp}",
},
},
}
// 验证地理变量在内容中的存在
content := config.Prepend[0].Content
require.Contains(t, content, "${geo-country}")
require.Contains(t, content, "${geo-province}")
require.Contains(t, content, "${geo-city}")
require.Contains(t, content, "${geo-isp}")
// 测试变量替换逻辑
geoVariables := []string{"geo-country", "geo-province", "geo-city", "geo-isp"}
for _, geo := range geoVariables {
require.Contains(t, content, fmt.Sprintf("${%s}", geo))
}
})
// 测试混合内容的地理变量
t.Run("mixed content geographic variables", func(t *testing.T) {
config := &AIPromptDecoratorConfig{
Append: []Message{
{
Role: "system",
Content: "User from ${geo-country} with ISP ${geo-isp}. Context: ${geo-province}, ${geo-city}",
},
},
}
content := config.Append[0].Content
require.Contains(t, content, "${geo-country}")
require.Contains(t, content, "${geo-isp}")
require.Contains(t, content, "${geo-province}")
require.Contains(t, content, "${geo-city}")
})
}
func TestEdgeCases(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试空前置和后置消息
t.Run("empty prepend and append", func(t *testing.T) {
host, status := test.NewTestHost(emptyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Test message"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试无效JSON请求体
t.Run("invalid JSON body", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
})
// 设置无效的请求体
body := `{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
]
// Missing closing brace
`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
})
}

View File

@@ -5,14 +5,20 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.0
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,424 @@
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基础模板配置
var basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"templates": []map[string]interface{}{
{
"name": "greeting",
"template": "Hello {{name}}, welcome to {{company}}!",
},
{
"name": "summary",
"template": "Here is a summary of {{topic}}: {{content}}",
},
},
})
return data
}()
// 测试配置:单个模板配置
var singleTemplateConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"templates": []map[string]interface{}{
{
"name": "simple",
"template": "This is a {{adjective}} {{noun}}.",
},
},
})
return data
}()
// 测试配置:空模板配置
var emptyTemplatesConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"templates": []map[string]interface{}{},
})
return data
}()
// 测试配置:复杂模板配置
var complexTemplateConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"templates": []map[string]interface{}{
{
"name": "email",
"template": "Dear {{recipient}},\n\n{{greeting}}\n\n{{body}}\n\nBest regards,\n{{sender}}",
},
{
"name": "report",
"template": "Report: {{title}}\nDate: {{date}}\nAuthor: {{author}}\n\n{{content}}\n\nConclusion: {{conclusion}}",
},
},
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基础模板配置解析
t.Run("basic templates config", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
promptConfig := config.(*AIPromptTemplateConfig)
require.NotNil(t, promptConfig.templates)
require.Len(t, promptConfig.templates, 2)
// 由于gjson.Get("template").Raw返回JSON原始值包含引号
require.Equal(t, "\"Hello {{name}}, welcome to {{company}}!\"", promptConfig.templates["greeting"])
require.Equal(t, "\"Here is a summary of {{topic}}: {{content}}\"", promptConfig.templates["summary"])
})
// 测试单个模板配置解析
t.Run("single template config", func(t *testing.T) {
host, status := test.NewTestHost(singleTemplateConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
promptConfig := config.(*AIPromptTemplateConfig)
require.NotNil(t, promptConfig.templates)
require.Len(t, promptConfig.templates, 1)
// 由于gjson.Get("template").Raw返回JSON原始值包含引号
require.Equal(t, "\"This is a {{adjective}} {{noun}}.\"", promptConfig.templates["simple"])
})
// 测试空模板配置解析
t.Run("empty templates config", func(t *testing.T) {
host, status := test.NewTestHost(emptyTemplatesConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
promptConfig := config.(*AIPromptTemplateConfig)
require.NotNil(t, promptConfig.templates)
require.Len(t, promptConfig.templates, 0)
})
// 测试复杂模板配置解析
t.Run("complex templates config", func(t *testing.T) {
host, status := test.NewTestHost(complexTemplateConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
promptConfig := config.(*AIPromptTemplateConfig)
require.NotNil(t, promptConfig.templates)
require.Len(t, promptConfig.templates, 2)
// 由于gjson.Get("template").Raw返回JSON原始值包含引号和转义字符
require.Equal(t, "\"Dear {{recipient}},\\n\\n{{greeting}}\\n\\n{{body}}\\n\\nBest regards,\\n{{sender}}\"", promptConfig.templates["email"])
require.Equal(t, "\"Report: {{title}}\\nDate: {{date}}\\nAuthor: {{author}}\\n\\n{{content}}\\n\\nConclusion: {{conclusion}}\"", promptConfig.templates["report"])
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试启用模板的情况
t.Run("template enabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,启用模板
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "true"},
{"content-length", "100"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试禁用模板的情况
t.Run("template disabled", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,禁用模板
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "false"},
{"content-length", "100"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
// 测试没有template-enable头的情况
t.Run("no template-enable header", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头不包含template-enable
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"content-length", "100"},
})
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
})
})
}
func TestOnHttpRequestBody(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试基础模板替换
t.Run("basic template replacement", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "true"},
})
// 设置请求体,包含模板和属性
body := `{
"template": "greeting",
"properties": {
"name": "Alice",
"company": "TechCorp"
}
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试复杂模板替换
t.Run("complex template replacement", func(t *testing.T) {
host, status := test.NewTestHost(complexTemplateConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "true"},
})
// 设置请求体,包含复杂模板和属性
body := `{
"template": "email",
"properties": {
"recipient": "John Doe",
"greeting": "I hope this email finds you well",
"body": "Please find attached the quarterly report",
"sender": "Jane Smith"
}
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试没有模板的情况
t.Run("no template in body", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "true"},
})
// 设置请求体,不包含模板
body := `{
"messages": [
{"role": "user", "content": "Hello"}
]
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试没有属性的情况
t.Run("no properties in body", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "true"},
})
// 设置请求体,包含模板但不包含属性
body := `{
"template": "greeting"
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试部分属性替换
t.Run("partial properties replacement", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"template-enable", "true"},
})
// 设置请求体,只包含部分属性
body := `{
"template": "greeting",
"properties": {
"name": "Bob"
}
}`
action := host.CallOnHttpRequestBody([]byte(body))
// 应该返回ActionContinue
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
})
}
func TestStructs(t *testing.T) {
// 测试AIPromptTemplateConfig结构体
t.Run("AIPromptTemplateConfig struct", func(t *testing.T) {
config := &AIPromptTemplateConfig{
templates: map[string]string{
"test": "This is a {{test}} template",
},
}
require.NotNil(t, config.templates)
require.Len(t, config.templates, 1)
require.Equal(t, "This is a {{test}} template", config.templates["test"])
})
}
func TestTemplateReplacementLogic(t *testing.T) {
// 测试模板变量替换逻辑
t.Run("template variable replacement", func(t *testing.T) {
config := &AIPromptTemplateConfig{
templates: map[string]string{
"greeting": "Hello {{name}}, welcome to {{company}}!",
},
}
// 模拟模板替换逻辑
template := config.templates["greeting"]
require.Equal(t, "Hello {{name}}, welcome to {{company}}!", template)
// 测试变量替换
properties := map[string]string{
"name": "Alice",
"company": "TechCorp",
}
for key, value := range properties {
template = strings.ReplaceAll(template, fmt.Sprintf("{{%s}}", key), value)
}
require.Equal(t, "Hello Alice, welcome to TechCorp!", template)
})
// 测试嵌套变量替换
t.Run("nested variable replacement", func(t *testing.T) {
config := &AIPromptTemplateConfig{
templates: map[string]string{
"nested": "{{greeting}} {{name}}, {{message}}",
},
}
template := config.templates["nested"]
require.Equal(t, "{{greeting}} {{name}}, {{message}}", template)
// 测试嵌套替换
properties := map[string]string{
"greeting": "Hello",
"name": "World",
"message": "welcome!",
}
for key, value := range properties {
template = strings.ReplaceAll(template, fmt.Sprintf("{{%s}}", key), value)
}
require.Equal(t, "Hello World, welcome!", template)
})
}

View File

@@ -16,4 +16,3 @@
!*/
/out
/test

View File

@@ -9,10 +9,21 @@ description: AI 代理插件配置参考
`AI 代理`插件实现了基于 OpenAI API 契约的 AI 代理功能。目前支持 OpenAI、Azure OpenAI、月之暗面Moonshot和通义千问等 AI
服务提供商。
> **注意:**
**🚀 自动协议兼容 (Auto Protocol Compatibility)**
插件现在支持**自动协议检测**,无需配置即可同时兼容 OpenAI 和 Claude 两种协议格式:
- **OpenAI 协议**: 请求路径 `/v1/chat/completions`,使用标准的 OpenAI Messages API 格式
- **Claude 协议**: 请求路径 `/v1/messages`,使用 Anthropic Claude Messages API 格式
- **智能转换**: 自动检测请求协议,如果目标供应商不原生支持该协议,则自动进行协议转换
- **零配置**: 用户无需设置 `protocol` 字段,插件自动处理
> **协议支持说明:**
> 请求路径后缀匹配 `/v1/chat/completions` 时,对应文生文场景,会用 OpenAI 的文生文协议解析请求 Body再转换为对应 LLM 厂商的文生文协议
> 请求路径后缀匹配 `/v1/messages` 时,对应 Claude 文生文场景,会自动检测供应商能力:如果支持原生 Claude 协议则直接转发,否则先转换为 OpenAI 协议再转发给供应商
> 请求路径后缀匹配 `/v1/embeddings` 时,对应文本向量场景,会用 OpenAI 的文本向量协议解析请求 Body再转换为对应 LLM 厂商的文本向量协议
## 运行属性
@@ -158,6 +169,14 @@ DeepSeek 所对应的 `type` 为 `deepseek`。它并无特有的配置字段。
Groq 所对应的 `type``groq`。它并无特有的配置字段。
#### Grok
Grok 所对应的 `type``grok`。它并无特有的配置字段。
#### OpenRouter
OpenRouter 所对应的 `type``openrouter`。它并无特有的配置字段。
#### 文心一言Baidu
文心一言所对应的 `type``baidu`。它并无特有的配置字段。
@@ -231,10 +250,11 @@ Cloudflare Workers AI 所对应的 `type` 为 `cloudflare`。它特有的配置
Gemini 所对应的 `type``gemini`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| --------------------- | ------------- | -------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1``v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 |
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
| ---------------------- | ------------- | -------- | -------- | ------------------------------------------------------------ |
| `geminiSafetySetting` | map of string | 非必填 | - | Gemini AI 内容过滤和安全级别设定。参考[Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) |
| `apiVersion` | string | 非必填 | `v1beta` | 用于指定 API 的版本, 可选择 `v1``v1beta` 。 版本差异请参考[API versions explained](https://ai.google.dev/gemini-api/docs/api-versions)。 |
| `geminiThinkingBudget` | number | 非必填 | - | gemini2.5系列的参数0是不开启思考-1动态调整具体参数指可参考官网 |
#### DeepL
@@ -862,19 +882,167 @@ provider:
}
```
### 使用 OpenAI 协议代理 Claude 服务
### 使用 OpenAI 协议代理 Grok 服务
**配置信息**
```yaml
provider:
type: claude
type: grok
apiTokens:
- 'YOUR_GROK_API_TOKEN'
```
**请求示例**
```json
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that can answer questions and help with tasks."
},
{
"role": "user",
"content": "What is 101*3?"
}
],
"model": "grok-4"
}
```
**响应示例**
```json
{
"id": "a3d1008e-4544-40d4-d075-11527e794e4a",
"object": "chat.completion",
"created": 1752854522,
"model": "grok-4",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "101 multiplied by 3 is 303.",
"refusal": null
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 32,
"completion_tokens": 9,
"total_tokens": 135,
"prompt_tokens_details": {
"text_tokens": 32,
"audio_tokens": 0,
"image_tokens": 0,
"cached_tokens": 6
},
"completion_tokens_details": {
"reasoning_tokens": 94,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
},
"num_sources_used": 0
},
"system_fingerprint": "fp_3a7881249c"
}
```
### 使用 OpenAI 协议代理 OpenRouter 服务
**配置信息**
```yaml
provider:
type: openrouter
apiTokens:
- 'YOUR_OPENROUTER_API_TOKEN'
modelMapping:
'gpt-4': 'openai/gpt-4-turbo-preview'
'gpt-3.5-turbo': 'openai/gpt-3.5-turbo'
'claude-3': 'anthropic/claude-3-opus'
'*': 'openai/gpt-3.5-turbo'
```
**请求示例**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
],
"temperature": 0.7
}
```
**响应示例**
```json
{
"id": "gen-1234567890abcdef",
"object": "chat.completion",
"created": 1699123456,
"model": "openai/gpt-4-turbo-preview",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "你好我是一个AI助手通过OpenRouter平台提供服务。我可以帮助回答问题、协助创作、进行对话等。有什么我可以帮助你的吗"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 46,
"total_tokens": 58
}
}
```
### 使用自动协议兼容功能
插件现在支持自动协议检测,可以同时处理 OpenAI 和 Claude 两种协议格式的请求。
**配置信息**
```yaml
provider:
type: claude # 原生支持 Claude 协议的供应商
apiTokens:
- 'YOUR_CLAUDE_API_TOKEN'
version: '2023-06-01'
```
**请求示例**
**OpenAI 协议请求示例**
URL: `http://your-domain/v1/chat/completions`
```json
{
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
]
}
```
**Claude 协议请求示例**
URL: `http://your-domain/v1/messages`
```json
{
@@ -891,6 +1059,8 @@ provider:
**响应示例**
两种协议格式的请求都会返回相应格式的响应:
```json
{
"id": "msg_01Jt3GzyjuzymnxmZERJguLK",
@@ -915,6 +1085,39 @@ provider:
}
```
### 使用智能协议转换
当目标供应商不原生支持 Claude 协议时,插件会自动进行协议转换:
**配置信息**
```yaml
provider:
type: qwen # 不原生支持 Claude 协议,会自动转换
apiTokens:
- 'YOUR_QWEN_API_TOKEN'
modelMapping:
'claude-3-opus-20240229': 'qwen-max'
'*': 'qwen-turbo'
```
**Claude 协议请求**
URL: `http://your-domain/v1/messages` (自动转换为 OpenAI 协议调用供应商)
```json
{
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "你好,你是谁?"
}
]
}
```
### 使用 OpenAI 协议代理混元服务
**配置信息**

View File

@@ -8,10 +8,21 @@ description: Reference for configuring the AI Proxy plugin
The `AI Proxy` plugin implements AI proxy functionality based on the OpenAI API contract. It currently supports AI service providers such as OpenAI, Azure OpenAI, Moonshot, and Qwen.
> **Note:**
**🚀 Auto Protocol Compatibility**
The plugin now supports **automatic protocol detection**, allowing seamless compatibility with both OpenAI and Claude protocol formats without configuration:
- **OpenAI Protocol**: Request path `/v1/chat/completions`, using standard OpenAI Messages API format
- **Claude Protocol**: Request path `/v1/messages`, using Anthropic Claude Messages API format
- **Intelligent Conversion**: Automatically detects request protocol and performs conversion if the target provider doesn't natively support it
- **Zero Configuration**: No need to set `protocol` field, the plugin handles everything automatically
> **Protocol Support:**
> When the request path suffix matches `/v1/chat/completions`, it corresponds to text-to-text scenarios. The request body will be parsed using OpenAI's text-to-text protocol and then converted to the corresponding LLM vendor's text-to-text protocol.
> When the request path suffix matches `/v1/messages`, it corresponds to Claude text-to-text scenarios. The plugin automatically detects provider capabilities: if native Claude protocol is supported, requests are forwarded directly; otherwise, they are converted to OpenAI protocol first.
> When the request path suffix matches `/v1/embeddings`, it corresponds to text vector scenarios. The request body will be parsed using OpenAI's text vector protocol and then converted to the corresponding LLM vendor's text vector protocol.
## Execution Properties
@@ -35,7 +46,7 @@ Plugin execution priority: `100`
| `apiTokens` | array of string | Optional | - | Tokens used for authentication when accessing AI services. If multiple tokens are configured, the plugin randomly selects one for each request. Some service providers only support configuring a single token. |
| `timeout` | number | Optional | - | Timeout for accessing AI services, in milliseconds. The default value is 120000, which equals 2 minutes. Only used when retrieving context data. Won't affect the request forwarded to the LLM upstream. |
| `modelMapping` | map of string | Optional | - | Mapping table for AI models, used to map model names in requests to names supported by the service provider.<br/>1. Supports prefix matching. For example, "gpt-3-\*" matches all model names starting with “gpt-3-”;<br/>2. Supports using "\*" as a key for a general fallback mapping;<br/>3. If the mapped target name is an empty string "", the original model name is preserved. |
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider) |
| `protocol` | string | Optional | - | API contract provided by the plugin. Currently supports the following values: openai (default, uses OpenAI's interface contract), original (uses the raw interface contract of the target service provider). **Note: Auto protocol detection is now supported, no need to configure this field to support both OpenAI and Claude protocols** |
| `context` | object | Optional | - | Configuration for AI conversation context information |
| `customSettings` | array of customSetting | Optional | - | Specifies overrides or fills parameters for AI requests |
| `subPath` | string | Optional | - | If subPath is configured, the prefix will be removed from the request path before further processing. |
@@ -129,6 +140,14 @@ For DeepSeek, the corresponding `type` is `deepseek`. It has no unique configura
For Groq, the corresponding `type` is `groq`. It has no unique configuration fields.
#### Grok
For Grok, the corresponding `type` is `grok`. It has no unique configuration fields.
#### OpenRouter
For OpenRouter, the corresponding `type` is `openrouter`. It has no unique configuration fields.
#### ERNIE Bot
For ERNIE Bot, the corresponding `type` is `baidu`. It has no unique configuration fields.
@@ -200,6 +219,8 @@ For Gemini, the corresponding `type` is `gemini`. Its unique configuration field
| Name | Data Type | Filling Requirements | Default Value | Description |
|---------------------|----------|----------------------|---------------|---------------------------------------------------------------------------------------------------------|
| `geminiSafetySetting` | map of string | Optional | - | Gemini AI content filtering and safety level settings. Refer to [Safety settings](https://ai.google.dev/gemini-api/docs/safety-settings). |
| `apiVersion` | string | 非必填 | `v1beta` | To specify the version of the API, you can choose either 'v1' or 'v1beta'. Version differences refer to https://ai.google.dev/gemini-api/docs/api-versions |
| `geminiThinkingBudget` | number | 非必填 | - | The parameters of the gemini2.5 series: 0 indicates no thinking mode, -1 represents dynamic adjustment. For specific parameter references, please refer to the official website |
### DeepL
@@ -807,19 +828,167 @@ provider:
}
```
### Using OpenAI Protocol Proxy for Claude Service
### Using OpenAI Protocol Proxy for Grok Service
**Configuration Information**
```yaml
provider:
type: claude
type: grok
apiTokens:
- "YOUR_GROK_API_TOKEN"
```
**Example Request**
```json
{
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that can answer questions and help with tasks."
},
{
"role": "user",
"content": "What is 101*3?"
}
],
"model": "grok-4"
}
```
**Example Response**
```json
{
"id": "a3d1008e-4544-40d4-d075-11527e794e4a",
"object": "chat.completion",
"created": 1752854522,
"model": "grok-4",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "101 multiplied by 3 is 303.",
"refusal": null
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 32,
"completion_tokens": 9,
"total_tokens": 135,
"prompt_tokens_details": {
"text_tokens": 32,
"audio_tokens": 0,
"image_tokens": 0,
"cached_tokens": 6
},
"completion_tokens_details": {
"reasoning_tokens": 94,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0
},
"num_sources_used": 0
},
"system_fingerprint": "fp_3a7881249c"
}
```
### Using OpenAI Protocol Proxy for OpenRouter Service
**Configuration Information**
```yaml
provider:
type: openrouter
apiTokens:
- 'YOUR_OPENROUTER_API_TOKEN'
modelMapping:
'gpt-4': 'openai/gpt-4-turbo-preview'
'gpt-3.5-turbo': 'openai/gpt-3.5-turbo'
'claude-3': 'anthropic/claude-3-opus'
'*': 'openai/gpt-3.5-turbo'
```
**Example Request**
```json
{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Hello, who are you?"
}
],
"temperature": 0.7
}
```
**Example Response**
```json
{
"id": "gen-1234567890abcdef",
"object": "chat.completion",
"created": 1699123456,
"model": "openai/gpt-4-turbo-preview",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I am an AI assistant powered by OpenRouter. I can help answer questions, assist with creative tasks, engage in conversations, and more. How can I assist you today?"
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 12,
"completion_tokens": 35,
"total_tokens": 47
}
}
```
### Using Auto Protocol Compatibility
The plugin now supports automatic protocol detection, capable of handling both OpenAI and Claude protocol format requests simultaneously.
**Configuration Information**
```yaml
provider:
type: claude # Provider with native Claude protocol support
apiTokens:
- "YOUR_CLAUDE_API_TOKEN"
version: "2023-06-01"
```
**Example Request**
**OpenAI Protocol Request Example**
URL: `http://your-domain/v1/chat/completions`
```json
{
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello, who are you?"
}
]
}
```
**Claude Protocol Request Example**
URL: `http://your-domain/v1/messages`
```json
{
@@ -836,6 +1005,8 @@ provider:
**Example Response**
Both protocol formats will return responses in their respective formats:
```json
{
"id": "msg_01Jt3GzyjuzymnxmZERJguLK",
@@ -860,6 +1031,39 @@ provider:
}
```
### Using Intelligent Protocol Conversion
When the target provider doesn't natively support Claude protocol, the plugin automatically performs protocol conversion:
**Configuration Information**
```yaml
provider:
type: qwen # Doesn't natively support Claude protocol, auto-conversion applied
apiTokens:
- "YOUR_QWEN_API_TOKEN"
modelMapping:
'claude-3-opus-20240229': 'qwen-max'
'*': 'qwen-turbo'
```
**Claude Protocol Request**
URL: `http://your-domain/v1/messages` (automatically converted to OpenAI protocol for provider)
```json
{
"model": "claude-3-opus-20240229",
"max_tokens": 1024,
"messages": [
{
"role": "user",
"content": "Hello, who are you?"
}
]
}
```
### Using OpenAI Protocol Proxy for Hunyuan Service
**Configuration Information**

View File

File diff suppressed because it is too large Load Diff

View File

@@ -7,12 +7,14 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
)
require github.com/tetratelabs/wazero v1.7.2 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0

View File

@@ -2,16 +2,16 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.0 h1:4Ik5n3FsJ5+r13KLQl2ky+8NuAE8dfWQwoKxXYD2KAw=
github.com/higress-group/wasm-go v1.0.0/go.mod h1:ODBV27sjmhIW8Cqv3R74EUcTnbdkE69bmXBQFuRkY1M=
github.com/higress-group/wasm-go v1.0.1 h1:T1m++qTEANp8+jwE0sxltwtaTKmrHCkLOp1m9N+YeqY=
github.com/higress-group/wasm-go v1.0.1/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=

View File

@@ -6,6 +6,7 @@ package main
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
@@ -31,6 +32,11 @@ const (
ctxOriginalAuth = "original_auth"
)
type pair[K, V any] struct {
key K
value V
}
var (
headersCtxKeyMapping = map[string]string{
util.HeaderAuthority: ctxOriginalHost,
@@ -42,6 +48,44 @@ var (
util.HeaderPath: util.HeaderOriginalPath,
util.HeaderAuthorization: util.HeaderOriginalAuth,
}
pathSuffixToApiName = []pair[string, provider.ApiName]{
// OpenAI style
{provider.PathOpenAIChatCompletions, provider.ApiNameChatCompletion},
{provider.PathOpenAICompletions, provider.ApiNameCompletion},
{provider.PathOpenAIEmbeddings, provider.ApiNameEmbeddings},
{provider.PathOpenAIAudioSpeech, provider.ApiNameAudioSpeech},
{provider.PathOpenAIImageGeneration, provider.ApiNameImageGeneration},
{provider.PathOpenAIImageVariation, provider.ApiNameImageVariation},
{provider.PathOpenAIImageEdit, provider.ApiNameImageEdit},
{provider.PathOpenAIBatches, provider.ApiNameBatches},
{provider.PathOpenAIFiles, provider.ApiNameFiles},
{provider.PathOpenAIModels, provider.ApiNameModels},
{provider.PathOpenAIFineTuningJobs, provider.ApiNameFineTuningJobs},
{provider.PathOpenAIResponses, provider.ApiNameResponses},
// Anthropic style
{provider.PathAnthropicMessages, provider.ApiNameAnthropicMessages},
{provider.PathAnthropicComplete, provider.ApiNameAnthropicComplete},
// Cohere style
{provider.PathCohereV1Rerank, provider.ApiNameCohereV1Rerank},
}
pathPatternToApiName = []pair[*regexp.Regexp, provider.ApiName]{
// OpenAI style
{util.RegRetrieveBatchPath, provider.ApiNameRetrieveBatch},
{util.RegCancelBatchPath, provider.ApiNameCancelBatch},
{util.RegRetrieveFilePath, provider.ApiNameRetrieveFile},
{util.RegRetrieveFileContentPath, provider.ApiNameRetrieveFileContent},
{util.RegRetrieveFineTuningJobPath, provider.ApiNameRetrieveFineTuningJob},
{util.RegRetrieveFineTuningJobEventsPath, provider.ApiNameFineTuningJobEvents},
{util.RegRetrieveFineTuningJobCheckpointsPath, provider.ApiNameFineTuningJobCheckpoints},
{util.RegCancelFineTuningJobPath, provider.ApiNameCancelFineTuningJob},
{util.RegResumeFineTuningJobPath, provider.ApiNameResumeFineTuningJob},
{util.RegPauseFineTuningJobPath, provider.ApiNamePauseFineTuningJob},
{util.RegFineTuningCheckpointPermissionPath, provider.ApiNameFineTuningCheckpointPermissions},
{util.RegDeleteFineTuningCheckpointPermissionPath, provider.ApiNameDeleteFineTuningCheckpointPermission},
// Gemini style
{util.RegGeminiGenerateContent, provider.ApiNameGeminiGenerateContent},
{util.RegGeminiStreamGenerateContent, provider.ApiNameGeminiStreamGenerateContent},
}
)
func main() {}
@@ -97,6 +141,9 @@ func initContext(ctx wrapper.HttpContext) {
value, _ := proxywasm.GetHttpRequestHeader(header)
ctx.SetContext(ctxKey, value)
}
for _, originHeader := range headerToOriginalHeaderMapping {
proxywasm.RemoveHttpRequestHeader(originHeader)
}
}
func saveContextsToHeaders(ctx wrapper.HttpContext) {
@@ -127,6 +174,9 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
initContext(ctx)
rawPath := ctx.Path()
@@ -144,6 +194,23 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
}
// Auto-detect protocol based on request path and handle conversion if needed
// If request is Claude format (/v1/messages) but provider doesn't support it natively,
// convert to OpenAI format (/v1/chat/completions)
if apiName == provider.ApiNameAnthropicMessages && !providerConfig.IsSupportedAPI(provider.ApiNameAnthropicMessages) {
// Provider doesn't support Claude protocol natively, convert to OpenAI format
newPath := strings.Replace(path.Path, provider.PathAnthropicMessages, provider.PathOpenAIChatCompletions, 1)
_ = proxywasm.ReplaceHttpRequestHeader(":path", newPath)
// Update apiName to match the new path
apiName = provider.ApiNameChatCompletion
// Mark that we need to convert response back to Claude format
ctx.SetContext("needClaudeResponseConversion", true)
log.Debugf("[Auto Protocol] Claude request detected, provider doesn't support natively, converted path from %s to %s, apiName: %s", path.Path, newPath, apiName)
} else if apiName == provider.ApiNameAnthropicMessages {
// Provider supports Claude protocol natively, no conversion needed
log.Debugf("[Auto Protocol] Claude request detected, provider supports natively, keeping original path: %s, apiName: %s", path.Path, apiName)
}
if contentType, _ := proxywasm.GetHttpRequestHeader(util.HeaderContentType); contentType != "" && !strings.Contains(contentType, util.MimeTypeApplicationJson) {
ctx.DontReadRequestBody()
log.Debugf("[onHttpRequestHeader] unsupported content type: %s, will not process the request body", contentType)
@@ -156,8 +223,6 @@ func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConf
}
ctx.SetContext(provider.CtxKeyApiName, apiName)
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
ctx.DisableReroute()
// Always remove the Accept-Encoding header to prevent the LLM from sending compressed responses,
// allowing plugins to inspect or modify the response correctly
@@ -275,17 +340,20 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo
}
util.ReplaceResponseHeaders(headers)
checkStream(ctx)
_, needHandleBody := activeProvider.(provider.TransformResponseBodyHandler)
var needHandleStreamingBody bool
_, needHandleStreamingBody = activeProvider.(provider.StreamingResponseBodyHandler)
if !needHandleStreamingBody {
_, needHandleStreamingBody = activeProvider.(provider.StreamingEventHandler)
}
if !needHandleBody && !needHandleStreamingBody {
// Check if we need to read body for Claude response conversion
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
if !needHandleBody && !needHandleStreamingBody && !needClaudeConversion {
ctx.DontReadResponseBody()
} else if !needHandleStreamingBody {
ctx.BufferResponseBody()
} else {
checkStream(ctx)
}
return types.ActionContinue
@@ -306,7 +374,12 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk)
if err == nil && modifiedChunk != nil {
return modifiedChunk
// Convert to Claude format if needed
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, modifiedChunk)
if convertErr != nil {
return modifiedChunk
}
return claudeChunk
}
return chunk
}
@@ -315,8 +388,8 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
events := provider.ExtractStreamingEvents(ctx, chunk)
log.Debugf("[onStreamingResponseBody] %d events received", len(events))
if len(events) == 0 {
// No events are extracted, return the original chunk
return chunk
// No events are extracted, return empty bytes slice
return []byte("")
}
var responseBuilder strings.Builder
for _, event := range events {
@@ -332,7 +405,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
log.Errorf("[onStreamingResponseBody] failed to process streaming event: %v\n%s", err, chunk)
return chunk
}
if outputEvents == nil || len(outputEvents) == 0 {
if len(outputEvents) == 0 {
responseBuilder.WriteString(event.ToHttpString())
} else {
for _, outputEvent := range outputEvents {
@@ -340,9 +413,40 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin
}
}
}
return []byte(responseBuilder.String())
result := []byte(responseBuilder.String())
// Convert to Claude format if needed
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
if convertErr != nil {
return result
}
return claudeChunk
}
return chunk
// If provider doesn't implement any streaming handlers but we need Claude conversion
// First extract complete events from the chunk
events := provider.ExtractStreamingEvents(ctx, chunk)
log.Debugf("[onStreamingResponseBody] %d events received (no handler)", len(events))
if len(events) == 0 {
// No events are extracted, return empty bytes slice
return []byte("")
}
// Build response from extracted events (without handler processing)
var responseBuilder strings.Builder
for _, event := range events {
responseBuilder.WriteString(event.ToHttpString())
}
result := []byte(responseBuilder.String())
// Convert to Claude format if needed
claudeChunk, convertErr := convertStreamingResponseToClaude(ctx, result)
if convertErr != nil {
return result
}
return claudeChunk
}
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte) types.Action {
@@ -355,20 +459,82 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
var finalBody []byte
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
body, err := handler.TransformResponseBody(ctx, apiName, body)
transformedBody, err := handler.TransformResponseBody(ctx, apiName, body)
if err != nil {
_ = util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
return types.ActionContinue
}
if err = provider.ReplaceResponseBody(body); err != nil {
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
}
finalBody = transformedBody
} else {
finalBody = body
}
// Convert to Claude format if needed (applies to both branches)
convertedBody, err := convertResponseBodyToClaude(ctx, finalBody)
if err != nil {
_ = util.ErrorHandler("ai-proxy.convert_resp_to_claude_failed", err)
return types.ActionContinue
}
if err = provider.ReplaceResponseBody(convertedBody); err != nil {
_ = util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
}
return types.ActionContinue
}
// Helper function to check if Claude response conversion is needed
func needsClaudeResponseConversion(ctx wrapper.HttpContext) bool {
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
return needClaudeConversion
}
// Helper function to convert OpenAI streaming response to Claude format
func convertStreamingResponseToClaude(ctx wrapper.HttpContext, data []byte) ([]byte, error) {
if !needsClaudeResponseConversion(ctx) {
return data, nil
}
// Get or create converter instance from context to maintain state
const claudeConverterKey = "claudeConverter"
var converter *provider.ClaudeToOpenAIConverter
if converterData := ctx.GetContext(claudeConverterKey); converterData != nil {
if c, ok := converterData.(*provider.ClaudeToOpenAIConverter); ok {
converter = c
}
}
if converter == nil {
converter = &provider.ClaudeToOpenAIConverter{}
ctx.SetContext(claudeConverterKey, converter)
}
claudeChunk, err := converter.ConvertOpenAIStreamResponseToClaude(ctx, data)
if err != nil {
log.Errorf("failed to convert streaming response to claude format: %v", err)
return data, err
}
return claudeChunk, nil
}
// Helper function to convert OpenAI response body to Claude format
func convertResponseBodyToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
if !needsClaudeResponseConversion(ctx) {
return body, nil
}
converter := &provider.ClaudeToOpenAIConverter{}
convertedBody, err := converter.ConvertOpenAIResponseToClaude(ctx, body)
if err != nil {
return body, fmt.Errorf("failed to convert response to claude format: %v", err)
}
return convertedBody, nil
}
func normalizeOpenAiRequestBody(body []byte) []byte {
var err error
// Default setting include_usage.
@@ -393,99 +559,19 @@ func checkStream(ctx wrapper.HttpContext) {
}
func getApiName(path string) provider.ApiName {
// openai style
if strings.HasSuffix(path, provider.PathOpenAIChatCompletions) {
return provider.ApiNameChatCompletion
}
if strings.HasSuffix(path, provider.PathOpenAICompletions) {
return provider.ApiNameCompletion
}
if strings.HasSuffix(path, provider.PathOpenAIEmbeddings) {
return provider.ApiNameEmbeddings
}
if strings.HasSuffix(path, provider.PathOpenAIAudioSpeech) {
return provider.ApiNameAudioSpeech
}
if strings.HasSuffix(path, provider.PathOpenAIImageGeneration) {
return provider.ApiNameImageGeneration
}
if strings.HasSuffix(path, provider.PathOpenAIImageVariation) {
return provider.ApiNameImageVariation
}
if strings.HasSuffix(path, provider.PathOpenAIImageEdit) {
return provider.ApiNameImageEdit
}
if strings.HasSuffix(path, provider.PathOpenAIBatches) {
return provider.ApiNameBatches
}
if util.RegRetrieveBatchPath.MatchString(path) {
return provider.ApiNameRetrieveBatch
}
if util.RegCancelBatchPath.MatchString(path) {
return provider.ApiNameCancelBatch
}
if strings.HasSuffix(path, provider.PathOpenAIFiles) {
return provider.ApiNameFiles
}
if util.RegRetrieveFilePath.MatchString(path) {
return provider.ApiNameRetrieveFile
}
if util.RegRetrieveFileContentPath.MatchString(path) {
return provider.ApiNameRetrieveFileContent
}
if strings.HasSuffix(path, provider.PathOpenAIModels) {
return provider.ApiNameModels
}
if strings.HasSuffix(path, provider.PathOpenAIFineTuningJobs) {
return provider.ApiNameFineTuningJobs
}
if util.RegRetrieveFineTuningJobPath.MatchString(path) {
return provider.ApiNameRetrieveFineTuningJob
}
if util.RegRetrieveFineTuningJobEventsPath.MatchString(path) {
return provider.ApiNameFineTuningJobEvents
}
if util.RegRetrieveFineTuningJobCheckpointsPath.MatchString(path) {
return provider.ApiNameFineTuningJobCheckpoints
}
if util.RegCancelFineTuningJobPath.MatchString(path) {
return provider.ApiNameCancelFineTuningJob
}
if util.RegResumeFineTuningJobPath.MatchString(path) {
return provider.ApiNameResumeFineTuningJob
}
if util.RegPauseFineTuningJobPath.MatchString(path) {
return provider.ApiNamePauseFineTuningJob
}
if util.RegFineTuningCheckpointPermissionPath.MatchString(path) {
return provider.ApiNameFineTuningCheckpointPermissions
}
if util.RegDeleteFineTuningCheckpointPermissionPath.MatchString(path) {
return provider.ApiNameDeleteFineTuningCheckpointPermission
}
if strings.HasSuffix(path, provider.PathOpenAIResponses) {
return provider.ApiNameResponses
// Check path suffix matches first
for _, p := range pathSuffixToApiName {
if strings.HasSuffix(path, p.key) {
return p.value
}
}
// Anthropic
if strings.HasSuffix(path, provider.PathAnthropicMessages) {
return provider.ApiNameAnthropicMessages
}
if strings.HasSuffix(path, provider.PathAnthropicComplete) {
return provider.ApiNameAnthropicComplete
// Check path pattern matches
for _, p := range pathPatternToApiName {
if p.key.MatchString(path) {
return p.value
}
}
// Gemini
if util.RegGeminiGenerateContent.MatchString(path) {
return provider.ApiNameGeminiGenerateContent
}
if util.RegGeminiStreamGenerateContent.MatchString(path) {
return provider.ApiNameGeminiStreamGenerateContent
}
// cohere style
if strings.HasSuffix(path, provider.PathCohereV1Rerank) {
return provider.ApiNameCohereV1Rerank
}
return ""
}

View File

@@ -0,0 +1,105 @@
package main
import (
"testing"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/test"
)
func Test_getApiName(t *testing.T) {
tests := []struct {
name string
path string
want provider.ApiName
}{
// OpenAI style
{"openai chat completions", "/v1/chat/completions", provider.ApiNameChatCompletion},
{"openai completions", "/v1/completions", provider.ApiNameCompletion},
{"openai embeddings", "/v1/embeddings", provider.ApiNameEmbeddings},
{"openai audio speech", "/v1/audio/speech", provider.ApiNameAudioSpeech},
{"openai image generation", "/v1/images/generations", provider.ApiNameImageGeneration},
{"openai image variation", "/v1/images/variations", provider.ApiNameImageVariation},
{"openai image edit", "/v1/images/edits", provider.ApiNameImageEdit},
{"openai batches", "/v1/batches", provider.ApiNameBatches},
{"openai retrieve batch", "/v1/batches/batchid", provider.ApiNameRetrieveBatch},
{"openai cancel batch", "/v1/batches/batchid/cancel", provider.ApiNameCancelBatch},
{"openai files", "/v1/files", provider.ApiNameFiles},
{"openai retrieve file", "/v1/files/fileid", provider.ApiNameRetrieveFile},
{"openai retrieve file content", "/v1/files/fileid/content", provider.ApiNameRetrieveFileContent},
{"openai models", "/v1/models", provider.ApiNameModels},
{"openai fine tuning jobs", "/v1/fine_tuning/jobs", provider.ApiNameFineTuningJobs},
{"openai retrieve fine tuning job", "/v1/fine_tuning/jobs/jobid", provider.ApiNameRetrieveFineTuningJob},
{"openai fine tuning job events", "/v1/fine_tuning/jobs/jobid/events", provider.ApiNameFineTuningJobEvents},
{"openai fine tuning job checkpoints", "/v1/fine_tuning/jobs/jobid/checkpoints", provider.ApiNameFineTuningJobCheckpoints},
{"openai cancel fine tuning job", "/v1/fine_tuning/jobs/jobid/cancel", provider.ApiNameCancelFineTuningJob},
{"openai resume fine tuning job", "/v1/fine_tuning/jobs/jobid/resume", provider.ApiNameResumeFineTuningJob},
{"openai pause fine tuning job", "/v1/fine_tuning/jobs/jobid/pause", provider.ApiNamePauseFineTuningJob},
{"openai fine tuning checkpoint permissions", "/v1/fine_tuning/checkpoints/checkpointid/permissions", provider.ApiNameFineTuningCheckpointPermissions},
{"openai delete fine tuning checkpoint permission", "/v1/fine_tuning/checkpoints/checkpointid/permissions/permissionid", provider.ApiNameDeleteFineTuningCheckpointPermission},
{"openai responses", "/v1/responses", provider.ApiNameResponses},
// Anthropic
{"anthropic messages", "/v1/messages", provider.ApiNameAnthropicMessages},
{"anthropic complete", "/v1/complete", provider.ApiNameAnthropicComplete},
// Gemini
{"gemini generate content", "/v1beta/models/gemini-1.0-pro:generateContent", provider.ApiNameGeminiGenerateContent},
{"gemini stream generate content", "/v1beta/models/gemini-1.0-pro:streamGenerateContent", provider.ApiNameGeminiStreamGenerateContent},
// Cohere
{"cohere rerank", "/v1/rerank", provider.ApiNameCohereV1Rerank},
// Unknown
{"unknown", "/v1/unknown", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getApiName(tt.path)
if got != tt.want {
t.Errorf("getApiName(%q) = %v, want %v", tt.path, got, tt.want)
}
})
}
}
func TestAi360(t *testing.T) {
test.RunAi360ParseConfigTests(t)
test.RunAi360OnHttpRequestHeadersTests(t)
test.RunAi360OnHttpRequestBodyTests(t)
test.RunAi360OnHttpResponseHeadersTests(t)
test.RunAi360OnHttpResponseBodyTests(t)
test.RunAi360OnStreamingResponseBodyTests(t)
}
func TestOpenAI(t *testing.T) {
test.RunOpenAIParseConfigTests(t)
test.RunOpenAIOnHttpRequestHeadersTests(t)
test.RunOpenAIOnHttpRequestBodyTests(t)
test.RunOpenAIOnHttpResponseHeadersTests(t)
test.RunOpenAIOnHttpResponseBodyTests(t)
test.RunOpenAIOnStreamingResponseBodyTests(t)
}
func TestQwen(t *testing.T) {
test.RunQwenParseConfigTests(t)
test.RunQwenOnHttpRequestHeadersTests(t)
test.RunQwenOnHttpRequestBodyTests(t)
test.RunQwenOnHttpResponseHeadersTests(t)
test.RunQwenOnHttpResponseBodyTests(t)
test.RunQwenOnStreamingResponseBodyTests(t)
}
func TestGemini(t *testing.T) {
test.RunGeminiParseConfigTests(t)
test.RunGeminiOnHttpRequestHeadersTests(t)
test.RunGeminiOnHttpRequestBodyTests(t)
test.RunGeminiOnHttpResponseHeadersTests(t)
test.RunGeminiOnHttpResponseBodyTests(t)
test.RunGeminiOnStreamingResponseBodyTests(t)
test.RunGeminiGetImageURLTests(t)
}
func TestAzure(t *testing.T) {
test.RunAzureParseConfigTests(t)
test.RunAzureOnHttpRequestHeadersTests(t)
test.RunAzureOnHttpRequestBodyTests(t)
test.RunAzureOnHttpResponseHeadersTests(t)
test.RunAzureOnHttpResponseBodyTests(t)
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/higress-group/wasm-go/pkg/wrapper"
)
type azureServiceUrlType int
const (
pathAzurePrefix = "/openai"
pathAzureModelPlaceholder = "{model}"
@@ -21,6 +23,12 @@ const (
queryAzureApiVersion = "api-version"
)
const (
azureServiceUrlTypeFull azureServiceUrlType = iota
azureServiceUrlTypeWithDeployment
azureServiceUrlTypeDomainOnly
)
var (
azureModelIrrelevantApis = map[ApiName]bool{
ApiNameModels: true,
@@ -31,7 +39,7 @@ var (
ApiNameRetrieveFile: true,
ApiNameRetrieveFileContent: true,
}
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(/.*|$)")
regexAzureModelWithPath = regexp.MustCompile("/openai/deployments/(.+?)(?:/(.*)|$)")
)
// azureProvider is the provider for Azure OpenAI service.
@@ -82,32 +90,44 @@ func (m *azureProviderInitializer) CreateProvider(config ProviderConfig) (Provid
modelSubMatch := regexAzureModelWithPath.FindStringSubmatch(serviceUrl.Path)
defaultModel := "placeholder"
var serviceUrlType azureServiceUrlType
if modelSubMatch != nil {
defaultModel = modelSubMatch[1]
if modelSubMatch[2] != "" {
serviceUrlType = azureServiceUrlTypeFull
} else {
serviceUrlType = azureServiceUrlTypeWithDeployment
}
log.Debugf("azureProvider: found default model from serviceUrl: %s", defaultModel)
} else {
serviceUrlType = azureServiceUrlTypeDomainOnly
log.Debugf("azureProvider: no default model found in serviceUrl")
}
log.Debugf("azureProvider: serviceUrlType=%d", serviceUrlType)
config.setDefaultCapabilities(m.DefaultCapabilities())
apiVersion := serviceUrl.Query().Get(queryAzureApiVersion)
log.Debugf("azureProvider: using %s: %s", queryAzureApiVersion, apiVersion)
return &azureProvider{
config: config,
serviceUrl: serviceUrl,
apiVersion: apiVersion,
defaultModel: defaultModel,
contextCache: createContextCache(&config),
config: config,
serviceUrl: serviceUrl,
serviceUrlType: serviceUrlType,
serviceUrlFullPath: serviceUrl.Path + "?" + serviceUrl.RawQuery,
apiVersion: apiVersion,
defaultModel: defaultModel,
contextCache: createContextCache(&config),
}, nil
}
type azureProvider struct {
config ProviderConfig
contextCache *contextCache
serviceUrl *url.URL
apiVersion string
defaultModel string
contextCache *contextCache
serviceUrl *url.URL
serviceUrlFullPath string
serviceUrlType azureServiceUrlType
apiVersion string
defaultModel string
}
func (m *azureProvider) GetProviderType() string {
@@ -152,21 +172,31 @@ func (m *azureProvider) transformRequestPath(ctx wrapper.HttpContext, apiName Ap
return originalPath
}
if m.serviceUrlType == azureServiceUrlTypeFull {
log.Debugf("azureProvider: use configured path %s", m.serviceUrlFullPath)
return m.serviceUrlFullPath
}
log.Debugf("azureProvider: original request path: %s", originalPath)
path := util.MapRequestPathByCapability(string(apiName), originalPath, m.config.capabilities)
log.Debugf("azureProvider: path: %s", path)
if strings.Contains(path, pathAzureModelPlaceholder) {
log.Debugf("azureProvider: path contains placeholder: %s", path)
model := ctx.GetStringContext(ctxKeyFinalRequestModel, "")
log.Debugf("azureProvider: model from context: %s", model)
if model == "" {
var model string
if m.serviceUrlType == azureServiceUrlTypeWithDeployment {
model = m.defaultModel
log.Debugf("azureProvider: use default model: %s", model)
} else {
model = ctx.GetStringContext(ctxKeyFinalRequestModel, "")
log.Debugf("azureProvider: model from context: %s", model)
if model == "" {
model = m.defaultModel
log.Debugf("azureProvider: use default model: %s", model)
}
}
path = strings.ReplaceAll(path, pathAzureModelPlaceholder, model)
log.Debugf("azureProvider: model replaced path: %s", path)
}
path = fmt.Sprintf("%s?%s=%s", path, queryAzureApiVersion, m.apiVersion)
path = path + "?" + m.serviceUrl.RawQuery
log.Debugf("azureProvider: final path: %s", path)
return path

View File

@@ -19,12 +19,9 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -99,8 +96,31 @@ func (b *bedrockProvider) convertEventFromBedrockToOpenAI(ctx wrapper.HttpContex
if bedrockEvent.Role != nil {
chatChoice.Delta.Role = *bedrockEvent.Role
}
if bedrockEvent.Start != nil {
chatChoice.Delta.Content = nil
chatChoice.Delta.ToolCalls = []toolCall{
{
Id: bedrockEvent.Start.ToolUse.ToolUseID,
Type: "function",
Function: functionCall{
Name: bedrockEvent.Start.ToolUse.Name,
Arguments: "",
},
},
}
}
if bedrockEvent.Delta != nil {
chatChoice.Delta = &chatMessage{Content: bedrockEvent.Delta.Text}
if bedrockEvent.Delta.ToolUse != nil {
chatChoice.Delta.ToolCalls = []toolCall{
{
Type: "function",
Function: functionCall{
Arguments: bedrockEvent.Delta.ToolUse.Input,
},
},
}
}
}
if bedrockEvent.StopReason != nil {
chatChoice.FinishReason = util.Ptr(stopReasonBedrock2OpenAI(*bedrockEvent.StopReason))
@@ -591,29 +611,7 @@ func (b *bedrockProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return b.config.handleRequestBody(b, b.contextCache, ctx, apiName, body)
}
func (b *bedrockProvider) insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error) {
request := &bedrockTextGenRequest{}
if err := json.Unmarshal(body, request); err != nil {
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
if len(request.System) > 0 {
request.System = append(request.System, systemContentBlock{Text: content})
} else {
request.System = []systemContentBlock{{Text: content}}
}
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, nil)
return requestBytes, err
}
func (b *bedrockProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if gjson.GetBytes(body, "model").Exists() {
rawModel := gjson.GetBytes(body, "model").String()
encodedModel := url.QueryEscape(rawModel)
body, _ = sjson.SetBytes(body, "model", encodedModel)
}
switch apiName {
case ApiNameChatCompletion:
return b.onChatCompletionRequestBody(ctx, body, headers)
@@ -651,7 +649,7 @@ func (b *bedrockProvider) onImageGenerationRequestBody(ctx wrapper.HttpContext,
return nil, err
}
headers.Set("Accept", "*/*")
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, request.Model))
b.overwriteRequestPathHeader(headers, bedrockInvokeModelPath, request.Model)
return b.buildBedrockImageGenerationRequest(request, headers)
}
@@ -675,7 +673,6 @@ func (b *bedrockProvider) buildBedrockImageGenerationRequest(origRequest *imageG
Quality: origRequest.Quality,
},
}
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockInvokeModelPath, origRequest.Model))
requestBytes, err := json.Marshal(request)
b.setAuthHeaders(requestBytes, headers)
return requestBytes, err
@@ -714,9 +711,9 @@ func (b *bedrockProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, b
streaming := request.Stream
headers.Set("Accept", "*/*")
if streaming {
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockStreamChatCompletionPath, request.Model))
b.overwriteRequestPathHeader(headers, bedrockStreamChatCompletionPath, request.Model)
} else {
util.OverwriteRequestPathHeader(headers, fmt.Sprintf(bedrockChatCompletionPath, request.Model))
b.overwriteRequestPathHeader(headers, bedrockChatCompletionPath, request.Model)
}
return b.buildBedrockTextGenerationRequest(request, headers)
}
@@ -726,9 +723,12 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
systemMessages := make([]systemContentBlock, 0)
for _, msg := range origRequest.Messages {
if msg.Role == roleSystem {
switch msg.Role {
case roleSystem:
systemMessages = append(systemMessages, systemContentBlock{Text: msg.StringContent()})
} else {
case roleTool:
messages = append(messages, chatToolMessage2BedrockMessage(msg))
default:
messages = append(messages, chatMessage2BedrockMessage(msg))
}
}
@@ -747,6 +747,36 @@ func (b *bedrockProvider) buildBedrockTextGenerationRequest(origRequest *chatCom
},
}
if origRequest.Tools != nil {
request.ToolConfig = &bedrockToolConfig{}
if origRequest.ToolChoice == nil {
request.ToolConfig.ToolChoice.Auto = &struct{}{}
} else if choice_type, ok := origRequest.ToolChoice.(string); ok {
switch choice_type {
case "required":
request.ToolConfig.ToolChoice.Any = &struct{}{}
case "auto":
request.ToolConfig.ToolChoice.Auto = &struct{}{}
case "none":
request.ToolConfig.ToolChoice.Auto = &struct{}{}
}
} else if choice, ok := origRequest.ToolChoice.(toolChoice); ok {
request.ToolConfig.ToolChoice.Tool = &bedrockToolSpecification{
Name: choice.Function.Name,
}
}
request.ToolConfig.Tools = []bedrockTool{}
for _, tool := range origRequest.Tools {
request.ToolConfig.Tools = append(request.ToolConfig.Tools, bedrockTool{
ToolSpec: bedrockToolSpecification{
InputSchema: bedrockToolInputSchemaJson{Json: tool.Function.Parameters},
Name: tool.Function.Name,
Description: tool.Function.Description,
},
})
}
}
for key, value := range b.config.bedrockAdditionalFields {
request.AdditionalModelRequestFields[key] = value
}
@@ -761,16 +791,29 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
if len(bedrockResponse.Output.Message.Content) > 0 {
outputContent = bedrockResponse.Output.Message.Content[0].Text
}
choices := []chatCompletionChoice{
{
Index: 0,
Message: &chatMessage{
Role: bedrockResponse.Output.Message.Role,
Content: outputContent,
},
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
choice := chatCompletionChoice{
Index: 0,
Message: &chatMessage{
Role: bedrockResponse.Output.Message.Role,
Content: outputContent,
},
FinishReason: util.Ptr(stopReasonBedrock2OpenAI(bedrockResponse.StopReason)),
}
choice.Message.ToolCalls = []toolCall{}
for _, content := range bedrockResponse.Output.Message.Content {
if content.ToolUse != nil {
args, _ := json.Marshal(content.ToolUse.Input)
choice.Message.ToolCalls = append(choice.Message.ToolCalls, toolCall{
Id: content.ToolUse.ToolUseId,
Type: "function",
Function: functionCall{
Name: content.ToolUse.Name,
Arguments: string(args),
},
})
}
}
choices := []chatCompletionChoice{choice}
requestId := ctx.GetStringContext(requestIdHeader, "")
modelId, _ := url.QueryUnescape(ctx.GetStringContext(ctxKeyFinalRequestModel, ""))
return &chatCompletionResponse{
@@ -788,6 +831,17 @@ func (b *bedrockProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, b
}
}
func (b *bedrockProvider) overwriteRequestPathHeader(headers http.Header, format, model string) {
modelInPath := model
// Just in case the model name has already been URL-escaped, we shouldn't escape it again.
if !strings.ContainsRune(model, '%') {
modelInPath = url.QueryEscape(model)
}
path := fmt.Sprintf(format, modelInPath)
log.Debugf("overwriting bedrock request path: %s", path)
util.OverwriteRequestPathHeader(headers, path)
}
func stopReasonBedrock2OpenAI(reason string) string {
switch reason {
case "end_turn":
@@ -796,6 +850,8 @@ func stopReasonBedrock2OpenAI(reason string) string {
return finishReasonStop
case "max_tokens":
return finishReasonLength
case "tool_use":
return finishReasonToolCall
default:
return reason
}
@@ -807,20 +863,48 @@ type bedrockTextGenRequest struct {
InferenceConfig bedrockInferenceConfig `json:"inferenceConfig,omitempty"`
AdditionalModelRequestFields map[string]interface{} `json:"additionalModelRequestFields,omitempty"`
PerformanceConfig PerformanceConfiguration `json:"performanceConfig,omitempty"`
ToolConfig *bedrockToolConfig `json:"toolConfig,omitempty"`
}
type bedrockToolConfig struct {
Tools []bedrockTool `json:"tools,omitempty"`
ToolChoice bedrockToolChoice `json:"toolChoice,omitempty"`
}
type PerformanceConfiguration struct {
Latency string `json:"latency,omitempty"`
}
type bedrockTool struct {
ToolSpec bedrockToolSpecification `json:"toolSpec,omitempty"`
}
type bedrockToolChoice struct {
Any *struct{} `json:"any,omitempty"`
Auto *struct{} `json:"auto,omitempty"`
Tool *bedrockToolSpecification `json:"tool,omitempty"`
}
type bedrockToolSpecification struct {
InputSchema bedrockToolInputSchemaJson `json:"inputSchema,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
}
type bedrockToolInputSchemaJson struct {
Json map[string]interface{} `json:"json,omitempty"`
}
type bedrockMessage struct {
Role string `json:"role"`
Content []bedrockMessageContent `json:"content"`
}
type bedrockMessageContent struct {
Text string `json:"text,omitempty"`
Image *imageBlock `json:"image,omitempty"`
Text string `json:"text,omitempty"`
Image *imageBlock `json:"image,omitempty"`
ToolResult *toolResultBlock `json:"toolResult,omitempty"`
ToolUse *toolUseBlock `json:"toolUse,omitempty"`
}
type systemContentBlock struct {
@@ -836,6 +920,22 @@ type imageSource struct {
Bytes string `json:"bytes,omitempty"`
}
type toolResultBlock struct {
ToolUseId string `json:"toolUseId"`
Content []toolResultContentBlock `json:"content"`
Status string `json:"status,omitempty"`
}
type toolResultContentBlock struct {
Text string `json:"text"`
}
type toolUseBlock struct {
Input map[string]interface{} `json:"input"`
Name string `json:"name"`
ToolUseId string `json:"toolUseId"`
}
type bedrockInferenceConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
@@ -859,13 +959,19 @@ type converseOutputMemberMessage struct {
}
type message struct {
Content []contentBlockMemberText `json:"content"`
Role string `json:"role"`
Content []contentBlock `json:"content"`
Role string `json:"role"`
}
type contentBlockMemberText struct {
Text string `json:"text"`
type contentBlock struct {
Text string `json:"text,omitempty"`
ToolUse *bedrockToolUse `json:"toolUse,omitempty"`
}
type bedrockToolUse struct {
Name string `json:"name"`
ToolUseId string `json:"toolUseId"`
Input map[string]interface{} `json:"input"`
}
type tokenUsage struct {
@@ -876,9 +982,53 @@ type tokenUsage struct {
TotalTokens int `json:"totalTokens"`
}
func chatToolMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
toolResultContent := &toolResultBlock{}
toolResultContent.ToolUseId = chatMessage.ToolCallId
if text, ok := chatMessage.Content.(string); ok {
toolResultContent.Content = []toolResultContentBlock{
{
Text: text,
},
}
openaiContent := chatMessage.ParseContent()
for _, part := range openaiContent {
var content bedrockMessageContent
if part.Type == contentTypeText {
content.Text = part.Text
} else {
continue
}
}
} else {
log.Warnf("only text content is supported, current content is %v", chatMessage.Content)
}
return bedrockMessage{
Role: roleUser,
Content: []bedrockMessageContent{
{
ToolResult: toolResultContent,
},
},
}
}
func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
if chatMessage.IsStringContent() {
return bedrockMessage{
var result bedrockMessage
if len(chatMessage.ToolCalls) > 0 {
result = bedrockMessage{
Role: chatMessage.Role,
Content: []bedrockMessageContent{{}},
}
params := map[string]interface{}{}
json.Unmarshal([]byte(chatMessage.ToolCalls[0].Function.Arguments), &params)
result.Content[0].ToolUse = &toolUseBlock{
Input: params,
Name: chatMessage.ToolCalls[0].Function.Name,
ToolUseId: chatMessage.ToolCalls[0].Id,
}
} else if chatMessage.IsStringContent() {
result = bedrockMessage{
Role: chatMessage.Role,
Content: []bedrockMessageContent{{Text: chatMessage.StringContent()}},
}
@@ -895,29 +1045,22 @@ func chatMessage2BedrockMessage(chatMessage chatMessage) bedrockMessage {
}
contents = append(contents, content)
}
return bedrockMessage{
result = bedrockMessage{
Role: chatMessage.Role,
Content: contents,
}
}
return result
}
func (b *bedrockProvider) setAuthHeaders(body []byte, headers http.Header) {
t := time.Now().UTC()
amzDate := t.Format("20060102T150405Z")
dateStamp := t.Format("20060102")
path, _ := proxywasm.GetHttpRequestHeader(":path")
if headers != nil {
path = headers.Get(":path")
}
path := headers.Get(":path")
signature := b.generateSignature(path, amzDate, dateStamp, body)
if headers != nil {
headers.Set("X-Amz-Date", amzDate)
headers.Set("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
} else {
_ = proxywasm.ReplaceHttpRequestHeader("X-Amz-Date", amzDate)
_ = proxywasm.ReplaceHttpRequestHeader("Authorization", fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
}
headers.Set("X-Amz-Date", amzDate)
util.OverwriteRequestAuthorizationHeader(headers, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s/%s/%s/aws4_request, SignedHeaders=%s, Signature=%s", b.config.awsAccessKey, dateStamp, b.config.awsRegion, awsService, bedrockSignedHeaders, signature))
}
func (b *bedrockProvider) generateSignature(path, amzDate, dateStamp string, body []byte) string {

View File

@@ -36,8 +36,18 @@ type claudeToolChoice struct {
}
type claudeChatMessage struct {
Role string `json:"role"`
Content any `json:"content"`
Role string `json:"role"`
Content claudeChatMessageContentWr `json:"content"`
}
// claudeChatMessageContentWr wraps the content to handle both string and array formats
type claudeChatMessageContentWr struct {
// StringValue holds simple text content
StringValue string
// ArrayValue holds multi-modal content
ArrayValue []claudeChatMessageContent
// IsString indicates whether this is a simple string or array
IsString bool
}
type claudeChatMessageContentSource struct {
@@ -49,23 +59,154 @@ type claudeChatMessageContentSource struct {
}
type claudeChatMessageContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *claudeChatMessageContentSource `json:"source,omitempty"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Source *claudeChatMessageContentSource `json:"source,omitempty"`
CacheControl map[string]interface{} `json:"cache_control,omitempty"`
// Tool use fields
Id string `json:"id,omitempty"` // For tool_use
Name string `json:"name,omitempty"` // For tool_use
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
// Tool result fields
ToolUseId string `json:"tool_use_id,omitempty"` // For tool_result
Content string `json:"content,omitempty"` // For tool_result
}
// UnmarshalJSON implements custom JSON unmarshaling for claudeChatMessageContentWr
func (ccw *claudeChatMessageContentWr) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var stringValue string
if err := json.Unmarshal(data, &stringValue); err == nil {
ccw.StringValue = stringValue
ccw.IsString = true
return nil
}
// Try to unmarshal as array of content blocks
var arrayValue []claudeChatMessageContent
if err := json.Unmarshal(data, &arrayValue); err == nil {
ccw.ArrayValue = arrayValue
ccw.IsString = false
return nil
}
return fmt.Errorf("content field must be either a string or an array of content blocks")
}
// MarshalJSON implements custom JSON marshaling for claudeChatMessageContentWr
func (ccw claudeChatMessageContentWr) MarshalJSON() ([]byte, error) {
if ccw.IsString {
return json.Marshal(ccw.StringValue)
}
return json.Marshal(ccw.ArrayValue)
}
// GetStringValue returns the string representation if it's a string, empty string otherwise
func (ccw claudeChatMessageContentWr) GetStringValue() string {
if ccw.IsString {
return ccw.StringValue
}
return ""
}
// GetArrayValue returns the array representation if it's an array, empty slice otherwise
func (ccw claudeChatMessageContentWr) GetArrayValue() []claudeChatMessageContent {
if !ccw.IsString {
return ccw.ArrayValue
}
return nil
}
// NewStringContent creates a new wrapper for string content
func NewStringContent(content string) claudeChatMessageContentWr {
return claudeChatMessageContentWr{
StringValue: content,
IsString: true,
}
}
// NewArrayContent creates a new wrapper for array content
func NewArrayContent(content []claudeChatMessageContent) claudeChatMessageContentWr {
return claudeChatMessageContentWr{
ArrayValue: content,
IsString: false,
}
}
// claudeSystemPrompt represents the system field which can be either a string or an array of text blocks
type claudeSystemPrompt struct {
// Will be set to the string value if system is a simple string
StringValue string
// Will be set to the array value if system is an array of text blocks
ArrayValue []claudeTextGenContent
// Indicates which type this represents
IsArray bool
}
// UnmarshalJSON implements custom JSON unmarshaling for claudeSystemPrompt
func (csp *claudeSystemPrompt) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var stringValue string
if err := json.Unmarshal(data, &stringValue); err == nil {
csp.StringValue = stringValue
csp.IsArray = false
return nil
}
// Try to unmarshal as array of text blocks
var arrayValue []claudeTextGenContent
if err := json.Unmarshal(data, &arrayValue); err == nil {
csp.ArrayValue = arrayValue
csp.IsArray = true
return nil
}
return fmt.Errorf("system field must be either a string or an array of text blocks")
}
// MarshalJSON implements custom JSON marshaling for claudeSystemPrompt
func (csp claudeSystemPrompt) MarshalJSON() ([]byte, error) {
if csp.IsArray {
return json.Marshal(csp.ArrayValue)
}
return json.Marshal(csp.StringValue)
}
// String returns the string representation of the system prompt
func (csp claudeSystemPrompt) String() string {
if csp.IsArray {
// Concatenate all text blocks
var parts []string
for _, block := range csp.ArrayValue {
if block.Text != "" {
parts = append(parts, block.Text)
}
}
return strings.Join(parts, "\n")
}
return csp.StringValue
}
// claudeThinkingConfig represents the thinking configuration for Claude
type claudeThinkingConfig struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens,omitempty"`
}
type claudeTextGenRequest struct {
Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Model string `json:"model"`
Messages []claudeChatMessage `json:"messages"`
System claudeSystemPrompt `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
ToolChoice *claudeToolChoice `json:"tool_choice,omitempty"`
Tools []claudeTool `json:"tools,omitempty"`
ServiceTier string `json:"service_tier,omitempty"`
Thinking *claudeThinkingConfig `json:"thinking,omitempty"`
}
type claudeTextGenResponse struct {
@@ -81,8 +222,13 @@ type claudeTextGenResponse struct {
}
type claudeTextGenContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
Type string `json:"type,omitempty"`
Text string `json:"text,omitempty"`
Id string `json:"id,omitempty"` // For tool_use
Name string `json:"name,omitempty"` // For tool_use
Input map[string]interface{} `json:"input,omitempty"` // For tool_use
Signature string `json:"signature,omitempty"` // For thinking
Thinking string `json:"thinking,omitempty"` // For thinking
}
type claudeTextGenUsage struct {
@@ -99,7 +245,7 @@ type claudeTextGenError struct {
type claudeTextGenStreamResponse struct {
Type string `json:"type"`
Message *claudeTextGenResponse `json:"message,omitempty"`
Index int `json:"index,omitempty"`
Index *int `json:"index,omitempty"`
ContentBlock *claudeTextGenContent `json:"content_block,omitempty"`
Delta *claudeTextGenDelta `json:"delta,omitempty"`
Usage *claudeTextGenUsage `json:"usage,omitempty"`
@@ -107,13 +253,13 @@ type claudeTextGenStreamResponse struct {
type claudeTextGenDelta struct {
Type string `json:"type"`
Text string `json:"text"`
StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"`
Text string `json:"text,omitempty"`
StopReason *string `json:"stop_reason,omitempty"`
StopSequence *string `json:"stop_sequence,omitempty"`
}
func (c *claudeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
if len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
@@ -255,7 +401,10 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
for _, message := range origRequest.Messages {
if message.Role == roleSystem {
claudeRequest.System = message.StringContent()
claudeRequest.System = claudeSystemPrompt{
StringValue: message.StringContent(),
IsArray: false,
}
continue
}
@@ -263,7 +412,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
claudeMessage.Content = NewStringContent(message.StringContent())
} else {
chatMessageContents := make([]claudeChatMessageContent, 0)
for _, messageContent := range message.ParseContent() {
@@ -310,7 +459,7 @@ func (c *claudeProvider) buildClaudeTextGenRequest(origRequest *chatCompletionRe
continue
}
}
claudeMessage.Content = chatMessageContents
claudeMessage.Content = NewArrayContent(chatMessageContents)
}
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
}
@@ -342,19 +491,25 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.StopReason)),
}
return &chatCompletionResponse{
response := &chatCompletionResponse{
Id: origResponse.Id,
Created: time.Now().UnixMilli() / 1000,
Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""),
SystemFingerprint: "",
Object: objectChatCompletion,
Choices: []chatCompletionChoice{choice},
Usage: &usage{
}
// Include usage information if available
if origResponse.Usage.InputTokens > 0 || origResponse.Usage.OutputTokens > 0 {
response.Usage = &usage{
PromptTokens: origResponse.Usage.InputTokens,
CompletionTokens: origResponse.Usage.OutputTokens,
TotalTokens: origResponse.Usage.InputTokens + origResponse.Usage.OutputTokens,
},
}
}
return response
}
func stopReasonClaude2OpenAI(reason *string) string {
@@ -376,31 +531,47 @@ func stopReasonClaude2OpenAI(reason *string) string {
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
switch origResponse.Type {
case "message_start":
c.messageId = origResponse.Message.Id
c.usage = usage{
PromptTokens: origResponse.Message.Usage.InputTokens,
CompletionTokens: origResponse.Message.Usage.OutputTokens,
if origResponse.Message != nil {
c.messageId = origResponse.Message.Id
c.usage = usage{
PromptTokens: origResponse.Message.Usage.InputTokens,
CompletionTokens: origResponse.Message.Usage.OutputTokens,
}
c.serviceTier = origResponse.Message.Usage.ServiceTier
}
var index int
if origResponse.Index != nil {
index = *origResponse.Index
}
c.serviceTier = origResponse.Message.Usage.ServiceTier
choice := chatCompletionChoice{
Index: origResponse.Index,
Index: index,
Delta: &chatMessage{Role: roleAssistant, Content: ""},
}
return c.createChatCompletionResponse(ctx, origResponse, choice)
case "content_block_delta":
var index int
if origResponse.Index != nil {
index = *origResponse.Index
}
choice := chatCompletionChoice{
Index: origResponse.Index,
Index: index,
Delta: &chatMessage{Content: origResponse.Delta.Text},
}
return c.createChatCompletionResponse(ctx, origResponse, choice)
case "message_delta":
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
if origResponse.Usage != nil {
c.usage.CompletionTokens += origResponse.Usage.OutputTokens
c.usage.TotalTokens = c.usage.PromptTokens + c.usage.CompletionTokens
}
var index int
if origResponse.Index != nil {
index = *origResponse.Index
}
choice := chatCompletionChoice{
Index: origResponse.Index,
Index: index,
Delta: &chatMessage{},
FinishReason: util.Ptr(stopReasonClaude2OpenAI(origResponse.Delta.StopReason)),
}
@@ -449,10 +620,17 @@ func (c *claudeProvider) insertHttpContextMessage(body []byte, content string, o
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
}
if request.System == "" {
request.System = content
systemStr := request.System.String()
if systemStr == "" {
request.System = claudeSystemPrompt{
StringValue: content,
IsArray: false,
}
} else {
request.System = content + "\n" + request.System
request.System = claudeSystemPrompt{
StringValue: content + "\n" + systemStr,
IsArray: false,
}
}
return json.Marshal(request)

View File

@@ -0,0 +1,824 @@
package provider
import (
"encoding/json"
"fmt"
"strings"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// ClaudeToOpenAIConverter converts Claude protocol requests to OpenAI protocol
type ClaudeToOpenAIConverter struct {
// State tracking for streaming conversion
messageStartSent bool
messageStopSent bool
messageId string
// Cache stop_reason until we get usage info
pendingStopReason *string
// Content block tracking with dynamic index allocation
nextContentIndex int
thinkingBlockIndex int
thinkingBlockStarted bool
thinkingBlockStopped bool
textBlockIndex int
textBlockStarted bool
textBlockStopped bool
toolBlockIndex int
toolBlockStarted bool
toolBlockStopped bool
// Tool call state tracking
toolCallStates map[string]*toolCallState
}
// contentConversionResult represents the result of converting Claude content to OpenAI format
type contentConversionResult struct {
textParts []string
toolCalls []toolCall
toolResults []claudeChatMessageContent
openaiContents []chatMessageContent
hasNonTextContent bool
}
// toolCallState tracks the state of a tool call during streaming
type toolCallState struct {
id string
name string
argumentsBuffer string
isComplete bool
}
// ConvertClaudeRequestToOpenAI converts a Claude chat completion request to OpenAI format
func (c *ClaudeToOpenAIConverter) ConvertClaudeRequestToOpenAI(body []byte) ([]byte, error) {
log.Debugf("[Claude->OpenAI] Original Claude request body: %s", string(body))
var claudeRequest claudeTextGenRequest
if err := json.Unmarshal(body, &claudeRequest); err != nil {
return nil, fmt.Errorf("unable to unmarshal claude request: %v", err)
}
// Convert Claude request to OpenAI format
openaiRequest := chatCompletionRequest{
Model: claudeRequest.Model,
Stream: claudeRequest.Stream,
Temperature: claudeRequest.Temperature,
TopP: claudeRequest.TopP,
MaxTokens: claudeRequest.MaxTokens,
Stop: claudeRequest.StopSequences,
}
// Convert messages from Claude format to OpenAI format
for _, claudeMsg := range claudeRequest.Messages {
// Handle different content types using the type-safe wrapper
if claudeMsg.Content.IsString {
// Simple text content
openaiMsg := chatMessage{
Role: claudeMsg.Role,
Content: claudeMsg.Content.GetStringValue(),
}
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
} else {
// Multi-modal content - process with convertContentArray
conversionResult := c.convertContentArray(claudeMsg.Content.GetArrayValue())
// Handle tool calls if present
if len(conversionResult.toolCalls) > 0 {
// Use tool_calls format (current OpenAI standard)
openaiMsg := chatMessage{
Role: claudeMsg.Role,
ToolCalls: conversionResult.toolCalls,
}
// Add text content if present, otherwise set to null
if len(conversionResult.textParts) > 0 {
openaiMsg.Content = strings.Join(conversionResult.textParts, "\n\n")
} else {
openaiMsg.Content = nil
}
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
}
// Handle tool results if present
if len(conversionResult.toolResults) > 0 {
for _, toolResult := range conversionResult.toolResults {
toolMsg := chatMessage{
Role: "tool",
Content: toolResult.Content,
ToolCallId: toolResult.ToolUseId,
}
openaiRequest.Messages = append(openaiRequest.Messages, toolMsg)
}
}
// Handle regular content if no tool calls or tool results
if len(conversionResult.toolCalls) == 0 && len(conversionResult.toolResults) == 0 {
var content interface{}
if !conversionResult.hasNonTextContent && len(conversionResult.textParts) > 0 {
// Simple text content
content = strings.Join(conversionResult.textParts, "\n\n")
} else {
// Multi-modal content or empty content
content = conversionResult.openaiContents
}
openaiMsg := chatMessage{
Role: claudeMsg.Role,
Content: content,
}
openaiRequest.Messages = append(openaiRequest.Messages, openaiMsg)
}
}
}
// Handle system message - Claude has separate system field
systemStr := claudeRequest.System.String()
if systemStr != "" {
systemMsg := chatMessage{
Role: roleSystem,
Content: systemStr,
}
// Insert system message at the beginning
openaiRequest.Messages = append([]chatMessage{systemMsg}, openaiRequest.Messages...)
}
// Convert tools if present
for _, claudeTool := range claudeRequest.Tools {
openaiTool := tool{
Type: "function",
Function: function{
Name: claudeTool.Name,
Description: claudeTool.Description,
Parameters: claudeTool.InputSchema,
},
}
openaiRequest.Tools = append(openaiRequest.Tools, openaiTool)
}
// Convert tool choice if present
if claudeRequest.ToolChoice != nil {
if claudeRequest.ToolChoice.Type == "tool" && claudeRequest.ToolChoice.Name != "" {
openaiRequest.ToolChoice = &toolChoice{
Type: "function",
Function: function{
Name: claudeRequest.ToolChoice.Name,
},
}
} else {
// For other types like "auto", "none", etc.
openaiRequest.ToolChoice = claudeRequest.ToolChoice.Type
}
// Handle parallel tool calls
openaiRequest.ParallelToolCalls = !claudeRequest.ToolChoice.DisableParallelToolUse
}
// Convert thinking configuration if present
if claudeRequest.Thinking != nil {
log.Debugf("[Claude->OpenAI] Found thinking config: type=%s, budget_tokens=%d",
claudeRequest.Thinking.Type, claudeRequest.Thinking.BudgetTokens)
if claudeRequest.Thinking.Type == "enabled" {
openaiRequest.ReasoningMaxTokens = claudeRequest.Thinking.BudgetTokens
// Set ReasoningEffort based on budget_tokens
// low: <4096, medium: >=4096 and <16384, high: >=16384
if claudeRequest.Thinking.BudgetTokens < 4096 {
openaiRequest.ReasoningEffort = "low"
} else if claudeRequest.Thinking.BudgetTokens < 16384 {
openaiRequest.ReasoningEffort = "medium"
} else {
openaiRequest.ReasoningEffort = "high"
}
log.Debugf("[Claude->OpenAI] Converted thinking config: budget_tokens=%d, reasoning_effort=%s, reasoning_max_tokens=%d",
claudeRequest.Thinking.BudgetTokens, openaiRequest.ReasoningEffort, openaiRequest.ReasoningMaxTokens)
}
} else {
log.Debugf("[Claude->OpenAI] No thinking config found")
}
result, err := json.Marshal(openaiRequest)
if err != nil {
return nil, fmt.Errorf("unable to marshal openai request: %v", err)
}
log.Debugf("[Claude->OpenAI] Converted OpenAI request body: %s", string(result))
return result, nil
}
// ConvertOpenAIResponseToClaude converts an OpenAI response back to Claude format
func (c *ClaudeToOpenAIConverter) ConvertOpenAIResponseToClaude(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
log.Debugf("[OpenAI->Claude] Original OpenAI response body: %s", string(body))
var openaiResponse chatCompletionResponse
if err := json.Unmarshal(body, &openaiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal openai response: %v", err)
}
// Convert OpenAI response to Claude format
claudeResponse := claudeTextGenResponse{
Id: openaiResponse.Id,
Type: "message",
Role: "assistant",
Model: openaiResponse.Model,
}
// Only include usage if it's available
if openaiResponse.Usage != nil {
claudeResponse.Usage = claudeTextGenUsage{
InputTokens: openaiResponse.Usage.PromptTokens,
OutputTokens: openaiResponse.Usage.CompletionTokens,
}
}
// Convert the first choice content
if len(openaiResponse.Choices) > 0 {
choice := openaiResponse.Choices[0]
if choice.Message != nil {
var contents []claudeTextGenContent
// Add reasoning content (thinking) if present - check both reasoning and reasoning_content fields
var reasoningText string
if choice.Message.Reasoning != "" {
reasoningText = choice.Message.Reasoning
} else if choice.Message.ReasoningContent != "" {
reasoningText = choice.Message.ReasoningContent
}
if reasoningText != "" {
contents = append(contents, claudeTextGenContent{
Type: "thinking",
Signature: "", // OpenAI doesn't provide signature, use empty string
Thinking: reasoningText,
})
log.Debugf("[OpenAI->Claude] Added thinking content: %s", reasoningText)
}
// Add text content if present
if choice.Message.StringContent() != "" {
contents = append(contents, claudeTextGenContent{
Type: "text",
Text: choice.Message.StringContent(),
})
}
// Add tool calls if present
if len(choice.Message.ToolCalls) > 0 {
for _, toolCall := range choice.Message.ToolCalls {
if !toolCall.Function.IsEmpty() {
// Parse arguments from JSON string to map
var input map[string]interface{}
if toolCall.Function.Arguments != "" {
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil {
log.Errorf("Failed to parse tool call arguments: %v", err)
input = map[string]interface{}{}
}
} else {
input = map[string]interface{}{}
}
contents = append(contents, claudeTextGenContent{
Type: "tool_use",
Id: toolCall.Id,
Name: toolCall.Function.Name,
Input: input,
})
}
}
}
claudeResponse.Content = contents
}
// Convert finish reason
if choice.FinishReason != nil {
claudeFinishReason := openAIFinishReasonToClaude(*choice.FinishReason)
claudeResponse.StopReason = &claudeFinishReason
}
}
result, err := json.Marshal(claudeResponse)
if err != nil {
return nil, fmt.Errorf("unable to marshal claude response: %v", err)
}
log.Debugf("[OpenAI->Claude] Converted Claude response body: %s", string(result))
return result, nil
}
// ConvertOpenAIStreamResponseToClaude converts OpenAI streaming response to Claude format
func (c *ClaudeToOpenAIConverter) ConvertOpenAIStreamResponseToClaude(ctx wrapper.HttpContext, chunk []byte) ([]byte, error) {
log.Debugf("[OpenAI->Claude] Original OpenAI streaming chunk: %s", string(chunk))
// Initialize tool call states if needed
if c.toolCallStates == nil {
c.toolCallStates = make(map[string]*toolCallState)
}
// For streaming responses, we need to handle the Server-Sent Events format
lines := strings.Split(string(chunk), "\n")
var result strings.Builder
for _, line := range lines {
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// Handle [DONE] messages
if data == "[DONE]" {
log.Debugf("[OpenAI->Claude] Processing [DONE] message, finalizing stream")
// Send final content_block_stop events for any active blocks
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
c.thinkingBlockStopped = true
log.Debugf("[OpenAI->Claude] Sending final thinking content_block_stop event at index %d", c.thinkingBlockIndex)
stopEvent := &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.thinkingBlockIndex,
}
stopData, _ := json.Marshal(stopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
}
if c.textBlockStarted && !c.textBlockStopped {
c.textBlockStopped = true
log.Debugf("[OpenAI->Claude] Sending final text content_block_stop event at index %d", c.textBlockIndex)
stopEvent := &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.textBlockIndex,
}
stopData, _ := json.Marshal(stopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
}
if c.toolBlockStarted && !c.toolBlockStopped {
c.toolBlockStopped = true
log.Debugf("[OpenAI->Claude] Sending final tool content_block_stop event at index %d", c.toolBlockIndex)
stopEvent := &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.toolBlockIndex,
}
stopData, _ := json.Marshal(stopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
}
// If we have a pending stop_reason but no usage, send message_delta with just stop_reason
if c.pendingStopReason != nil {
log.Debugf("[OpenAI->Claude] Sending final message_delta with pending stop_reason: %s", *c.pendingStopReason)
messageDelta := &claudeTextGenStreamResponse{
Type: "message_delta",
Delta: &claudeTextGenDelta{
Type: "message_delta",
StopReason: c.pendingStopReason,
},
}
stopData, _ := json.Marshal(messageDelta)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
c.pendingStopReason = nil
}
if c.messageStartSent && !c.messageStopSent {
c.messageStopSent = true
log.Debugf("[OpenAI->Claude] Sending final message_stop event")
messageStopEvent := &claudeTextGenStreamResponse{
Type: "message_stop",
}
stopData, _ := json.Marshal(messageStopEvent)
result.WriteString(fmt.Sprintf("data: %s\n\n", stopData))
}
// Reset all state for next request
c.messageStartSent = false
c.messageStopSent = false
c.messageId = ""
c.pendingStopReason = nil
c.nextContentIndex = 0
c.thinkingBlockIndex = -1
c.thinkingBlockStarted = false
c.thinkingBlockStopped = false
c.textBlockIndex = -1
c.textBlockStarted = false
c.textBlockStopped = false
c.toolBlockIndex = -1
c.toolBlockStarted = false
c.toolBlockStopped = false
c.toolCallStates = make(map[string]*toolCallState)
log.Debugf("[OpenAI->Claude] Reset converter state for next request")
continue
}
var openaiStreamResponse chatCompletionResponse
if err := json.Unmarshal([]byte(data), &openaiStreamResponse); err != nil {
log.Debugf("unable to unmarshal openai stream response: %v, data: %s", err, data)
continue
}
// Convert to Claude streaming format
claudeStreamResponses := c.buildClaudeStreamResponse(ctx, &openaiStreamResponse)
log.Debugf("[OpenAI->Claude] Generated %d Claude stream events from OpenAI chunk", len(claudeStreamResponses))
for i, claudeStreamResponse := range claudeStreamResponses {
responseData, err := json.Marshal(claudeStreamResponse)
if err != nil {
log.Errorf("unable to marshal claude stream response: %v", err)
continue
}
log.Debugf("[OpenAI->Claude] Stream event [%d/%d]: %s", i+1, len(claudeStreamResponses), string(responseData))
result.WriteString(fmt.Sprintf("data: %s\n\n", responseData))
}
}
}
claudeChunk := []byte(result.String())
log.Debugf("[OpenAI->Claude] Converted Claude streaming chunk: %s", string(claudeChunk))
return claudeChunk, nil
}
// buildClaudeStreamResponse builds Claude streaming responses from OpenAI streaming response
func (c *ClaudeToOpenAIConverter) buildClaudeStreamResponse(ctx wrapper.HttpContext, openaiResponse *chatCompletionResponse) []*claudeTextGenStreamResponse {
if len(openaiResponse.Choices) == 0 {
log.Debugf("[OpenAI->Claude] No choices in OpenAI response, skipping")
return nil
}
choice := openaiResponse.Choices[0]
var responses []*claudeTextGenStreamResponse
// Log what we're processing
hasRole := choice.Delta != nil && choice.Delta.Role != ""
hasContent := choice.Delta != nil && choice.Delta.Content != ""
hasFinishReason := choice.FinishReason != nil
hasUsage := openaiResponse.Usage != nil
log.Debugf("[OpenAI->Claude] Processing OpenAI chunk - Role: %v, Content: %v, FinishReason: %v, Usage: %v",
hasRole, hasContent, hasFinishReason, hasUsage)
// Handle message start (only once)
// Note: OpenRouter may send multiple messages with role but empty content at the start
// We only send message_start for the first one
if choice.Delta != nil && choice.Delta.Role != "" && !c.messageStartSent {
c.messageId = openaiResponse.Id
c.messageStartSent = true
message := &claudeTextGenResponse{
Id: openaiResponse.Id,
Type: "message",
Role: "assistant",
Model: openaiResponse.Model,
Content: []claudeTextGenContent{},
}
// Only include usage if it's available
if openaiResponse.Usage != nil {
message.Usage = claudeTextGenUsage{
InputTokens: openaiResponse.Usage.PromptTokens,
OutputTokens: 0,
}
}
responses = append(responses, &claudeTextGenStreamResponse{
Type: "message_start",
Message: message,
})
log.Debugf("[OpenAI->Claude] Generated message_start event for id: %s", openaiResponse.Id)
} else if choice.Delta != nil && choice.Delta.Role != "" && c.messageStartSent {
// Skip duplicate role messages from OpenRouter
log.Debugf("[OpenAI->Claude] Skipping duplicate role message for id: %s", openaiResponse.Id)
}
// Handle reasoning content (thinking) first - check both reasoning and reasoning_content fields
var reasoningText string
if choice.Delta != nil {
if choice.Delta.Reasoning != "" {
reasoningText = choice.Delta.Reasoning
} else if choice.Delta.ReasoningContent != "" {
reasoningText = choice.Delta.ReasoningContent
}
}
if reasoningText != "" {
log.Debugf("[OpenAI->Claude] Processing reasoning content delta: %s", reasoningText)
// Send content_block_start for thinking only once with dynamic index
if !c.thinkingBlockStarted {
c.thinkingBlockIndex = c.nextContentIndex
c.nextContentIndex++
c.thinkingBlockStarted = true
log.Debugf("[OpenAI->Claude] Generated content_block_start event for thinking at index %d", c.thinkingBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_start",
Index: &c.thinkingBlockIndex,
ContentBlock: &claudeTextGenContent{
Type: "thinking",
Signature: "", // OpenAI doesn't provide signature
Thinking: "",
},
})
}
// Send content_block_delta for thinking
log.Debugf("[OpenAI->Claude] Generated content_block_delta event with thinking: %s", reasoningText)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_delta",
Index: &c.thinkingBlockIndex,
Delta: &claudeTextGenDelta{
Type: "thinking_delta", // Use thinking_delta for reasoning content
Text: reasoningText,
},
})
}
// Handle content
if choice.Delta != nil && choice.Delta.Content != nil && choice.Delta.Content != "" {
deltaContent, ok := choice.Delta.Content.(string)
if !ok {
log.Debugf("[OpenAI->Claude] Content is not a string: %T", choice.Delta.Content)
return responses
}
log.Debugf("[OpenAI->Claude] Processing content delta: %s", deltaContent)
// Close thinking content block if it's still open
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
c.thinkingBlockStopped = true
log.Debugf("[OpenAI->Claude] Closing thinking content block before text")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.thinkingBlockIndex,
})
}
// Send content_block_start only once for text content with dynamic index
if !c.textBlockStarted {
c.textBlockIndex = c.nextContentIndex
c.nextContentIndex++
c.textBlockStarted = true
log.Debugf("[OpenAI->Claude] Generated content_block_start event for text at index %d", c.textBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_start",
Index: &c.textBlockIndex,
ContentBlock: &claudeTextGenContent{
Type: "text",
Text: "",
},
})
}
// Send content_block_delta
log.Debugf("[OpenAI->Claude] Generated content_block_delta event with text: %s", deltaContent)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_delta",
Index: &c.textBlockIndex,
Delta: &claudeTextGenDelta{
Type: "text_delta",
Text: deltaContent,
},
})
}
// Handle tool calls in streaming response
if choice.Delta != nil && len(choice.Delta.ToolCalls) > 0 {
for _, toolCall := range choice.Delta.ToolCalls {
if !toolCall.Function.IsEmpty() {
log.Debugf("[OpenAI->Claude] Processing tool call delta")
// Get or create tool call state
state := c.toolCallStates[toolCall.Id]
if state == nil {
state = &toolCallState{
id: toolCall.Id,
name: toolCall.Function.Name,
argumentsBuffer: "",
isComplete: false,
}
c.toolCallStates[toolCall.Id] = state
log.Debugf("[OpenAI->Claude] Created new tool call state for id: %s, name: %s", toolCall.Id, toolCall.Function.Name)
}
// Accumulate arguments
if toolCall.Function.Arguments != "" {
state.argumentsBuffer += toolCall.Function.Arguments
log.Debugf("[OpenAI->Claude] Accumulated tool arguments: %s", state.argumentsBuffer)
}
// Try to parse accumulated arguments as JSON to check if complete
var input map[string]interface{}
if state.argumentsBuffer != "" {
if err := json.Unmarshal([]byte(state.argumentsBuffer), &input); err == nil {
// Successfully parsed - arguments are complete
if !state.isComplete {
state.isComplete = true
log.Debugf("[OpenAI->Claude] Tool call arguments complete for %s: %s", state.name, state.argumentsBuffer)
// Close thinking content block if it's still open
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
c.thinkingBlockStopped = true
log.Debugf("[OpenAI->Claude] Closing thinking content block before tool use")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.thinkingBlockIndex,
})
}
// Close text content block if it's still open
if c.textBlockStarted && !c.textBlockStopped {
c.textBlockStopped = true
log.Debugf("[OpenAI->Claude] Closing text content block before tool use")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.textBlockIndex,
})
}
// Send content_block_start for tool_use only when we have complete arguments with dynamic index
if !c.toolBlockStarted {
c.toolBlockIndex = c.nextContentIndex
c.nextContentIndex++
c.toolBlockStarted = true
log.Debugf("[OpenAI->Claude] Generated content_block_start event for tool_use at index %d", c.toolBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_start",
Index: &c.toolBlockIndex,
ContentBlock: &claudeTextGenContent{
Type: "tool_use",
Id: toolCall.Id,
Name: state.name,
Input: input,
},
})
}
}
} else {
// Still accumulating arguments
log.Debugf("[OpenAI->Claude] Tool arguments not yet complete, continuing to accumulate: %v", err)
}
}
}
}
}
// Handle finish reason
if choice.FinishReason != nil {
claudeFinishReason := openAIFinishReasonToClaude(*choice.FinishReason)
log.Debugf("[OpenAI->Claude] Processing finish_reason: %s -> %s", *choice.FinishReason, claudeFinishReason)
// Send content_block_stop for any active content blocks
if c.thinkingBlockStarted && !c.thinkingBlockStopped {
c.thinkingBlockStopped = true
log.Debugf("[OpenAI->Claude] Generated thinking content_block_stop event at index %d", c.thinkingBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.thinkingBlockIndex,
})
}
if c.textBlockStarted && !c.textBlockStopped {
c.textBlockStopped = true
log.Debugf("[OpenAI->Claude] Generated text content_block_stop event at index %d", c.textBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.textBlockIndex,
})
}
if c.toolBlockStarted && !c.toolBlockStopped {
c.toolBlockStopped = true
log.Debugf("[OpenAI->Claude] Generated tool content_block_stop event at index %d", c.toolBlockIndex)
responses = append(responses, &claudeTextGenStreamResponse{
Type: "content_block_stop",
Index: &c.toolBlockIndex,
})
}
// Cache stop_reason until we get usage info (Claude protocol requires them together)
c.pendingStopReason = &claudeFinishReason
log.Debugf("[OpenAI->Claude] Cached stop_reason: %s, waiting for usage", claudeFinishReason)
}
// Handle usage information
if openaiResponse.Usage != nil && choice.FinishReason == nil {
log.Debugf("[OpenAI->Claude] Processing usage info - input: %d, output: %d",
openaiResponse.Usage.PromptTokens, openaiResponse.Usage.CompletionTokens)
// Send message_delta with both stop_reason and usage (Claude protocol requirement)
messageDelta := &claudeTextGenStreamResponse{
Type: "message_delta",
Delta: &claudeTextGenDelta{
Type: "message_delta",
},
Usage: &claudeTextGenUsage{
InputTokens: openaiResponse.Usage.PromptTokens,
OutputTokens: openaiResponse.Usage.CompletionTokens,
},
}
// Include cached stop_reason if available
if c.pendingStopReason != nil {
log.Debugf("[OpenAI->Claude] Combining cached stop_reason %s with usage", *c.pendingStopReason)
messageDelta.Delta.StopReason = c.pendingStopReason
c.pendingStopReason = nil // Clear cache
}
log.Debugf("[OpenAI->Claude] Generated message_delta event with usage and stop_reason")
responses = append(responses, messageDelta)
// Send message_stop after combined message_delta
if !c.messageStopSent {
c.messageStopSent = true
log.Debugf("[OpenAI->Claude] Generated message_stop event")
responses = append(responses, &claudeTextGenStreamResponse{
Type: "message_stop",
})
}
}
return responses
}
// openAIFinishReasonToClaude converts OpenAI finish reason to Claude format
func openAIFinishReasonToClaude(reason string) string {
switch reason {
case finishReasonStop:
return "end_turn"
case finishReasonLength:
return "max_tokens"
case finishReasonToolCall:
return "tool_use"
default:
return reason
}
}
// convertContentArray converts an array of Claude content to OpenAI content format
func (c *ClaudeToOpenAIConverter) convertContentArray(claudeContents []claudeChatMessageContent) *contentConversionResult {
result := &contentConversionResult{
textParts: []string{},
toolCalls: []toolCall{},
toolResults: []claudeChatMessageContent{},
openaiContents: []chatMessageContent{},
hasNonTextContent: false,
}
for _, claudeContent := range claudeContents {
switch claudeContent.Type {
case "text":
if claudeContent.Text != "" {
result.textParts = append(result.textParts, claudeContent.Text)
result.openaiContents = append(result.openaiContents, chatMessageContent{
Type: contentTypeText,
Text: claudeContent.Text,
})
}
case "image":
result.hasNonTextContent = true
if claudeContent.Source != nil {
if claudeContent.Source.Type == "base64" {
// Convert base64 image to OpenAI format
dataUrl := fmt.Sprintf("data:%s;base64,%s", claudeContent.Source.MediaType, claudeContent.Source.Data)
result.openaiContents = append(result.openaiContents, chatMessageContent{
Type: contentTypeImageUrl,
ImageUrl: &chatMessageContentImageUrl{
Url: dataUrl,
},
})
} else if claudeContent.Source.Type == "url" {
result.openaiContents = append(result.openaiContents, chatMessageContent{
Type: contentTypeImageUrl,
ImageUrl: &chatMessageContentImageUrl{
Url: claudeContent.Source.Url,
},
})
}
}
case "tool_use":
result.hasNonTextContent = true
// Convert Claude tool_use to OpenAI tool_calls format
if claudeContent.Id != "" && claudeContent.Name != "" {
// Convert input to JSON string for OpenAI format
var argumentsStr string
if claudeContent.Input != nil {
if argBytes, err := json.Marshal(claudeContent.Input); err == nil {
argumentsStr = string(argBytes)
}
}
toolCall := toolCall{
Id: claudeContent.Id,
Type: "function",
Function: functionCall{
Name: claudeContent.Name,
Arguments: argumentsStr,
},
}
result.toolCalls = append(result.toolCalls, toolCall)
log.Debugf("[Claude->OpenAI] Converted tool_use to tool_call: %s", claudeContent.Name)
}
case "tool_result":
result.hasNonTextContent = true
// Store tool results for processing
result.toolResults = append(result.toolResults, claudeContent)
log.Debugf("[Claude->OpenAI] Found tool_result for tool_use_id: %s", claudeContent.ToolUseId)
}
}
return result
}

View File

@@ -0,0 +1,727 @@
package provider
import (
"encoding/json"
"testing"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Mock logger for testing
type mockLogger struct{}
func (m *mockLogger) Trace(msg string) {}
func (m *mockLogger) Tracef(format string, args ...interface{}) {}
func (m *mockLogger) Debug(msg string) {}
func (m *mockLogger) Debugf(format string, args ...interface{}) {}
func (m *mockLogger) Info(msg string) {}
func (m *mockLogger) Infof(format string, args ...interface{}) {}
func (m *mockLogger) Warn(msg string) {}
func (m *mockLogger) Warnf(format string, args ...interface{}) {}
func (m *mockLogger) Error(msg string) {}
func (m *mockLogger) Errorf(format string, args ...interface{}) {}
func (m *mockLogger) Critical(msg string) {}
func (m *mockLogger) Criticalf(format string, args ...interface{}) {}
func (m *mockLogger) ResetID(pluginID string) {}
func init() {
// Initialize mock logger for testing
log.SetPluginLog(&mockLogger{})
}
func TestClaudeToOpenAIConverter_ConvertClaudeRequestToOpenAI(t *testing.T) {
converter := &ClaudeToOpenAIConverter{}
t.Run("convert_multiple_text_content_blocks", func(t *testing.T) {
// Test case for the bug fix: multiple text content blocks should be merged into a single string
claudeRequest := `{
"max_tokens": 32000,
"messages": [{
"content": [{
"text": "<system-reminder>\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.</system-reminder>",
"type": "text"
}, {
"text": "<system-reminder>\nyyy</system-reminder>",
"type": "text"
}, {
"cache_control": {
"type": "ephemeral"
},
"text": "你是谁",
"type": "text"
}],
"role": "user"
}],
"metadata": {
"user_id": "user_dd3c52c1d698a4486bdef490197846b7c1f7e553202dae5763f330c35aeb9823_account__session_b2e14122-0ac6-4959-9c5d-b49ae01ccb7c"
},
"model": "anthropic/claude-sonnet-4",
"stream": true,
"system": [{
"cache_control": {
"type": "ephemeral"
},
"text": "xxx",
"type": "text"
}, {
"cache_control": {
"type": "ephemeral"
},
"text": "yyy",
"type": "text"
}],
"temperature": 1,
"stream_options": {
"include_usage": true
}
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
// Parse the result to verify the conversion
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Verify basic fields are converted correctly
assert.Equal(t, "anthropic/claude-sonnet-4", openaiRequest.Model)
assert.Equal(t, true, openaiRequest.Stream)
assert.Equal(t, 1.0, openaiRequest.Temperature)
assert.Equal(t, 32000, openaiRequest.MaxTokens)
// Verify messages structure
require.Len(t, openaiRequest.Messages, 2)
// First message should be system message (converted from Claude's system field)
systemMsg := openaiRequest.Messages[0]
assert.Equal(t, roleSystem, systemMsg.Role)
assert.Equal(t, "xxx\nyyy", systemMsg.Content) // Claude system uses single \n
// Second message should be user message with merged text content
userMsg := openaiRequest.Messages[1]
assert.Equal(t, "user", userMsg.Role)
// The key fix: multiple text blocks should be merged into a single string
expectedContent := "<system-reminder>\nThis is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware. If you are working on tasks that would benefit from a todo list please use the TodoWrite tool to create one. If not, please feel free to ignore. Again do not mention this message to the user.</system-reminder>\n\n<system-reminder>\nyyy</system-reminder>\n\n你是谁"
assert.Equal(t, expectedContent, userMsg.Content)
})
t.Run("convert_mixed_content_with_image", func(t *testing.T) {
// Test case with mixed text and image content (should remain as array)
claudeRequest := `{
"model": "claude-3-sonnet-20240229",
"messages": [{
"role": "user",
"content": [{
"type": "text",
"text": "What's in this image?"
}, {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
}
}]
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Should have one user message
require.Len(t, openaiRequest.Messages, 1)
userMsg := openaiRequest.Messages[0]
assert.Equal(t, "user", userMsg.Role)
// Content should be an array (mixed content) - after JSON marshaling/unmarshaling it becomes []interface{}
contentArray, ok := userMsg.Content.([]interface{})
require.True(t, ok, "Content should be an array for mixed content")
require.Len(t, contentArray, 2)
// First element should be text
firstElement, ok := contentArray[0].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeText, firstElement["type"])
assert.Equal(t, "What's in this image?", firstElement["text"])
// Second element should be image
secondElement, ok := contentArray[1].(map[string]interface{})
require.True(t, ok)
assert.Equal(t, contentTypeImageUrl, secondElement["type"])
assert.NotNil(t, secondElement["image_url"])
imageUrl, ok := secondElement["image_url"].(map[string]interface{})
require.True(t, ok)
assert.Contains(t, imageUrl["url"], "data:image/jpeg;base64,")
})
t.Run("convert_simple_string_content", func(t *testing.T) {
// Test case with simple string content
claudeRequest := `{
"model": "claude-3-sonnet-20240229",
"messages": [{
"role": "user",
"content": "Hello, how are you?"
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
require.Len(t, openaiRequest.Messages, 1)
userMsg := openaiRequest.Messages[0]
assert.Equal(t, "user", userMsg.Role)
assert.Equal(t, "Hello, how are you?", userMsg.Content)
})
t.Run("convert_empty_content_array", func(t *testing.T) {
// Test case with empty content array
claudeRequest := `{
"model": "claude-3-sonnet-20240229",
"messages": [{
"role": "user",
"content": []
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
require.Len(t, openaiRequest.Messages, 1)
userMsg := openaiRequest.Messages[0]
assert.Equal(t, "user", userMsg.Role)
// Empty array should result in empty array, not string - after JSON marshaling/unmarshaling becomes []interface{}
if userMsg.Content != nil {
contentArray, ok := userMsg.Content.([]interface{})
require.True(t, ok, "Empty content should be an array")
assert.Empty(t, contentArray)
} else {
// null is also acceptable for empty content
assert.Nil(t, userMsg.Content)
}
})
t.Run("convert_tool_use_to_tool_calls", func(t *testing.T) {
// Test Claude tool_use conversion to OpenAI tool_calls format
claudeRequest := `{
"model": "anthropic/claude-sonnet-4",
"messages": [{
"role": "assistant",
"content": [{
"type": "text",
"text": "I'll help you search for information."
}, {
"type": "tool_use",
"id": "toolu_01D7FLrfh4GYq7yT1ULFeyMV",
"name": "web_search",
"input": {
"query": "Claude AI capabilities",
"max_results": 5
}
}]
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Should have one assistant message with tool_calls
require.Len(t, openaiRequest.Messages, 1)
assistantMsg := openaiRequest.Messages[0]
assert.Equal(t, "assistant", assistantMsg.Role)
assert.Equal(t, "I'll help you search for information.", assistantMsg.Content)
// Verify tool_calls format
require.NotNil(t, assistantMsg.ToolCalls)
require.Len(t, assistantMsg.ToolCalls, 1)
toolCall := assistantMsg.ToolCalls[0]
assert.Equal(t, "toolu_01D7FLrfh4GYq7yT1ULFeyMV", toolCall.Id)
assert.Equal(t, "function", toolCall.Type)
assert.Equal(t, "web_search", toolCall.Function.Name)
// Verify arguments are properly JSON encoded
var args map[string]interface{}
err = json.Unmarshal([]byte(toolCall.Function.Arguments), &args)
require.NoError(t, err)
assert.Equal(t, "Claude AI capabilities", args["query"])
assert.Equal(t, float64(5), args["max_results"])
})
t.Run("convert_tool_result_to_tool_message", func(t *testing.T) {
// Test Claude tool_result conversion to OpenAI tool message format
claudeRequest := `{
"model": "anthropic/claude-sonnet-4",
"messages": [{
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": "toolu_01D7FLrfh4GYq7yT1ULFeyMV",
"content": "Search results: Claude is an AI assistant created by Anthropic."
}]
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Should have one tool message
require.Len(t, openaiRequest.Messages, 1)
toolMsg := openaiRequest.Messages[0]
assert.Equal(t, "tool", toolMsg.Role)
assert.Equal(t, "Search results: Claude is an AI assistant created by Anthropic.", toolMsg.Content)
assert.Equal(t, "toolu_01D7FLrfh4GYq7yT1ULFeyMV", toolMsg.ToolCallId)
})
t.Run("convert_multiple_tool_calls", func(t *testing.T) {
// Test multiple tool_use in single message
claudeRequest := `{
"model": "anthropic/claude-sonnet-4",
"messages": [{
"role": "assistant",
"content": [{
"type": "tool_use",
"id": "toolu_search",
"name": "web_search",
"input": {"query": "weather"}
}, {
"type": "tool_use",
"id": "toolu_calc",
"name": "calculate",
"input": {"expression": "2+2"}
}]
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Should have one assistant message with multiple tool_calls
require.Len(t, openaiRequest.Messages, 1)
assistantMsg := openaiRequest.Messages[0]
assert.Equal(t, "assistant", assistantMsg.Role)
assert.Nil(t, assistantMsg.Content) // No text content, so should be null
// Verify multiple tool_calls
require.NotNil(t, assistantMsg.ToolCalls)
require.Len(t, assistantMsg.ToolCalls, 2)
// First tool call
assert.Equal(t, "toolu_search", assistantMsg.ToolCalls[0].Id)
assert.Equal(t, "web_search", assistantMsg.ToolCalls[0].Function.Name)
// Second tool call
assert.Equal(t, "toolu_calc", assistantMsg.ToolCalls[1].Id)
assert.Equal(t, "calculate", assistantMsg.ToolCalls[1].Function.Name)
})
t.Run("convert_multiple_tool_results", func(t *testing.T) {
// Test multiple tool_result messages
claudeRequest := `{
"model": "anthropic/claude-sonnet-4",
"messages": [{
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": "toolu_search",
"content": "Weather: 25°C sunny"
}, {
"type": "tool_result",
"tool_use_id": "toolu_calc",
"content": "Result: 4"
}]
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Should have two tool messages
require.Len(t, openaiRequest.Messages, 2)
// First tool result
toolMsg1 := openaiRequest.Messages[0]
assert.Equal(t, "tool", toolMsg1.Role)
assert.Equal(t, "Weather: 25°C sunny", toolMsg1.Content)
assert.Equal(t, "toolu_search", toolMsg1.ToolCallId)
// Second tool result
toolMsg2 := openaiRequest.Messages[1]
assert.Equal(t, "tool", toolMsg2.Role)
assert.Equal(t, "Result: 4", toolMsg2.Content)
assert.Equal(t, "toolu_calc", toolMsg2.ToolCallId)
})
t.Run("convert_mixed_text_and_tool_use", func(t *testing.T) {
// Test message with both text and tool_use
claudeRequest := `{
"model": "anthropic/claude-sonnet-4",
"messages": [{
"role": "assistant",
"content": [{
"type": "text",
"text": "Let me search for that information and do a calculation."
}, {
"type": "tool_use",
"id": "toolu_search123",
"name": "search_database",
"input": {"table": "users", "limit": 10}
}]
}],
"max_tokens": 1000
}`
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(claudeRequest))
require.NoError(t, err)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
require.NoError(t, err)
// Should have one assistant message with both content and tool_calls
require.Len(t, openaiRequest.Messages, 1)
assistantMsg := openaiRequest.Messages[0]
assert.Equal(t, "assistant", assistantMsg.Role)
assert.Equal(t, "Let me search for that information and do a calculation.", assistantMsg.Content)
// Should have tool_calls
require.NotNil(t, assistantMsg.ToolCalls)
require.Len(t, assistantMsg.ToolCalls, 1)
assert.Equal(t, "toolu_search123", assistantMsg.ToolCalls[0].Id)
assert.Equal(t, "search_database", assistantMsg.ToolCalls[0].Function.Name)
})
}
func TestClaudeToOpenAIConverter_ConvertOpenAIResponseToClaude(t *testing.T) {
converter := &ClaudeToOpenAIConverter{}
t.Run("convert_tool_calls_response", func(t *testing.T) {
// Test OpenAI response with tool calls conversion to Claude format
openaiResponse := `{
"id": "gen-1756214072-tVFkPBV6lxee00IqNAC5",
"provider": "Google",
"model": "anthropic/claude-sonnet-4",
"object": "chat.completion",
"created": 1756214072,
"choices": [{
"logprobs": null,
"finish_reason": "tool_calls",
"native_finish_reason": "tool_calls",
"index": 0,
"message": {
"role": "assistant",
"content": "I'll analyze the README file to understand this project's purpose.",
"refusal": null,
"reasoning": null,
"tool_calls": [{
"id": "toolu_vrtx_017ijjgx8hpigatPzzPW59Wq",
"index": 0,
"type": "function",
"function": {
"name": "Read",
"arguments": "{\"file_path\": \"/Users/zhangty/git/higress/README.md\"}"
}
}]
}
}],
"usage": {
"prompt_tokens": 14923,
"completion_tokens": 81,
"total_tokens": 15004
}
}`
result, err := converter.ConvertOpenAIResponseToClaude(nil, []byte(openaiResponse))
require.NoError(t, err)
var claudeResponse claudeTextGenResponse
err = json.Unmarshal(result, &claudeResponse)
require.NoError(t, err)
// Verify basic response fields
assert.Equal(t, "gen-1756214072-tVFkPBV6lxee00IqNAC5", claudeResponse.Id)
assert.Equal(t, "message", claudeResponse.Type)
assert.Equal(t, "assistant", claudeResponse.Role)
assert.Equal(t, "anthropic/claude-sonnet-4", claudeResponse.Model)
assert.Equal(t, "tool_use", *claudeResponse.StopReason)
// Verify usage
assert.Equal(t, 14923, claudeResponse.Usage.InputTokens)
assert.Equal(t, 81, claudeResponse.Usage.OutputTokens)
// Verify content array has both text and tool_use
require.Len(t, claudeResponse.Content, 2)
// First content should be text
textContent := claudeResponse.Content[0]
assert.Equal(t, "text", textContent.Type)
assert.Equal(t, "I'll analyze the README file to understand this project's purpose.", textContent.Text)
// Second content should be tool_use
toolContent := claudeResponse.Content[1]
assert.Equal(t, "tool_use", toolContent.Type)
assert.Equal(t, "toolu_vrtx_017ijjgx8hpigatPzzPW59Wq", toolContent.Id)
assert.Equal(t, "Read", toolContent.Name)
// Verify tool arguments
require.NotNil(t, toolContent.Input)
assert.Equal(t, "/Users/zhangty/git/higress/README.md", toolContent.Input["file_path"])
})
}
func TestClaudeToOpenAIConverter_ConvertThinkingConfig(t *testing.T) {
converter := &ClaudeToOpenAIConverter{}
tests := []struct {
name string
claudeRequest string
expectedMaxTokens int
expectedEffort string
expectThinkingConfig bool
}{
{
name: "thinking_enabled_low",
claudeRequest: `{
"model": "claude-sonnet-4",
"max_tokens": 1000,
"messages": [{"role": "user", "content": "Hello"}],
"thinking": {"type": "enabled", "budget_tokens": 2048}
}`,
expectedMaxTokens: 2048,
expectedEffort: "low",
expectThinkingConfig: true,
},
{
name: "thinking_enabled_medium",
claudeRequest: `{
"model": "claude-sonnet-4",
"max_tokens": 1000,
"messages": [{"role": "user", "content": "Hello"}],
"thinking": {"type": "enabled", "budget_tokens": 8192}
}`,
expectedMaxTokens: 8192,
expectedEffort: "medium",
expectThinkingConfig: true,
},
{
name: "thinking_enabled_high",
claudeRequest: `{
"model": "claude-sonnet-4",
"max_tokens": 1000,
"messages": [{"role": "user", "content": "Hello"}],
"thinking": {"type": "enabled", "budget_tokens": 20480}
}`,
expectedMaxTokens: 20480,
expectedEffort: "high",
expectThinkingConfig: true,
},
{
name: "thinking_disabled",
claudeRequest: `{
"model": "claude-sonnet-4",
"max_tokens": 1000,
"messages": [{"role": "user", "content": "Hello"}],
"thinking": {"type": "disabled"}
}`,
expectedMaxTokens: 0,
expectedEffort: "",
expectThinkingConfig: false,
},
{
name: "no_thinking",
claudeRequest: `{
"model": "claude-sonnet-4",
"max_tokens": 1000,
"messages": [{"role": "user", "content": "Hello"}]
}`,
expectedMaxTokens: 0,
expectedEffort: "",
expectThinkingConfig: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := converter.ConvertClaudeRequestToOpenAI([]byte(tt.claudeRequest))
assert.NoError(t, err)
assert.NotNil(t, result)
var openaiRequest chatCompletionRequest
err = json.Unmarshal(result, &openaiRequest)
assert.NoError(t, err)
if tt.expectThinkingConfig {
assert.Equal(t, tt.expectedMaxTokens, openaiRequest.ReasoningMaxTokens)
assert.Equal(t, tt.expectedEffort, openaiRequest.ReasoningEffort)
} else {
assert.Equal(t, 0, openaiRequest.ReasoningMaxTokens)
assert.Equal(t, "", openaiRequest.ReasoningEffort)
}
})
}
}
func TestClaudeToOpenAIConverter_ConvertReasoningResponseToClaude(t *testing.T) {
converter := &ClaudeToOpenAIConverter{}
tests := []struct {
name string
openaiResponse string
expectThinking bool
expectedText string
}{
{
name: "response_with_reasoning_content",
openaiResponse: `{
"id": "chatcmpl-test123",
"object": "chat.completion",
"created": 1699999999,
"model": "gpt-4o",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Based on my analysis, the answer is 42.",
"reasoning_content": "Let me think about this step by step:\n1. The question asks about the meaning of life\n2. According to Douglas Adams, the answer is 42\n3. Therefore, 42 is the correct answer"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}`,
expectThinking: true,
expectedText: "Based on my analysis, the answer is 42.",
},
{
name: "response_with_reasoning_field",
openaiResponse: `{
"id": "chatcmpl-test789",
"object": "chat.completion",
"created": 1699999999,
"model": "gpt-4o",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Based on my analysis, the answer is 42.",
"reasoning": "Let me think about this step by step:\n1. The question asks about the meaning of life\n2. According to Douglas Adams, the answer is 42\n3. Therefore, 42 is the correct answer"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}`,
expectThinking: true,
expectedText: "Based on my analysis, the answer is 42.",
},
{
name: "response_without_reasoning_content",
openaiResponse: `{
"id": "chatcmpl-test456",
"object": "chat.completion",
"created": 1699999999,
"model": "gpt-4o",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "The answer is 42."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 10,
"total_tokens": 15
}
}`,
expectThinking: false,
expectedText: "The answer is 42.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := converter.ConvertOpenAIResponseToClaude(nil, []byte(tt.openaiResponse))
assert.NoError(t, err)
assert.NotNil(t, result)
var claudeResponse claudeTextGenResponse
err = json.Unmarshal(result, &claudeResponse)
assert.NoError(t, err)
// Verify response structure
assert.Equal(t, "message", claudeResponse.Type)
assert.Equal(t, "assistant", claudeResponse.Role)
assert.NotEmpty(t, claudeResponse.Id) // ID should be present
if tt.expectThinking {
// Should have both thinking and text content
assert.Len(t, claudeResponse.Content, 2)
// First should be thinking
thinkingContent := claudeResponse.Content[0]
assert.Equal(t, "thinking", thinkingContent.Type)
assert.Equal(t, "", thinkingContent.Signature) // OpenAI doesn't provide signature
assert.Contains(t, thinkingContent.Thinking, "Let me think about this step by step")
// Second should be text
textContent := claudeResponse.Content[1]
assert.Equal(t, "text", textContent.Type)
assert.Equal(t, tt.expectedText, textContent.Text)
} else {
// Should only have text content
assert.Len(t, claudeResponse.Content, 1)
textContent := claudeResponse.Content[0]
assert.Equal(t, "text", textContent.Type)
assert.Equal(t, tt.expectedText, textContent.Text)
}
})
}
}

View File

@@ -5,24 +5,21 @@ import (
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
)
// deepseekProvider is the provider for deepseek Ai service.
const (
deepseekDomain = "api.deepseek.com"
// TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion
// accourding to the docs, the path should be /chat/completions, need to be verified
deepseekChatCompletionPath = "/v1/chat/completions"
deepseekDomain = "api.deepseek.com"
deepseekAnthropicMessagesPath = "/anthropic/v1/messages"
)
type deepseekProviderInitializer struct {
}
type deepseekProviderInitializer struct{}
func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
if len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
@@ -30,7 +27,9 @@ func (m *deepseekProviderInitializer) ValidateConfig(config *ProviderConfig) err
func (m *deepseekProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): deepseekChatCompletionPath,
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameModels): PathOpenAIModels,
string(ApiNameAnthropicMessages): deepseekAnthropicMessagesPath,
}
}

View File

@@ -146,7 +146,7 @@ func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
return fmt.Errorf("failed to init apiTokens: %v", err)
}
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
wrapper.RegisterTickFunc(c.failover.healthCheckInterval, func() {
// Only the Wasm VM that successfully acquires the lease will perform health check
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID) {
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())

View File

@@ -1,15 +1,21 @@
package provider
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
@@ -28,6 +34,12 @@ const (
geminiImageGenerationPath = "predict"
)
var geminiThinkingModels = map[string]bool{
"gemini-2.5-pro": true,
"gemini-2.5-flash": true,
"gemini-2.5-flash-lite": true,
}
type geminiProviderInitializer struct{}
func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error {
@@ -53,12 +65,17 @@ func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
return &geminiProvider{
config: config,
contextCache: createContextCache(&config),
client: wrapper.NewClusterClient(wrapper.RouteCluster{
Host: geminiDomain,
}),
}, nil
}
type geminiProvider struct {
config ProviderConfig
contextCache *contextCache
client wrapper.HttpClient
}
func (g *geminiProvider) GetProviderType() string {
@@ -77,11 +94,47 @@ func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
util.OverwriteRequestAuthorizationHeader(headers, "")
}
// to support the multimodal for gemini, we can't reuse the config's handleRequestBody
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
if g.config.firstByteTimeout != 0 && g.config.isStreamingAPI(apiName, body) {
err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms",
strconv.FormatUint(uint64(g.config.firstByteTimeout), 10))
if err != nil {
log.Errorf("failed to set timeout header: %v", err)
}
}
if g.config.IsOriginal() {
return types.ActionContinue, nil
}
headers := util.GetRequestHeaders()
request, err := g.TransformRequestBodyHeaders(ctx, apiName, body, headers)
if err != nil {
return types.ActionContinue, err
}
util.ReplaceRequestHeaders(headers)
if apiName == ApiNameChatCompletion {
if g.config.context != nil {
err = g.contextCache.GetContextFromFile(ctx, g, body)
if err == nil {
return types.ActionPause, nil
}
}
if action, err := g.processImageURL(ctx, request); err != nil {
return action, err
} else {
return action, replaceRequestBody(request)
}
}
return types.ActionContinue, replaceRequestBody(request)
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
@@ -365,6 +418,7 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
Threshold: threshold,
})
}
geminiRequest := geminiGenerationContentRequest{
Contents: make([]geminiChatContent, 0, len(request.Messages)),
SafetySettings: safetySettings,
@@ -379,6 +433,13 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
},
}
if geminiThinkingModels[request.Model] {
geminiRequest.GenerationConfig.ThinkingConfig = &geminiThinkingConfig{
IncludeThoughts: true,
ThinkingBudget: g.config.geminiThinkingBudget,
}
}
if request.Tools != nil {
functions := make([]function, 0, len(request.Tools))
for _, tool := range request.Tools {
@@ -393,12 +454,21 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
// shouldAddDummyModelMessage := false
for _, message := range request.Messages {
content := geminiChatContent{
Role: message.Role,
Parts: []geminiPart{
{
Text: message.StringContent(),
},
},
Role: message.Role,
Parts: []geminiPart{},
}
for _, c := range message.ParseContent() {
switch c.Type {
case contentTypeText:
content.Parts = append(content.Parts, geminiPart{
Text: c.Text,
})
case contentTypeImageUrl:
content.Parts = append(content.Parts, g.handleContentTypeImageUrl(c.ImageUrl))
default:
log.Debugf("currently gemini did not support this type: %s", c.Type)
}
}
// there's no assistant role in gemini and API shall vomit if role is not user or model
@@ -417,6 +487,176 @@ func (g *geminiProvider) buildGeminiChatRequest(request *chatCompletionRequest)
return &geminiRequest
}
func (g *geminiProvider) countImageUrl(request *geminiGenerationContentRequest) int {
totalImages := 0
for _, c := range request.Contents {
for _, p := range c.Parts {
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
totalImages += 1
}
}
}
return totalImages
}
func (g *geminiProvider) processImageURL(ctx wrapper.HttpContext, body []byte) (types.Action, error) {
request := &geminiGenerationContentRequest{}
err := json.Unmarshal(body, request)
if err != nil {
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal")
return types.ActionContinue, err
}
var totalImages int
if totalImages = g.countImageUrl(request); totalImages == 0 {
// there are no images return directly
return types.ActionContinue, replaceRequestBody(body)
}
if err := g.processImageURLWithCallback(ctx, body, totalImages, func(body []byte, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to get image while handle multi modal: %v", err)
util.ErrorHandler("ai-proxy.gemini.fetch_image_failed", err)
return
}
// replace the request
if err := replaceRequestBody(body); err != nil {
util.ErrorHandler("ai-proxy.gemini.replace_request_body_failed", err)
}
}); err != nil {
return types.ActionContinue, err
}
return types.ActionPause, nil
}
func (g *geminiProvider) processImageURLWithCallback(ctx wrapper.HttpContext, body []byte, totalImages int, callback func([]byte, error)) error {
request := &geminiGenerationContentRequest{}
err := json.Unmarshal(body, request)
if err != nil {
log.Errorf("failed to unmarshal geminiGenerationRequest while handle multi modal: %v", err)
return err
}
pending := totalImages
var callbackErr []error
for ci, c := range request.Contents {
for pi := range c.Parts {
p := &request.Contents[ci].Parts[pi]
if p.InlineData != nil && g.isUrl(p.InlineData.Data) {
g.getImageInlineDataWithCallback(p.InlineData.Data, func(gid *geminiInlineData, err error) {
if err != nil {
log.Errorf("image %s fetch failed: %v", p.InlineData.Data, err)
callbackErr = append(callbackErr, err)
} else {
*p.InlineData = *gid
}
pending -= 1
if pending == 0 {
body, err := json.Marshal(request)
if err != nil {
log.Errorf("failed to marshal request while processImageURL: %v", err)
callbackErr = append(callbackErr, err)
}
callback(body, errors.Join(callbackErr...))
}
})
}
}
}
return nil
}
func (g *geminiProvider) handleContentTypeImageUrl(c *chatMessageContentImageUrl) (part geminiPart) {
if g.isUrl(c.Url) {
part.InlineData = &geminiInlineData{
Data: c.Url,
}
return
}
part.InlineData = g.baseStr2InlineData(c.Url)
return
}
func (g *geminiProvider) isUrl(raw string) bool {
u, err := url.Parse(raw)
return err == nil && (u.Scheme == "http" || u.Scheme == "https")
}
func (g *geminiProvider) baseStr2InlineData(baseStr string) *geminiInlineData {
if strings.HasPrefix(baseStr, "data:") {
p := strings.SplitN(baseStr, ";", 2)
if len(p) != 2 {
log.Errorf("invalid base64 string: %s", p)
return nil
}
mime := strings.TrimPrefix(p[0], "data:")
baseData := strings.TrimPrefix(p[1], "base64,")
return &geminiInlineData{
MimeType: mime,
Data: baseData,
}
}
log.Errorf("invalid base64 string: %s", baseStr)
return &geminiInlineData{
MimeType: "",
Data: "",
}
}
func (g *geminiProvider) getImageInlineDataWithCallback(raw string, callback func(*geminiInlineData, error)) {
responseCallback := func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
callback(nil, fmt.Errorf("get %s failed, status: %v", raw, statusCode))
return
}
resReader := bytes.NewReader(responseBody)
const maxSize = 100 << 20
data, err := io.ReadAll(io.LimitReader(resReader, maxSize+1))
if err != nil {
callback(nil, fmt.Errorf("read %v response data failed: %v", raw, err))
return
}
if len(data) > maxSize {
callback(nil, fmt.Errorf("%v exceed max image size 100MB", raw))
return
}
mimeType := http.DetectContentType(data)
base64Data := base64.StdEncoding.EncodeToString(data)
callback(&geminiInlineData{
MimeType: mimeType,
Data: base64Data,
}, nil)
}
timeout := (time.Second * 30).Milliseconds()
headers := [][2]string{
{"Accept", "image/*"},
{"User-Agent", "Mozilla/5.0 (compatible; AI-Proxy/1.0)"},
{"Referer", "https://www.google.com/"},
}
if g.client == nil {
log.Error("client is nil")
return
}
err := g.client.Get(raw, headers, responseCallback, uint32(timeout))
if err != nil {
log.Errorf("failed to get image %s data", raw)
callback(nil, fmt.Errorf("failed to get image %s", raw))
return
}
}
func (g *geminiProvider) setSystemContent(request *geminiGenerationContentRequest, content string) {
systemContents := []geminiChatContent{{
Role: roleUser,

View File

@@ -0,0 +1,75 @@
package provider
import (
"errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
// grokProvider is the provider for Grok service.
const (
grokDomain = "api.x.ai"
grokChatCompletionPath = "/v1/chat/completions"
)
type grokProviderInitializer struct{}
func (g *grokProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if config.apiTokens == nil || len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (g *grokProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): grokChatCompletionPath,
}
}
func (g *grokProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(g.DefaultCapabilities())
return &grokProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type grokProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (g *grokProvider) GetProviderType() string {
return providerTypeGrok
}
func (g *grokProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
g.config.handleRequestHeaders(g, ctx, apiName)
return nil
}
func (g *grokProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
}
func (g *grokProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, grokDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (g *grokProvider) GetApiName(path string) ApiName {
if strings.Contains(path, grokChatCompletionPath) {
return ApiNameChatCompletion
}
return ""
}

View File

@@ -8,10 +8,10 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/log"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)

View File

@@ -29,7 +29,12 @@ const (
reasoningEndTag = "</think>"
)
type NonOpenAIStyleOptions struct {
ReasoningMaxTokens int `json:"reasoning_max_tokens,omitempty"`
}
type chatCompletionRequest struct {
NonOpenAIStyleOptions
Messages []chatMessage `json:"messages"`
Model string `json:"model"`
Store bool `json:"store,omitempty"`
@@ -169,8 +174,11 @@ type chatMessage struct {
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"` // For streaming responses
ToolCalls []toolCall `json:"tool_calls,omitempty"`
FunctionCall *functionCall `json:"function_call,omitempty"` // For legacy OpenAI format
Refusal string `json:"refusal,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
func (m *chatMessage) handleNonStreamingReasoningContent(reasoningContentMode string) {
@@ -377,14 +385,14 @@ func (m *chatMessage) ParseContent() []chatMessageContent {
}
type toolCall struct {
Index int `json:"index"`
Id string `json:"id"`
Index int `json:"index,omitempty"`
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function functionCall `json:"function"`
}
type functionCall struct {
Id string `json:"id"`
Id string `json:"id,omitempty"`
Name string `json:"name"`
Arguments string `json:"arguments"`
}

View File

@@ -0,0 +1,117 @@
package provider
import (
"errors"
"net/http"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/wrapper"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// openrouterProvider is the provider for OpenRouter service.
const (
openrouterDomain = "openrouter.ai"
openrouterChatCompletionPath = "/api/v1/chat/completions"
openrouterCompletionPath = "/api/v1/completions"
)
type openrouterProviderInitializer struct{}
func (o *openrouterProviderInitializer) ValidateConfig(config *ProviderConfig) error {
if len(config.apiTokens) == 0 {
return errors.New("no apiToken found in provider config")
}
return nil
}
func (o *openrouterProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
string(ApiNameChatCompletion): openrouterChatCompletionPath,
string(ApiNameCompletion): openrouterCompletionPath,
}
}
func (o *openrouterProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
config.setDefaultCapabilities(o.DefaultCapabilities())
return &openrouterProvider{
config: config,
contextCache: createContextCache(&config),
}, nil
}
type openrouterProvider struct {
config ProviderConfig
contextCache *contextCache
}
func (o *openrouterProvider) GetProviderType() string {
return providerTypeOpenRouter
}
func (o *openrouterProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
o.config.handleRequestHeaders(o, ctx, apiName)
return nil
}
func (o *openrouterProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !o.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return o.config.handleRequestBody(o, o.contextCache, ctx, apiName, body)
}
func (o *openrouterProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), o.config.capabilities)
util.OverwriteRequestHostHeader(headers, openrouterDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+o.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (o *openrouterProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return o.config.defaultTransformRequestBody(ctx, apiName, body)
}
// Check if ReasoningMaxTokens exists in the request body
reasoningMaxTokens := gjson.GetBytes(body, "reasoning_max_tokens")
if !reasoningMaxTokens.Exists() || reasoningMaxTokens.Int() == 0 {
// No reasoning_max_tokens, use default transformation
return o.config.defaultTransformRequestBody(ctx, apiName, body)
}
// Clear reasoning_effort field if it exists
modifiedBody, err := sjson.DeleteBytes(body, "reasoning_effort")
if err != nil {
// If delete fails, continue with original body
modifiedBody = body
}
// Set reasoning.max_tokens to the value of reasoning_max_tokens
modifiedBody, err = sjson.SetBytes(modifiedBody, "reasoning.max_tokens", reasoningMaxTokens.Int())
if err != nil {
return nil, err
}
// Remove the original reasoning_max_tokens field
modifiedBody, err = sjson.DeleteBytes(modifiedBody, "reasoning_max_tokens")
if err != nil {
return nil, err
}
// Apply default model mapping
return o.config.defaultTransformRequestBody(ctx, apiName, modifiedBody)
}
func (o *openrouterProvider) GetApiName(path string) ApiName {
if strings.Contains(path, openrouterChatCompletionPath) {
return ApiNameChatCompletion
}
if strings.Contains(path, openrouterCompletionPath) {
return ApiNameCompletion
}
return ""
}

View File

@@ -3,6 +3,7 @@ package provider
import (
"bytes"
"errors"
"fmt"
"math/rand"
"net/http"
"path"
@@ -107,6 +108,7 @@ const (
providerTypeQwen = "qwen"
providerTypeOpenAI = "openai"
providerTypeGroq = "groq"
providerTypeGrok = "grok"
providerTypeBaichuan = "baichuan"
providerTypeYi = "yi"
providerTypeDeepSeek = "deepseek"
@@ -129,6 +131,7 @@ const (
providerTypeDify = "dify"
providerTypeBedrock = "bedrock"
providerTypeVertex = "vertex"
providerTypeOpenRouter = "openrouter"
protocolOpenAI = "openai"
protocolOriginal = "original"
@@ -136,9 +139,11 @@ const (
roleSystem = "system"
roleAssistant = "assistant"
roleUser = "user"
roleTool = "tool"
finishReasonStop = "stop"
finishReasonLength = "length"
finishReasonStop = "stop"
finishReasonLength = "length"
finishReasonToolCall = "tool_calls"
ctxKeyIncrementalStreaming = "incrementalStreaming"
ctxKeyApiKey = "apiKey"
@@ -182,6 +187,7 @@ var (
providerTypeQwen: &qwenProviderInitializer{},
providerTypeOpenAI: &openaiProviderInitializer{},
providerTypeGroq: &groqProviderInitializer{},
providerTypeGrok: &grokProviderInitializer{},
providerTypeBaichuan: &baichuanProviderInitializer{},
providerTypeYi: &yiProviderInitializer{},
providerTypeDeepSeek: &deepseekProviderInitializer{},
@@ -204,6 +210,7 @@ var (
providerTypeDify: &difyProviderInitializer{},
providerTypeBedrock: &bedrockProviderInitializer{},
providerTypeVertex: &vertexProviderInitializer{},
providerTypeOpenRouter: &openrouterProviderInitializer{},
}
)
@@ -344,6 +351,9 @@ type ProviderConfig struct {
// @Title zh-CN Gemini AI内容过滤和安全级别设定
// @Description zh-CN 仅适用于 Gemini AI 服务。参考https://ai.google.dev/gemini-api/docs/safety-settings
geminiSafetySetting map[string]string `required:"false" yaml:"geminiSafetySetting" json:"geminiSafetySetting"`
// @Title zh-CN Gemini Thinking Budget 配置
// @Description zh-CN 仅适用于 Gemini AI 服务,用于控制思考预算
geminiThinkingBudget int64 `required:"false" yaml:"geminiThinkingBudget" json:"geminiThinkingBudget"`
// @Title zh-CN Vertex AI访问区域
// @Description zh-CN 仅适用于Vertex AI服务。如需查看支持的区域的完整列表请参阅https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations?hl=zh-cn#available-regions
vertexRegion string `required:"false" yaml:"vertexRegion" json:"vertexRegion"`
@@ -472,6 +482,7 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.geminiSafetySetting[k] = v.String()
}
}
c.geminiThinkingBudget = json.Get("geminiThinkingBudget").Int()
c.vertexRegion = json.Get("vertexRegion").String()
c.vertexProjectId = json.Get("vertexProjectId").String()
c.vertexAuthKey = json.Get("vertexAuthKey").String()
@@ -514,10 +525,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.reasoningContentMode = strings.ToLower(c.reasoningContentMode)
switch c.reasoningContentMode {
case reasoningBehaviorPassThrough, reasoningBehaviorIgnore, reasoningBehaviorConcat:
break
// valid values, no action needed
default:
c.reasoningContentMode = reasoningBehaviorPassThrough
break
}
}
@@ -824,6 +834,10 @@ func (c *ProviderConfig) isSupportedAPI(apiName ApiName) bool {
return exist
}
func (c *ProviderConfig) IsSupportedAPI(apiName ApiName) bool {
return c.isSupportedAPI(apiName)
}
func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) {
for capability, path := range capabilities {
c.capabilities[capability] = path
@@ -847,8 +861,22 @@ func (c *ProviderConfig) handleRequestBody(
return types.ActionContinue, nil
}
// use openai protocol
var err error
// handle claude protocol input - auto-detect based on conversion marker
// If main.go detected a Claude request that needs conversion, convert the body
needClaudeConversion, _ := ctx.GetContext("needClaudeResponseConversion").(bool)
if needClaudeConversion {
// Convert Claude protocol to OpenAI protocol
converter := &ClaudeToOpenAIConverter{}
body, err = converter.ConvertClaudeRequestToOpenAI(body)
if err != nil {
return types.ActionContinue, fmt.Errorf("failed to convert claude request to openai: %v", err)
}
log.Debugf("[Auto Protocol] converted Claude request body to OpenAI format")
}
// use openai protocol (either original openai or converted from claude)
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {

View File

@@ -0,0 +1,718 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置基本ai360配置
var basicAi360Config = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "ai360",
"apiTokens": []string{"sk-ai360-test123456789"},
"modelMapping": map[string]string{
"*": "360GPT_S2_V9",
},
},
})
return data
}()
// 测试配置ai360多模型配置
var ai360MultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "ai360",
"apiTokens": []string{"sk-ai360-multi-model"},
"modelMapping": map[string]string{
"gpt-3.5-turbo": "360GPT_S2_V9",
"gpt-4": "360GPT_S2_V9",
"text-embedding-ada-002": "360Embedding_Text_V1",
},
},
})
return data
}()
// 测试配置无效ai360配置缺少apiToken
var invalidAi360Config = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "ai360",
// 缺少apiTokens
},
})
return data
}()
// 测试配置ai360自定义域名配置
var ai360CustomDomainConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "ai360",
"apiTokens": []string{"sk-ai360-custom-domain"},
"modelMapping": map[string]string{
"*": "360GPT_S2_V9",
},
"openaiCustomUrl": "https://custom.ai360.cn/v1",
},
})
return data
}()
// 测试配置ai360完整配置包含failover等字段
var completeAi360Config = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "ai360",
"apiTokens": []string{"sk-ai360-complete"},
"modelMapping": map[string]string{
"*": "360GPT_S2_V9",
},
"failover": map[string]interface{}{
"enabled": false,
},
"retryOnFailure": map[string]interface{}{
"enabled": false,
},
},
})
return data
}()
func RunAi360ParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本ai360配置解析
t.Run("basic ai360 config", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试ai360多模型配置解析
t.Run("ai360 multi model config", func(t *testing.T) {
host, status := test.NewTestHost(ai360MultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效ai360配置缺少apiToken
t.Run("invalid ai360 config - missing api token", func(t *testing.T) {
host, status := test.NewTestHost(invalidAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试ai360自定义域名配置解析
t.Run("ai360 custom domain config", func(t *testing.T) {
host, status := test.NewTestHost(ai360CustomDomainConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试ai360完整配置解析
t.Run("ai360 complete config", func(t *testing.T) {
host, status := test.NewTestHost(completeAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
func RunAi360OnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试ai360请求头处理聊天完成接口
t.Run("ai360 chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为ai360域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "api.360.cn", hostValue, "Host should be changed to ai360 domain")
// 验证Authorization是否被设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "sk-ai360-test123456789", "Authorization should contain ai360 API token")
// 验证Path是否被正确处理
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
// ai360应该支持聊天完成接口路径可能被转换
require.Contains(t, pathValue, "/v1/chat/completions", "Path should contain chat completions endpoint")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasAi360Logs := false
for _, log := range debugLogs {
if strings.Contains(log, "ai360") {
hasAi360Logs = true
break
}
}
require.True(t, hasAi360Logs, "Should have ai360 processing logs")
})
// 测试ai360请求头处理嵌入接口
t.Run("ai360 embeddings request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证嵌入接口的请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.360.cn", hostValue)
// 验证Path转换
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Contains(t, pathValue, "/v1/embeddings", "Path should contain embeddings endpoint")
// 验证Authorization设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for embeddings")
require.Contains(t, authValue, "sk-ai360-test123456789", "Authorization should contain ai360 API token")
})
// 测试ai360请求头处理不支持的接口
t.Run("ai360 unsupported api request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证不支持的接口处理
// 即使是不支持的接口,基本的请求头转换仍然应该执行
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// Host仍然应该被转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.360.cn", hostValue)
})
})
}
func RunAi360OnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试ai360请求体处理聊天完成接口
t.Run("ai360 chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证模型名称是否被正确映射
// ai360 provider会将模型名称从gpt-3.5-turbo映射为360GPT_S2_V9
require.Contains(t, string(processedBody), "360GPT_S2_V9", "Model name should be mapped to ai360 format")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
infoLogs := host.GetInfoLogs()
// 验证是否有ai360相关的处理日志
hasAi360Logs := false
for _, log := range debugLogs {
if strings.Contains(log, "ai360") {
hasAi360Logs = true
break
}
}
for _, log := range infoLogs {
if strings.Contains(log, "ai360") {
hasAi360Logs = true
break
}
}
require.True(t, hasAi360Logs, "Should have ai360 processing logs")
})
// 测试ai360请求体处理嵌入接口
t.Run("ai360 embeddings request body", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证嵌入接口的请求体处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证模型名称映射
// ai360 provider会将模型名称从text-embedding-ada-002映射为360GPT_S2_V9
require.Contains(t, string(processedBody), "360GPT_S2_V9", "Model name should be mapped to ai360 format")
// 检查处理日志
debugLogs := host.GetDebugLogs()
hasEmbeddingLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "embeddings") || strings.Contains(log, "ai360") {
hasEmbeddingLogs = true
break
}
}
require.True(t, hasEmbeddingLogs, "Should have embedding processing logs")
})
// 测试ai360请求体处理不支持的接口
t.Run("ai360 unsupported api request body", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"dall-e-3","prompt":"test image"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证不支持的接口处理
// 验证请求体没有被意外修改
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
require.Contains(t, string(processedBody), "dall-e-3", "Request body should not be modified for unsupported APIs")
})
})
}
func RunAi360OnHttpResponseHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试ai360响应头处理聊天完成接口
t.Run("ai360 chat completion response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
{"X-Request-Id", "req-123"},
}
action := host.CallOnHttpResponseHeaders(responseHeaders)
require.Equal(t, types.ActionContinue, action)
// 验证响应头是否被正确处理
processedResponseHeaders := host.GetResponseHeaders()
require.NotNil(t, processedResponseHeaders)
// 验证状态码
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "200", statusValue, "Status should be 200")
// 验证Content-Type
contentTypeValue, hasContentType := test.GetHeaderValue(processedResponseHeaders, "Content-Type")
require.True(t, hasContentType, "Content-Type header should exist")
require.Equal(t, "application/json", contentTypeValue, "Content-Type should be application/json")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasResponseLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "response") || strings.Contains(log, "ai360") {
hasResponseLogs = true
break
}
}
require.True(t, hasResponseLogs, "Should have response processing logs")
})
// 测试ai360响应头处理嵌入接口
t.Run("ai360 embeddings response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
{"X-Embedding-Model", "360Embedding_Text_V1"},
}
action := host.CallOnHttpResponseHeaders(responseHeaders)
require.Equal(t, types.ActionContinue, action)
// 验证响应头处理
processedResponseHeaders := host.GetResponseHeaders()
require.NotNil(t, processedResponseHeaders)
// 验证嵌入模型信息
modelValue, hasModel := test.GetHeaderValue(processedResponseHeaders, "X-Embedding-Model")
require.True(t, hasModel, "Embedding model header should exist")
require.Equal(t, "360Embedding_Text_V1", modelValue, "Embedding model should match configuration")
})
// 测试ai360响应头处理错误响应
t.Run("ai360 error response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置错误响应头
errorResponseHeaders := [][2]string{
{":status", "429"},
{"Content-Type", "application/json"},
{"Retry-After", "60"},
}
action := host.CallOnHttpResponseHeaders(errorResponseHeaders)
require.Equal(t, types.ActionContinue, action)
// 验证错误响应头处理
processedResponseHeaders := host.GetResponseHeaders()
require.NotNil(t, processedResponseHeaders)
// 验证错误状态码
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "429", statusValue, "Status should be 429 (Too Many Requests)")
// 验证重试信息
retryValue, hasRetry := test.GetHeaderValue(processedResponseHeaders, "Retry-After")
require.True(t, hasRetry, "Retry-After header should exist")
require.Equal(t, "60", retryValue, "Retry-After should be 60 seconds")
})
})
}
func RunAi360OnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试ai360响应体处理聊天完成接口
t.Run("ai360 chat completion response body", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体
responseBody := `{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you today?"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 验证响应体内容
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "chat.completion", "Response should contain chat completion object")
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
require.Contains(t, responseStr, "usage", "Response should contain usage information")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasResponseBodyLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "ai360") {
hasResponseBodyLogs = true
break
}
}
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
})
// 测试ai360响应体处理嵌入接口
t.Run("ai360 embeddings response body", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体
responseBody := `{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": 0
}],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 验证嵌入响应内容
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "embedding", "Response should contain embedding object")
require.Contains(t, responseStr, "0.1", "Response should contain embedding vector")
require.Contains(t, responseStr, "text-embedding-ada-002", "Response should contain model name")
})
})
}
func RunAi360OnStreamingResponseBodyTests(t *testing.T) {
// 测试ai360响应体处理流式响应
test.RunTest(t, func(t *testing.T) {
t.Run("ai360 streaming response body", func(t *testing.T) {
host, status := test.NewTestHost(basicAi360Config)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置流式请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}],"stream":true}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置流式响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "text/event-stream"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 模拟流式响应体
chunk1 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"role":"assistant"},"index":0}]}
`
chunk2 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"},"index":0}]}
`
chunk3 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"},"index":0}]}
`
chunk4 := `data: [DONE]
`
// 处理流式响应体
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
require.Equal(t, types.ActionContinue, action1)
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), false)
require.Equal(t, types.ActionContinue, action2)
action3 := host.CallOnHttpStreamingResponseBody([]byte(chunk3), false)
require.Equal(t, types.ActionContinue, action3)
action4 := host.CallOnHttpStreamingResponseBody([]byte(chunk4), true)
require.Equal(t, types.ActionContinue, action4)
// 验证流式响应处理
// 注意流式响应可能不会在GetResponseBody中累积需要检查日志或其他方式验证
debugLogs := host.GetDebugLogs()
hasStreamingLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "ai360") {
hasStreamingLogs = true
break
}
}
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
})
})
}

View File

@@ -0,0 +1,600 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// 测试配置基本Azure OpenAI配置
var basicAzureConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-test123456789",
},
"azureServiceUrl": "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-02-15-preview",
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
// 测试配置Azure OpenAI完整路径配置
var azureFullPathConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-fullpath",
},
"azureServiceUrl": "https://fullpath-resource.openai.azure.com/openai/deployments/fullpath-deployment/chat/completions?api-version=2024-02-15-preview",
"modelMapping": map[string]string{
"gpt-3.5-turbo": "gpt-3.5-turbo",
"gpt-4": "gpt-4",
},
},
})
return data
}()
// 测试配置Azure OpenAI仅部署配置
var azureDeploymentOnlyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-deployment",
},
"azureServiceUrl": "https://deployment-resource.openai.azure.com/openai/deployments/deployment-only?api-version=2024-02-15-preview",
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
// 测试配置Azure OpenAI仅域名配置
var azureDomainOnlyConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-domain",
},
"azureServiceUrl": "https://domain-resource.openai.azure.com?api-version=2024-02-15-preview",
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
// 测试配置Azure OpenAI多模型配置
var azureMultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-multi",
},
"azureServiceUrl": "https://multi-resource.openai.azure.com/openai/deployments/multi-deployment?api-version=2024-02-15-preview",
"modelMapping": map[string]string{
"gpt-3.5-turbo": "gpt-3.5-turbo",
"gpt-4": "gpt-4",
"text-embedding-ada-002": "text-embedding-ada-002",
},
},
})
return data
}()
// 测试配置Azure OpenAI无效配置缺少azureServiceUrl
var azureInvalidConfigMissingUrl = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-invalid",
},
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
// 测试配置Azure OpenAI无效配置缺少api-version
var azureInvalidConfigMissingApiVersion = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"apiTokens": []string{
"sk-azure-invalid",
},
"azureServiceUrl": "https://invalid-resource.openai.azure.com/openai/deployments/invalid-deployment/chat/completions",
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
// 测试配置Azure OpenAI无效配置缺少apiToken
var azureInvalidConfigMissingToken = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "azure",
"azureServiceUrl": "https://invalid-resource.openai.azure.com/openai/deployments/invalid-deployment/chat/completions?api-version=2024-02-15-preview",
"modelMapping": map[string]interface{}{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
func RunAzureParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本Azure OpenAI配置解析
t.Run("basic azure config", func(t *testing.T) {
host, status := test.NewTestHost(basicAzureConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试Azure OpenAI完整路径配置解析
t.Run("azure full path config", func(t *testing.T) {
host, status := test.NewTestHost(azureFullPathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试Azure OpenAI仅部署配置解析
t.Run("azure deployment only config", func(t *testing.T) {
host, status := test.NewTestHost(azureDeploymentOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试Azure OpenAI仅域名配置解析
t.Run("azure domain only config", func(t *testing.T) {
host, status := test.NewTestHost(azureDomainOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试Azure OpenAI多模型配置解析
t.Run("azure multi model config", func(t *testing.T) {
host, status := test.NewTestHost(azureMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试Azure OpenAI无效配置缺少azureServiceUrl
t.Run("azure invalid config missing url", func(t *testing.T) {
host, status := test.NewTestHost(azureInvalidConfigMissingUrl)
defer host.Reset()
// 应该失败因为缺少azureServiceUrl
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试Azure OpenAI无效配置缺少api-version
t.Run("azure invalid config missing api version", func(t *testing.T) {
host, status := test.NewTestHost(azureInvalidConfigMissingApiVersion)
defer host.Reset()
// 应该失败因为缺少api-version
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试Azure OpenAI无效配置缺少apiToken
t.Run("azure invalid config missing token", func(t *testing.T) {
host, status := test.NewTestHost(azureInvalidConfigMissingToken)
defer host.Reset()
// 应该失败因为缺少apiToken
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
func RunAzureOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试Azure OpenAI请求头处理聊天完成接口
t.Run("azure chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAzureConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为Azure服务域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "test-resource.openai.azure.com", hostValue, "Host should be changed to Azure service domain")
// 验证api-key是否被设置
apiKeyValue, hasApiKey := test.GetHeaderValue(requestHeaders, "api-key")
require.True(t, hasApiKey, "api-key header should exist")
require.Equal(t, "sk-azure-test123456789", apiKeyValue, "api-key should contain Azure API token")
// 验证Path是否被正确处理
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/openai/deployments/test-deployment/chat/completions", "Path should contain Azure deployment path")
// 验证Content-Length是否被删除
_, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length")
require.False(t, hasContentLength, "Content-Length header should be deleted")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasAzureLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "azureProvider") {
hasAzureLogs = true
break
}
}
assert.True(t, hasAzureLogs, "Should have Azure provider debug logs")
})
// 测试Azure OpenAI请求头处理完整路径配置
t.Run("azure full path request headers", func(t *testing.T) {
host, status := test.NewTestHost(azureFullPathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为Azure服务域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "fullpath-resource.openai.azure.com", hostValue, "Host should be changed to Azure service domain")
// 验证api-key是否被设置
apiKeyValue, hasApiKey := test.GetHeaderValue(requestHeaders, "api-key")
require.True(t, hasApiKey, "api-key header should exist")
require.Equal(t, "sk-azure-fullpath", apiKeyValue, "api-key should contain Azure API token")
})
})
}
func RunAzureOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试Azure OpenAI请求体处理聊天完成接口
t.Run("azure chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicAzureConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello, how are you?"
}
],
"temperature": 0.7
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
transformedBody := host.GetRequestBody()
require.NotNil(t, transformedBody)
// 验证模型映射是否生效
var bodyMap map[string]interface{}
err := json.Unmarshal(transformedBody, &bodyMap)
require.NoError(t, err)
model, exists := bodyMap["model"]
require.True(t, exists, "Model should exist in request body")
require.Equal(t, "gpt-3.5-turbo", model, "Model should be mapped correctly")
// 验证请求路径是否被正确转换
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/openai/deployments/test-deployment/chat/completions", "Path should contain Azure deployment path")
require.Contains(t, pathValue, "api-version=2024-02-15-preview", "Path should contain API version")
})
// 测试Azure OpenAI请求体处理不同模型
t.Run("azure different model request body", func(t *testing.T) {
host, status := test.NewTestHost(azureMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体
requestBody := `{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "Explain quantum computing"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
transformedBody := host.GetRequestBody()
require.NotNil(t, transformedBody)
var bodyMap map[string]interface{}
err := json.Unmarshal(transformedBody, &bodyMap)
require.NoError(t, err)
model, exists := bodyMap["model"]
require.True(t, exists, "Model should exist in request body")
require.Equal(t, "gpt-4", model, "Model should be mapped correctly")
})
// 测试Azure OpenAI请求体处理仅部署配置
t.Run("azure deployment only request body", func(t *testing.T) {
host, status := test.NewTestHost(azureDeploymentOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Test message"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求路径是否使用默认部署
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/openai/deployments/deployment-only/chat/completions", "Path should use default deployment")
})
// 测试Azure OpenAI请求体处理仅域名配置
t.Run("azure domain only request body", func(t *testing.T) {
host, status := test.NewTestHost(azureDomainOnlyConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Test message"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求路径是否使用模型占位符
requestHeaders := host.GetRequestHeaders()
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/openai/deployments/gpt-3.5-turbo/chat/completions", "Path should use model from request body")
})
})
}
func RunAzureOnHttpResponseHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试Azure OpenAI响应头处理
t.Run("azure response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicAzureConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 处理响应头
action = host.CallOnHttpResponseHeaders([][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.ActionContinue, action)
// 验证响应头是否被正确处理
responseHeaders := host.GetResponseHeaders()
require.NotNil(t, responseHeaders)
// 验证状态码
statusValue, hasStatus := test.GetHeaderValue(responseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "200", statusValue, "Status should be 200")
// 验证Content-Type
contentTypeValue, hasContentType := test.GetHeaderValue(responseHeaders, "Content-Type")
require.True(t, hasContentType, "Content-Type header should exist")
require.Equal(t, "application/json", contentTypeValue, "Content-Type should be application/json")
})
})
}
func RunAzureOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试Azure OpenAI响应体处理
t.Run("azure response body", func(t *testing.T) {
host, status := test.NewTestHost(basicAzureConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 设置请求体
requestBody := `{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello"
}
]
}`
action = host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 处理响应体
responseBody := `{
"choices": [
{
"message": {
"content": "Hello! How can I help you?"
}
}
]
}`
action = host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理
transformedResponseBody := host.GetResponseBody()
require.NotNil(t, transformedResponseBody)
// 验证响应体内容
var responseMap map[string]interface{}
err := json.Unmarshal(transformedResponseBody, &responseMap)
require.NoError(t, err)
choices, exists := responseMap["choices"]
require.True(t, exists, "Choices should exist in response body")
require.NotNil(t, choices, "Choices should not be nil")
})
})
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,866 @@
package test
import (
"encoding/json"
"strings"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置基本OpenAI配置
var basicOpenAIConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "openai",
"apiTokens": []string{"sk-openai-test123456789"},
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
},
})
return data
}()
// 测试配置OpenAI多模型配置
var openAIMultiModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "openai",
"apiTokens": []string{"sk-openai-multi-model"},
"modelMapping": map[string]string{
"gpt-3.5-turbo": "gpt-3.5-turbo",
"gpt-4": "gpt-4",
"text-embedding-ada-002": "text-embedding-ada-002",
"dall-e-3": "dall-e-3",
},
},
})
return data
}()
// 测试配置OpenAI自定义域名配置直接路径
var openAICustomDomainDirectPathConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "openai",
"apiTokens": []string{"sk-openai-custom-domain"},
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
"openaiCustomUrl": "https://custom.openai.com/v1",
},
})
return data
}()
// 测试配置OpenAI自定义域名配置间接路径
var openAICustomDomainIndirectPathConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "openai",
"apiTokens": []string{"sk-openai-custom-domain"},
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
"openaiCustomUrl": "https://custom.openai.com/api",
},
})
return data
}()
// 测试配置OpenAI完整配置包含responseJsonSchema等字段
var completeOpenAIConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"provider": map[string]interface{}{
"type": "openai",
"apiTokens": []string{"sk-openai-complete"},
"modelMapping": map[string]string{
"*": "gpt-3.5-turbo",
},
"responseJsonSchema": map[string]interface{}{
"type": "json_object",
},
"failover": map[string]interface{}{
"enabled": false,
},
"retryOnFailure": map[string]interface{}{
"enabled": false,
},
},
})
return data
}()
func RunOpenAIParseConfigTests(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本OpenAI配置解析
t.Run("basic openai config", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试OpenAI多模型配置解析
t.Run("openai multi model config", func(t *testing.T) {
host, status := test.NewTestHost(openAIMultiModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试OpenAI自定义域名配置直接路径
t.Run("openai custom domain direct path config", func(t *testing.T) {
host, status := test.NewTestHost(openAICustomDomainDirectPathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试OpenAI自定义域名配置间接路径
t.Run("openai custom domain indirect path config", func(t *testing.T) {
host, status := test.NewTestHost(openAICustomDomainIndirectPathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试OpenAI完整配置解析
t.Run("openai complete config", func(t *testing.T) {
host, status := test.NewTestHost(completeOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
})
}
func RunOpenAIOnHttpRequestHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试OpenAI请求头处理聊天完成接口
t.Run("openai chat completion request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 应该返回HeaderStopIteration因为需要处理请求体
require.Equal(t, types.HeaderStopIteration, action)
// 验证请求头是否被正确处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为OpenAI默认域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost, "Host header should exist")
require.Equal(t, "api.openai.com", hostValue, "Host should be changed to OpenAI default domain")
// 验证Authorization是否被设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist")
require.Contains(t, authValue, "sk-openai-test123456789", "Authorization should contain OpenAI API token")
// 验证Path是否被正确处理
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath, "Path header should exist")
require.Contains(t, pathValue, "/v1/chat/completions", "Path should contain chat completions endpoint")
// 验证Content-Length是否被删除
_, hasContentLength := test.GetHeaderValue(requestHeaders, "Content-Length")
require.False(t, hasContentLength, "Content-Length header should be deleted")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasOpenAILogs := false
for _, log := range debugLogs {
if strings.Contains(log, "openai") {
hasOpenAILogs = true
break
}
}
require.True(t, hasOpenAILogs, "Should have OpenAI processing logs")
})
// 测试OpenAI请求头处理嵌入接口
t.Run("openai embeddings request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证嵌入接口的请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.openai.com", hostValue)
// 验证Path转换
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Contains(t, pathValue, "/v1/embeddings", "Path should contain embeddings endpoint")
// 验证Authorization设置
authValue, hasAuth := test.GetHeaderValue(requestHeaders, "Authorization")
require.True(t, hasAuth, "Authorization header should exist for embeddings")
require.Contains(t, authValue, "sk-openai-test123456789", "Authorization should contain OpenAI API token")
})
// 测试OpenAI请求头处理图像生成接口
t.Run("openai image generation request headers", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证图像生成接口的请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host转换
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "api.openai.com", hostValue)
// 验证Path转换
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
require.Contains(t, pathValue, "/v1/images/generations", "Path should contain image generations endpoint")
})
// 测试OpenAI自定义域名请求头处理
t.Run("openai custom domain request headers", func(t *testing.T) {
host, status := test.NewTestHost(openAICustomDomainDirectPathConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
require.Equal(t, types.HeaderStopIteration, action)
// 验证自定义域名的请求头处理
requestHeaders := host.GetRequestHeaders()
require.NotNil(t, requestHeaders)
// 验证Host是否被改为自定义域名
hostValue, hasHost := test.GetHeaderValue(requestHeaders, ":authority")
require.True(t, hasHost)
require.Equal(t, "custom.openai.com", hostValue, "Host should be changed to custom domain")
// 验证Path是否被正确处理
pathValue, hasPath := test.GetHeaderValue(requestHeaders, ":path")
require.True(t, hasPath)
// 对于直接路径,应该保持原有路径
require.Contains(t, pathValue, "/v1/chat/completions", "Path should be preserved for direct custom path")
})
})
}
func RunOpenAIOnHttpRequestBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试OpenAI请求体处理聊天完成接口
t.Run("openai chat completion request body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证模型名称是否被正确映射
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Original model name should be preserved or mapped")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
infoLogs := host.GetInfoLogs()
// 验证是否有OpenAI相关的处理日志
hasOpenAILogs := false
for _, log := range debugLogs {
if strings.Contains(log, "openai") {
hasOpenAILogs = true
break
}
}
for _, log := range infoLogs {
if strings.Contains(log, "openai") {
hasOpenAILogs = true
break
}
}
require.True(t, hasOpenAILogs, "Should have OpenAI processing logs")
})
// 测试OpenAI请求体处理嵌入接口
t.Run("openai embeddings request body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证嵌入接口的请求体处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证模型名称映射
// 由于使用了通配符映射 "*": "gpt-3.5-turbo"text-embedding-ada-002 会被映射为 gpt-3.5-turbo
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Model name should be mapped via wildcard")
// 检查处理日志
debugLogs := host.GetDebugLogs()
hasEmbeddingLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "embeddings") || strings.Contains(log, "openai") {
hasEmbeddingLogs = true
break
}
}
require.True(t, hasEmbeddingLogs, "Should have embedding processing logs")
})
// 测试OpenAI请求体处理图像生成接口
t.Run("openai image generation request body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"dall-e-3","prompt":"test image"}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证图像生成接口的请求体处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证模型名称映射
// 由于使用了通配符映射 "*": "gpt-3.5-turbo"dall-e-3 会被映射为 gpt-3.5-turbo
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Model name should be mapped via wildcard")
// 检查处理日志
debugLogs := host.GetDebugLogs()
hasImageLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "image") || strings.Contains(log, "openai") {
hasImageLogs = true
break
}
}
require.True(t, hasImageLogs, "Should have image generation processing logs")
})
// 测试OpenAI请求体处理带responseJsonSchema配置
t.Run("openai request body with responseJsonSchema", func(t *testing.T) {
host, status := test.NewTestHost(completeOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
action := host.CallOnHttpRequestBody([]byte(requestBody))
require.Equal(t, types.ActionContinue, action)
// 验证请求体是否被正确处理
processedBody := host.GetRequestBody()
require.NotNil(t, processedBody)
// 验证responseJsonSchema是否被应用
// 注意由于test框架的限制我们可能需要检查日志或其他方式来验证处理结果
require.Contains(t, string(processedBody), "gpt-3.5-turbo", "Model name should be preserved")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasSchemaLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "response format") || strings.Contains(log, "openai") {
hasSchemaLogs = true
break
}
}
require.True(t, hasSchemaLogs, "Should have response format processing logs")
})
})
}
func RunOpenAIOnHttpResponseHeadersTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试OpenAI响应头处理聊天完成接口
t.Run("openai chat completion response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
{"X-Request-Id", "req-123"},
}
action := host.CallOnHttpResponseHeaders(responseHeaders)
require.Equal(t, types.ActionContinue, action)
// 验证响应头是否被正确处理
processedResponseHeaders := host.GetResponseHeaders()
require.NotNil(t, processedResponseHeaders)
// 验证状态码
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "200", statusValue, "Status should be 200")
// 验证Content-Type
contentTypeValue, hasContentType := test.GetHeaderValue(processedResponseHeaders, "Content-Type")
require.True(t, hasContentType, "Content-Type header should exist")
require.Equal(t, "application/json", contentTypeValue, "Content-Type should be application/json")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasResponseLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "response") || strings.Contains(log, "openai") {
hasResponseLogs = true
break
}
}
require.True(t, hasResponseLogs, "Should have response processing logs")
})
// 测试OpenAI响应头处理嵌入接口
t.Run("openai embeddings response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
{"X-Embedding-Model", "text-embedding-ada-002"},
}
action := host.CallOnHttpResponseHeaders(responseHeaders)
require.Equal(t, types.ActionContinue, action)
// 验证响应头处理
processedResponseHeaders := host.GetResponseHeaders()
require.NotNil(t, processedResponseHeaders)
// 验证嵌入模型信息
modelValue, hasModel := test.GetHeaderValue(processedResponseHeaders, "X-Embedding-Model")
require.True(t, hasModel, "Embedding model header should exist")
require.Equal(t, "text-embedding-ada-002", modelValue, "Embedding model should match configuration")
})
// 测试OpenAI响应头处理错误响应
t.Run("openai error response headers", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置错误响应头
errorResponseHeaders := [][2]string{
{":status", "429"},
{"Content-Type", "application/json"},
{"Retry-After", "60"},
}
action := host.CallOnHttpResponseHeaders(errorResponseHeaders)
require.Equal(t, types.ActionContinue, action)
// 验证错误响应头处理
processedResponseHeaders := host.GetResponseHeaders()
require.NotNil(t, processedResponseHeaders)
// 验证错误状态码
statusValue, hasStatus := test.GetHeaderValue(processedResponseHeaders, ":status")
require.True(t, hasStatus, "Status header should exist")
require.Equal(t, "429", statusValue, "Status should be 429 (Too Many Requests)")
// 验证重试信息
retryValue, hasRetry := test.GetHeaderValue(processedResponseHeaders, "Retry-After")
require.True(t, hasRetry, "Retry-After header should exist")
require.Equal(t, "60", retryValue, "Retry-After should be 60 seconds")
})
})
}
func RunOpenAIOnHttpResponseBodyTests(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试OpenAI响应体处理聊天完成接口
t.Run("openai chat completion response body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}]}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体
responseBody := `{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I help you today?"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体是否被正确处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 验证响应体内容
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "chat.completion", "Response should contain chat completion object")
require.Contains(t, responseStr, "assistant", "Response should contain assistant role")
require.Contains(t, responseStr, "usage", "Response should contain usage information")
// 检查是否有相关的处理日志
debugLogs := host.GetDebugLogs()
hasResponseBodyLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "response") || strings.Contains(log, "body") || strings.Contains(log, "openai") {
hasResponseBodyLogs = true
break
}
}
require.True(t, hasResponseBodyLogs, "Should have response body processing logs")
})
// 测试OpenAI响应体处理嵌入接口
t.Run("openai embeddings response body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/embeddings"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"text-embedding-ada-002","input":"test text"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体
responseBody := `{
"object": "list",
"data": [{
"object": "embedding",
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": 0
}],
"model": "text-embedding-ada-002",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 验证嵌入响应内容
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "embedding", "Response should contain embedding object")
require.Contains(t, responseStr, "0.1", "Response should contain embedding vector")
require.Contains(t, responseStr, "text-embedding-ada-002", "Response should contain model name")
})
// 测试OpenAI响应体处理图像生成接口
t.Run("openai image generation response body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/images/generations"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置请求体
requestBody := `{"model":"dall-e-3","prompt":"test image"}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "application/json"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 设置响应体
responseBody := `{
"created": 1677652288,
"data": [{
"url": "https://example.com/image1.png",
"revised_prompt": "test image"
}]
}`
action := host.CallOnHttpResponseBody([]byte(responseBody))
require.Equal(t, types.ActionContinue, action)
// 验证响应体处理
processedResponseBody := host.GetResponseBody()
require.NotNil(t, processedResponseBody)
// 验证图像生成响应内容
responseStr := string(processedResponseBody)
require.Contains(t, responseStr, "data", "Response should contain data array")
require.Contains(t, responseStr, "url", "Response should contain image URL")
require.Contains(t, responseStr, "revised_prompt", "Response should contain revised prompt")
})
})
}
func RunOpenAIOnStreamingResponseBodyTests(t *testing.T) {
// 测试OpenAI响应体处理流式响应
test.RunTest(t, func(t *testing.T) {
t.Run("openai streaming response body", func(t *testing.T) {
host, status := test.NewTestHost(basicOpenAIConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 先设置请求头
host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/v1/chat/completions"},
{":method", "POST"},
{"Content-Type", "application/json"},
})
// 设置流式请求体
requestBody := `{"model":"gpt-3.5-turbo","messages":[{"role":"user","content":"test"}],"stream":true}`
host.CallOnHttpRequestBody([]byte(requestBody))
// 设置流式响应头
responseHeaders := [][2]string{
{":status", "200"},
{"Content-Type", "text/event-stream"},
}
host.CallOnHttpResponseHeaders(responseHeaders)
// 模拟流式响应体
chunk1 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"role":"assistant"},"index":0}]}
`
chunk2 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"Hello"},"index":0}]}
`
chunk3 := `data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"delta":{"content":"!"},"index":0}]}
`
chunk4 := `data: [DONE]
`
// 处理流式响应体
action1 := host.CallOnHttpStreamingResponseBody([]byte(chunk1), false)
require.Equal(t, types.ActionContinue, action1)
action2 := host.CallOnHttpStreamingResponseBody([]byte(chunk2), false)
require.Equal(t, types.ActionContinue, action2)
action3 := host.CallOnHttpStreamingResponseBody([]byte(chunk3), false)
require.Equal(t, types.ActionContinue, action3)
action4 := host.CallOnHttpStreamingResponseBody([]byte(chunk4), true)
require.Equal(t, types.ActionContinue, action4)
// 验证流式响应处理
// 注意流式响应可能不会在GetResponseBody中累积需要检查日志或其他方式验证
debugLogs := host.GetDebugLogs()
hasStreamingLogs := false
for _, log := range debugLogs {
if strings.Contains(log, "streaming") || strings.Contains(log, "chunk") || strings.Contains(log, "openai") {
hasStreamingLogs = true
break
}
}
require.True(t, hasStreamingLogs, "Should have streaming response processing logs")
})
})
}

View File

File diff suppressed because it is too large Load Diff

View File

@@ -59,18 +59,10 @@ func OverwriteRequestPath(path string) error {
}
func OverwriteRequestAuthorization(credential string) error {
if exist, _ := proxywasm.GetHttpRequestHeader(HeaderOriginalAuth); exist == "" {
if originAuth, err := proxywasm.GetHttpRequestHeader(HeaderAuthorization); err == nil {
_ = proxywasm.AddHttpRequestHeader(HeaderOriginalPath, originAuth)
}
}
return proxywasm.ReplaceHttpRequestHeader(HeaderAuthorization, credential)
}
func OverwriteRequestHostHeader(headers http.Header, host string) {
if originHost, err := proxywasm.GetHttpRequestHeader(HeaderAuthority); err == nil {
headers.Set(HeaderOriginalHost, originHost)
}
headers.Set(HeaderAuthority, host)
}
@@ -175,11 +167,6 @@ func SetOriginalRequestAuth(auth string) {
}
func OverwriteRequestAuthorizationHeader(headers http.Header, credential string) {
if exist := headers.Get(HeaderOriginalAuth); exist == "" {
if originAuth := headers.Get(HeaderAuthorization); originAuth != "" {
headers.Set(HeaderOriginalAuth, originAuth)
}
}
headers.Set(HeaderAuthorization, credential)
}

View File

@@ -5,14 +5,20 @@ go 1.24.1
toolchain go1.24.4
require (
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80
github.com/higress-group/wasm-go v1.0.1
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8
github.com/stretchr/testify v1.9.0
github.com/tidwall/gjson v1.18.0
github.com/tidwall/resp v0.1.1
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tetratelabs/wazero v1.7.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -2,14 +2,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80 h1:xqmtTZI0JQ2O+Lg9/CE6c+Tw9KD6FnvWw8EpLVuuvfg=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250611100342-5654e89a7a80/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.1 h1:T1m++qTEANp8+jwE0sxltwtaTKmrHCkLOp1m9N+YeqY=
github.com/higress-group/wasm-go v1.0.1/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0 h1:YGdj8KBzVjabU3STUfwMZghB+VlX6YLfJtLbrsWaOD0=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20250822030947-8345453fddd0/go.mod h1:tRI2LfMudSkKHhyv1uex3BWzcice2s/l8Ah8axporfA=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8 h1:rs+AH1wfZy4swzuAyiRXT7xPUm8gycXt9Gwy0tqOq0o=
github.com/higress-group/wasm-go v1.0.2-0.20250821081215-b573359becf8/go.mod h1:9k7L730huS/q4V5iH9WLDgf5ZUHEtfhM/uXcegKDG/M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=
github.com/tetratelabs/wazero v1.7.2/go.mod h1:ytl6Zuh20R/eROuyDaGPkp82O9C/DJfXAwJfQ3X6/7Y=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
@@ -19,5 +22,9 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

Some files were not shown because too many files have changed in this diff Show More