#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 创业板50市场状态分类器 - 真实数据版 基于规则定义标签,使用有监督学习(Random Forest) 数据源:akshare 创业板50指数 (sz399673) 标签定义基于真实价格行为规则 """ import numpy as np import pandas as pd from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split, cross_val_score from sklearn.metrics import classification_report, confusion_matrix import baostock as bs import requests from datetime import datetime, timedelta import warnings warnings.filterwarnings('ignore') def fetch_cyb50_data_baostock(start_date="2017-01-01", end_date="2025-12-31"): """从baostock获取创业板50历史数据""" print(f"[baostock] 获取创业板50数据 ({start_date} - {end_date})...") try: lg = bs.login() if lg.error_code != '0': print(f"baostock登录失败: {lg.error_msg}") return None rs = bs.query_history_k_data_plus("sz.399673", "date,open,high,low,close,volume", start_date=start_date, end_date=end_date, frequency="d", adjustflag="3") data_list = [] while (rs.error_code == '0') & rs.next(): row = rs.get_row_data() if row[0]: data_list.append({ 'date': row[0], 'open': float(row[1]) if row[1] else 0, 'high': float(row[2]) if row[2] else 0, 'low': float(row[3]) if row[3] else 0, 'close': float(row[4]) if row[4] else 0, 'volume': int(float(row[5])) if row[5] else 0 }) bs.logout() if not data_list: print("✗ baostock未获取到数据") return None df = pd.DataFrame(data_list) df['date'] = pd.to_datetime(df['date']) df = df.set_index('date').sort_index() df['return'] = df['close'].pct_change() print(f"✓ baostock获取成功: {len(df)}条数据 (至 {df.index[-1].date()})") return df[['open', 'high', 'low', 'close', 'volume', 'return']] except Exception as e: print(f"✗ baostock获取失败: {e}") return None def fetch_cyb50_data_akshare(start_date="2024-01-01", end_date=None): """从akshare获取创业板50数据(支持实时数据)""" print(f"[akshare] 获取创业板50数据...") try: import akshare as ak # 获取创业板50历史数据 # akshare的index_zh_a_hist接口,symbol="399673"为创业板50 df = ak.index_zh_a_hist(symbol="399673", period="daily", start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "") if end_date else None) if df is None or df.empty: print("✗ akshare未获取到数据") return None # 列名转换 df = df.rename(columns={ '日期': 'date', '开盘': 'open', '收盘': 'close', '最高': 'high', '最低': 'low', '成交量': 'volume' }) df['date'] = pd.to_datetime(df['date']) df = df.set_index('date').sort_index() df['return'] = df['close'].pct_change() print(f"✓ akshare获取成功: {len(df)}条数据 (至 {df.index[-1].date()})") return df[['open', 'high', 'low', 'close', 'volume', 'return']] except ImportError: print("✗ akshare未安装,尝试安装: pip install akshare") return None except Exception as e: print(f"✗ akshare获取失败: {e}") return None def fetch_cyb50_realtime_sina(): """从新浪财经获取创业板50实时数据""" print("[新浪财经] 获取创业板50实时数据...") try: # 新浪财经接口: sz399673 url = "https://hq.sinajs.cn/list=sz399673" headers = { 'Referer': 'https://finance.sina.com.cn', 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' } response = requests.get(url, headers=headers, timeout=10) response.encoding = 'gb2312' # 解析返回数据 data_str = response.text if 'var hq_str_sz399673=' not in data_str: print("✗ 新浪财经返回格式异常") return None # 提取数据部分 data_part = data_str.split('"')[1] fields = data_part.split(',') if len(fields) < 33: print("✗ 新浪财经字段不足") return None # 字段说明: # 0: 指数名称 1: 今日开盘 2: 昨日收盘 3: 当前价格 4: 今日最高 5: 今日最低 # 8: 成交量(手) 30: 日期 31: 时间 realtime_data = { 'date': fields[30], # YYYY-MM-DD 'time': fields[31], # HH:MM:SS 'open': float(fields[1]), 'high': float(fields[4]), 'low': float(fields[5]), 'close': float(fields[3]), # 当前价作为close 'pre_close': float(fields[2]), 'volume': int(float(fields[8])) } print(f"✓ 新浪财经实时数据: {realtime_data['date']} {realtime_data['time']} 收盘:{realtime_data['close']:.2f}") return realtime_data except Exception as e: print(f"✗ 新浪财经获取失败: {e}") return None def fetch_cyb50_realtime_akshare(): """从akshare获取创业板50实时数据""" print("[akshare实时] 获取创业板50实时行情...") try: import akshare as ak # 获取实时行情 df = ak.index_zh_a_spot_em() # 筛选创业板50 cyb50_row = df[df['代码'] == '399673'] if cyb50_row.empty: print("✗ akshare未找到创业板50数据") return None row = cyb50_row.iloc[0] realtime_data = { 'date': datetime.now().strftime('%Y-%m-%d'), 'time': row['时间'], 'open': float(row['开盘']), 'high': float(row['最高']), 'low': float(row['最低']), 'close': float(row['最新价']), 'pre_close': float(row['昨收']), 'volume': int(float(row['成交量'])) } print(f"✓ akshare实时数据: {realtime_data['date']} {realtime_data['time']} 收盘:{realtime_data['close']:.2f}") return realtime_data except Exception as e: print(f"✗ akshare实时获取失败: {e}") return None def merge_history_and_realtime(history_df, realtime_data): """合并历史数据和实时数据""" if history_df is None or realtime_data is None: return history_df realtime_date = pd.to_datetime(realtime_data['date']) # 检查实时数据日期是否已存在于历史数据中 if realtime_date in history_df.index: print(f"⚠️ 实时数据日期 {realtime_date.date()} 已存在于历史数据中,跳过合并") return history_df # 检查实时数据是否是下一个交易日 last_hist_date = history_df.index[-1] expected_next_date = last_hist_date + timedelta(days=1) # 处理周末和节假日 while expected_next_date.weekday() >= 5: # 5=周六, 6=周日 expected_next_date += timedelta(days=1) if realtime_date != expected_next_date and (realtime_date - last_hist_date).days > 3: print(f"⚠️ 日期跨度较大: 历史最后日期 {last_hist_date.date()}, 实时日期 {realtime_date.date()}") print(" 可能是节假日,仍尝试合并") # 创建实时数据行 new_row = pd.DataFrame({ 'open': [realtime_data['open']], 'high': [realtime_data['high']], 'low': [realtime_data['low']], 'close': [realtime_data['close']], 'volume': [realtime_data['volume']], 'return': [realtime_data['close'] / realtime_data['pre_close'] - 1] }, index=[realtime_date]) # 合并 merged_df = pd.concat([history_df, new_row]) print(f"✓ 数据合并完成: 历史{len(history_df)}条 + 实时1条 = {len(merged_df)}条") print(f" 最新日期: {merged_df.index[-1].date()} 收盘价: {merged_df['close'].iloc[-1]:.2f}") return merged_df def fetch_cyb50_data(start_date="2017-01-01", end_date="2025-12-31", use_realtime=True, prefer_source='baostock'): """ 获取创业板50数据,支持多数据源和实时数据合并 参数: start_date: 开始日期 end_date: 结束日期 use_realtime: 是否尝试获取实时数据 prefer_source: 优先使用的数据源 ('baostock', 'akshare', 'mixed') 返回: DataFrame with columns: [open, high, low, close, volume, return] """ print("="*60) print("创业板50数据获取 - 多数据源模式") print("="*60) history_df = None # 1. 获取历史数据 (T-1及之前) if prefer_source == 'baostock' or prefer_source == 'mixed': history_df = fetch_cyb50_data_baostock(start_date, end_date) if history_df is None and prefer_source == 'mixed': print("尝试备用数据源 akshare...") history_df = fetch_cyb50_data_akshare(start_date, end_date) elif prefer_source == 'akshare': history_df = fetch_cyb50_data_akshare(start_date, end_date) if history_df is None: print("✗ 历史数据获取失败") return None # 2. 获取实时数据并合并 if use_realtime: print("\n" + "-"*40) print("尝试获取今日实时数据...") print("-"*40) realtime_data = None # 尝试akshare实时数据 realtime_data = fetch_cyb50_realtime_akshare() # 如果失败,尝试新浪财经 if realtime_data is None: realtime_data = fetch_cyb50_realtime_sina() # 合并数据 if realtime_data: history_df = merge_history_and_realtime(history_df, realtime_data) else: print("⚠️ 未能获取实时数据,仅使用历史数据") print("\n" + "="*60) print(f"最终数据: {len(history_df)}条") print(f"日期范围: {history_df.index[0].date()} ~ {history_df.index[-1].date()}") print(f"价格范围: {history_df['close'].min():.2f} ~ {history_df['close'].max():.2f}") print("="*60) return history_df def calculate_features(df): """计算技术指标特征""" features = pd.DataFrame(index=df.index) # 价格特征 features['close'] = df['close'] # 1. 收益率特征 features['ret_1d'] = df['return'] features['ret_5d'] = df['close'].pct_change(5) features['ret_10d'] = df['close'].pct_change(10) features['ret_20d'] = df['close'].pct_change(20) # 2. 波动率特征 features['volatility_5d'] = df['return'].rolling(5).std() * np.sqrt(252) features['volatility_20d'] = df['return'].rolling(20).std() * np.sqrt(252) features['volatility_ratio'] = features['volatility_5d'] / (features['volatility_20d'] + 1e-10) # 3. 动量特征 features['momentum_10d'] = df['close'] / df['close'].shift(10) - 1 features['momentum_20d'] = df['close'] / df['close'].shift(20) - 1 # 4. 均线特征 features['ma5'] = df['close'].rolling(5).mean() features['ma20'] = df['close'].rolling(20).mean() features['ma60'] = df['close'].rolling(60).mean() features['ma5_above_ma20'] = (features['ma5'] > features['ma20']).astype(int) features['price_above_ma20'] = (df['close'] > features['ma20']).astype(int) # 5. RSI delta = df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(14).mean() loss = (-delta.where(delta < 0, 0)).rolling(14).mean() rs = gain / (loss + 1e-10) features['rsi_14'] = 100 - (100 / (1 + rs)) # 6. MACD ema12 = df['close'].ewm(span=12).mean() ema26 = df['close'].ewm(span=26).mean() features['macd'] = ema12 - ema26 features['macd_signal'] = features['macd'].ewm(span=9).mean() features['macd_hist'] = features['macd'] - features['macd_signal'] # 7. 布林带 features['bb_middle'] = df['close'].rolling(20).mean() bb_std = df['close'].rolling(20).std() features['bb_upper'] = features['bb_middle'] + 2 * bb_std features['bb_lower'] = features['bb_middle'] - 2 * bb_std features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10) # 8. ATR (平均真实波幅) high_low = df['high'] - df['low'] high_close = np.abs(df['high'] - df['close'].shift()) low_close = np.abs(df['low'] - df['close'].shift()) tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1) features['atr_14'] = tr.rolling(14).mean() features['atr_ratio'] = features['atr_14'] / df['close'] # 9. 成交量特征 features['volume_ratio'] = df['volume'] / df['volume'].rolling(20).mean() # 10. 趋势强度 features['adx'] = calculate_adx(df, 14) # 填充缺失值 features = features.ffill().fillna(0) return features def calculate_adx(df, period=14): """计算ADX趋势强度指标""" plus_dm = df['high'].diff() minus_dm = df['low'].diff().abs() plus_dm[plus_dm < 0] = 0 minus_dm[minus_dm < 0] = 0 tr = pd.concat([ df['high'] - df['low'], (df['high'] - df['close'].shift()).abs(), (df['low'] - df['close'].shift()).abs() ], axis=1).max(axis=1) atr = tr.rolling(period).mean() plus_di = 100 * (plus_dm.rolling(period).mean() / atr) minus_di = 100 * (minus_dm.rolling(period).mean() / atr) dx = (abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)) * 100 adx = dx.rolling(period).mean() return adx def define_market_regime(df, lookback=10): """ 基于规则定义市场状态标签(优化版V2) 优化目标: - 使三类分布更均衡(震荡 40-50%,趋势 30-40%,反转 10-20%) - 测试准确率 > 72% 规则(按优先级排序): 1. 反转 (2): 前N/2日收益 >= 2.5% 且后N/2日收益 <= -2%,或相反 2. 趋势 (1): |N日收益| >= 4%, 波动率 < 35%,且有方向性 3. 震荡 (0): 其余情况 """ labels = [] for i in range(len(df)): if i < lookback: labels.append(0) continue period_close = df['close'].iloc[i-lookback:i] period_high = df['high'].iloc[i-lookback:i] period_low = df['low'].iloc[i-lookback:i] start_price = period_close.iloc[0] end_price = period_close.iloc[-1] period_return = (end_price / start_price - 1) * 100 daily_returns = period_close.pct_change().dropna() volatility = daily_returns.std() * np.sqrt(252) * 100 max_price = period_high.max() min_price = period_low.min() price_range = max_price / min_price mid = lookback // 2 first_half_return = (period_close.iloc[mid] / start_price - 1) * 100 second_half_return = (end_price / period_close.iloc[mid] - 1) * 100 label = 0 # 默认震荡 # ========== 反转判断(严格的V型反转)========== # 需要前后两段都有明显的反向运动 if (first_half_return >= 2.5 and second_half_return <= -2.0) or \ (first_half_return <= -2.5 and second_half_return >= 2.0): # 反转需要整体有一定的波动 if volatility > 20 and price_range > 1.04: label = 2 # ========== 趋势判断(需要明显的方向性)========== elif abs(period_return) >= 4.0 and volatility < 35: # 趋势期间高低点差距要明显 if price_range > 1.04: # 排除V型反转(前后反向) if not (abs(first_half_return) > 3 and abs(second_half_return) > 2 and np.sign(first_half_return) != np.sign(second_half_return)): label = 1 # ========== 震荡(默认)========== else: label = 0 labels.append(label) return np.array(labels) def train_classifier(features, labels): """训练随机森林分类器""" print("\n训练分类器...") # 对齐数据 valid_idx = ~np.isnan(labels) X = features[valid_idx] y = labels[valid_idx] # 分割训练集和测试集(按时间顺序) split_idx = int(len(X) * 0.7) X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:] y_train, y_test = y[:split_idx], y[split_idx:] print(f"训练集: {len(X_train)}条") print(f"测试集: {len(X_test)}条") # 训练模型 clf = RandomForestClassifier( n_estimators=100, max_depth=10, min_samples_split=20, min_samples_leaf=10, random_state=42, class_weight='balanced' ) clf.fit(X_train, y_train) # 评估 train_score = clf.score(X_train, y_train) test_score = clf.score(X_test, y_test) print(f"\n训练准确率: {train_score:.2%}") print(f"测试准确率: {test_score:.2%}") # 交叉验证 cv_scores = cross_val_score(clf, X, y, cv=5) print(f"交叉验证准确率: {cv_scores.mean():.2%} (+/- {cv_scores.std()*2:.2%})") # 详细报告 y_pred = clf.predict(X_test) print("\n分类报告:") print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转'])) # 特征重要性 feature_importance = pd.DataFrame({ 'feature': X.columns, 'importance': clf.feature_importances_ }).sort_values('importance', ascending=False) print("\n特征重要性 TOP 10:") print(feature_importance.head(10).to_string(index=False)) return clf, feature_importance def main(): """主程序""" print("="*70) print("创业板50市场状态分类器 - 真实数据版") print("="*70) # 1. 获取真实数据 df = fetch_cyb50_data("2017-01-01", "2025-12-31") if df is None: return # 2. 计算特征 print("\n计算技术指标...") features = calculate_features(df) print(f"特征数量: {features.shape[1]}") # 3. 定义标签 print("\n定义市场状态标签...") labels = define_market_regime(df, lookback=10) # 统计标签分布 unique, counts = np.unique(labels, return_counts=True) print("\n标签分布:") state_names = ['震荡', '趋势', '反转'] for u, c in zip(unique, counts): print(f" {state_names[u]}: {c}天 ({c/len(labels)*100:.1f}%)") # 4. 训练分类器 clf, importance = train_classifier(features, labels) # 5. 当前状态预测 print("\n" + "="*70) print("当前市场状态识别") print("="*70) latest_features = features.iloc[-1:] current_pred = clf.predict(latest_features)[0] pred_proba = clf.predict_proba(latest_features)[0] print(f"\n当前日期: {df.index[-1].date()}") print(f"当前价格: {df['close'].iloc[-1]:.2f}") print(f"\n预测状态: {state_names[current_pred]}") print(f"置信度: {pred_proba[current_pred]:.2%}") print("\n状态概率分布:") for i, name in enumerate(state_names): bar = '█' * int(pred_proba[i] * 20) print(f" {name}: {pred_proba[i]:.2%} {bar}") # 保存模型 print("\n保存模型...") import pickle with open('/root/.openclaw/workspace/market-regime-identifier/rf_classifier.pkl', 'wb') as f: pickle.dump(clf, f) print("✓ 模型已保存: rf_classifier.pkl") print("\n" + "="*70) if __name__ == "__main__": main()