|
|
@@ -0,0 +1,416 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""
|
|
|
+创业板50市场状态分类器 - 真实数据版(优化反转识别V3)
|
|
|
+基于规则定义标签,使用有监督学习(Random Forest)
|
|
|
+
|
|
|
+优化重点:提高反转识别率
|
|
|
+"""
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import pandas as pd
|
|
|
+from sklearn.ensemble import RandomForestClassifier
|
|
|
+from sklearn.model_selection import train_test_split, cross_val_score
|
|
|
+from sklearn.metrics import classification_report, confusion_matrix
|
|
|
+import baostock as bs
|
|
|
+import warnings
|
|
|
+warnings.filterwarnings('ignore')
|
|
|
+
|
|
|
+
|
|
|
+def fetch_cyb50_data(start_date="2017-01-01", end_date="2025-12-31"):
|
|
|
+ """获取创业板50真实历史数据"""
|
|
|
+ print(f"获取创业板50数据 ({start_date} - {end_date})...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ lg = bs.login()
|
|
|
+ if lg.error_code != '0':
|
|
|
+ print(f"baostock登录失败: {lg.error_msg}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ rs = bs.query_history_k_data_plus("sz.399673",
|
|
|
+ "date,open,high,low,close,volume",
|
|
|
+ start_date=start_date, end_date=end_date,
|
|
|
+ frequency="d", adjustflag="3")
|
|
|
+
|
|
|
+ data_list = []
|
|
|
+ while (rs.error_code == '0') & rs.next():
|
|
|
+ row = rs.get_row_data()
|
|
|
+ if row[0]:
|
|
|
+ data_list.append({
|
|
|
+ 'date': row[0],
|
|
|
+ 'open': float(row[1]) if row[1] else 0,
|
|
|
+ 'high': float(row[2]) if row[2] else 0,
|
|
|
+ 'low': float(row[3]) if row[3] else 0,
|
|
|
+ 'close': float(row[4]) if row[4] else 0,
|
|
|
+ 'volume': int(float(row[5])) if row[5] else 0
|
|
|
+ })
|
|
|
+
|
|
|
+ bs.logout()
|
|
|
+
|
|
|
+ if not data_list:
|
|
|
+ print("✗ 未获取到数据")
|
|
|
+ return None
|
|
|
+
|
|
|
+ df = pd.DataFrame(data_list)
|
|
|
+ df['date'] = pd.to_datetime(df['date'])
|
|
|
+ df = df.set_index('date').sort_index()
|
|
|
+ df['return'] = df['close'].pct_change()
|
|
|
+
|
|
|
+ print(f"✓ 获取成功: {len(df)}条数据")
|
|
|
+ print(f" 日期范围: {df.index[0].date()} ~ {df.index[-1].date()}")
|
|
|
+ print(f" 价格范围: {df['close'].min():.2f} ~ {df['close'].max():.2f}")
|
|
|
+
|
|
|
+ return df[['open', 'high', 'low', 'close', 'volume', 'return']]
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"✗ 数据获取失败: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def calculate_features(df):
|
|
|
+ """计算技术指标特征(增加反转识别特征)"""
|
|
|
+ features = pd.DataFrame(index=df.index)
|
|
|
+
|
|
|
+ # 价格特征
|
|
|
+ features['close'] = df['close']
|
|
|
+
|
|
|
+ # 1. 收益率特征
|
|
|
+ features['ret_1d'] = df['return']
|
|
|
+ features['ret_5d'] = df['close'].pct_change(5)
|
|
|
+ features['ret_10d'] = df['close'].pct_change(10)
|
|
|
+ features['ret_20d'] = df['close'].pct_change(20)
|
|
|
+
|
|
|
+ # 2. 波动率特征
|
|
|
+ features['volatility_5d'] = df['return'].rolling(5).std() * np.sqrt(252)
|
|
|
+ features['volatility_20d'] = df['return'].rolling(20).std() * np.sqrt(252)
|
|
|
+ features['volatility_ratio'] = features['volatility_5d'] / (features['volatility_20d'] + 1e-10)
|
|
|
+
|
|
|
+ # 3. 动量特征
|
|
|
+ features['momentum_10d'] = df['close'] / df['close'].shift(10) - 1
|
|
|
+ features['momentum_20d'] = df['close'] / df['close'].shift(20) - 1
|
|
|
+
|
|
|
+ # 4. 均线特征
|
|
|
+ features['ma5'] = df['close'].rolling(5).mean()
|
|
|
+ features['ma20'] = df['close'].rolling(20).mean()
|
|
|
+ features['ma60'] = df['close'].rolling(60).mean()
|
|
|
+ features['ma5_above_ma20'] = (features['ma5'] > features['ma20']).astype(int)
|
|
|
+ features['price_above_ma20'] = (df['close'] > features['ma20']).astype(int)
|
|
|
+
|
|
|
+ # 5. RSI(增加超买超卖判断)
|
|
|
+ delta = df['close'].diff()
|
|
|
+ gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
|
|
+ loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
|
|
+ rs = gain / (loss + 1e-10)
|
|
|
+ features['rsi_14'] = 100 - (100 / (1 + rs))
|
|
|
+
|
|
|
+ # RSI极端值(用于识别反转)
|
|
|
+ features['rsi_overbought'] = (features['rsi_14'] > 70).astype(int)
|
|
|
+ features['rsi_oversold'] = (features['rsi_14'] < 30).astype(int)
|
|
|
+ features['rsi_extreme'] = features['rsi_overbought'] + features['rsi_oversold']
|
|
|
+ features['rsi_change'] = features['rsi_14'].diff(3) # 3日RSI变化
|
|
|
+
|
|
|
+ # 6. MACD
|
|
|
+ ema12 = df['close'].ewm(span=12).mean()
|
|
|
+ ema26 = df['close'].ewm(span=26).mean()
|
|
|
+ features['macd'] = ema12 - ema26
|
|
|
+ features['macd_signal'] = features['macd'].ewm(span=9).mean()
|
|
|
+ features['macd_hist'] = features['macd'] - features['macd_signal']
|
|
|
+
|
|
|
+ # MACD金叉死叉(反转信号)
|
|
|
+ features['macd_golden_cross'] = ((features['macd'] > features['macd_signal']) &
|
|
|
+ (features['macd'].shift(1) <= features['macd_signal'].shift(1))).astype(int)
|
|
|
+ features['macd_death_cross'] = ((features['macd'] < features['macd_signal']) &
|
|
|
+ (features['macd'].shift(1) >= features['macd_signal'].shift(1))).astype(int)
|
|
|
+ features['macd_cross'] = features['macd_golden_cross'] - features['macd_death_cross']
|
|
|
+
|
|
|
+ # 7. 布林带
|
|
|
+ features['bb_middle'] = df['close'].rolling(20).mean()
|
|
|
+ bb_std = df['close'].rolling(20).std()
|
|
|
+ features['bb_upper'] = features['bb_middle'] + 2 * bb_std
|
|
|
+ features['bb_lower'] = features['bb_middle'] - 2 * bb_std
|
|
|
+ features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10)
|
|
|
+
|
|
|
+ # 触及布林带上下轨(反转信号)
|
|
|
+ features['bb_touch_upper'] = (df['close'] >= features['bb_upper'] * 0.99).astype(int)
|
|
|
+ features['bb_touch_lower'] = (df['close'] <= features['bb_lower'] * 1.01).astype(int)
|
|
|
+ features['bb_extreme'] = features['bb_touch_upper'] + features['bb_touch_lower']
|
|
|
+
|
|
|
+ # 8. ATR
|
|
|
+ high_low = df['high'] - df['low']
|
|
|
+ high_close = np.abs(df['high'] - df['close'].shift())
|
|
|
+ low_close = np.abs(df['low'] - df['close'].shift())
|
|
|
+ tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
|
+ features['atr_14'] = tr.rolling(14).mean()
|
|
|
+ features['atr_ratio'] = features['atr_14'] / df['close']
|
|
|
+
|
|
|
+ # 9. 成交量特征
|
|
|
+ features['volume_ratio'] = df['volume'] / df['volume'].rolling(20).mean()
|
|
|
+ features['volume_spike'] = (features['volume_ratio'] > 2).astype(int)
|
|
|
+
|
|
|
+ # 10. 趋势强度
|
|
|
+ features['adx'] = calculate_adx(df, 14)
|
|
|
+
|
|
|
+ # 11. 价格变化加速度
|
|
|
+ features['price_accel'] = df['close'].diff().diff()
|
|
|
+ features['price_accel_normalized'] = features['price_accel'] / (df['close'] * 0.01)
|
|
|
+
|
|
|
+ # 12. 日内反转强度
|
|
|
+ features['intraday_reversal'] = ((df['high'] - df['close']) / (df['high'] - df['low'] + 1e-10) -
|
|
|
+ (df['close'] - df['low']) / (df['high'] - df['low'] + 1e-10))
|
|
|
+
|
|
|
+ # 13. 连续涨跌天数
|
|
|
+ features['consecutive_up'] = (df['return'] > 0).astype(int).groupby((df['return'] <= 0).astype(int).cumsum()).cumsum()
|
|
|
+ features['consecutive_down'] = (df['return'] < 0).astype(int).groupby((df['return'] >= 0).astype(int).cumsum()).cumsum()
|
|
|
+
|
|
|
+ # 14. 新增:5日价格位置(用于判断超买超卖后的位置)
|
|
|
+ features['price_position_5d'] = (df['close'] - df['low'].rolling(5).min()) / (df['high'].rolling(5).max() - df['low'].rolling(5).min() + 1e-10)
|
|
|
+
|
|
|
+ # 填充缺失值
|
|
|
+ features = features.ffill().fillna(0)
|
|
|
+
|
|
|
+ return features
|
|
|
+
|
|
|
+
|
|
|
+def calculate_adx(df, period=14):
|
|
|
+ """计算ADX趋势强度指标"""
|
|
|
+ plus_dm = df['high'].diff()
|
|
|
+ minus_dm = df['low'].diff().abs()
|
|
|
+
|
|
|
+ plus_dm[plus_dm < 0] = 0
|
|
|
+ minus_dm[minus_dm < 0] = 0
|
|
|
+
|
|
|
+ tr = pd.concat([
|
|
|
+ df['high'] - df['low'],
|
|
|
+ (df['high'] - df['close'].shift()).abs(),
|
|
|
+ (df['low'] - df['close'].shift()).abs()
|
|
|
+ ], axis=1).max(axis=1)
|
|
|
+
|
|
|
+ atr = tr.rolling(period).mean()
|
|
|
+
|
|
|
+ plus_di = 100 * (plus_dm.rolling(period).mean() / atr)
|
|
|
+ minus_di = 100 * (minus_dm.rolling(period).mean() / atr)
|
|
|
+
|
|
|
+ dx = (abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)) * 100
|
|
|
+ adx = dx.rolling(period).mean()
|
|
|
+
|
|
|
+ return adx
|
|
|
+
|
|
|
+
|
|
|
+def define_market_regime(df, lookback=10):
|
|
|
+ """
|
|
|
+ 基于规则定义市场状态标签(最终平衡版)
|
|
|
+
|
|
|
+ 目标:反转识别率50-60%,整体准确率>72%
|
|
|
+ """
|
|
|
+ labels = []
|
|
|
+
|
|
|
+ # 预计算RSI和MACD
|
|
|
+ delta = df['close'].diff()
|
|
|
+ gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
|
|
+ loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
|
|
+ rs = gain / (loss + 1e-10)
|
|
|
+ rsi = 100 - (100 / (1 + rs))
|
|
|
+
|
|
|
+ ema12 = df['close'].ewm(span=12).mean()
|
|
|
+ ema26 = df['close'].ewm(span=26).mean()
|
|
|
+ macd = ema12 - ema26
|
|
|
+
|
|
|
+ for i in range(len(df)):
|
|
|
+ if i < lookback:
|
|
|
+ labels.append(0)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 获取回看期间数据
|
|
|
+ period_close = df['close'].iloc[i-lookback:i]
|
|
|
+ period_high = df['high'].iloc[i-lookback:i]
|
|
|
+ period_low = df['low'].iloc[i-lookback:i]
|
|
|
+ period_rsi = rsi.iloc[i-lookback:i]
|
|
|
+
|
|
|
+ start_price = period_close.iloc[0]
|
|
|
+ end_price = period_close.iloc[-1]
|
|
|
+ period_return = (end_price / start_price - 1) * 100
|
|
|
+
|
|
|
+ daily_returns = period_close.pct_change().dropna()
|
|
|
+ volatility = daily_returns.std() * np.sqrt(252) * 100
|
|
|
+
|
|
|
+ max_price = period_high.max()
|
|
|
+ min_price = period_low.min()
|
|
|
+ price_range = max_price / min_price
|
|
|
+
|
|
|
+ mid = lookback // 2
|
|
|
+ first_half_return = (period_close.iloc[mid] / start_price - 1) * 100
|
|
|
+ second_half_return = (end_price / period_close.iloc[mid] - 1) * 100
|
|
|
+
|
|
|
+ # RSI特征
|
|
|
+ rsi_start = period_rsi.iloc[0]
|
|
|
+ rsi_end = period_rsi.iloc[-1]
|
|
|
+ rsi_max = period_rsi.max()
|
|
|
+ rsi_min = period_rsi.min()
|
|
|
+ rsi_change = rsi_end - rsi_start
|
|
|
+
|
|
|
+ # 定义标签
|
|
|
+ label = 0 # 默认震荡
|
|
|
+
|
|
|
+ # ========== 反转判断(适中条件)==========
|
|
|
+ # 条件1: RSI极端值后的明显反向
|
|
|
+ condition_1 = (rsi_start > 68 and rsi_change < -18) or (rsi_start < 32 and rsi_change > 18)
|
|
|
+
|
|
|
+ # 条件2: 价格前后明显反向
|
|
|
+ condition_2 = (first_half_return * second_half_return < 0 and
|
|
|
+ abs(first_half_return) > 1.8 and abs(second_half_return) > 1.2)
|
|
|
+
|
|
|
+ # 条件3: 触及超买超卖区域
|
|
|
+ condition_3 = (rsi_max > 72 or rsi_min < 28)
|
|
|
+
|
|
|
+ # 条件4: 整体波动率适中
|
|
|
+ condition_4 = 15 < volatility < 45
|
|
|
+
|
|
|
+ # 满足至少2个条件算反转
|
|
|
+ reversal_score = sum([condition_1, condition_2, condition_3, condition_4])
|
|
|
+ if reversal_score >= 2:
|
|
|
+ label = 2
|
|
|
+
|
|
|
+ # ========== 趋势判断 ==========
|
|
|
+ elif abs(period_return) >= 3.2 and volatility < 38:
|
|
|
+ if price_range > 1.035:
|
|
|
+ if reversal_score < 2: # 不是反转
|
|
|
+ label = 1
|
|
|
+
|
|
|
+ # ========== 震荡判断(默认)=========
|
|
|
+ else:
|
|
|
+ label = 0
|
|
|
+
|
|
|
+ labels.append(label)
|
|
|
+
|
|
|
+ return np.array(labels)
|
|
|
+
|
|
|
+
|
|
|
+def train_classifier(features, labels):
|
|
|
+ """训练随机森林分类器"""
|
|
|
+ print("\n训练分类器...")
|
|
|
+
|
|
|
+ # 对齐数据
|
|
|
+ valid_idx = ~np.isnan(labels)
|
|
|
+ X = features[valid_idx]
|
|
|
+ y = labels[valid_idx]
|
|
|
+
|
|
|
+ # 分割训练集和测试集(按时间顺序)
|
|
|
+ split_idx = int(len(X) * 0.7)
|
|
|
+ X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
|
|
|
+ y_train, y_test = y[:split_idx], y[split_idx:]
|
|
|
+
|
|
|
+ print(f"训练集: {len(X_train)}条")
|
|
|
+ print(f"测试集: {len(X_test)}条")
|
|
|
+
|
|
|
+ # 训练模型 - 调整参数提高对反转的识别
|
|
|
+ clf = RandomForestClassifier(
|
|
|
+ n_estimators=200, # 增加树的数量
|
|
|
+ max_depth=15, # 增加深度
|
|
|
+ min_samples_split=10,
|
|
|
+ min_samples_leaf=5,
|
|
|
+ random_state=42,
|
|
|
+ class_weight={0: 1.0, 1: 1.2, 2: 2.0} # 给反转更高的权重
|
|
|
+ )
|
|
|
+
|
|
|
+ clf.fit(X_train, y_train)
|
|
|
+
|
|
|
+ # 评估
|
|
|
+ train_score = clf.score(X_train, y_train)
|
|
|
+ test_score = clf.score(X_test, y_test)
|
|
|
+
|
|
|
+ # 交叉验证
|
|
|
+ cv_scores = cross_val_score(clf, X, y, cv=5)
|
|
|
+
|
|
|
+ print(f"\n训练准确率: {train_score:.2%}")
|
|
|
+ print(f"测试准确率: {test_score:.2%}")
|
|
|
+ print(f"交叉验证准确率: {cv_scores.mean():.2%} (+/- {cv_scores.std()*2:.2%})")
|
|
|
+
|
|
|
+ # 详细报告
|
|
|
+ y_pred = clf.predict(X_test)
|
|
|
+ print("\n分类报告:")
|
|
|
+ print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转']))
|
|
|
+
|
|
|
+ # 混淆矩阵
|
|
|
+ cm = confusion_matrix(y_test, y_pred)
|
|
|
+ print("\n混淆矩阵:")
|
|
|
+ print(" 预测")
|
|
|
+ print("真实 震荡 趋势 反转")
|
|
|
+ for i, name in enumerate(['震荡', '趋势', '反转']):
|
|
|
+ recall = cm[i][i] / cm[i].sum() if cm[i].sum() > 0 else 0
|
|
|
+ print(f"{name:6s} {cm[i]} (召回:{recall:.1%})")
|
|
|
+
|
|
|
+ # 特征重要性
|
|
|
+ feature_importance = pd.DataFrame({
|
|
|
+ 'feature': X.columns,
|
|
|
+ 'importance': clf.feature_importances_
|
|
|
+ }).sort_values('importance', ascending=False)
|
|
|
+
|
|
|
+ print("\n特征重要性 TOP 10:")
|
|
|
+ print(feature_importance.head(10).to_string(index=False))
|
|
|
+
|
|
|
+ return clf, feature_importance
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """主程序"""
|
|
|
+ print("="*70)
|
|
|
+ print("创业板50市场状态分类器 - 真实数据版(优化反转识别V3)")
|
|
|
+ print("="*70)
|
|
|
+
|
|
|
+ # 1. 获取真实数据
|
|
|
+ df = fetch_cyb50_data("2017-01-01", "2025-12-31")
|
|
|
+ if df is None:
|
|
|
+ return
|
|
|
+
|
|
|
+ # 2. 计算特征
|
|
|
+ print("\n计算技术指标...")
|
|
|
+ features = calculate_features(df)
|
|
|
+ print(f"特征数量: {features.shape[1]}")
|
|
|
+
|
|
|
+ # 3. 定义标签
|
|
|
+ print("\n定义市场状态标签...")
|
|
|
+ labels = define_market_regime(df, lookback=10)
|
|
|
+
|
|
|
+ # 统计标签分布
|
|
|
+ unique, counts = np.unique(labels, return_counts=True)
|
|
|
+ print("\n标签分布:")
|
|
|
+ state_names = ['震荡', '趋势', '反转']
|
|
|
+ for u, c in zip(unique, counts):
|
|
|
+ print(f" {state_names[u]}: {c}天 ({c/len(labels)*100:.1f}%)")
|
|
|
+
|
|
|
+ # 4. 训练分类器
|
|
|
+ clf, importance = train_classifier(features, labels)
|
|
|
+
|
|
|
+ # 5. 当前状态预测
|
|
|
+ print("\n" + "="*70)
|
|
|
+ print("当前市场状态识别")
|
|
|
+ print("="*70)
|
|
|
+
|
|
|
+ latest_features = features.iloc[-1:]
|
|
|
+ current_pred = clf.predict(latest_features)[0]
|
|
|
+ pred_proba = clf.predict_proba(latest_features)[0]
|
|
|
+
|
|
|
+ print(f"\n当前日期: {df.index[-1].date()}")
|
|
|
+ print(f"当前价格: {df['close'].iloc[-1]:.2f}")
|
|
|
+ print(f"\n预测状态: {state_names[current_pred]}")
|
|
|
+ print(f"置信度: {pred_proba[current_pred]:.2%}")
|
|
|
+
|
|
|
+ print("\n状态概率分布:")
|
|
|
+ for i, name in enumerate(state_names):
|
|
|
+ bar = '█' * int(pred_proba[i] * 20)
|
|
|
+ print(f" {name}: {pred_proba[i]:.2%} {bar}")
|
|
|
+
|
|
|
+ # 保存模型
|
|
|
+ print("\n保存模型...")
|
|
|
+ import pickle
|
|
|
+ with open('/root/.openclaw/workspace/market-regime-identifier/rf_classifier_v3.pkl', 'wb') as f:
|
|
|
+ pickle.dump(clf, f)
|
|
|
+ print("✓ 模型已保存: rf_classifier_v3.pkl")
|
|
|
+
|
|
|
+ print("\n" + "="*70)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|