mirror of
https://github.com/alibaba/higress.git
synced 2026-03-02 23:51:11 +08:00
391 lines
12 KiB
Go
391 lines
12 KiB
Go
// Copyright (c) 2024 Alibaba Group Holding Ltd.
|
||
//
|
||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
// you may not use this file except in compliance with the License.
|
||
// You may obtain a copy of the License at
|
||
//
|
||
// http://www.apache.org/licenses/LICENSE-2.0
|
||
//
|
||
// Unless required by applicable law or agreed to in writing, software
|
||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
// See the License for the specific language governing permissions and
|
||
// limitations under the License.
|
||
|
||
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"testing"
|
||
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||
"github.com/higress-group/wasm-go/pkg/test"
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// 测试配置:基本配置
|
||
var basicConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"apiKey": "sk-test123456789",
|
||
"promptParam": "prompt",
|
||
"model": "text-davinci-003",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:自定义模型配置
|
||
var customModelConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"apiKey": "sk-test123456789",
|
||
"promptParam": "text",
|
||
"model": "curie",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:自定义提示参数配置
|
||
var customPromptParamConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"apiKey": "sk-test123456789",
|
||
"promptParam": "question",
|
||
"model": "text-davinci-003",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:自定义 ChatGPT URI 配置
|
||
var customUriConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"apiKey": "sk-test123456789",
|
||
"promptParam": "prompt",
|
||
"model": "text-davinci-003",
|
||
"chatgptUri": "https://custom-ai.example.com/v1/chat/completions",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:自定义 Human ID 和 AI ID 配置
|
||
var customIdsConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"apiKey": "sk-test123456789",
|
||
"promptParam": "prompt",
|
||
"model": "text-davinci-003",
|
||
"HumainId": "User:",
|
||
"AIId": "Assistant:",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:无效配置(缺少 API Key)
|
||
var invalidConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"promptParam": "prompt",
|
||
"model": "text-davinci-003",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
// 测试配置:无效 URI 配置
|
||
var invalidUriConfig = func() json.RawMessage {
|
||
data, _ := json.Marshal(map[string]interface{}{
|
||
"apiKey": "sk-test123456789",
|
||
"promptParam": "prompt",
|
||
"model": "text-davinci-003",
|
||
"chatgptUri": "://invalid-uri",
|
||
})
|
||
return data
|
||
}()
|
||
|
||
func TestParseConfig(t *testing.T) {
|
||
test.RunGoTest(t, func(t *testing.T) {
|
||
// 测试基本配置解析
|
||
t.Run("basic config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试自定义模型配置解析
|
||
t.Run("custom model config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(customModelConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试自定义提示参数配置解析
|
||
t.Run("custom prompt param config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(customPromptParamConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试自定义 URI 配置解析
|
||
t.Run("custom uri config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(customUriConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试自定义 ID 配置解析
|
||
t.Run("custom ids config", func(t *testing.T) {
|
||
host, status := test.NewTestHost(customIdsConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
config, err := host.GetMatchConfig()
|
||
require.NoError(t, err)
|
||
require.NotNil(t, config)
|
||
})
|
||
|
||
// 测试无效配置(缺少 API Key)
|
||
t.Run("invalid config - missing api key", func(t *testing.T) {
|
||
host, status := test.NewTestHost(invalidConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
|
||
// 测试无效 URI 配置
|
||
t.Run("invalid config - invalid uri", func(t *testing.T) {
|
||
host, status := test.NewTestHost(invalidUriConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusFailed, status)
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestOnHttpRequestHeaders(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
// 测试基本请求头处理(带查询参数)
|
||
t.Run("basic request headers with query params", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含查询参数
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?prompt=Hello, how are you?"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 由于需要调用外部 AI 服务,应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟外部 AI 服务响应
|
||
host.CallOnHttpCall([][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{":status", "200"},
|
||
}, []byte(`{"choices":[{"text":"I'm doing well, thank you for asking!"}]}`))
|
||
|
||
response := host.GetLocalResponse()
|
||
require.Equal(t, uint32(200), response.StatusCode)
|
||
require.Equal(t, `{"choices":[{"text":"I'm doing well, thank you for asking!"}]}`, string(response.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试自定义提示参数请求头处理
|
||
t.Run("custom prompt param request headers", func(t *testing.T) {
|
||
host, status := test.NewTestHost(customPromptParamConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,使用自定义提示参数
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?question=What is the weather like?"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 由于需要调用外部 AI 服务,应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟外部 AI 服务响应
|
||
host.CallOnHttpCall([][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{":status", "200"},
|
||
}, []byte(`{"choices":[{"text":"I don't have access to real-time weather information."}]}`))
|
||
|
||
response := host.GetLocalResponse()
|
||
require.Equal(t, uint32(200), response.StatusCode)
|
||
require.Equal(t, `{"choices":[{"text":"I don't have access to real-time weather information."}]}`, string(response.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试缺少查询参数的情况
|
||
t.Run("missing query params", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,不包含查询参数
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 应该返回 ActionContinue,因为缺少查询参数
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试缺少提示参数的情况
|
||
t.Run("missing prompt param", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含查询参数但不包含提示参数
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?other=value"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 应该返回 ActionContinue,因为缺少提示参数
|
||
require.Equal(t, types.ActionContinue, action)
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试空提示参数的情况
|
||
t.Run("empty prompt param", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头,包含空的提示参数
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?prompt="},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 由于需要调用外部 AI 服务,应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟外部 AI 服务响应
|
||
host.CallOnHttpCall([][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{":status", "200"},
|
||
}, []byte(`{"choices":[{"text":"Empty prompt response"}]}`))
|
||
|
||
response := host.GetLocalResponse()
|
||
require.Equal(t, uint32(200), response.StatusCode)
|
||
require.Equal(t, `{"choices":[{"text":"Empty prompt response"}]}`, string(response.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试外部服务调用成功的情况
|
||
t.Run("external service call success", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?prompt=Tell me a joke"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 由于需要调用外部 AI 服务,应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟外部 AI 服务成功响应
|
||
host.CallOnHttpCall([][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{":status", "200"},
|
||
}, []byte(`{"choices":[{"text":"Why don't scientists trust atoms? Because they make up everything!"}]}`))
|
||
|
||
response := host.GetLocalResponse()
|
||
require.Equal(t, uint32(200), response.StatusCode)
|
||
require.Equal(t, `{"choices":[{"text":"Why don't scientists trust atoms? Because they make up everything!"}]}`, string(response.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
|
||
// 测试外部服务调用失败的情况
|
||
t.Run("external service call failure", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 设置请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?prompt=Hello"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 由于需要调用外部 AI 服务,应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 模拟外部 AI 服务失败响应
|
||
host.CallOnHttpCall([][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{":status", "429"},
|
||
}, []byte(`{"error":"Rate limit exceeded"}`))
|
||
|
||
response := host.GetLocalResponse()
|
||
require.Equal(t, uint32(429), response.StatusCode)
|
||
require.Equal(t, `{"error":"Rate limit exceeded"}`, string(response.Data))
|
||
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|
||
|
||
func TestCompleteFlow(t *testing.T) {
|
||
test.RunTest(t, func(t *testing.T) {
|
||
t.Run("complete chatgpt proxy flow", func(t *testing.T) {
|
||
host, status := test.NewTestHost(basicConfig)
|
||
defer host.Reset()
|
||
require.Equal(t, types.OnPluginStartStatusOK, status)
|
||
|
||
// 1. 处理请求头
|
||
action := host.CallOnHttpRequestHeaders([][2]string{
|
||
{":authority", "example.com"},
|
||
{":path", "/api/chat?prompt=What is artificial intelligence?"},
|
||
{":method", "GET"},
|
||
})
|
||
|
||
// 由于需要调用外部 AI 服务,应该返回 ActionPause
|
||
require.Equal(t, types.ActionPause, action)
|
||
|
||
// 2. 模拟外部 AI 服务响应
|
||
host.CallOnHttpCall([][2]string{
|
||
{"Content-Type", "application/json"},
|
||
{":status", "200"},
|
||
}, []byte(`{"choices":[{"text":"Artificial Intelligence (AI) is a branch of computer science that aims to create systems capable of performing tasks that typically require human intelligence."}]}`))
|
||
|
||
response := host.GetLocalResponse()
|
||
require.Equal(t, uint32(200), response.StatusCode)
|
||
require.Equal(t, `{"choices":[{"text":"Artificial Intelligence (AI) is a branch of computer science that aims to create systems capable of performing tasks that typically require human intelligence."}]}`, string(response.Data))
|
||
|
||
// 3. 完成请求
|
||
host.CompleteHttp()
|
||
})
|
||
})
|
||
}
|