mirror of
https://github.com/alibaba/higress.git
synced 2026-02-06 23:21:08 +08:00
96 lines
2.6 KiB
Cheetah
96 lines
2.6 KiB
Cheetah
import asyncio
|
|
from typing import Any
|
|
import os
|
|
import sys
|
|
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.memory import InMemoryMemory
|
|
from agentscope.message import Msg
|
|
from agentscope.pipeline._functional import stream_printing_messages
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.model import DashScopeChatModel
|
|
from agentscope.formatter import DashScopeChatFormatter
|
|
|
|
from agentrun.integration.agentscope import model, sandbox_toolset, toolset
|
|
from agentrun.sandbox import TemplateType
|
|
from agentrun.server import AgentRequest, AgentRunServer
|
|
from agentrun.utils.log import logger
|
|
|
|
from agent import Agent
|
|
from toolkit import toolkit, init_toolkit_sync
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "python"))
|
|
|
|
MODEL_NAME = "{{ .ChatModel }}"
|
|
SANDBOX_NAME = os.getenv("AGENTRUN_SANDBOX_NAME")
|
|
|
|
if not MODEL_NAME:
|
|
raise ValueError("请将 MODEL_NAME 替换为您已经创建的模型名称")
|
|
|
|
code_interpreter_tools = []
|
|
if SANDBOX_NAME and not SANDBOX_NAME.startswith("<"):
|
|
code_interpreter_tools = sandbox_toolset(
|
|
template_name=SANDBOX_NAME,
|
|
template_type=TemplateType.CODE_INTERPRETER,
|
|
sandbox_idle_timeout_seconds=300,
|
|
)
|
|
else:
|
|
logger.warning("SANDBOX_NAME 未设置或未替换,跳过加载沙箱工具。")
|
|
|
|
def load_sys_prompt(prompt_file_name="prompt.md"):
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
prompt_path = os.path.join(script_dir, prompt_file_name)
|
|
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
|
|
agent = Agent(
|
|
name="{{ .AgentName }}",
|
|
model=model(MODEL_NAME), # type: ignore
|
|
sys_prompt=load_sys_prompt(),
|
|
toolkit=toolkit,
|
|
memory=InMemoryMemory(),
|
|
formatter=DashScopeChatFormatter(),
|
|
)
|
|
|
|
|
|
async def invoke_agent(request: AgentRequest):
|
|
try:
|
|
content = request.messages[0].content
|
|
input_msg = Msg(
|
|
name="user_message",
|
|
content=content, # type: ignore
|
|
role="user",
|
|
)
|
|
|
|
async for msg, _ in stream_printing_messages(
|
|
agents=[agent],
|
|
coroutine_task=agent(input_msg),
|
|
):
|
|
text = msg.get_text_content()
|
|
if text:
|
|
yield text
|
|
|
|
except Exception:
|
|
logger.exception("调用出错")
|
|
raise
|
|
|
|
|
|
def main():
|
|
init_toolkit_sync()
|
|
|
|
AgentRunServer(invoke_agent=invoke_agent).start()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
"""
|
|
curl 127.0.0.1:9000/openai/v1/chat/completions -XPOST \
|
|
-H "content-type: application/json" \
|
|
-d '{
|
|
"messages": [{"role": "user", "content": "写一段代码,查询现在是几点?"}],
|
|
"stream":true
|
|
}'
|
|
"""
|