regime_simple.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """
  2. 简化的Regime策略 - 基于趋势强度动态调整仓位
  3. """
  4. import backtrader as bt
  5. import pandas as pd
  6. import numpy as np
  7. class SimpleRegimeStrategy(bt.Strategy):
  8. """
  9. 简化版状态感知策略
  10. 逻辑:
  11. - 计算20日趋势强度
  12. - 趋势强且向上: 满仓
  13. - 趋势弱或向下: 减仓/空仓
  14. - 均线金叉入场,死叉出场
  15. """
  16. params = (
  17. ('fast', 20),
  18. ('slow', 60),
  19. ('trend_threshold', 0.02),
  20. ('printlog', False),
  21. )
  22. def __init__(self):
  23. self.dataclose = self.datas[0].close
  24. self.order = None
  25. # 均线
  26. self.sma_fast = bt.indicators.SMA(period=self.p.fast)
  27. self.sma_slow = bt.indicators.SMA(period=self.p.slow)
  28. self.crossover = bt.indicators.CrossOver(self.sma_fast, self.sma_slow)
  29. # 趋势: 价格相对快均线的偏离
  30. self.trend = (self.dataclose - self.sma_fast) / self.sma_fast
  31. def next(self):
  32. if self.order:
  33. return
  34. # 当前趋势
  35. trend_val = self.trend[0] if not np.isnan(self.trend[0]) else 0
  36. # 金叉 + 趋势向上 = 买入
  37. if self.crossover > 0 and trend_val > -self.p.trend_threshold:
  38. if not self.position:
  39. size = int(self.broker.getcash() / self.dataclose[0] / 100) * 100
  40. if size > 0:
  41. self.order = self.buy(size=size)
  42. if self.p.printlog:
  43. self.log(f'BUY @ {self.dataclose[0]:.2f}, Trend: {trend_val:.4f}')
  44. # 死叉 或 趋势转弱 = 卖出
  45. elif self.crossover < 0 or (self.position and trend_val < -self.p.trend_threshold * 2):
  46. if self.position:
  47. self.order = self.close()
  48. if self.p.printlog:
  49. reason = 'Death Cross' if self.crossover < 0 else 'Weak Trend'
  50. self.log(f'SELL @ {self.dataclose[0]:.2f}, Reason: {reason}')
  51. def notify_order(self, order):
  52. if order.status in [order.Submitted, order.Accepted]:
  53. return
  54. if order.status in [order.Completed]:
  55. if order.isbuy():
  56. self.log(f'BUY EXECUTED @ {order.executed.price:.2f}')
  57. else:
  58. self.log(f'SELL EXECUTED @ {order.executed.price:.2f}')
  59. self.order = None
  60. def log(self, txt, dt=None):
  61. if not self.p.printlog:
  62. return
  63. dt = dt or self.datas[0].datetime.date(0)
  64. print(f'{dt.isoformat()} {txt}')
  65. def stop(self):
  66. roi = (self.broker.getvalue() / self.broker.startingcash - 1) * 100
  67. print(f'\n=== 最终收益: {roi:.2f}% ===')
  68. def run_simple_regime(csv_file="chinext50.csv", cash=100000.0):
  69. """运行简化Regime策略"""
  70. cerebro = bt.Cerebro()
  71. df = pd.read_csv(csv_file, parse_dates=['datetime'], index_col='datetime')
  72. data = bt.feeds.PandasData(dataname=df)
  73. cerebro.adddata(data)
  74. cerebro.addstrategy(SimpleRegimeStrategy, printlog=False)
  75. cerebro.broker.setcash(cash)
  76. cerebro.broker.setcommission(commission=0.001)
  77. cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02)
  78. cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
  79. cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
  80. cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
  81. print('=== 简化Regime策略回测 ===')
  82. print(f'初始资金: {cerebro.broker.getvalue():.2f}')
  83. results = cerebro.run()
  84. strat = results[0]
  85. print(f'\n最终资金: {cerebro.broker.getvalue():.2f}')
  86. returns = strat.analyzers.returns.get_analysis()
  87. print(f"年化收益: {returns.get('rnorm100', 0):.2f}%")
  88. sharpe = strat.analyzers.sharpe.get_analysis()
  89. sharpe_val = sharpe.get('sharperatio', 0)
  90. print(f"夏普比率: {sharpe_val:.3f}" if sharpe_val else "夏普比率: N/A")
  91. drawdown = strat.analyzers.drawdown.get_analysis()
  92. print(f"最大回撤: {drawdown.get('max', {}).get('drawdown', 0):.2f}%")
  93. trades = strat.analyzers.trades.get_analysis()
  94. if trades.get('total'):
  95. total_trades = trades['total'].get('total', 0)
  96. won_trades = trades['won'].get('total', 0) if trades.get('won') else 0
  97. print(f"总交易次数: {total_trades}")
  98. print(f"盈利次数: {won_trades}")
  99. if total_trades > 0:
  100. print(f"胜率: {won_trades/total_trades:.1%}")
  101. return cerebro, strat
  102. if __name__ == "__main__":
  103. run_simple_regime()