Support Openai structure output api (#feat 1214) (#1217)

Co-authored-by: Kent Dong <ch3cho@qq.com>
This commit is contained in:
Yang Beining
2024-08-22 12:33:35 +08:00
committed by GitHub
parent bdbfad8a8a
commit 0e58042fa6
4 changed files with 16 additions and 0 deletions

View File

@@ -52,6 +52,7 @@ OpenAI 所对应的 `type` 为 `openai`。它特有的配置字段如下:
| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-------------------|----------|----------|--------|-------------------------------------------------------------------------------|
| `openaiCustomUrl` | string | 非必填 | - | 基于OpenAI协议的自定义后端URL例如: www.example.com/myai/v1/chat/completions |
| `responseJsonSchema` | object | 非必填 | - | 预先定义OpenAI响应需满足的Json Schema, 注意目前仅特定的几种模型支持该用法|
#### Azure OpenAI

View File

@@ -31,6 +31,7 @@ type chatCompletionRequest struct {
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
Stop []string `json:"stop,omitempty"`
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
}
type streamOptions struct {

View File

@@ -89,6 +89,10 @@ func (m *openaiProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
if err := decodeChatCompletionRequest(body, request); err != nil {
return types.ActionContinue, err
}
if m.config.responseJsonSchema != nil {
log.Debugf("[ai-proxy] set response format to %s", m.config.responseJsonSchema)
request.ResponseFormat = m.config.responseJsonSchema
}
if request.Stream {
// For stream requests, we need to include usage in the response.
if request.StreamOptions == nil {

View File

@@ -181,6 +181,9 @@ type ProviderConfig struct {
// @Title zh-CN 翻译服务需指定的目标语种
// @Description zh-CN 翻译结果的语种目前仅适用于DeepL服务。
targetLang string `required:"false" yaml:"targetLang" json:"targetLang"`
// @Title zh-CN 指定服务返回的响应需满足的JSON Schema
// @Description zh-CN 目前仅适用于OpenAI部分模型服务。参考https://platform.openai.com/docs/guides/structured-outputs
responseJsonSchema map[string]interface{} `required:"false" yaml:"responseJsonSchema" json:"responseJsonSchema"`
}
func (c *ProviderConfig) FromJson(json gjson.Result) {
@@ -229,6 +232,13 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
}
}
c.targetLang = json.Get("targetLang").String()
if schemaValue, ok := json.Get("responseJsonSchema").Value().(map[string]interface{}); ok {
c.responseJsonSchema = schemaValue
} else {
c.responseJsonSchema = nil
}
}
func (c *ProviderConfig) Validate() error {