diff --git a/src/cc_slim/engine.py b/src/cc_slim/engine.py index da55fdd..76427a1 100644 --- a/src/cc_slim/engine.py +++ b/src/cc_slim/engine.py @@ -4,8 +4,9 @@ import json import os import tomllib from dataclasses import dataclass +from json import JSONDecodeError from pathlib import Path -from typing import Any +from typing import Any, Iterator from anthropic import Anthropic from openai import OpenAI @@ -62,17 +63,37 @@ class Agent: self.client = self._build_client() def reply(self, user_input: str) -> str: + parts: list[str] = [] + for event in self.stream_reply(user_input): + if event["type"] == "text": + parts.append(event["content"]) + elif event["type"] == "error": + raise RuntimeError(event["message"]) + return "".join(parts).strip() or "(empty response)" + + def stream_reply(self, user_input: str) -> Iterator[dict[str, Any]]: self.history.append({"role": "user", "content": user_input}) for _ in range(self.config.max_turns): - result = self._call_model() + try: + if self.config.provider == "openai": + result = yield from self._stream_openai() + else: + result = yield from self._stream_anthropic() + except Exception as exc: + yield {"type": "error", "message": str(exc)} + return + self.history.append(result["assistant"]) if not result["tool_calls"]: - return result["text"].strip() or "(empty response)" + yield {"type": "done"} + return for call in result["tool_calls"]: + yield {"type": "tool_call", "name": call["name"], "input": call["input"]} tool_output = self._run_tool(call["name"], call["input"]) + yield {"type": "tool_result", "name": call["name"], "output": tool_output} self.history.append( { "role": "tool", @@ -82,7 +103,7 @@ class Agent: } ) - return "已达到最大工具循环轮数,停止执行。" + yield {"type": "error", "message": "已达到最大工具循环轮数,停止执行。"} def _build_client(self) -> Any: if self.config.provider == "openai": @@ -110,11 +131,6 @@ class Agent: return "\n\n".join(part.strip() for part in parts if part.strip()) - def _call_model(self) -> dict[str, Any]: - if self.config.provider == "openai": - return self._call_openai() - return self._call_anthropic() - def _call_openai(self) -> dict[str, Any]: response = self.client.chat.completions.create( model=self.config.model, @@ -160,6 +176,51 @@ class Agent: assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls} return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)} + def _stream_openai(self) -> Iterator[dict[str, Any]]: + stream = self.client.chat.completions.create( + model=self.config.model, + messages=self._openai_messages(), + tools=self._openai_tools(), + tool_choice="auto", + stream=True, + ) + + text_parts: list[str] = [] + tool_call_parts: dict[int, dict[str, Any]] = {} + + for chunk in stream: + choice = chunk.choices[0] if chunk.choices else None + if choice is None: + continue + + delta = choice.delta + if delta.content: + text_parts.append(delta.content) + yield {"type": "text", "content": delta.content} + + for tool_delta in delta.tool_calls or []: + slot = tool_call_parts.setdefault( + tool_delta.index, + {"id": "", "name": "", "arguments": ""}, + ) + if tool_delta.id: + slot["id"] = tool_delta.id + if tool_delta.function and tool_delta.function.name: + slot["name"] = tool_delta.function.name + if tool_delta.function and tool_delta.function.arguments: + slot["arguments"] += tool_delta.function.arguments + + tool_calls = [self._finalize_openai_tool_call(part) for _, part in sorted(tool_call_parts.items())] + text = "".join(text_parts) + assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls} + return {"assistant": assistant, "tool_calls": tool_calls, "text": text} + + def _stream_anthropic(self) -> Iterator[dict[str, Any]]: + result = self._call_anthropic() + if result["text"]: + yield {"type": "text", "content": result["text"]} + return result + def _run_tool(self, name: str, payload: dict[str, Any]) -> str: tool = self.tools.get(name) if not tool: @@ -248,6 +309,18 @@ class Agent: ) return messages + def _finalize_openai_tool_call(self, part: dict[str, Any]) -> dict[str, Any]: + arguments = part.get("arguments", "") or "{}" + try: + payload = json.loads(arguments) + except JSONDecodeError as exc: + raise ValueError(f"工具参数解析失败: {exc}") from exc + return { + "id": part.get("id") or f"call_{part.get('name', 'tool')}", + "name": part.get("name") or "", + "input": payload, + } + def _load_file_config(path: Path) -> dict[str, Any]: if not path.exists(): diff --git a/src/cc_slim/main.py b/src/cc_slim/main.py index f92cf2c..d83adcf 100644 --- a/src/cc_slim/main.py +++ b/src/cc_slim/main.py @@ -13,6 +13,29 @@ app = typer.Typer(add_completion=False, no_args_is_help=False) console = Console() +def render_stream(agent: Agent, user_input: str) -> None: + printed_text = False + for event in agent.stream_reply(user_input): + if event["type"] == "text": + console.print(event["content"], end="") + printed_text = True + elif event["type"] == "tool_call": + if printed_text: + console.print() + printed_text = False + console.print(f"[cyan]->[/cyan] {event['name']}({event['input']})") + elif event["type"] == "tool_result": + console.print(f"[green]✓[/green] {event['name']} done") + elif event["type"] == "error": + if printed_text: + console.print() + console.print(f"[red]error:[/red] {event['message']}") + return + elif event["type"] == "done": + console.print() + return + + @app.command() def run( prompt: Optional[str] = typer.Argument(None, help="单次执行的用户输入"), @@ -37,7 +60,7 @@ def run( agent = Agent(config=config, tools=build_default_tools(root), workspace=root) if prompt: - console.print(agent.reply(prompt)) + render_stream(agent, prompt) return console.print("[bold cyan]cc-slim[/bold cyan] REPL,输入 exit 或 quit 退出。") @@ -54,7 +77,7 @@ def run( continue try: - console.print(agent.reply(user_input)) + render_stream(agent, user_input) except Exception as exc: # pragma: no cover console.print(f"[red]error:[/red] {exc}")