mirror of
https://github.com/alibaba/higress.git
synced 2026-02-26 13:40:49 +08:00
180 lines
5.4 KiB
Go
180 lines
5.4 KiB
Go
package provider
|
||
|
||
import (
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
|
||
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
|
||
"github.com/tidwall/gjson"
|
||
)
|
||
|
||
type ContextConfig struct {
|
||
// @Title zh-CN 文件URL
|
||
// @Description zh-CN 用于获取对话上下文的文件的URL。目前仅支持HTTP和HTTPS协议,纯文本格式文件
|
||
fileUrl string `required:"true" yaml:"url" json:"url"`
|
||
// @Title zh-CN 上游服务名称
|
||
// @Description zh-CN 文件服务所对应的网关内上游服务名称
|
||
serviceName string `required:"true" yaml:"serviceName" json:"serviceName"`
|
||
// @Title zh-CN 上游服务端口
|
||
// @Description zh-CN 文件服务所对应的网关内上游服务名称
|
||
servicePort int64 `required:"true" yaml:"servicePort" json:"servicePort"`
|
||
|
||
fileUrlObj *url.URL `yaml:"-"`
|
||
}
|
||
|
||
func (c *ContextConfig) FromJson(json gjson.Result) {
|
||
c.fileUrl = json.Get("fileUrl").String()
|
||
c.serviceName = json.Get("serviceName").String()
|
||
c.servicePort = json.Get("servicePort").Int()
|
||
}
|
||
|
||
func (c *ContextConfig) Validate() error {
|
||
if c.fileUrl == "" {
|
||
return errors.New("missing fileUrl in context config")
|
||
}
|
||
if fileUrlObj, err := url.Parse(c.fileUrl); err != nil {
|
||
return fmt.Errorf("invalid fileUrl in context config: %v", err)
|
||
} else {
|
||
c.fileUrlObj = fileUrlObj
|
||
}
|
||
if c.serviceName == "" {
|
||
return errors.New("missing serviceName in context config")
|
||
}
|
||
if c.servicePort == 0 {
|
||
return errors.New("missing servicePort in context config")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type contextCache struct {
|
||
client wrapper.HttpClient
|
||
fileUrl *url.URL
|
||
timeout uint32
|
||
|
||
loaded bool
|
||
content string
|
||
}
|
||
|
||
type ContextInserter interface {
|
||
insertHttpContextMessage(body []byte, content string, onlyOneSystemBeforeFile bool) ([]byte, error)
|
||
}
|
||
|
||
func (c *contextCache) GetContent(callback func(string, error)) error {
|
||
if callback == nil {
|
||
return errors.New("callback is nil")
|
||
}
|
||
|
||
if c.loaded {
|
||
log.Debugf("context file loaded from cache")
|
||
callback(c.content, nil)
|
||
return nil
|
||
}
|
||
|
||
log.Infof("loading context file from %s", c.fileUrl.String())
|
||
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
if statusCode != http.StatusOK {
|
||
callback("", fmt.Errorf("failed to load context file, status: %d", statusCode))
|
||
return
|
||
}
|
||
c.content = string(responseBody)
|
||
c.loaded = true
|
||
log.Debugf("content: %s", c.content)
|
||
callback(c.content, nil)
|
||
}, c.timeout)
|
||
}
|
||
|
||
func createContextCache(providerConfig *ProviderConfig) *contextCache {
|
||
contextConfig := providerConfig.context
|
||
if contextConfig == nil {
|
||
return nil
|
||
}
|
||
fileUrlObj, _ := url.Parse(contextConfig.fileUrl)
|
||
cluster := wrapper.FQDNCluster{
|
||
FQDN: contextConfig.serviceName,
|
||
Port: contextConfig.servicePort,
|
||
Host: fileUrlObj.Host,
|
||
}
|
||
return &contextCache{
|
||
client: wrapper.NewClusterClient(cluster),
|
||
fileUrl: fileUrlObj,
|
||
timeout: providerConfig.timeout,
|
||
}
|
||
}
|
||
|
||
func (c *contextCache) GetContextFromFile(ctx wrapper.HttpContext, provider Provider, body []byte) error {
|
||
if c.loaded {
|
||
log.Debugf("context file loaded from cache")
|
||
insertContext(provider, c.content, nil, body)
|
||
return nil
|
||
}
|
||
|
||
log.Infof("loading context file from %s", c.fileUrl.String())
|
||
return c.client.Get(c.fileUrl.Path, nil, func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||
if statusCode != http.StatusOK {
|
||
insertContext(provider, "", fmt.Errorf("failed to load context file, status: %d", statusCode), nil)
|
||
return
|
||
}
|
||
c.content = string(responseBody)
|
||
c.loaded = true
|
||
log.Debugf("content: %s", c.content)
|
||
insertContext(provider, c.content, nil, body)
|
||
}, c.timeout)
|
||
}
|
||
|
||
func insertContext(provider Provider, content string, err error, body []byte) {
|
||
defer func() {
|
||
_ = proxywasm.ResumeHttpRequest()
|
||
}()
|
||
|
||
typ := provider.GetProviderType()
|
||
if err != nil {
|
||
log.Errorf("failed to load context file: %v", err)
|
||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.load_ctx_failed", typ), fmt.Errorf("failed to load context file: %v", err))
|
||
}
|
||
|
||
if inserter, ok := provider.(ContextInserter); ok {
|
||
body, err = inserter.insertHttpContextMessage(body, content, false)
|
||
} else {
|
||
body, err = defaultInsertHttpContextMessage(body, content)
|
||
}
|
||
|
||
if err != nil {
|
||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.insert_ctx_failed", typ), fmt.Errorf("failed to insert context message: %v", err))
|
||
}
|
||
if err := replaceRequestBody(body); err != nil {
|
||
util.ErrorHandler(fmt.Sprintf("ai-proxy.%s.replace_request_body_failed", typ), fmt.Errorf("failed to replace request body: %v", err))
|
||
}
|
||
}
|
||
|
||
func defaultInsertHttpContextMessage(body []byte, content string) ([]byte, error) {
|
||
request := &chatCompletionRequest{}
|
||
if err := json.Unmarshal(body, request); err != nil {
|
||
return nil, fmt.Errorf("unable to unmarshal request: %v", err)
|
||
}
|
||
|
||
fileMessage := chatMessage{
|
||
Role: roleSystem,
|
||
Content: content,
|
||
}
|
||
var firstNonSystemMessageIndex int
|
||
for i, message := range request.Messages {
|
||
if message.Role != roleSystem {
|
||
firstNonSystemMessageIndex = i
|
||
break
|
||
}
|
||
}
|
||
if firstNonSystemMessageIndex == 0 {
|
||
request.Messages = append([]chatMessage{fileMessage}, request.Messages...)
|
||
} else {
|
||
request.Messages = append(request.Messages[:firstNonSystemMessageIndex], append([]chatMessage{fileMessage}, request.Messages[firstNonSystemMessageIndex:]...)...)
|
||
}
|
||
|
||
return json.Marshal(request)
|
||
}
|