signal_engine.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. """
  2. 统一信号引擎 (Unified Signal Engine)
  3. 为三品种提供统一的入场/出场信号:
  4. 1. 趋势确认:价格 > 20MA > 60MA,且均线向上
  5. 2. 动量确认:RSI 50-70,且5日涨幅 > 50% 20日涨幅
  6. 3. 量能确认:成交量 > 1.2倍20日均量
  7. """
  8. from typing import Dict, Optional, Tuple
  9. from dataclasses import dataclass
  10. from datetime import datetime
  11. import pandas as pd
  12. import numpy as np
  13. @dataclass
  14. class SignalResult:
  15. """信号结果"""
  16. signal: str # "enter_long", "exit", "hold", "neutral"
  17. confidence: float # 0-1
  18. trend_confirmed: bool
  19. momentum_confirmed: bool
  20. volume_confirmed: bool
  21. # 详细数据
  22. rsi: float
  23. price_vs_20ma: float
  24. ma20_slope: float
  25. ma60_slope: float
  26. volume_ratio: float
  27. return_5d: float
  28. return_20d: float
  29. class UnifiedSignalEngine:
  30. """
  31. 统一信号引擎
  32. 所有品种使用同一套信号规则,避免过拟合
  33. """
  34. def __init__(
  35. self,
  36. rsi_period: int = 14,
  37. rsi_lower: float = 45, # 放宽至45
  38. rsi_upper: float = 75, # 放宽至75
  39. volume_threshold: float = 1.0, # 放宽至1.0(持平即可)
  40. ma_fast: int = 20,
  41. ma_slow: int = 60
  42. ):
  43. self.rsi_period = rsi_period
  44. self.rsi_lower = rsi_lower
  45. self.rsi_upper = rsi_upper
  46. self.volume_threshold = volume_threshold
  47. self.ma_fast = ma_fast
  48. self.ma_slow = ma_slow
  49. # 当前持仓状态
  50. self.in_position = False
  51. self.entry_price = None
  52. self.highest_price = None
  53. def generate_signal(
  54. self,
  55. df: pd.DataFrame,
  56. current_date: Optional[datetime] = None
  57. ) -> SignalResult:
  58. """
  59. 生成交易信号
  60. Args:
  61. df: 品种数据(OHLCV)
  62. current_date: 当前日期(回测用)
  63. Returns:
  64. SignalResult: 信号结果
  65. """
  66. # 获取数据窗口
  67. if current_date is not None:
  68. df = df[df.index <= current_date]
  69. if len(df) < self.ma_slow + 5:
  70. return self._create_neutral_result()
  71. close = df['close']
  72. volume = df['volume']
  73. # 计算指标
  74. ma20 = close.rolling(self.ma_fast).mean()
  75. ma60 = close.rolling(self.ma_slow).mean()
  76. rsi = self._calculate_rsi(close)
  77. current_price = close.iloc[-1]
  78. current_ma20 = ma20.iloc[-1]
  79. current_ma60 = ma60.iloc[-1]
  80. current_volume = volume.iloc[-1]
  81. avg_volume = volume.iloc[-20:].mean()
  82. # 1. 趋势确认
  83. trend_confirmed = self._check_trend(
  84. current_price, current_ma20, current_ma60, ma20, ma60
  85. )
  86. # 2. 动量确认
  87. momentum_confirmed, rsi_value = self._check_momentum(close)
  88. # 3. 量能确认
  89. volume_confirmed, volume_ratio = self._check_volume(current_volume, avg_volume)
  90. # 计算收益率
  91. return_5d = (close.iloc[-1] - close.iloc[-5]) / close.iloc[-5] if len(close) >= 5 else 0
  92. return_20d = (close.iloc[-1] - close.iloc[-20]) / close.iloc[-20] if len(close) >= 20 else 0
  93. # 生成信号
  94. if not self.in_position:
  95. # 空仓:检查入场条件
  96. if trend_confirmed and momentum_confirmed and volume_confirmed:
  97. signal = "enter_long"
  98. confidence = self._calculate_confidence(
  99. trend_confirmed, momentum_confirmed, volume_confirmed,
  100. rsi_value, volume_ratio
  101. )
  102. self.in_position = True
  103. self.entry_price = current_price
  104. self.highest_price = current_price
  105. else:
  106. signal = "neutral"
  107. confidence = 0.0
  108. else:
  109. # 持仓:检查出场条件
  110. self.highest_price = max(self.highest_price, current_price)
  111. # 更新最高价
  112. if current_price > self.highest_price:
  113. self.highest_price = current_price
  114. # 检查出场条件
  115. should_exit = self._check_exit(
  116. current_price, close, ma20, rsi_value
  117. )
  118. if should_exit:
  119. signal = "exit"
  120. confidence = 1.0
  121. self._reset_position()
  122. else:
  123. signal = "hold"
  124. confidence = 0.5
  125. return SignalResult(
  126. signal=signal,
  127. confidence=confidence,
  128. trend_confirmed=trend_confirmed,
  129. momentum_confirmed=momentum_confirmed,
  130. volume_confirmed=volume_confirmed,
  131. rsi=rsi_value,
  132. price_vs_20ma=(current_price - current_ma20) / current_ma20,
  133. ma20_slope=(ma20.iloc[-1] - ma20.iloc[-5]) / ma20.iloc[-5] if len(ma20) >= 5 else 0,
  134. ma60_slope=(ma60.iloc[-1] - ma60.iloc[-5]) / ma60.iloc[-5] if len(ma60) >= 5 else 0,
  135. volume_ratio=volume_ratio,
  136. return_5d=return_5d,
  137. return_20d=return_20d
  138. )
  139. def _check_trend(
  140. self,
  141. price: float,
  142. ma20: float,
  143. ma60: float,
  144. ma20_series: pd.Series,
  145. ma60_series: pd.Series
  146. ) -> bool:
  147. """检查趋势确认条件"""
  148. # 价格 > 20MA > 60MA
  149. price_above_ma = price > ma20 > ma60
  150. # 60MA斜率 > -0.001(趋势向上或走平)
  151. ma60_slope = (ma60_series.iloc[-1] - ma60_series.iloc[-5]) / ma60_series.iloc[-5] \
  152. if len(ma60_series) >= 5 else 0
  153. ma_slope_positive = ma60_slope > -0.001
  154. return price_above_ma and ma_slope_positive
  155. def _check_momentum(self, close: pd.Series) -> Tuple[bool, float]:
  156. """检查动量确认条件"""
  157. rsi = self._calculate_rsi(close)
  158. # RSI在50-70之间(强势但非超买)
  159. rsi_in_range = self.rsi_lower <= rsi <= self.rsi_upper
  160. # 5日涨幅 > 50% 20日涨幅(动能加速)
  161. if len(close) >= 20:
  162. return_5d = (close.iloc[-1] - close.iloc[-5]) / close.iloc[-5]
  163. return_20d = (close.iloc[-1] - close.iloc[-20]) / close.iloc[-20]
  164. momentum_accelerating = return_5d > return_20d * 0.5
  165. else:
  166. momentum_accelerating = False
  167. return rsi_in_range and momentum_accelerating, rsi
  168. def _check_volume(self, current_vol: float, avg_vol: float) -> Tuple[bool, float]:
  169. """检查量能确认条件"""
  170. if avg_vol == 0:
  171. return False, 0
  172. volume_ratio = current_vol / avg_vol
  173. return volume_ratio >= self.volume_threshold, volume_ratio
  174. def _check_exit(
  175. self,
  176. current_price: float,
  177. close: pd.Series,
  178. ma20: pd.Series,
  179. rsi: float
  180. ) -> bool:
  181. """检查出场条件"""
  182. # 1. 趋势反转:价格跌破20日均线
  183. if current_price < ma20.iloc[-1]:
  184. return True
  185. # 2. 动量衰竭:RSI从高位跌破50
  186. if len(close) >= 2:
  187. prev_rsi = self._calculate_rsi(close.iloc[:-1])
  188. if prev_rsi > 60 and rsi < 50:
  189. return True
  190. # 3. 移动止盈:从最高点回撤10%
  191. if self.highest_price and self.entry_price:
  192. drawdown_from_peak = (self.highest_price - current_price) / self.highest_price
  193. if drawdown_from_peak >= 0.10:
  194. return True
  195. return False
  196. def _calculate_confidence(
  197. self,
  198. trend: bool,
  199. momentum: bool,
  200. volume: bool,
  201. rsi: float,
  202. volume_ratio: float
  203. ) -> float:
  204. """计算信号置信度"""
  205. # 基础分
  206. base = 0.5
  207. # 三条件都满足
  208. if trend and momentum and volume:
  209. base += 0.3
  210. # RSI越强越好(但不超过70)
  211. if 55 <= rsi <= 65:
  212. base += 0.1
  213. # 量能越大越好
  214. if volume_ratio > 1.5:
  215. base += 0.1
  216. return min(1.0, base)
  217. def _calculate_rsi(self, prices: pd.Series) -> float:
  218. """计算RSI"""
  219. if len(prices) < self.rsi_period + 1:
  220. return 50.0
  221. deltas = prices.diff()
  222. gains = deltas.clip(lower=0)
  223. losses = (-deltas).clip(lower=0)
  224. avg_gain = gains.rolling(self.rsi_period).mean()
  225. avg_loss = losses.rolling(self.rsi_period).mean()
  226. rs = avg_gain.iloc[-1] / avg_loss.iloc[-1] if avg_loss.iloc[-1] != 0 else 0
  227. rsi = 100 - (100 / (1 + rs))
  228. return rsi
  229. def _create_neutral_result(self) -> SignalResult:
  230. """创建中性信号结果"""
  231. return SignalResult(
  232. signal="neutral",
  233. confidence=0.0,
  234. trend_confirmed=False,
  235. momentum_confirmed=False,
  236. volume_confirmed=False,
  237. rsi=50.0,
  238. price_vs_20ma=0.0,
  239. ma20_slope=0.0,
  240. ma60_slope=0.0,
  241. volume_ratio=1.0,
  242. return_5d=0.0,
  243. return_20d=0.0
  244. )
  245. def _reset_position(self):
  246. """重置持仓状态"""
  247. self.in_position = False
  248. self.entry_price = None
  249. self.highest_price = None
  250. def set_position_state(self, in_position: bool, entry_price: Optional[float] = None):
  251. """设置持仓状态(用于回测恢复)"""
  252. self.in_position = in_position
  253. self.entry_price = entry_price
  254. self.highest_price = entry_price