mirror of
https://github.com/alibaba/higress.git
synced 2026-03-02 23:51:11 +08:00
277 lines
9.1 KiB
Go
277 lines
9.1 KiB
Go
// File generated by hgctl. Modify as required.
|
|
// See: https://higress.io/zh-cn/docs/user/wasm-go#2-%E7%BC%96%E5%86%99-maingo-%E6%96%87%E4%BB%B6
|
|
|
|
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/config"
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/provider"
|
|
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
const (
|
|
pluginName = "ai-proxy"
|
|
|
|
defaultMaxBodyBytes uint32 = 10 * 1024 * 1024
|
|
)
|
|
|
|
func main() {
|
|
wrapper.SetCtx(
|
|
pluginName,
|
|
wrapper.ParseOverrideConfigBy(parseGlobalConfig, parseOverrideRuleConfig),
|
|
wrapper.ProcessRequestHeadersBy(onHttpRequestHeader),
|
|
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
|
|
wrapper.ProcessResponseHeadersBy(onHttpResponseHeaders),
|
|
wrapper.ProcessStreamingResponseBodyBy(onStreamingResponseBody),
|
|
wrapper.ProcessResponseBodyBy(onHttpResponseBody),
|
|
)
|
|
}
|
|
|
|
func parseGlobalConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
|
//log.Debugf("loading global config: %s", json.String())
|
|
|
|
pluginConfig.FromJson(json)
|
|
if err := pluginConfig.Validate(); err != nil {
|
|
return err
|
|
}
|
|
if err := pluginConfig.Complete(log); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func parseOverrideRuleConfig(json gjson.Result, global config.PluginConfig, pluginConfig *config.PluginConfig, log wrapper.Log) error {
|
|
//log.Debugf("loading override rule config: %s", json.String())
|
|
|
|
*pluginConfig = global
|
|
|
|
pluginConfig.FromJson(json)
|
|
if err := pluginConfig.Validate(); err != nil {
|
|
return err
|
|
}
|
|
if err := pluginConfig.Complete(log); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func onHttpRequestHeader(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, log wrapper.Log) types.Action {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpRequestHeader] no active provider, skip processing")
|
|
ctx.DontReadRequestBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpRequestHeader] provider=%s", activeProvider.GetProviderType())
|
|
|
|
rawPath := ctx.Path()
|
|
path, _ := url.Parse(rawPath)
|
|
apiName := getOpenAiApiName(path.Path)
|
|
providerConfig := pluginConfig.GetProviderConfig()
|
|
if providerConfig.IsOriginal() {
|
|
if handler, ok := activeProvider.(provider.ApiNameHandler); ok {
|
|
apiName = handler.GetApiName(path.Path)
|
|
}
|
|
}
|
|
|
|
if apiName == "" {
|
|
log.Warnf("[onHttpRequestHeader] unsupported path: %s", path.Path)
|
|
return types.ActionContinue
|
|
}
|
|
|
|
ctx.SetContext(provider.CtxKeyApiName, apiName)
|
|
// Disable the route re-calculation since the plugin may modify some headers related to the chosen route.
|
|
ctx.DisableReroute()
|
|
|
|
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
|
if needHandleStreamingBody {
|
|
proxywasm.RemoveHttpRequestHeader("Accept-Encoding")
|
|
}
|
|
|
|
if handler, ok := activeProvider.(provider.RequestHeadersHandler); ok {
|
|
// Set the apiToken for the current request.
|
|
providerConfig.SetApiTokenInUse(ctx, log)
|
|
|
|
hasRequestBody := wrapper.HasRequestBody()
|
|
err := handler.OnRequestHeaders(ctx, apiName, log)
|
|
if err == nil {
|
|
if hasRequestBody {
|
|
proxywasm.RemoveHttpRequestHeader("Content-Length")
|
|
ctx.SetRequestBodyBufferLimit(defaultMaxBodyBytes)
|
|
// Delay the header processing to allow changing in OnRequestBody
|
|
return types.HeaderStopIteration
|
|
}
|
|
ctx.DontReadRequestBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
util.ErrorHandler("ai-proxy.proc_req_headers_failed", fmt.Errorf("failed to process request headers: %v", err))
|
|
return types.ActionContinue
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpRequestBody] no active provider, skip processing")
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpRequestBody] provider=%s", activeProvider.GetProviderType())
|
|
|
|
if handler, ok := activeProvider.(provider.RequestBodyHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
|
|
newBody, settingErr := pluginConfig.GetProviderConfig().ReplaceByCustomSettings(body)
|
|
if settingErr != nil {
|
|
util.ErrorHandler(
|
|
"ai-proxy.proc_req_body_failed",
|
|
fmt.Errorf("failed to replace request body by custom settings: %v", settingErr),
|
|
)
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpRequestBody] newBody=%s", newBody)
|
|
body = newBody
|
|
action, err := handler.OnRequestBody(ctx, apiName, body, log)
|
|
if err == nil {
|
|
return action
|
|
}
|
|
util.ErrorHandler("ai-proxy.proc_req_body_failed", fmt.Errorf("failed to process request body: %v", err))
|
|
}
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onHttpResponseHeaders(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, log wrapper.Log) types.Action {
|
|
if !wrapper.IsResponseFromUpstream() {
|
|
// Response is not coming from the upstream. Let it pass through.
|
|
ctx.DontReadResponseBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpResponseHeaders] no active provider, skip processing")
|
|
ctx.DontReadResponseBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpResponseHeaders] provider=%s", activeProvider.GetProviderType())
|
|
|
|
providerConfig := pluginConfig.GetProviderConfig()
|
|
apiTokenInUse := providerConfig.GetApiTokenInUse(ctx)
|
|
|
|
status, err := proxywasm.GetHttpResponseHeader(":status")
|
|
if err != nil || status != "200" {
|
|
if err != nil {
|
|
log.Errorf("unable to load :status header from response: %v", err)
|
|
}
|
|
ctx.DontReadResponseBody()
|
|
return providerConfig.OnRequestFailed(activeProvider, ctx, apiTokenInUse, log)
|
|
}
|
|
|
|
// Reset ctxApiTokenRequestFailureCount if the request is successful,
|
|
// the apiToken is removed only when the number of consecutive request failures exceeds the threshold.
|
|
providerConfig.ResetApiTokenRequestFailureCount(apiTokenInUse, log)
|
|
|
|
headers := util.GetOriginalResponseHeaders()
|
|
if handler, ok := activeProvider.(provider.TransformResponseHeadersHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
handler.TransformResponseHeaders(ctx, apiName, headers, log)
|
|
} else {
|
|
providerConfig.DefaultTransformResponseHeaders(ctx, headers)
|
|
}
|
|
util.ReplaceResponseHeaders(headers)
|
|
|
|
checkStream(ctx, log)
|
|
_, needHandleStreamingBody := activeProvider.(provider.StreamingResponseBodyHandler)
|
|
if !needHandleStreamingBody {
|
|
ctx.BufferResponseBody()
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func onStreamingResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, chunk []byte, isLastChunk bool, log wrapper.Log) []byte {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onStreamingResponseBody] no active provider, skip processing")
|
|
return chunk
|
|
}
|
|
|
|
log.Debugf("[onStreamingResponseBody] provider=%s", activeProvider.GetProviderType())
|
|
log.Debugf("isLastChunk=%v chunk: %s", isLastChunk, string(chunk))
|
|
|
|
if handler, ok := activeProvider.(provider.StreamingResponseBodyHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
modifiedChunk, err := handler.OnStreamingResponseBody(ctx, apiName, chunk, isLastChunk, log)
|
|
if err == nil && modifiedChunk != nil {
|
|
return modifiedChunk
|
|
}
|
|
return chunk
|
|
}
|
|
return chunk
|
|
}
|
|
|
|
func onHttpResponseBody(ctx wrapper.HttpContext, pluginConfig config.PluginConfig, body []byte, log wrapper.Log) types.Action {
|
|
activeProvider := pluginConfig.GetProvider()
|
|
|
|
if activeProvider == nil {
|
|
log.Debugf("[onHttpResponseBody] no active provider, skip processing")
|
|
return types.ActionContinue
|
|
}
|
|
|
|
log.Debugf("[onHttpResponseBody] provider=%s", activeProvider.GetProviderType())
|
|
|
|
if handler, ok := activeProvider.(provider.TransformResponseBodyHandler); ok {
|
|
apiName, _ := ctx.GetContext(provider.CtxKeyApiName).(provider.ApiName)
|
|
body, err := handler.TransformResponseBody(ctx, apiName, body, log)
|
|
if err != nil {
|
|
util.ErrorHandler("ai-proxy.proc_resp_body_failed", fmt.Errorf("failed to process response body: %v", err))
|
|
return types.ActionContinue
|
|
}
|
|
if err = provider.ReplaceResponseBody(body, log); err != nil {
|
|
util.ErrorHandler("ai-proxy.replace_resp_body_failed", fmt.Errorf("failed to replace response body: %v", err))
|
|
}
|
|
}
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func checkStream(ctx wrapper.HttpContext, log wrapper.Log) {
|
|
contentType, err := proxywasm.GetHttpResponseHeader("Content-Type")
|
|
if err != nil || !strings.HasPrefix(contentType, "text/event-stream") {
|
|
if err != nil {
|
|
log.Errorf("unable to load content-type header from response: %v", err)
|
|
}
|
|
ctx.BufferResponseBody()
|
|
ctx.SetResponseBodyBufferLimit(defaultMaxBodyBytes)
|
|
}
|
|
}
|
|
|
|
func getOpenAiApiName(path string) provider.ApiName {
|
|
if strings.HasSuffix(path, "/v1/chat/completions") {
|
|
return provider.ApiNameChatCompletion
|
|
}
|
|
if strings.HasSuffix(path, "/v1/embeddings") {
|
|
return provider.ApiNameEmbeddings
|
|
}
|
|
return ""
|
|
}
|