fix: potential use of mismatched tokens (#1092)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
pepesi
2024-08-26 15:40:55 +08:00
committed by GitHub
parent 496346fe95
commit f5b8341f7f
2 changed files with 16 additions and 6 deletions

View File

@@ -89,7 +89,8 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
return types.ActionContinue, replaceJsonRequestBody(request, log)
}
err := m.getContextContent(func(content string, err error) {
apiKey := m.config.GetOrSetTokenWithContext(ctx)
err := m.getContextContent(apiKey, func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
@@ -114,13 +115,13 @@ func (m *moonshotProvider) performChatCompletion(ctx wrapper.HttpContext, fileCo
return replaceJsonRequestBody(request, log)
}
func (m *moonshotProvider) getContextContent(callback func(string, error), log wrapper.Log) error {
func (m *moonshotProvider) getContextContent(apiKey string, callback func(string, error), log wrapper.Log) error {
if m.config.moonshotFileId != "" {
if m.fileContent != "" {
callback(m.fileContent, nil)
return nil
}
return m.sendRequest(http.MethodGet, "/v1/files/"+m.config.moonshotFileId+"/content", "",
return m.sendRequest(http.MethodGet, "/v1/files/"+m.config.moonshotFileId+"/content", "", apiKey,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
responseString := string(responseBody)
if statusCode != http.StatusOK {
@@ -141,13 +142,13 @@ func (m *moonshotProvider) getContextContent(callback func(string, error), log w
return errors.New("both moonshotFileId and context are not configured")
}
func (m *moonshotProvider) sendRequest(method, path string, body string, callback wrapper.ResponseCallback) error {
func (m *moonshotProvider) sendRequest(method, path, body, apiKey string, callback wrapper.ResponseCallback) error {
switch method {
case http.MethodGet:
headers := util.CreateHeaders("Authorization", "Bearer "+m.config.GetRandomToken())
headers := util.CreateHeaders("Authorization", "Bearer "+apiKey)
return m.client.Get(path, headers, callback, m.config.timeout)
case http.MethodPost:
headers := util.CreateHeaders("Authorization", "Bearer "+m.config.GetRandomToken(), "Content-Type", "application/json")
headers := util.CreateHeaders("Authorization", "Bearer "+apiKey, "Content-Type", "application/json")
return m.client.Post(path, headers, []byte(body), callback, m.config.timeout)
default:
return errors.New("unsupported method: " + method)

View File

@@ -290,6 +290,15 @@ func (c *ProviderConfig) Validate() error {
return nil
}
func (c *ProviderConfig) GetOrSetTokenWithContext(ctx wrapper.HttpContext) string {
ctxApiKey := ctx.GetContext(ctxKeyApiName)
if ctxApiKey == nil {
ctxApiKey = c.GetRandomToken()
ctx.SetContext(ctxKeyApiName, ctxApiKey)
}
return ctxApiKey.(string)
}
func (c *ProviderConfig) GetRandomToken() string {
apiTokens := c.apiTokens
count := len(apiTokens)