regime_detection.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """
  2. 创业板50指数市场状态识别 (Regime Detection)
  3. 基于波动率和趋势强度识别不同市场状态
  4. """
  5. import pandas as pd
  6. import numpy as np
  7. from enum import Enum
  8. class RegimeType(Enum):
  9. """市场状态类型"""
  10. STRONG_BULL = "强趋势上涨" # 高波动+上涨趋势
  11. WEAK_BULL = "弱趋势上涨" # 低波动+上涨趋势
  12. STRONG_BEAR = "强趋势下跌" # 高波动+下跌趋势
  13. WEAK_BEAR = "弱趋势下跌" # 低波动+下跌趋势
  14. CONSOLIDATION = "震荡整理" # 无明显趋势
  15. UNKNOWN = "未知"
  16. class RegimeDetector:
  17. """
  18. 基于波动率和趋势的市场状态识别器
  19. 创业板50特性:
  20. - 成长风格, 波动率高于主板
  21. - 趋势性强但反转快
  22. - 适合波动率+趋势双因子识别
  23. """
  24. def __init__(self,
  25. vol_short=20, # 短期波动率窗口
  26. vol_long=60, # 长期波动率窗口
  27. trend_window=20, # 趋势判断窗口
  28. vol_percentile=60, # 波动率分位数阈值
  29. trend_threshold=0.05): # 趋势强度阈值
  30. self.vol_short = vol_short
  31. self.vol_long = vol_long
  32. self.trend_window = trend_window
  33. self.vol_percentile = vol_percentile
  34. self.trend_threshold = trend_threshold
  35. def calculate_volatility(self, prices):
  36. """计算年化波动率"""
  37. returns = prices.pct_change().dropna()
  38. vol_short = returns.rolling(self.vol_short).std() * np.sqrt(252)
  39. vol_long = returns.rolling(self.vol_long).std() * np.sqrt(252)
  40. return vol_short, vol_long
  41. def calculate_trend(self, prices):
  42. """计算趋势强度和方向"""
  43. # 使用均线斜率判断趋势
  44. ma = prices.rolling(self.trend_window).mean()
  45. # 价格相对均线的偏离
  46. deviation = (prices - ma) / ma
  47. # 趋势强度: 斜率方向 + 持续性
  48. trend_strength = deviation.rolling(self.trend_window).mean()
  49. return trend_strength
  50. def detect_regime(self, prices):
  51. """
  52. 识别当前市场状态
  53. 返回: DataFrame with regime info
  54. """
  55. df = pd.DataFrame(index=prices.index)
  56. df['close'] = prices
  57. # 计算波动率
  58. vol_short, vol_long = self.calculate_volatility(prices)
  59. df['vol_short'] = vol_short
  60. df['vol_long'] = vol_long
  61. # 波动率分位数 (基于长期历史)
  62. df['vol_percentile'] = vol_short.rolling(252).apply(
  63. lambda x: pd.Series(x).rank(pct=True).iloc[-1] if len(x) > 0 else 0.5
  64. )
  65. # 趋势强度
  66. df['trend'] = self.calculate_trend(prices)
  67. # 趋势方向 (使用短期动量)
  68. df['momentum'] = prices.pct_change(self.trend_window)
  69. # 识别状态
  70. df['regime'] = RegimeType.UNKNOWN.value
  71. # 高波动
  72. high_vol = df['vol_percentile'] > self.vol_percentile / 100
  73. # 强趋势
  74. strong_trend_up = df['trend'] > self.trend_threshold
  75. strong_trend_down = df['trend'] < -self.trend_threshold
  76. # 强趋势上涨 (高波动+上涨)
  77. mask = high_vol & strong_trend_up
  78. df.loc[mask, 'regime'] = RegimeType.STRONG_BULL.value
  79. # 弱趋势上涨 (低波动+上涨)
  80. mask = (~high_vol) & strong_trend_up
  81. df.loc[mask, 'regime'] = RegimeType.WEAK_BULL.value
  82. # 强趋势下跌 (高波动+下跌)
  83. mask = high_vol & strong_trend_down
  84. df.loc[mask, 'regime'] = RegimeType.STRONG_BEAR.value
  85. # 弱趋势下跌 (低波动+下跌)
  86. mask = (~high_vol) & strong_trend_down
  87. df.loc[mask, 'regime'] = RegimeType.WEAK_BEAR.value
  88. # 震荡 (无明显趋势)
  89. mask = (~strong_trend_up) & (~strong_trend_down)
  90. df.loc[mask, 'regime'] = RegimeType.CONSOLIDATION.value
  91. return df
  92. def get_regime_stats(self, df):
  93. """统计各状态占比和表现"""
  94. stats = []
  95. for regime in df['regime'].unique():
  96. if pd.isna(regime):
  97. continue
  98. mask = df['regime'] == regime
  99. regime_data = df[mask]
  100. # 计算该状态下的收益统计
  101. returns = regime_data['close'].pct_change().dropna()
  102. stats.append({
  103. 'regime': regime,
  104. 'days': len(regime_data),
  105. 'pct': len(regime_data) / len(df) * 100,
  106. 'avg_return': returns.mean() * 100 if len(returns) > 0 else 0,
  107. 'volatility': returns.std() * np.sqrt(252) * 100 if len(returns) > 0 else 0,
  108. 'sharpe': (returns.mean() / returns.std() * np.sqrt(252)) if len(returns) > 0 and returns.std() > 0 else 0,
  109. 'max_return': returns.max() * 100 if len(returns) > 0 else 0,
  110. 'min_return': returns.min() * 100 if len(returns) > 0 else 0,
  111. })
  112. return pd.DataFrame(stats)
  113. def analyze_chinext50_regimes(csv_path="chinext50.csv"):
  114. """分析创业板50的历史状态分布"""
  115. df = pd.read_csv(csv_path, parse_dates=['datetime'], index_col='datetime')
  116. detector = RegimeDetector(
  117. vol_short=20,
  118. vol_long=60,
  119. trend_window=20,
  120. vol_percentile=60,
  121. trend_threshold=0.03 # 创业板波动大,阈值放宽
  122. )
  123. regimes = detector.detect_regime(df['close'])
  124. stats = detector.get_regime_stats(regimes)
  125. print("=" * 60)
  126. print("创业板50指数市场状态分析")
  127. print("=" * 60)
  128. print(f"\n数据区间: {regimes.index[0].date()} 至 {regimes.index[-1].date()}")
  129. print(f"总交易日: {len(regimes)}")
  130. print("\n各状态分布:")
  131. print(stats.to_string(index=False))
  132. # 保存结果
  133. regimes.to_csv("regimes.csv")
  134. stats.to_csv("regime_stats.csv", index=False)
  135. print("\n详细数据已保存: regimes.csv")
  136. print("统计结果已保存: regime_stats.csv")
  137. return regimes, stats
  138. if __name__ == "__main__":
  139. analyze_chinext50_regimes()