220 lines
6.1 KiB
Python
220 lines
6.1 KiB
Python
"""
|
||
配置管理模块
|
||
提供统一的配置加载和访问接口,支持大模型配置、系统配置等
|
||
"""
|
||
import json
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Dict, Any, Optional
|
||
|
||
|
||
class Config:
|
||
"""配置管理类,单例模式"""
|
||
|
||
_instance = None
|
||
_config_data: Dict[str, Any] = {}
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not self._config_data:
|
||
self.load_config()
|
||
|
||
def load_config(self, config_path: Optional[str] = None):
|
||
"""
|
||
加载配置文件
|
||
|
||
Args:
|
||
config_path: 配置文件路径,默认为项目根目录下的 config.json
|
||
"""
|
||
if config_path is None:
|
||
# 获取项目根目录(core 目录的上级目录)
|
||
project_root = Path(__file__).parent.parent
|
||
config_path = project_root / "config.json"
|
||
|
||
if not os.path.exists(config_path):
|
||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
self._config_data = json.load(f)
|
||
|
||
def get(self, key: str, default: Any = None) -> Any:
|
||
"""
|
||
获取配置项(支持点号分隔的多级key)
|
||
|
||
Args:
|
||
key: 配置项键,如 "llm.model_name" 或 "llm"
|
||
default: 默认值
|
||
|
||
Returns:
|
||
配置值
|
||
"""
|
||
keys = key.split('.')
|
||
value = self._config_data
|
||
|
||
for k in keys:
|
||
if isinstance(value, dict):
|
||
value = value.get(k)
|
||
if value is None:
|
||
return default
|
||
else:
|
||
return default
|
||
|
||
return value
|
||
|
||
@property
|
||
def llm_config(self) -> Dict[str, Any]:
|
||
"""获取大模型配置"""
|
||
return self.get('llm', {})
|
||
|
||
@property
|
||
def model_name(self) -> str:
|
||
"""获取模型名称"""
|
||
return self.get('llm.model_name', 'gpt-4')
|
||
|
||
@property
|
||
def api_base(self) -> str:
|
||
"""获取API地址"""
|
||
return self.get('llm.api_base', 'https://api.openai.com/v1')
|
||
|
||
@property
|
||
def api_key(self) -> str:
|
||
"""获取API Key"""
|
||
return self.get('llm.api_key', '')
|
||
|
||
@property
|
||
def temperature(self) -> float:
|
||
"""获取温度参数"""
|
||
return self.get('llm.temperature', 0.7)
|
||
|
||
@property
|
||
def max_tokens(self) -> int:
|
||
"""获取最大token数"""
|
||
return self.get('llm.max_tokens', 2000)
|
||
|
||
@property
|
||
def timeout(self) -> int:
|
||
"""获取超时时间"""
|
||
return self.get('llm.timeout', 60)
|
||
|
||
def update(self, key: str, value: Any):
|
||
"""
|
||
更新配置项
|
||
|
||
Args:
|
||
key: 配置项键,支持点号分隔的多级key
|
||
value: 新值
|
||
"""
|
||
keys = key.split('.')
|
||
config = self._config_data
|
||
|
||
for k in keys[:-1]:
|
||
if k not in config:
|
||
config[k] = {}
|
||
config = config[k]
|
||
|
||
config[keys[-1]] = value
|
||
|
||
def save_config(self, config_path: Optional[str] = None):
|
||
"""
|
||
保存配置到文件
|
||
|
||
Args:
|
||
config_path: 配置文件路径,默认为项目根目录下的 config.json
|
||
"""
|
||
if config_path is None:
|
||
project_root = Path(__file__).parent.parent
|
||
config_path = project_root / "config.json"
|
||
|
||
with open(config_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self._config_data, f, ensure_ascii=False, indent=2)
|
||
|
||
|
||
# 全局配置实例
|
||
config = Config()
|
||
|
||
|
||
def get_llm_client():
|
||
"""
|
||
获取配置好的LLM客户端(OpenAI SDK)
|
||
|
||
Returns:
|
||
OpenAI客户端实例
|
||
"""
|
||
try:
|
||
from openai import OpenAI
|
||
|
||
client = OpenAI(
|
||
api_key=config.api_key,
|
||
base_url=config.api_base,
|
||
timeout=config.timeout
|
||
)
|
||
return client
|
||
except ImportError:
|
||
print("警告: openai 库未安装,无法创建LLM客户端")
|
||
return None
|
||
except Exception as e:
|
||
print(f"创建LLM客户端失败: {e}")
|
||
return None
|
||
|
||
|
||
def llm_call(messages: list, **kwargs) -> Optional[str]:
|
||
"""
|
||
统一的LLM调用接口
|
||
|
||
Args:
|
||
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
|
||
**kwargs: 其他参数(会覆盖配置文件中的默认值)
|
||
|
||
Returns:
|
||
LLM返回的文本内容,失败返回None
|
||
"""
|
||
client = get_llm_client()
|
||
if client is None:
|
||
return None
|
||
|
||
try:
|
||
# 合并配置和传入的参数
|
||
params = {
|
||
'model': config.model_name,
|
||
'temperature': config.temperature,
|
||
'max_tokens': config.max_tokens,
|
||
}
|
||
params.update(kwargs)
|
||
|
||
response = client.chat.completions.create(
|
||
messages=messages,
|
||
**params
|
||
)
|
||
return response.choices[0].message.content
|
||
except Exception as e:
|
||
print(f"LLM调用失败: {e}")
|
||
return None
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试配置加载
|
||
print("=== 配置管理模块测试 ===")
|
||
print(f"模型名称: {config.model_name}")
|
||
print(f"API地址: {config.api_base}")
|
||
print(f"API Key: {config.api_key[:10]}..." if config.api_key else "API Key: 未配置")
|
||
print(f"温度: {config.temperature}")
|
||
print(f"最大tokens: {config.max_tokens}")
|
||
|
||
print("\n完整LLM配置:")
|
||
print(json.dumps(config.llm_config, indent=2, ensure_ascii=False))
|
||
|
||
# 测试配置更新
|
||
print("\n测试配置更新...")
|
||
config.update('llm.temperature', 0.5)
|
||
print(f"更新后的温度: {config.temperature}")
|
||
print()
|
||
# 测试获取不存在的配置
|
||
print(f"\n获取不存在的配置: {config.get('nonexistent.key', 'default_value')}")
|
||
# 测试调用
|
||
respodse = llm_call([{"role": "user", "content": "你好,用100字介绍一下你自己。"}])
|
||
print(f"\nLLM调用结果:\n{respodse}")
|