cyb50_simple.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50指数点位量化交易策略回测 - 简化版(快速演示)
  5. 训练集:2018-2023 | 验证集:2024-2025
  6. """
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib
  10. matplotlib.use('Agg')
  11. import matplotlib.pyplot as plt
  12. from datetime import datetime
  13. import warnings
  14. warnings.filterwarnings('ignore')
  15. # 设置图表
  16. plt.rcParams['font.size'] = 10
  17. # ==================== 1. 数据加载 ====================
  18. def load_real_data():
  19. """加载创业板50指数真实数据 - cyb50_baostock.csv"""
  20. df = pd.read_csv('cyb50_baostock.csv')
  21. df['date'] = pd.to_datetime(df['date'])
  22. df = df.set_index('date').sort_index()
  23. # 转换数据类型
  24. for col in ['open', 'high', 'low', 'close', 'volume']:
  25. df[col] = pd.to_numeric(df[col], errors='coerce')
  26. print(f"真实数据加载成功: {df.index[0].date()} ~ {df.index[-1].date()}")
  27. return df
  28. # ==================== 2. 策略类 ====================
  29. class SimpleCYBStrategy:
  30. """简化版策略:双均线 + 波动率控制"""
  31. def __init__(self, params=None):
  32. self.params = params or {
  33. 'fast_ma': 20, # 快线
  34. 'slow_ma': 60, # 慢线
  35. 'volatility_period': 20,
  36. 'max_position': 1.0,
  37. 'stop_loss': 0.10,
  38. }
  39. self.position = 0
  40. self.entry_price = None
  41. def generate_signal(self, data):
  42. """生成交易信号"""
  43. close = data['close']
  44. high = data['high']
  45. low = data['low']
  46. p = self.params
  47. # 计算均线
  48. ma_fast = close.rolling(p['fast_ma']).mean().iloc[-1]
  49. ma_slow = close.rolling(p['slow_ma']).mean().iloc[-1]
  50. # 计算波动率(20日ATR/价格)
  51. tr = pd.concat([
  52. high - low,
  53. abs(high - close.shift(1)),
  54. abs(low - close.shift(1))
  55. ], axis=1).max(axis=1)
  56. atr = tr.rolling(p['volatility_period']).mean().iloc[-1]
  57. vol_pct = atr / close.iloc[-1] * 100
  58. # 趋势判断
  59. trend_up = (close.iloc[-1] > ma_fast) and (ma_fast > ma_slow)
  60. trend_down = (close.iloc[-1] < ma_fast) and (ma_fast < ma_slow)
  61. # 仓位决策
  62. if trend_up and vol_pct < 4:
  63. target_pos = p['max_position'] # 满仓
  64. state = "BULL"
  65. elif trend_down or vol_pct > 6:
  66. target_pos = 0 # 空仓
  67. state = "BEAR" if trend_down else "HIGH_VOL"
  68. else:
  69. target_pos = p['max_position'] * 0.5 # 半仓
  70. state = "OSCILLATE"
  71. # 止损检查
  72. if self.position > 0 and self.entry_price:
  73. drawdown = (close.iloc[-1] - self.entry_price) / self.entry_price
  74. if drawdown < -p['stop_loss']:
  75. target_pos = 0
  76. # 更新
  77. if target_pos > 0 and self.position == 0:
  78. self.entry_price = close.iloc[-1]
  79. if target_pos == 0:
  80. self.entry_price = None
  81. self.position = target_pos
  82. return target_pos, state
  83. # ==================== 3. 回测引擎 ====================
  84. def backtest(data, strategy, start_date=None, end_date=None, warmup=60):
  85. """回测引擎"""
  86. if start_date:
  87. data = data[data.index >= start_date]
  88. if end_date:
  89. data = data[data.index <= end_date]
  90. results = []
  91. nav = 1.0
  92. for i in range(warmup, len(data)):
  93. curr_data = data.iloc[:i+1]
  94. # 获取信号
  95. position, state = strategy.generate_signal(curr_data)
  96. # 计算收益
  97. if i > warmup:
  98. daily_return = data['close'].iloc[i] / data['close'].iloc[i-1] - 1
  99. strategy_return = daily_return * results[-1]['position'] if results else 0
  100. nav *= (1 + strategy_return)
  101. results.append({
  102. 'date': data.index[i],
  103. 'position': position,
  104. 'nav': nav,
  105. 'state': state,
  106. 'close': data['close'].iloc[i]
  107. })
  108. df = pd.DataFrame(results).set_index('date')
  109. df['index_nav'] = df['close'] / df['close'].iloc[0]
  110. # 计算指标
  111. metrics = calculate_metrics(df['nav'], df['index_nav'])
  112. return df, metrics
  113. def calculate_metrics(strategy_nav, index_nav):
  114. """计算绩效指标"""
  115. s_returns = strategy_nav.pct_change().dropna()
  116. total_return = strategy_nav.iloc[-1] - 1
  117. days = len(strategy_nav)
  118. annual_return = (1 + total_return) ** (252 / days) - 1
  119. index_return = index_nav.iloc[-1] - 1
  120. index_annual = (1 + index_return) ** (252 / days) - 1
  121. # 最大回撤
  122. running_max = strategy_nav.expanding().max()
  123. max_dd = ((strategy_nav - running_max) / running_max).min()
  124. # 波动率和夏普
  125. volatility = s_returns.std() * np.sqrt(252)
  126. sharpe = (annual_return - 0.03) / volatility if volatility > 0 else 0
  127. # 卡玛
  128. calmar = annual_return / abs(max_dd) if max_dd != 0 else 0
  129. # 胜率
  130. win_rate = (s_returns > 0).mean()
  131. return {
  132. 'annual_return': annual_return,
  133. 'index_annual': index_annual,
  134. 'excess_annual': annual_return - index_annual,
  135. 'max_drawdown': max_dd,
  136. 'sharpe': sharpe,
  137. 'calmar': calmar,
  138. 'win_rate': win_rate,
  139. 'total_return': total_return,
  140. 'index_return': index_return
  141. }
  142. # ==================== 4. 可视化 ====================
  143. def plot_results(results, title, filename):
  144. """绘制回测图表"""
  145. fig, axes = plt.subplots(3, 1, figsize=(12, 9))
  146. # 净值曲线
  147. ax1 = axes[0]
  148. ax1.plot(results.index, results['nav'], label='Strategy', linewidth=2, color='blue')
  149. ax1.plot(results.index, results['index_nav'], label='Index', linewidth=1, color='gray', alpha=0.7)
  150. ax1.set_title(f'{title} - NAV Comparison')
  151. ax1.set_ylabel('NAV')
  152. ax1.legend()
  153. ax1.grid(True, alpha=0.3)
  154. # 仓位
  155. ax2 = axes[1]
  156. ax2.fill_between(results.index, 0, results['position'], alpha=0.3, color='green')
  157. ax2.set_ylabel('Position')
  158. ax2.set_ylim(0, 1.1)
  159. ax2.grid(True, alpha=0.3)
  160. # 回撤
  161. ax3 = axes[2]
  162. running_max = results['nav'].expanding().max()
  163. drawdown = (results['nav'] - running_max) / running_max
  164. ax3.fill_between(results.index, drawdown, 0, alpha=0.3, color='red')
  165. ax3.set_ylabel('Drawdown')
  166. ax3.set_xlabel('Date')
  167. ax3.grid(True, alpha=0.3)
  168. plt.tight_layout()
  169. plt.savefig(filename, dpi=150, bbox_inches='tight')
  170. print(f" 图表已保存: {filename}")
  171. # ==================== 5. 主程序 ====================
  172. def main():
  173. print("="*70)
  174. print("创业板50指数量化交易策略回测 - 简化版")
  175. print("="*70)
  176. # 加载真实数据
  177. print("\n[1] 加载真实数据...")
  178. data = load_real_data()
  179. print(f" 数据区间: {data.index[0].date()} ~ {data.index[-1].date()}")
  180. print(f" 共 {len(data)} 个交易日")
  181. # 训练阶段
  182. print("\n[2] 训练阶段 (2018-2023)...")
  183. strategy = SimpleCYBStrategy()
  184. train_results, train_metrics = backtest(data, strategy,
  185. start_date='2018-01-01',
  186. end_date='2023-12-31')
  187. print(f"\n 训练集表现:")
  188. print(f" - 策略年化收益: {train_metrics['annual_return']*100:>7.2f}%")
  189. print(f" - 指数年化收益: {train_metrics['index_annual']*100:>7.2f}%")
  190. print(f" - 超额收益: {train_metrics['excess_annual']*100:>7.2f}%")
  191. print(f" - 最大回撤: {train_metrics['max_drawdown']*100:>7.2f}%")
  192. print(f" - 夏普比率: {train_metrics['sharpe']:>7.2f}")
  193. print(f" - 卡玛比率: {train_metrics['calmar']:>7.2f}")
  194. print(f" - 胜率: {train_metrics['win_rate']*100:>7.1f}%")
  195. plot_results(train_results, "Training Set (2018-2023)", "train_results.png")
  196. # 验证阶段
  197. print("\n[3] 验证阶段 (2024-2025)...")
  198. strategy_val = SimpleCYBStrategy() # 使用相同参数
  199. val_results, val_metrics = backtest(data, strategy_val,
  200. start_date='2024-01-01',
  201. end_date='2025-12-31')
  202. print(f"\n 验证集表现:")
  203. print(f" - 策略年化收益: {val_metrics['annual_return']*100:>7.2f}%")
  204. print(f" - 指数年化收益: {val_metrics['index_annual']*100:>7.2f}%")
  205. print(f" - 超额收益: {val_metrics['excess_annual']*100:>7.2f}%")
  206. print(f" - 最大回撤: {val_metrics['max_drawdown']*100:>7.2f}%")
  207. print(f" - 夏普比率: {val_metrics['sharpe']:>7.2f}")
  208. print(f" - 卡玛比率: {val_metrics['calmar']:>7.2f}")
  209. plot_results(val_results, "Validation Set (2024-2025)", "val_results.png")
  210. # 过拟合检测
  211. print("\n[4] 过拟合检测:")
  212. sharpe_decay = (train_metrics['sharpe'] - val_metrics['sharpe']) / train_metrics['sharpe'] if train_metrics['sharpe'] != 0 else 0
  213. print(f" 夏普比率衰减: {sharpe_decay*100:.1f}%")
  214. if sharpe_decay > 0.5:
  215. print(" ⚠️ 警告:可能存在严重过拟合")
  216. elif sharpe_decay > 0.3:
  217. print(" ⚠️ 注意:轻度过拟合")
  218. else:
  219. print(" ✓ 无过拟合,策略稳健")
  220. # 总结
  221. print("\n" + "="*70)
  222. print("回测完成")
  223. print("="*70)
  224. print(f"\n输出文件:")
  225. print(f" - train_results.png (训练集图表)")
  226. print(f" - val_results.png (验证集图表)")
  227. if __name__ == "__main__":
  228. main()