backtest_dual_with_timing.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #!/usr/bin/env python3
  2. """
  3. DualDirection 策略 + 择时过滤
  4. 保持原策略所有参数不变,只增加日线趋势和30分钟状态过滤
  5. """
  6. import csv
  7. from datetime import datetime
  8. import math
  9. class DualDirectionWithTiming:
  10. def __init__(self, initial_capital=1000000):
  11. self.initial_capital = initial_capital
  12. self.position_size_pct = 1.0
  13. self.stop_loss_pct = 0.008
  14. self.take_profit_pct = 0.02
  15. self.max_hold_bars = 16
  16. self.min_trend_prob = 0.3
  17. self.require_daily_uptrend = True
  18. self.long_signal_count = 0
  19. self.filtered_count = 0
  20. self.trades = []
  21. self.capital = initial_capital
  22. def load_daily_data(self, filepath):
  23. daily_data = {}
  24. with open(filepath, 'r', encoding='utf-8-sig') as f:
  25. reader = csv.DictReader(f)
  26. for row in reader:
  27. try:
  28. dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S')
  29. date_str = dt.strftime('%Y-%m-%d')
  30. daily_data[date_str] = {
  31. 'open': float(row['open']),
  32. 'high': float(row['high']),
  33. 'low': float(row['low']),
  34. 'close': float(row['close'])
  35. }
  36. except:
  37. continue
  38. dates = sorted(daily_data.keys())
  39. closes = [daily_data[d]['close'] for d in dates]
  40. for i, date in enumerate(dates):
  41. if i < 19:
  42. daily_data[date]['ma20'] = None
  43. daily_data[date]['trend'] = 0
  44. else:
  45. ma20 = sum(closes[i-19:i+1]) / 20
  46. daily_data[date]['ma20'] = ma20
  47. close = closes[i]
  48. if close > ma20 * 1.02:
  49. daily_data[date]['trend'] = 1
  50. elif close < ma20 * 0.98:
  51. daily_data[date]['trend'] = -1
  52. else:
  53. daily_data[date]['trend'] = 0
  54. return daily_data
  55. def detect_market_regime(self, data, current_idx):
  56. if current_idx < 16:
  57. return 0, 0.0
  58. window = data[current_idx-16:current_idx]
  59. closes = [row['Close'] for row in window]
  60. highs = [row['High'] for row in window]
  61. lows = [row['Low'] for row in window]
  62. start_price = closes[0]
  63. end_price = closes[-1]
  64. period_return = (end_price / start_price - 1) * 100
  65. max_price = max(highs)
  66. min_price = min(lows)
  67. price_range = (max_price - min_price) / start_price * 100
  68. gains = []
  69. losses = []
  70. for i in range(1, len(closes)):
  71. change = closes[i] - closes[i-1]
  72. gains.append(max(0, change))
  73. losses.append(max(0, -change))
  74. avg_gain = sum(gains[-14:]) / 14 if len(gains) >= 14 else sum(gains) / len(gains)
  75. avg_loss = sum(losses[-14:]) / 14 if len(losses) >= 14 else sum(losses) / len(losses)
  76. if avg_loss == 0:
  77. rsi = 100
  78. else:
  79. rs = avg_gain / avg_loss
  80. rsi = 100 - (100 / (1 + rs))
  81. returns = [(closes[i] - closes[i-1]) / closes[i-1] * 100 for i in range(1, len(closes))]
  82. volatility = math.sqrt(sum(r**2 for r in returns) / len(returns)) if returns else 0
  83. reversal_score = 0
  84. if rsi > 70 or rsi < 30:
  85. reversal_score += 2
  86. elif rsi > 65 or rsi < 35:
  87. reversal_score += 1
  88. if price_range > 4 and abs(period_return) < 1.5:
  89. reversal_score += 1
  90. if reversal_score >= 3:
  91. return 2, 0.3
  92. trend_score = 0
  93. if abs(period_return) >= 2.0:
  94. trend_score += 3
  95. elif abs(period_return) >= 1.0:
  96. trend_score += 2
  97. elif abs(period_return) >= 0.5:
  98. trend_score += 1
  99. if 0.5 < volatility < 2.0:
  100. trend_score += 1
  101. first_half = closes[:len(closes)//2]
  102. second_half = closes[len(closes)//2:]
  103. first_avg = sum(first_half) / len(first_half)
  104. second_avg = sum(second_half) / len(second_half)
  105. if (period_return > 0 and second_avg > first_avg) or (period_return < 0 and second_avg < first_avg):
  106. trend_score += 1
  107. if trend_score >= 4:
  108. prob = min(0.95, 0.5 + abs(period_return) / 10)
  109. return 1, prob
  110. return 0, 0.2
  111. def calculate_indicators(self, data):
  112. for i, row in enumerate(data):
  113. if i < 24:
  114. row['RSI'] = 50
  115. row['MACD_hist'] = 0
  116. row['BB_lower'] = row['Close'] * 0.98
  117. row['Volume_Ratio'] = 1.0
  118. row['Price_Momentum'] = 0
  119. continue
  120. closes = [data[j]['Close'] for j in range(i-23, i+1)]
  121. highs = [data[j]['High'] for j in range(i-23, i+1)]
  122. lows = [data[j]['Low'] for j in range(i-23, i+1)]
  123. volumes = [data[j]['Volume'] for j in range(i-23, i+1)]
  124. gains = []
  125. losses = []
  126. for j in range(1, 15):
  127. change = closes[-j] - closes[-j-1]
  128. gains.append(max(0, change))
  129. losses.append(max(0, -change))
  130. avg_gain = sum(gains) / 14
  131. avg_loss = sum(losses) / 14
  132. if avg_loss == 0:
  133. row['RSI'] = 100
  134. else:
  135. rs = avg_gain / avg_loss
  136. row['RSI'] = 100 - (100 / (1 + rs))
  137. bb_middle = sum(closes[-20:]) / 20
  138. variance = sum((c - bb_middle) ** 2 for c in closes[-20:]) / 20
  139. bb_std = variance ** 0.5
  140. row['BB_lower'] = bb_middle - bb_std * 2
  141. ema12 = sum(closes[-12:]) / 12
  142. ema26 = sum(closes[-26:]) / 26 if len(closes) >= 26 else sum(closes) / len(closes)
  143. row['MACD'] = ema12 - ema26
  144. row['MACD_hist'] = row['MACD']
  145. vol_ma = sum(volumes[-12:]) / 12
  146. row['Volume_Ratio'] = row['Volume'] / vol_ma if vol_ma > 0 else 1
  147. row['Price_Momentum'] = (row['Close'] - closes[-6]) / closes[-6] if closes[-6] > 0 else 0
  148. return data
  149. def calculate_long_score(self, row, prev_rows):
  150. long_score = 0
  151. long_signals = []
  152. if row['RSI'] < 30:
  153. long_score += 2
  154. long_signals.append("RSI超卖")
  155. elif row['RSI'] < 35:
  156. long_score += 1
  157. long_signals.append("RSI偏弱")
  158. if row['Close'] <= row['BB_lower'] * 1.01:
  159. long_score += 2
  160. long_signals.append("触及下轨")
  161. elif row['Close'] <= row['BB_lower'] * 1.03:
  162. long_score += 1
  163. long_signals.append("接近下轨")
  164. if len(prev_rows) > 0:
  165. prev_macd_hist = prev_rows[-1]['MACD_hist']
  166. if row['MACD_hist'] > 0 and prev_macd_hist <= 0:
  167. long_score += 2
  168. long_signals.append("MACD金叉")
  169. elif row['MACD_hist'] > prev_macd_hist:
  170. long_score += 1
  171. long_signals.append("MACD改善")
  172. if row['Price_Momentum'] > 0.005:
  173. long_score += 1
  174. long_signals.append("动量向上")
  175. if row['Volume_Ratio'] > 1.5:
  176. long_score += 1
  177. long_signals.append("放量")
  178. return long_score, long_signals
  179. def run_backtest(self, data_file, daily_file):
  180. print("="*70)
  181. print("DualDirection + 择时过滤 回测")
  182. print("="*70)
  183. print(f"\nDualDirection参数: 止损0.8% 止盈2% 最大持仓16周期")
  184. print(f"择时过滤: 日线向上 + 30分钟趋势概率>=0.3")
  185. print(f"\n[1/4] 加载30分钟数据...")
  186. data = []
  187. with open(data_file, 'r', encoding='utf-8-sig') as f:
  188. reader = csv.DictReader(f)
  189. for row in reader:
  190. data.append({
  191. 'DateTime': row['DateTime'],
  192. 'Open': float(row['Open']),
  193. 'High': float(row['High']),
  194. 'Low': float(row['Low']),
  195. 'Close': float(row['Close']),
  196. 'Volume': float(row['Volume'])
  197. })
  198. print(f" {len(data)}条")
  199. print(f"\n[2/4] 加载日线数据...")
  200. daily_data = self.load_daily_data(daily_file)
  201. print(f" {len(daily_data)}条")
  202. print(f"\n[3/4] 计算技术指标...")
  203. data = self.calculate_indicators(data)
  204. print(f"\n[4/4] 执行回测...")
  205. position = 0
  206. entry_price = 0
  207. entry_idx = 0
  208. for i in range(24, len(data)):
  209. row = data[i]
  210. current_time = row['DateTime']
  211. current_price = row['Close']
  212. date_str = current_time[:10]
  213. daily_info = daily_data.get(date_str, {'trend': 0})
  214. daily_trend = daily_info['trend']
  215. regime_state, trend_prob = self.detect_market_regime(data, i)
  216. prev_rows = data[max(0, i-5):i]
  217. long_score, long_signals = self.calculate_long_score(row, prev_rows)
  218. if position > 0:
  219. holding_bars = i - entry_idx
  220. pnl_pct = (current_price - entry_price) / entry_price
  221. exit_reason = None
  222. if pnl_pct <= -self.stop_loss_pct:
  223. exit_reason = f"止损({current_price:.2f})"
  224. elif pnl_pct >= self.take_profit_pct:
  225. exit_reason = f"止盈({current_price:.2f})"
  226. elif holding_bars >= self.max_hold_bars:
  227. exit_reason = f"时间平仓({holding_bars}周期)"
  228. elif row['RSI'] > 75:
  229. exit_reason = f"RSI超买({row['RSI']:.1f})"
  230. if exit_reason:
  231. pnl = (current_price - entry_price) * position
  232. self.capital += pnl
  233. self.trades.append({
  234. 'action': 'CLOSE', 'time': current_time, 'price': current_price,
  235. 'pnl': pnl, 'pnl_pct': pnl_pct * 100, 'reason': exit_reason
  236. })
  237. position = 0
  238. entry_price = 0
  239. elif long_score >= 4 and position == 0:
  240. self.long_signal_count += 1
  241. can_trade = True
  242. if self.require_daily_uptrend and daily_trend != 1:
  243. can_trade = False
  244. if regime_state != 1 or trend_prob < self.min_trend_prob:
  245. can_trade = False
  246. if can_trade:
  247. position_value = self.capital * self.position_size_pct
  248. position = position_value / current_price
  249. entry_price = current_price
  250. entry_idx = i
  251. self.trades.append({
  252. 'action': 'OPEN', 'time': current_time, 'price': current_price,
  253. 'value': position_value,
  254. 'reason': f"信号{long_score}分|日线向上|趋势{trend_prob:.2f}"
  255. })
  256. else:
  257. self.filtered_count += 1
  258. closed = [t for t in self.trades if t['action']=='CLOSE']
  259. print(f" 信号: {self.long_signal_count} 过滤: {self.filtered_count} 交易: {len(closed)}")
  260. return self.generate_report()
  261. def generate_report(self):
  262. closed_trades = [t for t in self.trades if t['action'] == 'CLOSE']
  263. if not closed_trades:
  264. print("\n无交易")
  265. return None
  266. wins = [t for t in closed_trades if t['pnl'] > 0]
  267. losses = [t for t in closed_trades if t['pnl'] <= 0]
  268. total_pnl = sum(t['pnl'] for t in closed_trades)
  269. final_capital = self.initial_capital + total_pnl
  270. total_return = (final_capital / self.initial_capital - 1) * 100
  271. win_rate = len(wins) / len(closed_trades) * 100
  272. total_profit = sum(t['pnl'] for t in wins) if wins else 0
  273. total_loss = abs(sum(t['pnl'] for t in losses)) if losses else 0
  274. profit_factor = total_profit / total_loss if total_loss > 0 else 0
  275. print("\n" + "="*70)
  276. print("回测报告 - DualDirection + 择时过滤")
  277. print("="*70)
  278. print(f" 收益率: {total_return:+.2f}%")
  279. print(f" 信号: {self.long_signal_count} 过滤: {self.filtered_count} 交易: {len(closed_trades)}")
  280. print(f" 胜率: {win_rate:.2f}% 盈亏比: {profit_factor:.2f}")
  281. print(f"\n最近5笔:")
  282. for t in closed_trades[-5:]:
  283. print(f" {t['time']} | {t['pnl']:+10,.2f} | {t['reason']}")
  284. print("="*70)
  285. return {'total_return': total_return, 'win_rate': win_rate}
  286. if __name__ == '__main__':
  287. backtest = DualDirectionWithTiming()
  288. backtest.run_backtest('cyb50_30min_2023_to_20260325.csv', '../data-fetch/data/399673_SZ_day_20150101_20260325.csv')