set include_usage by default for all model providers (#1818)

This commit is contained in:
澄潭
2025-02-26 16:49:16 +08:00
committed by GitHub
parent f6c48415d1
commit 1787553294
3 changed files with 29 additions and 21 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -140,16 +141,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
providerConfig := pluginConfig.GetProviderConfig()
newBody, settingErr := providerConfig.ReplaceByCustomSettings(body)
if settingErr != nil {
_ = util.ErrorHandler(
"ai-proxy.proc_req_body_failed",
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
)
return types.ActionContinue
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
}
if providerConfig.IsOpenAIProtocol() {
newBody = normalizeOpenAiRequestBody(newBody, log)
}
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
body = newBody
action, err := handler.OnRequestBody(ctx, apiName, body, log)
@@ -297,6 +296,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
return types.ActionContinue
}
func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte {
var err error
// Default setting include_usage.
if gjson.GetBytes(body, "stream").Bool() {
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
if err != nil {
log.Errorf("set include_usage failed, err:%s", err)
}
}
return body
}
func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {

View File

@@ -127,21 +127,14 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
}
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return nil, err
}
if m.config.responseJsonSchema != nil {
request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
return nil, err
}
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
request.ResponseFormat = m.config.responseJsonSchema
body, _ = json.Marshal(request)
}
if request.Stream {
// For stream requests, we need to include usage in the response.
if request.StreamOptions == nil {
request.StreamOptions = &streamOptions{IncludeUsage: true}
} else if !request.StreamOptions.IncludeUsage {
request.StreamOptions.IncludeUsage = true
}
}
return json.Marshal(request)
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
}

View File

@@ -292,6 +292,10 @@ func (c *ProviderConfig) GetProtocol() string {
return c.protocol
}
func (c *ProviderConfig) IsOpenAIProtocol() bool {
return c.protocol == protocolOpenAI
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
c.id = json.Get("id").String()
c.typ = json.Get("type").String()