hmm_diagnosis.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. HMM模型诊断脚本
  5. 验证市场环境识别器的效果
  6. """
  7. import numpy as np
  8. import pandas as pd
  9. import warnings
  10. warnings.filterwarnings('ignore')
  11. import sys
  12. sys.path.insert(0, '/root/.openclaw/workspace/market-regime-identifier')
  13. from market_regime_hmm import MarketRegimeHMM, extract_features
  14. print("="*70)
  15. print("HMM模型诊断报告")
  16. print("="*70)
  17. # 1. 生成带标签的测试数据
  18. print("\n[1] 生成测试数据...")
  19. np.random.seed(42)
  20. n_days = 800
  21. # 创建有明确状态特征的数据
  22. segments = []
  23. true_states = []
  24. for i in range(8):
  25. state = i % 3
  26. seg_prices = []
  27. price = 1000 + i * 100
  28. for day in range(100):
  29. if state == 0: # 震荡: 零均值,中等波动
  30. ret = np.random.normal(0, 0.015)
  31. elif state == 1: # 趋势: 正漂移,低波动
  32. ret = np.random.normal(0.001, 0.010)
  33. else: # 反转: 负漂移,高波动
  34. ret = np.random.normal(-0.001, 0.025)
  35. price *= (1 + ret)
  36. seg_prices.append(price)
  37. true_states.append(state)
  38. segments.extend(seg_prices)
  39. dates = pd.date_range('2020-01-01', periods=n_days, freq='B')
  40. df = pd.DataFrame({
  41. 'open': np.array(segments) + np.random.normal(0, 2, n_days),
  42. 'high': np.array(segments) + np.abs(np.random.normal(5, 2, n_days)),
  43. 'low': np.array(segments) - np.abs(np.random.normal(5, 2, n_days)),
  44. 'close': segments,
  45. 'volume': np.random.randint(1000000, 5000000, n_days),
  46. 'true_state': true_states
  47. }, index=dates)
  48. print(f"数据天数: {n_days}")
  49. print(f"真实状态分布:")
  50. for i in range(3):
  51. count = sum(1 for s in true_states if s == i)
  52. print(f" 状态{i}: {count}天 ({count/n_days*100:.1f}%)")
  53. # 2. 特征提取
  54. print("\n[2] 特征提取...")
  55. features = extract_features(df)
  56. feature_cols = ['ret_std_5', 'momentum_10', 'vol_ratio', 'volume_change', 'intraday_trend']
  57. X = features[feature_cols].dropna()
  58. print(f"特征维度: {X.shape}")
  59. # 3. 训练模型
  60. print("\n[3] 训练HMM模型...")
  61. hmm = MarketRegimeHMM(n_components=3, n_iter=100)
  62. hmm.fit(X)
  63. # 4. 预测状态
  64. states, probs = hmm.predict(X)
  65. df_aligned = df.iloc[-len(states):].copy()
  66. df_aligned['predicted_state'] = states
  67. df_aligned['return'] = df_aligned['close'].pct_change()
  68. # 5. 诊断分析
  69. print("\n" + "="*70)
  70. print("诊断结果")
  71. print("="*70)
  72. # 5.1 转移矩阵对比
  73. print("\n[5.1] 转移矩阵对比")
  74. print("\n先验矩阵 (设定):")
  75. prior = np.array([
  76. [0.85, 0.10, 0.05],
  77. [0.15, 0.80, 0.05],
  78. [0.20, 0.10, 0.70]
  79. ])
  80. print(prior.round(3))
  81. print("\n学习到的矩阵:")
  82. learned = hmm.model.transmat_
  83. print(learned.round(3))
  84. print("\n差异:")
  85. diff = np.abs(learned - prior)
  86. print(diff.round(3))
  87. print(f"平均绝对差异: {diff.mean():.3f}")
  88. # 5.2 状态分布对比
  89. print("\n[5.2] 状态分布对比")
  90. print(f"{'状态':<10} {'真实占比':<15} {'预测占比':<15} {'差异':<10}")
  91. print("-"*50)
  92. for i in range(3):
  93. true_pct = sum(1 for s in true_states if s == i) / n_days * 100
  94. pred_pct = sum(1 for s in states if s == i) / len(states) * 100
  95. diff_pct = abs(true_pct - pred_pct)
  96. print(f"状态{i:<5} {true_pct:>6.1f}%{' '*8} {pred_pct:>6.1f}%{' '*8} {diff_pct:>5.1f}%")
  97. # 5.3 状态特征验证
  98. print("\n[5.3] 各状态的价格行为特征")
  99. print(f"{'状态':<8} {'收益率均值':<12} {'收益率标准差':<15} {'样本数':<10}")
  100. print("-"*50)
  101. for i in range(3):
  102. mask = states == i
  103. if mask.any():
  104. rets = df_aligned.loc[mask, 'return'].dropna()
  105. mean_ret = rets.mean() * 100
  106. std_ret = rets.std() * 100
  107. count = mask.sum()
  108. print(f"状态{i:<5} {mean_ret:>+8.3f}%{' '*4} {std_ret:>8.3f}%{' '*6} {count:>5}天")
  109. # 5.4 预期 vs 实际
  110. print("\n[5.4] 状态定义验证")
  111. state_names = ['震荡', '趋势', '反转']
  112. expected = {
  113. 0: {'vol': '中高', 'ret': '接近0'},
  114. 1: {'vol': '低', 'ret': '正'},
  115. 2: {'vol': '高', 'ret': '负'}
  116. }
  117. for i in range(3):
  118. mask = states == i
  119. if mask.any():
  120. rets = df_aligned.loc[mask, 'return'].dropna()
  121. mean_ret = rets.mean() * 100
  122. std_ret = rets.std() * 100
  123. print(f"\n状态{i} ({state_names[i]}):")
  124. print(f" 预期: 波动{expected[i]['vol']}, 收益{expected[i]['ret']}")
  125. print(f" 实际: 波动{std_ret:.2f}%, 收益{mean_ret:+.3f}%")
  126. # 简单判断
  127. if i == 0 and abs(mean_ret) < 0.1 and std_ret > 1.0:
  128. print(" ✓ 符合震荡特征")
  129. elif i == 1 and mean_ret > 0.05 and std_ret < 1.5:
  130. print(" ✓ 符合趋势特征")
  131. elif i == 2 and std_ret > 1.8:
  132. print(" ✓ 符合反转特征")
  133. else:
  134. print(" ✗ 特征不匹配")
  135. # 5.5 准确率估算
  136. print("\n[5.5] 状态识别准确率估算")
  137. # 基于特征匹配度估算
  138. matches = 0
  139. for i in range(len(states) - 1):
  140. true_seg = i // 100
  141. if states[i] == true_states[i]:
  142. matches += 1
  143. accuracy = matches / len(states) * 100
  144. print(f"与生成标签匹配率: {accuracy:.1f}%")
  145. if accuracy >= 72:
  146. print("✓ 达到目标准确率 (>72%)")
  147. else:
  148. print("✗ 未达到目标准确率,需要优化")
  149. print("\n" + "="*70)
  150. print("诊断结论")
  151. print("="*70)
  152. print(f"1. 转移矩阵与先验差异: {'可接受' if diff.mean() < 0.3 else '较大'}")
  153. print(f"2. 状态识别准确率: {accuracy:.1f}%")
  154. print(f"3. 状态特征一致性: 见上文分析")
  155. print("\n建议:")
  156. if diff.mean() > 0.3:
  157. print("- 转移矩阵与先验差异较大,建议检查数据特征或调整模型参数")
  158. if accuracy < 72:
  159. print("- 准确率不足,建议增加特征维度或使用更长的训练数据")
  160. print("="*70)