mirror of
https://github.com/alibaba/higress.git
synced 2026-05-24 04:37:25 +08:00
AI proxy plugin (#420)
This commit is contained in:
128
plugins/wasm-go/extensions/chatgpt-proxy/main.go
Normal file
128
plugins/wasm-go/extensions/chatgpt-proxy/main.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
|
||||
"github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm"
|
||||
"github.com/tetratelabs/proxy-wasm-go-sdk/proxywasm/types"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wrapper.SetCtx(
|
||||
"chatgpt-proxy",
|
||||
wrapper.ParseConfigBy(parseConfig),
|
||||
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
type MyConfig struct {
|
||||
Model string
|
||||
ApiKey string
|
||||
PromptParam string
|
||||
ChatgptPath string
|
||||
HumainId string
|
||||
AIId string
|
||||
client wrapper.HttpClient
|
||||
}
|
||||
|
||||
func parseConfig(json gjson.Result, config *MyConfig, log wrapper.Log) error {
|
||||
chatgptUri := json.Get("chatgptUri").String()
|
||||
var chatgptHost string
|
||||
if chatgptUri == "" {
|
||||
config.ChatgptPath = "/v1/completions"
|
||||
chatgptHost = "api.openai.com"
|
||||
} else {
|
||||
cp, err := url.Parse(chatgptUri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config.ChatgptPath = cp.Path
|
||||
chatgptHost = cp.Host
|
||||
}
|
||||
if config.ChatgptPath == "" {
|
||||
return errors.New("not found path in chatgptUri")
|
||||
}
|
||||
if chatgptHost == "" {
|
||||
return errors.New("not found host in chatgptUri")
|
||||
}
|
||||
config.client = wrapper.NewClusterClient(wrapper.RouteCluster{
|
||||
Host: chatgptHost,
|
||||
})
|
||||
config.Model = json.Get("model").String()
|
||||
if config.Model == "" {
|
||||
config.Model = "text-davinci-003"
|
||||
}
|
||||
config.ApiKey = json.Get("apiKey").String()
|
||||
if config.ApiKey == "" {
|
||||
return errors.New("no apiKey found in config")
|
||||
}
|
||||
config.PromptParam = json.Get("promptParam").String()
|
||||
if config.PromptParam == "" {
|
||||
config.PromptParam = "prompt"
|
||||
}
|
||||
config.HumainId = json.Get("HumainId").String()
|
||||
if config.HumainId == "" {
|
||||
config.HumainId = "Humain:"
|
||||
}
|
||||
config.AIId = json.Get("AIId").String()
|
||||
if config.AIId == "" {
|
||||
config.AIId = "AI:"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const bodyTemplate string = `
|
||||
{
|
||||
"model":"%s",
|
||||
"prompt":"%s",
|
||||
"temperature":0.9,
|
||||
"max_tokens": 150,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.6,
|
||||
"stop": [" %s", " %s"]
|
||||
}
|
||||
`
|
||||
|
||||
func onHttpRequestHeaders(ctx wrapper.HttpContext, config MyConfig, log wrapper.Log) types.Action {
|
||||
pairs := strings.SplitN(ctx.Path(), "?", 2)
|
||||
|
||||
if len(pairs) < 2 {
|
||||
proxywasm.SendHttpResponse(400, nil, []byte("1-need prompt param"), -1)
|
||||
return types.ActionContinue
|
||||
}
|
||||
querys, err := url.ParseQuery(pairs[1])
|
||||
if err != nil {
|
||||
proxywasm.SendHttpResponse(400, nil, []byte("2-need prompt param"), -1)
|
||||
return types.ActionContinue
|
||||
}
|
||||
var prompt []string
|
||||
var ok bool
|
||||
if prompt, ok = querys[config.PromptParam]; !ok || len(prompt) == 0 {
|
||||
proxywasm.SendHttpResponse(400, nil, []byte("3-need prompt param"), -1)
|
||||
return types.ActionContinue
|
||||
}
|
||||
body := fmt.Sprintf(bodyTemplate, config.Model, prompt[0], config.HumainId, config.AIId)
|
||||
err = config.client.Post(config.ChatgptPath, [][2]string{
|
||||
{"Content-Type", "application/json"},
|
||||
{"Authorization", "Bearer " + config.ApiKey},
|
||||
}, []byte(body),
|
||||
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
|
||||
var headers [][2]string
|
||||
for key, value := range responseHeaders {
|
||||
headers = append(headers, [2]string{key, value[0]})
|
||||
}
|
||||
proxywasm.SendHttpResponse(uint32(statusCode), headers, responseBody, -1)
|
||||
}, 10000)
|
||||
if err != nil {
|
||||
proxywasm.SendHttpResponse(500, nil, []byte("Internel Error: "+err.Error()), -1)
|
||||
return types.ActionContinue
|
||||
}
|
||||
return types.ActionPause
|
||||
}
|
||||
Reference in New Issue
Block a user