market_regime_hmm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 市场环境识别器 (Market Regime Identifier)
  5. 基于HMM隐马尔可夫模型的市场状态识别系统
  6. 状态定义:
  7. - 状态0(震荡):价格波动大但无明显方向,Hurst指数≈0.5,自相关性低
  8. - 状态1(趋势):价格持续单向运动,Hurst指数>0.6,高自相关
  9. - 状态2(反转):超买/超卖后的V型反转,RSI极端值后的快速回归
  10. 作者: OpenClaw
  11. 日期: 2026-03-06
  12. """
  13. import numpy as np
  14. import pandas as pd
  15. from hmmlearn.hmm import GaussianHMM
  16. from scipy import stats
  17. import warnings
  18. warnings.filterwarnings('ignore')
  19. # ==================== 特征工程 ====================
  20. def calculate_hurst(prices, max_lag=100):
  21. """
  22. 计算Hurst指数
  23. H ≈ 0.5: 随机游走(震荡)
  24. H > 0.6: 趋势性
  25. H < 0.4: 均值回归
  26. """
  27. lags = range(2, min(max_lag, len(prices)//4))
  28. tau = [np.std(np.subtract(prices[lag:], prices[:-lag])) for lag in lags]
  29. if len(tau) < 2 or any(t <= 0 for t in tau):
  30. return 0.5
  31. reg = np.polyfit(np.log(lags), np.log(tau), 1)
  32. return reg[0]
  33. def calculate_rsi(prices, period=14):
  34. """计算RSI指标"""
  35. deltas = np.diff(prices)
  36. gains = np.where(deltas > 0, deltas, 0)
  37. losses = np.where(deltas < 0, -deltas, 0)
  38. avg_gains = np.convolve(gains, np.ones(period)/period, mode='valid')
  39. avg_losses = np.convolve(losses, np.ones(period)/period, mode='valid')
  40. rs = avg_gains / (avg_losses + 1e-10)
  41. rsi = 100 - (100 / (1 + rs))
  42. # 补齐长度
  43. padding = np.full(period, 50)
  44. return np.concatenate([padding, rsi])
  45. def extract_features(df):
  46. """
  47. 提取特征向量 X_t
  48. X_t = [收益率标准差(5日), 价格动量(10日), 波动率比率(短/长), 成交量变化率, 日内趋势强度]
  49. """
  50. features = pd.DataFrame(index=df.index)
  51. # 1. 收益率标准差(5日)
  52. returns = df['close'].pct_change()
  53. features['ret_std_5'] = returns.rolling(5).std() * np.sqrt(252)
  54. # 2. 价格动量(10日)
  55. features['momentum_10'] = (df['close'] / df['close'].shift(10) - 1) * 100
  56. # 3. 波动率比率(短/长)
  57. vol_short = returns.rolling(5).std()
  58. vol_long = returns.rolling(20).std()
  59. features['vol_ratio'] = vol_short / (vol_long + 1e-10)
  60. # 4. 成交量变化率
  61. features['volume_change'] = df['volume'].pct_change() * 100
  62. # 5. 日内趋势强度
  63. features['intraday_trend'] = ((df['close'] - df['open']) / (df['high'] - df['low'] + 1e-10)) * 100
  64. # 6. Hurst指数(额外特征)
  65. features['hurst'] = df['close'].rolling(100).apply(calculate_hurst, raw=True)
  66. # 7. RSI
  67. features['rsi'] = calculate_rsi(df['close'].values)
  68. # 8. 自相关性
  69. features['autocorr'] = returns.rolling(20).apply(lambda x: x.autocorr(lag=1) if len(x) > 1 else 0)
  70. # 填充缺失值
  71. features = features.ffill().fillna(0)
  72. return features
  73. # ==================== HMM模型 ====================
  74. class MarketRegimeHMM:
  75. """市场环境HMM模型"""
  76. # 状态名称
  77. STATE_NAMES = {
  78. 0: '震荡',
  79. 1: '趋势',
  80. 2: '反转'
  81. }
  82. def __init__(self, n_components=3, n_iter=100):
  83. # 先验转移概率矩阵
  84. self.PRIOR_TRANSITION = np.array([
  85. [0.85, 0.10, 0.05], # 震荡 -> 震荡/趋势/反转
  86. [0.15, 0.80, 0.05], # 趋势 -> 震荡/趋势/反转
  87. [0.20, 0.10, 0.70] # 反转 -> 震荡/趋势/反转
  88. ])
  89. self.model = GaussianHMM(
  90. n_components=n_components,
  91. covariance_type='full',
  92. n_iter=n_iter,
  93. random_state=42,
  94. init_params='mc' # 只初始化均值和协方差,不初始化转移矩阵
  95. )
  96. self.is_fitted = False
  97. def fit(self, features):
  98. """训练HMM模型"""
  99. print("训练HMM模型...")
  100. X = features.values
  101. # 先验状态分布(均匀分布)
  102. self.model.startprob_ = np.array([1/3, 1/3, 1/3])
  103. # 使用先验转移概率初始化
  104. self.model.transmat_ = self.PRIOR_TRANSITION.copy()
  105. # 拟合模型
  106. self.model.fit(X)
  107. self.is_fitted = True
  108. print(f"模型收敛: {self.model.monitor_.converged}")
  109. print(f"迭代次数: {self.model.n_iter}")
  110. print("\n学习到的转移概率矩阵:")
  111. print(self.model.transmat_.round(3))
  112. return self
  113. def predict(self, features):
  114. """预测状态序列"""
  115. if not self.is_fitted:
  116. raise ValueError("模型尚未训练,请先调用fit()")
  117. X = features.values
  118. states = self.model.predict(X)
  119. # 计算状态概率
  120. state_probs = self.model.predict_proba(X)
  121. return states, state_probs
  122. def get_current_regime(self, features):
  123. """获取当前市场状态"""
  124. states, probs = self.predict(features)
  125. current_state = states[-1]
  126. current_prob = probs[-1]
  127. return {
  128. 'state': current_state,
  129. 'state_name': self.STATE_NAMES[current_state],
  130. 'probabilities': {
  131. self.STATE_NAMES[i]: current_prob[i]
  132. for i in range(len(self.STATE_NAMES))
  133. },
  134. 'confidence': current_prob[current_state]
  135. }
  136. # ==================== 策略切换逻辑 ====================
  137. class StrategySelector:
  138. """基于市场状态的策略选择器"""
  139. STRATEGY_CONFIG = {
  140. 0: { # 震荡
  141. 'name': '均值回归',
  142. 'action': 'RSI超买超卖交易',
  143. 'position_size': 0.5, # 降低仓位
  144. 'stop_loss': '2N',
  145. 'description': '关闭趋势策略,使用RSI超买(>70)超卖(<30)信号'
  146. },
  147. 1: { # 趋势
  148. 'name': '海龟趋势',
  149. 'action': '全速运行',
  150. 'position_size': 1.0, # 全仓位
  151. 'stop_loss': '2N',
  152. 'description': '增加仓位,突破20日高低点交易'
  153. },
  154. 2: { # 反转
  155. 'name': '反向/观望',
  156. 'action': '反向信号或空仓',
  157. 'position_size': 0.3, # 最小仓位
  158. 'stop_loss': '1N', # 收紧止损
  159. 'description': '反向信号或观望,收紧止损'
  160. }
  161. }
  162. @classmethod
  163. def get_strategy(cls, state):
  164. """根据状态获取策略配置"""
  165. return cls.STRATEGY_CONFIG.get(state, cls.STRATEGY_CONFIG[0])
  166. @classmethod
  167. def generate_signal(cls, state, rsi_value, price, ma20):
  168. """生成交易信号"""
  169. strategy = cls.get_strategy(state)
  170. signal = {
  171. 'state': state,
  172. 'strategy': strategy['name'],
  173. 'position_size': strategy['position_size'],
  174. 'action': 'HOLD'
  175. }
  176. if state == 0: # 震荡 - RSI均值回归
  177. if rsi_value < 30:
  178. signal['action'] = 'BUY'
  179. signal['reason'] = 'RSI超卖'
  180. elif rsi_value > 70:
  181. signal['action'] = 'SELL'
  182. signal['reason'] = 'RSI超买'
  183. elif state == 1: # 趋势 - 突破系统
  184. if price > ma20 * 1.02:
  185. signal['action'] = 'BUY'
  186. signal['reason'] = '突破20日均线2%'
  187. elif price < ma20 * 0.98:
  188. signal['action'] = 'SELL'
  189. signal['reason'] = '跌破20日均线2%'
  190. elif state == 2: # 反转 - 反向或观望
  191. if rsi_value > 70:
  192. signal['action'] = 'SELL'
  193. signal['reason'] = '超买后反转'
  194. elif rsi_value < 30:
  195. signal['action'] = 'BUY'
  196. signal['reason'] = '超卖后反转'
  197. else:
  198. signal['action'] = 'HOLD'
  199. signal['reason'] = '观望'
  200. return signal
  201. # ==================== 模型评估 ====================
  202. def evaluate_model(hmm, features, true_states=None):
  203. """
  204. 评估模型性能
  205. 由于真实状态未知,使用以下指标:
  206. 1. 对数似然值
  207. 2. AIC/BIC
  208. 3. 状态持续时间合理性
  209. 4. 状态与价格行为的对应关系
  210. """
  211. X = features.values
  212. # 计算对数似然
  213. log_likelihood = hmm.model.score(X)
  214. # 计算AIC和BIC
  215. n_params = hmm.model.n_components * (hmm.model.n_features + hmm.model.n_features * (hmm.model.n_features + 1) / 2) + hmm.model.n_components * hmm.model.n_components
  216. n_samples = len(X)
  217. aic = -2 * log_likelihood + 2 * n_params
  218. bic = -2 * log_likelihood + n_params * np.log(n_samples)
  219. print(f"\n模型评估指标:")
  220. print(f"对数似然: {log_likelihood:.2f}")
  221. print(f"AIC: {aic:.2f}")
  222. print(f"BIC: {bic:.2f}")
  223. # 预测状态
  224. states, probs = hmm.predict(features)
  225. # 统计状态分布
  226. state_counts = pd.Series(states).value_counts().sort_index()
  227. state_pct = (state_counts / len(states) * 100).round(2)
  228. print(f"\n状态分布:")
  229. for state_id, state_name in hmm.STATE_NAMES.items():
  230. count = state_counts.get(state_id, 0)
  231. pct = state_pct.get(state_id, 0)
  232. print(f" {state_name}: {count}天 ({pct}%)")
  233. # 计算平均状态持续时间
  234. state_durations = []
  235. current_state = states[0]
  236. duration = 1
  237. for s in states[1:]:
  238. if s == current_state:
  239. duration += 1
  240. else:
  241. state_durations.append((current_state, duration))
  242. current_state = s
  243. duration = 1
  244. state_durations.append((current_state, duration))
  245. print(f"\n平均状态持续时间:")
  246. for state_id in range(3):
  247. durations = [d for s, d in state_durations if s == state_id]
  248. if durations:
  249. avg_duration = np.mean(durations)
  250. print(f" {hmm.STATE_NAMES[state_id]}: {avg_duration:.1f}天")
  251. return {
  252. 'log_likelihood': log_likelihood,
  253. 'aic': aic,
  254. 'bic': bic,
  255. 'state_distribution': state_counts.to_dict(),
  256. 'states': states,
  257. 'state_probs': probs
  258. }
  259. # ==================== 主程序 ====================
  260. def main():
  261. """主程序"""
  262. print("="*70)
  263. print("市场环境识别器 (Market Regime Identifier)")
  264. print("基于HMM隐马尔可夫模型")
  265. print("="*70)
  266. # 示例:使用随机数据演示
  267. print("\n注意:这是演示版本,请使用真实数据运行")
  268. print("数据格式要求:DataFrame包含 'open', 'high', 'low', 'close', 'volume' 列")
  269. # 生成示例数据
  270. np.random.seed(42)
  271. n_days = 500
  272. dates = pd.date_range('2023-01-01', periods=n_days, freq='B')
  273. # 模拟价格走势(包含趋势、震荡、反转三种状态)
  274. price = 100
  275. prices = []
  276. for i in range(n_days):
  277. # 模拟不同状态
  278. if i < 150: # 趋势
  279. price *= (1 + np.random.normal(0.001, 0.01))
  280. elif i < 300: # 震荡
  281. price *= (1 + np.random.normal(0, 0.015))
  282. else: # 反转
  283. if i < 375:
  284. price *= (1 + np.random.normal(-0.002, 0.012))
  285. else:
  286. price *= (1 + np.random.normal(0.002, 0.012))
  287. prices.append(price)
  288. df = pd.DataFrame({
  289. 'open': prices + np.random.normal(0, 0.5, n_days),
  290. 'high': np.array(prices) + np.abs(np.random.normal(1, 0.5, n_days)),
  291. 'low': np.array(prices) - np.abs(np.random.normal(1, 0.5, n_days)),
  292. 'close': prices,
  293. 'volume': np.random.randint(1000000, 5000000, n_days)
  294. }, index=dates)
  295. print(f"\n示例数据: {len(df)}天")
  296. print(f"日期范围: {df.index[0].date()} ~ {df.index[-1].date()}")
  297. # 特征提取
  298. print("\n提取特征...")
  299. features = extract_features(df)
  300. # 选择训练特征(核心5个)
  301. feature_cols = ['ret_std_5', 'momentum_10', 'vol_ratio', 'volume_change', 'intraday_trend']
  302. X_train = features[feature_cols].dropna()
  303. print(f"特征矩阵: {X_train.shape}")
  304. # 训练HMM模型
  305. hmm = MarketRegimeHMM(n_components=3, n_iter=100)
  306. hmm.fit(X_train)
  307. # 预测状态
  308. states, probs = hmm.predict(X_train)
  309. # 评估模型
  310. eval_results = evaluate_model(hmm, X_train)
  311. # 获取当前状态
  312. current_regime = hmm.get_current_regime(X_train)
  313. print("\n" + "="*70)
  314. print("当前市场状态识别")
  315. print("="*70)
  316. print(f"状态: {current_regime['state_name']} (状态{current_regime['state']})")
  317. print(f"置信度: {current_regime['confidence']:.2%}")
  318. print("\n状态概率分布:")
  319. for name, prob in current_regime['probabilities'].items():
  320. bar = '█' * int(prob * 20)
  321. print(f" {name:6s}: {prob:.2%} {bar}")
  322. # 策略建议
  323. strategy = StrategySelector.get_strategy(current_regime['state'])
  324. current_rsi = features['rsi'].iloc[-1]
  325. current_price = df['close'].iloc[-1]
  326. current_ma20 = df['close'].rolling(20).mean().iloc[-1]
  327. signal = StrategySelector.generate_signal(
  328. current_regime['state'],
  329. current_rsi,
  330. current_price,
  331. current_ma20
  332. )
  333. print("\n" + "="*70)
  334. print("策略建议")
  335. print("="*70)
  336. print(f"推荐策略: {strategy['name']}")
  337. print(f"操作策略: {strategy['action']}")
  338. print(f"仓位建议: {strategy['position_size']*100:.0f}%")
  339. print(f"止损设置: {strategy['stop_loss']}")
  340. print(f"描述: {strategy['description']}")
  341. print("\n交易信号:")
  342. print(f" 动作: {signal['action']}")
  343. if 'reason' in signal:
  344. print(f" 原因: {signal['reason']}")
  345. print("\n" + "="*70)
  346. print("使用说明:")
  347. print("="*70)
  348. print("1. 准备真实市场数据(2017-2025年)")
  349. print("2. 调用 extract_features(df) 提取特征")
  350. print("3. 使用 MarketRegimeHMM 训练模型")
  351. print("4. 根据 get_current_regime() 结果切换策略")
  352. print("\n验证要求: 状态识别准确率 > 72%")
  353. print("="*70)
  354. if __name__ == "__main__":
  355. main()