hmm_diagnosis.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. HMM模型诊断脚本
  5. 验证市场环境识别器的效果
  6. """
  7. import numpy as np
  8. import pandas as pd
  9. import warnings
  10. warnings.filterwarnings('ignore')
  11. import sys
  12. from pathlib import Path
  13. PROJECT_DIR = Path(__file__).resolve().parent
  14. if str(PROJECT_DIR) not in sys.path:
  15. sys.path.insert(0, str(PROJECT_DIR))
  16. from market_regime_hmm import MarketRegimeHMM, extract_features
  17. print("="*70)
  18. print("HMM模型诊断报告")
  19. print("="*70)
  20. # 1. 生成带标签的测试数据
  21. print("\n[1] 生成测试数据...")
  22. np.random.seed(42)
  23. n_days = 800
  24. # 创建有明确状态特征的数据
  25. segments = []
  26. true_states = []
  27. for i in range(8):
  28. state = i % 3
  29. seg_prices = []
  30. price = 1000 + i * 100
  31. for day in range(100):
  32. if state == 0: # 震荡: 零均值,中等波动
  33. ret = np.random.normal(0, 0.015)
  34. elif state == 1: # 趋势: 正漂移,低波动
  35. ret = np.random.normal(0.001, 0.010)
  36. else: # 反转: 前半段单边,后半段反向,形成真正的拐点
  37. if day < 50:
  38. direction = 1 if (i % 2 == 0) else -1
  39. ret = np.random.normal(direction * 0.0018, 0.018)
  40. else:
  41. direction = -1 if (i % 2 == 0) else 1
  42. ret = np.random.normal(direction * 0.0018, 0.018)
  43. price *= (1 + ret)
  44. seg_prices.append(price)
  45. true_states.append(state)
  46. segments.extend(seg_prices)
  47. # 为反转段补充一个更符合定义的说明
  48. print(" 反转段定义: 前50天单边运行,后50天反向运行")
  49. dates = pd.date_range('2020-01-01', periods=n_days, freq='B')
  50. df = pd.DataFrame({
  51. 'open': np.array(segments) + np.random.normal(0, 2, n_days),
  52. 'high': np.array(segments) + np.abs(np.random.normal(5, 2, n_days)),
  53. 'low': np.array(segments) - np.abs(np.random.normal(5, 2, n_days)),
  54. 'close': segments,
  55. 'volume': np.random.randint(1000000, 5000000, n_days),
  56. 'true_state': true_states
  57. }, index=dates)
  58. print(f"数据天数: {n_days}")
  59. print(f"真实状态分布:")
  60. for i in range(3):
  61. count = sum(1 for s in true_states if s == i)
  62. print(f" 状态{i}: {count}天 ({count/n_days*100:.1f}%)")
  63. # 2. 特征提取
  64. print("\n[2] 特征提取...")
  65. features = extract_features(df)
  66. feature_cols = ['ret_std_5', 'momentum_10', 'vol_ratio', 'volume_change', 'intraday_trend']
  67. X = features[feature_cols].dropna()
  68. print(f"特征维度: {X.shape}")
  69. # 3. 训练模型
  70. print("\n[3] 训练HMM模型...")
  71. hmm = MarketRegimeHMM(n_components=3, n_iter=100)
  72. hmm.fit(X)
  73. # 4. 预测状态
  74. states, probs = hmm.predict(X)
  75. df_aligned = df.iloc[-len(states):].copy()
  76. df_aligned['predicted_state'] = states
  77. df_aligned['return'] = df_aligned['close'].pct_change()
  78. # 5. 诊断分析
  79. print("\n" + "="*70)
  80. print("诊断结果")
  81. print("="*70)
  82. # 5.1 转移矩阵对比
  83. print("\n[5.1] 转移矩阵对比")
  84. print("\n先验矩阵 (设定):")
  85. prior = np.array([
  86. [0.85, 0.10, 0.05],
  87. [0.15, 0.80, 0.05],
  88. [0.20, 0.10, 0.70]
  89. ])
  90. print(prior.round(3))
  91. print("\n学习到的矩阵:")
  92. learned = hmm.model.transmat_
  93. print(learned.round(3))
  94. print("\n差异:")
  95. diff = np.abs(learned - prior)
  96. print(diff.round(3))
  97. print(f"平均绝对差异: {diff.mean():.3f}")
  98. # 5.2 状态分布对比
  99. print("\n[5.2] 状态分布对比")
  100. print(f"{'状态':<10} {'真实占比':<15} {'预测占比':<15} {'差异':<10}")
  101. print("-"*50)
  102. for i in range(3):
  103. true_pct = sum(1 for s in true_states if s == i) / n_days * 100
  104. pred_pct = sum(1 for s in states if s == i) / len(states) * 100
  105. diff_pct = abs(true_pct - pred_pct)
  106. print(f"状态{i:<5} {true_pct:>6.1f}%{' '*8} {pred_pct:>6.1f}%{' '*8} {diff_pct:>5.1f}%")
  107. # 5.3 状态特征验证
  108. print("\n[5.3] 各状态的价格行为特征")
  109. print(f"{'状态':<8} {'收益率均值':<12} {'收益率标准差':<15} {'样本数':<10}")
  110. print("-"*50)
  111. for i in range(3):
  112. mask = states == i
  113. if mask.any():
  114. rets = df_aligned.loc[mask, 'return'].dropna()
  115. mean_ret = rets.mean() * 100
  116. std_ret = rets.std() * 100
  117. count = mask.sum()
  118. print(f"状态{i:<5} {mean_ret:>+8.3f}%{' '*4} {std_ret:>8.3f}%{' '*6} {count:>5}天")
  119. # 5.4 预期 vs 实际
  120. print("\n[5.4] 状态定义验证")
  121. state_names = ['震荡', '趋势', '反转']
  122. expected = {
  123. 0: {'vol': '中高', 'ret': '接近0'},
  124. 1: {'vol': '低', 'ret': '单边正/负漂移'},
  125. 2: {'vol': '较高', 'ret': '阶段内先同向后反向'}
  126. }
  127. for i in range(3):
  128. mask = states == i
  129. if mask.any():
  130. rets = df_aligned.loc[mask, 'return'].dropna()
  131. mean_ret = rets.mean() * 100
  132. std_ret = rets.std() * 100
  133. print(f"\n状态{i} ({state_names[i]}):")
  134. print(f" 预期: 波动{expected[i]['vol']}, 收益{expected[i]['ret']}")
  135. print(f" 实际: 波动{std_ret:.2f}%, 收益{mean_ret:+.3f}%")
  136. # 简单判断
  137. if i == 0 and abs(mean_ret) < 0.1 and std_ret > 1.0:
  138. print(" ✓ 符合震荡特征")
  139. elif i == 1 and mean_ret > 0.05 and std_ret < 1.5:
  140. print(" ✓ 符合趋势特征")
  141. elif i == 2 and std_ret > 1.8:
  142. print(" ✓ 符合反转特征")
  143. else:
  144. print(" ✗ 特征不匹配")
  145. # 5.5 准确率估算
  146. print("\n[5.5] 状态识别准确率估算")
  147. # 基于特征匹配度估算
  148. matches = 0
  149. for i in range(len(states) - 1):
  150. true_seg = i // 100
  151. if states[i] == true_states[i]:
  152. matches += 1
  153. accuracy = matches / len(states) * 100
  154. print(f"与生成标签匹配率: {accuracy:.1f}%")
  155. if accuracy >= 72:
  156. print("✓ 达到目标准确率 (>72%)")
  157. else:
  158. print("✗ 未达到目标准确率,需要优化")
  159. print("\n" + "="*70)
  160. print("诊断结论")
  161. print("="*70)
  162. print(f"1. 转移矩阵与先验差异: {'可接受' if diff.mean() < 0.3 else '较大'}")
  163. print(f"2. 状态识别准确率: {accuracy:.1f}%")
  164. print(f"3. 状态特征一致性: 见上文分析")
  165. print("\n建议:")
  166. if diff.mean() > 0.3:
  167. print("- 转移矩阵与先验差异较大,建议检查数据特征或调整模型参数")
  168. if accuracy < 72:
  169. print("- 准确率不足,建议增加特征维度或使用更长的训练数据")
  170. print("="*70)