cyb50_30min_classifier_v2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50市场状态分类器 - 30分钟级别(优化版)
  5. 专为30分钟交易策略优化,增强交易信号生成
  6. """
  7. import numpy as np
  8. import pandas as pd
  9. from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
  10. from sklearn.metrics import classification_report, confusion_matrix
  11. import warnings
  12. warnings.filterwarnings('ignore')
  13. def load_5min_data(filepath='SZ#399673.txt'):
  14. """加载5分钟数据文件"""
  15. print(f"加载5分钟数据: {filepath}")
  16. df = pd.read_csv(filepath, sep='\t', skiprows=2, encoding='gbk', header=None,
  17. comment='#', on_bad_lines='skip')
  18. df.columns = ['date', 'time', 'open', 'high', 'low', 'close', 'volume', 'amount']
  19. df = df[df['date'].astype(str).str.match(r'\d{4}/\d{2}/\d{2}')].copy()
  20. def format_time(t):
  21. if pd.isna(t):
  22. return '0000'
  23. t = int(t)
  24. return f"{t:04d}"
  25. df['time_str'] = df['time'].apply(format_time)
  26. df['datetime'] = pd.to_datetime(df['date'] + ' ' + df['time_str'],
  27. format='%Y/%m/%d %H%M')
  28. df = df.set_index('datetime').sort_index()
  29. df = df.drop('time_str', axis=1)
  30. for col in ['open', 'high', 'low', 'close', 'volume', 'amount']:
  31. df[col] = pd.to_numeric(df[col], errors='coerce')
  32. print(f"[OK] 加载成功: {len(df)}条5分钟数据")
  33. print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
  34. return df
  35. def resample_to_30min(df_5min):
  36. """将5分钟数据聚合成30分钟数据"""
  37. print("\n聚合成30分钟数据...")
  38. df_30min = df_5min.resample('30min').agg({
  39. 'open': 'first',
  40. 'high': 'max',
  41. 'low': 'min',
  42. 'close': 'last',
  43. 'volume': 'sum',
  44. 'amount': 'sum'
  45. }).dropna()
  46. df_30min['return'] = df_30min['close'].pct_change()
  47. # 计算K线实体和影线
  48. df_30min['body'] = df_30min['close'] - df_30min['open']
  49. df_30min['upper_shadow'] = df_30min['high'] - df_30min[['open', 'close']].max(axis=1)
  50. df_30min['lower_shadow'] = df_30min[['open', 'close']].min(axis=1) - df_30min['low']
  51. df_30min['body_pct'] = abs(df_30min['body']) / (df_30min['high'] - df_30min['low'] + 1e-10)
  52. print(f"[OK] 聚合完成: {len(df_30min)}条30分钟数据")
  53. return df_30min
  54. def calculate_features_30min_v2(df):
  55. """优化版30分钟特征计算 - 更适合交易决策"""
  56. features = pd.DataFrame(index=df.index)
  57. features['close'] = df['close']
  58. # ========== 1. 收益率特征 ==========
  59. features['ret_1'] = df['return'] # 当前周期
  60. features['ret_2'] = df['close'].pct_change(2) # 1小时
  61. features['ret_4'] = df['close'].pct_change(4) # 2小时
  62. features['ret_8'] = df['close'].pct_change(8) # 半日
  63. features['ret_16'] = df['close'].pct_change(16) # 1日
  64. # 累计收益率
  65. features['cum_ret_4h'] = (df['close'] / df['close'].shift(8) - 1) # 4小时累计
  66. features['cum_ret_1d'] = (df['close'] / df['close'].shift(16) - 1) # 1日累计
  67. # ========== 2. 波动率特征 ==========
  68. features['volatility_4'] = df['return'].rolling(4).std() * np.sqrt(48)
  69. features['volatility_8'] = df['return'].rolling(8).std() * np.sqrt(48)
  70. features['volatility_16'] = df['return'].rolling(16).std() * np.sqrt(48)
  71. features['vol_ratio'] = features['volatility_4'] / (features['volatility_16'] + 1e-10)
  72. # 波动率变化
  73. features['vol_change'] = features['volatility_8'].diff()
  74. # ========== 3. 趋势特征 ==========
  75. # 多周期均线
  76. features['ma4'] = df['close'].rolling(4).mean() # 2小时
  77. features['ma8'] = df['close'].rolling(8).mean() # 半日
  78. features['ma16'] = df['close'].rolling(16).mean() # 1日
  79. features['ma48'] = df['close'].rolling(48).mean() # 3日
  80. # 均线关系
  81. features['ma4_above_ma8'] = (features['ma4'] > features['ma8']).astype(int)
  82. features['ma8_above_ma16'] = (features['ma8'] > features['ma16']).astype(int)
  83. features['ma_slope_4'] = features['ma4'].diff(4) / features['ma4'].shift(4) * 100
  84. # 价格与均线偏离
  85. features['dist_to_ma4'] = (df['close'] - features['ma4']) / features['ma4'] * 100
  86. features['dist_to_ma16'] = (df['close'] - features['ma16']) / features['ma16'] * 100
  87. # ========== 4. 动量指标 ==========
  88. # RSI
  89. delta = df['close'].diff()
  90. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  91. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  92. rs = gain / (loss + 1e-10)
  93. features['rsi_14'] = 100 - (100 / (1 + rs))
  94. features['rsi_change'] = features['rsi_14'].diff(2)
  95. # RSI状态
  96. features['rsi_overbought'] = (features['rsi_14'] > 70).astype(int)
  97. features['rsi_oversold'] = (features['rsi_14'] < 30).astype(int)
  98. features['rsi_neutral'] = ((features['rsi_14'] >= 40) & (features['rsi_14'] <= 60)).astype(int)
  99. # ========== 5. MACD ==========
  100. ema12 = df['close'].ewm(span=12).mean()
  101. ema26 = df['close'].ewm(span=26).mean()
  102. features['macd'] = ema12 - ema26
  103. features['macd_signal'] = features['macd'].ewm(span=9).mean()
  104. features['macd_hist'] = features['macd'] - features['macd_signal']
  105. features['macd_cross'] = ((features['macd'] > features['macd_signal']) &
  106. (features['macd'].shift(1) <= features['macd_signal'].shift(1))).astype(int)
  107. # ========== 6. 布林带 ==========
  108. features['bb_middle'] = df['close'].rolling(20).mean()
  109. bb_std = df['close'].rolling(20).std()
  110. features['bb_upper'] = features['bb_middle'] + 2 * bb_std
  111. features['bb_lower'] = features['bb_middle'] - 2 * bb_std
  112. features['bb_width'] = (features['bb_upper'] - features['bb_lower']) / features['bb_middle'] * 100
  113. features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10)
  114. # 是否触及上下轨
  115. features['touch_upper'] = (df['close'] >= features['bb_upper'] * 0.995).astype(int)
  116. features['touch_lower'] = (df['close'] <= features['bb_lower'] * 1.005).astype(int)
  117. # ========== 7. K线形态特征 ==========
  118. features['body_pct'] = df['body_pct']
  119. features['upper_shadow_ratio'] = df['upper_shadow'] / (df['high'] - df['low'] + 1e-10)
  120. features['lower_shadow_ratio'] = df['lower_shadow'] / (df['high'] - df['low'] + 1e-10)
  121. # 锤子/吊颈线识别
  122. features['hammer'] = ((features['lower_shadow_ratio'] > 0.6) &
  123. (features['body_pct'] < 0.3)).astype(int)
  124. features['hanging_man'] = ((features['upper_shadow_ratio'] > 0.6) &
  125. (features['body_pct'] < 0.3)).astype(int)
  126. # ========== 8. 成交量特征 ==========
  127. features['volume_ratio'] = df['volume'] / df['volume'].rolling(16).mean()
  128. features['volume_spike'] = (features['volume_ratio'] > 2).astype(int)
  129. features['volume_trend'] = df['volume'].rolling(8).apply(lambda x: np.polyfit(range(len(x)), x, 1)[0] if len(x) == 8 else 0)
  130. # 量价关系
  131. features['vol_price_corr'] = df['volume'].rolling(8).corr(df['close'])
  132. # ========== 9. ATR与波动 ==========
  133. high_low = df['high'] - df['low']
  134. high_close = abs(df['high'] - df['close'].shift())
  135. low_close = abs(df['low'] - df['close'].shift())
  136. tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
  137. features['atr_14'] = tr.rolling(14).mean()
  138. features['atr_ratio'] = features['atr_14'] / df['close'] * 100
  139. # ========== 10. 时间特征 ==========
  140. features['hour'] = df.index.hour
  141. features['minute'] = df.index.minute
  142. features['is_open'] = ((df.index.hour == 9) & (df.index.minute == 30)).astype(int)
  143. features['is_morning'] = ((df.index.hour >= 9) & (df.index.hour < 11)).astype(int)
  144. features['is_afternoon'] = ((df.index.hour >= 13) & (df.index.hour < 15)).astype(int)
  145. features['is_close'] = ((df.index.hour == 15) & (df.index.minute == 0)).astype(int)
  146. # ========== 11. 价格行为 ==========
  147. # 连续涨跌
  148. features['consecutive_up'] = (df['return'] > 0).astype(int).groupby(
  149. (df['return'] <= 0).astype(int).cumsum()).cumsum()
  150. features['consecutive_down'] = (df['return'] < 0).astype(int).groupby(
  151. (df['return'] >= 0).astype(int).cumsum()).cumsum()
  152. # 加速度
  153. features['price_accel'] = df['close'].diff().diff()
  154. features['return_accel'] = df['return'].diff()
  155. # ========== 12. 支撑阻力 ==========
  156. # 近期高低点
  157. features['near_high_8'] = (df['close'] >= df['high'].rolling(8).max() * 0.995).astype(int)
  158. features['near_low_8'] = (df['close'] <= df['low'].rolling(8).min() * 1.005).astype(int)
  159. # 填充缺失值
  160. features = features.ffill().fillna(0)
  161. return features
  162. def define_market_regime_30min_v2(df, features, lookback=8):
  163. """
  164. 优化版30分钟市场状态标签定义
  165. 状态定义:
  166. 0 = 震荡 - 适合观望或区间交易
  167. 1 = 趋势 - 适合趋势跟随
  168. 2 = 反转 - 适合反向交易或减仓
  169. """
  170. labels = []
  171. n = len(df)
  172. for i in range(n):
  173. if i < lookback:
  174. labels.append(0)
  175. continue
  176. # 获取回看窗口数据
  177. window_close = df['close'].iloc[i-lookback:i]
  178. window_high = df['high'].iloc[i-lookback:i]
  179. window_low = df['low'].iloc[i-lookback:i]
  180. window_rsi = features['rsi_14'].iloc[i-lookback:i]
  181. window_vol = features['volatility_4'].iloc[i-lookback:i]
  182. start_price = window_close.iloc[0]
  183. end_price = window_close.iloc[-1]
  184. period_return = (end_price / start_price - 1) * 100
  185. # 波动率
  186. volatility = window_vol.mean()
  187. # 价格区间
  188. max_price = window_high.max()
  189. min_price = window_low.min()
  190. price_range = (max_price - min_price) / start_price * 100
  191. # RSI特征
  192. rsi_start = window_rsi.iloc[0]
  193. rsi_end = window_rsi.iloc[-1]
  194. rsi_change = rsi_end - rsi_start
  195. rsi_max = window_rsi.max()
  196. rsi_min = window_rsi.min()
  197. # 判断逻辑
  198. label = 0 # 默认震荡
  199. # ===== 反转信号判断 =====
  200. reversal_signals = 0
  201. # RSI极值反转
  202. if (rsi_start > 70 and rsi_change < -15) or (rsi_start < 30 and rsi_change > 15):
  203. reversal_signals += 2
  204. elif (rsi_max > 75 or rsi_min < 25):
  205. reversal_signals += 1
  206. # 价格触及极端后回落
  207. if price_range > 4 and abs(period_return) < 1:
  208. reversal_signals += 1
  209. # RSI背离
  210. if period_return > 2 and rsi_change < -5:
  211. reversal_signals += 2
  212. elif period_return < -2 and rsi_change > 5:
  213. reversal_signals += 2
  214. # 布林带触及
  215. bb_pos = features['bb_position'].iloc[i]
  216. if (bb_pos > 0.95 and period_return < 0) or (bb_pos < 0.05 and period_return > 0):
  217. reversal_signals += 1
  218. if reversal_signals >= 3:
  219. label = 2 # 反转
  220. # ===== 趋势信号判断 =====
  221. elif label == 0: # 不是反转才判断趋势
  222. trend_signals = 0
  223. # 明显的价格方向
  224. if abs(period_return) >= 2.5:
  225. trend_signals += 2
  226. elif abs(period_return) >= 1.5:
  227. trend_signals += 1
  228. # 低波动率(趋势市场通常波动率适中)
  229. if 10 < volatility < 30:
  230. trend_signals += 1
  231. # RSI趋势一致
  232. if period_return > 0 and rsi_change > 5:
  233. trend_signals += 1
  234. elif period_return < 0 and rsi_change < -5:
  235. trend_signals += 1
  236. # 均线排列
  237. if features['ma4_above_ma8'].iloc[i] == 1 and period_return > 0:
  238. trend_signals += 1
  239. elif features['ma4_above_ma8'].iloc[i] == 0 and period_return < 0:
  240. trend_signals += 1
  241. # MACD支持
  242. if features['macd_hist'].iloc[i] * period_return > 0:
  243. trend_signals += 1
  244. if trend_signals >= 4:
  245. label = 1 # 趋势
  246. labels.append(label)
  247. return np.array(labels)
  248. def backtest_strategy(df_result, initial_capital=1000000):
  249. """
  250. 基于30分钟状态识别进行策略回测
  251. 策略规则:
  252. - 震荡:观望(不持仓)
  253. - 趋势:跟随趋势(买入或做空)
  254. - 反转:反向交易或减仓
  255. """
  256. print("\n" + "="*70)
  257. print("30分钟状态策略回测")
  258. print("="*70)
  259. # 计算收益率
  260. df_result = df_result.copy()
  261. df_result['ret'] = df_result['close'].pct_change()
  262. capital = initial_capital
  263. position = 0 # 0=空仓, 1=做多, -1=做空
  264. entry_price = 0
  265. trades = []
  266. for i in range(1, len(df_result)):
  267. current = df_result.iloc[i]
  268. prev = df_result.iloc[i-1]
  269. state = int(current['state'])
  270. price = current['close']
  271. ret = current['ret']
  272. # 状态转换信号
  273. if position == 0: # 空仓
  274. if state == 1: # 趋势 -> 开仓
  275. # 根据当前趋势方向决定多空
  276. position = 1 if ret > 0 else -1
  277. entry_price = price
  278. trades.append({
  279. 'time': df_result.index[i],
  280. 'action': 'OPEN',
  281. 'position': 'LONG' if position == 1 else 'SHORT',
  282. 'price': price
  283. })
  284. else: # 有持仓
  285. # 检查出场条件
  286. exit_signal = False
  287. if state == 2: # 反转信号 -> 出场
  288. exit_signal = True
  289. elif state == 0: # 震荡 -> 出场观望
  290. exit_signal = True
  291. elif position == 1 and ret < -0.008: # 做多止损0.8%
  292. exit_signal = True
  293. elif position == -1 and ret > 0.008: # 做空止损0.8%
  294. exit_signal = True
  295. if exit_signal:
  296. pnl = (price / entry_price - 1) * position
  297. capital *= (1 + pnl)
  298. trades.append({
  299. 'time': df_result.index[i],
  300. 'action': 'CLOSE',
  301. 'position': 'LONG' if position == 1 else 'SHORT',
  302. 'price': price,
  303. 'pnl': pnl
  304. })
  305. position = 0
  306. # 计算回测结果
  307. total_return = (capital / initial_capital - 1) * 100
  308. print(f"\n初始资金: {initial_capital:,.0f}")
  309. print(f"最终资金: {capital:,.0f}")
  310. print(f"总收益率: {total_return:+.2f}%")
  311. print(f"交易次数: {len([t for t in trades if t['action'] == 'CLOSE'])}")
  312. if len(trades) > 0:
  313. closes = [t for t in trades if t['action'] == 'CLOSE']
  314. wins = len([t for t in closes if t.get('pnl', 0) > 0])
  315. win_rate = wins / len(closes) * 100 if closes else 0
  316. print(f"胜率: {win_rate:.1f}%")
  317. return trades, capital
  318. def train_and_evaluate(df_30min, features, labels):
  319. """训练和评估模型"""
  320. print("\n训练30分钟分类器...")
  321. valid_idx = ~np.isnan(labels)
  322. X = features[valid_idx]
  323. y = labels[valid_idx]
  324. df_aligned = df_30min.iloc[valid_idx].copy()
  325. # 时间序列分割
  326. split_idx = int(len(X) * 0.8)
  327. X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
  328. y_train, y_test = y[:split_idx], y[split_idx:]
  329. print(f"训练集: {len(X_train)}条")
  330. print(f"测试集: {len(X_test)}条")
  331. # 使用GBDT(通常比RF更适合时序)
  332. clf = GradientBoostingClassifier(
  333. n_estimators=150,
  334. max_depth=5,
  335. learning_rate=0.1,
  336. random_state=42
  337. )
  338. clf.fit(X_train, y_train)
  339. train_score = clf.score(X_train, y_train)
  340. test_score = clf.score(X_test, y_test)
  341. print(f"\n训练准确率: {train_score:.2%}")
  342. print(f"测试准确率: {test_score:.2%}")
  343. y_pred = clf.predict(X_test)
  344. print("\n分类报告:")
  345. print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转']))
  346. # 预测所有数据
  347. all_pred = clf.predict(X)
  348. all_proba = clf.predict_proba(X)
  349. df_aligned['state'] = all_pred
  350. df_aligned['prob_ranging'] = all_proba[:, 0]
  351. df_aligned['prob_trend'] = all_proba[:, 1]
  352. df_aligned['prob_reversal'] = all_proba[:, 2]
  353. # 特征重要性
  354. importance = pd.DataFrame({
  355. 'feature': X.columns,
  356. 'importance': clf.feature_importances_
  357. }).sort_values('importance', ascending=False)
  358. print("\n特征重要性 TOP 15:")
  359. print(importance.head(15).to_string(index=False))
  360. return clf, df_aligned, importance
  361. def main():
  362. """主程序"""
  363. print("="*70)
  364. print("创业板50市场状态分类器 - 30分钟级别(优化版)")
  365. print("="*70)
  366. # 1. 加载数据
  367. df_5min = load_5min_data('SZ#399673.txt')
  368. # 2. 聚合30分钟
  369. df_30min = resample_to_30min(df_5min)
  370. # 3. 计算优化特征
  371. print("\n计算30分钟特征(优化版)...")
  372. features = calculate_features_30min_v2(df_30min)
  373. print(f"特征数量: {features.shape[1]}")
  374. # 4. 定义标签
  375. print("\n定义市场状态标签(优化版)...")
  376. labels = define_market_regime_30min_v2(df_30min, features, lookback=8)
  377. # 统计
  378. unique, counts = np.unique(labels, return_counts=True)
  379. print("\n标签分布:")
  380. state_names = ['震荡', '趋势', '反转']
  381. for u, c in zip(unique, counts):
  382. print(f" {state_names[u]}: {c}个周期 ({c/len(labels)*100:.1f}%)")
  383. # 5. 训练模型
  384. clf, df_result, importance = train_and_evaluate(df_30min, features, labels)
  385. # 6. 策略回测
  386. trades, final_capital = backtest_strategy(df_result)
  387. # 7. 保存结果
  388. print("\n保存结果...")
  389. df_result.to_csv('cyb50_30min_regime_v2.csv')
  390. print("[OK] 结果已保存: cyb50_30min_regime_v2.csv")
  391. import pickle
  392. with open('rf_classifier_30min_v2.pkl', 'wb') as f:
  393. pickle.dump(clf, f)
  394. print("[OK] 模型已保存: rf_classifier_30min_v2.pkl")
  395. print("\n" + "="*70)
  396. if __name__ == "__main__":
  397. main()