cyb50_strategy.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50指数点位量化交易策略回测框架
  5. 训练集:2017-2023 | 验证集:2024-2025
  6. """
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from datetime import datetime, timedelta
  11. import warnings
  12. warnings.filterwarnings('ignore')
  13. # 设置中文显示
  14. plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
  15. plt.rcParams['axes.unicode_minus'] = False
  16. # ==================== 1. 数据加载 ====================
  17. def load_real_data():
  18. """加载真实数据 - cyb50_baostock.csv"""
  19. df = pd.read_csv('cyb50_baostock.csv')
  20. df['date'] = pd.to_datetime(df['date'])
  21. df = df.set_index('date').sort_index()
  22. # 转换数据类型
  23. for col in ['open', 'high', 'low', 'close', 'volume']:
  24. df[col] = pd.to_numeric(df[col], errors='coerce')
  25. print(f"真实数据加载成功: {df.index[0].date()} ~ {df.index[-1].date()}")
  26. return df
  27. # ==================== 2. 技术指标计算 ====================
  28. def calculate_atr(high, low, close, period=20):
  29. """计算ATR(平均真实波幅)"""
  30. tr1 = high - low
  31. tr2 = abs(high - close.shift(1))
  32. tr3 = abs(low - close.shift(1))
  33. tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
  34. atr = tr.rolling(window=period).mean()
  35. return atr
  36. def calculate_rsrs(high, low, n=20, m=250):
  37. """
  38. 计算RSRS指标(阻力支撑相对强度)- 优化版
  39. 返回: (rsrs_score, r_squared)
  40. """
  41. # 使用rolling计算斜率和R²
  42. def rolling_beta(x):
  43. if len(x) < n or np.std(x[:n//2]) == 0:
  44. return np.nan
  45. low_vals = x[:n//2]
  46. high_vals = x[n//2:]
  47. if np.std(low_vals) == 0:
  48. return 0
  49. beta = np.corrcoef(low_vals, high_vals)[0,1] * np.std(high_vals) / np.std(low_vals)
  50. return beta
  51. # 简化计算:使用 rolling.apply
  52. slopes = pd.Series(index=high.index, dtype=float)
  53. r2s = pd.Series(index=high.index, dtype=float)
  54. for i in range(n-1, len(high)):
  55. low_window = low.iloc[i-n+1:i+1].values
  56. high_window = high.iloc[i-n+1:i+1].values
  57. if np.std(low_window) > 0:
  58. beta = np.corrcoef(low_window, high_window)[0,1] * np.std(high_window) / np.std(low_window)
  59. # R²
  60. y_pred = np.mean(high_window) + beta * (low_window - np.mean(low_window))
  61. ss_res = np.sum((high_window - y_pred) ** 2)
  62. ss_tot = np.sum((high_window - np.mean(high_window)) ** 2)
  63. r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
  64. slopes.iloc[i] = beta
  65. r2s.iloc[i] = r2
  66. # 计算标准分(滚动M日)
  67. rsrs = pd.Series(index=high.index, dtype=float)
  68. for i in range(m+n-2, len(slopes)):
  69. slope_window = slopes.iloc[i-m+1:i+1]
  70. if slope_window.std() > 0:
  71. zscore = (slopes.iloc[i] - slope_window.mean()) / slope_window.std()
  72. rsrs.iloc[i] = zscore * r2s.iloc[i]
  73. return rsrs, r2s
  74. def calculate_rsi(close, period=14):
  75. """计算RSI指标"""
  76. delta = close.diff()
  77. gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
  78. loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
  79. rs = gain / loss
  80. rsi = 100 - (100 / (1 + rs))
  81. return rsi
  82. # ==================== 3. 市场状态判断 ====================
  83. def classify_state(close, ma20, ma60, ma120, atr_percent, rsrs, close_series):
  84. """
  85. 判断市场状态:BULL(牛市)/ BEAR(熊市)/ OSCILLATE(震荡)
  86. close_series: 用于计算涨跌幅的收盘价序列
  87. """
  88. # 趋势得分
  89. trend_score = 0
  90. if close > ma20: trend_score += 1
  91. if ma20 > ma60: trend_score += 1
  92. if ma60 > ma120: trend_score += 1
  93. # 波动率
  94. high_volatility = atr_percent > 5
  95. # 熔断检测(极端情况)
  96. daily_return = close_series.pct_change().iloc[-1] if len(close_series) > 1 else 0
  97. crash = daily_return < -0.07
  98. if crash or (trend_score <= 1 and high_volatility and close < ma60):
  99. return "BEAR"
  100. elif trend_score >= 2 and not high_volatility:
  101. return "BULL"
  102. else:
  103. return "OSCILLATE"
  104. # ==================== 4. 策略核心 ====================
  105. class CYB50Strategy:
  106. """创业板50指数交易策略"""
  107. def __init__(self, params=None):
  108. # 默认参数
  109. self.params = params or {
  110. 'rsrs_n': 20,
  111. 'rsrs_m': 250,
  112. 'bull_buy': 0.5,
  113. 'bull_sell': -0.7,
  114. 'bear_buy': 1.5,
  115. 'bull_max': 1.0,
  116. 'osc_max': 0.6,
  117. 'stop_loss': 0.10,
  118. 'min_change': 0.20
  119. }
  120. self.current_position = 0
  121. self.state_history = []
  122. self.entry_price = None
  123. def generate_signal(self, data):
  124. """生成交易信号"""
  125. close = data['close']
  126. high = data['high']
  127. low = data['low']
  128. # 计算指标
  129. rsrs, r2 = calculate_rsrs(high, low,
  130. self.params['rsrs_n'],
  131. self.params['rsrs_m'])
  132. ma20 = close.rolling(20).mean()
  133. ma60 = close.rolling(60).mean()
  134. ma120 = close.rolling(120).mean()
  135. atr = calculate_atr(high, low, close, 20)
  136. atr_percent = atr / close * 100
  137. # 获取当前值
  138. curr_rsrs = rsrs.iloc[-1]
  139. curr_close = close.iloc[-1]
  140. curr_ma20 = ma20.iloc[-1]
  141. curr_ma60 = ma60.iloc[-1]
  142. curr_ma120 = ma120.iloc[-1]
  143. curr_atr_pct = atr_percent.iloc[-1]
  144. # 检查是否有足够数据
  145. if pd.isna(curr_rsrs):
  146. return 0, "INIT"
  147. # 判断市场状态
  148. state = classify_state(curr_close, curr_ma20, curr_ma60,
  149. curr_ma120, curr_atr_pct, curr_rsrs, close)
  150. # 状态防抖(连续3日确认,极端情况除外)
  151. self.state_history.append(state)
  152. if len(self.state_history) >= 3:
  153. # 熔断检测:单日大跌立即转熊
  154. daily_return = close.pct_change().iloc[-1]
  155. if daily_return < -0.07:
  156. state = "BEAR"
  157. elif len(self.state_history) >= 3:
  158. # 正常防抖
  159. recent_states = self.state_history[-3:]
  160. if len(set(recent_states)) > 1:
  161. state = self.state_history[-2] if len(self.state_history) >= 2 else state
  162. # 根据状态确定仓位
  163. target_pos = self._calculate_position(state, curr_rsrs, curr_atr_pct)
  164. # 止损检查
  165. if self.entry_price is not None and self.current_position > 0:
  166. current_drawdown = (curr_close - self.entry_price) / self.entry_price
  167. if current_drawdown < -self.params['stop_loss']:
  168. target_pos = 0
  169. self.entry_price = None
  170. # 最小调仓幅度过滤
  171. if abs(target_pos - self.current_position) < self.params['min_change']:
  172. target_pos = self.current_position
  173. # 更新入场价
  174. if target_pos > 0 and self.current_position == 0:
  175. self.entry_price = curr_close
  176. elif target_pos == 0:
  177. self.entry_price = None
  178. self.current_position = target_pos
  179. return target_pos, state
  180. def _calculate_position(self, state, rsrs, atr_percent):
  181. """根据状态和指标计算目标仓位"""
  182. p = self.params
  183. if state == "BULL":
  184. if rsrs > p['bull_buy']:
  185. pos = p['bull_max']
  186. elif rsrs < p['bull_sell']:
  187. pos = 0
  188. else:
  189. pos = p['bull_max'] * 0.5
  190. elif state == "BEAR":
  191. # 熊市:空仓为主
  192. if rsrs < -p['bear_buy']:
  193. pos = 0.1 # 极端超卖,10%仓位博反弹
  194. else:
  195. pos = 0
  196. else: # OSCILLATE
  197. if rsrs > 0.7:
  198. pos = p['osc_max']
  199. elif rsrs < -0.7:
  200. pos = 0
  201. else:
  202. pos = p['osc_max'] * 0.5
  203. # 波动率调整
  204. if atr_percent > 5:
  205. pos *= 0.6
  206. return np.clip(pos, 0, 1)
  207. # ==================== 5. 回测引擎 ====================
  208. def backtest(data, strategy, initial_capital=1000000, start_date=None, end_date=None):
  209. """
  210. 回测引擎
  211. """
  212. # 数据切片
  213. if start_date:
  214. data = data[data.index >= start_date]
  215. if end_date:
  216. data = data[data.index <= end_date]
  217. dates = []
  218. positions = []
  219. navs = []
  220. states = []
  221. capital = initial_capital
  222. current_nav = 1.0
  223. # 跳过前250日(warm-up)
  224. start_idx = 250
  225. for i in range(start_idx, len(data)):
  226. curr_data = data.iloc[:i+1]
  227. curr_date = data.index[i]
  228. # 获取信号
  229. position, state = strategy.generate_signal(curr_data)
  230. # 计算当日收益(使用前一日仓位)
  231. if i > start_idx:
  232. daily_return = data['close'].iloc[i] / data['close'].iloc[i-1] - 1
  233. prev_position = positions[-1] if positions else 0
  234. strategy_return = daily_return * prev_position
  235. current_nav *= (1 + strategy_return)
  236. dates.append(curr_date)
  237. positions.append(position)
  238. navs.append(current_nav)
  239. states.append(state)
  240. # 构建结果DataFrame
  241. results = pd.DataFrame({
  242. 'date': dates,
  243. 'position': positions,
  244. 'nav': navs,
  245. 'state': states
  246. }).set_index('date')
  247. # 计算指数基准
  248. index_data = data.iloc[start_idx:].copy()
  249. results['index_close'] = index_data['close']
  250. results['index_nav'] = results['index_close'] / results['index_close'].iloc[0]
  251. # 计算指标
  252. metrics = calculate_metrics(results['nav'], results['index_nav'])
  253. return results, metrics
  254. def calculate_metrics(strategy_nav, index_nav):
  255. """计算回测指标"""
  256. # 收益率
  257. total_return = strategy_nav.iloc[-1] - 1
  258. days = len(strategy_nav)
  259. annual_return = (1 + total_return) ** (252 / days) - 1
  260. # 指数收益
  261. index_return = index_nav.iloc[-1] - 1
  262. index_annual = (1 + index_return) ** (252 / days) - 1
  263. # 最大回撤
  264. running_max = strategy_nav.expanding().max()
  265. drawdown = (strategy_nav - running_max) / running_max
  266. max_drawdown = drawdown.min()
  267. # 波动率
  268. daily_returns = strategy_nav.pct_change().dropna()
  269. volatility = daily_returns.std() * np.sqrt(252)
  270. # 夏普比率(假设无风险利率3%)
  271. excess_return = annual_return - 0.03
  272. sharpe = excess_return / volatility if volatility > 0 else 0
  273. # 卡玛比率
  274. calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
  275. # 胜率
  276. positive_days = (daily_returns > 0).sum()
  277. total_days = len(daily_returns)
  278. win_rate = positive_days / total_days
  279. # Beta
  280. index_returns = index_nav.pct_change().dropna()
  281. covariance = daily_returns.cov(index_returns)
  282. variance = index_returns.var()
  283. beta = covariance / variance if variance > 0 else 1
  284. # 年化超额收益
  285. excess_annual = annual_return - index_annual
  286. return {
  287. 'total_return': total_return,
  288. 'annual_return': annual_return,
  289. 'index_return': index_return,
  290. 'index_annual': index_annual,
  291. 'excess_annual': excess_annual,
  292. 'max_drawdown': max_drawdown,
  293. 'volatility': volatility,
  294. 'sharpe': sharpe,
  295. 'calmar': calmar,
  296. 'win_rate': win_rate,
  297. 'beta': beta,
  298. 'trading_days': days
  299. }
  300. # ==================== 6. 参数优化 ====================
  301. def grid_search(data, param_grid):
  302. """网格搜索最优参数 - 简化版"""
  303. best_score = -999
  304. best_params = None
  305. best_metrics = None
  306. # 只测试1组参数(加速演示)
  307. test_params = [
  308. {'rsrs_n': 20, 'rsrs_m': 250, 'bull_buy': 0.5, 'bull_sell': -0.7,
  309. 'bear_buy': 1.5, 'bull_max': 1.0, 'osc_max': 0.6, 'stop_loss': 0.10, 'min_change': 0.20},
  310. ]
  311. for params in test_params:
  312. print(f"测试参数: {params}")
  313. strategy = CYB50Strategy(params)
  314. results, metrics = backtest(data, strategy, start_date='2018-02-01', end_date='2023-12-31')
  315. # 综合评分
  316. score = metrics['sharpe'] * 0.4 + metrics['calmar'] * 0.4 + metrics['excess_annual'] * 2
  317. print(f" 年化: {metrics['annual_return']*100:.1f}%, 回撤: {metrics['max_drawdown']*100:.1f}%, 夏普: {metrics['sharpe']:.2f}, 评分: {score:.2f}")
  318. if score > best_score and metrics['max_drawdown'] > -0.40:
  319. best_score = score
  320. best_params = params
  321. best_metrics = metrics
  322. return best_params, best_metrics
  323. # ==================== 7. 可视化 ====================
  324. def plot_results(results, title="Backtest Results"):
  325. """绘制回测结果"""
  326. fig, axes = plt.subplots(3, 1, figsize=(12, 10))
  327. # 净值曲线
  328. ax1 = axes[0]
  329. ax1.plot(results.index, results['nav'], label='Strategy', linewidth=2)
  330. ax1.plot(results.index, results['index_nav'], label='Index', linewidth=1, alpha=0.7)
  331. ax1.set_title(f'{title} - NAV')
  332. ax1.set_ylabel('NAV')
  333. ax1.legend()
  334. ax1.grid(True, alpha=0.3)
  335. # 仓位变化
  336. ax2 = axes[1]
  337. ax2.fill_between(results.index, 0, results['position'], alpha=0.3, label='Position')
  338. ax2.set_ylabel('Position')
  339. ax2.set_ylim(0, 1.1)
  340. ax2.legend()
  341. ax2.grid(True, alpha=0.3)
  342. # 回撤
  343. ax3 = axes[2]
  344. running_max = results['nav'].expanding().max()
  345. drawdown = (results['nav'] - running_max) / running_max
  346. ax3.fill_between(results.index, drawdown, 0, alpha=0.3, color='red')
  347. ax3.set_ylabel('Drawdown')
  348. ax3.set_xlabel('Date')
  349. ax3.grid(True, alpha=0.3)
  350. plt.tight_layout()
  351. return fig
  352. # ==================== 8. 主程序 ====================
  353. def main():
  354. print("="*60)
  355. print("创业板50指数量化交易策略回测")
  356. print("="*60)
  357. # 加载真实数据
  358. print("\n[1] 加载真实数据...")
  359. data = load_real_data()
  360. print(f"数据区间: {data.index[0]} ~ {data.index[-1]}, 共{len(data)}个交易日")
  361. # 划分训练集和验证集
  362. train_end = '2023-12-31'
  363. val_start = '2024-01-01'
  364. # 训练阶段(参数优化)
  365. print("\n[2] 训练阶段:参数优化 (2018-2023)...")
  366. best_params, train_metrics = grid_search(data, None)
  367. print(f"\n最优参数:")
  368. for k, v in best_params.items():
  369. print(f" {k}: {v}")
  370. print(f"\n训练集表现 (2018-2023):")
  371. print(f" 策略年化收益: {train_metrics['annual_return']*100:.2f}%")
  372. print(f" 指数年化收益: {train_metrics['index_annual']*100:.2f}%")
  373. print(f" 超额收益: {train_metrics['excess_annual']*100:.2f}%")
  374. print(f" 最大回撤: {train_metrics['max_drawdown']*100:.2f}%")
  375. print(f" 夏普比率: {train_metrics['sharpe']:.2f}")
  376. print(f" 卡玛比率: {train_metrics['calmar']:.2f}")
  377. print(f" 胜率: {train_metrics['win_rate']*100:.1f}%")
  378. print(f" Beta: {train_metrics['beta']:.2f}")
  379. # 使用最优参数回测训练集(获取完整结果)
  380. strategy = CYB50Strategy(best_params)
  381. train_results, _ = backtest(data, strategy, start_date='2018-02-01', end_date=train_end)
  382. # 验证阶段(样本外)
  383. print(f"\n[3] 验证阶段:样本外测试 (2024-2025)...")
  384. strategy_val = CYB50Strategy(best_params)
  385. val_results, val_metrics = backtest(data, strategy_val, start_date=val_start, end_date='2025-12-31')
  386. print(f"\n验证集表现 (2024-2025):")
  387. print(f" 策略年化收益: {val_metrics['annual_return']*100:.2f}%")
  388. print(f" 指数年化收益: {val_metrics['index_annual']*100:.2f}%")
  389. print(f" 超额收益: {val_metrics['excess_annual']*100:.2f}%")
  390. print(f" 最大回撤: {val_metrics['max_drawdown']*100:.2f}%")
  391. print(f" 夏普比率: {val_metrics['sharpe']:.2f}")
  392. print(f" 卡玛比率: {val_metrics['calmar']:.2f}")
  393. # 过拟合检测
  394. sharpe_decay = (train_metrics['sharpe'] - val_metrics['sharpe']) / train_metrics['sharpe'] if train_metrics['sharpe'] != 0 else 0
  395. print(f"\n[4] 过拟合检测:")
  396. print(f" 夏普比率衰减: {sharpe_decay*100:.1f}%")
  397. if sharpe_decay > 0.5:
  398. print(" ⚠️ 警告:可能存在严重过拟合")
  399. elif sharpe_decay > 0.3:
  400. print(" ⚠️ 注意:轻度过拟合,建议简化参数")
  401. else:
  402. print(" ✓ 无过拟合,策略稳健")
  403. # 保存结果
  404. print(f"\n[5] 保存结果...")
  405. train_results.to_csv('train_results.csv')
  406. val_results.to_csv('val_results.csv')
  407. print(" 训练集结果: train_results.csv")
  408. print(" 验证集结果: val_results.csv")
  409. # 绘图
  410. print(f"\n[6] 生成图表...")
  411. fig1 = plot_results(train_results, "Training Set (2018-2023)")
  412. fig1.savefig('train_backtest.png', dpi=150, bbox_inches='tight')
  413. print(" 训练集图表: train_backtest.png")
  414. fig2 = plot_results(val_results, "Validation Set (2024-2025)")
  415. fig2.savefig('val_backtest.png', dpi=150, bbox_inches='tight')
  416. print(" 验证集图表: val_backtest.png")
  417. print("\n" + "="*60)
  418. print("回测完成")
  419. print("="*60)
  420. if __name__ == "__main__":
  421. main()