cyb50_market_classifier.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50市场状态分类器 - 真实数据版
  5. 基于规则定义标签,使用有监督学习(Random Forest)
  6. 数据源:akshare 创业板50指数 (sz399673)
  7. 标签定义基于真实价格行为规则
  8. """
  9. import numpy as np
  10. import pandas as pd
  11. from sklearn.ensemble import RandomForestClassifier
  12. from sklearn.model_selection import train_test_split, cross_val_score
  13. from sklearn.metrics import classification_report, confusion_matrix
  14. import baostock as bs
  15. import warnings
  16. warnings.filterwarnings('ignore')
  17. def fetch_cyb50_data(start_date="2017-01-01", end_date="2025-12-31"):
  18. """获取创业板50真实历史数据"""
  19. print(f"获取创业板50数据 ({start_date} - {end_date})...")
  20. try:
  21. # 使用baostock
  22. lg = bs.login()
  23. if lg.error_code != '0':
  24. print(f"baostock登录失败: {lg.error_msg}")
  25. return None
  26. # 创业板50代码: sz.399673
  27. rs = bs.query_history_k_data_plus("sz.399673",
  28. "date,open,high,low,close,volume",
  29. start_date=start_date, end_date=end_date,
  30. frequency="d", adjustflag="3")
  31. data_list = []
  32. while (rs.error_code == '0') & rs.next():
  33. row = rs.get_row_data()
  34. if row[0]:
  35. data_list.append({
  36. 'date': row[0],
  37. 'open': float(row[1]) if row[1] else 0,
  38. 'high': float(row[2]) if row[2] else 0,
  39. 'low': float(row[3]) if row[3] else 0,
  40. 'close': float(row[4]) if row[4] else 0,
  41. 'volume': int(float(row[5])) if row[5] else 0
  42. })
  43. bs.logout()
  44. if not data_list:
  45. print("✗ 未获取到数据")
  46. return None
  47. df = pd.DataFrame(data_list)
  48. df['date'] = pd.to_datetime(df['date'])
  49. df = df.set_index('date').sort_index()
  50. df['return'] = df['close'].pct_change()
  51. print(f"✓ 获取成功: {len(df)}条数据")
  52. print(f" 日期范围: {df.index[0].date()} ~ {df.index[-1].date()}")
  53. print(f" 价格范围: {df['close'].min():.2f} ~ {df['close'].max():.2f}")
  54. return df[['open', 'high', 'low', 'close', 'volume', 'return']]
  55. except Exception as e:
  56. print(f"✗ 数据获取失败: {e}")
  57. import traceback
  58. traceback.print_exc()
  59. return None
  60. def calculate_features(df):
  61. """计算技术指标特征"""
  62. features = pd.DataFrame(index=df.index)
  63. # 价格特征
  64. features['close'] = df['close']
  65. # 1. 收益率特征
  66. features['ret_1d'] = df['return']
  67. features['ret_5d'] = df['close'].pct_change(5)
  68. features['ret_10d'] = df['close'].pct_change(10)
  69. features['ret_20d'] = df['close'].pct_change(20)
  70. # 2. 波动率特征
  71. features['volatility_5d'] = df['return'].rolling(5).std() * np.sqrt(252)
  72. features['volatility_20d'] = df['return'].rolling(20).std() * np.sqrt(252)
  73. features['volatility_ratio'] = features['volatility_5d'] / (features['volatility_20d'] + 1e-10)
  74. # 3. 动量特征
  75. features['momentum_10d'] = df['close'] / df['close'].shift(10) - 1
  76. features['momentum_20d'] = df['close'] / df['close'].shift(20) - 1
  77. # 4. 均线特征
  78. features['ma5'] = df['close'].rolling(5).mean()
  79. features['ma20'] = df['close'].rolling(20).mean()
  80. features['ma60'] = df['close'].rolling(60).mean()
  81. features['ma5_above_ma20'] = (features['ma5'] > features['ma20']).astype(int)
  82. features['price_above_ma20'] = (df['close'] > features['ma20']).astype(int)
  83. # 5. RSI
  84. delta = df['close'].diff()
  85. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  86. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  87. rs = gain / (loss + 1e-10)
  88. features['rsi_14'] = 100 - (100 / (1 + rs))
  89. # 6. MACD
  90. ema12 = df['close'].ewm(span=12).mean()
  91. ema26 = df['close'].ewm(span=26).mean()
  92. features['macd'] = ema12 - ema26
  93. features['macd_signal'] = features['macd'].ewm(span=9).mean()
  94. features['macd_hist'] = features['macd'] - features['macd_signal']
  95. # 7. 布林带
  96. features['bb_middle'] = df['close'].rolling(20).mean()
  97. bb_std = df['close'].rolling(20).std()
  98. features['bb_upper'] = features['bb_middle'] + 2 * bb_std
  99. features['bb_lower'] = features['bb_middle'] - 2 * bb_std
  100. features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10)
  101. # 8. ATR (平均真实波幅)
  102. high_low = df['high'] - df['low']
  103. high_close = np.abs(df['high'] - df['close'].shift())
  104. low_close = np.abs(df['low'] - df['close'].shift())
  105. tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
  106. features['atr_14'] = tr.rolling(14).mean()
  107. features['atr_ratio'] = features['atr_14'] / df['close']
  108. # 9. 成交量特征
  109. features['volume_ratio'] = df['volume'] / df['volume'].rolling(20).mean()
  110. # 10. 趋势强度
  111. features['adx'] = calculate_adx(df, 14)
  112. # 填充缺失值
  113. features = features.ffill().fillna(0)
  114. return features
  115. def calculate_adx(df, period=14):
  116. """计算ADX趋势强度指标"""
  117. plus_dm = df['high'].diff()
  118. minus_dm = df['low'].diff().abs()
  119. plus_dm[plus_dm < 0] = 0
  120. minus_dm[minus_dm < 0] = 0
  121. tr = pd.concat([
  122. df['high'] - df['low'],
  123. (df['high'] - df['close'].shift()).abs(),
  124. (df['low'] - df['close'].shift()).abs()
  125. ], axis=1).max(axis=1)
  126. atr = tr.rolling(period).mean()
  127. plus_di = 100 * (plus_dm.rolling(period).mean() / atr)
  128. minus_di = 100 * (minus_dm.rolling(period).mean() / atr)
  129. dx = (abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)) * 100
  130. adx = dx.rolling(period).mean()
  131. return adx
  132. def define_market_regime(df, lookback=10):
  133. """
  134. 基于规则定义市场状态标签(优化版V2)
  135. 优化目标:
  136. - 使三类分布更均衡(震荡 40-50%,趋势 30-40%,反转 10-20%)
  137. - 测试准确率 > 72%
  138. 规则(按优先级排序):
  139. 1. 反转 (2): 前N/2日收益 >= 2.5% 且后N/2日收益 <= -2%,或相反
  140. 2. 趋势 (1): |N日收益| >= 4%, 波动率 < 35%,且有方向性
  141. 3. 震荡 (0): 其余情况
  142. """
  143. labels = []
  144. for i in range(len(df)):
  145. if i < lookback:
  146. labels.append(0)
  147. continue
  148. period_close = df['close'].iloc[i-lookback:i]
  149. period_high = df['high'].iloc[i-lookback:i]
  150. period_low = df['low'].iloc[i-lookback:i]
  151. start_price = period_close.iloc[0]
  152. end_price = period_close.iloc[-1]
  153. period_return = (end_price / start_price - 1) * 100
  154. daily_returns = period_close.pct_change().dropna()
  155. volatility = daily_returns.std() * np.sqrt(252) * 100
  156. max_price = period_high.max()
  157. min_price = period_low.min()
  158. price_range = max_price / min_price
  159. mid = lookback // 2
  160. first_half_return = (period_close.iloc[mid] / start_price - 1) * 100
  161. second_half_return = (end_price / period_close.iloc[mid] - 1) * 100
  162. label = 0 # 默认震荡
  163. # ========== 反转判断(严格的V型反转)==========
  164. # 需要前后两段都有明显的反向运动
  165. if (first_half_return >= 2.5 and second_half_return <= -2.0) or \
  166. (first_half_return <= -2.5 and second_half_return >= 2.0):
  167. # 反转需要整体有一定的波动
  168. if volatility > 20 and price_range > 1.04:
  169. label = 2
  170. # ========== 趋势判断(需要明显的方向性)==========
  171. elif abs(period_return) >= 4.0 and volatility < 35:
  172. # 趋势期间高低点差距要明显
  173. if price_range > 1.04:
  174. # 排除V型反转(前后反向)
  175. if not (abs(first_half_return) > 3 and abs(second_half_return) > 2 and
  176. np.sign(first_half_return) != np.sign(second_half_return)):
  177. label = 1
  178. # ========== 震荡(默认)==========
  179. else:
  180. label = 0
  181. labels.append(label)
  182. return np.array(labels)
  183. def train_classifier(features, labels):
  184. """训练随机森林分类器"""
  185. print("\n训练分类器...")
  186. # 对齐数据
  187. valid_idx = ~np.isnan(labels)
  188. X = features[valid_idx]
  189. y = labels[valid_idx]
  190. # 分割训练集和测试集(按时间顺序)
  191. split_idx = int(len(X) * 0.7)
  192. X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
  193. y_train, y_test = y[:split_idx], y[split_idx:]
  194. print(f"训练集: {len(X_train)}条")
  195. print(f"测试集: {len(X_test)}条")
  196. # 训练模型
  197. clf = RandomForestClassifier(
  198. n_estimators=100,
  199. max_depth=10,
  200. min_samples_split=20,
  201. min_samples_leaf=10,
  202. random_state=42,
  203. class_weight='balanced'
  204. )
  205. clf.fit(X_train, y_train)
  206. # 评估
  207. train_score = clf.score(X_train, y_train)
  208. test_score = clf.score(X_test, y_test)
  209. print(f"\n训练准确率: {train_score:.2%}")
  210. print(f"测试准确率: {test_score:.2%}")
  211. # 交叉验证
  212. cv_scores = cross_val_score(clf, X, y, cv=5)
  213. print(f"交叉验证准确率: {cv_scores.mean():.2%} (+/- {cv_scores.std()*2:.2%})")
  214. # 详细报告
  215. y_pred = clf.predict(X_test)
  216. print("\n分类报告:")
  217. print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转']))
  218. # 特征重要性
  219. feature_importance = pd.DataFrame({
  220. 'feature': X.columns,
  221. 'importance': clf.feature_importances_
  222. }).sort_values('importance', ascending=False)
  223. print("\n特征重要性 TOP 10:")
  224. print(feature_importance.head(10).to_string(index=False))
  225. return clf, feature_importance
  226. def main():
  227. """主程序"""
  228. print("="*70)
  229. print("创业板50市场状态分类器 - 真实数据版")
  230. print("="*70)
  231. # 1. 获取真实数据
  232. df = fetch_cyb50_data("2017-01-01", "2025-12-31")
  233. if df is None:
  234. return
  235. # 2. 计算特征
  236. print("\n计算技术指标...")
  237. features = calculate_features(df)
  238. print(f"特征数量: {features.shape[1]}")
  239. # 3. 定义标签
  240. print("\n定义市场状态标签...")
  241. labels = define_market_regime(df, lookback=10)
  242. # 统计标签分布
  243. unique, counts = np.unique(labels, return_counts=True)
  244. print("\n标签分布:")
  245. state_names = ['震荡', '趋势', '反转']
  246. for u, c in zip(unique, counts):
  247. print(f" {state_names[u]}: {c}天 ({c/len(labels)*100:.1f}%)")
  248. # 4. 训练分类器
  249. clf, importance = train_classifier(features, labels)
  250. # 5. 当前状态预测
  251. print("\n" + "="*70)
  252. print("当前市场状态识别")
  253. print("="*70)
  254. latest_features = features.iloc[-1:]
  255. current_pred = clf.predict(latest_features)[0]
  256. pred_proba = clf.predict_proba(latest_features)[0]
  257. print(f"\n当前日期: {df.index[-1].date()}")
  258. print(f"当前价格: {df['close'].iloc[-1]:.2f}")
  259. print(f"\n预测状态: {state_names[current_pred]}")
  260. print(f"置信度: {pred_proba[current_pred]:.2%}")
  261. print("\n状态概率分布:")
  262. for i, name in enumerate(state_names):
  263. bar = '█' * int(pred_proba[i] * 20)
  264. print(f" {name}: {pred_proba[i]:.2%} {bar}")
  265. # 保存模型
  266. print("\n保存模型...")
  267. import pickle
  268. with open('/root/.openclaw/workspace/market-regime-identifier/rf_classifier.pkl', 'wb') as f:
  269. pickle.dump(clf, f)
  270. print("✓ 模型已保存: rf_classifier.pkl")
  271. print("\n" + "="*70)
  272. if __name__ == "__main__":
  273. main()