backtest.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import backtrader as bt
  2. import pandas as pd
  3. from datetime import datetime
  4. class SmaCrossStrategy(bt.Strategy):
  5. """双均线交叉策略 - 创业板50示例"""
  6. params = (
  7. ('fast', 20),
  8. ('slow', 60),
  9. ('printlog', False),
  10. )
  11. def __init__(self):
  12. self.dataclose = self.datas[0].close
  13. self.order = None
  14. self.buyprice = None
  15. self.buycomm = None
  16. # 双均线
  17. self.sma_fast = bt.indicators.SimpleMovingAverage(
  18. self.datas[0], period=self.params.fast)
  19. self.sma_slow = bt.indicators.SimpleMovingAverage(
  20. self.datas[0], period=self.params.slow)
  21. # 交叉信号
  22. self.crossover = bt.indicators.CrossOver(self.sma_fast, self.sma_slow)
  23. def notify_order(self, order):
  24. if order.status in [order.Submitted, order.Accepted]:
  25. return
  26. if order.status in [order.Completed]:
  27. if order.isbuy():
  28. if self.params.printlog:
  29. self.log(f'BUY EXECUTED, Price: {order.executed.price:.2f}, '
  30. f'Cost: {order.executed.value:.2f}, '
  31. f'Comm: {order.executed.comm:.2f}')
  32. self.buyprice = order.executed.price
  33. self.buycomm = order.executed.comm
  34. else:
  35. if self.params.printlog:
  36. self.log(f'SELL EXECUTED, Price: {order.executed.price:.2f}, '
  37. f'Cost: {order.executed.value:.2f}, '
  38. f'Comm: {order.executed.comm:.2f}')
  39. elif order.status in [order.Canceled, order.Margin, order.Rejected]:
  40. if self.params.printlog:
  41. self.log('Order Canceled/Margin/Rejected')
  42. self.order = None
  43. def notify_trade(self, trade):
  44. if not trade.isclosed:
  45. return
  46. if self.params.printlog:
  47. self.log(f'OPERATION PROFIT, GROSS: {trade.pnl:.2f}, NET: {trade.pnlcomm:.2f}')
  48. def next(self):
  49. if self.order:
  50. return
  51. # 金叉买入
  52. if self.crossover > 0:
  53. if not self.position:
  54. self.order = self.buy()
  55. if self.params.printlog:
  56. self.log(f'BUY CREATE, {self.dataclose[0]:.2f}')
  57. # 死叉卖出
  58. elif self.crossover < 0:
  59. if self.position:
  60. self.order = self.sell()
  61. if self.params.printlog:
  62. self.log(f'SELL CREATE, {self.dataclose[0]:.2f}')
  63. def log(self, txt, dt=None):
  64. dt = dt or self.datas[0].datetime.date(0)
  65. print(f'{dt.isoformat()} {txt}')
  66. def stop(self):
  67. # 最终收益
  68. roi = (self.broker.getvalue() / self.broker.startingcash - 1) * 100
  69. print(f'\n=== 最终收益: {roi:.2f}% ===')
  70. print(f'初始资金: {self.broker.startingcash:.2f}')
  71. print(f'最终资金: {self.broker.getvalue():.2f}')
  72. def run_backtest(csv_file="chinext50.csv", cash=100000.0, commission=0.001):
  73. """运行回测"""
  74. cerebro = bt.Cerebro()
  75. # 数据
  76. df = pd.read_csv(csv_file, parse_dates=['datetime'], index_col='datetime')
  77. data = bt.feeds.PandasData(dataname=df)
  78. cerebro.adddata(data)
  79. # 策略
  80. cerebro.addstrategy(SmaCrossStrategy, fast=20, slow=60, printlog=True)
  81. # 资金与手续费
  82. cerebro.broker.setcash(cash)
  83. cerebro.broker.setcommission(commission=commission)
  84. # 添加分析器
  85. cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02)
  86. cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
  87. cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
  88. print(f'初始资金: {cerebro.broker.getvalue():.2f}')
  89. # 运行
  90. results = cerebro.run()
  91. strat = results[0]
  92. # 输出指标
  93. print(f'\n=== 回测指标 ===')
  94. print(f"年化收益: {strat.analyzers.returns.get_analysis()['rnorm100']:.2f}%")
  95. print(f"夏普比率: {strat.analyzers.sharpe.get_analysis()['sharperatio']:.3f}")
  96. print(f"最大回撤: {strat.analyzers.drawdown.get_analysis()['max']['drawdown']:.2f}%")
  97. return cerebro, strat
  98. if __name__ == "__main__":
  99. # 先运行 fetch_data.py 获取数据
  100. cerebro, strat = run_backtest()
  101. # cerebro.plot() # 如需图表,取消注释