coordinator.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. """
  2. 智能体协同机制
  3. 管理多智能体协同,处理信号冲突和叠加
  4. """
  5. from dataclasses import dataclass, field
  6. from typing import Dict, List, Optional, Tuple, Any
  7. from datetime import datetime
  8. from enum import Enum
  9. import uuid
  10. import pandas as pd
  11. from agents.base import AgentBase, AgentSignal, SignalDirection
  12. class ConflictResolutionMethod(Enum):
  13. """冲突解决方式"""
  14. CONFIDENCE_WEIGHTED = "confidence_weighted"
  15. HIGHER_WINS = "higher_wins"
  16. CANCEL = "cancel"
  17. @dataclass
  18. class CoordinatedSignal:
  19. """协同后的信号"""
  20. original_signals: List[AgentSignal]
  21. final_direction: SignalDirection
  22. final_strength: float # 0-1
  23. final_position: float # 0-1
  24. coordination_type: str # "conflict_resolved", "reinforced", "single"
  25. reasoning: str
  26. timestamp: datetime = field(default_factory=datetime.now)
  27. @dataclass
  28. class ReinforcementResult:
  29. """信号叠加增强结果"""
  30. is_reinforced: bool
  31. direction: SignalDirection
  32. boost_factor: float
  33. original_count: int
  34. avg_confidence: float
  35. class AgentCoordinator:
  36. """
  37. 智能体协同器
  38. 功能:
  39. 1. 信号冲突处理:当智能体生成反向信号时,比较置信度处理
  40. 2. 信号叠加增强:当多个智能体生成同向信号时,提升仓位上限
  41. """
  42. def __init__(
  43. self,
  44. conflict_resolution: ConflictResolutionMethod = ConflictResolutionMethod.CONFIDENCE_WEIGHTED,
  45. min_confidence_diff: float = 0.2,
  46. reinforcement_threshold: int = 3,
  47. position_boost: float = 0.2,
  48. stop_tighten: float = 0.5,
  49. max_position_boost: float = 0.3
  50. ):
  51. self.conflict_resolution = conflict_resolution
  52. self.min_confidence_diff = min_confidence_diff
  53. self.reinforcement_threshold = reinforcement_threshold
  54. self.position_boost = position_boost
  55. self.stop_tighten = stop_tighten
  56. self.max_position_boost = max_position_boost
  57. self.coordination_history: List[CoordinatedSignal] = []
  58. def coordinate(
  59. self,
  60. signals: Dict[str, AgentSignal],
  61. weights: Dict[str, float]
  62. ) -> Optional[CoordinatedSignal]:
  63. """
  64. 协同多个智能体的信号
  65. Args:
  66. signals: 各智能体的信号
  67. weights: 各智能体的权重
  68. Returns:
  69. CoordinatedSignal: 协同后的信号
  70. """
  71. if not signals:
  72. return None
  73. # 单信号直接返回
  74. if len(signals) == 1:
  75. signal = list(signals.values())[0]
  76. return CoordinatedSignal(
  77. original_signals=[signal],
  78. final_direction=signal.direction,
  79. final_strength=signal.confidence,
  80. final_position=signal.suggested_position,
  81. coordination_type="single",
  82. reasoning="单信号,无需协同"
  83. )
  84. # 分组:多头、空头、中性
  85. long_signals = []
  86. short_signals = []
  87. neutral_signals = []
  88. for name, signal in signals.items():
  89. if signal.direction == SignalDirection.LONG:
  90. long_signals.append((name, signal))
  91. elif signal.direction == SignalDirection.SHORT:
  92. short_signals.append((name, signal))
  93. else:
  94. neutral_signals.append((name, signal))
  95. # 情况1: 只有同向信号 - 叠加增强
  96. if long_signals and not short_signals:
  97. return self._reinforce_signals(long_signals, weights, SignalDirection.LONG)
  98. if short_signals and not long_signals:
  99. return self._reinforce_signals(short_signals, weights, SignalDirection.SHORT)
  100. # 情况2: 有反向信号 - 冲突处理
  101. if long_signals and short_signals:
  102. return self._resolve_conflict(long_signals, short_signals, weights)
  103. # 情况3: 只有中性信号
  104. if neutral_signals:
  105. avg_confidence = sum(s.confidence for _, s in neutral_signals) / len(neutral_signals)
  106. return CoordinatedSignal(
  107. original_signals=[s for _, s in neutral_signals],
  108. final_direction=SignalDirection.NEUTRAL,
  109. final_strength=avg_confidence,
  110. final_position=0.0,
  111. coordination_type="conflict_resolved",
  112. reasoning="所有智能体均输出中性信号,观望"
  113. )
  114. return None
  115. def _reinforce_signals(
  116. self,
  117. signals: List[Tuple[str, AgentSignal]],
  118. weights: Dict[str, float],
  119. direction: SignalDirection
  120. ) -> CoordinatedSignal:
  121. """
  122. 信号叠加增强
  123. 当多个智能体生成同向信号时:
  124. 1. 提升该方向仓位上限
  125. 2. 收紧止损
  126. """
  127. if len(signals) < self.reinforcement_threshold:
  128. # 未达到增强阈值,正常加权
  129. weighted_position = sum(
  130. signal.suggested_position * weights.get(name, 1/len(signals))
  131. for name, signal in signals
  132. )
  133. avg_confidence = sum(s.confidence for _, s in signals) / len(signals)
  134. coordinated = CoordinatedSignal(
  135. original_signals=[s for _, s in signals],
  136. final_direction=direction,
  137. final_strength=avg_confidence,
  138. final_position=min(1.0, weighted_position),
  139. coordination_type="reinforced",
  140. reasoning=f"{len(signals)}个智能体同向,未达增强阈值"
  141. )
  142. else:
  143. # 达到增强阈值,提升仓位
  144. base_position = sum(
  145. signal.suggested_position * weights.get(name, 1/len(signals))
  146. for name, signal in signals
  147. )
  148. # 计算增强因子
  149. boost = min(
  150. self.max_position_boost,
  151. self.position_boost * (len(signals) - self.reinforcement_threshold + 1)
  152. )
  153. enhanced_position = min(1.0, base_position * (1 + boost))
  154. avg_confidence = sum(s.confidence for _, s in signals) / len(signals)
  155. coordinated = CoordinatedSignal(
  156. original_signals=[s for _, s in signals],
  157. final_direction=direction,
  158. final_strength=min(1.0, avg_confidence * 1.1), # 置信度小幅提升
  159. final_position=enhanced_position,
  160. coordination_type="reinforced",
  161. reasoning=f"{len(signals)}个智能体同向,仓位提升{boost:.1%},止损收紧{self.stop_tighten:.1%}"
  162. )
  163. self.coordination_history.append(coordinated)
  164. return coordinated
  165. def _resolve_conflict(
  166. self,
  167. long_signals: List[Tuple[str, AgentSignal]],
  168. short_signals: List[Tuple[str, AgentSignal]],
  169. weights: Dict[str, float]
  170. ) -> CoordinatedSignal:
  171. """
  172. 信号冲突处理
  173. 当存在反向信号时:
  174. 1. 比较双方加权置信度
  175. 2. 根据配置的策略处理冲突
  176. """
  177. # 计算各方加权力量和
  178. long_strength = sum(
  179. signal.confidence * weights.get(name, 0.5)
  180. for name, signal in long_signals
  181. )
  182. short_strength = sum(
  183. signal.confidence * weights.get(name, 0.5)
  184. for name, signal in short_signals
  185. )
  186. # 计算数量
  187. long_count = len(long_signals)
  188. short_count = len(short_signals)
  189. all_signals = [s for _, s in long_signals] + [s for _, s in short_signals]
  190. # 冲突解决策略
  191. if self.conflict_resolution == ConflictResolutionMethod.CANCEL:
  192. # 取消所有信号,观望
  193. return CoordinatedSignal(
  194. original_signals=all_signals,
  195. final_direction=SignalDirection.NEUTRAL,
  196. final_strength=0.0,
  197. final_position=0.0,
  198. coordination_type="conflict_resolved",
  199. reasoning=f"信号冲突(多:{long_count} vs 空:{short_count}),取消所有信号观望"
  200. )
  201. elif self.conflict_resolution == ConflictResolutionMethod.HIGHER_WINS:
  202. # 较强方获胜
  203. if long_strength > short_strength:
  204. winner_signals = long_signals
  205. winner_direction = SignalDirection.LONG
  206. else:
  207. winner_signals = short_signals
  208. winner_direction = SignalDirection.SHORT
  209. avg_confidence = sum(s.confidence for _, s in winner_signals) / len(winner_signals)
  210. weighted_position = sum(
  211. signal.suggested_position * weights.get(name, 1/len(winner_signals))
  212. for name, signal in winner_signals
  213. )
  214. return CoordinatedSignal(
  215. original_signals=all_signals,
  216. final_direction=winner_direction,
  217. final_strength=avg_confidence,
  218. final_position=weighted_position,
  219. coordination_type="conflict_resolved",
  220. reasoning=f"信号冲突,{winner_direction.value}方获胜(力量比{long_strength:.2f}:{short_strength:.2f})"
  221. )
  222. else: # CONFIDENCE_WEIGHTED
  223. # 按置信度加权执行双方
  224. total_strength = long_strength + short_strength
  225. if total_strength == 0:
  226. return CoordinatedSignal(
  227. original_signals=all_signals,
  228. final_direction=SignalDirection.NEUTRAL,
  229. final_strength=0.0,
  230. final_position=0.0,
  231. coordination_type="conflict_resolved",
  232. reasoning="信号冲突,双方力量均为0,观望"
  233. )
  234. long_ratio = long_strength / total_strength
  235. short_ratio = short_strength / total_strength
  236. # 净仓位
  237. net_position = long_ratio - short_ratio
  238. if abs(net_position) < self.min_confidence_diff:
  239. return CoordinatedSignal(
  240. original_signals=all_signals,
  241. final_direction=SignalDirection.NEUTRAL,
  242. final_strength=abs(net_position),
  243. final_position=0.0,
  244. coordination_type="conflict_resolved",
  245. reasoning=f"信号冲突,净仓位{net_position:.2f}低于阈值,观望"
  246. )
  247. # 确定方向和仓位
  248. if net_position > 0:
  249. direction = SignalDirection.LONG
  250. position = min(1.0, net_position)
  251. else:
  252. direction = SignalDirection.SHORT
  253. position = min(1.0, abs(net_position))
  254. # 综合置信度
  255. avg_confidence = sum(s.confidence for s in all_signals) / len(all_signals)
  256. return CoordinatedSignal(
  257. original_signals=all_signals,
  258. final_direction=direction,
  259. final_strength=avg_confidence,
  260. final_position=position,
  261. coordination_type="conflict_resolved",
  262. reasoning=f"信号冲突,置信度加权结果:{direction.value},仓位{position:.2f}(多{long_ratio:.2f} vs 空{short_ratio:.2f})"
  263. )
  264. def get_coordination_summary(self) -> Dict[str, Any]:
  265. """获取协同历史摘要"""
  266. if not self.coordination_history:
  267. return {}
  268. total = len(self.coordination_history)
  269. type_counts = {}
  270. for coord in self.coordination_history:
  271. type_counts[coord.coordination_type] = type_counts.get(coord.coordination_type, 0) + 1
  272. return {
  273. "total_coordinations": total,
  274. "type_distribution": type_counts,
  275. "avg_final_position": sum(c.final_position for c in self.coordination_history) / total,
  276. "last_direction": self.coordination_history[-1].final_direction.value if self.coordination_history else None
  277. }