optimize plugin sdk (#1930)

This commit is contained in:
澄潭
2025-03-22 22:46:37 +08:00
committed by GitHub
parent 1812a6b0a9
commit 45fbc8b084
117 changed files with 1036 additions and 766 deletions

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
@@ -79,7 +80,7 @@ type qwenProvider struct {
contextCache *contextCache
}
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if m.config.qwenDomain != "" {
util.OverwriteRequestHostHeader(headers, m.config.qwenDomain)
} else {
@@ -92,11 +93,11 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
}
}
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if m.config.qwenEnableCompatible {
if gjson.GetBytes(body, "model").Exists() {
rawModel := gjson.GetBytes(body, "model").String()
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
mappedModel := getMappedModel(rawModel, m.config.modelMapping)
newBody, err := sjson.SetBytes(body, "model", mappedModel)
if err != nil {
log.Errorf("Replace model error: %v", err)
@@ -108,11 +109,11 @@ func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
}
switch apiName {
case ApiNameChatCompletion:
return m.onChatCompletionRequestBody(ctx, body, headers, log)
return m.onChatCompletionRequestBody(ctx, body, headers)
case ApiNameEmbeddings:
return m.onEmbeddingsRequestBody(ctx, body, log)
return m.onEmbeddingsRequestBody(ctx, body)
default:
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
return m.config.defaultTransformRequestBody(ctx, apiName, body)
}
}
@@ -120,8 +121,8 @@ func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
@@ -131,16 +132,16 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
return nil
}
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
err := m.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
@@ -162,9 +163,9 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
return m.buildQwenTextGenerationRequest(ctx, request, streaming)
}
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
request := &embeddingsRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
@@ -175,7 +176,7 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error) {
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
return nil, nil
}
@@ -189,7 +190,7 @@ func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, e
}
var outputEvents []StreamEvent
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming)
for _, response := range responses {
responseBody, err := json.Marshal(response)
if err != nil {
@@ -203,15 +204,15 @@ func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, e
return outputEvents, nil
}
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if m.config.qwenEnableCompatible {
return body, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionResponseBody(ctx, body, log)
return m.onChatCompletionResponseBody(ctx, body)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsResponseBody(ctx, body, log)
return m.onEmbeddingsResponseBody(ctx, body)
}
if m.config.isSupportedAPI(apiName) {
return body, nil
@@ -219,7 +220,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
return nil, errUnsupportedApiName
}
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
qwenResponse := &qwenTextGenResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
@@ -228,7 +229,7 @@ func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, bod
return json.Marshal(response)
}
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
qwenResponse := &qwenTextEmbeddingResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
@@ -308,7 +309,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
}
}
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse {
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool) []*chatCompletionResponse {
baseMessage := chatCompletionResponse{
Id: qwenResponse.RequestId,
Created: time.Now().UnixMilli() / 1000,