agent.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """
  2. 均值回归者智能体 - 简化高效版
  3. 核心逻辑:
  4. 1. RSI<30超卖区 = 买入信号
  5. 2. 价格偏离20日均线>5% = 加仓信号
  6. 3. RSI>70或回归均线 = 卖出
  7. 4. 仅Spring/Winter交易,仓位60%
  8. """
  9. from datetime import datetime, timedelta
  10. from typing import Dict, Any, Optional
  11. import numpy as np
  12. import pandas as pd
  13. from agents.base import AgentBase, AgentSignal, SignalDirection, SignalStrength
  14. from core.ecosystem import MacroRegime, UnifiedEcosystem
  15. class MeanReversionAgent(AgentBase):
  16. """
  17. 均值回归者 - Spring/Winter专用
  18. """
  19. def __init__(
  20. self,
  21. config: Optional[Dict[str, Any]] = None,
  22. rsi_period: int = 14,
  23. rsi_oversold: float = 35.0, # RSI超卖线
  24. rsi_overbought: float = 65.0, # RSI超买线
  25. ma_period: int = 20,
  26. deviation_threshold: float = 0.05 # 偏离5%
  27. ):
  28. super().__init__(
  29. name="mean_reversion",
  30. config=config,
  31. max_position=0.6, # 最大60%仓位
  32. min_confidence=0.5
  33. )
  34. self.rsi_period = rsi_period
  35. self.rsi_oversold = rsi_oversold
  36. self.rsi_overbought = rsi_overbought
  37. self.ma_period = ma_period
  38. self.deviation_threshold = deviation_threshold
  39. self.preferred_regimes = [MacroRegime.SPRING, MacroRegime.WINTER]
  40. self.in_position = False
  41. self.entry_price = None
  42. def generate_signal(
  43. self,
  44. price_data: pd.DataFrame,
  45. ecosystem: Optional[UnifiedEcosystem] = None
  46. ) -> Optional[AgentSignal]:
  47. """生成交易信号 - RSI超卖反弹"""
  48. if len(price_data) < max(self.rsi_period, self.ma_period) + 5:
  49. return None
  50. close = price_data['close']
  51. current_price = close.iloc[-1]
  52. # 计算RSI
  53. rsi = self._calculate_rsi(close)
  54. # 计算均线偏离
  55. ma = close.rolling(self.ma_period).mean()
  56. ma_deviation = (current_price - ma.iloc[-1]) / ma.iloc[-1]
  57. direction = SignalDirection.NEUTRAL
  58. confidence = 0.0
  59. position_size = 0.0
  60. signal_type = "neutral"
  61. # 持仓中,检查出场
  62. if self.in_position:
  63. # RSI超买或回归均线,出场
  64. if rsi > self.rsi_overbought or abs(ma_deviation) < 0.02:
  65. direction = SignalDirection.NEUTRAL
  66. confidence = 0.0
  67. position_size = 0.0
  68. signal_type = "exit_rsi_mean"
  69. self._reset_state()
  70. else:
  71. # 持仓中
  72. direction = SignalDirection.LONG
  73. confidence = 0.6
  74. position_size = 0.6
  75. signal_type = "hold"
  76. # 空仓中,检查入场(仅Spring/Winter)
  77. else:
  78. # 生态过滤
  79. if ecosystem and ecosystem.macro.regime not in self.preferred_regimes:
  80. return None
  81. # RSI超卖 + 价格偏离均线 = 买入
  82. if rsi < self.rsi_oversold and ma_deviation < -self.deviation_threshold:
  83. direction = SignalDirection.LONG
  84. confidence = min(1.0, (self.rsi_oversold - rsi) / 20 + 0.6)
  85. position_size = 0.6
  86. signal_type = "entry_rsi_oversold"
  87. self.in_position = True
  88. self.entry_price = current_price
  89. # 仅偏离无RSI = 轻仓试探
  90. elif ma_deviation < -self.deviation_threshold * 1.5:
  91. direction = SignalDirection.LONG
  92. confidence = 0.5
  93. position_size = 0.3
  94. signal_type = "entry_deviation"
  95. self.in_position = True
  96. self.entry_price = current_price
  97. # 无信号
  98. if direction == SignalDirection.NEUTRAL and position_size == 0:
  99. return None
  100. return AgentSignal(
  101. agent_name=self.name,
  102. direction=direction,
  103. strength=SignalStrength.STRONG if confidence > 0.7 else SignalStrength.MEDIUM,
  104. confidence=confidence,
  105. suggested_position=position_size,
  106. expected_return=0.12,
  107. win_probability=0.55,
  108. timestamp=datetime.now(),
  109. valid_until=datetime.now() + timedelta(hours=24),
  110. metadata={
  111. "signal_type": signal_type,
  112. "rsi": rsi,
  113. "ma_deviation": ma_deviation,
  114. "current_price": current_price,
  115. "regime": ecosystem.macro.regime.value if ecosystem else "unknown"
  116. }
  117. )
  118. def _calculate_rsi(self, prices: pd.Series) -> float:
  119. """计算RSI"""
  120. if len(prices) < self.rsi_period + 1:
  121. return 50.0
  122. deltas = prices.diff()
  123. gains = deltas.clip(lower=0)
  124. losses = (-deltas).clip(lower=0)
  125. avg_gain = gains.rolling(self.rsi_period).mean()
  126. avg_loss = losses.rolling(self.rsi_period).mean()
  127. rs = avg_gain.iloc[-1] / avg_loss.iloc[-1] if avg_loss.iloc[-1] != 0 else 0
  128. rsi = 100 - (100 / (1 + rs))
  129. return rsi
  130. def _reset_state(self):
  131. """重置状态"""
  132. self.in_position = False
  133. self.entry_price = None
  134. def get_expected_return(self, price_data: pd.DataFrame, ecosystem=None) -> float:
  135. return 0.12
  136. def get_win_probability(self, price_data: pd.DataFrame, ecosystem=None) -> float:
  137. return 0.55
  138. def calculate_utility(self, price_data, ecosystem, lambda_risk=0.3, alpha_recent=0.3):
  139. """Spring/Winter高优先级"""
  140. if ecosystem and ecosystem.macro.regime == MacroRegime.SPRING:
  141. return 0.48 # Spring最高优先级(略低于Summer的breakout)
  142. elif ecosystem and ecosystem.macro.regime == MacroRegime.WINTER:
  143. return 0.40 # Winter高优先级
  144. elif ecosystem and ecosystem.macro.regime == MacroRegime.SUMMER:
  145. return 0.05 # Summer几乎不交易
  146. return 0.20 # 其他中等优先级