mirror of
https://github.com/alibaba/higress.git
synced 2026-03-09 19:20:51 +08:00
260 lines
6.6 KiB
Go
260 lines
6.6 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"mime"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/textproto"
|
|
"strings"
|
|
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/higress-group/wasm-go/pkg/log"
|
|
"github.com/higress-group/wasm-go/pkg/wrapper"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
DefaultMaxBodyBytes = 100 * 1024 * 1024 // 100MB
|
|
)
|
|
|
|
func main() {}
|
|
|
|
func init() {
|
|
wrapper.SetCtx(
|
|
"model-router",
|
|
wrapper.ParseConfig(parseConfig),
|
|
wrapper.ProcessRequestHeaders(onHttpRequestHeaders),
|
|
wrapper.ProcessRequestBody(onHttpRequestBody),
|
|
wrapper.WithRebuildAfterRequests[ModelRouterConfig](1000),
|
|
wrapper.WithRebuildMaxMemBytes[ModelRouterConfig](200*1024*1024),
|
|
)
|
|
}
|
|
|
|
type ModelRouterConfig struct {
|
|
modelKey string
|
|
addProviderHeader string
|
|
modelToHeader string
|
|
enableOnPathSuffix []string
|
|
}
|
|
|
|
func parseConfig(json gjson.Result, config *ModelRouterConfig) error {
|
|
config.modelKey = json.Get("modelKey").String()
|
|
if config.modelKey == "" {
|
|
config.modelKey = "model"
|
|
}
|
|
config.addProviderHeader = json.Get("addProviderHeader").String()
|
|
config.modelToHeader = json.Get("modelToHeader").String()
|
|
|
|
enableOnPathSuffix := json.Get("enableOnPathSuffix")
|
|
if enableOnPathSuffix.Exists() && enableOnPathSuffix.IsArray() {
|
|
for _, item := range enableOnPathSuffix.Array() {
|
|
config.enableOnPathSuffix = append(config.enableOnPathSuffix, item.String())
|
|
}
|
|
} else {
|
|
// Default suffixes if not provided
|
|
config.enableOnPathSuffix = []string{
|
|
"/completions",
|
|
"/embeddings",
|
|
"/images/generations",
|
|
"/audio/speech",
|
|
"/fine_tuning/jobs",
|
|
"/moderations",
|
|
"/image-synthesis",
|
|
"/video-synthesis",
|
|
"/rerank",
|
|
"/messages",
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func onHttpRequestHeaders(ctx wrapper.HttpContext, config ModelRouterConfig) types.Action {
|
|
path, err := proxywasm.GetHttpRequestHeader(":path")
|
|
if err != nil {
|
|
return types.ActionContinue
|
|
}
|
|
|
|
// Remove query parameters for suffix check
|
|
if idx := strings.Index(path, "?"); idx != -1 {
|
|
path = path[:idx]
|
|
}
|
|
|
|
enable := false
|
|
for _, suffix := range config.enableOnPathSuffix {
|
|
if suffix == "*" || strings.HasSuffix(path, suffix) {
|
|
enable = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !enable || !ctx.HasRequestBody() {
|
|
ctx.DontReadRequestBody()
|
|
return types.ActionContinue
|
|
}
|
|
|
|
// Prepare for body processing
|
|
proxywasm.RemoveHttpRequestHeader("content-length")
|
|
// 100MB buffer limit
|
|
ctx.SetRequestBodyBufferLimit(DefaultMaxBodyBytes)
|
|
|
|
return types.HeaderStopIteration
|
|
}
|
|
|
|
func onHttpRequestBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
|
|
contentType, err := proxywasm.GetHttpRequestHeader("content-type")
|
|
if err != nil {
|
|
return types.ActionContinue
|
|
}
|
|
|
|
if strings.Contains(contentType, "application/json") {
|
|
return handleJsonBody(ctx, config, body)
|
|
} else if strings.Contains(contentType, "multipart/form-data") {
|
|
return handleMultipartBody(ctx, config, body, contentType)
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func handleJsonBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte) types.Action {
|
|
if !json.Valid(body) {
|
|
log.Error("invalid json body")
|
|
return types.ActionContinue
|
|
}
|
|
modelValue := gjson.GetBytes(body, config.modelKey).String()
|
|
if modelValue == "" {
|
|
return types.ActionContinue
|
|
}
|
|
|
|
if config.modelToHeader != "" {
|
|
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
|
|
}
|
|
|
|
if config.addProviderHeader != "" {
|
|
parts := strings.SplitN(modelValue, "/", 2)
|
|
if len(parts) == 2 {
|
|
provider := parts[0]
|
|
model := parts[1]
|
|
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
|
|
|
|
newBody, err := sjson.SetBytes(body, config.modelKey, model)
|
|
if err != nil {
|
|
log.Errorf("failed to update model in json body: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
_ = proxywasm.ReplaceHttpRequestBody(newBody)
|
|
log.Debugf("model route to provider: %s, model: %s", provider, model)
|
|
} else {
|
|
log.Debugf("model route to provider not work, model: %s", modelValue)
|
|
}
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|
|
|
|
func handleMultipartBody(ctx wrapper.HttpContext, config ModelRouterConfig, body []byte, contentType string) types.Action {
|
|
_, params, err := mime.ParseMediaType(contentType)
|
|
if err != nil {
|
|
log.Errorf("failed to parse content type: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
boundary, ok := params["boundary"]
|
|
if !ok {
|
|
log.Errorf("no boundary in content type")
|
|
return types.ActionContinue
|
|
}
|
|
|
|
reader := multipart.NewReader(bytes.NewReader(body), boundary)
|
|
var newBody bytes.Buffer
|
|
writer := multipart.NewWriter(&newBody)
|
|
writer.SetBoundary(boundary)
|
|
|
|
modified := false
|
|
|
|
for {
|
|
part, err := reader.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.Errorf("failed to read multipart part: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
|
|
// Read part content
|
|
partContent, err := io.ReadAll(part)
|
|
if err != nil {
|
|
log.Errorf("failed to read part content: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
|
|
formName := part.FormName()
|
|
if formName == config.modelKey {
|
|
modelValue := string(partContent)
|
|
|
|
if config.modelToHeader != "" {
|
|
_ = proxywasm.ReplaceHttpRequestHeader(config.modelToHeader, modelValue)
|
|
}
|
|
|
|
if config.addProviderHeader != "" {
|
|
parts := strings.SplitN(modelValue, "/", 2)
|
|
if len(parts) == 2 {
|
|
provider := parts[0]
|
|
model := parts[1]
|
|
_ = proxywasm.ReplaceHttpRequestHeader(config.addProviderHeader, provider)
|
|
|
|
// Write modified part
|
|
h := make(http.Header)
|
|
for k, v := range part.Header {
|
|
h[k] = v
|
|
}
|
|
|
|
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
|
|
if err != nil {
|
|
log.Errorf("failed to create part: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
_, err = pw.Write([]byte(model))
|
|
if err != nil {
|
|
log.Errorf("failed to write part content: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
modified = true
|
|
log.Debugf("model route to provider: %s, model: %s", provider, model)
|
|
continue
|
|
} else {
|
|
log.Debugf("model route to provider not work, model: %s", modelValue)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write original part
|
|
h := make(http.Header)
|
|
for k, v := range part.Header {
|
|
h[k] = v
|
|
}
|
|
pw, err := writer.CreatePart(textproto.MIMEHeader(h))
|
|
if err != nil {
|
|
log.Errorf("failed to create part: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
_, err = pw.Write(partContent)
|
|
if err != nil {
|
|
log.Errorf("failed to write part content: %v", err)
|
|
return types.ActionContinue
|
|
}
|
|
}
|
|
|
|
writer.Close()
|
|
|
|
if modified {
|
|
_ = proxywasm.ReplaceHttpRequestBody(newBody.Bytes())
|
|
}
|
|
|
|
return types.ActionContinue
|
|
}
|