generate_regime_chart.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 生成2024-2025年市场状态识别完整图表
  5. """
  6. import sys
  7. sys.path.insert(0, '/root/.openclaw/workspace/market-regime-identifier')
  8. import numpy as np
  9. import pandas as pd
  10. import matplotlib.pyplot as plt
  11. import matplotlib.dates as mdates
  12. from cyb50_market_classifier import fetch_cyb50_data, calculate_features, define_market_regime
  13. import pickle
  14. import warnings
  15. warnings.filterwarnings('ignore')
  16. # 设置中文字体
  17. plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
  18. plt.rcParams['axes.unicode_minus'] = False
  19. print("="*70)
  20. print("生成2024年至今市场状态识别图表")
  21. print("="*70)
  22. # 获取数据
  23. df = fetch_cyb50_data('2024-01-01', '2026-03-06')
  24. if df is None:
  25. exit(1)
  26. print(f"\n数据范围: {df.index[0].date()} ~ {df.index[-1].date()}")
  27. # 计算特征和标签
  28. features = calculate_features(df)
  29. labels = define_market_regime(df, lookback=10)
  30. # 训练模型
  31. valid_idx = ~np.isnan(labels)
  32. X = features[valid_idx]
  33. y = labels[valid_idx]
  34. from sklearn.ensemble import RandomForestClassifier
  35. clf = RandomForestClassifier(
  36. n_estimators=100,
  37. max_depth=10,
  38. min_samples_split=20,
  39. min_samples_leaf=10,
  40. random_state=42,
  41. class_weight='balanced'
  42. )
  43. clf.fit(X, y)
  44. # 预测所有数据
  45. states = clf.predict(X)
  46. probs = clf.predict_proba(X)
  47. # 对齐数据
  48. df_aligned = df.iloc[-len(states):].copy()
  49. df_aligned['state'] = states
  50. df_aligned['state_prob'] = [p[s] for s, p in zip(states, probs)]
  51. df_aligned['prob_ranging'] = probs[:, 0] # 震荡概率
  52. df_aligned['prob_trend'] = probs[:, 1] # 趋势概率
  53. df_aligned['prob_reversal'] = probs[:, 2] # 反转概率
  54. # 生成图表
  55. fig, axes = plt.subplots(3, 1, figsize=(16, 12))
  56. state_names = ['Ranging', 'Trend', 'Reversal']
  57. colors = ['#2196F3', '#4CAF50', '#FF5722'] # 蓝、绿、橙
  58. # 图1: 价格走势 + 状态标记
  59. ax1 = axes[0]
  60. for i, (name, color) in enumerate(zip(state_names, colors)):
  61. mask = df_aligned['state'] == i
  62. if mask.any():
  63. ax1.scatter(df_aligned.index[mask], df_aligned['close'][mask],
  64. c=color, label=name, alpha=0.7, s=30)
  65. ax1.plot(df_aligned.index, df_aligned['close'], 'k-', alpha=0.3, linewidth=0.5)
  66. ax1.set_ylabel('Price', fontsize=12)
  67. ax1.set_title('CYB50 Market Regime Identification 2024-2025', fontsize=14, fontweight='bold')
  68. ax1.legend(loc='upper left')
  69. ax1.grid(True, alpha=0.3)
  70. # 添加关键点位标注
  71. for idx, row in df_aligned.iterrows():
  72. if idx.month == 1 and idx.day == 2: # 年初
  73. ax1.annotate(f'{row["close"]:.0f}',
  74. xy=(idx, row['close']),
  75. xytext=(10, 10), textcoords='offset points',
  76. fontsize=8, alpha=0.7)
  77. # 图2: 状态概率时间序列
  78. ax2 = axes[1]
  79. ax2.fill_between(df_aligned.index, 0, df_aligned['prob_ranging'],
  80. alpha=0.5, label='Ranging', color=colors[0])
  81. ax2.fill_between(df_aligned.index, df_aligned['prob_ranging'],
  82. df_aligned['prob_ranging'] + df_aligned['prob_trend'],
  83. alpha=0.5, label='Trend', color=colors[1])
  84. ax2.fill_between(df_aligned.index,
  85. df_aligned['prob_ranging'] + df_aligned['prob_trend'], 1,
  86. alpha=0.5, label='Reversal', color=colors[2])
  87. ax2.set_ylabel('Probability', fontsize=12)
  88. ax2.set_title('State Probability Over Time', fontsize=12)
  89. ax2.legend(loc='upper left')
  90. ax2.grid(True, alpha=0.3)
  91. ax2.set_ylim(0, 1)
  92. # 图3: 状态分布统计
  93. ax3 = axes[2]
  94. state_counts = df_aligned['state'].value_counts().sort_index()
  95. bars = ax3.bar(range(3), state_counts.values, color=colors, alpha=0.7)
  96. ax3.set_xticks(range(3))
  97. ax3.set_xticklabels(state_names)
  98. ax3.set_ylabel('Days', fontsize=12)
  99. ax3.set_title('State Distribution 2024-2025', fontsize=12)
  100. # 添加数值标签
  101. for i, (bar, count) in enumerate(zip(bars, state_counts.values)):
  102. pct = count / len(df_aligned) * 100
  103. ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
  104. f'{count}d\n({pct:.1f}%)',
  105. ha='center', va='bottom', fontsize=10)
  106. plt.tight_layout()
  107. plt.savefig('/root/.openclaw/workspace/market-regime-identifier/cyb50_regime_2024_2025.png',
  108. dpi=150, bbox_inches='tight')
  109. print("\n✓ 图表已保存: cyb50_regime_2024_2025.png")
  110. # 生成详细报告
  111. print("\n" + "="*70)
  112. print("2024-2025年详细识别结果")
  113. print("="*70)
  114. # 按月份统计
  115. print("\n【月度统计】")
  116. print(f"{'月份':<10} {'总天数':<8} {'震荡':<8} {'趋势':<8} {'反转':<8} {'主要状态':<10}")
  117. print("-"*70)
  118. for year in [2024, 2025]:
  119. for month in range(1, 13):
  120. mask = (df_aligned.index.year == year) & (df_aligned.index.month == month)
  121. if not mask.any():
  122. continue
  123. month_data = df_aligned[mask]
  124. total = len(month_data)
  125. ranging = (month_data['state'] == 0).sum()
  126. trend = (month_data['state'] == 1).sum()
  127. reversal = (month_data['state'] == 2).sum()
  128. main_state = state_names[month_data['state'].mode()[0]]
  129. print(f"{year}-{month:02d} {total:<8} {ranging:<8} {trend:<8} {reversal:<8} {main_state:<10}")
  130. # 关键点位
  131. print("\n【关键点位标注】")
  132. print(f"{'日期':<12} {'收盘价':<10} {'状态':<10} {'置信度':<10} {'说明':<20}")
  133. print("-"*70)
  134. # 每月第一个交易日
  135. for year in [2024, 2025]:
  136. for month in range(1, 13):
  137. mask = (df_aligned.index.year == year) & (df_aligned.index.month == month)
  138. if not mask.any():
  139. continue
  140. month_data = df_aligned[mask]
  141. first_day = month_data.iloc[0]
  142. date_str = month_data.index[0].strftime('%Y-%m-%d')
  143. price = first_day['close']
  144. state = state_names[int(first_day['state'])]
  145. prob = first_day['state_prob']
  146. # 简单说明
  147. if first_day['state'] == 0:
  148. desc = 'Consolidation'
  149. elif first_day['state'] == 1:
  150. if month_data['close'].iloc[-1] > price:
  151. desc = 'Uptrend'
  152. else:
  153. desc = 'Downtrend'
  154. else:
  155. desc = 'Reversal'
  156. print(f"{date_str:<12} {price:<10.2f} {state:<10} {prob:<10.2%} {desc:<20}")
  157. print("\n" + "="*70)
  158. print("✓ 报告生成完成!")
  159. print("="*70)