add 流式输出

This commit is contained in:
hc 2026-04-10 19:45:20 +08:00
parent 54191f8458
commit 1d66019529
2 changed files with 107 additions and 11 deletions

View File

@ -4,8 +4,9 @@ import json
import os import os
import tomllib import tomllib
from dataclasses import dataclass from dataclasses import dataclass
from json import JSONDecodeError
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Iterator
from anthropic import Anthropic from anthropic import Anthropic
from openai import OpenAI from openai import OpenAI
@ -62,17 +63,37 @@ class Agent:
self.client = self._build_client() self.client = self._build_client()
def reply(self, user_input: str) -> str: def reply(self, user_input: str) -> str:
parts: list[str] = []
for event in self.stream_reply(user_input):
if event["type"] == "text":
parts.append(event["content"])
elif event["type"] == "error":
raise RuntimeError(event["message"])
return "".join(parts).strip() or "(empty response)"
def stream_reply(self, user_input: str) -> Iterator[dict[str, Any]]:
self.history.append({"role": "user", "content": user_input}) self.history.append({"role": "user", "content": user_input})
for _ in range(self.config.max_turns): for _ in range(self.config.max_turns):
result = self._call_model() try:
if self.config.provider == "openai":
result = yield from self._stream_openai()
else:
result = yield from self._stream_anthropic()
except Exception as exc:
yield {"type": "error", "message": str(exc)}
return
self.history.append(result["assistant"]) self.history.append(result["assistant"])
if not result["tool_calls"]: if not result["tool_calls"]:
return result["text"].strip() or "(empty response)" yield {"type": "done"}
return
for call in result["tool_calls"]: for call in result["tool_calls"]:
yield {"type": "tool_call", "name": call["name"], "input": call["input"]}
tool_output = self._run_tool(call["name"], call["input"]) tool_output = self._run_tool(call["name"], call["input"])
yield {"type": "tool_result", "name": call["name"], "output": tool_output}
self.history.append( self.history.append(
{ {
"role": "tool", "role": "tool",
@ -82,7 +103,7 @@ class Agent:
} }
) )
return "已达到最大工具循环轮数,停止执行。" yield {"type": "error", "message": "已达到最大工具循环轮数,停止执行。"}
def _build_client(self) -> Any: def _build_client(self) -> Any:
if self.config.provider == "openai": if self.config.provider == "openai":
@ -110,11 +131,6 @@ class Agent:
return "\n\n".join(part.strip() for part in parts if part.strip()) return "\n\n".join(part.strip() for part in parts if part.strip())
def _call_model(self) -> dict[str, Any]:
if self.config.provider == "openai":
return self._call_openai()
return self._call_anthropic()
def _call_openai(self) -> dict[str, Any]: def _call_openai(self) -> dict[str, Any]:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.config.model, model=self.config.model,
@ -160,6 +176,51 @@ class Agent:
assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls} assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)} return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)}
def _stream_openai(self) -> Iterator[dict[str, Any]]:
stream = self.client.chat.completions.create(
model=self.config.model,
messages=self._openai_messages(),
tools=self._openai_tools(),
tool_choice="auto",
stream=True,
)
text_parts: list[str] = []
tool_call_parts: dict[int, dict[str, Any]] = {}
for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice is None:
continue
delta = choice.delta
if delta.content:
text_parts.append(delta.content)
yield {"type": "text", "content": delta.content}
for tool_delta in delta.tool_calls or []:
slot = tool_call_parts.setdefault(
tool_delta.index,
{"id": "", "name": "", "arguments": ""},
)
if tool_delta.id:
slot["id"] = tool_delta.id
if tool_delta.function and tool_delta.function.name:
slot["name"] = tool_delta.function.name
if tool_delta.function and tool_delta.function.arguments:
slot["arguments"] += tool_delta.function.arguments
tool_calls = [self._finalize_openai_tool_call(part) for _, part in sorted(tool_call_parts.items())]
text = "".join(text_parts)
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
def _stream_anthropic(self) -> Iterator[dict[str, Any]]:
result = self._call_anthropic()
if result["text"]:
yield {"type": "text", "content": result["text"]}
return result
def _run_tool(self, name: str, payload: dict[str, Any]) -> str: def _run_tool(self, name: str, payload: dict[str, Any]) -> str:
tool = self.tools.get(name) tool = self.tools.get(name)
if not tool: if not tool:
@ -248,6 +309,18 @@ class Agent:
) )
return messages return messages
def _finalize_openai_tool_call(self, part: dict[str, Any]) -> dict[str, Any]:
arguments = part.get("arguments", "") or "{}"
try:
payload = json.loads(arguments)
except JSONDecodeError as exc:
raise ValueError(f"工具参数解析失败: {exc}") from exc
return {
"id": part.get("id") or f"call_{part.get('name', 'tool')}",
"name": part.get("name") or "",
"input": payload,
}
def _load_file_config(path: Path) -> dict[str, Any]: def _load_file_config(path: Path) -> dict[str, Any]:
if not path.exists(): if not path.exists():

View File

@ -13,6 +13,29 @@ app = typer.Typer(add_completion=False, no_args_is_help=False)
console = Console() console = Console()
def render_stream(agent: Agent, user_input: str) -> None:
printed_text = False
for event in agent.stream_reply(user_input):
if event["type"] == "text":
console.print(event["content"], end="")
printed_text = True
elif event["type"] == "tool_call":
if printed_text:
console.print()
printed_text = False
console.print(f"[cyan]->[/cyan] {event['name']}({event['input']})")
elif event["type"] == "tool_result":
console.print(f"[green]✓[/green] {event['name']} done")
elif event["type"] == "error":
if printed_text:
console.print()
console.print(f"[red]error:[/red] {event['message']}")
return
elif event["type"] == "done":
console.print()
return
@app.command() @app.command()
def run( def run(
prompt: Optional[str] = typer.Argument(None, help="单次执行的用户输入"), prompt: Optional[str] = typer.Argument(None, help="单次执行的用户输入"),
@ -37,7 +60,7 @@ def run(
agent = Agent(config=config, tools=build_default_tools(root), workspace=root) agent = Agent(config=config, tools=build_default_tools(root), workspace=root)
if prompt: if prompt:
console.print(agent.reply(prompt)) render_stream(agent, prompt)
return return
console.print("[bold cyan]cc-slim[/bold cyan] REPL输入 exit 或 quit 退出。") console.print("[bold cyan]cc-slim[/bold cyan] REPL输入 exit 或 quit 退出。")
@ -54,7 +77,7 @@ def run(
continue continue
try: try:
console.print(agent.reply(user_input)) render_stream(agent, user_input)
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
console.print(f"[red]error:[/red] {exc}") console.print(f"[red]error:[/red] {exc}")