cc-slim/src/cc_slim/engine.py

390 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import os
import platform
import sys
import tomllib
from dataclasses import dataclass
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Iterator
from anthropic import Anthropic
from openai import OpenAI
from cc_slim.memory import MemoryStore
from cc_slim.session import SessionStore
from cc_slim.tools import Tool
@dataclass
class Config:
provider: str
model: str
api_key: str
base_url: str | None
max_turns: int = 12
def resolve_config(workspace: Path, cli: dict[str, Any]) -> Config:
file_cfg = _load_file_config(workspace / ".cc-slim.toml")
provider = _pick(cli.get("provider"), os.getenv("CC_SLIM_PROVIDER"), file_cfg.get("provider"), "openai")
model = _pick(
cli.get("model"),
os.getenv("CC_SLIM_MODEL"),
file_cfg.get("model"),
"gpt-4.1-mini" if provider == "openai" else "claude-3-5-haiku-latest",
)
api_key = _pick(
cli.get("api_key"),
os.getenv("CC_SLIM_API_KEY"),
os.getenv("OPENAI_API_KEY") if provider == "openai" else os.getenv("ANTHROPIC_API_KEY"),
file_cfg.get("api_key"),
"",
)
base_url = _pick(cli.get("base_url"), os.getenv("CC_SLIM_BASE_URL"), file_cfg.get("base_url"), None)
max_turns_raw = _pick(cli.get("max_turns"), os.getenv("CC_SLIM_MAX_TURNS"), file_cfg.get("max_turns"), 12)
if not api_key:
raise ValueError("缺少 API key请通过 CLI、环境变量或 .cc-slim.toml 提供。")
return Config(
provider=str(provider).strip().lower(),
model=str(model).strip(),
api_key=str(api_key).strip(),
base_url=str(base_url).strip() if base_url else None,
max_turns=int(max_turns_raw),
)
class Agent:
def __init__(
self,
config: Config,
tools: list[Tool],
workspace: Path,
session_store: SessionStore | None = None,
session_id: str | None = None,
history: list[dict[str, Any]] | None = None,
) -> None:
self.config = config
self.tools = {tool.name: tool for tool in tools}
self.history: list[dict[str, Any]] = list(history or [])
self.session_store = session_store
self.session_id = session_id
self.system_prompt = self._build_system_prompt(workspace)
self.client = self._build_client()
def reply(self, user_input: str) -> str:
parts: list[str] = []
for event in self.stream_reply(user_input):
if event["type"] == "text":
parts.append(event["content"])
elif event["type"] == "error":
raise RuntimeError(event["message"])
return "".join(parts).strip() or "(empty response)"
def stream_reply(self, user_input: str) -> Iterator[dict[str, Any]]:
user_message = {"role": "user", "content": user_input}
self.history.append(user_message)
self._save_message(user_message)
for _ in range(self.config.max_turns):
try:
if self.config.provider == "openai":
result = yield from self._stream_openai()
else:
result = yield from self._stream_anthropic()
except Exception as exc:
yield {"type": "error", "message": str(exc)}
return
self.history.append(result["assistant"])
self._save_message(result["assistant"])
if not result["tool_calls"]:
yield {"type": "done"}
return
for call in result["tool_calls"]:
yield {"type": "tool_call", "name": call["name"], "input": call["input"]}
tool_output = self._run_tool(call["name"], call["input"])
yield {"type": "tool_result", "name": call["name"], "output": tool_output}
self.history.append(
{
"role": "tool",
"tool_call_id": call["id"],
"name": call["name"],
"content": tool_output,
}
)
self._save_message(self.history[-1])
yield {"type": "error", "message": "已达到最大工具循环轮数,停止执行。"}
def _build_client(self) -> Any:
if self.config.provider == "openai":
kwargs: dict[str, Any] = {"api_key": self.config.api_key}
if self.config.base_url:
kwargs["base_url"] = self.config.base_url
return OpenAI(**kwargs)
if self.config.provider == "anthropic":
kwargs = {"api_key": self.config.api_key}
if self.config.base_url:
kwargs["base_url"] = self.config.base_url
return Anthropic(**kwargs)
raise ValueError(f"不支持的 provider: {self.config.provider}")
def _build_system_prompt(self, workspace: Path) -> str:
parts: list[str] = []
parts.append(self._build_runtime_summary(workspace))
agents = workspace / "AGENTS.md"
if agents.exists():
parts.append(agents.read_text(encoding="utf-8"))
skills_dir = workspace / "SKILLS"
if skills_dir.exists():
for path in sorted(skills_dir.glob("*.md"), key=lambda p: p.name):
parts.append(path.read_text(encoding="utf-8"))
memory = MemoryStore(workspace).read()
if memory:
parts.append(f"# Memory\n\n{memory}")
return "\n\n".join(part.strip() for part in parts if part.strip())
def _build_runtime_summary(self, workspace: Path) -> str:
tool_names = ", ".join(self.tools.keys()) or "(none)"
shell_name = self._detect_shell()
return "\n".join(
[
"## 运行环境",
f"- 平台: {platform.system() or 'Unknown'}",
f"- sys.platform: {sys.platform}",
f"- shell: {shell_name}",
f"- workspace: {workspace}",
f"- 可用工具: {tool_names}",
"- 行动时必须以以上运行环境信息为准,不要默认套用 Unix/Linux 命令习惯。",
]
)
def _detect_shell(self) -> str:
if os.name == "nt":
return os.getenv("COMSPEC", "Windows shell (likely PowerShell or cmd.exe)")
return os.getenv("SHELL", "unknown shell")
def _call_openai(self) -> dict[str, Any]:
response = self.client.chat.completions.create(
model=self.config.model,
messages=self._openai_messages(),
tools=self._openai_tools(),
tool_choice="auto",
)
message = response.choices[0].message
text = message.content or ""
tool_calls = []
for call in message.tool_calls or []:
tool_calls.append(
{
"id": call.id,
"name": call.function.name,
"input": json.loads(call.function.arguments or "{}"),
}
)
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
def _call_anthropic(self) -> dict[str, Any]:
response = self.client.messages.create(
model=self.config.model,
system=self.system_prompt,
max_tokens=2048,
messages=self._anthropic_messages(),
tools=self._anthropic_tools(),
)
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
content_blocks: list[dict[str, Any]] = []
for block in response.content:
if block.type == "text":
text_parts.append(block.text)
content_blocks.append({"type": "text", "text": block.text})
elif block.type == "tool_use":
payload = dict(block.input)
tool_calls.append({"id": block.id, "name": block.name, "input": payload})
content_blocks.append({"type": "tool_use", "id": block.id, "name": block.name, "input": payload})
assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)}
def _stream_openai(self) -> Iterator[dict[str, Any]]:
stream = self.client.chat.completions.create(
model=self.config.model,
messages=self._openai_messages(),
tools=self._openai_tools(),
tool_choice="auto",
stream=True,
)
text_parts: list[str] = []
tool_call_parts: dict[int, dict[str, Any]] = {}
for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice is None:
continue
delta = choice.delta
if delta.content:
text_parts.append(delta.content)
yield {"type": "text", "content": delta.content}
for tool_delta in delta.tool_calls or []:
slot = tool_call_parts.setdefault(
tool_delta.index,
{"id": "", "name": "", "arguments": ""},
)
if tool_delta.id:
slot["id"] = tool_delta.id
if tool_delta.function and tool_delta.function.name:
slot["name"] = tool_delta.function.name
if tool_delta.function and tool_delta.function.arguments:
slot["arguments"] += tool_delta.function.arguments
tool_calls = [self._finalize_openai_tool_call(part) for _, part in sorted(tool_call_parts.items())]
text = "".join(text_parts)
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
def _stream_anthropic(self) -> Iterator[dict[str, Any]]:
result = self._call_anthropic()
if result["text"]:
yield {"type": "text", "content": result["text"]}
return result
def _run_tool(self, name: str, payload: dict[str, Any]) -> str:
tool = self.tools.get(name)
if not tool:
return f"Tool not found: {name}"
try:
return tool.execute(payload)
except Exception as exc:
return f"Tool error in {name}: {exc}"
def _openai_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in self.tools.values()
]
def _anthropic_tools(self) -> list[dict[str, Any]]:
return [
{
"name": tool.name,
"description": tool.description,
"input_schema": tool.input_schema,
}
for tool in self.tools.values()
]
def _openai_messages(self) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
for item in self.history:
if item["role"] == "user":
messages.append({"role": "user", "content": item["content"]})
elif item["role"] == "assistant":
payload: dict[str, Any] = {"role": "assistant", "content": item.get("content", "")}
if item.get("tool_calls"):
payload["tool_calls"] = [
{
"id": call["id"],
"type": "function",
"function": {
"name": call["name"],
"arguments": json.dumps(call["input"], ensure_ascii=False),
},
}
for call in item["tool_calls"]
]
messages.append(payload)
elif item["role"] == "tool":
messages.append(
{
"role": "tool",
"tool_call_id": item["tool_call_id"],
"content": item["content"],
}
)
return messages
def _anthropic_messages(self) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for item in self.history:
if item["role"] == "user":
messages.append({"role": "user", "content": item["content"]})
elif item["role"] == "assistant":
content = item.get("content", "")
messages.append({"role": "assistant", "content": content if isinstance(content, list) else content or ""})
elif item["role"] == "tool":
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": item["tool_call_id"],
"content": item["content"],
}
],
}
)
return messages
def _finalize_openai_tool_call(self, part: dict[str, Any]) -> dict[str, Any]:
arguments = part.get("arguments", "") or "{}"
try:
payload = json.loads(arguments)
except JSONDecodeError as exc:
raise ValueError(f"工具参数解析失败: {exc}") from exc
return {
"id": part.get("id") or f"call_{part.get('name', 'tool')}",
"name": part.get("name") or "",
"input": payload,
}
def _save_message(self, message: dict[str, Any]) -> None:
if self.session_store and self.session_id:
self.session_store.append_message(self.session_id, message)
def _load_file_config(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
data = tomllib.loads(path.read_text(encoding="utf-8"))
if "cc_slim" in data and isinstance(data["cc_slim"], dict):
return dict(data["cc_slim"])
return {k: v for k, v in data.items() if not isinstance(v, dict)}
def _pick(*values: Any) -> Any:
for value in values:
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
return value
return None