cyb50_market_classifier.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50市场状态分类器 - 真实数据版
  5. 基于规则定义标签,使用有监督学习(Random Forest)
  6. 数据源:akshare 创业板50指数 (sz399673)
  7. 标签定义基于真实价格行为规则
  8. """
  9. import numpy as np
  10. import pandas as pd
  11. from sklearn.ensemble import RandomForestClassifier
  12. from sklearn.model_selection import train_test_split, cross_val_score
  13. from sklearn.metrics import classification_report, confusion_matrix
  14. import baostock as bs
  15. import requests
  16. from datetime import datetime, timedelta
  17. import warnings
  18. warnings.filterwarnings('ignore')
  19. def fetch_cyb50_data_baostock(start_date="2017-01-01", end_date="2025-12-31"):
  20. """从baostock获取创业板50历史数据"""
  21. print(f"[baostock] 获取创业板50数据 ({start_date} - {end_date})...")
  22. try:
  23. lg = bs.login()
  24. if lg.error_code != '0':
  25. print(f"baostock登录失败: {lg.error_msg}")
  26. return None
  27. rs = bs.query_history_k_data_plus("sz.399673",
  28. "date,open,high,low,close,volume",
  29. start_date=start_date, end_date=end_date,
  30. frequency="d", adjustflag="3")
  31. data_list = []
  32. while (rs.error_code == '0') & rs.next():
  33. row = rs.get_row_data()
  34. if row[0]:
  35. data_list.append({
  36. 'date': row[0],
  37. 'open': float(row[1]) if row[1] else 0,
  38. 'high': float(row[2]) if row[2] else 0,
  39. 'low': float(row[3]) if row[3] else 0,
  40. 'close': float(row[4]) if row[4] else 0,
  41. 'volume': int(float(row[5])) if row[5] else 0
  42. })
  43. bs.logout()
  44. if not data_list:
  45. print("✗ baostock未获取到数据")
  46. return None
  47. df = pd.DataFrame(data_list)
  48. df['date'] = pd.to_datetime(df['date'])
  49. df = df.set_index('date').sort_index()
  50. df['return'] = df['close'].pct_change()
  51. print(f"✓ baostock获取成功: {len(df)}条数据 (至 {df.index[-1].date()})")
  52. return df[['open', 'high', 'low', 'close', 'volume', 'return']]
  53. except Exception as e:
  54. print(f"✗ baostock获取失败: {e}")
  55. return None
  56. def fetch_cyb50_data_akshare(start_date="2024-01-01", end_date=None):
  57. """从akshare获取创业板50数据(支持实时数据)"""
  58. print(f"[akshare] 获取创业板50数据...")
  59. try:
  60. import akshare as ak
  61. # 获取创业板50历史数据
  62. # akshare的index_zh_a_hist接口,symbol="399673"为创业板50
  63. df = ak.index_zh_a_hist(symbol="399673", period="daily",
  64. start_date=start_date.replace("-", ""),
  65. end_date=end_date.replace("-", "") if end_date else None)
  66. if df is None or df.empty:
  67. print("✗ akshare未获取到数据")
  68. return None
  69. # 列名转换
  70. df = df.rename(columns={
  71. '日期': 'date',
  72. '开盘': 'open',
  73. '收盘': 'close',
  74. '最高': 'high',
  75. '最低': 'low',
  76. '成交量': 'volume'
  77. })
  78. df['date'] = pd.to_datetime(df['date'])
  79. df = df.set_index('date').sort_index()
  80. df['return'] = df['close'].pct_change()
  81. print(f"✓ akshare获取成功: {len(df)}条数据 (至 {df.index[-1].date()})")
  82. return df[['open', 'high', 'low', 'close', 'volume', 'return']]
  83. except ImportError:
  84. print("✗ akshare未安装,尝试安装: pip install akshare")
  85. return None
  86. except Exception as e:
  87. print(f"✗ akshare获取失败: {e}")
  88. return None
  89. def fetch_cyb50_realtime_sina():
  90. """从新浪财经获取创业板50实时数据"""
  91. print("[新浪财经] 获取创业板50实时数据...")
  92. try:
  93. # 新浪财经接口: sz399673
  94. url = "https://hq.sinajs.cn/list=sz399673"
  95. headers = {
  96. 'Referer': 'https://finance.sina.com.cn',
  97. 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
  98. }
  99. response = requests.get(url, headers=headers, timeout=10)
  100. response.encoding = 'gb2312'
  101. # 解析返回数据
  102. data_str = response.text
  103. if 'var hq_str_sz399673=' not in data_str:
  104. print("✗ 新浪财经返回格式异常")
  105. return None
  106. # 提取数据部分
  107. data_part = data_str.split('"')[1]
  108. fields = data_part.split(',')
  109. if len(fields) < 33:
  110. print("✗ 新浪财经字段不足")
  111. return None
  112. # 字段说明:
  113. # 0: 指数名称 1: 今日开盘 2: 昨日收盘 3: 当前价格 4: 今日最高 5: 今日最低
  114. # 8: 成交量(手) 30: 日期 31: 时间
  115. realtime_data = {
  116. 'date': fields[30], # YYYY-MM-DD
  117. 'time': fields[31], # HH:MM:SS
  118. 'open': float(fields[1]),
  119. 'high': float(fields[4]),
  120. 'low': float(fields[5]),
  121. 'close': float(fields[3]), # 当前价作为close
  122. 'pre_close': float(fields[2]),
  123. 'volume': int(float(fields[8]))
  124. }
  125. print(f"✓ 新浪财经实时数据: {realtime_data['date']} {realtime_data['time']} 收盘:{realtime_data['close']:.2f}")
  126. return realtime_data
  127. except Exception as e:
  128. print(f"✗ 新浪财经获取失败: {e}")
  129. return None
  130. def fetch_cyb50_realtime_akshare():
  131. """从akshare获取创业板50实时数据"""
  132. print("[akshare实时] 获取创业板50实时行情...")
  133. try:
  134. import akshare as ak
  135. # 获取实时行情
  136. df = ak.index_zh_a_spot_em()
  137. # 筛选创业板50
  138. cyb50_row = df[df['代码'] == '399673']
  139. if cyb50_row.empty:
  140. print("✗ akshare未找到创业板50数据")
  141. return None
  142. row = cyb50_row.iloc[0]
  143. realtime_data = {
  144. 'date': datetime.now().strftime('%Y-%m-%d'),
  145. 'time': row['时间'],
  146. 'open': float(row['开盘']),
  147. 'high': float(row['最高']),
  148. 'low': float(row['最低']),
  149. 'close': float(row['最新价']),
  150. 'pre_close': float(row['昨收']),
  151. 'volume': int(float(row['成交量']))
  152. }
  153. print(f"✓ akshare实时数据: {realtime_data['date']} {realtime_data['time']} 收盘:{realtime_data['close']:.2f}")
  154. return realtime_data
  155. except Exception as e:
  156. print(f"✗ akshare实时获取失败: {e}")
  157. return None
  158. def merge_history_and_realtime(history_df, realtime_data):
  159. """合并历史数据和实时数据"""
  160. if history_df is None or realtime_data is None:
  161. return history_df
  162. realtime_date = pd.to_datetime(realtime_data['date'])
  163. # 检查实时数据日期是否已存在于历史数据中
  164. if realtime_date in history_df.index:
  165. print(f"⚠️ 实时数据日期 {realtime_date.date()} 已存在于历史数据中,跳过合并")
  166. return history_df
  167. # 检查实时数据是否是下一个交易日
  168. last_hist_date = history_df.index[-1]
  169. expected_next_date = last_hist_date + timedelta(days=1)
  170. # 处理周末和节假日
  171. while expected_next_date.weekday() >= 5: # 5=周六, 6=周日
  172. expected_next_date += timedelta(days=1)
  173. if realtime_date != expected_next_date and (realtime_date - last_hist_date).days > 3:
  174. print(f"⚠️ 日期跨度较大: 历史最后日期 {last_hist_date.date()}, 实时日期 {realtime_date.date()}")
  175. print(" 可能是节假日,仍尝试合并")
  176. # 创建实时数据行
  177. new_row = pd.DataFrame({
  178. 'open': [realtime_data['open']],
  179. 'high': [realtime_data['high']],
  180. 'low': [realtime_data['low']],
  181. 'close': [realtime_data['close']],
  182. 'volume': [realtime_data['volume']],
  183. 'return': [realtime_data['close'] / realtime_data['pre_close'] - 1]
  184. }, index=[realtime_date])
  185. # 合并
  186. merged_df = pd.concat([history_df, new_row])
  187. print(f"✓ 数据合并完成: 历史{len(history_df)}条 + 实时1条 = {len(merged_df)}条")
  188. print(f" 最新日期: {merged_df.index[-1].date()} 收盘价: {merged_df['close'].iloc[-1]:.2f}")
  189. return merged_df
  190. def fetch_cyb50_data(start_date="2017-01-01", end_date="2025-12-31",
  191. use_realtime=True, prefer_source='baostock'):
  192. """
  193. 获取创业板50数据,支持多数据源和实时数据合并
  194. 参数:
  195. start_date: 开始日期
  196. end_date: 结束日期
  197. use_realtime: 是否尝试获取实时数据
  198. prefer_source: 优先使用的数据源 ('baostock', 'akshare', 'mixed')
  199. 返回:
  200. DataFrame with columns: [open, high, low, close, volume, return]
  201. """
  202. print("="*60)
  203. print("创业板50数据获取 - 多数据源模式")
  204. print("="*60)
  205. history_df = None
  206. # 1. 获取历史数据 (T-1及之前)
  207. if prefer_source == 'baostock' or prefer_source == 'mixed':
  208. history_df = fetch_cyb50_data_baostock(start_date, end_date)
  209. if history_df is None and prefer_source == 'mixed':
  210. print("尝试备用数据源 akshare...")
  211. history_df = fetch_cyb50_data_akshare(start_date, end_date)
  212. elif prefer_source == 'akshare':
  213. history_df = fetch_cyb50_data_akshare(start_date, end_date)
  214. if history_df is None:
  215. print("✗ 历史数据获取失败")
  216. return None
  217. # 2. 获取实时数据并合并
  218. if use_realtime:
  219. print("\n" + "-"*40)
  220. print("尝试获取今日实时数据...")
  221. print("-"*40)
  222. realtime_data = None
  223. # 尝试akshare实时数据
  224. realtime_data = fetch_cyb50_realtime_akshare()
  225. # 如果失败,尝试新浪财经
  226. if realtime_data is None:
  227. realtime_data = fetch_cyb50_realtime_sina()
  228. # 合并数据
  229. if realtime_data:
  230. history_df = merge_history_and_realtime(history_df, realtime_data)
  231. else:
  232. print("⚠️ 未能获取实时数据,仅使用历史数据")
  233. print("\n" + "="*60)
  234. print(f"最终数据: {len(history_df)}条")
  235. print(f"日期范围: {history_df.index[0].date()} ~ {history_df.index[-1].date()}")
  236. print(f"价格范围: {history_df['close'].min():.2f} ~ {history_df['close'].max():.2f}")
  237. print("="*60)
  238. return history_df
  239. def calculate_features(df):
  240. """计算技术指标特征"""
  241. features = pd.DataFrame(index=df.index)
  242. # 价格特征
  243. features['close'] = df['close']
  244. # 1. 收益率特征
  245. features['ret_1d'] = df['return']
  246. features['ret_5d'] = df['close'].pct_change(5)
  247. features['ret_10d'] = df['close'].pct_change(10)
  248. features['ret_20d'] = df['close'].pct_change(20)
  249. # 2. 波动率特征
  250. features['volatility_5d'] = df['return'].rolling(5).std() * np.sqrt(252)
  251. features['volatility_20d'] = df['return'].rolling(20).std() * np.sqrt(252)
  252. features['volatility_ratio'] = features['volatility_5d'] / (features['volatility_20d'] + 1e-10)
  253. # 3. 动量特征
  254. features['momentum_10d'] = df['close'] / df['close'].shift(10) - 1
  255. features['momentum_20d'] = df['close'] / df['close'].shift(20) - 1
  256. # 4. 均线特征
  257. features['ma5'] = df['close'].rolling(5).mean()
  258. features['ma20'] = df['close'].rolling(20).mean()
  259. features['ma60'] = df['close'].rolling(60).mean()
  260. features['ma5_above_ma20'] = (features['ma5'] > features['ma20']).astype(int)
  261. features['price_above_ma20'] = (df['close'] > features['ma20']).astype(int)
  262. # 5. RSI
  263. delta = df['close'].diff()
  264. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  265. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  266. rs = gain / (loss + 1e-10)
  267. features['rsi_14'] = 100 - (100 / (1 + rs))
  268. # 6. MACD
  269. ema12 = df['close'].ewm(span=12).mean()
  270. ema26 = df['close'].ewm(span=26).mean()
  271. features['macd'] = ema12 - ema26
  272. features['macd_signal'] = features['macd'].ewm(span=9).mean()
  273. features['macd_hist'] = features['macd'] - features['macd_signal']
  274. # 7. 布林带
  275. features['bb_middle'] = df['close'].rolling(20).mean()
  276. bb_std = df['close'].rolling(20).std()
  277. features['bb_upper'] = features['bb_middle'] + 2 * bb_std
  278. features['bb_lower'] = features['bb_middle'] - 2 * bb_std
  279. features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10)
  280. # 8. ATR (平均真实波幅)
  281. high_low = df['high'] - df['low']
  282. high_close = np.abs(df['high'] - df['close'].shift())
  283. low_close = np.abs(df['low'] - df['close'].shift())
  284. tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
  285. features['atr_14'] = tr.rolling(14).mean()
  286. features['atr_ratio'] = features['atr_14'] / df['close']
  287. # 9. 成交量特征
  288. features['volume_ratio'] = df['volume'] / df['volume'].rolling(20).mean()
  289. # 10. 趋势强度
  290. features['adx'] = calculate_adx(df, 14)
  291. # 填充缺失值
  292. features = features.ffill().fillna(0)
  293. return features
  294. def calculate_adx(df, period=14):
  295. """计算ADX趋势强度指标"""
  296. plus_dm = df['high'].diff()
  297. minus_dm = df['low'].diff().abs()
  298. plus_dm[plus_dm < 0] = 0
  299. minus_dm[minus_dm < 0] = 0
  300. tr = pd.concat([
  301. df['high'] - df['low'],
  302. (df['high'] - df['close'].shift()).abs(),
  303. (df['low'] - df['close'].shift()).abs()
  304. ], axis=1).max(axis=1)
  305. atr = tr.rolling(period).mean()
  306. plus_di = 100 * (plus_dm.rolling(period).mean() / atr)
  307. minus_di = 100 * (minus_dm.rolling(period).mean() / atr)
  308. dx = (abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)) * 100
  309. adx = dx.rolling(period).mean()
  310. return adx
  311. def define_market_regime(df, lookback=10):
  312. """
  313. 基于规则定义市场状态标签(优化版V2)
  314. 优化目标:
  315. - 使三类分布更均衡(震荡 40-50%,趋势 30-40%,反转 10-20%)
  316. - 测试准确率 > 72%
  317. 规则(按优先级排序):
  318. 1. 反转 (2): 前N/2日收益 >= 2.5% 且后N/2日收益 <= -2%,或相反
  319. 2. 趋势 (1): |N日收益| >= 4%, 波动率 < 35%,且有方向性
  320. 3. 震荡 (0): 其余情况
  321. """
  322. labels = []
  323. for i in range(len(df)):
  324. if i < lookback:
  325. labels.append(0)
  326. continue
  327. period_close = df['close'].iloc[i-lookback:i]
  328. period_high = df['high'].iloc[i-lookback:i]
  329. period_low = df['low'].iloc[i-lookback:i]
  330. start_price = period_close.iloc[0]
  331. end_price = period_close.iloc[-1]
  332. period_return = (end_price / start_price - 1) * 100
  333. daily_returns = period_close.pct_change().dropna()
  334. volatility = daily_returns.std() * np.sqrt(252) * 100
  335. max_price = period_high.max()
  336. min_price = period_low.min()
  337. price_range = max_price / min_price
  338. mid = lookback // 2
  339. first_half_return = (period_close.iloc[mid] / start_price - 1) * 100
  340. second_half_return = (end_price / period_close.iloc[mid] - 1) * 100
  341. label = 0 # 默认震荡
  342. # ========== 反转判断(严格的V型反转)==========
  343. # 需要前后两段都有明显的反向运动
  344. if (first_half_return >= 2.5 and second_half_return <= -2.0) or \
  345. (first_half_return <= -2.5 and second_half_return >= 2.0):
  346. # 反转需要整体有一定的波动
  347. if volatility > 20 and price_range > 1.04:
  348. label = 2
  349. # ========== 趋势判断(需要明显的方向性)==========
  350. elif abs(period_return) >= 4.0 and volatility < 35:
  351. # 趋势期间高低点差距要明显
  352. if price_range > 1.04:
  353. # 排除V型反转(前后反向)
  354. if not (abs(first_half_return) > 3 and abs(second_half_return) > 2 and
  355. np.sign(first_half_return) != np.sign(second_half_return)):
  356. label = 1
  357. # ========== 震荡(默认)==========
  358. else:
  359. label = 0
  360. labels.append(label)
  361. return np.array(labels)
  362. def train_classifier(features, labels):
  363. """训练随机森林分类器"""
  364. print("\n训练分类器...")
  365. # 对齐数据
  366. valid_idx = ~np.isnan(labels)
  367. X = features[valid_idx]
  368. y = labels[valid_idx]
  369. # 分割训练集和测试集(按时间顺序)
  370. split_idx = int(len(X) * 0.7)
  371. X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
  372. y_train, y_test = y[:split_idx], y[split_idx:]
  373. print(f"训练集: {len(X_train)}条")
  374. print(f"测试集: {len(X_test)}条")
  375. # 训练模型
  376. clf = RandomForestClassifier(
  377. n_estimators=100,
  378. max_depth=10,
  379. min_samples_split=20,
  380. min_samples_leaf=10,
  381. random_state=42,
  382. class_weight='balanced'
  383. )
  384. clf.fit(X_train, y_train)
  385. # 评估
  386. train_score = clf.score(X_train, y_train)
  387. test_score = clf.score(X_test, y_test)
  388. print(f"\n训练准确率: {train_score:.2%}")
  389. print(f"测试准确率: {test_score:.2%}")
  390. # 交叉验证
  391. cv_scores = cross_val_score(clf, X, y, cv=5)
  392. print(f"交叉验证准确率: {cv_scores.mean():.2%} (+/- {cv_scores.std()*2:.2%})")
  393. # 详细报告
  394. y_pred = clf.predict(X_test)
  395. print("\n分类报告:")
  396. print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转']))
  397. # 特征重要性
  398. feature_importance = pd.DataFrame({
  399. 'feature': X.columns,
  400. 'importance': clf.feature_importances_
  401. }).sort_values('importance', ascending=False)
  402. print("\n特征重要性 TOP 10:")
  403. print(feature_importance.head(10).to_string(index=False))
  404. return clf, feature_importance
  405. def main():
  406. """主程序"""
  407. print("="*70)
  408. print("创业板50市场状态分类器 - 真实数据版")
  409. print("="*70)
  410. # 1. 获取真实数据
  411. df = fetch_cyb50_data("2017-01-01", "2025-12-31")
  412. if df is None:
  413. return
  414. # 2. 计算特征
  415. print("\n计算技术指标...")
  416. features = calculate_features(df)
  417. print(f"特征数量: {features.shape[1]}")
  418. # 3. 定义标签
  419. print("\n定义市场状态标签...")
  420. labels = define_market_regime(df, lookback=10)
  421. # 统计标签分布
  422. unique, counts = np.unique(labels, return_counts=True)
  423. print("\n标签分布:")
  424. state_names = ['震荡', '趋势', '反转']
  425. for u, c in zip(unique, counts):
  426. print(f" {state_names[u]}: {c}天 ({c/len(labels)*100:.1f}%)")
  427. # 4. 训练分类器
  428. clf, importance = train_classifier(features, labels)
  429. # 5. 当前状态预测
  430. print("\n" + "="*70)
  431. print("当前市场状态识别")
  432. print("="*70)
  433. latest_features = features.iloc[-1:]
  434. current_pred = clf.predict(latest_features)[0]
  435. pred_proba = clf.predict_proba(latest_features)[0]
  436. print(f"\n当前日期: {df.index[-1].date()}")
  437. print(f"当前价格: {df['close'].iloc[-1]:.2f}")
  438. print(f"\n预测状态: {state_names[current_pred]}")
  439. print(f"置信度: {pred_proba[current_pred]:.2%}")
  440. print("\n状态概率分布:")
  441. for i, name in enumerate(state_names):
  442. bar = '█' * int(pred_proba[i] * 20)
  443. print(f" {name}: {pred_proba[i]:.2%} {bar}")
  444. # 保存模型
  445. print("\n保存模型...")
  446. import pickle
  447. with open('/root/.openclaw/workspace/market-regime-identifier/rf_classifier.pkl', 'wb') as f:
  448. pickle.dump(clf, f)
  449. print("✓ 模型已保存: rf_classifier.pkl")
  450. print("\n" + "="*70)
  451. if __name__ == "__main__":
  452. main()