backtest_t1_standalone.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. CYB50 只做多T+1回测系统
  5. 不依赖pandas,使用纯Python实现
  6. """
  7. import csv
  8. import json
  9. from datetime import datetime, timedelta
  10. from collections import deque
  11. import math
  12. # ==================== 技术指标计算类 ====================
  13. class TechnicalIndicators:
  14. """技术指标计算 - 纯Python实现"""
  15. @staticmethod
  16. def sma(data, period):
  17. """简单移动平均线"""
  18. if len(data) < period:
  19. return None
  20. return sum(data[-period:]) / period
  21. @staticmethod
  22. def ema(data, period):
  23. """指数移动平均线"""
  24. if len(data) < period:
  25. return None
  26. multiplier = 2 / (period + 1)
  27. ema = data[0]
  28. for price in data[1:]:
  29. ema = (price - ema) * multiplier + ema
  30. return ema
  31. @staticmethod
  32. def rsi(prices, period=14):
  33. """RSI计算"""
  34. if len(prices) < period + 1:
  35. return None
  36. gains = []
  37. losses = []
  38. for i in range(1, len(prices)):
  39. change = prices[i] - prices[i-1]
  40. if change > 0:
  41. gains.append(change)
  42. losses.append(0)
  43. else:
  44. gains.append(0)
  45. losses.append(abs(change))
  46. if len(gains) < period:
  47. return None
  48. avg_gain = sum(gains[-period:]) / period
  49. avg_loss = sum(losses[-period:]) / period
  50. if avg_loss == 0:
  51. return 100
  52. rs = avg_gain / avg_loss
  53. return 100 - (100 / (1 + rs))
  54. @staticmethod
  55. def bollinger_bands(prices, period=20, std_dev=2):
  56. """布林带计算"""
  57. if len(prices) < period:
  58. return None, None, None
  59. middle = sum(prices[-period:]) / period
  60. variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
  61. std = math.sqrt(variance)
  62. upper = middle + (std * std_dev)
  63. lower = middle - (std * std_dev)
  64. return upper, middle, lower
  65. @staticmethod
  66. def macd(prices, fast=12, slow=26, signal=9):
  67. """MACD计算"""
  68. if len(prices) < slow:
  69. return None, None, None
  70. # 计算EMA
  71. def calc_ema(data, period):
  72. multiplier = 2 / (period + 1)
  73. ema = data[0]
  74. for price in data[1:]:
  75. ema = (price - ema) * multiplier + ema
  76. return ema
  77. ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None
  78. ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None
  79. if ema_fast is None or ema_slow is None:
  80. return None, None, None
  81. macd_line = ema_fast - ema_slow
  82. # 简化:使用当前MACD作为信号线近似
  83. signal_line = macd_line * 0.8 # 近似值
  84. histogram = macd_line - signal_line
  85. return macd_line, signal_line, histogram
  86. @staticmethod
  87. def kdj(highs, lows, closes, period=9):
  88. """KDJ计算"""
  89. if len(closes) < period:
  90. return None, None, None
  91. low_n = min(lows[-period:])
  92. high_n = max(highs[-period:])
  93. close = closes[-1]
  94. if high_n == low_n:
  95. rsv = 50
  96. else:
  97. rsv = (close - low_n) / (high_n - low_n) * 100
  98. # 简化KDJ计算
  99. k = rsv
  100. d = k
  101. j = 3 * k - 2 * d
  102. return k, d, j
  103. @staticmethod
  104. def atr(highs, lows, closes, period=14):
  105. """ATR计算"""
  106. if len(closes) < period + 1:
  107. return None
  108. tr_values = []
  109. for i in range(1, len(closes)):
  110. high_low = highs[i] - lows[i]
  111. high_close = abs(highs[i] - closes[i-1])
  112. low_close = abs(lows[i] - closes[i-1])
  113. tr = max(high_low, high_close, low_close)
  114. tr_values.append(tr)
  115. if len(tr_values) < period:
  116. return None
  117. return sum(tr_values[-period:]) / period
  118. # ==================== 数据加载类 ====================
  119. class DataLoader:
  120. """CSV数据加载器"""
  121. def __init__(self, file_path):
  122. self.file_path = file_path
  123. self.data = []
  124. def load(self):
  125. """加载CSV数据"""
  126. print(f"正在加载数据文件: {self.file_path}")
  127. with open(self.file_path, 'r', encoding='utf-8-sig') as f:
  128. reader = csv.DictReader(f)
  129. for row in reader:
  130. # 解析时间
  131. dt_str = row['DateTime']
  132. dt = datetime.strptime(dt_str, '%Y-%m-%d %H:%M:%S')
  133. self.data.append({
  134. 'datetime': dt,
  135. 'date': dt.date(),
  136. 'time': dt.time(),
  137. 'open': float(row['Open']),
  138. 'high': float(row['High']),
  139. 'low': float(row['Low']),
  140. 'close': float(row['Close']),
  141. 'volume': float(row['Volume']),
  142. 'a': float(row['a']) if row['a'] else 0,
  143. 'pc': float(row['pc']) if row['pc'] else 0,
  144. 'sf': float(row['sf']) if row['sf'] else 0
  145. })
  146. print(f"✅ 数据加载完成: {len(self.data)}条K线")
  147. print(f" 数据区间: {self.data[0]['datetime']} ~ {self.data[-1]['datetime']}")
  148. return self.data
  149. # ==================== 信号生成器 ====================
  150. class SignalGenerator:
  151. """只做多信号生成器"""
  152. def __init__(self):
  153. self.prices = []
  154. self.highs = []
  155. self.lows = []
  156. self.volumes = []
  157. self.macd_histograms = []
  158. def update(self, bar):
  159. """更新数据"""
  160. self.prices.append(bar['close'])
  161. self.highs.append(bar['high'])
  162. self.lows.append(bar['low'])
  163. self.volumes.append(bar['volume'])
  164. def calculate_indicators(self):
  165. """计算所有技术指标"""
  166. if len(self.prices) < 26:
  167. return None
  168. ti = TechnicalIndicators()
  169. # 移动平均线
  170. ma6 = ti.sma(self.prices, 6)
  171. ma12 = ti.sma(self.prices, 12)
  172. ma24 = ti.sma(self.prices, 24)
  173. # RSI
  174. rsi = ti.rsi(self.prices, 14)
  175. # 布林带
  176. bb_upper, bb_middle, bb_lower = ti.bollinger_bands(self.prices, 20)
  177. # MACD
  178. macd_line, macd_signal, macd_hist = ti.macd(self.prices, 12, 26, 9)
  179. if macd_hist is not None:
  180. self.macd_histograms.append(macd_hist)
  181. # KDJ
  182. k, d, j = ti.kdj(self.highs, self.lows, self.prices, 9)
  183. # ATR
  184. atr = ti.atr(self.highs, self.lows, self.prices, 14)
  185. atr_pct = atr / self.prices[-1] if atr else None
  186. # 成交量比率
  187. volume_ma = ti.sma(self.volumes, 12)
  188. volume_ratio = self.volumes[-1] / volume_ma if volume_ma else 1.0
  189. # 价格动量
  190. price_momentum = (self.prices[-1] - self.prices[-6]) / self.prices[-6] if len(self.prices) >= 6 else 0
  191. # 涨跌幅
  192. returns = (self.prices[-1] - self.prices[-2]) / self.prices[-2] if len(self.prices) >= 2 else 0
  193. close_open_pct = (self.prices[-1] - self.highs[-1]) / self.highs[-1] # 简化计算
  194. return {
  195. 'ma6': ma6,
  196. 'ma12': ma12,
  197. 'ma24': ma24,
  198. 'rsi': rsi,
  199. 'bb_upper': bb_upper,
  200. 'bb_middle': bb_middle,
  201. 'bb_lower': bb_lower,
  202. 'macd': macd_line,
  203. 'macd_signal': macd_signal,
  204. 'macd_hist': macd_hist,
  205. 'k': k,
  206. 'd': d,
  207. 'j': j,
  208. 'atr_pct': atr_pct,
  209. 'volume_ratio': volume_ratio,
  210. 'price_momentum': price_momentum,
  211. 'returns': returns,
  212. 'close_open_pct': close_open_pct
  213. }
  214. def generate_long_signal(self, indicators, bar_idx):
  215. """生成做多信号"""
  216. if indicators is None:
  217. return 0, []
  218. score = 0
  219. signals = []
  220. # 1. RSI超卖
  221. if indicators['rsi'] < 30:
  222. score += 2
  223. signals.append("RSI超卖")
  224. elif indicators['rsi'] < 35:
  225. score += 1
  226. signals.append("RSI偏弱")
  227. # 2. KDJ超卖
  228. if indicators['k'] < 20 and indicators['d'] < 20:
  229. score += 2
  230. signals.append("KDJ超卖")
  231. elif indicators['j'] < 0:
  232. score += 1
  233. signals.append("KDJ极端超卖")
  234. # 3. MACD金叉
  235. if len(self.macd_histograms) >= 2:
  236. if indicators['macd_hist'] > 0 and self.macd_histograms[-2] <= 0:
  237. score += 2
  238. signals.append("MACD金叉")
  239. elif indicators['macd_hist'] > self.macd_histograms[-2]:
  240. score += 1
  241. signals.append("MACD改善")
  242. # 4. 价格触及布林带下轨
  243. current_price = self.prices[-1]
  244. if indicators['bb_lower'] and current_price <= indicators['bb_lower'] * 1.005:
  245. score += 2
  246. signals.append("触及下轨")
  247. elif indicators['bb_lower'] and current_price <= indicators['bb_lower'] * 1.01:
  248. score += 1
  249. signals.append("接近下轨")
  250. # 5. 连续下跌后的反转
  251. if len(self.prices) >= 7:
  252. recent_returns = [(self.prices[i] - self.prices[i-1]) / self.prices[i-1]
  253. for i in range(len(self.prices)-6, len(self.prices))]
  254. if min(recent_returns) < -0.015:
  255. consecutive_decline = sum(1 for r in recent_returns if r < 0)
  256. if consecutive_decline >= 4:
  257. score += 2
  258. signals.append("连续下跌反转")
  259. # 6. 价格动量反转
  260. if indicators['price_momentum'] < -0.02:
  261. score += 1
  262. signals.append("动量超卖")
  263. # 7. 成交量配合
  264. if indicators['volume_ratio'] > 1.2:
  265. score += 1
  266. signals.append("放量配合")
  267. # 8. MA趋势过滤
  268. if indicators['ma6'] and indicators['ma12'] and indicators['ma24']:
  269. if indicators['ma6'] < indicators['ma12'] < indicators['ma24']:
  270. score -= 1
  271. signals.append("MA下降趋势惩罚")
  272. elif indicators['ma6'] > indicators['ma12']:
  273. score += 1
  274. signals.append("MA短期上行")
  275. return score, signals
  276. # ==================== T+1交易执行器 ====================
  277. class T1BacktestExecutor:
  278. """T+1回测执行器"""
  279. def __init__(self, initial_capital=1000000):
  280. self.initial_capital = initial_capital
  281. self.capital = initial_capital
  282. self.position = 0
  283. self.entry_price = 0
  284. self.entry_time = None
  285. self.entry_date = None
  286. self.entry_signals = []
  287. self.holding_bars = 0
  288. # 参数
  289. self.commission_rate = 0.0001 # 万分之一
  290. self.stop_loss_pct = 0.008 # 0.8%止损
  291. self.take_profit_pct = 0.02 # 2%止盈
  292. self.max_hold_bars = 16 # 最大持仓8小时
  293. # 交易记录
  294. self.trades = []
  295. self.equity_curve = []
  296. # 待平仓队列 (T+1规则:当天买入的次日才能卖出)
  297. self.pending_positions = [] # 存储不能当天卖出的持仓信息
  298. def can_trade(self, current_date):
  299. """检查是否可以交易(T+1限制)"""
  300. # 检查是否有前一天买入的持仓可以卖出
  301. available_to_sell = []
  302. still_pending = []
  303. for pos in self.pending_positions:
  304. if pos['entry_date'] < current_date:
  305. # 可以卖出了
  306. available_to_sell.append(pos)
  307. else:
  308. # 还不能卖出
  309. still_pending.append(pos)
  310. self.pending_positions = still_pending
  311. return available_to_sell
  312. def check_exit(self, bar, position_info):
  313. """检查是否需要平仓"""
  314. price = bar['close']
  315. entry_price = position_info['entry_price']
  316. holding_bars = position_info['holding_bars']
  317. stop_loss = entry_price * (1 - self.stop_loss_pct)
  318. take_profit = entry_price * (1 + self.take_profit_pct)
  319. # 止损
  320. if price <= stop_loss:
  321. return True, f"止损({price:.2f}<={stop_loss:.2f})", price
  322. # 止盈
  323. if price >= take_profit:
  324. return True, f"止盈({price:.2f}>={take_profit:.2f})", price
  325. # 最大持仓时间
  326. if holding_bars >= self.max_hold_bars:
  327. return True, f"时间平仓({holding_bars}周期)", price
  328. return False, "", price
  329. def execute_buy(self, bar, score, signals):
  330. """执行买入"""
  331. price = bar['close']
  332. date = bar['date']
  333. dt = bar['datetime']
  334. # 计算仓位(全仓)
  335. position_value = self.capital
  336. position_size = int(position_value / price)
  337. if position_size <= 0:
  338. return False
  339. cost = position_size * price * (1 + self.commission_rate)
  340. if cost > self.capital:
  341. position_size = int(self.capital / (price * (1 + self.commission_rate)))
  342. cost = position_size * price * (1 + self.commission_rate)
  343. self.capital -= cost
  344. # 记录持仓信息(T+1规则下,当天不能卖出)
  345. position_info = {
  346. 'entry_price': price,
  347. 'entry_time': dt,
  348. 'entry_date': date,
  349. 'position_size': position_size,
  350. 'holding_bars': 0,
  351. 'entry_signals': signals,
  352. 'score': score,
  353. 'stop_loss': price * (1 - self.stop_loss_pct),
  354. 'take_profit': price * (1 + self.take_profit_pct)
  355. }
  356. self.pending_positions.append(position_info)
  357. print(f"\n[开仓] {dt} 价格:{price:.2f} 数量:{position_size} 信号分数:{score}")
  358. print(f" 信号: {', '.join(signals)}")
  359. return True
  360. def execute_sell(self, bar, position_info, exit_reason):
  361. """执行卖出"""
  362. price = bar['close']
  363. dt = bar['datetime']
  364. entry_price = position_info['entry_price']
  365. position_size = position_info['position_size']
  366. entry_time = position_info['entry_time']
  367. holding_bars = position_info['holding_bars']
  368. # 计算盈亏
  369. gross_pnl = (price - entry_price) * position_size
  370. open_cost = position_size * entry_price * self.commission_rate
  371. close_revenue = position_size * price
  372. close_cost = close_revenue * self.commission_rate
  373. pnl = gross_pnl - open_cost - close_cost
  374. pnl_pct = (price - entry_price) / entry_price * 100
  375. # 更新资金
  376. self.capital += close_revenue - close_cost
  377. # 记录交易
  378. trade = {
  379. 'entry_time': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
  380. 'exit_time': dt.strftime('%Y-%m-%d %H:%M:%S'),
  381. 'entry_price': round(entry_price, 2),
  382. 'exit_price': round(price, 2),
  383. 'position': position_size,
  384. 'pnl': round(pnl, 2),
  385. 'pnl_pct': round(pnl_pct, 2),
  386. 'exit_reason': exit_reason,
  387. 'holding_bars': holding_bars,
  388. 'holding_hours': round(holding_bars * 0.5, 1),
  389. 'entry_signals': '|'.join(position_info['entry_signals']),
  390. 'capital': round(self.capital, 2),
  391. 'position_value': round(position_size * entry_price, 2)
  392. }
  393. self.trades.append(trade)
  394. status = "盈利" if pnl > 0 else "亏损"
  395. print(f"[平仓] {dt} 价格:{price:.2f} 盈亏:{pnl:+.2f}({pnl_pct:+.2f}%) [{status}] 原因:{exit_reason}")
  396. return trade
  397. def update_equity(self, bar, active_position=None):
  398. """更新权益曲线"""
  399. price = bar['close']
  400. dt = bar['datetime']
  401. total_value = self.capital
  402. if active_position:
  403. total_value += active_position['position_size'] * price
  404. self.equity_curve.append({
  405. 'datetime': dt.strftime('%Y-%m-%d %H:%M:%S'),
  406. 'price': round(price, 2),
  407. 'capital': round(self.capital, 2),
  408. 'total_value': round(total_value, 2),
  409. 'return_pct': round((total_value / self.initial_capital - 1) * 100, 2)
  410. })
  411. def run_backtest(self, data):
  412. """运行回测"""
  413. print("\n" + "="*80)
  414. print("开始T+1回测")
  415. print("="*80)
  416. signal_gen = SignalGenerator()
  417. active_position = None # 当前活跃持仓(可以卖出的)
  418. for i, bar in enumerate(data):
  419. current_date = bar['date']
  420. # 更新信号生成器
  421. signal_gen.update(bar)
  422. # 检查T+1限制,获取可以卖出的持仓
  423. available_positions = self.can_trade(current_date)
  424. # 如果有可卖出的持仓,选择第一个作为活跃持仓
  425. if available_positions and active_position is None:
  426. active_position = available_positions[0]
  427. for pos in available_positions[1:]:
  428. self.pending_positions.append(pos)
  429. # 更新活跃持仓的持仓时间
  430. if active_position:
  431. active_position['holding_bars'] += 1
  432. # 检查是否需要平仓(只有活跃持仓可以平仓)
  433. if active_position:
  434. should_exit, exit_reason, exit_price = self.check_exit(bar, active_position)
  435. if should_exit:
  436. self.execute_sell(bar, active_position, exit_reason)
  437. active_position = None
  438. # 检查是否开新仓(无持仓时)
  439. if active_position is None and len(self.pending_positions) == 0 and i >= 26:
  440. indicators = signal_gen.calculate_indicators()
  441. score, signals = signal_gen.generate_long_signal(indicators, i)
  442. # 信号分数>=4且开仓
  443. if score >= 4:
  444. self.execute_buy(bar, score, signals)
  445. # 更新权益曲线
  446. self.update_equity(bar, active_position)
  447. # 回测结束,强制平仓所有持仓
  448. print("\n" + "="*80)
  449. print("回测结束,强制平仓")
  450. print("="*80)
  451. if active_position:
  452. self.execute_sell(data[-1], active_position, "回测结束")
  453. # 处理pending中的持仓(如果数据结束但还有持仓)
  454. for pos in self.pending_positions:
  455. pos['holding_bars'] = self.max_hold_bars # 强制达到平仓条件
  456. self.execute_sell(data[-1], pos, "回测结束(T+1)")
  457. return self.trades, self.equity_curve
  458. # ==================== 回测报告生成器 ====================
  459. class BacktestReport:
  460. """生成回测报告"""
  461. def __init__(self, trades, equity_curve, initial_capital=1000000):
  462. self.trades = trades
  463. self.equity_curve = equity_curve
  464. self.initial_capital = initial_capital
  465. def calculate_metrics(self):
  466. """计算回测指标"""
  467. if not self.trades:
  468. return {
  469. 'total_trades': 0,
  470. 'win_rate': 0,
  471. 'profit_factor': 0,
  472. 'total_return': 0,
  473. 'max_drawdown': 0,
  474. 'sharpe_ratio': 0
  475. }
  476. total_trades = len(self.trades)
  477. winning_trades = [t for t in self.trades if t['pnl'] > 0]
  478. losing_trades = [t for t in self.trades if t['pnl'] <= 0]
  479. win_count = len(winning_trades)
  480. loss_count = len(losing_trades)
  481. win_rate = (win_count / total_trades * 100) if total_trades > 0 else 0
  482. total_profit = sum(t['pnl'] for t in winning_trades)
  483. total_loss = abs(sum(t['pnl'] for t in losing_trades))
  484. profit_factor = total_profit / total_loss if total_loss > 0 else 0
  485. # 总收益
  486. final_capital = self.trades[-1]['capital'] if self.trades else self.initial_capital
  487. total_return = (final_capital - self.initial_capital) / self.initial_capital * 100
  488. # 最大回撤
  489. max_drawdown = self._calculate_max_drawdown()
  490. # 夏普比率(简化计算)
  491. sharpe_ratio = self._calculate_sharpe()
  492. return {
  493. 'total_trades': total_trades,
  494. 'win_count': win_count,
  495. 'loss_count': loss_count,
  496. 'win_rate': round(win_rate, 2),
  497. 'profit_factor': round(profit_factor, 2),
  498. 'total_profit': round(total_profit, 2),
  499. 'total_loss': round(total_loss, 2),
  500. 'total_return': round(total_return, 2),
  501. 'max_drawdown': round(max_drawdown, 2),
  502. 'sharpe_ratio': round(sharpe_ratio, 2),
  503. 'initial_capital': self.initial_capital,
  504. 'final_capital': round(final_capital, 2),
  505. 'net_profit': round(final_capital - self.initial_capital, 2)
  506. }
  507. def _calculate_max_drawdown(self):
  508. """计算最大回撤"""
  509. if not self.equity_curve:
  510. return 0
  511. max_dd = 0
  512. peak = self.equity_curve[0]['total_value']
  513. for point in self.equity_curve:
  514. value = point['total_value']
  515. if value > peak:
  516. peak = value
  517. dd = (peak - value) / peak * 100
  518. if dd > max_dd:
  519. max_dd = dd
  520. return max_dd
  521. def _calculate_sharpe(self):
  522. """计算夏普比率(简化版)"""
  523. if len(self.equity_curve) < 2:
  524. return 0
  525. # 计算收益率序列
  526. returns = []
  527. for i in range(1, len(self.equity_curve)):
  528. prev = self.equity_curve[i-1]['total_value']
  529. curr = self.equity_curve[i]['total_value']
  530. if prev > 0:
  531. returns.append((curr - prev) / prev)
  532. if not returns:
  533. return 0
  534. avg_return = sum(returns) / len(returns)
  535. # 计算标准差
  536. variance = sum((r - avg_return) ** 2 for r in returns) / len(returns)
  537. std = math.sqrt(variance) if variance > 0 else 0
  538. # 年化夏普(简化:假设每个bar代表30分钟)
  539. if std > 0:
  540. sharpe = (avg_return * 48 * 252) / (std * math.sqrt(48)) # 48个30分钟/天,252交易日/年
  541. return sharpe
  542. return 0
  543. def generate_report(self):
  544. """生成文字报告"""
  545. metrics = self.calculate_metrics()
  546. report = []
  547. report.append("="*80)
  548. report.append("CYB50 只做多T+1策略回测报告")
  549. report.append("="*80)
  550. report.append("")
  551. report.append("【回测参数】")
  552. report.append(f" 初始资金: {metrics['initial_capital']:,.0f} 元")
  553. report.append(f" 最终资金: {metrics['final_capital']:,.2f} 元")
  554. report.append(f" 净盈亏: {metrics['net_profit']:+,.2f} 元")
  555. report.append(f" 总收益率: {metrics['total_return']:+.2f}%")
  556. report.append("")
  557. report.append("【交易统计】")
  558. report.append(f" 总交易次数: {metrics['total_trades']} 笔")
  559. report.append(f" 盈利次数: {metrics['win_count']} 笔")
  560. report.append(f" 亏损次数: {metrics['loss_count']} 笔")
  561. report.append(f" 胜率: {metrics['win_rate']}%")
  562. report.append(f" 盈亏比: {metrics['profit_factor']}")
  563. report.append(f" 总盈利: {metrics['total_profit']:,.2f} 元")
  564. report.append(f" 总亏损: {metrics['total_loss']:,.2f} 元")
  565. report.append("")
  566. report.append("【风险指标】")
  567. report.append(f" 最大回撤: {metrics['max_drawdown']}%")
  568. report.append(f" 夏普比率: {metrics['sharpe_ratio']}")
  569. report.append("")
  570. if self.trades:
  571. report.append("【最近20笔交易明细】")
  572. report.append("-"*120)
  573. report.append(f"{'开仓时间':<20} {'平仓时间':<20} {'开仓价':>10} {'平仓价':>10} {'盈亏':>12} {'盈亏%':>8} {'持仓h':>6} {'原因':<20}")
  574. report.append("-"*120)
  575. for t in self.trades[-20:]:
  576. report.append(f"{t['entry_time']:<20} {t['exit_time']:<20} {t['entry_price']:>10.2f} {t['exit_price']:>10.2f} "
  577. f"{t['pnl']:>+12.2f} {t['pnl_pct']:>+7.2f}% {t['holding_hours']:>6.1f} {t['exit_reason']:<20}")
  578. report.append("-"*120)
  579. report.append("")
  580. report.append("="*80)
  581. return "\n".join(report), metrics
  582. def save_results(self, output_dir="."):
  583. """保存结果到文件"""
  584. import os
  585. os.makedirs(output_dir, exist_ok=True)
  586. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  587. # 1. 保存交易明细
  588. trades_file = os.path.join(output_dir, f"trades_{timestamp}.csv")
  589. if self.trades:
  590. with open(trades_file, 'w', newline='', encoding='utf-8-sig') as f:
  591. writer = csv.DictWriter(f, fieldnames=self.trades[0].keys())
  592. writer.writeheader()
  593. writer.writerows(self.trades)
  594. print(f"✅ 交易明细已保存: {trades_file}")
  595. # 2. 保存权益曲线
  596. equity_file = os.path.join(output_dir, f"equity_{timestamp}.csv")
  597. if self.equity_curve:
  598. with open(equity_file, 'w', newline='', encoding='utf-8-sig') as f:
  599. writer = csv.DictWriter(f, fieldnames=self.equity_curve[0].keys())
  600. writer.writeheader()
  601. writer.writerows(self.equity_curve)
  602. print(f"✅ 权益曲线已保存: {equity_file}")
  603. # 3. 保存报告
  604. report_text, metrics = self.generate_report()
  605. report_file = os.path.join(output_dir, f"report_{timestamp}.txt")
  606. with open(report_file, 'w', encoding='utf-8') as f:
  607. f.write(report_text)
  608. print(f"✅ 回测报告已保存: {report_file}")
  609. # 4. 保存指标JSON
  610. json_file = os.path.join(output_dir, f"metrics_{timestamp}.json")
  611. with open(json_file, 'w', encoding='utf-8') as f:
  612. json.dump(metrics, f, indent=2, ensure_ascii=False)
  613. print(f"✅ 指标数据已保存: {json_file}")
  614. return trades_file, equity_file, report_file, json_file
  615. # ==================== 主函数 ====================
  616. def main():
  617. """主程序"""
  618. print("="*80)
  619. print("CYB50 只做多T+1回测系统")
  620. print("="*80)
  621. # 数据文件路径
  622. data_file = "/home/erwin/.openclaw/workspace/cyb50-quant/cat-fly/t1/cyb50_30min_2023_to_20260325.csv"
  623. # 1. 加载数据
  624. loader = DataLoader(data_file)
  625. data = loader.load()
  626. # 2. 运行回测
  627. executor = T1BacktestExecutor(initial_capital=1000000)
  628. trades, equity_curve = executor.run_backtest(data)
  629. # 3. 生成报告
  630. report = BacktestReport(trades, equity_curve, initial_capital=1000000)
  631. report_text, metrics = report.generate_report()
  632. # 4. 打印报告
  633. print("\n" + report_text)
  634. # 5. 保存结果
  635. output_dir = "/home/erwin/.openclaw/workspace/cyb50-quant/cat-fly/t1/backtest_results"
  636. report.save_results(output_dir)
  637. print(f"\n✅ 回测完成!")
  638. print(f" 总收益率: {metrics['total_return']:+.2f}%")
  639. print(f" 交易次数: {metrics['total_trades']} 笔")
  640. print(f" 胜率: {metrics['win_rate']}%")
  641. if __name__ == "__main__":
  642. main()