regime_strategy.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """
  2. 市场状态感知策略 (Regime-Aware Strategy)
  3. 基于Regime Detection的动态仓位管理
  4. """
  5. import backtrader as bt
  6. import pandas as pd
  7. import numpy as np
  8. class RegimeAwareStrategy(bt.Strategy):
  9. """
  10. 创业板50状态感知策略
  11. 核心逻辑:
  12. - 强趋势上涨: 满仓持有
  13. - 弱趋势上涨: 半仓持有
  14. - 震荡整理: 轻仓或空仓
  15. - 弱趋势下跌: 空仓或轻仓做空(如允许)
  16. - 强趋势下跌: 空仓观望
  17. 入场信号: 20/60日均线金叉
  18. 出场信号: 20/60日均线死叉 或 状态恶化
  19. """
  20. params = (
  21. ('fast_ma', 20),
  22. ('slow_ma', 60),
  23. ('vol_short', 20),
  24. ('vol_long', 60),
  25. ('trend_threshold', 0.03),
  26. ('vol_percentile_threshold', 0.6),
  27. ('strong_bull_pct', 1.0), # 强趋势上涨仓位
  28. ('weak_bull_pct', 0.5), # 弱趋势上涨仓位
  29. ('consolidation_pct', 0.2), # 震荡仓位
  30. ('weak_bear_pct', 0.0), # 弱趋势下跌仓位
  31. ('strong_bear_pct', 0.0), # 强趋势下跌仓位
  32. ('printlog', True),
  33. )
  34. def __init__(self):
  35. self.dataclose = self.datas[0].close
  36. self.order = None
  37. # 双均线
  38. self.sma_fast = bt.indicators.SMA(period=self.p.fast_ma)
  39. self.sma_slow = bt.indicators.SMA(period=self.p.slow_ma)
  40. self.crossover = bt.indicators.CrossOver(self.sma_fast, self.sma_slow)
  41. # 波动率计算
  42. self.returns = bt.indicators.PctChange(self.dataclose, period=1)
  43. self.vol_short = bt.indicators.StdDev(self.returns, period=self.p.vol_short) * np.sqrt(252)
  44. self.vol_long = bt.indicators.StdDev(self.returns, period=self.p.vol_long) * np.sqrt(252)
  45. # 趋势强度
  46. self.ma_deviation = (self.dataclose - self.sma_fast) / self.sma_fast
  47. self.trend_strength = bt.indicators.SMA(self.ma_deviation, period=self.p.fast_ma)
  48. # 存储状态历史
  49. self.regime_history = []
  50. def get_current_regime(self):
  51. """判断当前市场状态"""
  52. # 获取当前值
  53. vol_pct = self.vol_short[0] / self.vol_long[0] if self.vol_long[0] != 0 else 1.0
  54. trend = self.trend_strength[0]
  55. # 高波动判断
  56. high_vol = vol_pct > 1.2 # 短期波动率高于长期20%
  57. # 趋势判断
  58. strong_up = trend > self.p.trend_threshold
  59. strong_down = trend < -self.p.trend_threshold
  60. if high_vol and strong_up:
  61. return 'strong_bull'
  62. elif (not high_vol) and strong_up:
  63. return 'weak_bull'
  64. elif high_vol and strong_down:
  65. return 'strong_bear'
  66. elif (not high_vol) and strong_down:
  67. return 'weak_bear'
  68. else:
  69. return 'consolidation'
  70. def get_target_position(self, regime):
  71. """根据状态确定目标仓位"""
  72. position_map = {
  73. 'strong_bull': self.p.strong_bull_pct,
  74. 'weak_bull': self.p.weak_bull_pct,
  75. 'consolidation': self.p.consolidation_pct,
  76. 'weak_bear': self.p.weak_bear_pct,
  77. 'strong_bear': self.p.strong_bear_pct,
  78. }
  79. return position_map.get(regime, 0)
  80. def next(self):
  81. # 当前状态
  82. regime = self.get_current_regime()
  83. self.regime_history.append(regime)
  84. # 目标仓位
  85. target_pct = self.get_target_position(regime)
  86. # 金叉信号
  87. golden_cross = self.crossover > 0
  88. # 死叉信号
  89. death_cross = self.crossover < 0
  90. # 当前仓位
  91. current_value = self.broker.getvalue()
  92. cash = self.broker.getcash()
  93. position_size = self.position.size if self.position else 0
  94. current_pct = (position_size * self.dataclose[0]) / current_value if current_value > 0 else 0
  95. # 交易逻辑
  96. if target_pct > 0 and position_size == 0 and golden_cross:
  97. # 入场: 有仓位空间 + 空仓 + 金叉
  98. size = int((current_value * target_pct) / self.dataclose[0] / 100) * 100
  99. if size > 0:
  100. self.order = self.buy(size=size)
  101. if self.p.printlog:
  102. self.log(f'BUY [{regime}], Size: {size}, Target: {target_pct:.0%}')
  103. elif position_size > 0 and (death_cross or target_pct == 0):
  104. # 出场: 有持仓 + (死叉 或 状态恶化到0仓位)
  105. self.order = self.close()
  106. if self.p.printlog:
  107. reason = 'Death Cross' if death_cross else f'Regime: {regime}'
  108. self.log(f'SELL [{reason}], Size: {position_size}')
  109. elif position_size > 0 and target_pct < current_pct * 0.8:
  110. # 减仓: 状态恶化但未到0
  111. new_size = int((current_value * target_pct) / self.dataclose[0] / 100) * 100
  112. if new_size < position_size:
  113. close_size = position_size - new_size
  114. self.order = self.sell(size=close_size)
  115. if self.p.printlog:
  116. self.log(f'REDUCE [{regime}], From {current_pct:.1%} to {target_pct:.0%}')
  117. def notify_order(self, order):
  118. if order.status in [order.Submitted, order.Accepted]:
  119. return
  120. if order.status in [order.Completed]:
  121. if order.isbuy():
  122. self.log(f'BUY EXECUTED @ {order.executed.price:.2f}')
  123. else:
  124. self.log(f'SELL EXECUTED @ {order.executed.price:.2f}')
  125. self.order = None
  126. def log(self, txt, dt=None):
  127. dt = dt or self.datas[0].datetime.date(0)
  128. print(f'{dt.isoformat()} {txt}')
  129. def stop(self):
  130. # 统计状态分布
  131. if self.regime_history:
  132. from collections import Counter
  133. regime_counts = Counter(self.regime_history)
  134. print('\n=== 状态分布 ===')
  135. for regime, count in regime_counts.most_common():
  136. print(f'{regime}: {count} 天 ({count/len(self.regime_history):.1%})')
  137. roi = (self.broker.getvalue() / self.broker.startingcash - 1) * 100
  138. print(f'\n=== 最终收益: {roi:.2f}% ===')
  139. def run_regime_backtest(csv_file="chinext50.csv", cash=100000.0):
  140. """运行状态感知策略回测"""
  141. cerebro = bt.Cerebro()
  142. # 数据
  143. df = pd.read_csv(csv_file, parse_dates=['datetime'], index_col='datetime')
  144. data = bt.feeds.PandasData(dataname=df)
  145. cerebro.adddata(data)
  146. # 策略
  147. cerebro.addstrategy(RegimeAwareStrategy)
  148. # 设置
  149. cerebro.broker.setcash(cash)
  150. cerebro.broker.setcommission(commission=0.001)
  151. # 分析器
  152. cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02)
  153. cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
  154. cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
  155. print('=== 状态感知策略回测 ===')
  156. print(f'初始资金: {cerebro.broker.getvalue():.2f}')
  157. results = cerebro.run()
  158. strat = results[0]
  159. print(f'\n=== 回测指标 ===')
  160. returns = strat.analyzers.returns.get_analysis()
  161. print(f"年化收益: {returns.get('rnorm100', 0):.2f}%")
  162. sharpe = strat.analyzers.sharpe.get_analysis()
  163. sharpe_ratio = sharpe.get('sharperatio')
  164. sharpe_text = 'N/A' if sharpe_ratio is None else f"{sharpe_ratio:.3f}"
  165. print(f"夏普比率: {sharpe_text}")
  166. drawdown = strat.analyzers.drawdown.get_analysis()
  167. print(f"最大回撤: {drawdown.get('max', {}).get('drawdown', 0):.2f}%")
  168. return cerebro, strat
  169. if __name__ == "__main__":
  170. run_regime_backtest()