mirror of
https://github.com/alibaba/higress.git
synced 2026-06-07 03:37:28 +08:00
optimize plugin sdk (#1930)
This commit is contained in:
@@ -9,9 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/google/uuid"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
)
|
||||
|
||||
@@ -56,35 +56,35 @@ func (g *geminiProvider) GetProviderType() string {
|
||||
return providerTypeGemini
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName, log)
|
||||
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
|
||||
g.config.handleRequestHeaders(g, ctx, apiName)
|
||||
// Delay the header processing to allow changing streaming mode in OnRequestBody
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
|
||||
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
|
||||
util.OverwriteRequestHostHeader(headers, geminiDomain)
|
||||
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
|
||||
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
|
||||
if !g.config.isSupportedAPI(apiName) {
|
||||
return types.ActionContinue, errUnsupportedApiName
|
||||
}
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
|
||||
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return g.onChatCompletionRequestBody(ctx, body, headers, log)
|
||||
return g.onChatCompletionRequestBody(ctx, body, headers)
|
||||
} else {
|
||||
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
|
||||
return g.onEmbeddingsRequestBody(ctx, body, headers)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
|
||||
err := g.config.parseRequestAndMapModel(ctx, request, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -95,9 +95,9 @@ func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
|
||||
request := &embeddingsRequest{}
|
||||
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
|
||||
if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
|
||||
@@ -107,7 +107,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
|
||||
return json.Marshal(geminiRequest)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
|
||||
log.Infof("chunk body:%s", string(chunk))
|
||||
if isLastChunk || len(chunk) == 0 {
|
||||
return nil, nil
|
||||
@@ -143,15 +143,15 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
|
||||
return []byte(modifiedResponseChunk), nil
|
||||
}
|
||||
|
||||
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
|
||||
if apiName == ApiNameChatCompletion {
|
||||
return g.onChatCompletionResponseBody(ctx, body, log)
|
||||
return g.onChatCompletionResponseBody(ctx, body)
|
||||
} else {
|
||||
return g.onEmbeddingsResponseBody(ctx, body, log)
|
||||
return g.onEmbeddingsResponseBody(ctx, body)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
geminiResponse := &geminiChatResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
|
||||
@@ -164,7 +164,7 @@ func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, b
|
||||
return json.Marshal(response)
|
||||
}
|
||||
|
||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
|
||||
geminiResponse := &geminiEmbeddingResponse{}
|
||||
if err := json.Unmarshal(body, geminiResponse); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
|
||||
@@ -434,7 +434,7 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa
|
||||
}
|
||||
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
|
||||
if err != nil {
|
||||
proxywasm.LogErrorf("get toolCalls from gemini response failed: " + err.Error())
|
||||
log.Errorf("get toolCalls from gemini response failed: " + err.Error())
|
||||
return toolCalls
|
||||
}
|
||||
toolCall := toolCall{
|
||||
|
||||
Reference in New Issue
Block a user