mirror of
https://github.com/alibaba/higress.git
synced 2026-06-09 04:37:31 +08:00
feat: implement apiToken failover mechanism (#1256)
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -57,6 +60,10 @@ type contextCache struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type ContextInserter interface {
|
||||
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
|
||||
}
|
||||
|
||||
func (c *contextCache) GetContent(callback func(string, error), log wrapper.Log) error {
|
||||
if callback == nil {
|
||||
return errors.New("callback is nil")
|
||||
@@ -98,3 +105,79 @@ func createContextCache(providerConfig *ProviderConfig) *contextCache {
|
||||
timeout: providerConfig.timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte, log wrapper.Log) error {
|
||||
// get context will overwrite the original request host and path
|
||||
// save the original request host and path in case they are needed for apiToken health check
|
||||
ctx.SetContext(ctxRequestHost, wrapper.GetRequestHost())
|
||||
ctx.SetContext(ctxRequestPath, wrapper.GetRequestPath())
|
||||
|
||||
if c.loaded {
|
||||
log.Debugf("context file loaded from cache")
|
||||
insertContext(provider, c.content, nil, body, log)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("loading context file from %s", c.fileUrl.String())
|
||||
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
if statusCode != http.StatusOK {
|
||||
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil, log)
|
||||
return
|
||||
}
|
||||
c.content = string(responseBody)
|
||||
c.loaded = true
|
||||
log.Debugf("content: %s", c.content)
|
||||
insertContext(provider, c.content, nil, body, log)
|
||||
}, c.timeout)
|
||||
}
|
||||
|
||||
func insertContext(provider Provider, content string, err error, body []byte, log wrapper.Log) {
|
||||
defer func() {
|
||||
_ = proxywasm.ResumeHttpRequest()
|
||||
}()
|
||||
|
||||
typ := provider.GetProviderType()
|
||||
if err != nil {
|
||||
log.Errorf("failed to load context file: %v", err)
|
||||
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
|
||||
}
|
||||
|
||||
if inserter, ok := provider.(ContextInserter); ok {
|
||||
body, err = inserter.insertHttpContextMessage(body, content, false)
|
||||
} else {
|
||||
body, err = defaultInsertHttpContextMessage(body, content)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to insert context message: %v", err))
|
||||
}
|
||||
if err := replaceHttpJsonRequestBody(body, log); err != nil {
|
||||
_ = util.SendResponse(500, fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) {
|
||||
request := &chatCompletionRequest{}
|
||||
if err := json.Unmarshal(body, request); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||||
}
|
||||
|
||||
fileMessage := chatMessage{
|
||||
Role: roleSystem,
|
||||
Content: content,
|
||||
}
|
||||
var firstNonSystemMessageIndex int
|
||||
for i, message := range request.Messages {
|
||||
if message.Role != roleSystem {
|
||||
firstNonSystemMessageIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if firstNonSystemMessageIndex == 0 {
|
||||
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
|
||||
} else {
|
||||
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
|
||||
}
|
||||
|
||||
return json.Marshal(request)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user