cyb50_market_classifier_v3.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50市场状态分类器 - 真实数据版(优化反转识别V3)
  5. 基于规则定义标签,使用有监督学习(Random Forest)
  6. 优化重点:提高反转识别率
  7. """
  8. import numpy as np
  9. import pandas as pd
  10. from sklearn.ensemble import RandomForestClassifier
  11. from sklearn.model_selection import TimeSeriesSplit, cross_val_score
  12. from sklearn.metrics import classification_report, confusion_matrix
  13. import baostock as bs
  14. from pathlib import Path
  15. import warnings
  16. warnings.filterwarnings('ignore')
  17. PROJECT_DIR = Path(__file__).resolve().parent
  18. def fetch_cyb50_data(start_date="2017-01-01", end_date="2025-12-31"):
  19. """获取创业板50真实历史数据"""
  20. print(f"获取创业板50数据 ({start_date} - {end_date})...")
  21. try:
  22. lg = bs.login()
  23. if lg.error_code != '0':
  24. print(f"baostock登录失败: {lg.error_msg}")
  25. return None
  26. rs = bs.query_history_k_data_plus("sz.399673",
  27. "date,open,high,low,close,volume",
  28. start_date=start_date, end_date=end_date,
  29. frequency="d", adjustflag="3")
  30. data_list = []
  31. while (rs.error_code == '0') & rs.next():
  32. row = rs.get_row_data()
  33. if row[0]:
  34. data_list.append({
  35. 'date': row[0],
  36. 'open': float(row[1]) if row[1] else 0,
  37. 'high': float(row[2]) if row[2] else 0,
  38. 'low': float(row[3]) if row[3] else 0,
  39. 'close': float(row[4]) if row[4] else 0,
  40. 'volume': int(float(row[5])) if row[5] else 0
  41. })
  42. bs.logout()
  43. if not data_list:
  44. print("[ERR] 未获取到数据")
  45. return None
  46. df = pd.DataFrame(data_list)
  47. df['date'] = pd.to_datetime(df['date'])
  48. df = df.set_index('date').sort_index()
  49. df['return'] = df['close'].pct_change()
  50. print(f"[OK] 获取成功: {len(df)}条数据")
  51. print(f" 日期范围: {df.index[0].date()} ~ {df.index[-1].date()}")
  52. print(f" 价格范围: {df['close'].min():.2f} ~ {df['close'].max():.2f}")
  53. return df[['open', 'high', 'low', 'close', 'volume', 'return']]
  54. except Exception as e:
  55. print(f"[ERR] 数据获取失败: {e}")
  56. import traceback
  57. traceback.print_exc()
  58. return None
  59. def calculate_features(df):
  60. """计算技术指标特征(增加反转识别特征)"""
  61. features = pd.DataFrame(index=df.index)
  62. # 价格特征
  63. features['close'] = df['close']
  64. # 1. 收益率特征
  65. features['ret_1d'] = df['return']
  66. features['ret_5d'] = df['close'].pct_change(5)
  67. features['ret_10d'] = df['close'].pct_change(10)
  68. features['ret_20d'] = df['close'].pct_change(20)
  69. # 2. 波动率特征
  70. features['volatility_5d'] = df['return'].rolling(5).std() * np.sqrt(252)
  71. features['volatility_20d'] = df['return'].rolling(20).std() * np.sqrt(252)
  72. features['volatility_ratio'] = features['volatility_5d'] / (features['volatility_20d'] + 1e-10)
  73. # 3. 动量特征
  74. features['momentum_10d'] = df['close'] / df['close'].shift(10) - 1
  75. features['momentum_20d'] = df['close'] / df['close'].shift(20) - 1
  76. # 4. 均线特征
  77. features['ma5'] = df['close'].rolling(5).mean()
  78. features['ma20'] = df['close'].rolling(20).mean()
  79. features['ma60'] = df['close'].rolling(60).mean()
  80. features['ma5_above_ma20'] = (features['ma5'] > features['ma20']).astype(int)
  81. features['price_above_ma20'] = (df['close'] > features['ma20']).astype(int)
  82. # 5. RSI(增加超买超卖判断)
  83. delta = df['close'].diff()
  84. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  85. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  86. rs = gain / (loss + 1e-10)
  87. features['rsi_14'] = 100 - (100 / (1 + rs))
  88. # RSI极端值(用于识别反转)
  89. features['rsi_overbought'] = (features['rsi_14'] > 70).astype(int)
  90. features['rsi_oversold'] = (features['rsi_14'] < 30).astype(int)
  91. features['rsi_extreme'] = features['rsi_overbought'] + features['rsi_oversold']
  92. features['rsi_change'] = features['rsi_14'].diff(3) # 3日RSI变化
  93. # 6. MACD
  94. ema12 = df['close'].ewm(span=12).mean()
  95. ema26 = df['close'].ewm(span=26).mean()
  96. features['macd'] = ema12 - ema26
  97. features['macd_signal'] = features['macd'].ewm(span=9).mean()
  98. features['macd_hist'] = features['macd'] - features['macd_signal']
  99. # MACD金叉死叉(反转信号)
  100. features['macd_golden_cross'] = ((features['macd'] > features['macd_signal']) &
  101. (features['macd'].shift(1) <= features['macd_signal'].shift(1))).astype(int)
  102. features['macd_death_cross'] = ((features['macd'] < features['macd_signal']) &
  103. (features['macd'].shift(1) >= features['macd_signal'].shift(1))).astype(int)
  104. features['macd_cross'] = features['macd_golden_cross'] - features['macd_death_cross']
  105. # 7. 布林带
  106. features['bb_middle'] = df['close'].rolling(20).mean()
  107. bb_std = df['close'].rolling(20).std()
  108. features['bb_upper'] = features['bb_middle'] + 2 * bb_std
  109. features['bb_lower'] = features['bb_middle'] - 2 * bb_std
  110. features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10)
  111. # 触及布林带上下轨(反转信号)
  112. features['bb_touch_upper'] = (df['close'] >= features['bb_upper'] * 0.99).astype(int)
  113. features['bb_touch_lower'] = (df['close'] <= features['bb_lower'] * 1.01).astype(int)
  114. features['bb_extreme'] = features['bb_touch_upper'] + features['bb_touch_lower']
  115. # 8. ATR
  116. high_low = df['high'] - df['low']
  117. high_close = np.abs(df['high'] - df['close'].shift())
  118. low_close = np.abs(df['low'] - df['close'].shift())
  119. tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
  120. features['atr_14'] = tr.rolling(14).mean()
  121. features['atr_ratio'] = features['atr_14'] / df['close']
  122. # 9. 成交量特征
  123. features['volume_ratio'] = df['volume'] / df['volume'].rolling(20).mean()
  124. features['volume_spike'] = (features['volume_ratio'] > 2).astype(int)
  125. # 10. 趋势强度
  126. features['adx'] = calculate_adx(df, 14)
  127. # 11. 价格变化加速度
  128. features['price_accel'] = df['close'].diff().diff()
  129. features['price_accel_normalized'] = features['price_accel'] / (df['close'] * 0.01)
  130. # 12. 日内反转强度
  131. features['intraday_reversal'] = ((df['high'] - df['close']) / (df['high'] - df['low'] + 1e-10) -
  132. (df['close'] - df['low']) / (df['high'] - df['low'] + 1e-10))
  133. # 13. 连续涨跌天数
  134. # 口径:return == 0 视为“中断连续序列”,且当天 up/down 都记 0
  135. ret_sign = np.sign(df['return'].fillna(0))
  136. up_mask = ret_sign > 0
  137. up_group = (~up_mask).cumsum()
  138. features['consecutive_up'] = up_mask.astype(int).groupby(up_group).cumsum()
  139. down_mask = ret_sign < 0
  140. down_group = (~down_mask).cumsum()
  141. features['consecutive_down'] = down_mask.astype(int).groupby(down_group).cumsum()
  142. # 14. 新增:5日价格位置(用于判断超买超卖后的位置)
  143. features['price_position_5d'] = (df['close'] - df['low'].rolling(5).min()) / (df['high'].rolling(5).max() - df['low'].rolling(5).min() + 1e-10)
  144. # 填充缺失值
  145. features = features.ffill().fillna(0)
  146. return features
  147. def calculate_adx(df, period=14):
  148. """计算ADX趋势强度指标(标准 Wilder 方法)"""
  149. up_move = df['high'].diff()
  150. down_move = -df['low'].diff()
  151. plus_dm = np.where((up_move > down_move) & (up_move > 0), up_move, 0.0)
  152. minus_dm = np.where((down_move > up_move) & (down_move > 0), down_move, 0.0)
  153. plus_dm = pd.Series(plus_dm, index=df.index)
  154. minus_dm = pd.Series(minus_dm, index=df.index)
  155. tr = pd.concat([
  156. df['high'] - df['low'],
  157. (df['high'] - df['close'].shift()).abs(),
  158. (df['low'] - df['close'].shift()).abs()
  159. ], axis=1).max(axis=1)
  160. atr = tr.ewm(alpha=1/period, adjust=False, min_periods=period).mean()
  161. plus_di = 100 * plus_dm.ewm(alpha=1/period, adjust=False, min_periods=period).mean() / (atr + 1e-10)
  162. minus_di = 100 * minus_dm.ewm(alpha=1/period, adjust=False, min_periods=period).mean() / (atr + 1e-10)
  163. dx = (plus_di - minus_di).abs() / (plus_di + minus_di + 1e-10) * 100
  164. adx = dx.ewm(alpha=1/period, adjust=False, min_periods=period).mean()
  165. return adx
  166. def define_market_regime(df, lookback=10):
  167. """
  168. 基于规则定义市场状态标签(最终平衡版)
  169. 目标:反转识别率50-60%,整体准确率>72%
  170. """
  171. labels = []
  172. # 预计算RSI和MACD
  173. delta = df['close'].diff()
  174. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  175. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  176. rs = gain / (loss + 1e-10)
  177. rsi = 100 - (100 / (1 + rs))
  178. ema12 = df['close'].ewm(span=12).mean()
  179. ema26 = df['close'].ewm(span=26).mean()
  180. macd = ema12 - ema26
  181. for i in range(len(df)):
  182. if i < lookback:
  183. labels.append(0)
  184. continue
  185. # 获取回看期间数据
  186. period_close = df['close'].iloc[i-lookback:i]
  187. period_high = df['high'].iloc[i-lookback:i]
  188. period_low = df['low'].iloc[i-lookback:i]
  189. period_rsi = rsi.iloc[i-lookback:i]
  190. start_price = period_close.iloc[0]
  191. end_price = period_close.iloc[-1]
  192. period_return = (end_price / start_price - 1) * 100
  193. daily_returns = period_close.pct_change().dropna()
  194. volatility = daily_returns.std() * np.sqrt(252) * 100
  195. max_price = period_high.max()
  196. min_price = period_low.min()
  197. price_range = max_price / min_price
  198. mid = lookback // 2
  199. first_half_return = (period_close.iloc[mid] / start_price - 1) * 100
  200. second_half_return = (end_price / period_close.iloc[mid] - 1) * 100
  201. # RSI特征
  202. rsi_start = period_rsi.iloc[0]
  203. rsi_end = period_rsi.iloc[-1]
  204. rsi_max = period_rsi.max()
  205. rsi_min = period_rsi.min()
  206. rsi_change = rsi_end - rsi_start
  207. # 定义标签
  208. label = 0 # 默认震荡
  209. # ========== 反转判断(收紧到明确前后反向)==========
  210. reversal_core = (
  211. (first_half_return >= 2.5 and second_half_return <= -2.0) or
  212. (first_half_return <= -2.5 and second_half_return >= 2.0)
  213. )
  214. rsi_confirmation = (
  215. (rsi_start > 68 and rsi_change < -18) or
  216. (rsi_start < 32 and rsi_change > 18) or
  217. (rsi_max > 72 or rsi_min < 28)
  218. )
  219. if reversal_core and volatility > 20 and price_range > 1.04 and rsi_confirmation:
  220. label = 2
  221. # ========== 趋势判断(向主线边界靠拢) ==========
  222. elif abs(period_return) >= 4.0 and volatility < 35:
  223. if price_range > 1.04:
  224. if not (abs(first_half_return) > 3 and abs(second_half_return) > 2 and
  225. np.sign(first_half_return) != np.sign(second_half_return)):
  226. label = 1
  227. # ========== 震荡判断(默认)=========
  228. else:
  229. label = 0
  230. labels.append(label)
  231. return np.array(labels)
  232. def train_classifier(features, labels):
  233. """训练随机森林分类器"""
  234. print("\n训练分类器...")
  235. # 对齐数据
  236. valid_idx = ~np.isnan(labels)
  237. X = features[valid_idx]
  238. y = labels[valid_idx]
  239. # 分割训练集和测试集(按时间顺序)
  240. split_idx = int(len(X) * 0.7)
  241. X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
  242. y_train, y_test = y[:split_idx], y[split_idx:]
  243. print(f"训练集: {len(X_train)}条")
  244. print(f"测试集: {len(X_test)}条")
  245. # 训练模型 - 调整参数提高对反转的识别
  246. clf = RandomForestClassifier(
  247. n_estimators=200, # 增加树的数量
  248. max_depth=15, # 增加深度
  249. min_samples_split=10,
  250. min_samples_leaf=5,
  251. random_state=42,
  252. class_weight='balanced'
  253. )
  254. clf.fit(X_train, y_train)
  255. # 评估
  256. train_score = clf.score(X_train, y_train)
  257. test_score = clf.score(X_test, y_test)
  258. # 时间序列交叉验证(避免未来数据泄漏到过去)
  259. tscv = TimeSeriesSplit(n_splits=5)
  260. cv_scores = cross_val_score(clf, X, y, cv=tscv)
  261. print(f"\n训练准确率: {train_score:.2%}")
  262. print(f"测试准确率: {test_score:.2%}")
  263. print(f"时间序列交叉验证准确率: {cv_scores.mean():.2%} (+/- {cv_scores.std()*2:.2%})")
  264. # 详细报告
  265. y_pred = clf.predict(X_test)
  266. print("\n分类报告:")
  267. print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转']))
  268. # 混淆矩阵
  269. cm = confusion_matrix(y_test, y_pred)
  270. print("\n混淆矩阵:")
  271. print(" 预测")
  272. print("真实 震荡 趋势 反转")
  273. for i, name in enumerate(['震荡', '趋势', '反转']):
  274. recall = cm[i][i] / cm[i].sum() if cm[i].sum() > 0 else 0
  275. print(f"{name:6s} {cm[i]} (召回:{recall:.1%})")
  276. # 特征重要性
  277. feature_importance = pd.DataFrame({
  278. 'feature': X.columns,
  279. 'importance': clf.feature_importances_
  280. }).sort_values('importance', ascending=False)
  281. print("\n特征重要性 TOP 10:")
  282. print(feature_importance.head(10).to_string(index=False))
  283. return clf, feature_importance
  284. def main():
  285. """主程序"""
  286. print("="*70)
  287. print("创业板50市场状态分类器 - 真实数据版(优化反转识别V3)")
  288. print("="*70)
  289. # 1. 获取真实数据
  290. df = fetch_cyb50_data("2017-01-01", "2025-12-31")
  291. if df is None:
  292. return
  293. # 2. 计算特征
  294. print("\n计算技术指标...")
  295. features = calculate_features(df)
  296. print(f"特征数量: {features.shape[1]}")
  297. # 3. 定义标签
  298. print("\n定义市场状态标签...")
  299. labels = define_market_regime(df, lookback=10)
  300. # 统计标签分布
  301. unique, counts = np.unique(labels, return_counts=True)
  302. print("\n标签分布:")
  303. state_names = ['震荡', '趋势', '反转']
  304. for u, c in zip(unique, counts):
  305. print(f" {state_names[u]}: {c}天 ({c/len(labels)*100:.1f}%)")
  306. # 4. 训练分类器
  307. clf, importance = train_classifier(features, labels)
  308. # 5. 当前状态预测
  309. print("\n" + "="*70)
  310. print("当前市场状态识别")
  311. print("="*70)
  312. latest_features = features.iloc[-1:]
  313. current_pred = clf.predict(latest_features)[0]
  314. pred_proba = clf.predict_proba(latest_features)[0]
  315. print(f"\n当前日期: {df.index[-1].date()}")
  316. print(f"当前价格: {df['close'].iloc[-1]:.2f}")
  317. print(f"\n预测状态: {state_names[current_pred]}")
  318. print(f"置信度: {pred_proba[current_pred]:.2%}")
  319. print("\n状态概率分布:")
  320. for i, name in enumerate(state_names):
  321. bar = '#' * int(pred_proba[i] * 20)
  322. print(f" {name}: {pred_proba[i]:.2%} {bar}")
  323. # 保存模型
  324. print("\n保存模型...")
  325. import pickle
  326. model_path = PROJECT_DIR / 'rf_classifier_v3.pkl'
  327. with open(model_path, 'wb') as f:
  328. pickle.dump(clf, f)
  329. print(f"[OK] 模型已保存: {model_path.name}")
  330. print("\n" + "="*70)
  331. if __name__ == "__main__":
  332. main()