| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 创业板50指数点位量化交易策略回测 - 简化版(快速演示)
- 训练集:2018-2023 | 验证集:2024-2025
- """
- import pandas as pd
- import numpy as np
- import matplotlib
- matplotlib.use('Agg')
- import matplotlib.pyplot as plt
- from datetime import datetime
- import warnings
- warnings.filterwarnings('ignore')
- # 设置图表
- plt.rcParams['font.size'] = 10
- # ==================== 1. 数据加载 ====================
- def load_real_data():
- """加载创业板50指数真实数据 - cyb50_baostock.csv"""
- df = pd.read_csv('cyb50_baostock.csv')
- df['date'] = pd.to_datetime(df['date'])
- df = df.set_index('date').sort_index()
-
- # 转换数据类型
- for col in ['open', 'high', 'low', 'close', 'volume']:
- df[col] = pd.to_numeric(df[col], errors='coerce')
-
- print(f"真实数据加载成功: {df.index[0].date()} ~ {df.index[-1].date()}")
- return df
- # ==================== 2. 策略类 ====================
- class SimpleCYBStrategy:
- """简化版策略:双均线 + 波动率控制"""
-
- def __init__(self, params=None):
- self.params = params or {
- 'fast_ma': 20, # 快线
- 'slow_ma': 60, # 慢线
- 'volatility_period': 20,
- 'max_position': 1.0,
- 'stop_loss': 0.10,
- }
- self.position = 0
- self.entry_price = None
-
- def generate_signal(self, data):
- """生成交易信号"""
- close = data['close']
- high = data['high']
- low = data['low']
-
- p = self.params
-
- # 计算均线
- ma_fast = close.rolling(p['fast_ma']).mean().iloc[-1]
- ma_slow = close.rolling(p['slow_ma']).mean().iloc[-1]
-
- # 计算波动率(20日ATR/价格)
- tr = pd.concat([
- high - low,
- abs(high - close.shift(1)),
- abs(low - close.shift(1))
- ], axis=1).max(axis=1)
- atr = tr.rolling(p['volatility_period']).mean().iloc[-1]
- vol_pct = atr / close.iloc[-1] * 100
-
- # 趋势判断
- trend_up = (close.iloc[-1] > ma_fast) and (ma_fast > ma_slow)
- trend_down = (close.iloc[-1] < ma_fast) and (ma_fast < ma_slow)
-
- # 仓位决策
- if trend_up and vol_pct < 4:
- target_pos = p['max_position'] # 满仓
- state = "BULL"
- elif trend_down or vol_pct > 6:
- target_pos = 0 # 空仓
- state = "BEAR" if trend_down else "HIGH_VOL"
- else:
- target_pos = p['max_position'] * 0.5 # 半仓
- state = "OSCILLATE"
-
- # 止损检查
- if self.position > 0 and self.entry_price:
- drawdown = (close.iloc[-1] - self.entry_price) / self.entry_price
- if drawdown < -p['stop_loss']:
- target_pos = 0
-
- # 更新
- if target_pos > 0 and self.position == 0:
- self.entry_price = close.iloc[-1]
- if target_pos == 0:
- self.entry_price = None
-
- self.position = target_pos
- return target_pos, state
- # ==================== 3. 回测引擎 ====================
- def backtest(data, strategy, start_date=None, end_date=None, warmup=60):
- """回测引擎"""
- if start_date:
- data = data[data.index >= start_date]
- if end_date:
- data = data[data.index <= end_date]
-
- results = []
- nav = 1.0
-
- for i in range(warmup, len(data)):
- curr_data = data.iloc[:i+1]
-
- # 获取信号
- position, state = strategy.generate_signal(curr_data)
-
- # 计算收益
- if i > warmup:
- daily_return = data['close'].iloc[i] / data['close'].iloc[i-1] - 1
- strategy_return = daily_return * results[-1]['position'] if results else 0
- nav *= (1 + strategy_return)
-
- results.append({
- 'date': data.index[i],
- 'position': position,
- 'nav': nav,
- 'state': state,
- 'close': data['close'].iloc[i]
- })
-
- df = pd.DataFrame(results).set_index('date')
- df['index_nav'] = df['close'] / df['close'].iloc[0]
-
- # 计算指标
- metrics = calculate_metrics(df['nav'], df['index_nav'])
- return df, metrics
- def calculate_metrics(strategy_nav, index_nav):
- """计算绩效指标"""
- s_returns = strategy_nav.pct_change().dropna()
-
- total_return = strategy_nav.iloc[-1] - 1
- days = len(strategy_nav)
- annual_return = (1 + total_return) ** (252 / days) - 1
-
- index_return = index_nav.iloc[-1] - 1
- index_annual = (1 + index_return) ** (252 / days) - 1
-
- # 最大回撤
- running_max = strategy_nav.expanding().max()
- max_dd = ((strategy_nav - running_max) / running_max).min()
-
- # 波动率和夏普
- volatility = s_returns.std() * np.sqrt(252)
- sharpe = (annual_return - 0.03) / volatility if volatility > 0 else 0
-
- # 卡玛
- calmar = annual_return / abs(max_dd) if max_dd != 0 else 0
-
- # 胜率
- win_rate = (s_returns > 0).mean()
-
- return {
- 'annual_return': annual_return,
- 'index_annual': index_annual,
- 'excess_annual': annual_return - index_annual,
- 'max_drawdown': max_dd,
- 'sharpe': sharpe,
- 'calmar': calmar,
- 'win_rate': win_rate,
- 'total_return': total_return,
- 'index_return': index_return
- }
- # ==================== 4. 可视化 ====================
- def plot_results(results, title, filename):
- """绘制回测图表"""
- fig, axes = plt.subplots(3, 1, figsize=(12, 9))
-
- # 净值曲线
- ax1 = axes[0]
- ax1.plot(results.index, results['nav'], label='Strategy', linewidth=2, color='blue')
- ax1.plot(results.index, results['index_nav'], label='Index', linewidth=1, color='gray', alpha=0.7)
- ax1.set_title(f'{title} - NAV Comparison')
- ax1.set_ylabel('NAV')
- ax1.legend()
- ax1.grid(True, alpha=0.3)
-
- # 仓位
- ax2 = axes[1]
- ax2.fill_between(results.index, 0, results['position'], alpha=0.3, color='green')
- ax2.set_ylabel('Position')
- ax2.set_ylim(0, 1.1)
- ax2.grid(True, alpha=0.3)
-
- # 回撤
- ax3 = axes[2]
- running_max = results['nav'].expanding().max()
- drawdown = (results['nav'] - running_max) / running_max
- ax3.fill_between(results.index, drawdown, 0, alpha=0.3, color='red')
- ax3.set_ylabel('Drawdown')
- ax3.set_xlabel('Date')
- ax3.grid(True, alpha=0.3)
-
- plt.tight_layout()
- plt.savefig(filename, dpi=150, bbox_inches='tight')
- print(f" 图表已保存: {filename}")
- # ==================== 5. 主程序 ====================
- def main():
- print("="*70)
- print("创业板50指数量化交易策略回测 - 简化版")
- print("="*70)
-
- # 加载真实数据
- print("\n[1] 加载真实数据...")
- data = load_real_data()
- print(f" 数据区间: {data.index[0].date()} ~ {data.index[-1].date()}")
- print(f" 共 {len(data)} 个交易日")
-
- # 训练阶段
- print("\n[2] 训练阶段 (2018-2023)...")
- strategy = SimpleCYBStrategy()
- train_results, train_metrics = backtest(data, strategy,
- start_date='2018-01-01',
- end_date='2023-12-31')
-
- print(f"\n 训练集表现:")
- print(f" - 策略年化收益: {train_metrics['annual_return']*100:>7.2f}%")
- print(f" - 指数年化收益: {train_metrics['index_annual']*100:>7.2f}%")
- print(f" - 超额收益: {train_metrics['excess_annual']*100:>7.2f}%")
- print(f" - 最大回撤: {train_metrics['max_drawdown']*100:>7.2f}%")
- print(f" - 夏普比率: {train_metrics['sharpe']:>7.2f}")
- print(f" - 卡玛比率: {train_metrics['calmar']:>7.2f}")
- print(f" - 胜率: {train_metrics['win_rate']*100:>7.1f}%")
-
- plot_results(train_results, "Training Set (2018-2023)", "train_results.png")
-
- # 验证阶段
- print("\n[3] 验证阶段 (2024-2025)...")
- strategy_val = SimpleCYBStrategy() # 使用相同参数
- val_results, val_metrics = backtest(data, strategy_val,
- start_date='2024-01-01',
- end_date='2025-12-31')
-
- print(f"\n 验证集表现:")
- print(f" - 策略年化收益: {val_metrics['annual_return']*100:>7.2f}%")
- print(f" - 指数年化收益: {val_metrics['index_annual']*100:>7.2f}%")
- print(f" - 超额收益: {val_metrics['excess_annual']*100:>7.2f}%")
- print(f" - 最大回撤: {val_metrics['max_drawdown']*100:>7.2f}%")
- print(f" - 夏普比率: {val_metrics['sharpe']:>7.2f}")
- print(f" - 卡玛比率: {val_metrics['calmar']:>7.2f}")
-
- plot_results(val_results, "Validation Set (2024-2025)", "val_results.png")
-
- # 过拟合检测
- print("\n[4] 过拟合检测:")
- sharpe_decay = (train_metrics['sharpe'] - val_metrics['sharpe']) / train_metrics['sharpe'] if train_metrics['sharpe'] != 0 else 0
- print(f" 夏普比率衰减: {sharpe_decay*100:.1f}%")
-
- if sharpe_decay > 0.5:
- print(" ⚠️ 警告:可能存在严重过拟合")
- elif sharpe_decay > 0.3:
- print(" ⚠️ 注意:轻度过拟合")
- else:
- print(" ✓ 无过拟合,策略稳健")
-
- # 总结
- print("\n" + "="*70)
- print("回测完成")
- print("="*70)
- print(f"\n输出文件:")
- print(f" - train_results.png (训练集图表)")
- print(f" - val_results.png (验证集图表)")
- if __name__ == "__main__":
- main()
|