backtest_multi_timeframe.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. CYB50 多周期确认 + 参数优化回测系统
  5. 结合日线趋势和30分钟择时,支持参数扫描
  6. """
  7. import csv
  8. import json
  9. from datetime import datetime, timedelta
  10. from collections import deque
  11. import math
  12. import os
  13. # ==================== 技术指标计算类 ====================
  14. class TechnicalIndicators:
  15. """技术指标计算 - 纯Python实现"""
  16. @staticmethod
  17. def sma(data, period):
  18. """简单移动平均线"""
  19. if len(data) < period:
  20. return None
  21. return sum(data[-period:]) / period
  22. @staticmethod
  23. def rsi(prices, period=14):
  24. """RSI计算"""
  25. if len(prices) < period + 1:
  26. return None
  27. gains = []
  28. losses = []
  29. for i in range(1, len(prices)):
  30. change = prices[i] - prices[i-1]
  31. if change > 0:
  32. gains.append(change)
  33. losses.append(0)
  34. else:
  35. gains.append(0)
  36. losses.append(abs(change))
  37. if len(gains) < period:
  38. return None
  39. avg_gain = sum(gains[-period:]) / period
  40. avg_loss = sum(losses[-period:]) / period
  41. if avg_loss == 0:
  42. return 100
  43. rs = avg_gain / avg_loss
  44. return 100 - (100 / (1 + rs))
  45. @staticmethod
  46. def bollinger_bands(prices, period=20, std_dev=2):
  47. """布林带计算"""
  48. if len(prices) < period:
  49. return None, None, None
  50. middle = sum(prices[-period:]) / period
  51. variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
  52. std = math.sqrt(variance)
  53. upper = middle + (std * std_dev)
  54. lower = middle - (std * std_dev)
  55. return upper, middle, lower
  56. @staticmethod
  57. def macd(prices, fast=12, slow=26, signal=9):
  58. """MACD计算"""
  59. if len(prices) < slow:
  60. return None, None, None
  61. def calc_ema(data, period):
  62. multiplier = 2 / (period + 1)
  63. ema = data[0]
  64. for price in data[1:]:
  65. ema = (price - ema) * multiplier + ema
  66. return ema
  67. ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None
  68. ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None
  69. if ema_fast is None or ema_slow is None:
  70. return None, None, None
  71. macd_line = ema_fast - ema_slow
  72. macd_prices = []
  73. for i in range(slow, len(prices) + 1):
  74. fast_ema = calc_ema(prices[i-fast:i], fast)
  75. slow_ema = calc_ema(prices[i-slow:i], slow)
  76. macd_prices.append(fast_ema - slow_ema)
  77. signal_line = None
  78. if len(macd_prices) >= signal:
  79. signal_line = calc_ema(macd_prices[-signal:], signal)
  80. histogram = macd_line - signal_line if signal_line else None
  81. return macd_line, signal_line, histogram
  82. # ==================== 日线趋势管理器 ====================
  83. class DailyTrendManager:
  84. """管理日线趋势数据,提供多周期确认"""
  85. def __init__(self, daily_file):
  86. self.daily_data = {}
  87. self.daily_trend = {} # date -> trend info
  88. self.load_daily_data(daily_file)
  89. self.calculate_daily_trend()
  90. def load_daily_data(self, filepath):
  91. """加载日线数据"""
  92. print(f"加载日线数据: {filepath}")
  93. try:
  94. with open(filepath, 'r', encoding='utf-8-sig') as f:
  95. reader = csv.DictReader(f)
  96. for row in reader:
  97. try:
  98. dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S')
  99. date_str = dt.strftime('%Y-%m-%d')
  100. self.daily_data[date_str] = {
  101. 'open': float(row['open']),
  102. 'high': float(row['high']),
  103. 'low': float(row['low']),
  104. 'close': float(row['close']),
  105. 'volume': float(row['volume'])
  106. }
  107. except:
  108. continue
  109. print(f"[OK] 加载成功: {len(self.daily_data)}条日线数据")
  110. except Exception as e:
  111. print(f"[ERROR] 加载失败: {e}")
  112. def calculate_daily_trend(self, ma_period=20):
  113. """计算日线趋势"""
  114. print(f"计算日线趋势 (MA{ma_period})...")
  115. dates = sorted(self.daily_data.keys())
  116. closes = [self.daily_data[d]['close'] for d in dates]
  117. for i, date in enumerate(dates):
  118. if i < ma_period - 1:
  119. self.daily_trend[date] = {'trend': 0, 'ma20': None, 'trend_strength': 0}
  120. continue
  121. # 计算MA20
  122. ma20 = sum(closes[i-ma_period+1:i+1]) / ma_period
  123. close = closes[i]
  124. # 趋势方向: 1=向上, -1=向下, 0=横盘
  125. if close > ma20 * 1.02:
  126. trend = 1 # 明显向上
  127. elif close < ma20 * 0.98:
  128. trend = -1 # 明显向下
  129. else:
  130. trend = 0 # 横盘
  131. # 趋势强度
  132. trend_strength = (close - ma20) / ma20 * 100
  133. self.daily_trend[date] = {
  134. 'trend': trend,
  135. 'ma20': ma20,
  136. 'trend_strength': trend_strength
  137. }
  138. print(f"[OK] 日线趋势计算完成")
  139. def get_daily_trend(self, date_str):
  140. """获取指定日期的日线趋势"""
  141. return self.daily_trend.get(date_str, {'trend': 0, 'ma20': None, 'trend_strength': 0})
  142. def can_trade_long(self, date_str, require_uptrend=True):
  143. """检查是否允许做多"""
  144. trend_info = self.get_daily_trend(date_str)
  145. if require_uptrend:
  146. # 要求日线趋势向上
  147. return trend_info['trend'] == 1, trend_info
  148. else:
  149. # 允许横盘和向上,禁止向下
  150. return trend_info['trend'] >= 0, trend_info
  151. # ==================== 30分钟市场状态管理器 ====================
  152. class MarketRegimeManager:
  153. """管理30分钟市场状态数据"""
  154. def __init__(self, regime_file):
  155. self.regime_data = {}
  156. self.load_regime_data(regime_file)
  157. def load_regime_data(self, filepath):
  158. """加载市场状态数据"""
  159. print(f"加载30分钟状态数据: {filepath}")
  160. try:
  161. with open(filepath, 'r', encoding='utf-8-sig') as f:
  162. reader = csv.DictReader(f)
  163. for row in reader:
  164. dt_str = row['datetime']
  165. self.regime_data[dt_str] = {
  166. 'state': int(row['state']),
  167. 'prob_ranging': float(row['prob_ranging']),
  168. 'prob_trend': float(row['prob_trend']),
  169. 'prob_reversal': float(row['prob_reversal'])
  170. }
  171. print(f"[OK] 加载成功: {len(self.regime_data)}条状态数据")
  172. except Exception as e:
  173. print(f"[ERROR] 加载失败: {e}")
  174. def get_regime(self, dt_str):
  175. """获取指定时间的市场状态"""
  176. return self.regime_data.get(dt_str, {
  177. 'state': 0,
  178. 'prob_ranging': 1.0,
  179. 'prob_trend': 0.0,
  180. 'prob_reversal': 0.0
  181. })
  182. def can_open_long(self, dt_str, min_trend_prob=0.5):
  183. """判断是否允许开多单"""
  184. regime = self.get_regime(dt_str)
  185. state = regime['state']
  186. trend_prob = regime['prob_trend']
  187. if state == 1 and trend_prob >= min_trend_prob:
  188. return True, regime
  189. if state == 2:
  190. return False, regime
  191. return False, regime
  192. # ==================== 回测引擎(带参数) ====================
  193. class BacktestEngine:
  194. """多周期确认回测引擎"""
  195. def __init__(self, initial_capital=1000000, position_size=0.5,
  196. min_trend_prob=0.5, require_daily_uptrend=True):
  197. self.initial_capital = initial_capital
  198. self.position_size = position_size
  199. self.min_trend_prob = min_trend_prob
  200. self.require_daily_uptrend = require_daily_uptrend
  201. self.capital = initial_capital
  202. self.position = 0
  203. self.entry_price = 0
  204. self.entry_time = None
  205. self.holding_periods = 0
  206. self.max_holding_periods = 16
  207. self.equity_curve = []
  208. self.trades = []
  209. self.signals = []
  210. self.block_reasons = {'daily': 0, 'regime': 0}
  211. self.prices = deque(maxlen=100)
  212. self.highs = deque(maxlen=100)
  213. self.lows = deque(maxlen=100)
  214. def calculate_signals(self):
  215. """计算交易信号"""
  216. if len(self.prices) < 50:
  217. return None
  218. price_list = list(self.prices)
  219. rsi = TechnicalIndicators.rsi(price_list, 14)
  220. bb_upper, bb_middle, bb_lower = TechnicalIndicators.bollinger_bands(price_list, 20, 2)
  221. ma5 = TechnicalIndicators.sma(price_list, 5)
  222. ma10 = TechnicalIndicators.sma(price_list, 10)
  223. ma20 = TechnicalIndicators.sma(price_list, 20)
  224. macd_line, signal_line, histogram = TechnicalIndicators.macd(price_list)
  225. return {
  226. 'rsi': rsi,
  227. 'bb_upper': bb_upper,
  228. 'bb_lower': bb_lower,
  229. 'bb_middle': bb_middle,
  230. 'ma5': ma5,
  231. 'ma10': ma10,
  232. 'ma20': ma20,
  233. 'macd': macd_line,
  234. 'macd_signal': signal_line,
  235. 'price': price_list[-1]
  236. }
  237. def check_long_signal(self, signals):
  238. """检查做多信号"""
  239. if signals is None:
  240. return False, "指标不足"
  241. conditions = []
  242. if signals['rsi'] is not None and signals['rsi'] < 65:
  243. conditions.append('RSI<65')
  244. if (signals['ma5'] is not None and signals['ma10'] is not None and
  245. signals['ma5'] > signals['ma10']):
  246. conditions.append('MA5>MA10')
  247. if (signals['macd'] is not None and signals['macd_signal'] is not None and
  248. signals['macd'] > signals['macd_signal']):
  249. conditions.append('MACD金叉')
  250. if (signals['bb_middle'] is not None and
  251. signals['price'] > signals['bb_middle']):
  252. conditions.append('价格>中轨')
  253. if len(conditions) >= 3:
  254. return True, '+'.join(conditions)
  255. return False, f"条件不足({len(conditions)}/3)"
  256. def check_exit_signal(self, signals, current_price):
  257. """检查平仓信号"""
  258. if signals is None or self.position == 0:
  259. return False, ""
  260. stop_loss = self.entry_price * 0.975
  261. if current_price <= stop_loss:
  262. return True, f"止损({current_price:.2f}<={stop_loss:.2f})"
  263. take_profit = self.entry_price * 1.04
  264. if current_price >= take_profit:
  265. return True, f"止盈({current_price:.2f}>={take_profit:.2f})"
  266. if self.holding_periods >= self.max_holding_periods:
  267. return True, f"时间平仓({self.holding_periods}周期)"
  268. if signals['rsi'] is not None and signals['rsi'] > 75:
  269. return True, f"RSI超买({signals['rsi']:.1f})"
  270. return False, ""
  271. def open_position(self, price, time_str, reason):
  272. """开仓"""
  273. position_value = self.capital * self.position_size
  274. self.position = position_value / price
  275. self.entry_price = price
  276. self.entry_time = time_str
  277. self.holding_periods = 0
  278. self.trades.append({
  279. 'action': 'OPEN',
  280. 'time': time_str,
  281. 'price': price,
  282. 'shares': self.position,
  283. 'value': position_value,
  284. 'reason': reason
  285. })
  286. def close_position(self, price, time_str, reason):
  287. """平仓"""
  288. if self.position == 0:
  289. return
  290. pnl = (price - self.entry_price) * self.position
  291. pnl_pct = (price / self.entry_price - 1) * 100
  292. self.capital += pnl
  293. self.trades.append({
  294. 'action': 'CLOSE',
  295. 'time': time_str,
  296. 'price': price,
  297. 'shares': self.position,
  298. 'pnl': pnl,
  299. 'pnl_pct': pnl_pct,
  300. 'reason': reason
  301. })
  302. self.position = 0
  303. self.entry_price = 0
  304. self.holding_periods = 0
  305. def update(self, timestamp, open_price, high, low, close,
  306. daily_manager, regime_manager):
  307. """更新回测状态"""
  308. self.prices.append(close)
  309. self.highs.append(high)
  310. self.lows.append(low)
  311. signals = self.calculate_signals()
  312. dt_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
  313. date_str = timestamp.strftime('%Y-%m-%d')
  314. # 多周期确认
  315. daily_ok, daily_info = daily_manager.can_trade_long(
  316. date_str, self.require_daily_uptrend)
  317. regime_ok, regime_info = regime_manager.can_open_long(
  318. dt_str, self.min_trend_prob)
  319. equity = self.capital
  320. if self.position > 0:
  321. equity += self.position * close
  322. self.equity_curve.append({
  323. 'time': dt_str,
  324. 'equity': equity,
  325. 'close': close,
  326. 'position': 1 if self.position > 0 else 0
  327. })
  328. if self.position > 0:
  329. self.holding_periods += 1
  330. should_exit, exit_reason = self.check_exit_signal(signals, close)
  331. if should_exit:
  332. self.close_position(close, dt_str, exit_reason)
  333. else:
  334. tech_signal, tech_reason = self.check_long_signal(signals)
  335. if tech_signal:
  336. block_reason = []
  337. if not daily_ok:
  338. block_reason.append(f"日线趋势向下(强度:{daily_info['trend_strength']:.2f}%)")
  339. self.block_reasons['daily'] += 1
  340. if not regime_ok:
  341. block_reason.append(f"30分钟非趋势状态(state={regime_info['state']},概率={regime_info['prob_trend']:.2f})")
  342. self.block_reasons['regime'] += 1
  343. if daily_ok and regime_ok:
  344. self.open_position(close, dt_str,
  345. f"{tech_reason}|日线向上|30分钟趋势(prob={regime_info['prob_trend']:.2f})")
  346. else:
  347. self.signals.append({
  348. 'time': dt_str,
  349. 'price': close,
  350. 'tech_reason': tech_reason,
  351. 'block_reason': '|'.join(block_reason)
  352. })
  353. return equity
  354. # ==================== 主程序 ====================
  355. def load_data(filepath):
  356. """加载30分钟数据"""
  357. print(f"加载30分钟数据: {filepath}")
  358. data = []
  359. with open(filepath, 'r', encoding='utf-8-sig') as f:
  360. reader = csv.DictReader(f)
  361. for row in reader:
  362. try:
  363. dt = datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S')
  364. data.append({
  365. 'datetime': dt,
  366. 'open': float(row['Open']),
  367. 'high': float(row['High']),
  368. 'low': float(row['Low']),
  369. 'close': float(row['Close']),
  370. 'volume': float(row['Volume'])
  371. })
  372. except:
  373. continue
  374. print(f"[OK] 加载成功: {len(data)}条")
  375. return data
  376. def run_single_backtest(data, daily_manager, regime_manager, params,
  377. output_dir='backtest_results'):
  378. """运行单次回测"""
  379. os.makedirs(output_dir, exist_ok=True)
  380. engine = BacktestEngine(
  381. initial_capital=1000000,
  382. position_size=0.5,
  383. min_trend_prob=params['min_trend_prob'],
  384. require_daily_uptrend=params['require_daily_uptrend']
  385. )
  386. for row in data:
  387. engine.update(
  388. row['datetime'],
  389. row['open'],
  390. row['high'],
  391. row['low'],
  392. row['close'],
  393. daily_manager,
  394. regime_manager
  395. )
  396. # 统计结果
  397. initial = engine.initial_capital
  398. final = engine.equity_curve[-1]['equity'] if engine.equity_curve else initial
  399. total_return = (final / initial - 1) * 100
  400. closed_trades = [t for t in engine.trades if t['action'] == 'CLOSE']
  401. win_count = len([t for t in closed_trades if t['pnl'] > 0])
  402. loss_count = len([t for t in closed_trades if t['pnl'] <= 0])
  403. win_rate = win_count / len(closed_trades) * 100 if closed_trades else 0
  404. total_profit = sum(t['pnl'] for t in closed_trades if t['pnl'] > 0)
  405. total_loss = sum(t['pnl'] for t in closed_trades if t['pnl'] <= 0)
  406. profit_factor = abs(total_profit / total_loss) if total_loss != 0 else 0
  407. result = {
  408. 'params': params,
  409. 'total_return': total_return,
  410. 'trade_count': len(closed_trades),
  411. 'win_count': win_count,
  412. 'loss_count': loss_count,
  413. 'win_rate': win_rate,
  414. 'profit_factor': profit_factor,
  415. 'blocked_daily': engine.block_reasons['daily'],
  416. 'blocked_regime': engine.block_reasons['regime'],
  417. 'final_capital': final
  418. }
  419. return result, engine
  420. def run_parameter_scan(data_file, daily_file, regime_file, output_dir='optimization_results'):
  421. """参数扫描优化"""
  422. os.makedirs(output_dir, exist_ok=True)
  423. # 加载数据
  424. data = load_data(data_file)
  425. daily_manager = DailyTrendManager(daily_file)
  426. regime_manager = MarketRegimeManager(regime_file)
  427. # 参数网格
  428. param_grid = [
  429. {'min_trend_prob': 0.3, 'require_daily_uptrend': True},
  430. {'min_trend_prob': 0.4, 'require_daily_uptrend': True},
  431. {'min_trend_prob': 0.5, 'require_daily_uptrend': True},
  432. {'min_trend_prob': 0.6, 'require_daily_uptrend': True},
  433. {'min_trend_prob': 0.7, 'require_daily_uptrend': True},
  434. {'min_trend_prob': 0.5, 'require_daily_uptrend': False}, # 允许横盘
  435. ]
  436. print("\n" + "="*70)
  437. print("参数优化扫描")
  438. print("="*70)
  439. all_results = []
  440. for i, params in enumerate(param_grid):
  441. print(f"\n[{i+1}/{len(param_grid)}] 测试参数: {params}")
  442. result, engine = run_single_backtest(
  443. data, daily_manager, regime_manager, params, output_dir)
  444. all_results.append(result)
  445. print(f" 收益率: {result['total_return']:+.2f}%")
  446. print(f" 交易次数: {result['trade_count']}")
  447. print(f" 胜率: {result['win_rate']:.1f}%")
  448. print(f" 盈亏比: {result['profit_factor']:.2f}")
  449. # 排序结果
  450. all_results.sort(key=lambda x: x['total_return'], reverse=True)
  451. print("\n" + "="*70)
  452. print("参数优化结果排名")
  453. print("="*70)
  454. for i, r in enumerate(all_results[:5]):
  455. print(f"\n第{i+1}名:")
  456. print(f" 参数: 趋势概率阈值={r['params']['min_trend_prob']}, "
  457. f"要求日线向上={r['params']['require_daily_uptrend']}")
  458. print(f" 收益率: {r['total_return']:+.2f}%")
  459. print(f" 交易次数: {r['trade_count']}")
  460. print(f" 胜率: {r['win_rate']:.1f}%")
  461. print(f" 盈亏比: {r['profit_factor']:.2f}")
  462. print(f" 被日线过滤: {r['blocked_daily']}次")
  463. print(f" 被30分钟过滤: {r['blocked_regime']}次")
  464. # 保存优化结果
  465. result_file = f"{output_dir}/parameter_optimization_results.json"
  466. with open(result_file, 'w', encoding='utf-8') as f:
  467. json.dump(all_results, f, indent=2, ensure_ascii=False)
  468. print(f"\n优化结果已保存: {result_file}")
  469. return all_results
  470. if __name__ == '__main__':
  471. DATA_FILE = 'cyb50_30min_2023_to_20260325.csv'
  472. DAILY_FILE = '../data-fetch/data/399673_SZ_day_20150101_20260325.csv'
  473. REGIME_FILE = '../../market-regime-identifier-30/cyb50_30min_regime_result.csv'
  474. results = run_parameter_scan(DATA_FILE, DAILY_FILE, REGIME_FILE)