From d5a9ff3a989c707291c4cea87a4e4cb737e715b7 Mon Sep 17 00:00:00 2001 From: Kent Dong Date: Tue, 16 Jul 2024 18:38:43 +0800 Subject: [PATCH] fix: Fix possible type-casting related panics in ai-proxy plugin (#1127) --- plugins/wasm-go/extensions/ai-proxy/main.go | 8 ++++---- .../extensions/ai-proxy/provider/baidu.go | 4 ++-- .../extensions/ai-proxy/provider/claude.go | 4 ++-- .../extensions/ai-proxy/provider/hunyuan.go | 4 ++-- .../extensions/ai-proxy/provider/minimax.go | 6 +++--- .../wasm-go/extensions/ai-proxy/provider/qwen.go | 9 +++------ plugins/wasm-go/pkg/wrapper/plugin_wrapper.go | 16 ++++++++++++++++ 7 files changed, 32 insertions(+), 19 deletions(-) diff --git a/plugins/wasm-go/extensions/ai-proxy/main.go b/plugins/wasm-go/extensions/ai-proxy/main.go index da81e570e..9be121dcf 100644 --- a/plugins/wasm-go/extensions/ai-proxy/main.go +++ b/plugins/wasm-go/extensions/ai-proxy/main.go @@ -99,7 +99,7 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType()) if handler, ok := activeProvider.(provider.RequestBodyHandler); ok { - apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnRequestBody(ctx, apiName, body, log) if err == nil { return action @@ -139,7 +139,7 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginCo } if handler, ok := activeProvider.(provider.ResponseHeadersHandler); ok { - apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnResponseHeaders(ctx, apiName, log) if err == nil { return action @@ -171,7 +171,7 @@ func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.Plugin log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk)) if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok { - apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log) if err == nil && modifiedChunk != nil { return modifiedChunk @@ -193,7 +193,7 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi //log.Debugf("response body: %s", string(body)) if handler, ok := activeProvider.(provider.ResponseBodyHandler); ok { - apiName := ctx.GetContext(ctxKeyApiName).(provider.ApiName) + apiName, _ := ctx.GetContext(ctxKeyApiName).(provider.ApiName) action, err := handler.OnResponseBody(ctx, apiName, body, log) if err == nil { return action diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go index 2b8e3972d..eca234a70 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/baidu.go @@ -298,7 +298,7 @@ func (b *baiduProvider) responseBaidu2OpenAI(ctx wrapper.HttpContext, response * return &chatCompletionResponse{ Id: response.Id, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletion, Choices: []chatCompletionChoice{choice}, @@ -321,7 +321,7 @@ func (b *baiduProvider) streamResponseBaidu2OpenAI(ctx wrapper.HttpContext, resp return &chatCompletionResponse{ Id: response.Id, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletion, Choices: []chatCompletionChoice{choice}, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go index ad9c51f8f..439cbadbb 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/claude.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/claude.go @@ -292,7 +292,7 @@ func (c *claudeProvider) responseClaude2OpenAI(ctx wrapper.HttpContext, origResp return &chatCompletionResponse{ Id: origResponse.Id, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletion, Choices: []chatCompletionChoice{choice}, @@ -356,7 +356,7 @@ func createChatCompletionResponse(ctx wrapper.HttpContext, response *claudeTextG return &chatCompletionResponse{ Id: response.Message.Id, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), Object: objectChatCompletionChunk, Choices: []chatCompletionChoice{choice}, } diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go index 2c30ad5bf..c1fc736d6 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go @@ -351,7 +351,7 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex openAIFormattedChunk := &chatCompletionResponse{ Id: hunyuanFormattedChunk.Id, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletionChunk, Usage: usage{ @@ -470,7 +470,7 @@ func (m *hunyuanProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, h return &chatCompletionResponse{ Id: hunyuanResponse.Response.Id, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletion, Choices: choices, diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go index 44720f7fb..4a6d840f9 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/minimax.go @@ -222,9 +222,9 @@ func (m *minimaxProvider) OnResponseHeaders(ctx wrapper.HttpContext, apiName Api return types.ActionContinue, nil } // 模型对应接口为ChatCompletion v2,跳过OnStreamingResponseBody()和OnResponseBody() - model := ctx.GetContext(ctxKeyFinalRequestModel) - if model != nil { - _, ok := chatCompletionProModels[model.(string)] + model := ctx.GetStringContext(ctxKeyFinalRequestModel, "") + if model != "" { + _, ok := chatCompletionProModels[model] if !ok { ctx.DontReadResponseBody() return types.ActionContinue, nil diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go index fa213493e..53efa965e 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/qwen.go @@ -229,10 +229,7 @@ func (m *qwenProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name Api receivedBody = append(bufferedStreamingBody, chunk...) } - incrementalStreaming, err := ctx.GetContext(ctxKeyIncrementalStreaming).(bool) - if !err { - incrementalStreaming = false - } + incrementalStreaming := ctx.GetBoolContext(ctxKeyIncrementalStreaming, false) eventStartIndex, lineStartIndex, valueStartIndex := -1, -1, -1 @@ -387,7 +384,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen return &chatCompletionResponse{ Id: qwenResponse.RequestId, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), SystemFingerprint: "", Object: objectChatCompletion, Choices: choices, @@ -403,7 +400,7 @@ func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpCont baseMessage := chatCompletionResponse{ Id: qwenResponse.RequestId, Created: time.Now().UnixMilli() / 1000, - Model: ctx.GetContext(ctxKeyFinalRequestModel).(string), + Model: ctx.GetStringContext(ctxKeyFinalRequestModel, ""), Choices: make([]chatCompletionChoice, 0), SystemFingerprint: "", Object: objectChatCompletionChunk, diff --git a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go index 1b1210fe5..dfdcf9e4d 100644 --- a/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go +++ b/plugins/wasm-go/pkg/wrapper/plugin_wrapper.go @@ -38,6 +38,8 @@ type HttpContext interface { Method() string SetContext(key string, value interface{}) GetContext(key string) interface{} + GetBoolContext(key string, defaultValue bool) bool + GetStringContext(key, defaultValue string) string // If the onHttpRequestBody handle is not set, the request body will not be read by default DontReadRequestBody() // If the onHttpResponseBody handle is not set, the request body will not be read by default @@ -297,6 +299,20 @@ func (ctx *CommonHttpCtx[PluginConfig]) GetContext(key string) interface{} { return ctx.userContext[key] } +func (ctx *CommonHttpCtx[PluginConfig]) GetBoolContext(key string, defaultValue bool) bool { + if b, ok := ctx.userContext[key].(bool); ok { + return b + } + return defaultValue +} + +func (ctx *CommonHttpCtx[PluginConfig]) GetStringContext(key, defaultValue string) string { + if s, ok := ctx.userContext[key].(string); ok { + return s + } + return defaultValue +} + func (ctx *CommonHttpCtx[PluginConfig]) Scheme() string { proxywasm.SetEffectiveContext(ctx.contextID) return GetRequestScheme()