cc-slim/src/cc_slim/engine.py
2026-04-13 18:19:45 +08:00

532 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import os
import platform
import sys
import tomllib
from dataclasses import dataclass
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Iterator
from anthropic import Anthropic
from langfuse.openai import OpenAI
from cc_slim.memory import MemoryStore
from cc_slim.permissions import PermissionChecker
from cc_slim.session import SessionStore
from cc_slim.tools import Tool
@dataclass
class Config:
provider: str
model: str
api_key: str
base_url: str | None
langfuse_public_key: str | None = None
langfuse_secret_key: str | None = None
langfuse_base_url: str | None = None
max_turns: int = 12
def resolve_config(config_root: Path, cli: dict[str, Any]) -> Config:
file_cfg = _load_file_config(config_root / ".cc-slim.toml")
provider = _pick(cli.get("provider"), os.getenv("CC_SLIM_PROVIDER"), file_cfg.get("provider"), "openai")
model = _pick(
cli.get("model"),
os.getenv("CC_SLIM_MODEL"),
file_cfg.get("model"),
"gpt-4.1-mini" if provider == "openai" else "claude-3-5-haiku-latest",
)
api_key = _pick(
cli.get("api_key"),
os.getenv("CC_SLIM_API_KEY"),
os.getenv("OPENAI_API_KEY") if provider == "openai" else os.getenv("ANTHROPIC_API_KEY"),
file_cfg.get("api_key"),
"",
)
base_url = _pick(cli.get("base_url"), os.getenv("CC_SLIM_BASE_URL"), file_cfg.get("base_url"), None)
langfuse_public_key = _pick(
os.getenv("LANGFUSE_PUBLIC_KEY"),
file_cfg.get("LANGFUSE_PUBLIC_KEY"),
None,
)
langfuse_secret_key = _pick(
os.getenv("LANGFUSE_SECRET_KEY"),
file_cfg.get("LANGFUSE_SECRET_KEY"),
None,
)
langfuse_base_url = _pick(
os.getenv("LANGFUSE_BASE_URL"),
file_cfg.get("LANGFUSE_BASE_URL"),
None,
)
max_turns_raw = _pick(cli.get("max_turns"), os.getenv("CC_SLIM_MAX_TURNS"), file_cfg.get("max_turns"), 12)
if not api_key:
raise ValueError("缺少 API key请通过 CLI、环境变量或 .cc-slim.toml 提供。")
return Config(
provider=str(provider).strip().lower(),
model=str(model).strip(),
api_key=str(api_key).strip(),
base_url=str(base_url).strip() if base_url else None,
langfuse_public_key=str(langfuse_public_key).strip() if langfuse_public_key else None,
langfuse_secret_key=str(langfuse_secret_key).strip() if langfuse_secret_key else None,
langfuse_base_url=str(langfuse_base_url).strip() if langfuse_base_url else None,
max_turns=int(max_turns_raw),
)
class Agent:
def __init__(
self,
config: Config,
tools: list[Tool],
workspace: Path,
session_store: SessionStore | None = None,
session_id: str | None = None,
history: list[dict[str, Any]] | None = None,
permission_checker: PermissionChecker | None = None,
confirm_tool: Any | None = None,
) -> None:
self.config = config
self.tools = {tool.name: tool for tool in tools}
self.workspace = workspace
self.history: list[dict[str, Any]] = list(history or [])
self.session_store = session_store
self.session_id = session_id
self.permission_checker = permission_checker
self.confirm_tool = confirm_tool
self.system_prompt = self._build_system_prompt(workspace)
self.client = self._build_client()
def reply(self, user_input: str) -> str:
parts: list[str] = []
for event in self.stream_reply(user_input):
if event["type"] == "text":
parts.append(event["content"])
elif event["type"] == "error":
raise RuntimeError(event["message"])
return "".join(parts).strip() or "(empty response)"
def dream(self, memory_store: MemoryStore) -> str:
if not self.history:
return "当前会话没有可整理的内容。"
dream_markdown = self._run_dream_model(memory_store.read(), self._serialize_history_for_dream())
memory_store.apply_dream(dream_markdown)
self.system_prompt = self._build_system_prompt(self.workspace)
return "已完成 dreammemory 已更新"
def stream_reply(self, user_input: str) -> Iterator[dict[str, Any]]:
user_message = {"role": "user", "content": user_input}
self.history.append(user_message)
self._save_message(user_message)
for _ in range(self.config.max_turns):
try:
if self.config.provider == "openai":
result = yield from self._stream_openai()
else:
result = yield from self._stream_anthropic()
except Exception as exc:
yield {"type": "error", "message": str(exc)}
return
self.history.append(result["assistant"])
self._save_message(result["assistant"])
if not result["tool_calls"]:
yield {"type": "done"}
return
for call in result["tool_calls"]:
denied_output = self._check_tool_permission(call["name"], call["input"])
if denied_output is not None:
yield {"type": "tool_result", "name": call["name"], "output": denied_output}
self.history.append(
{
"role": "tool",
"tool_call_id": call["id"],
"name": call["name"],
"content": denied_output,
}
)
self._save_message(self.history[-1])
continue
yield {"type": "tool_call", "name": call["name"], "input": call["input"]}
tool_output = self._run_tool(call["name"], call["input"])
yield {"type": "tool_result", "name": call["name"], "output": tool_output}
self.history.append(
{
"role": "tool",
"tool_call_id": call["id"],
"name": call["name"],
"content": tool_output,
}
)
self._save_message(self.history[-1])
yield {"type": "error", "message": "已达到最大工具循环轮数,停止执行。"}
def _build_client(self) -> Any:
if self.config.provider == "openai":
self._configure_langfuse()
kwargs: dict[str, Any] = {"api_key": self.config.api_key}
if self.config.base_url:
kwargs["base_url"] = self.config.base_url
return OpenAI(**kwargs)
if self.config.provider == "anthropic":
kwargs = {"api_key": self.config.api_key}
if self.config.base_url:
kwargs["base_url"] = self.config.base_url
return Anthropic(**kwargs)
raise ValueError(f"不支持的 provider: {self.config.provider}")
def _build_system_prompt(self, workspace: Path) -> str:
parts: list[str] = []
parts.append(self._load_builtin_system_prompt())
parts.append(self._build_runtime_summary(workspace))
workspace_agents = self._load_workspace_agents(workspace)
if workspace_agents:
parts.append(workspace_agents)
parts.extend(self._load_workspace_skills(workspace))
memory_prompt = self._load_memory_prompt(workspace)
if memory_prompt:
parts.append(memory_prompt)
return "\n\n".join(part.strip() for part in parts if part.strip())
def _load_builtin_system_prompt(self) -> str:
return (Path(__file__).resolve().parent / "system.md").read_text(encoding="utf-8")
def _load_workspace_agents(self, workspace: Path) -> str:
agents = workspace / "AGENTS.md"
if not agents.exists():
return ""
return agents.read_text(encoding="utf-8")
def _load_workspace_skills(self, workspace: Path) -> list[str]:
skills: list[str] = []
skills_dir = workspace / "SKILLS"
if not skills_dir.exists():
return skills
for path in sorted(skills_dir.glob("*.md"), key=lambda p: p.name):
skills.append(path.read_text(encoding="utf-8"))
return skills
def _load_memory_prompt(self, workspace: Path) -> str:
memory = MemoryStore(workspace).read()
if not memory:
return ""
return memory
def _serialize_history_for_dream(self) -> str:
lines: list[str] = []
for item in self.history:
role = item.get("role", "unknown")
if role == "assistant" and item.get("tool_calls"):
content = item.get("content", "")
else:
content = item.get("content", "")
if isinstance(content, list):
content = json.dumps(content, ensure_ascii=False)
lines.append(f"[{role}] {content}")
return "\n".join(lines)
def _dream_prompt(self, current_memory: str, session_text: str) -> str:
return (
"你不是在做聊天总结,而是在从当前 session 中提炼长期有价值的信息。\n"
"只保留未来仍值得记住的内容,忽略一次性任务细节、临时状态和短期噪声。\n"
"请结合当前 memory 与当前 session输出可直接写回 memory 的 Markdown。\n"
"只输出以下四个 sections不要输出代码块、解释或其他标题\n\n"
"## User Memory\n"
"## Project Memory\n"
"## Constraints\n"
"## Consolidated Facts\n\n"
f"当前 memory:\n{current_memory}\n\n"
f"当前 session:\n{session_text}"
)
def _run_dream_model(self, current_memory: str, session_text: str) -> str:
prompt = self._dream_prompt(current_memory, session_text)
if self.config.provider == "openai":
response = self.client.chat.completions.create(
model=self.config.model,
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt},
],
)
return response.choices[0].message.content or ""
response = self.client.messages.create(
model=self.config.model,
system=self.system_prompt,
max_tokens=2048,
messages=[{"role": "user", "content": prompt}],
)
return "\n".join(block.text for block in response.content if block.type == "text")
def _build_runtime_summary(self, workspace: Path) -> str:
tool_names = ", ".join(self.tools.keys()) or "(none)"
shell_name = self._detect_shell()
return "\n".join(
[
"## 运行环境",
f"- 平台: {platform.system() or 'Unknown'}",
f"- sys.platform: {sys.platform}",
f"- shell: {shell_name}",
f"- workspace: {workspace}",
"- 默认边界: workspace only",
f"- 可用工具: {tool_names}",
"- 行动时必须以以上运行环境信息为准,不要默认套用 Unix/Linux 命令习惯。",
"- 默认只应在当前 workspace 内读写文件并执行与项目相关的操作,不要主动探索工作区外路径。",
]
)
def _detect_shell(self) -> str:
if os.name == "nt":
return os.getenv("COMSPEC", "Windows shell (likely PowerShell or cmd.exe)")
return os.getenv("SHELL", "unknown shell")
def _configure_langfuse(self) -> None:
if self.config.langfuse_public_key:
os.environ["LANGFUSE_PUBLIC_KEY"] = self.config.langfuse_public_key
if self.config.langfuse_secret_key:
os.environ["LANGFUSE_SECRET_KEY"] = self.config.langfuse_secret_key
if self.config.langfuse_base_url:
os.environ["LANGFUSE_BASE_URL"] = self.config.langfuse_base_url
def _call_openai(self) -> dict[str, Any]:
response = self.client.chat.completions.create(
model=self.config.model,
messages=self._openai_messages(),
tools=self._openai_tools(),
tool_choice="auto",
)
message = response.choices[0].message
text = message.content or ""
tool_calls = []
for call in message.tool_calls or []:
tool_calls.append(
{
"id": call.id,
"name": call.function.name,
"input": json.loads(call.function.arguments or "{}"),
}
)
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
def _call_anthropic(self) -> dict[str, Any]:
response = self.client.messages.create(
model=self.config.model,
system=self.system_prompt,
max_tokens=2048,
messages=self._anthropic_messages(),
tools=self._anthropic_tools(),
)
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
content_blocks: list[dict[str, Any]] = []
for block in response.content:
if block.type == "text":
text_parts.append(block.text)
content_blocks.append({"type": "text", "text": block.text})
elif block.type == "tool_use":
payload = dict(block.input)
tool_calls.append({"id": block.id, "name": block.name, "input": payload})
content_blocks.append({"type": "tool_use", "id": block.id, "name": block.name, "input": payload})
assistant = {"role": "assistant", "content": content_blocks, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": "\n".join(text_parts)}
def _stream_openai(self) -> Iterator[dict[str, Any]]:
stream = self.client.chat.completions.create(
model=self.config.model,
messages=self._openai_messages(),
tools=self._openai_tools(),
tool_choice="auto",
stream=True,
)
text_parts: list[str] = []
tool_call_parts: dict[int, dict[str, Any]] = {}
for chunk in stream:
choice = chunk.choices[0] if chunk.choices else None
if choice is None:
continue
delta = choice.delta
if delta.content:
text_parts.append(delta.content)
yield {"type": "text", "content": delta.content}
for tool_delta in delta.tool_calls or []:
slot = tool_call_parts.setdefault(
tool_delta.index,
{"id": "", "name": "", "arguments": ""},
)
if tool_delta.id:
slot["id"] = tool_delta.id
if tool_delta.function and tool_delta.function.name:
slot["name"] = tool_delta.function.name
if tool_delta.function and tool_delta.function.arguments:
slot["arguments"] += tool_delta.function.arguments
tool_calls = [self._finalize_openai_tool_call(part) for _, part in sorted(tool_call_parts.items())]
text = "".join(text_parts)
assistant = {"role": "assistant", "content": text, "tool_calls": tool_calls}
return {"assistant": assistant, "tool_calls": tool_calls, "text": text}
def _stream_anthropic(self) -> Iterator[dict[str, Any]]:
result = self._call_anthropic()
if result["text"]:
yield {"type": "text", "content": result["text"]}
return result
def _check_tool_permission(self, name: str, payload: dict[str, Any]) -> str | None:
if not self.permission_checker:
return None
if self.permission_checker.is_hard_blocked(name, payload):
return self.permission_checker.denial_reason(name, payload)
if self.permission_checker.requires_confirmation(name, payload):
if not self.confirm_tool:
return self.permission_checker.denial_reason(name, payload)
if not self.confirm_tool(name, payload):
return self.permission_checker.denial_reason(name, payload)
return None
def _run_tool(self, name: str, payload: dict[str, Any]) -> str:
tool = self.tools.get(name)
if not tool:
return f"Tool not found: {name}"
try:
return tool.execute(payload)
except Exception as exc:
return f"Tool error in {name}: {exc}"
def _openai_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.input_schema,
},
}
for tool in self.tools.values()
]
def _anthropic_tools(self) -> list[dict[str, Any]]:
return [
{
"name": tool.name,
"description": tool.description,
"input_schema": tool.input_schema,
}
for tool in self.tools.values()
]
def _openai_messages(self) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
for item in self.history:
if item["role"] == "user":
messages.append({"role": "user", "content": item["content"]})
elif item["role"] == "assistant":
payload: dict[str, Any] = {"role": "assistant", "content": item.get("content", "")}
if item.get("tool_calls"):
payload["tool_calls"] = [
{
"id": call["id"],
"type": "function",
"function": {
"name": call["name"],
"arguments": json.dumps(call["input"], ensure_ascii=False),
},
}
for call in item["tool_calls"]
]
messages.append(payload)
elif item["role"] == "tool":
messages.append(
{
"role": "tool",
"tool_call_id": item["tool_call_id"],
"content": item["content"],
}
)
return messages
def _anthropic_messages(self) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []
for item in self.history:
if item["role"] == "user":
messages.append({"role": "user", "content": item["content"]})
elif item["role"] == "assistant":
content = item.get("content", "")
messages.append({"role": "assistant", "content": content if isinstance(content, list) else content or ""})
elif item["role"] == "tool":
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": item["tool_call_id"],
"content": item["content"],
}
],
}
)
return messages
def _finalize_openai_tool_call(self, part: dict[str, Any]) -> dict[str, Any]:
arguments = part.get("arguments", "") or "{}"
try:
payload = json.loads(arguments)
except JSONDecodeError as exc:
raise ValueError(f"工具参数解析失败: {exc}") from exc
return {
"id": part.get("id") or f"call_{part.get('name', 'tool')}",
"name": part.get("name") or "",
"input": payload,
}
def _save_message(self, message: dict[str, Any]) -> None:
if self.session_store and self.session_id:
self.session_store.append_message(self.session_id, message)
def _load_file_config(path: Path) -> dict[str, Any]:
if not path.exists():
return {}
data = tomllib.loads(path.read_text(encoding="utf-8"))
if "cc_slim" in data and isinstance(data["cc_slim"], dict):
return dict(data["cc_slim"])
return {k: v for k, v in data.items() if not isinstance(v, dict)}
def _pick(*values: Any) -> Any:
for value in values:
if value is None:
continue
if isinstance(value, str) and not value.strip():
continue
return value
return None