backtest_t1_with_regime.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. CYB50 择时过滤T+1回测系统 - 结合市场状态
  5. 只做多,使用市场状态过滤开仓信号
  6. """
  7. import csv
  8. import json
  9. from datetime import datetime, timedelta
  10. from collections import deque
  11. import math
  12. # ==================== 技术指标计算类 ====================
  13. class TechnicalIndicators:
  14. """技术指标计算 - 纯Python实现"""
  15. @staticmethod
  16. def sma(data, period):
  17. """简单移动平均线"""
  18. if len(data) < period:
  19. return None
  20. return sum(data[-period:]) / period
  21. @staticmethod
  22. def ema(data, period):
  23. """指数移动平均线"""
  24. if len(data) < period:
  25. return None
  26. multiplier = 2 / (period + 1)
  27. ema = data[0]
  28. for price in data[1:]:
  29. ema = (price - ema) * multiplier + ema
  30. return ema
  31. @staticmethod
  32. def rsi(prices, period=14):
  33. """RSI计算"""
  34. if len(prices) < period + 1:
  35. return None
  36. gains = []
  37. losses = []
  38. for i in range(1, len(prices)):
  39. change = prices[i] - prices[i-1]
  40. if change > 0:
  41. gains.append(change)
  42. losses.append(0)
  43. else:
  44. gains.append(0)
  45. losses.append(abs(change))
  46. if len(gains) < period:
  47. return None
  48. avg_gain = sum(gains[-period:]) / period
  49. avg_loss = sum(losses[-period:]) / period
  50. if avg_loss == 0:
  51. return 100
  52. rs = avg_gain / avg_loss
  53. return 100 - (100 / (1 + rs))
  54. @staticmethod
  55. def bollinger_bands(prices, period=20, std_dev=2):
  56. """布林带计算"""
  57. if len(prices) < period:
  58. return None, None, None
  59. middle = sum(prices[-period:]) / period
  60. variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
  61. std = math.sqrt(variance)
  62. upper = middle + (std * std_dev)
  63. lower = middle - (std * std_dev)
  64. return upper, middle, lower
  65. @staticmethod
  66. def macd(prices, fast=12, slow=26, signal=9):
  67. """MACD计算"""
  68. if len(prices) < slow:
  69. return None, None, None
  70. def calc_ema(data, period):
  71. multiplier = 2 / (period + 1)
  72. ema = data[0]
  73. for price in data[1:]:
  74. ema = (price - ema) * multiplier + ema
  75. return ema
  76. ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None
  77. ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None
  78. if ema_fast is None or ema_slow is None:
  79. return None, None, None
  80. macd_line = ema_fast - ema_slow
  81. # 计算信号线 (EMA of MACD)
  82. macd_prices = []
  83. for i in range(slow, len(prices) + 1):
  84. fast_ema = calc_ema(prices[i-fast:i], fast)
  85. slow_ema = calc_ema(prices[i-slow:i], slow)
  86. macd_prices.append(fast_ema - slow_ema)
  87. signal_line = None
  88. if len(macd_prices) >= signal:
  89. signal_line = calc_ema(macd_prices[-signal:], signal)
  90. histogram = macd_line - signal_line if signal_line else None
  91. return macd_line, signal_line, histogram
  92. # ==================== 市场状态管理器 ====================
  93. class MarketRegimeManager:
  94. """管理市场状态数据,提供择时过滤"""
  95. def __init__(self, regime_file):
  96. self.regime_data = {}
  97. self.load_regime_data(regime_file)
  98. def load_regime_data(self, filepath):
  99. """加载市场状态数据"""
  100. print(f"加载市场状态数据: {filepath}")
  101. try:
  102. with open(filepath, 'r', encoding='utf-8') as f:
  103. reader = csv.DictReader(f)
  104. for row in reader:
  105. # 解析datetime
  106. dt_str = row['datetime']
  107. self.regime_data[dt_str] = {
  108. 'state': int(row['state']),
  109. 'prob_ranging': float(row['prob_ranging']),
  110. 'prob_trend': float(row['prob_trend']),
  111. 'prob_reversal': float(row['prob_reversal'])
  112. }
  113. print(f"[OK] 加载成功: {len(self.regime_data)}条状态数据")
  114. except Exception as e:
  115. print(f"[ERROR] 加载失败: {e}")
  116. self.regime_data = {}
  117. def get_regime(self, dt_str):
  118. """获取指定时间的市场状态"""
  119. return self.regime_data.get(dt_str, {
  120. 'state': 0, # 默认震荡
  121. 'prob_ranging': 1.0,
  122. 'prob_trend': 0.0,
  123. 'prob_reversal': 0.0
  124. })
  125. def can_open_long(self, dt_str, min_trend_prob=0.5):
  126. """
  127. 判断是否允许开多单
  128. 规则:
  129. - 趋势状态(state=1) + 趋势概率 > min_trend_prob -> 允许
  130. - 其他状态 -> 禁止
  131. """
  132. regime = self.get_regime(dt_str)
  133. state = regime['state']
  134. trend_prob = regime['prob_trend']
  135. # 只在趋势状态且概率足够高时允许开仓
  136. if state == 1 and trend_prob >= min_trend_prob:
  137. return True, f"趋势状态(概率{trend_prob:.2f})"
  138. # 反转状态 - 禁止开仓
  139. if state == 2:
  140. return False, f"反转状态(概率{regime['prob_reversal']:.2f})"
  141. # 震荡状态 - 观望
  142. return False, f"震荡状态(概率{regime['prob_ranging']:.2f})"
  143. # ==================== 回测引擎 ====================
  144. class BacktestEngine:
  145. """择时过滤T+1回测引擎"""
  146. def __init__(self, initial_capital=1000000, position_size=0.5):
  147. self.initial_capital = initial_capital
  148. self.position_size = position_size
  149. self.capital = initial_capital
  150. self.position = 0 # 持仓数量
  151. self.entry_price = 0
  152. self.entry_time = None
  153. self.holding_periods = 0
  154. self.max_holding_periods = 16 # 最大持仓周期(8小时)
  155. # 记录
  156. self.equity_curve = []
  157. self.trades = []
  158. self.signals = []
  159. # 指标
  160. self.prices = deque(maxlen=100)
  161. self.highs = deque(maxlen=100)
  162. self.lows = deque(maxlen=100)
  163. def calculate_signals(self):
  164. """计算交易信号"""
  165. if len(self.prices) < 50:
  166. return None
  167. price_list = list(self.prices)
  168. high_list = list(self.highs)
  169. low_list = list(self.lows)
  170. # 技术指标
  171. rsi = TechnicalIndicators.rsi(price_list, 14)
  172. bb_upper, bb_middle, bb_lower = TechnicalIndicators.bollinger_bands(price_list, 20, 2)
  173. # 均线
  174. ma5 = TechnicalIndicators.sma(price_list, 5)
  175. ma10 = TechnicalIndicators.sma(price_list, 10)
  176. ma20 = TechnicalIndicators.sma(price_list, 20)
  177. # MACD
  178. macd_line, signal_line, histogram = TechnicalIndicators.macd(price_list)
  179. return {
  180. 'rsi': rsi,
  181. 'bb_upper': bb_upper,
  182. 'bb_lower': bb_lower,
  183. 'bb_middle': bb_middle,
  184. 'ma5': ma5,
  185. 'ma10': ma10,
  186. 'ma20': ma20,
  187. 'macd': macd_line,
  188. 'macd_signal': signal_line,
  189. 'price': price_list[-1]
  190. }
  191. def check_long_signal(self, signals):
  192. """检查做多信号"""
  193. if signals is None:
  194. return False, "指标不足"
  195. conditions = []
  196. # RSI条件 - 避免超买
  197. if signals['rsi'] is not None and signals['rsi'] < 65:
  198. conditions.append('RSI<65')
  199. # 均线条件 - 短期在长期之上
  200. if (signals['ma5'] is not None and signals['ma10'] is not None and
  201. signals['ma5'] > signals['ma10']):
  202. conditions.append('MA5>MA10')
  203. # MACD条件
  204. if (signals['macd'] is not None and signals['macd_signal'] is not None and
  205. signals['macd'] > signals['macd_signal']):
  206. conditions.append('MACD金叉')
  207. # 布林带条件 - 价格在布林带中轨之上
  208. if (signals['bb_middle'] is not None and
  209. signals['price'] > signals['bb_middle']):
  210. conditions.append('价格>中轨')
  211. # 至少需要3个条件满足
  212. if len(conditions) >= 3:
  213. return True, '+'.join(conditions)
  214. return False, f"条件不足({len(conditions)}/3)"
  215. def check_exit_signal(self, signals, current_price):
  216. """检查平仓信号"""
  217. if signals is None or self.position == 0:
  218. return False, ""
  219. # 止损 2.5%
  220. stop_loss = self.entry_price * 0.975
  221. if current_price <= stop_loss:
  222. return True, f"止损({current_price:.2f}<={stop_loss:.2f})"
  223. # 止盈 4%
  224. take_profit = self.entry_price * 1.04
  225. if current_price >= take_profit:
  226. return True, f"止盈({current_price:.2f}>={take_profit:.2f})"
  227. # 最大持仓时间
  228. if self.holding_periods >= self.max_holding_periods:
  229. return True, f"时间平仓({self.holding_periods}周期)"
  230. # RSI超买平仓
  231. if signals['rsi'] is not None and signals['rsi'] > 75:
  232. return True, f"RSI超买({signals['rsi']:.1f})"
  233. return False, ""
  234. def open_position(self, price, time_str, reason):
  235. """开仓"""
  236. position_value = self.capital * self.position_size
  237. self.position = position_value / price
  238. self.entry_price = price
  239. self.entry_time = time_str
  240. self.holding_periods = 0
  241. self.trades.append({
  242. 'action': 'OPEN',
  243. 'time': time_str,
  244. 'price': price,
  245. 'shares': self.position,
  246. 'value': position_value,
  247. 'reason': reason
  248. })
  249. def close_position(self, price, time_str, reason):
  250. """平仓"""
  251. if self.position == 0:
  252. return
  253. pnl = (price - self.entry_price) * self.position
  254. pnl_pct = (price / self.entry_price - 1) * 100
  255. self.capital += pnl
  256. self.trades.append({
  257. 'action': 'CLOSE',
  258. 'time': time_str,
  259. 'price': price,
  260. 'shares': self.position,
  261. 'pnl': pnl,
  262. 'pnl_pct': pnl_pct,
  263. 'reason': reason
  264. })
  265. self.position = 0
  266. self.entry_price = 0
  267. self.holding_periods = 0
  268. def update(self, timestamp, open_price, high, low, close, regime_manager):
  269. """更新回测状态"""
  270. self.prices.append(close)
  271. self.highs.append(high)
  272. self.lows.append(low)
  273. # 计算信号
  274. signals = self.calculate_signals()
  275. # 获取市场状态
  276. dt_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
  277. can_open, regime_reason = regime_manager.can_open_long(dt_str)
  278. # 记录权益
  279. equity = self.capital
  280. if self.position > 0:
  281. equity += self.position * close
  282. self.equity_curve.append({
  283. 'time': dt_str,
  284. 'equity': equity,
  285. 'close': close,
  286. 'position': 1 if self.position > 0 else 0
  287. })
  288. # 持仓更新
  289. if self.position > 0:
  290. self.holding_periods += 1
  291. # 检查平仓
  292. should_exit, exit_reason = self.check_exit_signal(signals, close)
  293. if should_exit:
  294. self.close_position(close, dt_str, exit_reason)
  295. else:
  296. # 空仓 - 检查开仓
  297. # 先检查技术信号
  298. tech_signal, tech_reason = self.check_long_signal(signals)
  299. if tech_signal:
  300. # 技术信号满足,再检查择时过滤
  301. if can_open:
  302. self.open_position(close, dt_str, f"{tech_reason}|{regime_reason}")
  303. else:
  304. # 技术信号满足但被择时过滤
  305. self.signals.append({
  306. 'time': dt_str,
  307. 'price': close,
  308. 'tech_reason': tech_reason,
  309. 'block_reason': regime_reason
  310. })
  311. return equity
  312. # ==================== 主程序 ====================
  313. def load_data(filepath):
  314. """加载30分钟数据"""
  315. print(f"加载数据: {filepath}")
  316. data = []
  317. with open(filepath, 'r', encoding='utf-8-sig') as f: # utf-8-sig handles BOM
  318. reader = csv.DictReader(f)
  319. for row in reader:
  320. try:
  321. dt = datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S')
  322. data.append({
  323. 'datetime': dt,
  324. 'open': float(row['Open']),
  325. 'high': float(row['High']),
  326. 'low': float(row['Low']),
  327. 'close': float(row['Close']),
  328. 'volume': float(row['Volume'])
  329. })
  330. except Exception as e:
  331. continue
  332. print(f"[OK] 加载成功: {len(data)}条")
  333. return data
  334. def run_backtest(data_file, regime_file, output_dir='backtest_results'):
  335. """运行回测"""
  336. import os
  337. os.makedirs(output_dir, exist_ok=True)
  338. # 加载数据
  339. data = load_data(data_file)
  340. regime_manager = MarketRegimeManager(regime_file)
  341. # 创建回测引擎
  342. engine = BacktestEngine(initial_capital=1000000, position_size=0.5)
  343. print("\n" + "="*70)
  344. print("开始回测 - 择时过滤T+1策略")
  345. print("="*70)
  346. print("策略规则:")
  347. print(" - 只做多,持仓上限50%")
  348. print(" - 技术信号: RSI<65 + MA5>MA10 + MACD金叉 + 价格>布林带中轨")
  349. print(" - 择时过滤: 只在趋势状态(state=1)且趋势概率>0.5时开仓")
  350. print(" - 止损: -2.5% | 止盈: +4% | 最大持仓: 16周期(8小时)")
  351. print("="*70)
  352. # 运行回测
  353. for row in data:
  354. engine.update(
  355. row['datetime'],
  356. row['open'],
  357. row['high'],
  358. row['low'],
  359. row['close'],
  360. regime_manager
  361. )
  362. # 统计结果
  363. print("\n" + "="*70)
  364. print("回测结果")
  365. print("="*70)
  366. initial = engine.initial_capital
  367. final = engine.equity_curve[-1]['equity'] if engine.equity_curve else initial
  368. total_return = (final / initial - 1) * 100
  369. print(f"初始资金: {initial:,.2f} 元")
  370. print(f"最终资金: {final:,.2f} 元")
  371. print(f"总收益率: {total_return:+.2f}%")
  372. # 交易统计
  373. trades = engine.trades
  374. closed_trades = [t for t in trades if t['action'] == 'CLOSE']
  375. print(f"\n总交易次数: {len(closed_trades)}")
  376. if closed_trades:
  377. wins = [t for t in closed_trades if t['pnl'] > 0]
  378. losses = [t for t in closed_trades if t['pnl'] <= 0]
  379. win_count = len(wins)
  380. loss_count = len(losses)
  381. win_rate = win_count / len(closed_trades) * 100
  382. total_profit = sum(t['pnl'] for t in wins) if wins else 0
  383. total_loss = sum(t['pnl'] for t in losses) if losses else 0
  384. avg_win = total_profit / win_count if win_count > 0 else 0
  385. avg_loss = total_loss / loss_count if loss_count > 0 else 0
  386. profit_factor = abs(total_profit / total_loss) if total_loss != 0 else 0
  387. print(f" 盈利: {win_count} | 亏损: {loss_count}")
  388. print(f" 胜率: {win_rate:.2f}%")
  389. print(f" 盈亏比: {profit_factor:.2f}")
  390. print(f" 平均每笔盈利: {avg_win:,.2f}")
  391. print(f" 平均每笔亏损: {avg_loss:,.2f}")
  392. # 过滤掉的信号统计
  393. blocked = engine.signals
  394. print(f"\n被择时过滤的信号: {len(blocked)}次")
  395. if blocked:
  396. print(" (技术信号满足但市场状态不允许开仓)")
  397. # 保存结果
  398. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  399. # 保存权益曲线
  400. equity_file = f"{output_dir}/equity_with_regime_{timestamp}.csv"
  401. with open(equity_file, 'w', newline='', encoding='utf-8') as f:
  402. writer = csv.DictWriter(f, fieldnames=['time', 'equity', 'close', 'position'])
  403. writer.writeheader()
  404. writer.writerows(engine.equity_curve)
  405. # 保存交易记录
  406. trades_file = f"{output_dir}/trades_with_regime_{timestamp}.csv"
  407. with open(trades_file, 'w', newline='', encoding='utf-8') as f:
  408. if trades and len(trades) > 0:
  409. writer = csv.DictWriter(f, fieldnames=trades[0].keys())
  410. writer.writeheader()
  411. writer.writerows(trades)
  412. # 保存过滤信号
  413. if blocked:
  414. blocked_file = f"{output_dir}/blocked_signals_{timestamp}.csv"
  415. with open(blocked_file, 'w', newline='', encoding='utf-8') as f:
  416. writer = csv.DictWriter(f, fieldnames=blocked[0].keys())
  417. writer.writeheader()
  418. writer.writerows(blocked)
  419. # 保存报告
  420. report_file = f"{output_dir}/report_with_regime_{timestamp}.txt"
  421. with open(report_file, 'w', encoding='utf-8') as f:
  422. f.write("="*70 + "\n")
  423. f.write("CYB50 择时过滤T+1策略回测报告\n")
  424. f.write("="*70 + "\n\n")
  425. f.write(f"初始资金: {initial:,.2f} 元\n")
  426. f.write(f"最终资金: {final:,.2f} 元\n")
  427. f.write(f"总收益率: {total_return:+.2f}%\n")
  428. f.write(f"总交易次数: {len(closed_trades)}\n")
  429. if closed_trades:
  430. f.write(f"胜率: {win_rate:.2f}%\n")
  431. f.write(f"盈亏比: {profit_factor:.2f}\n")
  432. f.write(f"\n被择时过滤的信号: {len(blocked)}次\n")
  433. print(f"\n结果已保存到: {output_dir}/")
  434. print(f" - {equity_file}")
  435. print(f" - {trades_file}")
  436. print(f" - {report_file}")
  437. return engine
  438. if __name__ == '__main__':
  439. DATA_FILE = 'cyb50_30min_2023_to_20260325.csv'
  440. REGIME_FILE = '../../market-regime-identifier-30/cyb50_30min_regime_result.csv'
  441. engine = run_backtest(DATA_FILE, REGIME_FILE)