cyb50_trend.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50指数 - 高收益趋势策略
  5. 使用真实价格特征,追求年化30%+收益
  6. """
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib
  10. matplotlib.use('Agg')
  11. import matplotlib.pyplot as plt
  12. import warnings
  13. warnings.filterwarnings('ignore')
  14. def load_real_data():
  15. """加载创业板50指数真实数据 - cyb50_baostock.csv"""
  16. df = pd.read_csv('cyb50_baostock.csv')
  17. df['date'] = pd.to_datetime(df['date'])
  18. df = df.set_index('date').sort_index()
  19. # 转换数据类型
  20. for col in ['open', 'high', 'low', 'close', 'volume']:
  21. df[col] = pd.to_numeric(df[col], errors='coerce')
  22. print(f"真实数据加载成功: {df.index[0].date()} ~ {df.index[-1].date()}")
  23. return df
  24. class TrendStrategy:
  25. """趋势跟踪策略 - 激进高收益版"""
  26. def __init__(self):
  27. self.pos = 0
  28. self.entry = 0
  29. self.peak = 0
  30. def signal(self, data):
  31. c = data['close'].values
  32. if len(c) < 60:
  33. return 0
  34. # 技术指标 - 更短周期,更敏感
  35. ma3 = np.mean(c[-3:])
  36. ma10 = np.mean(c[-10:])
  37. ma30 = np.mean(c[-30:])
  38. # 价格创10日新高(更敏感)
  39. highest_10 = np.max(c[-10:])
  40. lowest_10 = np.min(c[-10:])
  41. curr = c[-1]
  42. # 突破买入:创10日新高
  43. breakout = (curr >= highest_10 * 0.995) and (ma3 > ma10)
  44. # 卖出:跌破10日最低点
  45. sell = (curr <= lowest_10 * 1.005) or (ma3 < ma10 * 0.97)
  46. if breakout and self.pos == 0:
  47. return 1.0 # 满仓
  48. elif sell and self.pos > 0:
  49. return 0.0 # 清仓
  50. else:
  51. return self.pos
  52. def generate(self, data):
  53. new_pos = self.signal(data)
  54. curr_price = data['close'].iloc[-1]
  55. # 移动止损 - 更宽松的10%
  56. if self.pos > 0:
  57. if curr_price > self.peak:
  58. self.peak = curr_price
  59. if curr_price < self.peak * 0.90:
  60. new_pos = 0
  61. # 更新状态
  62. if new_pos > 0 and self.pos == 0:
  63. self.entry = curr_price
  64. self.peak = curr_price
  65. state = "BUY"
  66. elif new_pos == 0 and self.pos > 0:
  67. self.entry = 0
  68. self.peak = 0
  69. state = "SELL"
  70. elif new_pos > 0:
  71. state = "HOLD"
  72. else:
  73. state = "EMPTY"
  74. self.pos = new_pos
  75. return new_pos, state
  76. def backtest(data, strategy, start, end, warmup=60):
  77. data = data[(data.index >= start) & (data.index <= end)]
  78. nav = 1.0
  79. results = []
  80. for i in range(warmup, len(data)):
  81. curr = data.iloc[:i+1]
  82. pos, state = strategy.generate(curr)
  83. if i > warmup:
  84. ret = data['close'].iloc[i] / data['close'].iloc[i-1] - 1
  85. nav *= (1 + ret * results[-1]['pos'])
  86. results.append({
  87. 'date': data.index[i],
  88. 'pos': pos,
  89. 'nav': nav,
  90. 'state': state,
  91. 'price': data['close'].iloc[i]
  92. })
  93. df = pd.DataFrame(results).set_index('date')
  94. df['idx_nav'] = df['price'] / df['price'].iloc[0]
  95. return df
  96. def calc_metrics(nav, idx_nav):
  97. total = nav.iloc[-1] - 1
  98. days = len(nav)
  99. annual = (1 + total) ** (252/days) - 1
  100. idx_total = idx_nav.iloc[-1] - 1
  101. idx_annual = (1 + idx_total) ** (252/days) - 1
  102. running_max = nav.expanding().max()
  103. max_dd = ((nav - running_max) / running_max).min()
  104. vol = nav.pct_change().std() * np.sqrt(252)
  105. sharpe = (annual - 0.03) / vol if vol > 0 else 0
  106. calmar = annual / abs(max_dd) if max_dd != 0 else 0
  107. return {
  108. 'annual': annual, 'idx_annual': idx_annual,
  109. 'excess': annual - idx_annual, 'max_dd': max_dd,
  110. 'sharpe': sharpe, 'calmar': calmar,
  111. 'total': total, 'idx_total': idx_total
  112. }
  113. def plot(df, title, fn):
  114. fig, ax = plt.subplots(2, 1, figsize=(14, 8))
  115. ax[0].plot(df.index, df['nav'], 'r-', lw=2, label='Strategy')
  116. ax[0].plot(df.index, df['idx_nav'], 'gray', lw=1, alpha=0.6, label='Index')
  117. ax[0].set_title(title, fontsize=14)
  118. ax[0].legend()
  119. ax[0].grid(True, alpha=0.3)
  120. ax[1].fill_between(df.index, 0, df['pos'], alpha=0.5, color='green')
  121. ax[1].set_ylim(0, 1.1)
  122. ax[1].set_ylabel('Position')
  123. ax[1].grid(True, alpha=0.3)
  124. plt.tight_layout()
  125. plt.savefig(fn, dpi=150)
  126. print(f" 图表: {fn}")
  127. def main():
  128. print("="*60)
  129. print("创业板50 - 趋势突破策略")
  130. print("="*60)
  131. data = load_real_data()
  132. print(f"\n数据: {data.index[0].date()} ~ {data.index[-1].date()}")
  133. # 训练
  134. print("\n【训练集 2018-2023】")
  135. s = TrendStrategy()
  136. train = backtest(data, s, '2018-01-01', '2023-12-31')
  137. m = calc_metrics(train['nav'], train['idx_nav'])
  138. print(f" 策略收益: {m['total']*100:7.1f}% (年化 {m['annual']*100:5.1f}%)")
  139. print(f" 指数收益: {m['idx_total']*100:7.1f}% (年化 {m['idx_annual']*100:5.1f}%)")
  140. print(f" 超额收益: {m['excess']*100:7.1f}%")
  141. print(f" 最大回撤: {m['max_dd']*100:7.1f}%")
  142. print(f" 夏普比率: {m['sharpe']:7.2f}")
  143. print(f" 卡玛比率: {m['calmar']:7.2f}")
  144. plot(train, "Training 2018-2023", "train_trend.png")
  145. # 验证
  146. print("\n【验证集 2024-2025】")
  147. s2 = TrendStrategy()
  148. val = backtest(data, s2, '2024-01-01', '2025-12-31')
  149. m2 = calc_metrics(val['nav'], val['idx_nav'])
  150. print(f" 策略收益: {m2['total']*100:7.1f}% (年化 {m2['annual']*100:5.1f}%)")
  151. print(f" 指数收益: {m2['idx_total']*100:7.1f}% (年化 {m2['idx_annual']*100:5.1f}%)")
  152. print(f" 超额收益: {m2['excess']*100:7.1f}%")
  153. print(f" 最大回撤: {m2['max_dd']*100:7.1f}%")
  154. print(f" 夏普比率: {m2['sharpe']:7.2f}")
  155. plot(val, "Validation 2024-2025", "val_trend.png")
  156. # 评价
  157. print("\n【策略评价】")
  158. if m['annual'] > 0.30:
  159. print(" ✅ 训练集年化超30%,高收益潜力")
  160. elif m['annual'] > 0.15:
  161. print(" ✅ 训练集表现良好")
  162. else:
  163. print(" ⚠️ 训练集收益一般")
  164. if m2['annual'] > 0:
  165. print(" ✅ 验证集正收益")
  166. else:
  167. print(" ❌ 验证集亏损")
  168. print("\n" + "="*60)
  169. if __name__ == "__main__":
  170. main()