debug_indicators.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334
  1. """
  2. Debug版本 - 检查指标计算
  3. """
  4. import backtrader as bt
  5. import pandas as pd
  6. import numpy as np
  7. class DebugStrategy(bt.Strategy):
  8. params = (('fast', 20), ('slow', 60))
  9. def __init__(self):
  10. self.dataclose = self.datas[0].close
  11. self.sma_fast = bt.indicators.SMA(period=self.p.fast)
  12. self.sma_slow = bt.indicators.SMA(period=self.p.slow)
  13. self.crossover = bt.indicators.CrossOver(self.sma_fast, self.sma_slow)
  14. def next(self):
  15. # 每100天打印一次检查
  16. if len(self) % 100 == 0:
  17. fast_val = self.sma_fast[0]
  18. slow_val = self.sma_slow[0]
  19. fast_str = f'{fast_val:.2f}' if not np.isnan(fast_val) else 'None'
  20. slow_str = f'{slow_val:.2f}' if not np.isnan(slow_val) else 'None'
  21. print(f'Day {len(self)}: Close={self.dataclose[0]:.2f}, FastMA={fast_str}, SlowMA={slow_str}, Cross={self.crossover[0]}')
  22. if __name__ == "__main__":
  23. cerebro = bt.Cerebro()
  24. df = pd.read_csv("chinext50.csv", parse_dates=['datetime'], index_col='datetime')
  25. data = bt.feeds.PandasData(dataname=df)
  26. cerebro.adddata(data)
  27. cerebro.addstrategy(DebugStrategy)
  28. cerebro.run()