agent.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """
  2. 趋势猎手智能体 - 稳健趋势跟踪
  3. 核心逻辑:
  4. 1. 双均线金叉入场
  5. 2. ADX>20确认趋势
  6. 3. 价格>MA20确认方向
  7. 4. 生态过滤:只在Summer/Spring交易
  8. """
  9. from datetime import datetime, timedelta
  10. from typing import Dict, Any, Optional
  11. import numpy as np
  12. import pandas as pd
  13. from agents.base import AgentBase, AgentSignal, SignalDirection, SignalStrength
  14. from core.ecosystem import MacroRegime, UnifiedEcosystem
  15. class TrendHunterAgent(AgentBase):
  16. """
  17. 趋势猎手 - 稳健版本
  18. """
  19. def __init__(
  20. self,
  21. config: Optional[Dict[str, Any]] = None,
  22. ma_fast: int = 5, # 5日均线
  23. ma_slow: int = 20, # 20日均线
  24. adx_period: int = 14,
  25. adx_threshold: float = 20.0, # ADX门槛
  26. ):
  27. super().__init__(
  28. name="trend_hunter",
  29. config=config,
  30. max_position=1.0,
  31. min_confidence=0.55
  32. )
  33. self.ma_fast = ma_fast
  34. self.ma_slow = ma_slow
  35. self.adx_period = adx_period
  36. self.adx_threshold = adx_threshold
  37. self.preferred_regimes = [MacroRegime.SUMMER, MacroRegime.SPRING]
  38. def generate_signal(
  39. self,
  40. price_data: pd.DataFrame,
  41. ecosystem: Optional[UnifiedEcosystem] = None
  42. ) -> Optional[AgentSignal]:
  43. """生成交易信号"""
  44. if len(price_data) < self.ma_slow + 10:
  45. return None
  46. close = price_data['close']
  47. high = price_data['high']
  48. low = price_data['low']
  49. # 计算指标
  50. ma_fast = close.rolling(self.ma_fast).mean()
  51. ma_slow = close.rolling(self.ma_slow).mean()
  52. adx = self._calculate_adx(price_data)
  53. current_price = close.iloc[-1]
  54. current_ma_fast = ma_fast.iloc[-1]
  55. current_ma_slow = ma_slow.iloc[-1]
  56. prev_ma_fast = ma_fast.iloc[-2]
  57. prev_ma_slow = ma_slow.iloc[-2]
  58. # 金叉/死叉判断
  59. golden_cross = (prev_ma_fast <= prev_ma_slow) and (current_ma_fast > current_ma_slow)
  60. death_cross = (prev_ma_fast >= prev_ma_slow) and (current_ma_fast < current_ma_slow)
  61. # 趋势强度
  62. ma_diff = (current_ma_fast - current_ma_slow) / current_ma_slow
  63. direction = SignalDirection.NEUTRAL
  64. confidence = 0.0
  65. position_size = 0.0
  66. # 只在Summer/Spring交易
  67. if ecosystem and ecosystem.macro.regime in self.preferred_regimes:
  68. health = ecosystem.meso.health_score / 100
  69. # 入场:简化条件 - 只要在均线上方 + 健康度>50
  70. if current_price > current_ma_slow and current_ma_fast > current_ma_slow and health > 0.5:
  71. # 金叉时高仓位,否则维持仓位
  72. if golden_cross:
  73. direction = SignalDirection.LONG
  74. confidence = min(1.0, 0.7 * health)
  75. position_size = confidence
  76. else:
  77. direction = SignalDirection.LONG
  78. confidence = 0.5 * health
  79. position_size = confidence * 0.6
  80. # 出场:死叉或跌破慢线
  81. elif death_cross or current_price < current_ma_slow * 0.98: # 允许2%回撤
  82. direction = SignalDirection.NEUTRAL
  83. confidence = 0.0
  84. position_size = 0.0
  85. # 非目标生态,观望
  86. else:
  87. direction = SignalDirection.NEUTRAL
  88. confidence = 0.0
  89. position_size = 0.0
  90. # 无信号时返回None
  91. if direction == SignalDirection.NEUTRAL and position_size == 0:
  92. return None
  93. # 确定强度
  94. if confidence > 0.75:
  95. strength = SignalStrength.STRONG
  96. elif confidence > 0.6:
  97. strength = SignalStrength.MEDIUM
  98. else:
  99. strength = SignalStrength.WEAK
  100. signal = AgentSignal(
  101. agent_name=self.name,
  102. direction=direction,
  103. strength=strength,
  104. confidence=confidence,
  105. suggested_position=position_size,
  106. expected_return=self.get_expected_return(price_data, ecosystem),
  107. win_probability=self.get_win_probability(price_data, ecosystem),
  108. timestamp=datetime.now(),
  109. valid_until=datetime.now() + timedelta(hours=24),
  110. metadata={
  111. "current_price": current_price,
  112. "ma_fast": current_ma_fast,
  113. "ma_slow": current_ma_slow,
  114. "adx": adx,
  115. "golden_cross": golden_cross,
  116. "death_cross": death_cross,
  117. "ma_diff": ma_diff,
  118. "health": ecosystem.meso.health_score if ecosystem else 0
  119. }
  120. )
  121. self.current_signal = signal
  122. self.signal_history.append(signal)
  123. return signal
  124. def _calculate_adx(self, data: pd.DataFrame) -> float:
  125. """计算ADX指标"""
  126. if len(data) < self.adx_period * 2:
  127. return 20.0
  128. high = data['high']
  129. low = data['low']
  130. close = data['close']
  131. # +DM和-DM
  132. plus_dm = high.diff()
  133. minus_dm = -low.diff()
  134. plus_dm = plus_dm.clip(lower=0)
  135. minus_dm = minus_dm.clip(lower=0)
  136. # TR
  137. tr1 = high - low
  138. tr2 = (high - close.shift(1)).abs()
  139. tr3 = (low - close.shift(1)).abs()
  140. tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
  141. # ATR
  142. atr = tr.rolling(self.adx_period).mean()
  143. # +DI和-DI
  144. plus_di = 100 * (plus_dm.rolling(self.adx_period).mean() / atr)
  145. minus_di = 100 * (minus_dm.rolling(self.adx_period).mean() / atr)
  146. # DX和ADX
  147. dx = 100 * (plus_di - minus_di).abs() / (plus_di + minus_di)
  148. adx = dx.rolling(self.adx_period).mean()
  149. return adx.iloc[-1] if not pd.isna(adx.iloc[-1]) else 20.0
  150. def get_expected_return(self, price_data: pd.DataFrame, ecosystem=None) -> float:
  151. """预期收益"""
  152. if ecosystem and ecosystem.macro.regime == MacroRegime.SUMMER:
  153. return 0.35
  154. elif ecosystem and ecosystem.macro.regime == MacroRegime.SPRING:
  155. return 0.25
  156. return 0.15
  157. def get_win_probability(self, price_data: pd.DataFrame, ecosystem=None) -> float:
  158. """胜率估算"""
  159. base_prob = 0.55
  160. if ecosystem:
  161. if ecosystem.macro.regime == MacroRegime.SUMMER:
  162. base_prob += 0.12
  163. elif ecosystem.macro.regime == MacroRegime.SPRING:
  164. base_prob += 0.08
  165. health = ecosystem.meso.health_score / 100
  166. base_prob += (health - 0.5) * 0.1
  167. return min(0.80, base_prob)
  168. def _calculate_position_size(
  169. self,
  170. ecosystem: Optional[UnifiedEcosystem],
  171. confidence: float
  172. ) -> float:
  173. """计算仓位"""
  174. return self.max_position * confidence