Files
higress/plugins/wasm-go/extensions/chatgpt-proxy/main_test.go

391 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Copyright (c) 2024 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/json"
"testing"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/higress-group/wasm-go/pkg/test"
"github.com/stretchr/testify/require"
)
// 测试配置:基本配置
var basicConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test123456789",
"promptParam": "prompt",
"model": "text-davinci-003",
})
return data
}()
// 测试配置:自定义模型配置
var customModelConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test123456789",
"promptParam": "text",
"model": "curie",
})
return data
}()
// 测试配置:自定义提示参数配置
var customPromptParamConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test123456789",
"promptParam": "question",
"model": "text-davinci-003",
})
return data
}()
// 测试配置:自定义 ChatGPT URI 配置
var customUriConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test123456789",
"promptParam": "prompt",
"model": "text-davinci-003",
"chatgptUri": "https://custom-ai.example.com/v1/chat/completions",
})
return data
}()
// 测试配置:自定义 Human ID 和 AI ID 配置
var customIdsConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test123456789",
"promptParam": "prompt",
"model": "text-davinci-003",
"HumainId": "User:",
"AIId": "Assistant:",
})
return data
}()
// 测试配置:无效配置(缺少 API Key
var invalidConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"promptParam": "prompt",
"model": "text-davinci-003",
})
return data
}()
// 测试配置:无效 URI 配置
var invalidUriConfig = func() json.RawMessage {
data, _ := json.Marshal(map[string]interface{}{
"apiKey": "sk-test123456789",
"promptParam": "prompt",
"model": "text-davinci-003",
"chatgptUri": "://invalid-uri",
})
return data
}()
func TestParseConfig(t *testing.T) {
test.RunGoTest(t, func(t *testing.T) {
// 测试基本配置解析
t.Run("basic config", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义模型配置解析
t.Run("custom model config", func(t *testing.T) {
host, status := test.NewTestHost(customModelConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义提示参数配置解析
t.Run("custom prompt param config", func(t *testing.T) {
host, status := test.NewTestHost(customPromptParamConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义 URI 配置解析
t.Run("custom uri config", func(t *testing.T) {
host, status := test.NewTestHost(customUriConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试自定义 ID 配置解析
t.Run("custom ids config", func(t *testing.T) {
host, status := test.NewTestHost(customIdsConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
config, err := host.GetMatchConfig()
require.NoError(t, err)
require.NotNil(t, config)
})
// 测试无效配置(缺少 API Key
t.Run("invalid config - missing api key", func(t *testing.T) {
host, status := test.NewTestHost(invalidConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
// 测试无效 URI 配置
t.Run("invalid config - invalid uri", func(t *testing.T) {
host, status := test.NewTestHost(invalidUriConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusFailed, status)
})
})
}
func TestOnHttpRequestHeaders(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
// 测试基本请求头处理(带查询参数)
t.Run("basic request headers with query params", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含查询参数
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?prompt=Hello, how are you?"},
{":method", "GET"},
})
// 由于需要调用外部 AI 服务,应该返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟外部 AI 服务响应
host.CallOnHttpCall([][2]string{
{"Content-Type", "application/json"},
{":status", "200"},
}, []byte(`{"choices":[{"text":"I'm doing well, thank you for asking!"}]}`))
response := host.GetLocalResponse()
require.Equal(t, uint32(200), response.StatusCode)
require.Equal(t, `{"choices":[{"text":"I'm doing well, thank you for asking!"}]}`, string(response.Data))
host.CompleteHttp()
})
// 测试自定义提示参数请求头处理
t.Run("custom prompt param request headers", func(t *testing.T) {
host, status := test.NewTestHost(customPromptParamConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,使用自定义提示参数
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?question=What is the weather like?"},
{":method", "GET"},
})
// 由于需要调用外部 AI 服务,应该返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟外部 AI 服务响应
host.CallOnHttpCall([][2]string{
{"Content-Type", "application/json"},
{":status", "200"},
}, []byte(`{"choices":[{"text":"I don't have access to real-time weather information."}]}`))
response := host.GetLocalResponse()
require.Equal(t, uint32(200), response.StatusCode)
require.Equal(t, `{"choices":[{"text":"I don't have access to real-time weather information."}]}`, string(response.Data))
host.CompleteHttp()
})
// 测试缺少查询参数的情况
t.Run("missing query params", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,不包含查询参数
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat"},
{":method", "GET"},
})
// 应该返回 ActionContinue因为缺少查询参数
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试缺少提示参数的情况
t.Run("missing prompt param", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含查询参数但不包含提示参数
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?other=value"},
{":method", "GET"},
})
// 应该返回 ActionContinue因为缺少提示参数
require.Equal(t, types.ActionContinue, action)
host.CompleteHttp()
})
// 测试空提示参数的情况
t.Run("empty prompt param", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头,包含空的提示参数
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?prompt="},
{":method", "GET"},
})
// 由于需要调用外部 AI 服务,应该返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟外部 AI 服务响应
host.CallOnHttpCall([][2]string{
{"Content-Type", "application/json"},
{":status", "200"},
}, []byte(`{"choices":[{"text":"Empty prompt response"}]}`))
response := host.GetLocalResponse()
require.Equal(t, uint32(200), response.StatusCode)
require.Equal(t, `{"choices":[{"text":"Empty prompt response"}]}`, string(response.Data))
host.CompleteHttp()
})
// 测试外部服务调用成功的情况
t.Run("external service call success", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?prompt=Tell me a joke"},
{":method", "GET"},
})
// 由于需要调用外部 AI 服务,应该返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟外部 AI 服务成功响应
host.CallOnHttpCall([][2]string{
{"Content-Type", "application/json"},
{":status", "200"},
}, []byte(`{"choices":[{"text":"Why don't scientists trust atoms? Because they make up everything!"}]}`))
response := host.GetLocalResponse()
require.Equal(t, uint32(200), response.StatusCode)
require.Equal(t, `{"choices":[{"text":"Why don't scientists trust atoms? Because they make up everything!"}]}`, string(response.Data))
host.CompleteHttp()
})
// 测试外部服务调用失败的情况
t.Run("external service call failure", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 设置请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?prompt=Hello"},
{":method", "GET"},
})
// 由于需要调用外部 AI 服务,应该返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 模拟外部 AI 服务失败响应
host.CallOnHttpCall([][2]string{
{"Content-Type", "application/json"},
{":status", "429"},
}, []byte(`{"error":"Rate limit exceeded"}`))
response := host.GetLocalResponse()
require.Equal(t, uint32(429), response.StatusCode)
require.Equal(t, `{"error":"Rate limit exceeded"}`, string(response.Data))
host.CompleteHttp()
})
})
}
func TestCompleteFlow(t *testing.T) {
test.RunTest(t, func(t *testing.T) {
t.Run("complete chatgpt proxy flow", func(t *testing.T) {
host, status := test.NewTestHost(basicConfig)
defer host.Reset()
require.Equal(t, types.OnPluginStartStatusOK, status)
// 1. 处理请求头
action := host.CallOnHttpRequestHeaders([][2]string{
{":authority", "example.com"},
{":path", "/api/chat?prompt=What is artificial intelligence?"},
{":method", "GET"},
})
// 由于需要调用外部 AI 服务,应该返回 ActionPause
require.Equal(t, types.ActionPause, action)
// 2. 模拟外部 AI 服务响应
host.CallOnHttpCall([][2]string{
{"Content-Type", "application/json"},
{":status", "200"},
}, []byte(`{"choices":[{"text":"Artificial Intelligence (AI) is a branch of computer science that aims to create systems capable of performing tasks that typically require human intelligence."}]}`))
response := host.GetLocalResponse()
require.Equal(t, uint32(200), response.StatusCode)
require.Equal(t, `{"choices":[{"text":"Artificial Intelligence (AI) is a branch of computer science that aims to create systems capable of performing tasks that typically require human intelligence."}]}`, string(response.Data))
// 3. 完成请求
host.CompleteHttp()
})
})
}