cyb50_simple.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50指数点位量化交易策略回测 - 简化版(快速演示)
  5. 训练集:2018-2023 | 验证集:2024-2025
  6. """
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib
  10. matplotlib.use('Agg')
  11. import matplotlib.pyplot as plt
  12. from datetime import datetime
  13. import warnings
  14. warnings.filterwarnings('ignore')
  15. # 设置图表
  16. plt.rcParams['font.size'] = 10
  17. # ==================== 1. 数据生成 ====================
  18. def generate_index_data(start='2017-01-01', end='2025-12-31', seed=42):
  19. """生成模拟创业板50指数数据"""
  20. np.random.seed(seed)
  21. dates = pd.date_range(start=start, end=end, freq='D')
  22. dates = dates[dates.dayofweek < 5] # 去掉周末
  23. n = len(dates)
  24. # 生成收益率序列(带趋势和波动)
  25. returns = np.random.normal(0.0002, 0.015, n)
  26. # 2019-2021牛市
  27. bull = (dates >= '2019-01-01') & (dates <= '2021-12-31')
  28. returns[bull] += 0.0008
  29. # 2022熊市
  30. bear = (dates >= '2022-01-01') & (dates <= '2022-12-31')
  31. returns[bear] -= 0.001
  32. # 2024-2025震荡
  33. osc = dates >= '2024-01-01'
  34. returns[osc] = np.random.normal(0, 0.012, sum(osc))
  35. # 计算价格
  36. price = 1800 * np.exp(np.cumsum(returns))
  37. df = pd.DataFrame(index=dates)
  38. df['close'] = price
  39. df['open'] = price * (1 + np.random.normal(0, 0.005, n))
  40. df['high'] = df[['open', 'close']].max(axis=1) * (1 + np.abs(np.random.normal(0, 0.008, n)))
  41. df['low'] = df[['open', 'close']].min(axis=1) * (1 - np.abs(np.random.normal(0, 0.008, n)))
  42. return df
  43. # ==================== 2. 策略类 ====================
  44. class SimpleCYBStrategy:
  45. """简化版策略:双均线 + 波动率控制"""
  46. def __init__(self, params=None):
  47. self.params = params or {
  48. 'fast_ma': 20, # 快线
  49. 'slow_ma': 60, # 慢线
  50. 'volatility_period': 20,
  51. 'max_position': 1.0,
  52. 'stop_loss': 0.10,
  53. }
  54. self.position = 0
  55. self.entry_price = None
  56. def generate_signal(self, data):
  57. """生成交易信号"""
  58. close = data['close']
  59. high = data['high']
  60. low = data['low']
  61. p = self.params
  62. # 计算均线
  63. ma_fast = close.rolling(p['fast_ma']).mean().iloc[-1]
  64. ma_slow = close.rolling(p['slow_ma']).mean().iloc[-1]
  65. # 计算波动率(20日ATR/价格)
  66. tr = pd.concat([
  67. high - low,
  68. abs(high - close.shift(1)),
  69. abs(low - close.shift(1))
  70. ], axis=1).max(axis=1)
  71. atr = tr.rolling(p['volatility_period']).mean().iloc[-1]
  72. vol_pct = atr / close.iloc[-1] * 100
  73. # 趋势判断
  74. trend_up = (close.iloc[-1] > ma_fast) and (ma_fast > ma_slow)
  75. trend_down = (close.iloc[-1] < ma_fast) and (ma_fast < ma_slow)
  76. # 仓位决策
  77. if trend_up and vol_pct < 4:
  78. target_pos = p['max_position'] # 满仓
  79. state = "BULL"
  80. elif trend_down or vol_pct > 6:
  81. target_pos = 0 # 空仓
  82. state = "BEAR" if trend_down else "HIGH_VOL"
  83. else:
  84. target_pos = p['max_position'] * 0.5 # 半仓
  85. state = "OSCILLATE"
  86. # 止损检查
  87. if self.position > 0 and self.entry_price:
  88. drawdown = (close.iloc[-1] - self.entry_price) / self.entry_price
  89. if drawdown < -p['stop_loss']:
  90. target_pos = 0
  91. # 更新
  92. if target_pos > 0 and self.position == 0:
  93. self.entry_price = close.iloc[-1]
  94. if target_pos == 0:
  95. self.entry_price = None
  96. self.position = target_pos
  97. return target_pos, state
  98. # ==================== 3. 回测引擎 ====================
  99. def backtest(data, strategy, start_date=None, end_date=None, warmup=60):
  100. """回测引擎"""
  101. if start_date:
  102. data = data[data.index >= start_date]
  103. if end_date:
  104. data = data[data.index <= end_date]
  105. results = []
  106. nav = 1.0
  107. for i in range(warmup, len(data)):
  108. curr_data = data.iloc[:i+1]
  109. # 获取信号
  110. position, state = strategy.generate_signal(curr_data)
  111. # 计算收益
  112. if i > warmup:
  113. daily_return = data['close'].iloc[i] / data['close'].iloc[i-1] - 1
  114. strategy_return = daily_return * results[-1]['position'] if results else 0
  115. nav *= (1 + strategy_return)
  116. results.append({
  117. 'date': data.index[i],
  118. 'position': position,
  119. 'nav': nav,
  120. 'state': state,
  121. 'close': data['close'].iloc[i]
  122. })
  123. df = pd.DataFrame(results).set_index('date')
  124. df['index_nav'] = df['close'] / df['close'].iloc[0]
  125. # 计算指标
  126. metrics = calculate_metrics(df['nav'], df['index_nav'])
  127. return df, metrics
  128. def calculate_metrics(strategy_nav, index_nav):
  129. """计算绩效指标"""
  130. s_returns = strategy_nav.pct_change().dropna()
  131. total_return = strategy_nav.iloc[-1] - 1
  132. days = len(strategy_nav)
  133. annual_return = (1 + total_return) ** (252 / days) - 1
  134. index_return = index_nav.iloc[-1] - 1
  135. index_annual = (1 + index_return) ** (252 / days) - 1
  136. # 最大回撤
  137. running_max = strategy_nav.expanding().max()
  138. max_dd = ((strategy_nav - running_max) / running_max).min()
  139. # 波动率和夏普
  140. volatility = s_returns.std() * np.sqrt(252)
  141. sharpe = (annual_return - 0.03) / volatility if volatility > 0 else 0
  142. # 卡玛
  143. calmar = annual_return / abs(max_dd) if max_dd != 0 else 0
  144. # 胜率
  145. win_rate = (s_returns > 0).mean()
  146. return {
  147. 'annual_return': annual_return,
  148. 'index_annual': index_annual,
  149. 'excess_annual': annual_return - index_annual,
  150. 'max_drawdown': max_dd,
  151. 'sharpe': sharpe,
  152. 'calmar': calmar,
  153. 'win_rate': win_rate,
  154. 'total_return': total_return,
  155. 'index_return': index_return
  156. }
  157. # ==================== 4. 可视化 ====================
  158. def plot_results(results, title, filename):
  159. """绘制回测图表"""
  160. fig, axes = plt.subplots(3, 1, figsize=(12, 9))
  161. # 净值曲线
  162. ax1 = axes[0]
  163. ax1.plot(results.index, results['nav'], label='Strategy', linewidth=2, color='blue')
  164. ax1.plot(results.index, results['index_nav'], label='Index', linewidth=1, color='gray', alpha=0.7)
  165. ax1.set_title(f'{title} - NAV Comparison')
  166. ax1.set_ylabel('NAV')
  167. ax1.legend()
  168. ax1.grid(True, alpha=0.3)
  169. # 仓位
  170. ax2 = axes[1]
  171. ax2.fill_between(results.index, 0, results['position'], alpha=0.3, color='green')
  172. ax2.set_ylabel('Position')
  173. ax2.set_ylim(0, 1.1)
  174. ax2.grid(True, alpha=0.3)
  175. # 回撤
  176. ax3 = axes[2]
  177. running_max = results['nav'].expanding().max()
  178. drawdown = (results['nav'] - running_max) / running_max
  179. ax3.fill_between(results.index, drawdown, 0, alpha=0.3, color='red')
  180. ax3.set_ylabel('Drawdown')
  181. ax3.set_xlabel('Date')
  182. ax3.grid(True, alpha=0.3)
  183. plt.tight_layout()
  184. plt.savefig(filename, dpi=150, bbox_inches='tight')
  185. print(f" 图表已保存: {filename}")
  186. # ==================== 5. 主程序 ====================
  187. def main():
  188. print("="*70)
  189. print("创业板50指数量化交易策略回测 - 简化版")
  190. print("="*70)
  191. # 生成数据
  192. print("\n[1] 生成模拟数据...")
  193. data = generate_index_data()
  194. print(f" 数据区间: {data.index[0].date()} ~ {data.index[-1].date()}")
  195. print(f" 共 {len(data)} 个交易日")
  196. # 训练阶段
  197. print("\n[2] 训练阶段 (2018-2023)...")
  198. strategy = SimpleCYBStrategy()
  199. train_results, train_metrics = backtest(data, strategy,
  200. start_date='2018-01-01',
  201. end_date='2023-12-31')
  202. print(f"\n 训练集表现:")
  203. print(f" - 策略年化收益: {train_metrics['annual_return']*100:>7.2f}%")
  204. print(f" - 指数年化收益: {train_metrics['index_annual']*100:>7.2f}%")
  205. print(f" - 超额收益: {train_metrics['excess_annual']*100:>7.2f}%")
  206. print(f" - 最大回撤: {train_metrics['max_drawdown']*100:>7.2f}%")
  207. print(f" - 夏普比率: {train_metrics['sharpe']:>7.2f}")
  208. print(f" - 卡玛比率: {train_metrics['calmar']:>7.2f}")
  209. print(f" - 胜率: {train_metrics['win_rate']*100:>7.1f}%")
  210. plot_results(train_results, "Training Set (2018-2023)", "train_results.png")
  211. # 验证阶段
  212. print("\n[3] 验证阶段 (2024-2025)...")
  213. strategy_val = SimpleCYBStrategy() # 使用相同参数
  214. val_results, val_metrics = backtest(data, strategy_val,
  215. start_date='2024-01-01',
  216. end_date='2025-12-31')
  217. print(f"\n 验证集表现:")
  218. print(f" - 策略年化收益: {val_metrics['annual_return']*100:>7.2f}%")
  219. print(f" - 指数年化收益: {val_metrics['index_annual']*100:>7.2f}%")
  220. print(f" - 超额收益: {val_metrics['excess_annual']*100:>7.2f}%")
  221. print(f" - 最大回撤: {val_metrics['max_drawdown']*100:>7.2f}%")
  222. print(f" - 夏普比率: {val_metrics['sharpe']:>7.2f}")
  223. print(f" - 卡玛比率: {val_metrics['calmar']:>7.2f}")
  224. plot_results(val_results, "Validation Set (2024-2025)", "val_results.png")
  225. # 过拟合检测
  226. print("\n[4] 过拟合检测:")
  227. sharpe_decay = (train_metrics['sharpe'] - val_metrics['sharpe']) / train_metrics['sharpe'] if train_metrics['sharpe'] != 0 else 0
  228. print(f" 夏普比率衰减: {sharpe_decay*100:.1f}%")
  229. if sharpe_decay > 0.5:
  230. print(" ⚠️ 警告:可能存在严重过拟合")
  231. elif sharpe_decay > 0.3:
  232. print(" ⚠️ 注意:轻度过拟合")
  233. else:
  234. print(" ✓ 无过拟合,策略稳健")
  235. # 总结
  236. print("\n" + "="*70)
  237. print("回测完成")
  238. print("="*70)
  239. print(f"\n输出文件:")
  240. print(f" - train_results.png (训练集图表)")
  241. print(f" - val_results.png (验证集图表)")
  242. if __name__ == "__main__":
  243. main()