optimize plugin sdk (#1930)

This commit is contained in:
澄潭
2025-03-22 22:46:37 +08:00
committed by GitHub
parent 1812a6b0a9
commit 45fbc8b084
117 changed files with 1036 additions and 766 deletions

View File

@@ -48,20 +48,20 @@ func (m *ai360Provider) GetProviderType() string {
return providerTypeAi360
}
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *ai360Provider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *ai360Provider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *ai360Provider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, ai360Domain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -70,16 +71,16 @@ func (m *azureProvider) GetProviderType() string {
return providerTypeAzure
}
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *azureProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
func (m *azureProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *azureProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
finalRequestUrl := *m.serviceUrl
if u, e := url.Parse(ctx.Path()); e == nil {
if len(u.Query()) != 0 {

View File

@@ -49,19 +49,19 @@ func (m *baichuanProvider) GetProviderType() string {
return providerTypeBaichuan
}
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *baichuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *baichuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *baichuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, baichuanDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -50,19 +50,19 @@ func (g *baiduProvider) GetProviderType() string {
return providerTypeBaidu
}
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
g.config.handleRequestHeaders(g, ctx, apiName, log)
func (g *baiduProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
g.config.handleRequestHeaders(g, ctx, apiName)
return nil
}
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *baiduProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
}
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (g *baiduProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, baiduDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -112,12 +113,12 @@ func (c *claudeProvider) GetProviderType() string {
return providerTypeClaude
}
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
c.config.handleRequestHeaders(c, ctx, apiName, log)
func (c *claudeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
c.config.handleRequestHeaders(c, ctx, apiName)
return nil
}
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), c.config.capabilities)
util.OverwriteRequestHostHeader(headers, claudeDomain)
@@ -130,26 +131,26 @@ func (c *claudeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
headers.Set("anthropic-version", c.config.claudeVersion)
}
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body)
}
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
return c.config.defaultTransformRequestBody(ctx, apiName, body)
}
request := &chatCompletionRequest{}
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := c.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
claudeRequest := c.buildClaudeTextGenRequest(request)
return json.Marshal(claudeRequest)
}
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
@@ -164,7 +165,7 @@ func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName
return json.Marshal(response)
}
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
@@ -185,7 +186,7 @@ func (c *claudeProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
log.Errorf("unable to unmarshal claude response: %v", err)
continue
}
response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse, log)
response := c.streamResponseClaude2OpenAI(ctx, &claudeResponse)
if response != nil {
responseBody, err := json.Marshal(response)
if err != nil {
@@ -266,7 +267,7 @@ func stopReasonClaude2OpenAI(reason *string) string {
}
}
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse, log wrapper.Log) *chatCompletionResponse {
func (c *claudeProvider) streamResponseClaude2OpenAI(ctx wrapper.HttpContext, origResponse *claudeTextGenStreamResponse) *chatCompletionResponse {
switch origResponse.Type {
case "message_start":
choice := chatCompletionChoice{

View File

@@ -48,19 +48,19 @@ func (c *cloudflareProvider) GetProviderType() string {
return providerTypeCloudflare
}
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
c.config.handleRequestHeaders(c, ctx, apiName, log)
func (c *cloudflareProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
c.config.handleRequestHeaders(c, ctx, apiName)
return nil
}
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (c *cloudflareProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !c.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body, log)
return c.config.handleRequestBody(c, c.contextCache, ctx, apiName, body)
}
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (c *cloudflareProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeader(headers, strings.Replace(cloudflareChatCompletionPath, "{account_id}", c.config.cloudflareAccountId, 1))
util.OverwriteRequestHostHeader(headers, cloudflareDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+c.config.GetApiTokenInUse(ctx))

View File

@@ -65,16 +65,16 @@ func (m *cohereProvider) GetProviderType() string {
return providerTypeCohere
}
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *cohereProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *cohereProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohereTextGenRequest {
@@ -96,19 +96,19 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
}
}
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, cohereDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
return m.config.defaultTransformRequestBody(ctx, apiName, body)
}
request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}

View File

@@ -8,6 +8,7 @@ import (
"net/url"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
@@ -64,7 +65,7 @@ type ContextInserter interface {
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
}
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
func (c *contextCache) GetContent(callback func(string, error)) error {
if callback == nil {
return errors.New("callback is nil")
}
@@ -106,27 +107,27 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
}
}
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte) error {
if c.loaded {
log.Debugf("context file loaded from cache")
insertContext(provider, c.content, nil, body, log)
insertContext(provider, c.content, nil, body)
return nil
}
log.Infof("loading context file from %s", c.fileUrl.String())
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode != http.StatusOK {
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil)
return
}
c.content = string(responseBody)
c.loaded = true
log.Debugf("content: %s", c.content)
insertContext(provider, c.content, nil, body, log)
insertContext(provider, c.content, nil, body)
}, c.timeout)
}
func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
func insertContext(provider Provider, content string, err error, body []byte) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
@@ -146,7 +147,7 @@ func insertContext(provider Provider, content string, err error, body []byte, lo
if err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
}
if err := replaceRequestBody(body, log); err != nil {
if err := replaceRequestBody(body); err != nil {
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
}
}

View File

@@ -42,12 +42,12 @@ func (m *cozeProvider) GetProviderType() string {
return providerTypeCoze
}
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *cozeProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *cozeProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, cozeDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")

View File

@@ -82,12 +82,12 @@ func (d *deeplProvider) GetProviderType() string {
return providerTypeDeepl
}
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
d.config.handleRequestHeaders(d, ctx, apiName, log)
func (d *deeplProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
d.config.handleRequestHeaders(d, ctx, apiName)
return nil
}
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if apiName != "" {
util.OverwriteRequestPathHeader(headers, deeplChatCompletionPath)
}
@@ -96,14 +96,14 @@ func (d *deeplProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestAuthorizationHeader(headers, "DeepL-Auth-Key "+d.config.GetApiTokenInUse(ctx))
}
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (d *deeplProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !d.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body)
}
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return nil, err
@@ -119,7 +119,7 @@ func (d *deeplProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, api
return json.Marshal(baiduRequest)
}
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (d *deeplProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}

View File

@@ -51,19 +51,19 @@ func (m *deepseekProvider) GetProviderType() string {
return providerTypeDeepSeek
}
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *deepseekProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *deepseekProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *deepseekProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, deepseekDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -50,12 +51,12 @@ func (d *difyProvider) GetProviderType() string {
return providerTypeDify
}
func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
d.config.handleRequestHeaders(d, ctx, apiName, log)
func (d *difyProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
d.config.handleRequestHeaders(d, ctx, apiName)
return nil
}
func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if d.config.difyApiUrl != "" {
log.Debugf("use local host: %s", d.config.difyApiUrl)
util.OverwriteRequestHostHeader(headers, d.config.difyApiUrl)
@@ -73,19 +74,19 @@ func (d *difyProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+d.config.GetApiTokenInUse(ctx))
}
func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body, log)
return d.config.handleRequestBody(d, d.contextCache, ctx, apiName, body)
}
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return d.config.defaultTransformRequestBody(ctx, apiName, body, log)
return d.config.defaultTransformRequestBody(ctx, apiName, body)
}
request := &chatCompletionRequest{}
err := d.config.parseRequestAndMapModel(ctx, request, body, log)
err := d.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
@@ -95,7 +96,7 @@ func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
return json.Marshal(difyRequest)
}
func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (d *difyProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
@@ -146,7 +147,7 @@ func (d *difyProvider) responseDify2OpenAI(ctx wrapper.HttpContext, response *Di
}
}
func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (d *difyProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}

