mirror of
https://github.com/alibaba/higress.git
synced 2026-03-02 15:40:54 +08:00
222 lines
6.8 KiB
Go
222 lines
6.8 KiB
Go
// Copyright (c) 2022 Alibaba Group Holding Ltd.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package wrapper
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"github.com/alibaba/higress/plugins/wasm-go/pkg/log"
|
|
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
|
|
"github.com/invopop/jsonschema"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type MCPTool[PluginConfig any] interface {
|
|
Create(params []byte) MCPTool[PluginConfig]
|
|
Call(context HttpContext, config PluginConfig) error
|
|
Description() string
|
|
InputSchema() map[string]any
|
|
}
|
|
|
|
type MCPTools[PluginConfig any] map[string]MCPTool[PluginConfig]
|
|
|
|
type addMCPToolOption[PluginConfig any] struct {
|
|
name string
|
|
tool MCPTool[PluginConfig]
|
|
}
|
|
|
|
func AddMCPTool[PluginConfig any](name string, tool MCPTool[PluginConfig]) CtxOption[PluginConfig] {
|
|
return &addMCPToolOption[PluginConfig]{
|
|
name: name,
|
|
tool: tool,
|
|
}
|
|
}
|
|
|
|
func (o *addMCPToolOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
|
|
ctx.isJsonRpcSever = true
|
|
ctx.handleJsonRpcMethod = true
|
|
if _, exist := ctx.mcpTools[o.name]; exist {
|
|
panic(fmt.Sprintf("Conflict! There is a tool with the same name:%s",
|
|
o.name))
|
|
}
|
|
ctx.mcpTools[o.name] = o.tool
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) OnMCPResponseSuccess(result map[string]any) {
|
|
ctx.OnJsonRpcResponseSuccess(result)
|
|
// TODO: support pub to redis when use POST + SSE
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) OnMCPResponseError(err error, code ...int) {
|
|
ctx.OnJsonRpcResponseError(err, code...)
|
|
// TODO: support pub to redis when use POST + SSE
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) OnMCPToolCallSuccess(content []map[string]any) {
|
|
ctx.OnMCPResponseSuccess(map[string]any{
|
|
"content": content,
|
|
"isError": false,
|
|
})
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) OnMCPToolCallError(err error) {
|
|
ctx.OnMCPResponseSuccess(map[string]any{
|
|
"content": []map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": err.Error(),
|
|
},
|
|
},
|
|
"isError": true,
|
|
})
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) SendMCPToolTextResult(result string) {
|
|
ctx.OnMCPToolCallSuccess([]map[string]any{
|
|
{
|
|
"type": "text",
|
|
"text": result,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) registerMCPTools(mcpTools MCPTools[PluginConfig]) {
|
|
if !ctx.plugin.vm.isJsonRpcSever {
|
|
return
|
|
}
|
|
if !ctx.plugin.vm.handleJsonRpcMethod {
|
|
return
|
|
}
|
|
ctx.plugin.vm.jsonRpcMethodHandlers["tools/list"] = func(context HttpContext, config PluginConfig, id int64, params gjson.Result) error {
|
|
var tools []map[string]any
|
|
for name, tool := range mcpTools {
|
|
tools = append(tools, map[string]any{
|
|
"name": name,
|
|
"description": tool.Description(),
|
|
"inputSchema": tool.InputSchema(),
|
|
})
|
|
}
|
|
ctx.OnMCPResponseSuccess(map[string]any{
|
|
"tools": tools,
|
|
"nextCursor": "",
|
|
})
|
|
return nil
|
|
}
|
|
ctx.plugin.vm.jsonRpcMethodHandlers["tools/call"] = func(context HttpContext, config PluginConfig, id int64, params gjson.Result) error {
|
|
name := params.Get("name").String()
|
|
args := params.Get("arguments")
|
|
if tool, ok := mcpTools[name]; ok {
|
|
log.Debugf("mcp call tool[%s] with arguments[%s]", name, args.Raw)
|
|
toolInstance := tool.Create([]byte(args.Raw))
|
|
err := toolInstance.Call(context, config)
|
|
// TODO: validate the json schema through github.com/kaptinlin/jsonschema
|
|
if err != nil {
|
|
ctx.OnMCPToolCallError(err)
|
|
return nil
|
|
}
|
|
return nil
|
|
}
|
|
ctx.OnMCPResponseError(errors.New("Unknown tool: invalid_tool_name"), ErrInvalidParams)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
type mcpToolRequestFunc[PluginConfig any] func(context HttpContext, config PluginConfig, toolName string, toolArgs gjson.Result) types.Action
|
|
type mcpToolResponseFunc[PluginConfig any] func(context HttpContext, config PluginConfig, isError bool, content gjson.Result) types.Action
|
|
type jsonRpcErrorFunc[PluginConfig any] func(context HttpContext, config PluginConfig, errorCode int64, errorMessage string) types.Action
|
|
|
|
type mcpToolRequestOption[PluginConfig any] struct {
|
|
f mcpToolRequestFunc[PluginConfig]
|
|
}
|
|
|
|
func OnMCPToolRequest[PluginConfig any](f mcpToolRequestFunc[PluginConfig]) CtxOption[PluginConfig] {
|
|
return &mcpToolRequestOption[PluginConfig]{f}
|
|
}
|
|
|
|
func (o *mcpToolRequestOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
|
|
ctx.isJsonRpcSever = true
|
|
ctx.onMcpToolRequest = o.f
|
|
}
|
|
|
|
type mcpToolResponseOption[PluginConfig any] struct {
|
|
f mcpToolResponseFunc[PluginConfig]
|
|
}
|
|
|
|
func OnMCPToolResponse[PluginConfig any](f mcpToolResponseFunc[PluginConfig]) CtxOption[PluginConfig] {
|
|
return &mcpToolResponseOption[PluginConfig]{f}
|
|
}
|
|
|
|
func (o *mcpToolResponseOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
|
|
ctx.isJsonRpcSever = true
|
|
ctx.onMcpToolResponse = o.f
|
|
}
|
|
|
|
type jsonRpcErrorOption[PluginConfig any] struct {
|
|
f jsonRpcErrorFunc[PluginConfig]
|
|
}
|
|
|
|
func OnJsonRpcError[PluginConfig any](f jsonRpcErrorFunc[PluginConfig]) CtxOption[PluginConfig] {
|
|
return &jsonRpcErrorOption[PluginConfig]{f}
|
|
}
|
|
|
|
func (o *jsonRpcErrorOption[PluginConfig]) Apply(ctx *CommonVmCtx[PluginConfig]) {
|
|
ctx.isJsonRpcSever = true
|
|
ctx.onJsonRpcError = o.f
|
|
}
|
|
|
|
func (ctx *CommonHttpCtx[PluginConfig]) registerMCPToolProcessor() {
|
|
if !ctx.plugin.vm.isJsonRpcSever {
|
|
return
|
|
}
|
|
if ctx.plugin.vm.handleJsonRpcMethod {
|
|
return
|
|
}
|
|
if ctx.plugin.vm.onMcpToolRequest != nil {
|
|
ctx.plugin.vm.jsonRpcRequestHandler = func(context HttpContext, config PluginConfig, id int64, method string, params gjson.Result) types.Action {
|
|
toolName := params.Get("name").String()
|
|
toolArgs := params.Get("arguments")
|
|
return ctx.plugin.vm.onMcpToolRequest(context, config, toolName, toolArgs)
|
|
}
|
|
}
|
|
if ctx.plugin.vm.onMcpToolResponse != nil {
|
|
ctx.plugin.vm.jsonRpcResponseHandler = func(context HttpContext, config PluginConfig, id int64, result, error gjson.Result) types.Action {
|
|
if result.Exists() {
|
|
isError := result.Get("isError").Bool()
|
|
content := result.Get("content")
|
|
return ctx.plugin.vm.onMcpToolResponse(context, config, isError, content)
|
|
}
|
|
if error.Exists() && ctx.plugin.vm.onJsonRpcError != nil {
|
|
return ctx.plugin.vm.onJsonRpcError(context, config, error.Get("code").Int(), error.Get("message").String())
|
|
}
|
|
return types.ActionContinue
|
|
}
|
|
}
|
|
}
|
|
|
|
func ToInputSchema(v any) map[string]any {
|
|
t := reflect.TypeOf(v)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
inputSchema := jsonschema.Reflect(v).Definitions[t.Name()]
|
|
inputSchemaBytes, _ := json.Marshal(inputSchema)
|
|
var result map[string]any
|
|
json.Unmarshal(inputSchemaBytes, &result)
|
|
return result
|
|
}
|