backtest_final_optimal.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. CYB50 最优参数完整回测 + 数据核对
  5. 参数: min_trend_prob=0.3, require_daily_uptrend=True
  6. """
  7. import csv
  8. import json
  9. from datetime import datetime, timedelta
  10. from collections import deque
  11. import math
  12. import os
  13. class TechnicalIndicators:
  14. @staticmethod
  15. def sma(data, period):
  16. if len(data) < period:
  17. return None
  18. return sum(data[-period:]) / period
  19. @staticmethod
  20. def rsi(prices, period=14):
  21. if len(prices) < period + 1:
  22. return None
  23. gains, losses = [], []
  24. for i in range(1, len(prices)):
  25. change = prices[i] - prices[i-1]
  26. gains.append(change if change > 0 else 0)
  27. losses.append(abs(change) if change < 0 else 0)
  28. avg_gain = sum(gains[-period:]) / period
  29. avg_loss = sum(losses[-period:]) / period
  30. if avg_loss == 0:
  31. return 100
  32. return 100 - (100 / (1 + avg_gain / avg_loss))
  33. @staticmethod
  34. def bollinger_bands(prices, period=20, std_dev=2):
  35. if len(prices) < period:
  36. return None, None, None
  37. middle = sum(prices[-period:]) / period
  38. variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
  39. std = math.sqrt(variance)
  40. return middle + std*std_dev, middle, middle - std*std_dev
  41. @staticmethod
  42. def macd(prices, fast=12, slow=26, signal=9):
  43. if len(prices) < slow:
  44. return None, None, None
  45. def calc_ema(data, period):
  46. mult = 2 / (period + 1)
  47. ema = data[0]
  48. for p in data[1:]:
  49. ema = (p - ema) * mult + ema
  50. return ema
  51. macd_vals = []
  52. for i in range(slow, len(prices)+1):
  53. f = calc_ema(prices[i-fast:i], fast)
  54. s = calc_ema(prices[i-slow:i], slow)
  55. macd_vals.append(f - s)
  56. sig = calc_ema(macd_vals[-signal:], signal) if len(macd_vals) >= signal else None
  57. return macd_vals[-1], sig, macd_vals[-1] - sig if sig else None
  58. class DailyTrendManager:
  59. def __init__(self, daily_file):
  60. self.daily_data = {}
  61. self.daily_trend = {}
  62. self.load_daily_data(daily_file)
  63. self.calculate_daily_trend()
  64. def load_daily_data(self, filepath):
  65. with open(filepath, 'r', encoding='utf-8-sig') as f:
  66. reader = csv.DictReader(f)
  67. for row in reader:
  68. try:
  69. dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S')
  70. self.daily_data[dt.strftime('%Y-%m-%d')] = {
  71. 'open': float(row['open']), 'high': float(row['high']),
  72. 'low': float(row['low']), 'close': float(row['close'])
  73. }
  74. except:
  75. continue
  76. def calculate_daily_trend(self, ma_period=20):
  77. dates = sorted(self.daily_data.keys())
  78. closes = [self.daily_data[d]['close'] for d in dates]
  79. for i, date in enumerate(dates):
  80. if i < ma_period - 1:
  81. self.daily_trend[date] = {'trend': 0, 'ma20': None, 'trend_strength': 0}
  82. continue
  83. ma20 = sum(closes[i-ma_period+1:i+1]) / ma_period
  84. close = closes[i]
  85. trend = 1 if close > ma20 * 1.02 else (-1 if close < ma20 * 0.98 else 0)
  86. self.daily_trend[date] = {
  87. 'trend': trend, 'ma20': ma20,
  88. 'trend_strength': (close - ma20) / ma20 * 100
  89. }
  90. def get_daily_trend(self, date_str):
  91. return self.daily_trend.get(date_str, {'trend': 0, 'ma20': None, 'trend_strength': 0})
  92. class MarketRegimeManager:
  93. def __init__(self, regime_file):
  94. self.regime_data = {}
  95. self.load_regime_data(regime_file)
  96. def load_regime_data(self, filepath):
  97. with open(filepath, 'r', encoding='utf-8-sig') as f:
  98. reader = csv.DictReader(f)
  99. for row in reader:
  100. self.regime_data[row['datetime']] = {
  101. 'state': int(row['state']),
  102. 'prob_trend': float(row['prob_trend'])
  103. }
  104. def get_regime(self, dt_str):
  105. return self.regime_data.get(dt_str, {'state': 0, 'prob_trend': 0.0})
  106. class BacktestEngine:
  107. def __init__(self):
  108. self.initial_capital = 1000000
  109. self.position_size = 0.5
  110. self.capital = self.initial_capital
  111. self.position = 0
  112. self.entry_price = 0
  113. self.holding_periods = 0
  114. self.max_holding_periods = 16
  115. self.equity_curve = []
  116. self.trades = []
  117. self.prices = deque(maxlen=100)
  118. def calculate_signals(self):
  119. if len(self.prices) < 50:
  120. return None
  121. pl = list(self.prices)
  122. return {
  123. 'rsi': TechnicalIndicators.rsi(pl),
  124. 'bb_middle': TechnicalIndicators.bollinger_bands(pl)[1],
  125. 'ma5': TechnicalIndicators.sma(pl, 5),
  126. 'ma10': TechnicalIndicators.sma(pl, 10),
  127. 'macd': TechnicalIndicators.macd(pl)[0],
  128. 'macd_signal': TechnicalIndicators.macd(pl)[1],
  129. 'price': pl[-1]
  130. }
  131. def check_long_signal(self, s):
  132. if not s:
  133. return False, ""
  134. c = []
  135. if s['rsi'] and s['rsi'] < 65: c.append('RSI<65')
  136. if s['ma5'] and s['ma10'] and s['ma5'] > s['ma10']: c.append('MA5>MA10')
  137. if s['macd'] and s['macd_signal'] and s['macd'] > s['macd_signal']: c.append('MACD金叉')
  138. if s['bb_middle'] and s['price'] > s['bb_middle']: c.append('价格>中轨')
  139. return (True, '+'.join(c)) if len(c) >= 3 else (False, f"{len(c)}/3")
  140. def check_exit(self, s, price):
  141. if not s or self.position == 0:
  142. return False, ""
  143. if price <= self.entry_price * 0.975: return True, f"止损({price:.2f})"
  144. if price >= self.entry_price * 1.04: return True, f"止盈({price:.2f})"
  145. if self.holding_periods >= self.max_holding_periods: return True, "时间平仓"
  146. if s['rsi'] and s['rsi'] > 75: return True, f"RSI超买({s['rsi']:.1f})"
  147. return False, ""
  148. def open(self, price, time_str, reason):
  149. val = self.capital * self.position_size
  150. self.position = val / price
  151. self.entry_price = price
  152. self.holding_periods = 0
  153. self.trades.append({'action': 'OPEN', 'time': time_str, 'price': price,
  154. 'shares': self.position, 'value': val, 'reason': reason})
  155. def close(self, price, time_str, reason):
  156. if self.position == 0: return
  157. pnl = (price - self.entry_price) * self.position
  158. pnl_pct = (price / self.entry_price - 1) * 100
  159. self.capital += pnl
  160. self.trades.append({'action': 'CLOSE', 'time': time_str, 'price': price,
  161. 'shares': self.position, 'pnl': pnl, 'pnl_pct': pnl_pct, 'reason': reason})
  162. self.position = 0
  163. def update(self, ts, o, h, l, c, dm, rm):
  164. self.prices.append(c)
  165. dt_str = ts.strftime('%Y-%m-%d %H:%M:%S')
  166. date_str = ts.strftime('%Y-%m-%d')
  167. daily = dm.get_daily_trend(date_str)
  168. regime = rm.get_regime(dt_str)
  169. equity = self.capital + (self.position * c if self.position > 0 else 0)
  170. self.equity_curve.append({'time': dt_str, 'equity': equity, 'close': c, 'position': 1 if self.position else 0,
  171. 'daily_trend': daily['trend'], 'daily_strength': daily['trend_strength'],
  172. 'regime_state': regime['state'], 'regime_prob': regime['prob_trend']})
  173. if self.position > 0:
  174. self.holding_periods += 1
  175. s = self.calculate_signals()
  176. ex, reason = self.check_exit(s, c)
  177. if ex: self.close(c, dt_str, reason)
  178. else:
  179. s = self.calculate_signals()
  180. ok, tech_reason = self.check_long_signal(s)
  181. if ok and daily['trend'] == 1 and regime['state'] == 1 and regime['prob_trend'] >= 0.3:
  182. self.open(c, dt_str, f"{tech_reason}|日线向上|30分钟趋势{regime['prob_trend']:.2f}")
  183. return equity
  184. def load_data(fp):
  185. data = []
  186. with open(fp, 'r', encoding='utf-8-sig') as f:
  187. for row in csv.DictReader(f):
  188. try:
  189. data.append({
  190. 'datetime': datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S'),
  191. 'open': float(row['Open']), 'high': float(row['High']),
  192. 'low': float(row['Low']), 'close': float(row['Close'])
  193. })
  194. except:
  195. continue
  196. return data
  197. def verify_data_integrity(data, dm, rm):
  198. """核对数据完整性"""
  199. print("\n" + "="*70)
  200. print("数据准确性核对报告")
  201. print("="*70)
  202. issues = []
  203. checked = 0
  204. for row in data:
  205. dt_str = row['datetime'].strftime('%Y-%m-%d %H:%M:%S')
  206. date_str = row['datetime'].strftime('%Y-%m-%d')
  207. # 检查日线数据
  208. if date_str not in dm.daily_data:
  209. issues.append(f"缺少日线数据: {date_str}")
  210. # 检查30分钟状态
  211. if dt_str not in rm.regime_data:
  212. issues.append(f"缺少30分钟状态: {dt_str}")
  213. checked += 1
  214. if checked % 1000 == 0:
  215. print(f" 已核对 {checked}/{len(data)} 条数据...")
  216. print(f"\n数据核对完成:")
  217. print(f" 总数据条数: {len(data)}")
  218. print(f" 日线数据: {len(dm.daily_data)}条")
  219. print(f" 30分钟状态: {len(rm.regime_data)}条")
  220. print(f" 发现问题: {len(issues)}个")
  221. if issues:
  222. print(f"\n前10个问题:")
  223. for i in issues[:10]:
  224. print(f" - {i}")
  225. return len(issues) == 0
  226. def run_backtest(data_file, daily_file, regime_file, output_dir='final_backtest'):
  227. os.makedirs(output_dir, exist_ok=True)
  228. print("加载数据...")
  229. data = load_data(data_file)
  230. dm = DailyTrendManager(daily_file)
  231. rm = MarketRegimeManager(regime_file)
  232. # 核对数据
  233. data_ok = verify_data_integrity(data, dm, rm)
  234. if not data_ok:
  235. print("\n[警告] 数据存在问题,但继续回测...")
  236. print("\n运行最优参数回测...")
  237. engine = BacktestEngine()
  238. for row in data:
  239. engine.update(row['datetime'], row['open'], row['high'], row['low'], row['close'], dm, rm)
  240. # 统计
  241. initial = engine.initial_capital
  242. final = engine.equity_curve[-1]['equity']
  243. total_ret = (final / initial - 1) * 100
  244. closed = [t for t in engine.trades if t['action'] == 'CLOSE']
  245. wins = [t for t in closed if t['pnl'] > 0]
  246. losses = [t for t in closed if t['pnl'] <= 0]
  247. win_rate = len(wins) / len(closed) * 100 if closed else 0
  248. total_profit = sum(t['pnl'] for t in wins) if wins else 0
  249. total_loss = sum(t['pnl'] for t in losses) if losses else 0
  250. profit_factor = abs(total_profit / total_loss) if total_loss else 0
  251. # 计算最大回撤
  252. peak = initial
  253. max_dd = 0
  254. for e in engine.equity_curve:
  255. if e['equity'] > peak:
  256. peak = e['equity']
  257. dd = (peak - e['equity']) / peak * 100
  258. if dd > max_dd:
  259. max_dd = dd
  260. # 保存权益曲线
  261. with open(f"{output_dir}/equity_final.csv", 'w', newline='') as f:
  262. w = csv.DictWriter(f, fieldnames=['time', 'equity', 'close', 'position', 'daily_trend', 'daily_strength', 'regime_state', 'regime_prob'])
  263. w.writeheader()
  264. w.writerows(engine.equity_curve)
  265. # 保存交易记录
  266. with open(f"{output_dir}/trades_final.csv", 'w', newline='') as f:
  267. if engine.trades:
  268. # 获取所有可能的字段
  269. all_fields = set()
  270. for t in engine.trades:
  271. all_fields.update(t.keys())
  272. fieldnames = sorted(all_fields)
  273. w = csv.DictWriter(f, fieldnames=fieldnames)
  274. w.writeheader()
  275. w.writerows(engine.trades)
  276. # 生成详细报告
  277. report = f"""
  278. ================================================================================
  279. CYB50 最优参数回测报告 - 详细版
  280. ================================================================================
  281. 回测参数:
  282. - 初始资金: 1,000,000 元
  283. - 持仓上限: 50%
  284. - 30分钟趋势概率阈值: 0.3 (最优)
  285. - 日线要求: 必须向上 (MA20之上)
  286. - 止损: -2.5% | 止盈: +4% | 最大持仓: 16周期(8小时)
  287. ================================================================================
  288. 整体表现
  289. ================================================================================
  290. 初始资金: {initial:>15,.2f} 元
  291. 最终资金: {final:>15,.2f} 元
  292. 净盈亏: {final-initial:>15,.2f} 元
  293. 总收益率: {total_ret:>15.2f} %
  294. 最大回撤: {max_dd:>15.2f} %
  295. ================================================================================
  296. 交易统计
  297. ================================================================================
  298. 总交易次数: {len(closed):>15} 笔
  299. 盈利次数: {len(wins):>15} 笔
  300. 亏损次数: {len(losses):>15} 笔
  301. 胜率: {win_rate:>15.2f} %
  302. 盈亏比: {profit_factor:>15.2f}
  303. 总盈利: {total_profit:>15,.2f} 元
  304. 总亏损: {total_loss:>15,.2f} 元
  305. 平均每笔盈利: {total_profit/len(wins) if wins else 0:>15,.2f} 元
  306. 平均每笔亏损: {total_loss/len(losses) if losses else 0:>15,.2f} 元
  307. ================================================================================
  308. 最近10笔交易明细
  309. ================================================================================
  310. """
  311. for t in closed[-10:]:
  312. report += f" {t['time']} | 平仓价: {t['price']:.2f} | 盈亏: {t['pnl']:>+10,.2f} ({t['pnl_pct']:+.2f}%) | {t['reason']}\n"
  313. report += f"""
  314. ================================================================================
  315. 数据核对结果
  316. ================================================================================
  317. 30分钟数据条数: {len(data)} 条
  318. 日线数据条数: {len(dm.daily_data)} 条
  319. 30分钟状态条数: {len(rm.regime_data)} 条
  320. 数据完整性: {'通过 ✓' if data_ok else '存在问题 ✗'}
  321. ================================================================================
  322. 文件输出
  323. ================================================================================
  324. - {output_dir}/equity_final.csv (权益曲线)
  325. - {output_dir}/trades_final.csv (交易明细)
  326. - {output_dir}/report_final.txt (本报告)
  327. ================================================================================
  328. """
  329. with open(f"{output_dir}/report_final.txt", 'w') as f:
  330. f.write(report)
  331. print(report)
  332. print(f"\n所有文件已保存到: {output_dir}/")
  333. return engine
  334. if __name__ == '__main__':
  335. run_backtest(
  336. 'cyb50_30min_2023_to_20260325.csv',
  337. '../data-fetch/data/399673_SZ_day_20150101_20260325.csv',
  338. '../../market-regime-identifier-30/cyb50_30min_regime_result.csv'
  339. )