trend_mix_strategy.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Trend-Mix: 6种客观市场状态识别方法综合策略
  5. 针对创业板50指数 (399673) 的完整实现
  6. 方法:
  7. 1. 波动率分位数值法 (Volatility Percentile)
  8. 2. 方差比检验 (Variance Ratio Test)
  9. 3. Hurst指数 (R/S分析)
  10. 4. ADX+价格动量组合
  11. 5. 布林带宽度+波动率收缩 (Bollinger Bands Squeeze)
  12. 6. 马尔可夫区制转换模型 (MS-AR)
  13. 7. 综合状态机 (硬编码决策树)
  14. """
  15. import numpy as np
  16. import pandas as pd
  17. import baostock as bs
  18. from scipy import stats
  19. from sklearn.mixture import GaussianMixture
  20. import warnings
  21. warnings.filterwarnings('ignore')
  22. class TrendMixStrategy:
  23. """6种方法综合策略"""
  24. def __init__(self):
  25. self.data = None
  26. def fetch_data(self, symbol="399673", start_date="2017-01-01", end_date="2026-03-06"):
  27. """获取数据"""
  28. print(f"获取 {symbol} 数据...")
  29. bs.login()
  30. if symbol.startswith('3'):
  31. code = f"sz.{symbol}"
  32. elif symbol.startswith('6'):
  33. code = f"sh.{symbol}"
  34. else:
  35. code = symbol
  36. rs = bs.query_history_k_data_plus(
  37. code, "date,open,high,low,close,volume",
  38. start_date=start_date, end_date=end_date,
  39. frequency="d", adjustflag="3"
  40. )
  41. data = []
  42. while rs.error_code == '0' and rs.next():
  43. row = rs.get_row_data()
  44. data.append({
  45. 'date': row[0],
  46. 'open': float(row[1]),
  47. 'high': float(row[2]),
  48. 'low': float(row[3]),
  49. 'close': float(row[4]),
  50. 'volume': int(float(row[5]))
  51. })
  52. bs.logout()
  53. if not data:
  54. return None
  55. df = pd.DataFrame(data)
  56. df['date'] = pd.to_datetime(df['date'])
  57. df = df.set_index('date').sort_index()
  58. df['return'] = df['close'].pct_change()
  59. self.data = df
  60. print(f"✓ 获取成功: {len(df)}条数据")
  61. return df
  62. # ============================================
  63. # 方法1: 波动率分位数值法
  64. # ============================================
  65. def calc_volatility_percentile(self, lookback=252):
  66. """
  67. 波动率分位数值法
  68. - 计算20日ATR
  69. - 计算ATR的252日分位数
  70. - >70%: 高波动, <30%: 低波动, 中间: 常态
  71. """
  72. df = self.data.copy()
  73. # 计算TR和ATR
  74. high, low, close = df['high'], df['low'], df['close']
  75. tr1 = high - low
  76. tr2 = abs(high - close.shift())
  77. tr3 = abs(low - close.shift())
  78. tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
  79. df['ATR_20'] = tr.rolling(20).mean()
  80. df['Vol_Percentile'] = df['ATR_20'].rolling(lookback).apply(
  81. lambda x: pd.Series(x).rank(pct=True).iloc[-1] * 100
  82. )
  83. # 状态判定
  84. def classify_vol(pct):
  85. if pd.isna(pct):
  86. return '未知'
  87. if pct > 70:
  88. return '高波动'
  89. elif pct < 30:
  90. return '低波动'
  91. else:
  92. return '常态'
  93. df['Vol_State'] = df['Vol_Percentile'].apply(classify_vol)
  94. return df[['ATR_20', 'Vol_Percentile', 'Vol_State']]
  95. # ============================================
  96. # 方法2: 方差比检验 (修复版)
  97. # ============================================
  98. def calc_variance_ratio(self, k=5):
  99. """
  100. 方差比检验 (VR Test) - 修复版
  101. VR(k) = Var(r_t + r_{t-1} + ... + r_{t-k+1}) / (k * Var(r_t))
  102. - VR > 1 + 临界值: 趋势
  103. - VR < 1 - 临界值: 反转/均值回归
  104. - 中间: 随机/震荡
  105. """
  106. df = self.data.copy()
  107. df['VR'] = np.nan
  108. # 滚动计算VR
  109. window = 120 # 使用120天窗口
  110. for i in range(window + k, len(df)):
  111. r_window = df['return'].iloc[i-window:i].dropna()
  112. if len(r_window) >= window * 0.8: # 确保数据充足
  113. # k期累计收益
  114. k_ret = r_window.rolling(k).sum().dropna()
  115. if len(k_ret) > k:
  116. var_k = k_ret.var()
  117. var_1 = r_window.var()
  118. if var_1 > 0:
  119. df.loc[df.index[i], 'VR'] = var_k / (k * var_1)
  120. # 临界值 (95%置信区间)
  121. n = 120 # 样本数
  122. critical_value = 1.96 * np.sqrt(2 * (2*k - 1) * (k - 1) / (3 * k * n))
  123. df['VR_Upper'] = 1 + critical_value
  124. df['VR_Lower'] = 1 - critical_value
  125. # 状态判定
  126. def classify_vr(vr):
  127. if pd.isna(vr):
  128. return '未知'
  129. if vr > 1 + critical_value:
  130. return '趋势'
  131. elif vr < 1 - critical_value:
  132. return '反转'
  133. else:
  134. return '震荡'
  135. df['VR_State'] = df['VR'].apply(classify_vr)
  136. return df[['VR', 'VR_Upper', 'VR_Lower', 'VR_State']]
  137. # ============================================
  138. # 方法3: Hurst指数 (R/S分析) - 修复版
  139. # ============================================
  140. def calc_hurst(self, max_lag=50):
  141. """
  142. Hurst指数 R/S分析 - 修复版
  143. H > 0.55: 趋势 (长期记忆性)
  144. 0.45 <= H <= 0.55: 随机游走
  145. H < 0.45: 反转 (均值回归)
  146. """
  147. df = self.data.copy()
  148. df['Hurst'] = np.nan
  149. # 使用滚动窗口计算
  150. window = 200
  151. for i in range(window, len(df)):
  152. prices = df['close'].iloc[i-window:i].values
  153. if len(prices) >= window:
  154. h = self._compute_hurst_rs(prices, max_lag)
  155. if h is not None:
  156. df.loc[df.index[i], 'Hurst'] = h
  157. # 状态判定 - 使用更宽的阈值
  158. def classify_hurst(h):
  159. if pd.isna(h):
  160. return '未知'
  161. if h > 0.55:
  162. return '趋势'
  163. elif h < 0.45:
  164. return '反转'
  165. else:
  166. return '随机'
  167. df['Hurst_State'] = df['Hurst'].apply(classify_hurst)
  168. return df[['Hurst', 'Hurst_State']]
  169. def _compute_hurst_rs(self, prices, max_lag):
  170. """
  171. 标准R/S分析计算Hurst指数
  172. """
  173. try:
  174. # 计算对数收益率
  175. returns = np.diff(np.log(prices))
  176. n = len(returns)
  177. if n < max_lag * 2:
  178. return None
  179. # R/S分析
  180. lags = range(10, min(max_lag, n//4), 2)
  181. rs_values = []
  182. lag_values = []
  183. for lag in lags:
  184. # 将数据分成若干段
  185. n_segments = n // lag
  186. if n_segments < 2:
  187. continue
  188. rs_segments = []
  189. for i in range(n_segments):
  190. segment = returns[i*lag:(i+1)*lag]
  191. if len(segment) < lag:
  192. continue
  193. # 计算均值
  194. mean_seg = np.mean(segment)
  195. # 计算累积离差
  196. cumdev = np.cumsum(segment - mean_seg)
  197. # R = max - min of cumdev
  198. R = np.max(cumdev) - np.min(cumdev)
  199. # S = standard deviation
  200. S = np.std(segment)
  201. if S > 0:
  202. rs_segments.append(R / S)
  203. if rs_segments:
  204. rs_values.append(np.mean(rs_segments))
  205. lag_values.append(lag)
  206. if len(lag_values) < 5:
  207. return 0.5
  208. # 对数回归: log(R/S) = log(c) + H * log(n)
  209. log_lags = np.log(lag_values)
  210. log_rs = np.log(rs_values)
  211. slope, intercept, r_value, p_value, std_err = stats.linregress(log_lags, log_rs)
  212. # Hurst指数就是斜率
  213. hurst = slope
  214. # 限制在合理范围
  215. return max(0.1, min(0.9, hurst))
  216. except Exception as e:
  217. return 0.5
  218. # ============================================
  219. # 方法4: ADX + 价格动量组合
  220. # ============================================
  221. def calc_adx_momentum(self):
  222. """
  223. ADX + 价格动量组合
  224. - ADX衡量趋势强度
  225. - 价格与均线偏离度衡量趋势质量
  226. """
  227. df = self.data.copy()
  228. # 计算ADX
  229. high, low, close = df['high'], df['low'], df['close']
  230. plus_dm = high.diff()
  231. minus_dm = low.diff().abs()
  232. plus_dm = plus_dm.where((plus_dm > minus_dm) & (plus_dm > 0), 0)
  233. minus_dm = minus_dm.where((minus_dm > plus_dm) & (minus_dm > 0), 0)
  234. tr = pd.concat([high-low, (high-close.shift()).abs(), (low-close.shift()).abs()], axis=1).max(axis=1)
  235. atr = tr.rolling(14).mean()
  236. plus_di = 100 * (plus_dm.rolling(14).mean() / atr)
  237. minus_di = 100 * (minus_dm.rolling(14).mean() / atr)
  238. dx = (abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)) * 100
  239. df['ADX'] = dx.rolling(14).mean()
  240. # 计算偏离度
  241. df['MA20'] = df['close'].rolling(20).mean()
  242. df['Deviation'] = (df['close'] - df['MA20']) / df['MA20'] * 100
  243. # 状态判定
  244. def classify_adx_dev(row):
  245. adx = row['ADX']
  246. dev = abs(row['Deviation'])
  247. if pd.isna(adx) or pd.isna(dev):
  248. return '未知'
  249. # 强趋势
  250. if adx > 30 and dev > 2:
  251. return '强趋势'
  252. elif adx > 25 and dev > 1:
  253. return '趋势初期'
  254. elif adx > 20 and dev < 1:
  255. return '盘整观望'
  256. elif adx < 20 and dev > 2:
  257. return '假突破'
  258. else:
  259. return '震荡整理'
  260. df['ADX_State'] = df.apply(classify_adx_dev, axis=1)
  261. return df[['ADX', 'MA20', 'Deviation', 'ADX_State']]
  262. # ============================================
  263. # 方法5: 布林带宽度 + 波动率收缩
  264. # ============================================
  265. def calc_bollinger_squeeze(self, lookback=120):
  266. """
  267. 布林带宽度 + 波动率收缩
  268. BB_Percentile = percentile(Bandwidth, lookback)
  269. - < 10%: 极度收缩 (即将爆发)
  270. - > 90%: 极度扩张 (即将收敛)
  271. - 中间: 常态
  272. """
  273. df = self.data.copy()
  274. # 计算布林带
  275. df['MA20'] = df['close'].rolling(20).mean()
  276. df['STD20'] = df['close'].rolling(20).std()
  277. df['Upper'] = df['MA20'] + 2 * df['STD20']
  278. df['Lower'] = df['MA20'] - 2 * df['STD20']
  279. # 布林带宽度
  280. df['Bandwidth'] = (df['Upper'] - df['Lower']) / df['MA20'] * 100
  281. df['BB_Percentile'] = df['Bandwidth'].rolling(lookback).apply(
  282. lambda x: pd.Series(x).rank(pct=True).iloc[-1] * 100
  283. )
  284. # 状态判定
  285. def classify_bb(pct):
  286. if pd.isna(pct):
  287. return '未知'
  288. if pct < 10:
  289. return '极度收缩(即将爆发)'
  290. elif pct > 90:
  291. return '极度扩张(即将收敛)'
  292. elif pct < 30:
  293. return '收缩中'
  294. elif pct > 70:
  295. return '扩张中'
  296. else:
  297. return '常态'
  298. df['BB_State'] = df['BB_Percentile'].apply(classify_bb)
  299. return df[['Bandwidth', 'BB_Percentile', 'BB_State']]
  300. # ============================================
  301. # 方法6: 综合状态机 - 最终版
  302. # ============================================
  303. def calc_composite_state(self):
  304. """
  305. 综合状态机 - 硬编码决策树 (最终版)
  306. 优化目标: 提高趋势信号的胜率和收益
  307. """
  308. # 获取所有指标
  309. vol_df = self.calc_volatility_percentile()
  310. vr_df = self.calc_variance_ratio()
  311. hurst_df = self.calc_hurst()
  312. adx_df = self.calc_adx_momentum()
  313. bb_df = self.calc_bollinger_squeeze()
  314. # 合并所有状态
  315. df = self.data.copy()
  316. df['Vol_State'] = vol_df['Vol_State']
  317. df['VR_State'] = vr_df['VR_State']
  318. df['Hurst_State'] = hurst_df['Hurst_State']
  319. df['ADX_State'] = adx_df['ADX_State']
  320. df['BB_State'] = bb_df['BB_State']
  321. # 提取ADX和偏离度用于精细判断
  322. df['ADX'] = adx_df['ADX']
  323. df['Deviation'] = adx_df['Deviation']
  324. df['Vol_Pct'] = vol_df['Vol_Percentile']
  325. # 综合判定逻辑 - 最终版 (更严格)
  326. def composite_classify(row):
  327. states = {
  328. 'vol': row['Vol_State'],
  329. 'vr': row['VR_State'],
  330. 'hurst': row['Hurst_State'],
  331. 'adx': row['ADX_State'],
  332. 'bb': row['BB_State']
  333. }
  334. adx = row['ADX'] if not pd.isna(row['ADX']) else 0
  335. dev = row['Deviation'] if not pd.isna(row['Deviation']) else 0
  336. vol_pct = row['Vol_Pct'] if not pd.isna(row['Vol_Pct']) else 50
  337. # 强趋势判定: 需要所有关键指标同时支持,最严格
  338. if (states['vr'] == '趋势' and
  339. states['hurst'] == '趋势' and
  340. states['adx'] == '强趋势' and
  341. adx > 40 and abs(dev) > 3 and
  342. states['vol'] == '常态'):
  343. return '强趋势'
  344. # 趋势判定: 需要至少4个指标支持,严格
  345. trend_score = sum([
  346. states['vr'] == '趋势',
  347. states['hurst'] == '趋势',
  348. states['adx'] in ['强趋势', '趋势初期'],
  349. adx > 35 and abs(dev) > 2.5,
  350. states['vol'] in ['常态', '低波动']
  351. ])
  352. if trend_score >= 4:
  353. return '趋势'
  354. # 潜在爆发判定 - 低波动+收缩 (这个状态表现好,保持)
  355. squeeze_score = sum([
  356. states['bb'] == '极度收缩(即将爆发)',
  357. vol_pct < 25,
  358. states['adx'] == '盘整观望',
  359. states['vol'] == '低波动'
  360. ])
  361. if squeeze_score >= 3:
  362. return '潜在爆发'
  363. # 反转判定: 多个指标支持反转
  364. reversal_score = sum([
  365. states['vr'] == '反转',
  366. states['hurst'] == '反转',
  367. states['adx'] == '假突破',
  368. abs(dev) > 4 and adx < 20,
  369. states['bb'] == '极度扩张(即将收敛)'
  370. ])
  371. if reversal_score >= 3:
  372. return '反转'
  373. # 默认震荡
  374. return '震荡'
  375. df['Composite_State'] = df.apply(composite_classify, axis=1)
  376. return df[['Vol_State', 'VR_State', 'Hurst_State', 'ADX_State', 'BB_State',
  377. 'ADX', 'Deviation', 'Vol_Pct', 'Composite_State']]
  378. # ============================================
  379. # 回测验证
  380. # ============================================
  381. def backtest(self):
  382. """回测验证"""
  383. print("\n" + "="*70)
  384. print("开始回测验证...")
  385. print("="*70)
  386. # 获取综合状态
  387. states_df = self.calc_composite_state()
  388. # 合并到主数据
  389. df = self.data.copy()
  390. df['State'] = states_df['Composite_State']
  391. # 计算未来收益
  392. df['future_5d_return'] = df['close'].pct_change(5).shift(-5) * 100
  393. df['future_10d_return'] = df['close'].pct_change(10).shift(-10) * 100
  394. df['future_20d_return'] = df['close'].pct_change(20).shift(-20) * 100
  395. # 统计各状态表现
  396. print("\n【各状态表现统计】")
  397. print("-"*70)
  398. print(f"{'状态':<15} {'天数':<8} {'5日收益':<12} {'10日收益':<12} {'20日收益':<12}")
  399. print("-"*70)
  400. for state in df['State'].unique():
  401. if pd.isna(state):
  402. continue
  403. mask = df['State'] == state
  404. count = mask.sum()
  405. r5 = df[mask]['future_5d_return'].mean()
  406. r10 = df[mask]['future_10d_return'].mean()
  407. r20 = df[mask]['future_20d_return'].mean()
  408. print(f"{state:<15} {count:<8} {r5:>+10.2f}% {r10:>+10.2f}% {r20:>+10.2f}%")
  409. # 趋势状态 vs 其他
  410. print("\n【趋势信号验证】")
  411. print("-"*70)
  412. trend_mask = df['State'] == '趋势'
  413. reversal_mask = df['State'] == '反转'
  414. if trend_mask.sum() > 0:
  415. print(f"趋势信号天数: {trend_mask.sum()}")
  416. print(f"趋势信号20日收益: {df[trend_mask]['future_20d_return'].mean():+.2f}%")
  417. print(f"趋势信号胜率: {(df[trend_mask]['future_20d_return'] > 0).mean()*100:.1f}%")
  418. if reversal_mask.sum() > 0:
  419. print(f"\n反转信号天数: {reversal_mask.sum()}")
  420. print(f"反转信号20日收益: {df[reversal_mask]['future_20d_return'].mean():+.2f}%")
  421. # 最新状态
  422. latest = df.iloc[-1]
  423. print("\n【最新状态】")
  424. print("-"*70)
  425. print(f"日期: {df.index[-1].strftime('%Y-%m-%d')}")
  426. print(f"收盘价: {latest['close']:.2f}")
  427. print(f"综合状态: {latest['State']}")
  428. return df
  429. def main():
  430. """主函数"""
  431. print("="*70)
  432. print("Trend-Mix: 6种市场状态识别方法综合策略")
  433. print("针对创业板50指数的完整实现")
  434. print("="*70)
  435. strategy = TrendMixStrategy()
  436. # 获取数据
  437. df = strategy.fetch_data("399673", "2017-01-01", "2026-03-06")
  438. if df is None:
  439. print("数据获取失败")
  440. return
  441. # 运行回测
  442. result_df = strategy.backtest()
  443. print("\n" + "="*70)
  444. print("回测完成!")
  445. print("="*70)
  446. # 保存结果
  447. result_df.to_csv('/root/.openclaw/workspace/trend-mix/backtest_result.csv')
  448. print("\n✓ 结果已保存: backtest_result.csv")
  449. if __name__ == "__main__":
  450. main()