mirror of
https://github.com/alibaba/higress.git
synced 2026-03-12 04:30:49 +08:00
set include_usage by default for all model providers (#1818)
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -140,16 +141,14 @@ func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig
|
||||
|
||||
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
||||
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
||||
|
||||
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
||||
providerConfig := pluginConfig.GetProviderConfig()
|
||||
newBody, settingErr := providerConfig.ReplaceByCustomSettings(body)
|
||||
if settingErr != nil {
|
||||
_ = util.ErrorHandler(
|
||||
"ai-proxy.proc_req_body_failed",
|
||||
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
|
||||
)
|
||||
return types.ActionContinue
|
||||
log.Errorf("failed to replace request body by custom settings: %v", settingErr)
|
||||
}
|
||||
if providerConfig.IsOpenAIProtocol() {
|
||||
newBody = normalizeOpenAiRequestBody(newBody, log)
|
||||
}
|
||||
|
||||
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
||||
body = newBody
|
||||
action, err := handler.OnRequestBody(ctx, apiName, body, log)
|
||||
@@ -297,6 +296,18 @@ func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfi
|
||||
return types.ActionContinue
|
||||
}
|
||||
|
||||
func normalizeOpenAiRequestBody(body []byte, log wrapper.Log) []byte {
|
||||
var err error
|
||||
// Default setting include_usage.
|
||||
if gjson.GetBytes(body, "stream").Bool() {
|
||||
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
log.Errorf("set include_usage failed, err:%s", err)
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
||||
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
||||
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
||||
|
||||
@@ -127,21 +127,14 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
|
||||
}
|
||||
|
||||
func (m *openaiProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m.config.responseJsonSchema != nil {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := decodeChatCompletionRequest(body, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
|
||||
request.ResponseFormat = m.config.responseJsonSchema
|
||||
body, _ = json.Marshal(request)
|
||||
}
|
||||
if request.Stream {
|
||||
// For stream requests, we need to include usage in the response.
|
||||
if request.StreamOptions == nil {
|
||||
request.StreamOptions = &streamOptions{IncludeUsage: true}
|
||||
} else if !request.StreamOptions.IncludeUsage {
|
||||
request.StreamOptions.IncludeUsage = true
|
||||
}
|
||||
}
|
||||
return json.Marshal(request)
|
||||
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
|
||||
}
|
||||
|
||||
@@ -292,6 +292,10 @@ func (c *ProviderConfig) GetProtocol() string {
|
||||
return c.protocol
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) IsOpenAIProtocol() bool {
|
||||
return c.protocol == protocolOpenAI
|
||||
}
|
||||
|
||||
func (c *ProviderConfig) FromJson(json gjson.Result) {
|
||||
c.id = json.Get("id").String()
|
||||
c.typ = json.Get("type").String()
|
||||
|
||||
Reference in New Issue
Block a user