add 流式输出

This commit is contained in:
hc 2026-04-10 19:45:20 +08:00
parent 54191f8458
commit 1d66019529
2 changed files with 107 additions and 11 deletions

View File

@ -4,8 +4,9 @@ import json
import os
import tomllib
from dataclasses import dataclass
from json import JSONDecodeError
from pathlib import Path
from typing import Any
from typing import Any, Iterator
from anthropic import Anthropic
from openai import OpenAI
@ -62,17 +63,37 @@ class Agent:
self.client = self._build_client()
def reply(self, user_input: str) -> str:
parts: list[str] = []
for event in self.stream_reply(user_input):
if event["type"] == "text":
parts.append(event["content"])
elif event["type"] == "error":
raise RuntimeError(event["message"])
return "".join(parts).strip() or "(empty response)"
def stream_reply(self, user_input: str) -> Iterator[dict[str, Any]]:
self.history.append({"role": "user", "content": user_input})
for _ in range(self.config.max_turns):
result = self._call_model()
try:
if self.config.provider == "openai":
result = yield from self._stream_openai()
else:
result = yield from self._stream_anthropic()
except Exception as exc:
yield {"type": "error", "message": str(exc)}
return
self.history.append(result["assistant"])
if not result["tool_calls"]:
return result["text"].strip() or "(empty response)"
yield {"type": "done"}
return
for call in result["tool_calls"]:
yield {"type": "tool_call", "name": call["name"], "input": call["input"]}
tool_output = self._run_tool(call["name"], call["input"])
yield {"type": "tool_result", "name": call["name"], "output": tool_output}
self.history.append(
{
"role": "tool",
@ -82,7 +103,7 @@ class Agent:
}
)
return "已达到最大工具循环轮数,停止执行。"
yield {"type": "error", "message": "已达到最大工具循环轮数,停止执行。"}
def _build_client(self) -> Any:
if self.config.provider == "openai":
@ -110,11 +131,6 @@ class Agent:
return "\n\n".join(part.strip() for part in parts if part.strip())
def _call_model(self) -> dict[str, Any]:
if self.config.provider == "openai":
return self._call_openai()
return self._call_anthropic()
def _call_openai(self) -> dict[str, Any]:
response = self.client.chat.completions.create(
model=self.config.model,
@ -160,6 +176,51 @@ class Agent:
assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)}
def _stream_openai(self) -> Iterator[dict[str, Any]]:
stream = self.client.chat.completions.create(
model=self.config.model,
messages=self._openai_messages(),
tools=self._openai_tools(),
tool_choice="auto",
stream=True,
)
text_parts: list[str] = []
tool_call_parts: dict[int, dict[str, Any]] = {}
for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice is None:
continue
delta = choice.delta
if delta.content:
text_parts.append(delta.content)
yield {"type": "text", "content": delta.content}
for tool_delta in delta.tool_calls or []:
slot = tool_call_parts.setdefault(
tool_delta.index,
{"id": "", "name": "", "arguments": ""},
)
if tool_delta.id:
slot["id"] = tool_delta.id
if tool_delta.function and tool_delta.function.name:
slot["name"] = tool_delta.function.name
if tool_delta.function and tool_delta.function.arguments:
slot["arguments"] += tool_delta.function.arguments
tool_calls = [self._finalize_openai_tool_call(part) for _, part in sorted(tool_call_parts.items())]
text = "".join(text_parts)
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
def _stream_anthropic(self) -> Iterator[dict[str, Any]]:
result = self._call_anthropic()
if result["text"]:
yield {"type": "text", "content": result["text"]}
return result
def _run_tool(self, name: str, payload: dict[str, Any]) -> str:
tool = self.tools.get(name)
if not tool:
@ -248,6 +309,18 @@ class Agent:
)
return messages
def _finalize_openai_tool_call(self, part: dict[str, Any]) -> dict[str, Any]:
arguments = part.get("arguments", "") or "{}"
try:
payload = json.loads(arguments)
except JSONDecodeError as exc:
raise ValueError(f"工具参数解析失败: {exc}") from exc
return {
"id": part.get("id") or f"call_{part.get('name', 'tool')}",
"name": part.get("name") or "",
"input": payload,
}
def _load_file_config(path: Path) -> dict[str, Any]:
if not path.exists():

View File

@ -13,6 +13,29 @@ app = typer.Typer(add_completion=False, no_args_is_help=False)
console = Console()
def render_stream(agent: Agent, user_input: str) -> None:
printed_text = False
for event in agent.stream_reply(user_input):
if event["type"] == "text":
console.print(event["content"], end="")
printed_text = True
elif event["type"] == "tool_call":
if printed_text:
console.print()
printed_text = False
console.print(f"[cyan]->[/cyan] {event['name']}({event['input']})")
elif event["type"] == "tool_result":
console.print(f"[green]✓[/green] {event['name']} done")
elif event["type"] == "error":
if printed_text:
console.print()
console.print(f"[red]error:[/red] {event['message']}")
return
elif event["type"] == "done":
console.print()
return
@app.command()
def run(
prompt: Optional[str] = typer.Argument(None, help="单次执行的用户输入"),
@ -37,7 +60,7 @@ def run(
agent = Agent(config=config, tools=build_default_tools(root), workspace=root)
if prompt:
console.print(agent.reply(prompt))
render_stream(agent, prompt)
return
console.print("[bold cyan]cc-slim[/bold cyan] REPL输入 exit 或 quit 退出。")
@ -54,7 +77,7 @@ def run(
continue
try:
console.print(agent.reply(user_input))
render_stream(agent, user_input)
except Exception as exc: # pragma: no cover
console.print(f"[red]error:[/red] {exc}")