regime_v2.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """
  2. 修正版Regime策略 - 使用正确的CrossOver判断
  3. """
  4. import backtrader as bt
  5. import pandas as pd
  6. import numpy as np
  7. class RegimeStrategyV2(bt.Strategy):
  8. """
  9. 状态感知策略V2
  10. 修复:
  11. - CrossOver只在穿越当天返回1/-1,需要检测这个变化
  12. - 加入趋势强度过滤
  13. """
  14. params = (
  15. ('fast', 20),
  16. ('slow', 60),
  17. ('trend_threshold', 0.02),
  18. ('printlog', True),
  19. )
  20. def __init__(self):
  21. self.dataclose = self.datas[0].close
  22. self.order = None
  23. # 均线
  24. self.sma_fast = bt.indicators.SMA(period=self.p.fast)
  25. self.sma_slow = bt.indicators.SMA(period=self.p.slow)
  26. # 趋势强度
  27. self.trend = (self.dataclose - self.sma_fast) / self.sma_fast
  28. # 记录上一个cross状态
  29. self.last_cross = 0
  30. def next(self):
  31. if self.order:
  32. return
  33. # 当前cross状态: 1=金叉(快上穿慢), -1=死叉(快下穿慢), 0=无变化
  34. cross_now = 0
  35. if self.sma_fast[0] > self.sma_slow[0] and self.sma_fast[-1] <= self.sma_slow[-1]:
  36. cross_now = 1 # 金叉
  37. elif self.sma_fast[0] < self.sma_slow[0] and self.sma_fast[-1] >= self.sma_slow[-1]:
  38. cross_now = -1 # 死叉
  39. trend_val = self.trend[0] if not np.isnan(self.trend[0]) else 0
  40. # 金叉入场
  41. if cross_now == 1:
  42. if not self.position:
  43. size = int(self.broker.getcash() / self.dataclose[0] / 100) * 100
  44. if size > 0:
  45. self.order = self.buy(size=size)
  46. if self.p.printlog:
  47. self.log(f'BUY @ {self.dataclose[0]:.2f}, Trend: {trend_val:.4f}')
  48. # 死叉出场
  49. elif cross_now == -1:
  50. if self.position:
  51. self.order = self.close()
  52. if self.p.printlog:
  53. self.log(f'SELL @ {self.dataclose[0]:.2f}, Trend: {trend_val:.4f}')
  54. def notify_order(self, order):
  55. if order.status in [order.Submitted, order.Accepted]:
  56. return
  57. if order.status in [order.Completed]:
  58. if order.isbuy():
  59. self.log(f'BUY EXECUTED @ {order.executed.price:.2f}')
  60. else:
  61. self.log(f'SELL EXECUTED @ {order.executed.price:.2f}')
  62. self.order = None
  63. def log(self, txt, dt=None):
  64. if not self.p.printlog:
  65. return
  66. dt = dt or self.datas[0].datetime.date(0)
  67. print(f'{dt.isoformat()} {txt}')
  68. def stop(self):
  69. roi = (self.broker.getvalue() / self.broker.startingcash - 1) * 100
  70. print(f'\n=== 最终收益: {roi:.2f}% ===')
  71. def run_regime_v2(csv_file="chinext50.csv", cash=100000.0):
  72. """运行修正版Regime策略"""
  73. cerebro = bt.Cerebro()
  74. df = pd.read_csv(csv_file, parse_dates=['datetime'], index_col='datetime')
  75. data = bt.feeds.PandasData(dataname=df)
  76. cerebro.adddata(data)
  77. cerebro.addstrategy(RegimeStrategyV2, printlog=False)
  78. cerebro.broker.setcash(cash)
  79. cerebro.broker.setcommission(commission=0.001)
  80. cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02)
  81. cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
  82. cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
  83. cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
  84. print('=== Regime策略V2回测 ===')
  85. print(f'初始资金: {cerebro.broker.getvalue():.2f}')
  86. results = cerebro.run()
  87. strat = results[0]
  88. print(f'最终资金: {cerebro.broker.getvalue():.2f}')
  89. returns = strat.analyzers.returns.get_analysis()
  90. print(f"年化收益: {returns.get('rnorm100', 0):.2f}%")
  91. sharpe = strat.analyzers.sharpe.get_analysis()
  92. sharpe_val = sharpe.get('sharperatio', 0)
  93. if sharpe_val:
  94. print(f"夏普比率: {sharpe_val:.3f}")
  95. else:
  96. print("夏普比率: N/A")
  97. drawdown = strat.analyzers.drawdown.get_analysis()
  98. print(f"最大回撤: {drawdown.get('max', {}).get('drawdown', 0):.2f}%")
  99. trades = strat.analyzers.trades.get_analysis()
  100. if trades and trades.get('total'):
  101. total = trades['total'].get('total', 0)
  102. won = trades['won'].get('total', 0) if trades.get('won') else 0
  103. print(f"总交易: {total}, 盈利: {won}")
  104. if total > 0:
  105. print(f"胜率: {won/total:.1%}")
  106. return cerebro, strat
  107. if __name__ == "__main__":
  108. run_regime_v2()