add 流式输出
This commit is contained in:
parent
54191f8458
commit
1d66019529
@ -4,8 +4,9 @@ import json
|
||||
import os
|
||||
import tomllib
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Iterator
|
||||
|
||||
from anthropic import Anthropic
|
||||
from openai import OpenAI
|
||||
@ -62,17 +63,37 @@ class Agent:
|
||||
self.client = self._build_client()
|
||||
|
||||
def reply(self, user_input: str) -> str:
|
||||
parts: list[str] = []
|
||||
for event in self.stream_reply(user_input):
|
||||
if event["type"] == "text":
|
||||
parts.append(event["content"])
|
||||
elif event["type"] == "error":
|
||||
raise RuntimeError(event["message"])
|
||||
return "".join(parts).strip() or "(empty response)"
|
||||
|
||||
def stream_reply(self, user_input: str) -> Iterator[dict[str, Any]]:
|
||||
self.history.append({"role": "user", "content": user_input})
|
||||
|
||||
for _ in range(self.config.max_turns):
|
||||
result = self._call_model()
|
||||
try:
|
||||
if self.config.provider == "openai":
|
||||
result = yield from self._stream_openai()
|
||||
else:
|
||||
result = yield from self._stream_anthropic()
|
||||
except Exception as exc:
|
||||
yield {"type": "error", "message": str(exc)}
|
||||
return
|
||||
|
||||
self.history.append(result["assistant"])
|
||||
|
||||
if not result["tool_calls"]:
|
||||
return result["text"].strip() or "(empty response)"
|
||||
yield {"type": "done"}
|
||||
return
|
||||
|
||||
for call in result["tool_calls"]:
|
||||
yield {"type": "tool_call", "name": call["name"], "input": call["input"]}
|
||||
tool_output = self._run_tool(call["name"], call["input"])
|
||||
yield {"type": "tool_result", "name": call["name"], "output": tool_output}
|
||||
self.history.append(
|
||||
{
|
||||
"role": "tool",
|
||||
@ -82,7 +103,7 @@ class Agent:
|
||||
}
|
||||
)
|
||||
|
||||
return "已达到最大工具循环轮数,停止执行。"
|
||||
yield {"type": "error", "message": "已达到最大工具循环轮数,停止执行。"}
|
||||
|
||||
def _build_client(self) -> Any:
|
||||
if self.config.provider == "openai":
|
||||
@ -110,11 +131,6 @@ class Agent:
|
||||
|
||||
return "\n\n".join(part.strip() for part in parts if part.strip())
|
||||
|
||||
def _call_model(self) -> dict[str, Any]:
|
||||
if self.config.provider == "openai":
|
||||
return self._call_openai()
|
||||
return self._call_anthropic()
|
||||
|
||||
def _call_openai(self) -> dict[str, Any]:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
@ -160,6 +176,51 @@ class Agent:
|
||||
assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls}
|
||||
return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)}
|
||||
|
||||
def _stream_openai(self) -> Iterator[dict[str, Any]]:
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.config.model,
|
||||
messages=self._openai_messages(),
|
||||
tools=self._openai_tools(),
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_call_parts: dict[int, dict[str, Any]] = {}
|
||||
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if choice is None:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
if delta.content:
|
||||
text_parts.append(delta.content)
|
||||
yield {"type": "text", "content": delta.content}
|
||||
|
||||
for tool_delta in delta.tool_calls or []:
|
||||
slot = tool_call_parts.setdefault(
|
||||
tool_delta.index,
|
||||
{"id": "", "name": "", "arguments": ""},
|
||||
)
|
||||
if tool_delta.id:
|
||||
slot["id"] = tool_delta.id
|
||||
if tool_delta.function and tool_delta.function.name:
|
||||
slot["name"] = tool_delta.function.name
|
||||
if tool_delta.function and tool_delta.function.arguments:
|
||||
slot["arguments"] += tool_delta.function.arguments
|
||||
|
||||
tool_calls = [self._finalize_openai_tool_call(part) for _, part in sorted(tool_call_parts.items())]
|
||||
text = "".join(text_parts)
|
||||
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
|
||||
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
|
||||
|
||||
def _stream_anthropic(self) -> Iterator[dict[str, Any]]:
|
||||
result = self._call_anthropic()
|
||||
if result["text"]:
|
||||
yield {"type": "text", "content": result["text"]}
|
||||
return result
|
||||
|
||||
def _run_tool(self, name: str, payload: dict[str, Any]) -> str:
|
||||
tool = self.tools.get(name)
|
||||
if not tool:
|
||||
@ -248,6 +309,18 @@ class Agent:
|
||||
)
|
||||
return messages
|
||||
|
||||
def _finalize_openai_tool_call(self, part: dict[str, Any]) -> dict[str, Any]:
|
||||
arguments = part.get("arguments", "") or "{}"
|
||||
try:
|
||||
payload = json.loads(arguments)
|
||||
except JSONDecodeError as exc:
|
||||
raise ValueError(f"工具参数解析失败: {exc}") from exc
|
||||
return {
|
||||
"id": part.get("id") or f"call_{part.get('name', 'tool')}",
|
||||
"name": part.get("name") or "",
|
||||
"input": payload,
|
||||
}
|
||||
|
||||
|
||||
def _load_file_config(path: Path) -> dict[str, Any]:
|
||||
if not path.exists():
|
||||
|
||||
@ -13,6 +13,29 @@ app = typer.Typer(add_completion=False, no_args_is_help=False)
|
||||
console = Console()
|
||||
|
||||
|
||||
def render_stream(agent: Agent, user_input: str) -> None:
|
||||
printed_text = False
|
||||
for event in agent.stream_reply(user_input):
|
||||
if event["type"] == "text":
|
||||
console.print(event["content"], end="")
|
||||
printed_text = True
|
||||
elif event["type"] == "tool_call":
|
||||
if printed_text:
|
||||
console.print()
|
||||
printed_text = False
|
||||
console.print(f"[cyan]->[/cyan] {event['name']}({event['input']})")
|
||||
elif event["type"] == "tool_result":
|
||||
console.print(f"[green]✓[/green] {event['name']} done")
|
||||
elif event["type"] == "error":
|
||||
if printed_text:
|
||||
console.print()
|
||||
console.print(f"[red]error:[/red] {event['message']}")
|
||||
return
|
||||
elif event["type"] == "done":
|
||||
console.print()
|
||||
return
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
prompt: Optional[str] = typer.Argument(None, help="单次执行的用户输入"),
|
||||
@ -37,7 +60,7 @@ def run(
|
||||
agent = Agent(config=config, tools=build_default_tools(root), workspace=root)
|
||||
|
||||
if prompt:
|
||||
console.print(agent.reply(prompt))
|
||||
render_stream(agent, prompt)
|
||||
return
|
||||
|
||||
console.print("[bold cyan]cc-slim[/bold cyan] REPL,输入 exit 或 quit 退出。")
|
||||
@ -54,7 +77,7 @@ def run(
|
||||
continue
|
||||
|
||||
try:
|
||||
console.print(agent.reply(user_input))
|
||||
render_stream(agent, user_input)
|
||||
except Exception as exc: # pragma: no cover
|
||||
console.print(f"[red]error:[/red] {exc}")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user