cyb50_30min_classifier.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创业板50市场状态分类器 - 30分钟级别
  5. 基于本地5分钟数据文件,聚合成30分钟K线
  6. """
  7. import numpy as np
  8. import pandas as pd
  9. from sklearn.ensemble import RandomForestClassifier
  10. from sklearn.metrics import classification_report, confusion_matrix
  11. import warnings
  12. warnings.filterwarnings('ignore')
  13. def load_5min_data(filepath='SZ#399673.txt'):
  14. """加载5分钟数据文件"""
  15. print(f"加载5分钟数据: {filepath}")
  16. # 读取数据,跳过前两行(标题行),过滤注释行
  17. df = pd.read_csv(filepath, sep='\t', skiprows=2, encoding='gbk', header=None,
  18. comment='#', on_bad_lines='skip')
  19. # 指定列名
  20. df.columns = ['date', 'time', 'open', 'high', 'low', 'close', 'volume', 'amount']
  21. # 过滤掉包含非日期数据的行
  22. df = df[df['date'].astype(str).str.match(r'\d{4}/\d{2}/\d{2}')].copy()
  23. # 创建datetime索引
  24. # 处理time列: 如果是数字,格式化为4位时间字符串
  25. def format_time(t):
  26. if pd.isna(t):
  27. return '0000'
  28. t = int(t)
  29. return f"{t:04d}"
  30. df['time_str'] = df['time'].apply(format_time)
  31. df['datetime'] = pd.to_datetime(df['date'] + ' ' + df['time_str'],
  32. format='%Y/%m/%d %H%M')
  33. df = df.set_index('datetime').sort_index()
  34. df = df.drop('time_str', axis=1)
  35. # 转换为数值类型
  36. for col in ['open', 'high', 'low', 'close', 'volume', 'amount']:
  37. df[col] = pd.to_numeric(df[col], errors='coerce')
  38. print(f"[OK] 加载成功: {len(df)}条5分钟数据")
  39. print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
  40. print(f" 价格范围: {df['close'].min():.2f} ~ {df['close'].max():.2f}")
  41. return df
  42. def resample_to_30min(df_5min):
  43. """将5分钟数据聚合成30分钟数据"""
  44. print("\n聚合成30分钟数据...")
  45. # 30分钟重采样规则
  46. df_30min = df_5min.resample('30min').agg({
  47. 'open': 'first',
  48. 'high': 'max',
  49. 'low': 'min',
  50. 'close': 'last',
  51. 'volume': 'sum',
  52. 'amount': 'sum'
  53. }).dropna()
  54. # 计算收益率
  55. df_30min['return'] = df_30min['close'].pct_change()
  56. print(f"[OK] 聚合完成: {len(df_30min)}条30分钟数据")
  57. return df_30min
  58. def calculate_features_30min(df):
  59. """计算30分钟级别的技术指标特征"""
  60. features = pd.DataFrame(index=df.index)
  61. # 价格特征
  62. features['close'] = df['close']
  63. # 1. 收益率特征(30分钟周期)
  64. features['ret_1'] = df['return'] # 1个30分钟周期
  65. features['ret_4'] = df['close'].pct_change(4) # 2小时
  66. features['ret_8'] = df['close'].pct_change(8) # 4小时(半日)
  67. features['ret_16'] = df['close'].pct_change(16) # 8小时(1个交易日)
  68. # 2. 波动率特征(30分钟周期)
  69. features['volatility_4'] = df['return'].rolling(4).std() * np.sqrt(48) # 2小时波动率年化
  70. features['volatility_16'] = df['return'].rolling(16).std() * np.sqrt(48) # 日波动率年化
  71. features['volatility_ratio'] = features['volatility_4'] / (features['volatility_16'] + 1e-10)
  72. # 3. 动量特征
  73. features['momentum_8'] = df['close'] / df['close'].shift(8) - 1 # 4小时动量
  74. features['momentum_16'] = df['close'] / df['close'].shift(16) - 1 # 日动量
  75. # 4. 均线特征(30分钟周期)
  76. features['ma4'] = df['close'].rolling(4).mean() # 2小时均线
  77. features['ma16'] = df['close'].rolling(16).mean() # 日均线
  78. features['ma48'] = df['close'].rolling(48).mean() # 3日均线
  79. features['ma4_above_ma16'] = (features['ma4'] > features['ma16']).astype(int)
  80. # 5. RSI(14个30分钟周期 = 7小时)
  81. delta = df['close'].diff()
  82. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  83. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  84. rs = gain / (loss + 1e-10)
  85. features['rsi_14'] = 100 - (100 / (1 + rs))
  86. features['rsi_overbought'] = (features['rsi_14'] > 70).astype(int)
  87. features['rsi_oversold'] = (features['rsi_14'] < 30).astype(int)
  88. # 6. MACD
  89. ema12 = df['close'].ewm(span=12).mean()
  90. ema26 = df['close'].ewm(span=26).mean()
  91. features['macd'] = ema12 - ema26
  92. features['macd_signal'] = features['macd'].ewm(span=9).mean()
  93. features['macd_hist'] = features['macd'] - features['macd_signal']
  94. # 7. 布林带
  95. features['bb_middle'] = df['close'].rolling(20).mean()
  96. bb_std = df['close'].rolling(20).std()
  97. features['bb_upper'] = features['bb_middle'] + 2 * bb_std
  98. features['bb_lower'] = features['bb_middle'] - 2 * bb_std
  99. features['bb_position'] = (df['close'] - features['bb_lower']) / (features['bb_upper'] - features['bb_lower'] + 1e-10)
  100. # 8. 成交量特征
  101. features['volume_ratio'] = df['volume'] / df['volume'].rolling(16).mean()
  102. features['volume_spike'] = (features['volume_ratio'] > 2).astype(int)
  103. # 9. 趋势强度(ADX近似)
  104. high_low = df['high'] - df['low']
  105. features['atr_14'] = high_low.rolling(14).mean()
  106. features['atr_ratio'] = features['atr_14'] / df['close']
  107. # 10. 日内时间特征
  108. features['hour'] = df.index.hour
  109. features['is_morning'] = ((features['hour'] >= 9) & (features['hour'] < 11)).astype(int)
  110. features['is_afternoon'] = ((features['hour'] >= 13) & (features['hour'] < 15)).astype(int)
  111. # 11. 价格变化加速度
  112. features['price_accel'] = df['close'].diff().diff()
  113. features['price_accel_normalized'] = features['price_accel'] / (df['close'] * 0.01)
  114. # 12. 连续涨跌周期数
  115. features['consecutive_up'] = (df['return'] > 0).astype(int).groupby((df['return'] <= 0).astype(int).cumsum()).cumsum()
  116. features['consecutive_down'] = (df['return'] < 0).astype(int).groupby((df['return'] >= 0).astype(int).cumsum()).cumsum()
  117. # 填充缺失值
  118. features = features.ffill().fillna(0)
  119. return features
  120. def define_market_regime_30min(df, lookback=8):
  121. """
  122. 基于规则定义30分钟市场状态标签
  123. 参数:
  124. lookback: 回看周期数(默认8 = 4小时)
  125. """
  126. labels = []
  127. # 预计算RSI
  128. delta = df['close'].diff()
  129. gain = (delta.where(delta > 0, 0)).rolling(14).mean()
  130. loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
  131. rs = gain / (loss + 1e-10)
  132. rsi = 100 - (100 / (1 + rs))
  133. for i in range(len(df)):
  134. if i < lookback:
  135. labels.append(0)
  136. continue
  137. # 获取回看期间数据
  138. period_close = df['close'].iloc[i-lookback:i]
  139. period_high = df['high'].iloc[i-lookback:i]
  140. period_low = df['low'].iloc[i-lookback:i]
  141. period_rsi = rsi.iloc[i-lookback:i]
  142. start_price = period_close.iloc[0]
  143. end_price = period_close.iloc[-1]
  144. period_return = (end_price / start_price - 1) * 100
  145. daily_returns = period_close.pct_change().dropna()
  146. volatility = daily_returns.std() * np.sqrt(48) * 100
  147. max_price = period_high.max()
  148. min_price = period_low.min()
  149. mid = lookback // 2
  150. first_half_return = (period_close.iloc[mid] / start_price - 1) * 100
  151. second_half_return = (end_price / period_close.iloc[mid] - 1) * 100
  152. # RSI特征
  153. rsi_start = period_rsi.iloc[0]
  154. rsi_end = period_rsi.iloc[-1]
  155. rsi_change = rsi_end - rsi_start
  156. # 定义标签
  157. label = 0 # 默认震荡
  158. # ========== 反转判断 ==========
  159. condition_1 = (rsi_start > 68 and rsi_change < -15) or (rsi_start < 32 and rsi_change > 15)
  160. condition_2 = (first_half_return * second_half_return < 0 and
  161. abs(first_half_return) > 1.5 and abs(second_half_return) > 1.0)
  162. condition_3 = (period_rsi.max() > 72 or period_rsi.min() < 28)
  163. condition_4 = 12 < volatility < 40
  164. reversal_score = sum([condition_1, condition_2, condition_3, condition_4])
  165. if reversal_score >= 2:
  166. label = 2
  167. # ========== 趋势判断 ==========
  168. elif abs(period_return) >= 2.5 and volatility < 35:
  169. if reversal_score < 2:
  170. label = 1
  171. labels.append(label)
  172. return np.array(labels)
  173. def train_and_predict(df_30min, features, labels):
  174. """训练模型并预测"""
  175. print("\n训练30分钟级别分类器...")
  176. # 对齐数据
  177. valid_idx = ~np.isnan(labels)
  178. X = features[valid_idx]
  179. y = labels[valid_idx]
  180. df_aligned = df_30min.iloc[valid_idx].copy()
  181. # 分割训练集和测试集(按时间顺序,80%训练,20%测试)
  182. split_idx = int(len(X) * 0.8)
  183. X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
  184. y_train, y_test = y[:split_idx], y[split_idx:]
  185. print(f"训练集: {len(X_train)}条")
  186. print(f"测试集: {len(X_test)}条")
  187. # 训练模型
  188. clf = RandomForestClassifier(
  189. n_estimators=200,
  190. max_depth=15,
  191. min_samples_split=10,
  192. min_samples_leaf=5,
  193. random_state=42,
  194. class_weight={0: 1.0, 1: 1.2, 2: 2.0}
  195. )
  196. clf.fit(X_train, y_train)
  197. # 评估
  198. train_score = clf.score(X_train, y_train)
  199. test_score = clf.score(X_test, y_test)
  200. print(f"\n训练准确率: {train_score:.2%}")
  201. print(f"测试准确率: {test_score:.2%}")
  202. # 详细报告
  203. y_pred = clf.predict(X_test)
  204. print("\n分类报告:")
  205. print(classification_report(y_test, y_pred, target_names=['震荡', '趋势', '反转']))
  206. # 预测所有数据
  207. all_pred = clf.predict(X)
  208. all_proba = clf.predict_proba(X)
  209. # 添加预测结果到DataFrame
  210. df_aligned['state'] = all_pred
  211. df_aligned['prob_ranging'] = all_proba[:, 0]
  212. df_aligned['prob_trend'] = all_proba[:, 1]
  213. df_aligned['prob_reversal'] = all_proba[:, 2]
  214. # 特征重要性
  215. feature_importance = pd.DataFrame({
  216. 'feature': X.columns,
  217. 'importance': clf.feature_importances_
  218. }).sort_values('importance', ascending=False)
  219. print("\n特征重要性 TOP 10:")
  220. print(feature_importance.head(10).to_string(index=False))
  221. return clf, df_aligned, feature_importance
  222. def analyze_regime_distribution(df_result):
  223. """分析市场状态分布"""
  224. print("\n" + "="*70)
  225. print("30分钟市场状态分析")
  226. print("="*70)
  227. state_names = ['震荡', '趋势', '反转']
  228. # 整体分布
  229. print("\n【整体分布】")
  230. for i, name in enumerate(state_names):
  231. count = (df_result['state'] == i).sum()
  232. pct = count / len(df_result) * 100
  233. print(f" {name}: {count}个周期 ({pct:.1f}%)")
  234. # 按日期统计
  235. print("\n【最近5个交易日状态分布】")
  236. df_result['date'] = df_result.index.date
  237. recent_dates = df_result['date'].unique()[-5:]
  238. for date in recent_dates:
  239. day_data = df_result[df_result['date'] == date]
  240. print(f"\n {date}:")
  241. for i, name in enumerate(state_names):
  242. count = (day_data['state'] == i).sum()
  243. print(f" {name}: {count}个30分钟周期")
  244. # 当前状态
  245. latest = df_result.iloc[-1]
  246. current_state = state_names[int(latest['state'])]
  247. print("\n【当前状态】")
  248. print(f" 时间: {df_result.index[-1]}")
  249. print(f" 收盘价: {latest['close']:.2f}")
  250. print(f" 市场状态: {current_state}")
  251. print(f" 置信度: {latest[['prob_ranging', 'prob_trend', 'prob_reversal']].max():.2%}")
  252. print(f" 概率分布: 震荡{latest['prob_ranging']:.1%} / 趋势{latest['prob_trend']:.1%} / 反转{latest['prob_reversal']:.1%}")
  253. def main():
  254. """主程序"""
  255. print("="*70)
  256. print("创业板50市场状态分类器 - 30分钟级别")
  257. print("="*70)
  258. # 1. 加载5分钟数据
  259. df_5min = load_5min_data('SZ#399673.txt')
  260. # 2. 聚合成30分钟数据
  261. df_30min = resample_to_30min(df_5min)
  262. # 3. 计算特征
  263. print("\n计算30分钟技术指标...")
  264. features = calculate_features_30min(df_30min)
  265. print(f"特征数量: {features.shape[1]}")
  266. # 4. 定义标签
  267. print("\n定义市场状态标签...")
  268. labels = define_market_regime_30min(df_30min, lookback=8)
  269. # 统计标签分布
  270. unique, counts = np.unique(labels, return_counts=True)
  271. print("\n标签分布:")
  272. state_names = ['震荡', '趋势', '反转']
  273. for u, c in zip(unique, counts):
  274. print(f" {state_names[u]}: {c}个周期 ({c/len(labels)*100:.1f}%)")
  275. # 5. 训练并预测
  276. clf, df_result, importance = train_and_predict(df_30min, features, labels)
  277. # 6. 分析结果
  278. analyze_regime_distribution(df_result)
  279. # 7. 保存结果
  280. print("\n保存结果...")
  281. df_result.to_csv('cyb50_30min_regime_result.csv')
  282. print("[OK] 结果已保存: cyb50_30min_regime_result.csv")
  283. # 保存模型
  284. import pickle
  285. with open('rf_classifier_30min.pkl', 'wb') as f:
  286. pickle.dump(clf, f)
  287. print("[OK] 模型已保存: rf_classifier_30min.pkl")
  288. print("\n" + "="*70)
  289. if __name__ == "__main__":
  290. main()