| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 |
- #!/usr/bin/env python3
- """
- 双均线策略回测入口
- 使用示例:
- python backtest.py --symbol AAPL --start 2020-01-01 --end 2023-12-31
- python backtest.py --symbol BTC-USD --short 10 --long 30
- """
- import argparse
- import sys
- from datetime import datetime, timedelta
- from pathlib import Path
- import numpy as np
- import pandas as pd
- import yfinance as yf
- from dual_ma_strategy import DualMAStrategy, SignalType
- def fetch_data(
- symbol: str,
- start_date: str,
- end_date: str,
- interval: str = '1d'
- ) -> pd.DataFrame:
- """
- 从Yahoo Finance获取历史数据
- 参数:
- symbol: 股票/加密货币代码
- start_date: 开始日期 (YYYY-MM-DD)
- end_date: 结束日期 (YYYY-MM-DD)
- interval: 数据周期 (1d, 1h, 1m等)
- 返回:
- 包含OHLCV数据的DataFrame
- """
- print(f"正在下载 {symbol} 的数据 ({start_date} ~ {end_date})...")
- try:
- ticker = yf.Ticker(symbol)
- df = ticker.history(start=start_date, end=end_date, interval=interval)
- if df.empty:
- raise ValueError(f"未获取到 {symbol} 的数据")
- # 标准化列名
- df.columns = [c.lower().replace(' ', '_') for c in df.columns]
- # 确保必要的列存在
- required_cols = ['open', 'high', 'low', 'close', 'volume']
- for col in required_cols:
- if col not in df.columns:
- raise ValueError(f"数据缺少必要的列: {col}")
- print(f"成功获取 {len(df)} 条数据")
- return df
- except Exception as e:
- print(f"数据获取失败: {e}")
- sys.exit(1)
- def generate_sample_data(
- start_date: str,
- end_date: str,
- n_points: int = 500
- ) -> pd.DataFrame:
- """
- 生成模拟价格数据(用于测试,无需网络)
- 参数:
- start_date: 开始日期
- end_date: 结束日期
- n_points: 数据点数量
- 返回:
- 模拟的OHLCV数据
- """
- print("使用模拟数据进行测试...")
- # 生成日期范围
- dates = pd.date_range(start=start_date, periods=n_points, freq='D')
- # 生成随机游走价格
- np.random.seed(42) # 固定随机种子,保证可重复
- returns = np.random.normal(0.001, 0.02, n_points) # 正态分布收益率
- price = 100 * np.exp(np.cumsum(returns)) # 几何布朗运动
- # 生成OHLCV数据
- df = pd.DataFrame({
- 'open': price * (1 + np.random.normal(0, 0.005, n_points)),
- 'high': price * (1 + np.abs(np.random.normal(0, 0.01, n_points))),
- 'low': price * (1 - np.abs(np.random.normal(0, 0.01, n_points))),
- 'close': price,
- 'volume': np.random.randint(1000000, 10000000, n_points)
- }, index=dates)
- # 确保high是最高,low是最低
- df['high'] = df[['open', 'close', 'high']].max(axis=1)
- df['low'] = df[['open', 'close', 'low']].min(axis=1)
- print(f"生成了 {len(df)} 条模拟数据")
- return df
- def run_backtest(
- df: pd.DataFrame,
- short_window: int,
- long_window: int,
- initial_capital: float,
- verbose: bool = True
- ) -> dict:
- """
- 执行回测
- 参数:
- df: OHLCV数据
- short_window: 短期均线周期
- long_window: 长期均线周期
- initial_capital: 初始资金
- verbose: 是否打印详细信息
- 返回:
- 回测结果字典
- """
- print("\n" + "=" * 60)
- print("开始回测")
- print(f"策略参数: 短期MA={short_window}, 长期MA={long_window}")
- print(f"初始资金: {initial_capital:,.2f}")
- print("=" * 60 + "\n")
- # 初始化策略
- strategy = DualMAStrategy(
- short_window=short_window,
- long_window=long_window,
- initial_capital=initial_capital
- )
- # 生成交易信号
- df_with_signals = strategy.generate_signals(df)
- # 遍历每根K线执行策略
- for timestamp, row in df_with_signals.iterrows():
- # 跳过无效数据
- if pd.isna(row['short_ma']) or pd.isna(row['long_ma']):
- continue
- signal = strategy.on_bar(timestamp, row)
- if signal and verbose:
- emoji = "🟢" if signal.signal_type == SignalType.BUY else "🔴"
- print(f"\n{emoji} 信号触发 [{signal.timestamp}]")
- print(f" 类型: {'买入' if signal.signal_type == SignalType.BUY else '卖出'}")
- print(f" 价格: {signal.price:.2f}")
- print(f" 短期MA: {signal.short_ma:.2f}")
- print(f" 长期MA: {signal.long_ma:.2f}")
- print(f" 原因: {signal.reason}")
- # 最后如果有持仓,强制平仓
- if strategy.position is not None:
- print(f"\n回测结束,强制平仓...")
- last_price = df['close'].iloc[-1]
- last_time = df.index[-1]
- strategy._close_position(last_time, last_price)
- # 获取绩效汇总
- performance = strategy.get_performance_summary()
- return {
- 'strategy': strategy,
- 'performance': performance,
- 'df': df_with_signals,
- 'signals': strategy.signals,
- 'trades': strategy.trades
- }
- def print_results(results: dict):
- """打印回测结果"""
- perf = results['performance']
- trades = results['trades']
- print("\n" + "=" * 60)
- print("回测结果汇总")
- print("=" * 60)
- print(f"\n【交易统计】")
- print(f" 总交易次数: {perf['total_trades']}")
- print(f" 盈利次数: {perf['winning_trades']}")
- print(f" 亏损次数: {perf['losing_trades']}")
- print(f" 胜率: {perf['win_rate']:.2f}%")
- print(f"\n【盈亏统计】")
- print(f" 总盈亏: {perf['total_pnl']:+.2f}")
- print(f" 平均盈亏: {perf['avg_pnl']:+.2f}")
- print(f" 最大盈利: {perf['max_pnl']:+.2f}")
- print(f" 最大亏损: {perf['min_pnl']:+.2f}")
- print(f" 平均盈利: {perf['avg_win']:+.2f}")
- print(f" 平均亏损: {perf['avg_loss']:+.2f}")
- print(f" 盈亏比: {perf['profit_factor']:.2f}")
- print(f"\n【收益表现】")
- print(f" 初始资金: {results['strategy'].initial_capital:,.2f}")
- print(f" 最终权益: {perf['final_equity']:,.2f}")
- print(f" 总收益率: {perf['total_return_pct']:+.2f}%")
- print(f"\n【交易明细】")
- if trades:
- print(f"{'序号':<6}{'入场时间':<22}{'出场时间':<22}{'方向':<6}{'入场价':<10}{'出场价':<10}{'盈亏':<12}{'盈亏%':<10}")
- print("-" * 110)
- for i, trade in enumerate(trades, 1):
- side_str = "多" if trade['side'] == 1 else "空"
- pnl_str = f"{trade['pnl']:+.2f}"
- pnl_pct_str = f"{trade['pnl_pct']:+.2f}%"
- print(f"{i:<6}{str(trade['entry_time']):<22}{str(trade['exit_time']):<22}"
- f"{side_str:<6}{trade['entry_price']:<10.2f}{trade['exit_price']:<10.2f}"
- f"{pnl_str:<12}{pnl_pct_str:<10}")
- def plot_results(results: dict, output_path: str = None):
- """
- 绘制回测结果图表(需要matplotlib)
- 参数:
- results: 回测结果
- output_path: 图表保存路径(可选)
- """
- try:
- import matplotlib.pyplot as plt
- from matplotlib.patches import Rectangle
- except ImportError:
- print("\n提示: 安装 matplotlib 可生成可视化图表")
- print(" pip install matplotlib")
- return
- df = results['df']
- signals = results['signals']
- trades = results['trades']
- fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True,
- gridspec_kw={'height_ratios': [3, 1, 1]})
- # 图1: 价格与均线
- ax1 = axes[0]
- ax1.plot(df.index, df['close'], label='收盘价', linewidth=1.5, color='black', alpha=0.7)
- ax1.plot(df.index, df['short_ma'], label=f"短期MA({results['strategy'].short_window})",
- linewidth=1, color='orange')
- ax1.plot(df.index, df['long_ma'], label=f"长期MA({results['strategy'].long_window})",
- linewidth=1, color='blue')
- # 标记买卖点
- for signal in signals:
- if signal.signal_type == SignalType.BUY:
- ax1.scatter(signal.timestamp, signal.price, marker='^', color='green',
- s=100, zorder=5, label='买入' if signal == signals[0] else "")
- else:
- ax1.scatter(signal.timestamp, signal.price, marker='v', color='red',
- s=100, zorder=5, label='卖出' if signal == signals[0] else "")
- ax1.set_ylabel('价格', fontsize=11)
- ax1.set_title('双均线策略回测结果', fontsize=13, fontweight='bold')
- ax1.legend(loc='upper left', fontsize=9)
- ax1.grid(True, alpha=0.3)
- # 图2: 持仓状态
- ax2 = axes[1]
- position_series = pd.Series(index=df.index, data=0.0)
- for trade in trades:
- mask = (df.index >= trade['entry_time']) & (df.index <= trade['exit_time'])
- position_series[mask] = 1.0 if trade['side'] == 1 else -1.0
- ax2.fill_between(df.index, position_series, alpha=0.3,
- where=position_series > 0, color='green', label='多头持仓')
- ax2.fill_between(df.index, position_series, alpha=0.3,
- where=position_series < 0, color='red', label='空头持仓')
- ax2.set_ylabel('持仓', fontsize=11)
- ax2.set_ylim(-1.5, 1.5)
- ax2.legend(loc='upper left', fontsize=9)
- ax2.grid(True, alpha=0.3)
- # 图3: 累计盈亏
- ax3 = axes[2]
- cumulative_pnl = []
- running_pnl = 0
- trade_idx = 0
- for timestamp in df.index:
- # 检查是否有交易在这一天结束
- for trade in trades:
- if trade['exit_time'] == timestamp:
- running_pnl += trade['pnl']
- cumulative_pnl.append(running_pnl)
- ax3.plot(df.index, cumulative_pnl, color='purple', linewidth=1.5)
- ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
- ax3.fill_between(df.index, cumulative_pnl, 0, alpha=0.3,
- where=[p >= 0 for p in cumulative_pnl], color='green')
- ax3.fill_between(df.index, cumulative_pnl, 0, alpha=0.3,
- where=[p < 0 for p in cumulative_pnl], color='red')
- ax3.set_ylabel('累计盈亏', fontsize=11)
- ax3.set_xlabel('日期', fontsize=11)
- ax3.grid(True, alpha=0.3)
- plt.tight_layout()
- if output_path:
- plt.savefig(output_path, dpi=150, bbox_inches='tight')
- print(f"\n图表已保存至: {output_path}")
- else:
- plt.show()
- def main():
- """主函数"""
- parser = argparse.ArgumentParser(
- description='双均线策略回测工具',
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- 使用示例:
- # 使用真实数据回测
- python backtest.py --symbol AAPL --start 2020-01-01 --end 2023-12-31
- # 自定义均线参数
- python backtest.py --symbol BTC-USD --short 10 --long 50 --capital 50000
- # 使用模拟数据(无需网络)
- python backtest.py --demo
- # 保存图表
- python backtest.py --symbol TSLA --plot result.png
- """
- )
- # 数据参数
- parser.add_argument('--symbol', '-s', type=str, default='AAPL',
- help='股票代码 (默认: AAPL)')
- parser.add_argument('--start', type=str,
- default=(datetime.now() - timedelta(days=3*365)).strftime('%Y-%m-%d'),
- help='开始日期 (YYYY-MM-DD)')
- parser.add_argument('--end', type=str,
- default=datetime.now().strftime('%Y-%m-%d'),
- help='结束日期 (YYYY-MM-DD)')
- parser.add_argument('--interval', '-i', type=str, default='1d',
- help='数据周期: 1d, 1wk, 1mo (默认: 1d)')
- # 策略参数
- parser.add_argument('--short', type=int, default=5,
- help='短期均线周期 (默认: 5)')
- parser.add_argument('--long', type=int, default=20,
- help='长期均线周期 (默认: 20)')
- parser.add_argument('--capital', '-c', type=float, default=100000,
- help='初始资金 (默认: 100000)')
- # 其他选项
- parser.add_argument('--demo', action='store_true',
- help='使用模拟数据进行测试')
- parser.add_argument('--plot', type=str, metavar='PATH',
- help='保存图表到指定路径')
- parser.add_argument('--quiet', '-q', action='store_true',
- help='安静模式,只输出汇总结果')
- args = parser.parse_args()
- # 获取数据
- if args.demo:
- df = generate_sample_data(args.start, args.end)
- else:
- df = fetch_data(args.symbol, args.start, args.end, args.interval)
- # 执行回测
- results = run_backtest(
- df=df,
- short_window=args.short,
- long_window=args.long,
- initial_capital=args.capital,
- verbose=not args.quiet
- )
- # 打印结果
- print_results(results)
- # 绘制图表
- if args.plot:
- plot_results(results, args.plot)
- elif not args.quiet:
- # 询问是否显示图表
- plot_results(results)
- if __name__ == '__main__':
- main()
|