cyb50_high_perf.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50指数真实数据回测 - 高收益优化版
  5. """
  6. import pandas as pd
  7. import numpy as np
  8. import matplotlib
  9. matplotlib.use('Agg')
  10. import matplotlib.pyplot as plt
  11. import warnings
  12. warnings.filterwarnings('ignore')
  13. # 加载真实数据
  14. def load_real_data():
  15. """加载创业板50指数真实数据 - cyb50_baostock.csv"""
  16. df = pd.read_csv('cyb50_baostock.csv')
  17. df['date'] = pd.to_datetime(df['date'])
  18. df = df.set_index('date').sort_index()
  19. # 转换数据类型
  20. for col in ['open', 'high', 'low', 'close', 'volume']:
  21. df[col] = pd.to_numeric(df[col], errors='coerce')
  22. print(f"真实数据加载成功: {df.index[0].date()} ~ {df.index[-1].date()}")
  23. return df
  24. # ==================== 高性能策略 ====================
  25. class HighPerformanceStrategy:
  26. """
  27. 高收益策略:趋势跟踪 + 动量加速 + 智能止盈
  28. """
  29. def __init__(self, params=None):
  30. self.params = params or {
  31. 'fast_ma': 5, # 超短均线,快速响应
  32. 'slow_ma': 20, # 月均线
  33. 'trend_ma': 60, # 季均线
  34. 'momentum_period': 10,
  35. 'volatility_period': 20,
  36. 'max_position': 1.0,
  37. 'profit_take': 0.15, # 15%止盈
  38. 'trailing_stop': 0.08, # 8%移动止损
  39. }
  40. # 确保所有参数都有默认值
  41. default_params = {
  42. 'fast_ma': 5,
  43. 'slow_ma': 20,
  44. 'trend_ma': 60,
  45. 'momentum_period': 10,
  46. 'volatility_period': 20,
  47. 'max_position': 1.0,
  48. 'profit_take': 0.15,
  49. 'trailing_stop': 0.08,
  50. }
  51. if params:
  52. for key, val in default_params.items():
  53. if key not in params:
  54. self.params[key] = val
  55. self.position = 0
  56. self.entry_price = None
  57. self.max_price = None
  58. def generate_signal(self, data):
  59. """生成交易信号"""
  60. close = data['close']
  61. p = self.params
  62. # 计算指标
  63. ma_fast = close.rolling(p['fast_ma']).mean().iloc[-1]
  64. ma_slow = close.rolling(p['slow_ma']).mean().iloc[-1]
  65. ma_trend = close.rolling(p['trend_ma']).mean().iloc[-1]
  66. # 动量
  67. momentum = (close.iloc[-1] / close.iloc[-p['momentum_period']] - 1) * 100
  68. # 波动率
  69. returns = close.pct_change()
  70. vol = returns.rolling(p['volatility_period']).std().iloc[-1] * np.sqrt(252) * 100
  71. curr_price = close.iloc[-1]
  72. # 趋势强度
  73. trend_strong = (curr_price > ma_fast) and (ma_fast > ma_slow) and (ma_slow > ma_trend)
  74. trend_weak = (curr_price < ma_fast) and (ma_fast < ma_slow)
  75. # 信号生成
  76. if trend_strong and momentum > 2:
  77. # 强势上涨,满仓
  78. target_pos = p['max_position']
  79. state = "STRONG_UP"
  80. elif trend_strong and momentum > 0:
  81. # 趋势向上但动量一般,80%仓位
  82. target_pos = p['max_position'] * 0.8
  83. state = "UP"
  84. elif trend_weak or momentum < -3:
  85. # 趋势转弱,空仓
  86. target_pos = 0
  87. state = "DOWN"
  88. else:
  89. # 震荡,50%仓位
  90. target_pos = p['max_position'] * 0.5
  91. state = "OSCILLATE"
  92. # 移动止盈
  93. if self.position > 0 and self.max_price:
  94. current_return = (curr_price - self.entry_price) / self.entry_price
  95. # 更新最高价
  96. if curr_price > self.max_price:
  97. self.max_price = curr_price
  98. # 移动止损:从最高点回撤8%离场
  99. drawdown_from_peak = (curr_price - self.max_price) / self.max_price
  100. if drawdown_from_peak < -p['trailing_stop']:
  101. target_pos = 0
  102. state = "TRAILING_STOP"
  103. # 固定止盈15%
  104. elif current_return > p['profit_take']:
  105. target_pos = 0.5 # 减半仓,锁定利润
  106. state = "PROFIT_TAKE"
  107. # 更新状态
  108. if target_pos > 0 and self.position == 0:
  109. self.entry_price = curr_price
  110. self.max_price = curr_price
  111. elif target_pos == 0:
  112. self.entry_price = None
  113. self.max_price = None
  114. self.position = target_pos
  115. return target_pos, state
  116. # ==================== 回测引擎 ====================
  117. def backtest(data, strategy, start_date=None, end_date=None, warmup=60):
  118. """回测引擎"""
  119. if start_date:
  120. data = data[data.index >= start_date]
  121. if end_date:
  122. data = data[data.index <= end_date]
  123. results = []
  124. nav = 1.0
  125. for i in range(warmup, len(data)):
  126. curr_data = data.iloc[:i+1]
  127. position, state = strategy.generate_signal(curr_data)
  128. if i > warmup:
  129. daily_return = data['close'].iloc[i] / data['close'].iloc[i-1] - 1
  130. strategy_return = daily_return * results[-1]['position'] if results else 0
  131. nav *= (1 + strategy_return)
  132. results.append({
  133. 'date': data.index[i],
  134. 'position': position,
  135. 'nav': nav,
  136. 'state': state,
  137. 'close': data['close'].iloc[i]
  138. })
  139. df = pd.DataFrame(results).set_index('date')
  140. df['index_nav'] = df['close'] / df['close'].iloc[0]
  141. metrics = calculate_metrics(df['nav'], df['index_nav'])
  142. return df, metrics
  143. def calculate_metrics(strategy_nav, index_nav):
  144. """计算绩效指标"""
  145. s_returns = strategy_nav.pct_change().dropna()
  146. total_return = strategy_nav.iloc[-1] - 1
  147. days = len(strategy_nav)
  148. annual_return = (1 + total_return) ** (252 / days) - 1
  149. index_return = index_nav.iloc[-1] - 1
  150. index_annual = (1 + index_return) ** (252 / days) - 1
  151. running_max = strategy_nav.expanding().max()
  152. max_dd = ((strategy_nav - running_max) / running_max).min()
  153. volatility = s_returns.std() * np.sqrt(252)
  154. sharpe = (annual_return - 0.03) / volatility if volatility > 0 else 0
  155. calmar = annual_return / abs(max_dd) if max_dd != 0 else 0
  156. win_rate = (s_returns > 0).mean()
  157. return {
  158. 'annual_return': annual_return,
  159. 'index_annual': index_annual,
  160. 'excess_annual': annual_return - index_annual,
  161. 'max_drawdown': max_dd,
  162. 'sharpe': sharpe,
  163. 'calmar': calmar,
  164. 'win_rate': win_rate,
  165. 'total_return': total_return,
  166. 'index_return': index_return
  167. }
  168. def plot_results(results, title, filename):
  169. """绘制回测图表"""
  170. fig, axes = plt.subplots(3, 1, figsize=(14, 10))
  171. ax1 = axes[0]
  172. ax1.plot(results.index, results['nav'], label='Strategy', linewidth=2, color='red')
  173. ax1.plot(results.index, results['index_nav'], label='Index', linewidth=1, color='gray', alpha=0.7)
  174. ax1.set_title(f'{title}', fontsize=14)
  175. ax1.set_ylabel('NAV')
  176. ax1.legend()
  177. ax1.grid(True, alpha=0.3)
  178. ax2 = axes[1]
  179. colors = {'STRONG_UP': 'green', 'UP': 'lightgreen', 'DOWN': 'red',
  180. 'OSCILLATE': 'yellow', 'TRAILING_STOP': 'orange', 'PROFIT_TAKE': 'blue'}
  181. pos_colors = [colors.get(s, 'gray') for s in results['state']]
  182. ax2.fill_between(results.index, 0, results['position'], alpha=0.5, color='green')
  183. ax2.set_ylabel('Position')
  184. ax2.set_ylim(0, 1.1)
  185. ax2.grid(True, alpha=0.3)
  186. ax3 = axes[2]
  187. running_max = results['nav'].expanding().max()
  188. drawdown = (results['nav'] - running_max) / running_max
  189. ax3.fill_between(results.index, drawdown, 0, alpha=0.3, color='red')
  190. ax3.set_ylabel('Drawdown')
  191. ax3.set_xlabel('Date')
  192. ax3.grid(True, alpha=0.3)
  193. plt.tight_layout()
  194. plt.savefig(filename, dpi=150, bbox_inches='tight')
  195. print(f" 图表已保存: {filename}")
  196. # ==================== 主程序 ====================
  197. def main():
  198. print("="*70)
  199. print("创业板50指数高收益策略回测")
  200. print("="*70)
  201. # 加载真实数据
  202. print("\n[1] 加载真实数据...")
  203. data = load_real_data()
  204. print(f" 数据区间: {data.index[0].date()} ~ {data.index[-1].date()}")
  205. # 训练阶段
  206. print("\n[2] 训练阶段 (2018-2023) - 优化参数...")
  207. # 测试多组参数,找最优
  208. best_params = None
  209. best_score = -999
  210. test_configs = [
  211. {'fast_ma': 5, 'slow_ma': 20, 'profit_take': 0.15, 'trailing_stop': 0.08},
  212. {'fast_ma': 3, 'slow_ma': 15, 'profit_take': 0.12, 'trailing_stop': 0.06},
  213. {'fast_ma': 10, 'slow_ma': 30, 'profit_take': 0.20, 'trailing_stop': 0.10},
  214. ]
  215. for cfg in test_configs:
  216. strategy = HighPerformanceStrategy(cfg)
  217. results, metrics = backtest(data, strategy, start_date='2018-01-01', end_date='2023-12-31')
  218. # 评分:收益优先
  219. score = metrics['annual_return'] * 0.5 + metrics['calmar'] * 0.3 + metrics['sharpe'] * 0.2
  220. print(f"\n 参数: {cfg}")
  221. print(f" 年化: {metrics['annual_return']*100:.1f}%, 回撤: {metrics['max_drawdown']*100:.1f}%, 评分: {score:.2f}")
  222. if score > best_score and metrics['max_drawdown'] > -0.40:
  223. best_score = score
  224. best_params = cfg
  225. print(f"\n 最优参数: {best_params}")
  226. # 用最优参数重新回测训练集
  227. strategy = HighPerformanceStrategy(best_params)
  228. train_results, train_metrics = backtest(data, strategy, start_date='2018-01-01', end_date='2023-12-31')
  229. print(f"\n 训练集最终表现:")
  230. print(f" ┌─────────────────────────────────────┐")
  231. print(f" │ 策略年化收益: {train_metrics['annual_return']*100:>8.2f}% │")
  232. print(f" │ 指数年化收益: {train_metrics['index_annual']*100:>8.2f}% │")
  233. print(f" │ 超额收益: {train_metrics['excess_annual']*100:>8.2f}% │")
  234. print(f" │ 最大回撤: {train_metrics['max_drawdown']*100:>8.2f}% │")
  235. print(f" │ 夏普比率: {train_metrics['sharpe']:>8.2f} │")
  236. print(f" │ 卡玛比率: {train_metrics['calmar']:>8.2f} │")
  237. print(f" │ 胜率: {train_metrics['win_rate']*100:>8.1f}% │")
  238. print(f" └─────────────────────────────────────┘")
  239. plot_results(train_results, "Training Set (2018-2023)", "train_high_perf.png")
  240. # 验证阶段
  241. print(f"\n[3] 验证阶段 (2024-2025) - 样本外测试...")
  242. strategy_val = HighPerformanceStrategy(best_params)
  243. val_results, val_metrics = backtest(data, strategy_val, start_date='2024-01-01', end_date='2025-12-31')
  244. print(f"\n 验证集最终表现:")
  245. print(f" ┌─────────────────────────────────────┐")
  246. print(f" │ 策略年化收益: {val_metrics['annual_return']*100:>8.2f}% │")
  247. print(f" │ 指数年化收益: {val_metrics['index_annual']*100:>8.2f}% │")
  248. print(f" │ 超额收益: {val_metrics['excess_annual']*100:>8.2f}% │")
  249. print(f" │ 最大回撤: {val_metrics['max_drawdown']*100:>8.2f}% │")
  250. print(f" │ 夏普比率: {val_metrics['sharpe']:>8.2f} │")
  251. print(f" │ 卡玛比率: {val_metrics['calmar']:>8.2f} │")
  252. print(f" └─────────────────────────────────────┘")
  253. plot_results(val_results, "Validation Set (2024-2025)", "val_high_perf.png")
  254. # 过拟合检测
  255. print(f"\n[4] 过拟合检测:")
  256. return_decay = (train_metrics['annual_return'] - val_metrics['annual_return']) / train_metrics['annual_return'] if train_metrics['annual_return'] != 0 else 0
  257. print(f" 年化收益衰减: {return_decay*100:.1f}%")
  258. if return_decay > 0.5:
  259. print(" ⚠️ 策略在验证集表现下降明显")
  260. else:
  261. print(" ✓ 策略稳健性良好")
  262. print("\n" + "="*70)
  263. print("回测完成!")
  264. print("="*70)
  265. if __name__ == "__main__":
  266. main()