optimize plugin sdk (#1930)

This commit is contained in:
澄潭
2025-03-22 22:46:37 +08:00
committed by GitHub
parent 1812a6b0a9
commit 45fbc8b084
117 changed files with 1036 additions and 766 deletions

View File

@@ -9,9 +9,9 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -56,35 +56,35 @@ func (g *geminiProvider) GetProviderType() string {
return providerTypeGemini
}
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
g.config.handleRequestHeaders(g, ctx, apiName, log)
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
g.config.handleRequestHeaders(g, ctx, apiName)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
}
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionRequestBody(ctx, body, headers, log)
return g.onChatCompletionRequestBody(ctx, body, headers)
} else {
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
return g.onEmbeddingsRequestBody(ctx, body, headers)
}
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
err := g.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
@@ -95,9 +95,9 @@ func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &embeddingsRequest{}
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
@@ -107,7 +107,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
return nil, nil
@@ -143,15 +143,15 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
return []byte(modifiedResponseChunk), nil
}
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionResponseBody(ctx, body, log)
return g.onChatCompletionResponseBody(ctx, body)
} else {
return g.onEmbeddingsResponseBody(ctx, body, log)
return g.onEmbeddingsResponseBody(ctx, body)
}
}
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
@@ -164,7 +164,7 @@ func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, b
return json.Marshal(response)
}
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
geminiResponse := &geminiEmbeddingResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
@@ -434,7 +434,7 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
proxywasm.LogErrorf("get toolCalls from gemini response failed: " + err.Error())
log.Errorf("get toolCalls from gemini response failed: " + err.Error())
return toolCalls
}
toolCall := toolCall{