diff --git a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go index 90a419d46..8aeb72b0b 100644 --- a/plugins/wasm-go/extensions/ai-proxy/provider/provider.go +++ b/plugins/wasm-go/extensions/ai-proxy/provider/provider.go @@ -7,6 +7,8 @@ import ( "net/http" "path" "regexp" + "strconv" + "strings" "github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util" @@ -380,6 +382,9 @@ type ProviderConfig struct { basePath string `required:"false" yaml:"basePath" json:"basePath"` // @Title zh-CN basePathHandling用于指定basePath的处理方式,可选值:removePrefix、prepend basePathHandling basePathHandling `required:"false" yaml:"basePathHandling" json:"basePathHandling"` + // @Title zh-CN 首包超时 + // @Description zh-CN 流式请求中收到上游服务第一个响应包的超时时间,单位为毫秒。默认值为 0,表示不开启首包超时 + firstByteTimeout uint32 `required:"false" yaml:"firstByteTimeout" json:"firstByteTimeout"` } func (c *ProviderConfig) GetId() string { @@ -409,6 +414,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { if c.timeout == 0 { c.timeout = defaultTimeout } + // first byte timeout + c.firstByteTimeout = uint32(json.Get("firstByteTimeout").Uint()) c.openaiCustomUrl = json.Get("openaiCustomUrl").String() c.moonshotFileId = json.Get("moonshotFileId").String() c.azureServiceUrl = json.Get("azureServiceUrl").String() @@ -825,6 +832,15 @@ func (c *ProviderConfig) setDefaultCapabilities(capabilities map[string]string) func (c *ProviderConfig) handleRequestBody( provider Provider, contextCache *contextCache, ctx wrapper.HttpContext, apiName ApiName, body []byte, ) (types.Action, error) { + // add the first byte timeout header to the request + if c.firstByteTimeout != 0 && c.isStreamingAPI(apiName, body) { + err := proxywasm.ReplaceHttpRequestHeader("x-envoy-upstream-rq-first-byte-timeout-ms", strconv.FormatUint(uint64(c.firstByteTimeout), 10)) + if err != nil { + log.Errorf("failed to set x-envoy-upstream-rq-first-byte-timeout-ms header: %v", err) + } + log.Debugf("[firstByteTimeout] %d", c.firstByteTimeout) + } + // use original protocol if c.IsOriginal() { return types.ActionContinue, nil @@ -903,6 +919,24 @@ func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext } } +func (c *ProviderConfig) isStreamingAPI(apiName ApiName, body []byte) bool { + stream := false + switch apiName { + case ApiNameCompletion, + ApiNameChatCompletion, + ApiNameImageGeneration, + ApiNameImageEdit, + ApiNameResponses, + ApiNameQwenAsyncAIGC, + ApiNameAnthropicMessages, + ApiNameAnthropicComplete: + stream = gjson.GetBytes(body, "stream").Bool() + case ApiNameGeminiStreamGenerateContent: + stream = true + } + return stream +} + func (c *ProviderConfig) needToProcessRequestBody(apiName ApiName) bool { switch apiName { case ApiNameChatCompletion,