base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. """
  2. 智能体基类
  3. 定义所有策略智能体的标准接口和通用功能
  4. """
  5. from abc import ABC, abstractmethod
  6. from dataclasses import dataclass, field
  7. from datetime import datetime
  8. from typing import Dict, List, Optional, Any, Tuple
  9. from enum import Enum
  10. import uuid
  11. import numpy as np
  12. import pandas as pd
  13. class SignalDirection(Enum):
  14. """信号方向"""
  15. LONG = "long"
  16. SHORT = "short"
  17. NEUTRAL = "neutral"
  18. class SignalStrength(Enum):
  19. """信号强度"""
  20. STRONG = "strong"
  21. MEDIUM = "medium"
  22. WEAK = "weak"
  23. @dataclass
  24. class AgentSignal:
  25. """智能体信号数据结构"""
  26. agent_name: str
  27. direction: SignalDirection
  28. strength: SignalStrength
  29. confidence: float # 0-1
  30. suggested_position: float # 0-1
  31. expected_return: float
  32. win_probability: float
  33. timestamp: datetime
  34. valid_until: Optional[datetime] = None
  35. metadata: Dict[str, Any] = field(default_factory=dict)
  36. def is_valid(self) -> bool:
  37. """检查信号是否有效"""
  38. if self.valid_until is None:
  39. return True
  40. return datetime.now() < self.valid_until
  41. @dataclass
  42. class AgentHealth:
  43. """智能体健康度数据结构"""
  44. overall_score: float # 0-100
  45. sharpe_ratio: float
  46. regime_adaptation: float # 0-100
  47. signal_stability: float # 0-100
  48. compute_efficiency: float # 0-100
  49. status: str # "green", "yellow", "icu", "archived"
  50. last_evaluation: datetime
  51. recommendations: List[str] = field(default_factory=list)
  52. class AgentBase(ABC):
  53. """
  54. 智能体基类
  55. 所有策略智能体必须继承此类并实现抽象方法
  56. """
  57. def __init__(
  58. self,
  59. name: str,
  60. config: Optional[Dict[str, Any]] = None,
  61. max_position: float = 1.0,
  62. min_confidence: float = 0.5
  63. ):
  64. self.name = name
  65. self.config = config or {}
  66. self.max_position = max_position
  67. self.min_confidence = min_confidence
  68. self.agent_id = str(uuid.uuid4())[:8]
  69. # 历史记录
  70. self.signal_history: List[AgentSignal] = []
  71. self.trade_history: List[Dict] = []
  72. self.health_history: List[AgentHealth] = []
  73. # 状态
  74. self.is_active = True
  75. self.current_signal: Optional[AgentSignal] = None
  76. self.performance_stats = {
  77. "total_trades": 0,
  78. "winning_trades": 0,
  79. "total_return": 0.0,
  80. "max_drawdown": 0.0
  81. }
  82. @abstractmethod
  83. def generate_signal(
  84. self,
  85. price_data: pd.DataFrame,
  86. ecosystem: Optional[Any] = None
  87. ) -> Optional[AgentSignal]:
  88. """
  89. 生成交易信号
  90. Args:
  91. price_data: 价格数据
  92. ecosystem: 当前市场生态(可选)
  93. Returns:
  94. AgentSignal: 交易信号,无信号时返回None
  95. """
  96. pass
  97. @abstractmethod
  98. def get_expected_return(
  99. self,
  100. price_data: pd.DataFrame,
  101. ecosystem: Optional[Any] = None
  102. ) -> float:
  103. """
  104. 计算预期收益
  105. Returns:
  106. float: 预期收益率(年化)
  107. """
  108. pass
  109. @abstractmethod
  110. def get_win_probability(
  111. self,
  112. price_data: pd.DataFrame,
  113. ecosystem: Optional[Any] = None
  114. ) -> float:
  115. """
  116. 计算胜率
  117. Returns:
  118. float: 胜率 (0-1)
  119. """
  120. pass
  121. def get_health_score(self) -> AgentHealth:
  122. """
  123. 计算健康度评分
  124. 默认实现,子类可覆盖
  125. """
  126. if len(self.trade_history) < 20:
  127. return AgentHealth(
  128. overall_score=50.0,
  129. sharpe_ratio=0.0,
  130. regime_adaptation=50.0,
  131. signal_stability=50.0,
  132. compute_efficiency=100.0,
  133. status="yellow",
  134. last_evaluation=datetime.now(),
  135. recommendations=["样本不足,需要更多交易数据"]
  136. )
  137. # 计算近期夏普比率
  138. recent_returns = [t.get("return", 0) for t in self.trade_history[-20:]]
  139. sharpe = self._calculate_sharpe(recent_returns)
  140. # 生态适应性(简化计算)
  141. regime_adaptation = self._calculate_regime_adaptation()
  142. # 信号稳定性
  143. signal_stability = self._calculate_signal_stability()
  144. # 计算效率(固定值,子类可覆盖)
  145. compute_efficiency = 95.0
  146. # 综合评分
  147. overall = (
  148. sharpe * 30 +
  149. regime_adaptation * 0.3 +
  150. signal_stability * 0.2 +
  151. compute_efficiency * 0.1
  152. )
  153. overall = min(100, max(0, overall + 50)) # 调整到0-100
  154. # 确定状态
  155. if overall >= 80:
  156. status = "green"
  157. recommendations = []
  158. elif overall >= 60:
  159. status = "yellow"
  160. recommendations = ["健康度一般,建议监控"]
  161. elif overall >= 30:
  162. status = "yellow"
  163. recommendations = ["健康度偏低,需要优化参数"]
  164. else:
  165. status = "icu"
  166. recommendations = ["健康度严重不足,建议暂停交易"]
  167. health = AgentHealth(
  168. overall_score=overall,
  169. sharpe_ratio=sharpe,
  170. regime_adaptation=regime_adaptation,
  171. signal_stability=signal_stability,
  172. compute_efficiency=compute_efficiency,
  173. status=status,
  174. last_evaluation=datetime.now(),
  175. recommendations=recommendations
  176. )
  177. self.health_history.append(health)
  178. return health
  179. def calculate_utility(
  180. self,
  181. price_data: pd.DataFrame,
  182. ecosystem: Any,
  183. lambda_risk: float = 0.5,
  184. alpha_recent: float = 0.3
  185. ) -> float:
  186. """
  187. 计算期望效用 E[U]
  188. E[U] = P(Win|Regime) * Expected_Return - λ * Risk_Penalty + α * Recent_Performance
  189. Args:
  190. price_data: 价格数据
  191. ecosystem: 市场生态
  192. lambda_risk: 风险惩罚系数
  193. alpha_recent: 近期表现系数
  194. Returns:
  195. float: 期望效用
  196. """
  197. # 基础参数
  198. p_win = self.get_win_probability(price_data, ecosystem)
  199. expected_return = self.get_expected_return(price_data, ecosystem)
  200. # 风险惩罚(使用最大回撤)
  201. risk_penalty = abs(self.performance_stats.get("max_drawdown", 0.1))
  202. # 近期表现(最近10笔交易的胜率)
  203. recent_trades = self.trade_history[-10:]
  204. if recent_trades:
  205. recent_win_rate = sum(1 for t in recent_trades if t.get("return", 0) > 0) / len(recent_trades)
  206. recent_performance = recent_win_rate * 0.1 # 归一化
  207. else:
  208. recent_performance = 0.0
  209. # 生态适配加成
  210. regime_match = self._check_regime_match(ecosystem)
  211. utility = (
  212. p_win * expected_return
  213. - lambda_risk * risk_penalty
  214. + alpha_recent * recent_performance
  215. ) * regime_match
  216. return utility
  217. def update_trade_result(self, trade_result: Dict):
  218. """更新交易结果"""
  219. self.trade_history.append(trade_result)
  220. self.performance_stats["total_trades"] += 1
  221. if trade_result.get("return", 0) > 0:
  222. self.performance_stats["winning_trades"] += 1
  223. self.performance_stats["total_return"] += trade_result.get("return", 0)
  224. # 更新最大回撤
  225. drawdown = trade_result.get("drawdown", 0)
  226. if drawdown < self.performance_stats["max_drawdown"]:
  227. self.performance_stats["max_drawdown"] = drawdown
  228. def activate(self):
  229. """激活智能体"""
  230. self.is_active = True
  231. def deactivate(self):
  232. """停用智能体"""
  233. self.is_active = False
  234. def get_performance_summary(self) -> Dict[str, Any]:
  235. """获取业绩摘要"""
  236. total = self.performance_stats["total_trades"]
  237. wins = self.performance_stats["winning_trades"]
  238. return {
  239. "agent_name": self.name,
  240. "agent_id": self.agent_id,
  241. "is_active": self.is_active,
  242. "total_trades": total,
  243. "win_rate": wins / total if total > 0 else 0.0,
  244. "total_return": self.performance_stats["total_return"],
  245. "max_drawdown": self.performance_stats["max_drawdown"],
  246. "current_signal": self.current_signal.direction.value if self.current_signal else None
  247. }
  248. # 辅助方法
  249. def _calculate_sharpe(self, returns: List[float], risk_free_rate: float = 0.03) -> float:
  250. """计算夏普比率"""
  251. if len(returns) < 2:
  252. return 0.0
  253. returns_array = np.array(returns)
  254. excess_returns = returns_array - risk_free_rate / 252
  255. std = np.std(excess_returns, ddof=1)
  256. if std == 0:
  257. return 0.0
  258. return np.mean(excess_returns) / std * np.sqrt(252)
  259. def _calculate_regime_adaptation(self) -> float:
  260. """计算生态适应性(简化实现)"""
  261. if len(self.trade_history) < 20:
  262. return 50.0
  263. # 假设有生态信息,计算不同生态下的表现差异
  264. # 简化:返回固定值,子类可覆盖
  265. return 60.0
  266. def _calculate_signal_stability(self) -> float:
  267. """计算信号稳定性"""
  268. if len(self.signal_history) < 10:
  269. return 50.0
  270. recent_signals = self.signal_history[-20:]
  271. if len(recent_signals) < 2:
  272. return 50.0
  273. # 计算信号方向变化的频率
  274. direction_changes = sum(
  275. 1 for i in range(1, len(recent_signals))
  276. if recent_signals[i].direction != recent_signals[i-1].direction
  277. )
  278. stability = 1 - (direction_changes / (len(recent_signals) - 1))
  279. return stability * 100
  280. def _check_regime_match(self, ecosystem: Any) -> float:
  281. """检查当前生态适配度"""
  282. # 默认返回1.0,子类可覆盖
  283. return 1.0
  284. def _validate_signal(
  285. self,
  286. direction: SignalDirection,
  287. confidence: float,
  288. ecosystem: Optional[Any] = None
  289. ) -> bool:
  290. """验证信号有效性"""
  291. if confidence < self.min_confidence:
  292. return False
  293. # 检查生态毒性
  294. if ecosystem and hasattr(ecosystem, 'micro'):
  295. if ecosystem.micro.flow_toxicity.value in ["high", "medium"]:
  296. if confidence < 0.8: # 有毒环境下需要更高置信度
  297. return False
  298. return True