View File

@@ -49,19 +49,19 @@ func (m *doubaoProvider) GetProviderType() string {
return providerTypeDoubao
}
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *doubaoProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *doubaoProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *doubaoProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, doubaoDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
@@ -125,11 +126,11 @@ func (c *ProviderConfig) initVariable() {
c.failover.ctxVmLease = provider + "-" + id + "-vmLease"
}
func parseConfig(json gjson.Result, config *any, log wrapper.Log) error {
func parseConfig(json gjson.Result, config *any) error {
return nil
}
func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Provider) error {
func (c *ProviderConfig) SetApiTokensFailover(activeProvider Provider) error {
c.initVariable()
// Reset shared data in case plugin configuration is updated
log.Debugf("ai-proxy plugin configuration is updated, reset shared data")
@@ -147,7 +148,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
wrapper.RegisteTickFunc(c.failover.healthCheckInterval, func() {
// Only the Wasm VM that successfully acquires the lease will perform health check
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID, log) {
if c.isFailoverEnabled() && c.tryAcquireOrRenewLease(vmID) {
log.Debugf("Successfully acquired or renewed lease for %v: %v", vmID, c.GetType())
unavailableTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
if err != nil {
@@ -157,7 +158,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
if len(unavailableTokens) > 0 {
for _, apiToken := range unavailableTokens {
log.Debugf("Perform health check for unavailable apiTokens: %s", strings.Join(unavailableTokens, ", "))
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody(log)
healthCheckEndpoint, headers, body := c.generateRequestHeadersAndBody()
healthCheckClient = wrapper.NewClusterClient(wrapper.TargetCluster{
Cluster: healthCheckEndpoint.Cluster,
})
@@ -165,7 +166,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
ctx := createHttpContext()
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body, log)
modifiedHeaders, modifiedBody, err := c.transformRequestHeadersAndBody(ctx, activeProvider, headers, body)
if err != nil {
log.Errorf("Failed to transform request headers and body: %v", err)
}
@@ -173,7 +174,7 @@ func (c *ProviderConfig) SetApiTokensFailover(log wrapper.Log, activeProvider Pr
// The apiToken for ChatCompletion and Embeddings can be the same, so we only need to health check ChatCompletion
err = healthCheckClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
if statusCode == 200 {
c.handleAvailableApiToken(apiToken, log)
c.handleAvailableApiToken(apiToken)
}
}, uint32(c.failover.healthCheckTimeout))
if err != nil {
@@ -191,19 +192,19 @@ func generateUrl(header http.Header) string {
return fmt.Sprintf("https://%s%s", header.Get(":authority"), header.Get(":path"))
}
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte, log wrapper.Log) (http.Header, []byte, error) {
func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, headers [][2]string, body []byte) (http.Header, []byte, error) {
modifiedHeaders := util.SliceToHeader(headers)
if handler, ok := activeProvider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders, log)
handler.TransformRequestHeaders(ctx, ApiNameChatCompletion, modifiedHeaders)
}
var err error
if handler, ok := activeProvider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body, log)
body, err = handler.TransformRequestBody(ctx, ApiNameChatCompletion, body)
} else if handler, ok := activeProvider.(TransformRequestBodyHeadersHandler); ok {
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, modifiedHeaders, log)
body, err = handler.TransformRequestBodyHeaders(ctx, ApiNameChatCompletion, body, modifiedHeaders)
} else {
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body, log)
body, err = c.defaultTransformRequestBody(ctx, ApiNameChatCompletion, body)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to transform request body: %v", err)
@@ -213,14 +214,14 @@ func (c *ProviderConfig) transformRequestHeadersAndBody(ctx wrapper.HttpContext,
}
func createHttpContext() *wrapper.CommonHttpCtx[any] {
setParseConfig := wrapper.ParseConfigBy[any](parseConfig)
setParseConfig := wrapper.ParseConfig[any](parseConfig)
vmCtx := wrapper.NewCommonVmCtx[any]("health-check", setParseConfig)
pluginCtx := vmCtx.NewPluginContext(rand.Uint32())
ctx := pluginCtx.NewHttpContext(rand.Uint32()).(*wrapper.CommonHttpCtx[any])
return ctx
}
func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthCheckEndpoint, [][2]string, []byte) {
func (c *ProviderConfig) generateRequestHeadersAndBody() (HealthCheckEndpoint, [][2]string, []byte) {
data, _, err := proxywasm.GetSharedData(c.failover.ctxHealthCheckEndpoint)
if err != nil {
log.Errorf("Failed to get request host and path: %v", err)
@@ -248,20 +249,20 @@ func (c *ProviderConfig) generateRequestHeadersAndBody(log wrapper.Log) (HealthC
return healthCheckEndpoint, headers, body
}
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bool {
func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string) bool {
now := time.Now().Unix()
data, cas, err := proxywasm.GetSharedData(c.failover.ctxVmLease)
if err != nil {
if errors.Is(err, types.ErrorStatusNotFound) {
return c.setLease(vmID, now, cas, log)
return c.setLease(vmID, now, cas)
} else {
log.Errorf("Failed to get lease: %v", err)
return false
}
}
if data == nil {
return c.setLease(vmID, now, cas, log)
return c.setLease(vmID, now, cas)
}
var lease Lease
@@ -275,13 +276,13 @@ func (c *ProviderConfig) tryAcquireOrRenewLease(vmID string, log wrapper.Log) bo
if lease.VMID == vmID || now-lease.Timestamp > 60 {
lease.VMID = vmID
lease.Timestamp = now
return c.setLease(vmID, now, cas, log)
return c.setLease(vmID, now, cas)
}
return false
}
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32, log wrapper.Log) bool {
func (c *ProviderConfig) setLease(vmID string, timestamp int64, cas uint32) bool {
lease := Lease{
VMID: vmID,
Timestamp: timestamp,
@@ -305,7 +306,7 @@ func generateVMID() string {
// When number of request successes exceeds the threshold during health check,
// add the apiToken back to the available list and remove it from the unavailable list
func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Log) {
func (c *ProviderConfig) handleAvailableApiToken(apiToken string) {
successApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount)
if err != nil {
log.Errorf("Failed to get successApiTokenRequestCount: %v", err)
@@ -315,18 +316,18 @@ func (c *ProviderConfig) handleAvailableApiToken(apiToken string, log wrapper.Lo
successCount := successApiTokenRequestCount[apiToken] + 1
if successCount >= c.failover.successThreshold {
log.Infof("healthcheck after failover: apiToken %s is available now, add it back to the apiTokens list", apiToken)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
addApiToken(c.failover.ctxApiTokens, apiToken, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
removeApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
addApiToken(c.failover.ctxApiTokens, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken)
} else {
log.Debugf("apiToken %s is still unavailable, the number of health check passed: %d, continue to health check...", apiToken, successCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken, log)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestSuccessCount, apiToken)
}
}
// When number of request failures exceeds the threshold,
// remove the apiToken from the available list and add it to the unavailable list
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string, log wrapper.Log) {
func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiToken string) {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
log.Errorf("Failed to get failureApiTokenRequestCount: %v", err)
@@ -346,26 +347,26 @@ func (c *ProviderConfig) handleUnavailableApiToken(ctx wrapper.HttpContext, apiT
failureCount := failureApiTokenRequestCount[apiToken] + 1
if failureCount >= c.failover.failureThreshold {
log.Infof("failover: apiToken %s is unavailable now, remove it from apiTokens list", apiToken)
removeApiToken(c.failover.ctxApiTokens, apiToken, log)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
removeApiToken(c.failover.ctxApiTokens, apiToken)
addApiToken(c.failover.ctxUnavailableApiTokens, apiToken)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
// Set the request host and path to shared data in case they are needed in apiToken health check
c.setHealthCheckEndpoint(ctx, log)
c.setHealthCheckEndpoint(ctx)
} else {
log.Debugf("apiToken %s is still available as it has not reached the failure threshold, the number of failed request: %d", apiToken, failureCount)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken, log)
addApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiToken)
}
}
func addApiToken(key, apiToken string, log wrapper.Log) {
modifyApiToken(key, apiToken, addApiTokenOperation, log)
func addApiToken(key, apiToken string) {
modifyApiToken(key, apiToken, addApiTokenOperation)
}
func removeApiToken(key, apiToken string, log wrapper.Log) {
modifyApiToken(key, apiToken, removeApiTokenOperation, log)
func removeApiToken(key, apiToken string) {
modifyApiToken(key, apiToken, removeApiTokenOperation)
}
func modifyApiToken(key, apiToken, op string, log wrapper.Log) {
func modifyApiToken(key, apiToken, op string) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokens, cas, err := getApiTokens(key)
if err != nil {
@@ -468,15 +469,15 @@ func getApiTokenRequestCount(key string) (map[string]int64, uint32, error) {
return apiTokens, cas, nil
}
func addApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation, log)
func addApiTokenRequestCount(key, apiToken string) {
modifyApiTokenRequestCount(key, apiToken, addApiTokenRequestCountOperation)
}
func resetApiTokenRequestCount(key, apiToken string, log wrapper.Log) {
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation, log)
func resetApiTokenRequestCount(key, apiToken string) {
modifyApiTokenRequestCount(key, apiToken, resetApiTokenRequestCountOperation)
}
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string, log wrapper.Log) {
func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string) {
if c.isFailoverEnabled() {
failureApiTokenRequestCount, _, err := getApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount)
if err != nil {
@@ -484,12 +485,12 @@ func (c *ProviderConfig) ResetApiTokenRequestFailureCount(apiTokenInUse string,
}
if _, ok := failureApiTokenRequestCount[apiTokenInUse]; ok {
log.Infof("Reset apiToken %s request failure count", apiTokenInUse)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse, log)
resetApiTokenRequestCount(c.failover.ctxApiTokenRequestFailureCount, apiTokenInUse)
}
}
}
func modifyApiTokenRequestCount(key, apiToken string, op string, log wrapper.Log) {
func modifyApiTokenRequestCount(key, apiToken string, op string) {
for attempt := 1; attempt <= casMaxRetries; attempt++ {
apiTokenRequestCount, cas, err := getApiTokenRequestCount(key)
if err != nil {
@@ -524,7 +525,7 @@ func (c *ProviderConfig) initApiTokens() error {
return setApiTokens(c.failover.ctxApiTokens, c.apiTokens, 0)
}
func (c *ProviderConfig) GetGlobalRandomToken(log wrapper.Log) string {
func (c *ProviderConfig) GetGlobalRandomToken() string {
apiTokens, _, err := getApiTokens(c.failover.ctxApiTokens)
unavailableApiTokens, _, err := getApiTokens(c.failover.ctxUnavailableApiTokens)
log.Debugf("apiTokens: %v, unavailableApiTokens: %v", apiTokens, unavailableApiTokens)
@@ -550,7 +551,7 @@ func (c *ProviderConfig) GetAvailableApiToken(ctx wrapper.HttpContext) []string
}
// SetAvailableApiTokens set available apiTokens of current request in the context, will be used in the retryOnFailure
func (c *ProviderConfig) SetAvailableApiTokens(ctx wrapper.HttpContext, log wrapper.Log) {
func (c *ProviderConfig) SetAvailableApiTokens(ctx wrapper.HttpContext) {
var apiTokens []string
if c.isFailoverEnabled() {
apiTokens, _, _ = getApiTokens(c.failover.ctxApiTokens)
@@ -572,14 +573,14 @@ func (c *ProviderConfig) resetSharedData() {
_ = proxywasm.SetSharedData(c.failover.ctxApiTokenRequestFailureCount, nil, 0)
}
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string, log wrapper.Log) types.Action {
func (c *ProviderConfig) OnRequestFailed(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, status string) types.Action {
if c.isFailoverEnabled() && util.MatchStatus(status, c.failover.failoverOnStatus) {
log.Warnf("apiToken:%s need failover, error status:%s", apiTokenInUse, status)
c.handleUnavailableApiToken(ctx, apiTokenInUse, log)
c.handleUnavailableApiToken(ctx, apiTokenInUse)
}
if c.IsRetryOnFailureEnabled() && util.MatchStatus(status, c.retryOnFailure.retryOnStatus) {
log.Warnf("need retry, notice that retry response will be bufferd, error status:%s", status)
err := c.retryFailedRequest(activeProvider, ctx, apiTokenInUse, apiTokens, log)
err := c.retryFailedRequest(activeProvider, ctx, apiTokenInUse, apiTokens)
if err != nil {
log.Errorf("retryFailedRequest failed, err:%v", err)
return types.ActionContinue
@@ -598,11 +599,11 @@ func (c *ProviderConfig) GetApiTokenInUse(ctx wrapper.HttpContext) string {
return token
}
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.Log) {
func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext) {
var apiToken string
// if enable apiToken failover, only use available apiToken from global apiTokens list
if c.isFailoverEnabled() {
apiToken = c.GetGlobalRandomToken(log)
apiToken = c.GetGlobalRandomToken()
} else {
apiToken = c.GetRandomToken()
}
@@ -610,7 +611,7 @@ func (c *ProviderConfig) SetApiTokenInUse(ctx wrapper.HttpContext, log wrapper.L
ctx.SetContext(c.failover.ctxApiTokenInUse, apiToken)
}
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext, log wrapper.Log) {
func (c *ProviderConfig) setHealthCheckEndpoint(ctx wrapper.HttpContext) {
cluster, err := proxywasm.GetProperty([]string{"cluster_name"})
if err != nil {
log.Errorf("Failed to get cluster_name: %v", err)

View File

@@ -9,9 +9,9 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/google/uuid"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -56,35 +56,35 @@ func (g *geminiProvider) GetProviderType() string {
return providerTypeGemini
}
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
g.config.handleRequestHeaders(g, ctx, apiName, log)
func (g *geminiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
g.config.handleRequestHeaders(g, ctx, apiName)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (g *geminiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, geminiDomain)
headers.Set(geminiApiKeyHeader, g.config.GetApiTokenInUse(ctx))
}
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *geminiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
}
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionRequestBody(ctx, body, headers, log)
return g.onChatCompletionRequestBody(ctx, body, headers)
} else {
return g.onEmbeddingsRequestBody(ctx, body, headers, log)
return g.onEmbeddingsRequestBody(ctx, body, headers)
}
}
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := g.config.parseRequestAndMapModel(ctx, request, body, log)
err := g.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
@@ -95,9 +95,9 @@ func (g *geminiProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, bo
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &embeddingsRequest{}
if err := g.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := g.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
path := g.getRequestPath(ApiNameEmbeddings, request.Model, false)
@@ -107,7 +107,7 @@ func (g *geminiProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body [
return json.Marshal(geminiRequest)
}
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
log.Infof("chunk body:%s", string(chunk))
if isLastChunk || len(chunk) == 0 {
return nil, nil
@@ -143,15 +143,15 @@ func (g *geminiProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name A
return []byte(modifiedResponseChunk), nil
}
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName == ApiNameChatCompletion {
return g.onChatCompletionResponseBody(ctx, body, log)
return g.onChatCompletionResponseBody(ctx, body)
} else {
return g.onEmbeddingsResponseBody(ctx, body, log)
return g.onEmbeddingsResponseBody(ctx, body)
}
}
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
geminiResponse := &geminiChatResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal gemini chat response: %v", err)
@@ -164,7 +164,7 @@ func (g *geminiProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, b
return json.Marshal(response)
}
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (g *geminiProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
geminiResponse := &geminiEmbeddingResponse{}
if err := json.Unmarshal(body, geminiResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal gemini embeddings response: %v", err)
@@ -434,7 +434,7 @@ func (g *geminiProvider) buildToolCalls(candidate *geminiChatCandidate) []toolCa
}
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
proxywasm.LogErrorf("get toolCalls from gemini response failed: " + err.Error())
log.Errorf("get toolCalls from gemini response failed: " + err.Error())
return toolCalls
}
toolCall := toolCall{

View File

@@ -51,20 +51,20 @@ func (m *githubProvider) GetProviderType() string {
return providerTypeGithub
}
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *githubProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *githubProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *githubProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, githubDomain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestAuthorizationHeader(headers, m.config.GetApiTokenInUse(ctx))

View File

@@ -48,19 +48,19 @@ func (g *groqProvider) GetProviderType() string {
return providerTypeGroq
}
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
g.config.handleRequestHeaders(g, ctx, apiName, log)
func (g *groqProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
g.config.handleRequestHeaders(g, ctx, apiName)
return nil
}
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (g *groqProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !g.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body, log)
return g.config.handleRequestBody(g, g.contextCache, ctx, apiName, body)
}
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (g *groqProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), g.config.capabilities)
util.OverwriteRequestHostHeader(headers, groqDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+g.config.GetApiTokenInUse(ctx))

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -135,13 +136,13 @@ func (m *hunyuanProvider) useOpenAICompatibleAPI() bool {
return len(m.config.hunyuanAuthId) == 0 && len(m.config.hunyuanAuthKey) == 0
}
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *hunyuanProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if m.useOpenAICompatibleAPI() {
util.OverwriteRequestHostHeader(headers, hunyuanOpenAiDomain)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
@@ -156,7 +157,7 @@ func (m *hunyuanProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNa
}
// hunyuan 的 OnRequestBody 逻辑中包含了对 headers 签名的逻辑,并且插入 context 以后还要重新计算签名,因此无法复用 handleRequestBody 方法
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
@@ -185,7 +186,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
// 若无配置文件,直接返回
if m.config.context == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
return types.ActionContinue, replaceJsonRequestBody(request)
}
err := m.contextCache.GetContent(func(content string, err error) {
log.Debugf("#debug nash5# ctx file loaded! callback start, content is: %s", content)
@@ -204,17 +205,17 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
if err := replaceJsonRequestBody(request, log); err != nil {
if err := replaceJsonRequestBody(request); err != nil {
util.ErrorHandler("ai-proxy.hunyuan.insert_ctx_failed", fmt.Errorf("failed to replace request body: %v", err))
}
}, log)
})
if err == nil {
log.Debugf("#debug nash5# ctx file load success!")
return types.ActionPause, nil
}
log.Debugf("#debug nash5# ctx file load failed!")
return types.ActionContinue, replaceJsonRequestBody(request, log)
return types.ActionContinue, replaceJsonRequestBody(request)
}
// 使用open ai接口协议
@@ -228,7 +229,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
return types.ActionContinue, errors.New("missing model in chat completion request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, model) // 设置原始请求的model以便返回值使用
mappedModel := getMappedModel(model, m.config.modelMapping, log)
mappedModel := getMappedModel(model, m.config.modelMapping)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
@@ -258,7 +259,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
string(body),
)
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest, log)
return types.ActionContinue, replaceJsonRequestBody(hunyuanRequest)
}
err := m.contextCache.GetContent(func(content string, err error) {
@@ -278,10 +279,10 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
authorizedValueNew := GetTC3Authorizationcode(m.config.hunyuanAuthId, m.config.hunyuanAuthKey, timestamp, hunyuanDomain, hunyuanChatCompletionTCAction, string(hunyuanBody))
_ = util.OverwriteRequestAuthorization(authorizedValueNew)
if err := replaceJsonRequestBody(hunyuanRequest, log); err != nil {
if err := replaceJsonRequestBody(hunyuanRequest); err != nil {
util.ErrorHandler("ai-proxy.hunyuan.insert_ctx_failed", fmt.Errorf("failed to replace request body: %v", err))
}
}, log)
})
if err == nil {
return types.ActionPause, nil
}
@@ -289,12 +290,12 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
}
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if m.useOpenAICompatibleAPI() {
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
return m.config.defaultTransformRequestBody(ctx, apiName, body)
}
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
err := m.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
@@ -317,7 +318,7 @@ func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, a
return json.Marshal(hunyuanRequest)
}
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() || name != ApiNameChatCompletion {
return chunk, nil
}
@@ -364,7 +365,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
newBufferedBody = newBufferedBody[newEventPivot+2:] // 跳过结束标识
// 转换并追加到输出缓冲区
convertedData, _ := m.convertChunkFromHunyuanToOpenAI(ctx, eventData, log)
convertedData, _ := m.convertChunkFromHunyuanToOpenAI(ctx, eventData)
// log.Debugf("@@@ >>> converted one chunk: %s", string(convertedData))
outputBuffer = append(outputBuffer, convertedData...)
}
@@ -376,7 +377,7 @@ func (m *hunyuanProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
return outputBuffer, nil
}
func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, hunyuanChunk []byte, log wrapper.Log) ([]byte, error) {
func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContext, hunyuanChunk []byte) ([]byte, error) {
// 将hunyuan的chunk转为openai的chunk
hunyuanFormattedChunk := &hunyuanTextGenDetailedResponseNonStreaming{}
if err := json.Unmarshal(hunyuanChunk, hunyuanFormattedChunk); err != nil {
@@ -433,7 +434,7 @@ func (m *hunyuanProvider) convertChunkFromHunyuanToOpenAI(ctx wrapper.HttpContex
return []byte(openAIChunk.String()), nil
}
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (m *hunyuanProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if m.config.IsOriginal() || m.useOpenAICompatibleAPI() {
return body, nil
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -73,45 +74,45 @@ func (m *minimaxProvider) GetProviderType() string {
return providerTypeMinimax
}
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *minimaxProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
// Delay the header processing to allow changing streaming mode in OnRequestBody
return nil
}
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *minimaxProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, minimaxDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
}
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *minimaxProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
if minimaxApiTypePro == m.config.minimaxApiType {
// Use chat completion Pro API.
return m.handleRequestBodyByChatCompletionPro(body, log)
return m.handleRequestBodyByChatCompletionPro(body)
} else {
// Use chat completion V2 API.
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
}
// handleRequestBodyByChatCompletionPro processes the request body using the chat completion Pro API.
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log wrapper.Log) (types.Action, error) {
func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte) (types.Action, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
// Map the model and rewrite the request path.
request.Model = getMappedModel(request.Model, m.config.modelMapping, log)
request.Model = getMappedModel(request.Model, m.config.modelMapping)
_ = util.OverwriteRequestPath(fmt.Sprintf("%s?GroupId=%s", minimaxChatCompletionProPath, m.config.minimaxGroupId))
if m.config.context == nil {
minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, "")
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest, log)
return types.ActionContinue, replaceJsonRequestBody(minimaxRequest)
}
err := m.contextCache.GetContent(func(content string, err error) {
@@ -126,30 +127,30 @@ func (m *minimaxProvider) handleRequestBodyByChatCompletionPro(body []byte, log
// For minimaxChatCompletionPro, we need to manually handle context messages.
// minimaxChatCompletionV2 uses the default defaultInsertHttpContextMessage method to insert context messages.
minimaxRequest := m.buildMinimaxChatCompletionProRequest(request, content)
if err := replaceJsonRequestBody(minimaxRequest, log); err != nil {
if err := replaceJsonRequestBody(minimaxRequest); err != nil {
util.ErrorHandler("ai-proxy.minimax.insert_ctx_failed", fmt.Errorf("failed to replace Request body: %v", err))
}
}, log)
})
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
return m.handleRequestBodyByChatCompletionV2(body, headers, log)
func (m *minimaxProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
return m.handleRequestBodyByChatCompletionV2(body, headers)
}
// handleRequestBodyByChatCompletionV2 processes the request body using the chat completion V2 API.
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *minimaxProvider) handleRequestBodyByChatCompletionV2(body []byte, headers http.Header) ([]byte, error) {
util.OverwriteRequestPathHeader(headers, minimaxChatCompletionV2Path)
rawModel := gjson.GetBytes(body, "model").String()
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
mappedModel := getMappedModel(rawModel, m.config.modelMapping)
return sjson.SetBytes(body, "model", mappedModel)
}
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
// Skip OnStreamingResponseBody() and OnResponseBody() when using the original protocol
// or when the model corresponds to the chat completion V2 interface.
if m.config.protocol == protocolOriginal || minimaxApiTypePro != m.config.minimaxApiType {
@@ -160,7 +161,7 @@ func (m *minimaxProvider) TransformResponseHeaders(ctx wrapper.HttpContext, apiN
}
// OnStreamingResponseBody handles streaming response chunks from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
@@ -199,7 +200,7 @@ func (m *minimaxProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name
}
// TransformResponseBody handles the final response body from the Minimax service only for requests using the OpenAI protocol and corresponding to the chat completion Pro API.
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (m *minimaxProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}

View File

@@ -47,19 +47,19 @@ func (m *mistralProvider) GetProviderType() string {
return providerTypeMistral
}
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *mistralProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *mistralProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *mistralProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestHostHeader(headers, mistralDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")

View File

@@ -4,8 +4,8 @@ import (
"fmt"
"strings"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
const (
@@ -172,7 +172,7 @@ func (m *chatMessage) handleStreamingReasoningContent(ctx wrapper.HttpContext, r
if contentPushed {
if m.ReasoningContent != "" {
// This shouldn't happen, but if it does, we can add a log here.
proxywasm.LogWarnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
log.Warnf("[ai-proxy] Content already pushed, but reasoning content is not empty: %v", m)
}
return
}

View File

@@ -7,6 +7,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -63,12 +64,12 @@ func (m *moonshotProvider) GetProviderType() string {
return providerTypeMoonshot
}
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *moonshotProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, moonshotDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
@@ -77,7 +78,7 @@ func (m *moonshotProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiN
// moonshot 有自己获取 context 的配置moonshotFileId因此无法复用 handleRequestBody 方法
// moonshot 的 body 没有修改无须实现TransformRequestBody使用默认的 defaultTransformRequestBody 方法
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
@@ -87,12 +88,12 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
}
request := &chatCompletionRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return types.ActionContinue, err
}
if m.config.moonshotFileId == "" && m.contextCache == nil {
return types.ActionContinue, replaceJsonRequestBody(request, log)
return types.ActionContinue, replaceJsonRequestBody(request)
}
apiKey := m.config.GetOrSetTokenWithContext(ctx)
@@ -105,23 +106,23 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
_ = util.ErrorHandler("ai-proxy.moonshot.load_ctx_failed", fmt.Errorf("failed to load context file: %v", err))
return
}
err = m.performChatCompletion(ctx, content, request, log)
err = m.performChatCompletion(ctx, content, request)
if err != nil {
_ = util.ErrorHandler("ai-proxy.moonshot.insert_ctx_failed", fmt.Errorf("failed to perform chat completion: %v", err))
}
}, log)
})
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
func (m *moonshotProvider) performChatCompletion(ctx wrapper.HttpContext, fileContent string, request *chatCompletionRequest, log wrapper.Log) error {
func (m *moonshotProvider) performChatCompletion(ctx wrapper.HttpContext, fileContent string, request *chatCompletionRequest) error {
insertContextMessage(request, fileContent)
return replaceJsonRequestBody(request, log)
return replaceJsonRequestBody(request)
}
func (m *moonshotProvider) getContextContent(apiKey string, callback func(string, error), log wrapper.Log) error {
func (m *moonshotProvider) getContextContent(apiKey string, callback func(string, error)) error {
if m.config.moonshotFileId != "" {
if m.fileContent != "" {
callback(m.fileContent, nil)
@@ -142,7 +143,7 @@ func (m *moonshotProvider) getContextContent(apiKey string, callback func(string
}
if m.contextCache != nil {
return m.contextCache.GetContent(callback, log)
return m.contextCache.GetContent(callback)
}
return errors.New("both moonshotFileId and context are not configured")
@@ -161,7 +162,7 @@ func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callba
}
}
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
func (m *moonshotProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error) {
if name != ApiNameChatCompletion {
return nil, nil
}

View File

@@ -54,19 +54,19 @@ func (m *ollamaProvider) GetProviderType() string {
return providerTypeOllama
}
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *ollamaProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, m.serviceDomain)
headers.Del("Content-Length")

View File

@@ -7,8 +7,8 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -69,7 +69,7 @@ func (m *openaiProviderInitializer) CreateProvider(config ProviderConfig) (Provi
}
}
config.setDefaultCapabilities(capabilities)
proxywasm.LogDebugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
log.Debugf("ai-proxy: openai provider customDomain:%s, customPath:%s, isDirectCustomPath:%v, capabilities:%v",
pairs[0], customPath, isDirectCustomPath, capabilities)
return &openaiProvider{
config: config,
@@ -92,12 +92,12 @@ func (m *openaiProvider) GetProviderType() string {
return providerTypeOpenAI
}
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *openaiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if m.customPath != "" {
if m.isDirectCustomPath || apiName == "" {
util.OverwriteRequestPathHeader(headers, m.customPath)
@@ -118,15 +118,15 @@ func (m *openaiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
headers.Del("Content-Length")
}
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if apiName != ApiNameChatCompletion {
// We don't need to process the request body for other APIs.
return types.ActionContinue, nil
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if m.config.responseJsonSchema != nil {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
@@ -136,5 +136,5 @@ func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
request.ResponseFormat = m.config.responseJsonSchema
body, _ = json.Marshal(request)
}
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
return m.config.defaultTransformRequestBody(ctx, apiName, body)
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
@@ -145,19 +146,19 @@ type Provider interface {
}
type RequestHeadersHandler interface {
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error
OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error
}
type RequestBodyHandler interface {
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error)
OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error)
}
type StreamingResponseBodyHandler interface {
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error)
OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error)
}
type StreamingEventHandler interface {
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error)
OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error)
}
type ApiNameHandler interface {
@@ -165,25 +166,25 @@ type ApiNameHandler interface {
}
type TransformRequestHeadersHandler interface {
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
}
type TransformRequestBodyHandler interface {
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
}
// TransformRequestBodyHeadersHandler allows to transform request headers based on the request body.
// Some providers (e.g. gemini) transform request headers (e.g., path) based on the request body (e.g., model).
type TransformRequestBodyHeadersHandler interface {
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error)
TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error)
}
type TransformResponseHeadersHandler interface {
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log)
TransformResponseHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header)
}
type TransformResponseBodyHandler interface {
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error)
TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error)
}
type ProviderConfig struct {
@@ -496,7 +497,7 @@ func CreateProvider(pc ProviderConfig) (Provider, error) {
return initializer.CreateProvider(pc)
}
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte, log wrapper.Log) error {
func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, request interface{}, body []byte) error {
switch req := request.(type) {
case *chatCompletionRequest:
if err := decodeChatCompletionRequest(body, req); err != nil {
@@ -511,18 +512,18 @@ func (c *ProviderConfig) parseRequestAndMapModel(ctx wrapper.HttpContext, reques
ctx.SetContext(ctxKeyIsStreaming, false)
}
return c.setRequestModel(ctx, req, log)
return c.setRequestModel(ctx, req)
case *embeddingsRequest:
if err := decodeEmbeddingsRequest(body, req); err != nil {
return err
}
return c.setRequestModel(ctx, req, log)
return c.setRequestModel(ctx, req)
default:
return errors.New("unsupported request type")
}
}
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}, log wrapper.Log) error {
func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interface{}) error {
var model *string
switch req := request.(type) {
@@ -534,16 +535,16 @@ func (c *ProviderConfig) setRequestModel(ctx wrapper.HttpContext, request interf
return errors.New("unsupported request type")
}
return c.mapModel(ctx, model, log)
return c.mapModel(ctx, model)
}
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wrapper.Log) error {
func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string) error {
if *model == "" {
return errors.New("missing model in request")
}
ctx.SetContext(ctxKeyOriginalRequestModel, *model)
mappedModel := getMappedModel(*model, c.modelMapping, log)
mappedModel := getMappedModel(*model, c.modelMapping)
if mappedModel == "" {
return errors.New("model becomes empty after applying the configured mapping")
}
@@ -553,15 +554,15 @@ func (c *ProviderConfig) mapModel(ctx wrapper.HttpContext, model *string, log wr
return nil
}
func getMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
mappedModel := doGetMappedModel(model, modelMapping, log)
func getMappedModel(model string, modelMapping map[string]string) string {
mappedModel := doGetMappedModel(model, modelMapping)
if len(mappedModel) != 0 {
return mappedModel
}
return model
}
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
func doGetMappedModel(model string, modelMapping map[string]string) string {
if len(modelMapping) == 0 {
return ""
}
@@ -590,7 +591,7 @@ func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.
return ""
}
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte, log wrapper.Log) []StreamEvent {
func ExtractStreamingEvents(ctx wrapper.HttpContext, chunk []byte) []StreamEvent {
body := chunk
if bufferedStreamingBody, has := ctx.GetContext(ctxKeyStreamingBody).([]byte); has {
body = append(bufferedStreamingBody, chunk...)
@@ -679,8 +680,7 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string)
}
func (c *ProviderConfig) handleRequestBody(
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log,
) (types.Action, error) {
provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
// use original protocol
if c.IsOriginal() {
return types.ActionContinue, nil
@@ -689,13 +689,13 @@ func (c *ProviderConfig) handleRequestBody(
// use openai protocol
var err error
if handler, ok := provider.(TransformRequestBodyHandler); ok {
body, err = handler.TransformRequestBody(ctx, apiName, body, log)
body, err = handler.TransformRequestBody(ctx, apiName, body)
} else if handler, ok := provider.(TransformRequestBodyHeadersHandler); ok {
headers := util.GetOriginalRequestHeaders()
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers, log)
body, err = handler.TransformRequestBodyHeaders(ctx, apiName, body, headers)
util.ReplaceRequestHeaders(headers)
} else {
body, err = c.defaultTransformRequestBody(ctx, apiName, body, log)
body, err = c.defaultTransformRequestBody(ctx, apiName, body)
}
if err != nil {
@@ -704,28 +704,28 @@ func (c *ProviderConfig) handleRequestBody(
if apiName == ApiNameChatCompletion {
if c.context == nil {
return types.ActionContinue, replaceRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body)
}
err = contextCache.GetContextFromFile(ctx, provider, body, log)
err = contextCache.GetContextFromFile(ctx, provider, body)
if err == nil {
return types.ActionPause, nil
}
return types.ActionContinue, err
}
return types.ActionContinue, replaceRequestBody(body, log)
return types.ActionContinue, replaceRequestBody(body)
}
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) {
func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.HttpContext, apiName ApiName) {
headers := util.GetOriginalRequestHeaders()
if handler, ok := provider.(TransformRequestHeadersHandler); ok {
handler.TransformRequestHeaders(ctx, apiName, headers, log)
handler.TransformRequestHeaders(ctx, apiName, headers)
util.ReplaceRequestHeaders(headers)
}
}
// defaultTransformRequestBody 默认的请求体转换方法只做模型映射用slog替换模型名称不用序列化和反序列化提高性能
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
switch apiName {
case ApiNameChatCompletion:
stream := gjson.GetBytes(body, "stream").Bool()
@@ -738,7 +738,7 @@ func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, ap
}
model := gjson.GetBytes(body, "model").String()
ctx.SetContext(ctxKeyOriginalRequestModel, model)
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping))
}
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
@@ -79,7 +80,7 @@ type qwenProvider struct {
contextCache *contextCache
}
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
if m.config.qwenDomain != "" {
util.OverwriteRequestHostHeader(headers, m.config.qwenDomain)
} else {
@@ -92,11 +93,11 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
}
}
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header) ([]byte, error) {
if m.config.qwenEnableCompatible {
if gjson.GetBytes(body, "model").Exists() {
rawModel := gjson.GetBytes(body, "model").String()
mappedModel := getMappedModel(rawModel, m.config.modelMapping, log)
mappedModel := getMappedModel(rawModel, m.config.modelMapping)
newBody, err := sjson.SetBytes(body, "model", mappedModel)
if err != nil {
log.Errorf("Replace model error: %v", err)
@@ -108,11 +109,11 @@ func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiN
}
switch apiName {
case ApiNameChatCompletion:
return m.onChatCompletionRequestBody(ctx, body, headers, log)
return m.onChatCompletionRequestBody(ctx, body, headers)
case ApiNameEmbeddings:
return m.onEmbeddingsRequestBody(ctx, body, log)
return m.onEmbeddingsRequestBody(ctx, body)
default:
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
return m.config.defaultTransformRequestBody(ctx, apiName, body)
}
}
@@ -120,8 +121,8 @@ func (m *qwenProvider) GetProviderType() string {
return providerTypeQwen
}
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
if m.config.protocol == protocolOriginal {
ctx.DontReadRequestBody()
@@ -131,16 +132,16 @@ func (m *qwenProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName
return nil
}
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *qwenProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body []byte, headers http.Header) ([]byte, error) {
request := &chatCompletionRequest{}
err := m.config.parseRequestAndMapModel(ctx, request, body, log)
err := m.config.parseRequestAndMapModel(ctx, request, body)
if err != nil {
return nil, err
}
@@ -162,9 +163,9 @@ func (m *qwenProvider) onChatCompletionRequestBody(ctx wrapper.HttpContext, body
return m.buildQwenTextGenerationRequest(ctx, request, streaming)
}
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
request := &embeddingsRequest{}
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
if err := m.config.parseRequestAndMapModel(ctx, request, body); err != nil {
return nil, err
}
@@ -175,7 +176,7 @@ func (m *qwenProvider) onEmbeddingsRequestBody(ctx wrapper.HttpContext, body []b
return json.Marshal(qwenRequest)
}
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent, log wrapper.Log) ([]StreamEvent, error) {
func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, event StreamEvent) ([]StreamEvent, error) {
if m.config.qwenEnableCompatible || name != ApiNameChatCompletion {
return nil, nil
}
@@ -189,7 +190,7 @@ func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, e
}
var outputEvents []StreamEvent
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming, log)
responses := m.buildChatCompletionStreamingResponse(ctx, qwenResponse, incrementalStreaming)
for _, response := range responses {
responseBody, err := json.Marshal(response)
if err != nil {
@@ -203,15 +204,15 @@ func (m *qwenProvider) OnStreamingEvent(ctx wrapper.HttpContext, name ApiName, e
return outputEvents, nil
}
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if m.config.qwenEnableCompatible {
return body, nil
}
if apiName == ApiNameChatCompletion {
return m.onChatCompletionResponseBody(ctx, body, log)
return m.onChatCompletionResponseBody(ctx, body)
}
if apiName == ApiNameEmbeddings {
return m.onEmbeddingsResponseBody(ctx, body, log)
return m.onEmbeddingsResponseBody(ctx, body)
}
if m.config.isSupportedAPI(apiName) {
return body, nil
@@ -219,7 +220,7 @@ func (m *qwenProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName Ap
return nil, errUnsupportedApiName
}
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
qwenResponse := &qwenTextGenResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
@@ -228,7 +229,7 @@ func (m *qwenProvider) onChatCompletionResponseBody(ctx wrapper.HttpContext, bod
return json.Marshal(response)
}
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte, log wrapper.Log) ([]byte, error) {
func (m *qwenProvider) onEmbeddingsResponseBody(ctx wrapper.HttpContext, body []byte) ([]byte, error) {
qwenResponse := &qwenTextEmbeddingResponse{}
if err := json.Unmarshal(body, qwenResponse); err != nil {
return nil, fmt.Errorf("unable to unmarshal Qwen response: %v", err)
@@ -308,7 +309,7 @@ func (m *qwenProvider) buildChatCompletionResponse(ctx wrapper.HttpContext, qwen
}
}
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool, log wrapper.Log) []*chatCompletionResponse {
func (m *qwenProvider) buildChatCompletionStreamingResponse(ctx wrapper.HttpContext, qwenResponse *qwenTextGenResponse, incrementalStreaming bool) []*chatCompletionResponse {
baseMessage := chatCompletionResponse{
Id: qwenResponse.RequestId,
Created: time.Now().UnixMilli() / 1000,

View File

@@ -3,7 +3,8 @@ package provider
import (
"encoding/json"
"fmt"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
)
@@ -24,7 +25,7 @@ func decodeEmbeddingsRequest(body []byte, request *embeddingsRequest) error {
return nil
}
func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
func replaceJsonRequestBody(request interface{}) error {
body, err := json.Marshal(request)
if err != nil {
return fmt.Errorf("unable to marshal request: %v", err)
@@ -37,7 +38,7 @@ func replaceJsonRequestBody(request interface{}, log wrapper.Log) error {
return err
}
func replaceRequestBody(body []byte, log wrapper.Log) error {
func replaceRequestBody(body []byte) error {
log.Debugf("request body: %s", string(body))
err := proxywasm.ReplaceHttpRequestBody(body)
if err != nil {
@@ -65,7 +66,7 @@ func insertContextMessage(request *chatCompletionRequest, content string) {
}
}
func ReplaceResponseBody(body []byte, log wrapper.Log) error {
func ReplaceResponseBody(body []byte) error {
log.Debugf("response body: %s", string(body))
err := proxywasm.ReplaceHttpResponseBody(body)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"net/http"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/tidwall/gjson"
@@ -50,24 +51,24 @@ func (c *ProviderConfig) IsRetryOnFailureEnabled() bool {
return c.retryOnFailure.enabled
}
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string, log wrapper.Log) error {
func (c *ProviderConfig) retryFailedRequest(activeProvider Provider, ctx wrapper.HttpContext, apiTokenInUse string, apiTokens []string) error {
log.Infof("Retry failed request: provider=%s", activeProvider.GetProviderType())
retryClient := createRetryClient()
apiName, _ := ctx.GetContext(CtxKeyApiName).(ApiName)
ctx.SetContext(ctxRetryCount, 1)
return c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens, log)
return c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens)
}
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte, log wrapper.Log) ([][2]string, []byte) {
func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext, activeProvider Provider, apiName ApiName, headers http.Header, body []byte) ([][2]string, []byte) {
if handler, ok := activeProvider.(TransformResponseHeadersHandler); ok {
handler.TransformResponseHeaders(ctx, apiName, headers, log)
handler.TransformResponseHeaders(ctx, apiName, headers)
} else {
c.DefaultTransformResponseHeaders(ctx, headers)
}
if handler, ok := activeProvider.(TransformResponseBodyHandler); ok {
var err error
body, err = handler.TransformResponseBody(ctx, apiName, body, log)
body, err = handler.TransformResponseBody(ctx, apiName, body)
if err != nil {
log.Errorf("Failed to transform response body: %v", err)
}
@@ -77,7 +78,7 @@ func (c *ProviderConfig) transformResponseHeadersAndBody(ctx wrapper.HttpContext
}
func (c *ProviderConfig) retryCall(
ctx wrapper.HttpContext, log wrapper.Log, activeProvider Provider,
ctx wrapper.HttpContext, activeProvider Provider,
apiName ApiName, statusCode int, responseHeaders http.Header, responseBody []byte,
retryClient *wrapper.ClusterClient[wrapper.RouteCluster],
apiTokenInUse string, apiTokens []string) {
@@ -87,7 +88,7 @@ func (c *ProviderConfig) retryCall(
if statusCode == 200 {
log.Infof("Retry request succeeded")
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody, log)
headers, body := c.transformResponseHeadersAndBody(ctx, activeProvider, apiName, responseHeaders, responseBody)
proxywasm.SendHttpResponse(200, headers, body, -1)
return
} else {
@@ -97,7 +98,7 @@ func (c *ProviderConfig) retryCall(
retryCount++
if retryCount <= int(c.retryOnFailure.maxRetries) {
ctx.SetContext(ctxRetryCount, retryCount)
err := c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens, log)
err := c.sendRetryRequest(ctx, apiName, activeProvider, retryClient, apiTokenInUse, apiTokens)
if err != nil {
log.Errorf("sendRetryRequest failed, err:%v", err)
proxywasm.ResumeHttpResponse()
@@ -113,10 +114,10 @@ func (c *ProviderConfig) retryCall(
func (c *ProviderConfig) sendRetryRequest(
ctx wrapper.HttpContext, apiName ApiName, activeProvider Provider,
retryClient *wrapper.ClusterClient[wrapper.RouteCluster],
apiTokenInUse string, apiTokens []string, log wrapper.Log) error {
apiTokenInUse string, apiTokens []string) error {
// Remove last failed token from retry apiTokens list
apiTokens = removeApiTokenFromRetryList(apiTokens, apiTokenInUse, log)
apiTokens = removeApiTokenFromRetryList(apiTokens, apiTokenInUse)
if len(apiTokens) == 0 {
return errors.New("No more apiTokens to retry")
}
@@ -130,14 +131,14 @@ func (c *ProviderConfig) sendRetryRequest(
{"content-type", "application/json"},
{":authority", ctx.GetStringContext(CtxRequestHost, "")},
{":path", ctx.GetStringContext(CtxRequestPath, "")},
}, requestBody, log)
}, requestBody)
if err != nil {
return fmt.Errorf("sendRetryRequest failed to transform request headers and body: %v", err)
}
err = retryClient.Post(generateUrl(modifiedHeaders), util.HeaderToSlice(modifiedHeaders), modifiedBody,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
c.retryCall(ctx, log, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient, apiTokenInUse, apiTokens)
c.retryCall(ctx, activeProvider, apiName, statusCode, responseHeaders, responseBody, retryClient, apiTokenInUse, apiTokens)
}, uint32(c.retryOnFailure.retryTimeout))
if err != nil {
return fmt.Errorf("Failed to send retry request: %v", err)
@@ -150,7 +151,7 @@ func createRetryClient() *wrapper.ClusterClient[wrapper.RouteCluster] {
return retryClient
}
func removeApiTokenFromRetryList(apiTokens []string, removedApiToken string, log wrapper.Log) []string {
func removeApiTokenFromRetryList(apiTokens []string, removedApiToken string) []string {
var availableApiTokens []string
for _, s := range apiTokens {
if s != removedApiToken {

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)
@@ -73,19 +74,19 @@ func (p *sparkProvider) GetProviderType() string {
return providerTypeSpark
}
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
p.config.handleRequestHeaders(p, ctx, apiName, log)
func (p *sparkProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
p.config.handleRequestHeaders(p, ctx, apiName)
return nil
}
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (p *sparkProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !p.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body, log)
return p.config.handleRequestBody(p, p.contextCache, ctx, apiName, body)
}
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) ([]byte, error) {
if apiName != ApiNameChatCompletion {
return body, nil
}
@@ -100,7 +101,7 @@ func (p *sparkProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName A
return json.Marshal(response)
}
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool, log wrapper.Log) ([]byte, error) {
func (p *sparkProvider) OnStreamingResponseBody(ctx wrapper.HttpContext, name ApiName, chunk []byte, isLastChunk bool) ([]byte, error) {
if isLastChunk || len(chunk) == 0 {
return nil, nil
}
@@ -177,7 +178,7 @@ func (p *sparkProvider) appendResponse(responseBuilder *strings.Builder, respons
responseBuilder.WriteString(fmt.Sprintf("%s %s\n\n", streamDataItemKey, responseBody))
}
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (p *sparkProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), p.config.capabilities)
util.OverwriteRequestHostHeader(headers, sparkHost)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+p.config.GetApiTokenInUse(ctx))

View File

@@ -48,19 +48,19 @@ func (m *stepfunProvider) GetProviderType() string {
return providerTypeStepfun
}
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *stepfunProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *stepfunProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *stepfunProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, stepfunDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -47,19 +47,19 @@ func (m *togetherAIProvider) GetProviderType() string {
return providerTypeTogetherAI
}
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *togetherAIProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *togetherAIProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *togetherAIProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, togetherAIDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -47,19 +47,19 @@ func (m *yiProvider) GetProviderType() string {
return providerTypeYi
}
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *yiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *yiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *yiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, yiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))

View File

@@ -49,19 +49,19 @@ func (m *zhipuAiProvider) GetProviderType() string {
return providerTypeZhipuAi
}
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) error {
m.config.handleRequestHeaders(m, ctx, apiName, log)
func (m *zhipuAiProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName) error {
m.config.handleRequestHeaders(m, ctx, apiName)
return nil
}
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
func (m *zhipuAiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte) (types.Action, error) {
if !m.config.isSupportedAPI(apiName) {
return types.ActionContinue, errUnsupportedApiName
}
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body, log)
return m.config.handleRequestBody(m, m.contextCache, ctx, apiName, body)
}
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
func (m *zhipuAiProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header) {
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, zhipuAiDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))