backtest.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. #!/usr/bin/env python3
  2. """
  3. 双均线策略回测入口
  4. 使用示例:
  5. python backtest.py --symbol AAPL --start 2020-01-01 --end 2023-12-31
  6. python backtest.py --symbol BTC-USD --short 10 --long 30
  7. """
  8. import argparse
  9. import sys
  10. from datetime import datetime, timedelta
  11. from pathlib import Path
  12. import numpy as np
  13. import pandas as pd
  14. import yfinance as yf
  15. from dual_ma_strategy import DualMAStrategy, SignalType
  16. def fetch_data(
  17. symbol: str,
  18. start_date: str,
  19. end_date: str,
  20. interval: str = '1d'
  21. ) -> pd.DataFrame:
  22. """
  23. 从Yahoo Finance获取历史数据
  24. 参数:
  25. symbol: 股票/加密货币代码
  26. start_date: 开始日期 (YYYY-MM-DD)
  27. end_date: 结束日期 (YYYY-MM-DD)
  28. interval: 数据周期 (1d, 1h, 1m等)
  29. 返回:
  30. 包含OHLCV数据的DataFrame
  31. """
  32. print(f"正在下载 {symbol} 的数据 ({start_date} ~ {end_date})...")
  33. try:
  34. ticker = yf.Ticker(symbol)
  35. df = ticker.history(start=start_date, end=end_date, interval=interval)
  36. if df.empty:
  37. raise ValueError(f"未获取到 {symbol} 的数据")
  38. # 标准化列名
  39. df.columns = [c.lower().replace(' ', '_') for c in df.columns]
  40. # 确保必要的列存在
  41. required_cols = ['open', 'high', 'low', 'close', 'volume']
  42. for col in required_cols:
  43. if col not in df.columns:
  44. raise ValueError(f"数据缺少必要的列: {col}")
  45. print(f"成功获取 {len(df)} 条数据")
  46. return df
  47. except Exception as e:
  48. print(f"数据获取失败: {e}")
  49. sys.exit(1)
  50. def generate_sample_data(
  51. start_date: str,
  52. end_date: str,
  53. n_points: int = 500
  54. ) -> pd.DataFrame:
  55. """
  56. 生成模拟价格数据(用于测试,无需网络)
  57. 参数:
  58. start_date: 开始日期
  59. end_date: 结束日期
  60. n_points: 数据点数量
  61. 返回:
  62. 模拟的OHLCV数据
  63. """
  64. print("使用模拟数据进行测试...")
  65. # 生成日期范围
  66. dates = pd.date_range(start=start_date, periods=n_points, freq='D')
  67. # 生成随机游走价格
  68. np.random.seed(42) # 固定随机种子,保证可重复
  69. returns = np.random.normal(0.001, 0.02, n_points) # 正态分布收益率
  70. price = 100 * np.exp(np.cumsum(returns)) # 几何布朗运动
  71. # 生成OHLCV数据
  72. df = pd.DataFrame({
  73. 'open': price * (1 + np.random.normal(0, 0.005, n_points)),
  74. 'high': price * (1 + np.abs(np.random.normal(0, 0.01, n_points))),
  75. 'low': price * (1 - np.abs(np.random.normal(0, 0.01, n_points))),
  76. 'close': price,
  77. 'volume': np.random.randint(1000000, 10000000, n_points)
  78. }, index=dates)
  79. # 确保high是最高,low是最低
  80. df['high'] = df[['open', 'close', 'high']].max(axis=1)
  81. df['low'] = df[['open', 'close', 'low']].min(axis=1)
  82. print(f"生成了 {len(df)} 条模拟数据")
  83. return df
  84. def run_backtest(
  85. df: pd.DataFrame,
  86. short_window: int,
  87. long_window: int,
  88. initial_capital: float,
  89. verbose: bool = True
  90. ) -> dict:
  91. """
  92. 执行回测
  93. 参数:
  94. df: OHLCV数据
  95. short_window: 短期均线周期
  96. long_window: 长期均线周期
  97. initial_capital: 初始资金
  98. verbose: 是否打印详细信息
  99. 返回:
  100. 回测结果字典
  101. """
  102. print("\n" + "=" * 60)
  103. print("开始回测")
  104. print(f"策略参数: 短期MA={short_window}, 长期MA={long_window}")
  105. print(f"初始资金: {initial_capital:,.2f}")
  106. print("=" * 60 + "\n")
  107. # 初始化策略
  108. strategy = DualMAStrategy(
  109. short_window=short_window,
  110. long_window=long_window,
  111. initial_capital=initial_capital
  112. )
  113. # 生成交易信号
  114. df_with_signals = strategy.generate_signals(df)
  115. # 遍历每根K线执行策略
  116. for timestamp, row in df_with_signals.iterrows():
  117. # 跳过无效数据
  118. if pd.isna(row['short_ma']) or pd.isna(row['long_ma']):
  119. continue
  120. signal = strategy.on_bar(timestamp, row)
  121. if signal and verbose:
  122. emoji = "🟢" if signal.signal_type == SignalType.BUY else "🔴"
  123. print(f"\n{emoji} 信号触发 [{signal.timestamp}]")
  124. print(f" 类型: {'买入' if signal.signal_type == SignalType.BUY else '卖出'}")
  125. print(f" 价格: {signal.price:.2f}")
  126. print(f" 短期MA: {signal.short_ma:.2f}")
  127. print(f" 长期MA: {signal.long_ma:.2f}")
  128. print(f" 原因: {signal.reason}")
  129. # 最后如果有持仓,强制平仓
  130. if strategy.position is not None:
  131. print(f"\n回测结束,强制平仓...")
  132. last_price = df['close'].iloc[-1]
  133. last_time = df.index[-1]
  134. strategy._close_position(last_time, last_price)
  135. # 获取绩效汇总
  136. performance = strategy.get_performance_summary()
  137. return {
  138. 'strategy': strategy,
  139. 'performance': performance,
  140. 'df': df_with_signals,
  141. 'signals': strategy.signals,
  142. 'trades': strategy.trades
  143. }
  144. def print_results(results: dict):
  145. """打印回测结果"""
  146. perf = results['performance']
  147. trades = results['trades']
  148. print("\n" + "=" * 60)
  149. print("回测结果汇总")
  150. print("=" * 60)
  151. print(f"\n【交易统计】")
  152. print(f" 总交易次数: {perf['total_trades']}")
  153. print(f" 盈利次数: {perf['winning_trades']}")
  154. print(f" 亏损次数: {perf['losing_trades']}")
  155. print(f" 胜率: {perf['win_rate']:.2f}%")
  156. print(f"\n【盈亏统计】")
  157. print(f" 总盈亏: {perf['total_pnl']:+.2f}")
  158. print(f" 平均盈亏: {perf['avg_pnl']:+.2f}")
  159. print(f" 最大盈利: {perf['max_pnl']:+.2f}")
  160. print(f" 最大亏损: {perf['min_pnl']:+.2f}")
  161. print(f" 平均盈利: {perf['avg_win']:+.2f}")
  162. print(f" 平均亏损: {perf['avg_loss']:+.2f}")
  163. print(f" 盈亏比: {perf['profit_factor']:.2f}")
  164. print(f"\n【收益表现】")
  165. print(f" 初始资金: {results['strategy'].initial_capital:,.2f}")
  166. print(f" 最终权益: {perf['final_equity']:,.2f}")
  167. print(f" 总收益率: {perf['total_return_pct']:+.2f}%")
  168. print(f"\n【交易明细】")
  169. if trades:
  170. print(f"{'序号':<6}{'入场时间':<22}{'出场时间':<22}{'方向':<6}{'入场价':<10}{'出场价':<10}{'盈亏':<12}{'盈亏%':<10}")
  171. print("-" * 110)
  172. for i, trade in enumerate(trades, 1):
  173. side_str = "多" if trade['side'] == 1 else "空"
  174. pnl_str = f"{trade['pnl']:+.2f}"
  175. pnl_pct_str = f"{trade['pnl_pct']:+.2f}%"
  176. print(f"{i:<6}{str(trade['entry_time']):<22}{str(trade['exit_time']):<22}"
  177. f"{side_str:<6}{trade['entry_price']:<10.2f}{trade['exit_price']:<10.2f}"
  178. f"{pnl_str:<12}{pnl_pct_str:<10}")
  179. def plot_results(results: dict, output_path: str = None):
  180. """
  181. 绘制回测结果图表(需要matplotlib)
  182. 参数:
  183. results: 回测结果
  184. output_path: 图表保存路径(可选)
  185. """
  186. try:
  187. import matplotlib.pyplot as plt
  188. from matplotlib.patches import Rectangle
  189. except ImportError:
  190. print("\n提示: 安装 matplotlib 可生成可视化图表")
  191. print(" pip install matplotlib")
  192. return
  193. df = results['df']
  194. signals = results['signals']
  195. trades = results['trades']
  196. fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True,
  197. gridspec_kw={'height_ratios': [3, 1, 1]})
  198. # 图1: 价格与均线
  199. ax1 = axes[0]
  200. ax1.plot(df.index, df['close'], label='收盘价', linewidth=1.5, color='black', alpha=0.7)
  201. ax1.plot(df.index, df['short_ma'], label=f"短期MA({results['strategy'].short_window})",
  202. linewidth=1, color='orange')
  203. ax1.plot(df.index, df['long_ma'], label=f"长期MA({results['strategy'].long_window})",
  204. linewidth=1, color='blue')
  205. # 标记买卖点
  206. for signal in signals:
  207. if signal.signal_type == SignalType.BUY:
  208. ax1.scatter(signal.timestamp, signal.price, marker='^', color='green',
  209. s=100, zorder=5, label='买入' if signal == signals[0] else "")
  210. else:
  211. ax1.scatter(signal.timestamp, signal.price, marker='v', color='red',
  212. s=100, zorder=5, label='卖出' if signal == signals[0] else "")
  213. ax1.set_ylabel('价格', fontsize=11)
  214. ax1.set_title('双均线策略回测结果', fontsize=13, fontweight='bold')
  215. ax1.legend(loc='upper left', fontsize=9)
  216. ax1.grid(True, alpha=0.3)
  217. # 图2: 持仓状态
  218. ax2 = axes[1]
  219. position_series = pd.Series(index=df.index, data=0.0)
  220. for trade in trades:
  221. mask = (df.index >= trade['entry_time']) & (df.index <= trade['exit_time'])
  222. position_series[mask] = 1.0 if trade['side'] == 1 else -1.0
  223. ax2.fill_between(df.index, position_series, alpha=0.3,
  224. where=position_series > 0, color='green', label='多头持仓')
  225. ax2.fill_between(df.index, position_series, alpha=0.3,
  226. where=position_series < 0, color='red', label='空头持仓')
  227. ax2.set_ylabel('持仓', fontsize=11)
  228. ax2.set_ylim(-1.5, 1.5)
  229. ax2.legend(loc='upper left', fontsize=9)
  230. ax2.grid(True, alpha=0.3)
  231. # 图3: 累计盈亏
  232. ax3 = axes[2]
  233. cumulative_pnl = []
  234. running_pnl = 0
  235. trade_idx = 0
  236. for timestamp in df.index:
  237. # 检查是否有交易在这一天结束
  238. for trade in trades:
  239. if trade['exit_time'] == timestamp:
  240. running_pnl += trade['pnl']
  241. cumulative_pnl.append(running_pnl)
  242. ax3.plot(df.index, cumulative_pnl, color='purple', linewidth=1.5)
  243. ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
  244. ax3.fill_between(df.index, cumulative_pnl, 0, alpha=0.3,
  245. where=[p >= 0 for p in cumulative_pnl], color='green')
  246. ax3.fill_between(df.index, cumulative_pnl, 0, alpha=0.3,
  247. where=[p < 0 for p in cumulative_pnl], color='red')
  248. ax3.set_ylabel('累计盈亏', fontsize=11)
  249. ax3.set_xlabel('日期', fontsize=11)
  250. ax3.grid(True, alpha=0.3)
  251. plt.tight_layout()
  252. if output_path:
  253. plt.savefig(output_path, dpi=150, bbox_inches='tight')
  254. print(f"\n图表已保存至: {output_path}")
  255. else:
  256. plt.show()
  257. def main():
  258. """主函数"""
  259. parser = argparse.ArgumentParser(
  260. description='双均线策略回测工具',
  261. formatter_class=argparse.RawDescriptionHelpFormatter,
  262. epilog="""
  263. 使用示例:
  264. # 使用真实数据回测
  265. python backtest.py --symbol AAPL --start 2020-01-01 --end 2023-12-31
  266. # 自定义均线参数
  267. python backtest.py --symbol BTC-USD --short 10 --long 50 --capital 50000
  268. # 使用模拟数据(无需网络)
  269. python backtest.py --demo
  270. # 保存图表
  271. python backtest.py --symbol TSLA --plot result.png
  272. """
  273. )
  274. # 数据参数
  275. parser.add_argument('--symbol', '-s', type=str, default='AAPL',
  276. help='股票代码 (默认: AAPL)')
  277. parser.add_argument('--start', type=str,
  278. default=(datetime.now() - timedelta(days=3*365)).strftime('%Y-%m-%d'),
  279. help='开始日期 (YYYY-MM-DD)')
  280. parser.add_argument('--end', type=str,
  281. default=datetime.now().strftime('%Y-%m-%d'),
  282. help='结束日期 (YYYY-MM-DD)')
  283. parser.add_argument('--interval', '-i', type=str, default='1d',
  284. help='数据周期: 1d, 1wk, 1mo (默认: 1d)')
  285. # 策略参数
  286. parser.add_argument('--short', type=int, default=5,
  287. help='短期均线周期 (默认: 5)')
  288. parser.add_argument('--long', type=int, default=20,
  289. help='长期均线周期 (默认: 20)')
  290. parser.add_argument('--capital', '-c', type=float, default=100000,
  291. help='初始资金 (默认: 100000)')
  292. # 其他选项
  293. parser.add_argument('--demo', action='store_true',
  294. help='使用模拟数据进行测试')
  295. parser.add_argument('--plot', type=str, metavar='PATH',
  296. help='保存图表到指定路径')
  297. parser.add_argument('--quiet', '-q', action='store_true',
  298. help='安静模式,只输出汇总结果')
  299. args = parser.parse_args()
  300. # 获取数据
  301. if args.demo:
  302. df = generate_sample_data(args.start, args.end)
  303. else:
  304. df = fetch_data(args.symbol, args.start, args.end, args.interval)
  305. # 执行回测
  306. results = run_backtest(
  307. df=df,
  308. short_window=args.short,
  309. long_window=args.long,
  310. initial_capital=args.capital,
  311. verbose=not args.quiet
  312. )
  313. # 打印结果
  314. print_results(results)
  315. # 绘制图表
  316. if args.plot:
  317. plot_results(results, args.plot)
  318. elif not args.quiet:
  319. # 询问是否显示图表
  320. plot_results(results)
  321. if __name__ == '__main__':
  322. main()