#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ CYB50 择时过滤T+1回测系统 - 结合市场状态 只做多,使用市场状态过滤开仓信号 """ import csv import json from datetime import datetime, timedelta from collections import deque import math # ==================== 技术指标计算类 ==================== class TechnicalIndicators: """技术指标计算 - 纯Python实现""" @staticmethod def sma(data, period): """简单移动平均线""" if len(data) < period: return None return sum(data[-period:]) / period @staticmethod def ema(data, period): """指数移动平均线""" if len(data) < period: return None multiplier = 2 / (period + 1) ema = data[0] for price in data[1:]: ema = (price - ema) * multiplier + ema return ema @staticmethod def rsi(prices, period=14): """RSI计算""" if len(prices) < period + 1: return None gains = [] losses = [] for i in range(1, len(prices)): change = prices[i] - prices[i-1] if change > 0: gains.append(change) losses.append(0) else: gains.append(0) losses.append(abs(change)) if len(gains) < period: return None avg_gain = sum(gains[-period:]) / period avg_loss = sum(losses[-period:]) / period if avg_loss == 0: return 100 rs = avg_gain / avg_loss return 100 - (100 / (1 + rs)) @staticmethod def bollinger_bands(prices, period=20, std_dev=2): """布林带计算""" if len(prices) < period: return None, None, None middle = sum(prices[-period:]) / period variance = sum((p - middle) ** 2 for p in prices[-period:]) / period std = math.sqrt(variance) upper = middle + (std * std_dev) lower = middle - (std * std_dev) return upper, middle, lower @staticmethod def macd(prices, fast=12, slow=26, signal=9): """MACD计算""" if len(prices) < slow: return None, None, None def calc_ema(data, period): multiplier = 2 / (period + 1) ema = data[0] for price in data[1:]: ema = (price - ema) * multiplier + ema return ema ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None if ema_fast is None or ema_slow is None: return None, None, None macd_line = ema_fast - ema_slow # 计算信号线 (EMA of MACD) macd_prices = [] for i in range(slow, len(prices) + 1): fast_ema = calc_ema(prices[i-fast:i], fast) slow_ema = calc_ema(prices[i-slow:i], slow) macd_prices.append(fast_ema - slow_ema) signal_line = None if len(macd_prices) >= signal: signal_line = calc_ema(macd_prices[-signal:], signal) histogram = macd_line - signal_line if signal_line else None return macd_line, signal_line, histogram # ==================== 市场状态管理器 ==================== class MarketRegimeManager: """管理市场状态数据,提供择时过滤""" def __init__(self, regime_file): self.regime_data = {} self.load_regime_data(regime_file) def load_regime_data(self, filepath): """加载市场状态数据""" print(f"加载市场状态数据: {filepath}") try: with open(filepath, 'r', encoding='utf-8') as f: reader = csv.DictReader(f) for row in reader: # 解析datetime dt_str = row['datetime'] self.regime_data[dt_str] = { 'state': int(row['state']), 'prob_ranging': float(row['prob_ranging']), 'prob_trend': float(row['prob_trend']), 'prob_reversal': float(row['prob_reversal']) } print(f"[OK] 加载成功: {len(self.regime_data)}条状态数据") except Exception as e: print(f"[ERROR] 加载失败: {e}") self.regime_data = {} def get_regime(self, dt_str): """获取指定时间的市场状态""" return self.regime_data.get(dt_str, { 'state': 0, # 默认震荡 'prob_ranging': 1.0, 'prob_trend': 0.0, 'prob_reversal': 0.0 }) def can_open_long(self, dt_str, min_trend_prob=0.5): """ 判断是否允许开多单 规则: - 趋势状态(state=1) + 趋势概率 > min_trend_prob -> 允许 - 其他状态 -> 禁止 """ regime = self.get_regime(dt_str) state = regime['state'] trend_prob = regime['prob_trend'] # 只在趋势状态且概率足够高时允许开仓 if state == 1 and trend_prob >= min_trend_prob: return True, f"趋势状态(概率{trend_prob:.2f})" # 反转状态 - 禁止开仓 if state == 2: return False, f"反转状态(概率{regime['prob_reversal']:.2f})" # 震荡状态 - 观望 return False, f"震荡状态(概率{regime['prob_ranging']:.2f})" # ==================== 回测引擎 ==================== class BacktestEngine: """择时过滤T+1回测引擎""" def __init__(self, initial_capital=1000000, position_size=0.5): self.initial_capital = initial_capital self.position_size = position_size self.capital = initial_capital self.position = 0 # 持仓数量 self.entry_price = 0 self.entry_time = None self.holding_periods = 0 self.max_holding_periods = 16 # 最大持仓周期(8小时) # 记录 self.equity_curve = [] self.trades = [] self.signals = [] # 指标 self.prices = deque(maxlen=100) self.highs = deque(maxlen=100) self.lows = deque(maxlen=100) def calculate_signals(self): """计算交易信号""" if len(self.prices) < 50: return None price_list = list(self.prices) high_list = list(self.highs) low_list = list(self.lows) # 技术指标 rsi = TechnicalIndicators.rsi(price_list, 14) bb_upper, bb_middle, bb_lower = TechnicalIndicators.bollinger_bands(price_list, 20, 2) # 均线 ma5 = TechnicalIndicators.sma(price_list, 5) ma10 = TechnicalIndicators.sma(price_list, 10) ma20 = TechnicalIndicators.sma(price_list, 20) # MACD macd_line, signal_line, histogram = TechnicalIndicators.macd(price_list) return { 'rsi': rsi, 'bb_upper': bb_upper, 'bb_lower': bb_lower, 'bb_middle': bb_middle, 'ma5': ma5, 'ma10': ma10, 'ma20': ma20, 'macd': macd_line, 'macd_signal': signal_line, 'price': price_list[-1] } def check_long_signal(self, signals): """检查做多信号""" if signals is None: return False, "指标不足" conditions = [] # RSI条件 - 避免超买 if signals['rsi'] is not None and signals['rsi'] < 65: conditions.append('RSI<65') # 均线条件 - 短期在长期之上 if (signals['ma5'] is not None and signals['ma10'] is not None and signals['ma5'] > signals['ma10']): conditions.append('MA5>MA10') # MACD条件 if (signals['macd'] is not None and signals['macd_signal'] is not None and signals['macd'] > signals['macd_signal']): conditions.append('MACD金叉') # 布林带条件 - 价格在布林带中轨之上 if (signals['bb_middle'] is not None and signals['price'] > signals['bb_middle']): conditions.append('价格>中轨') # 至少需要3个条件满足 if len(conditions) >= 3: return True, '+'.join(conditions) return False, f"条件不足({len(conditions)}/3)" def check_exit_signal(self, signals, current_price): """检查平仓信号""" if signals is None or self.position == 0: return False, "" # 止损 2.5% stop_loss = self.entry_price * 0.975 if current_price <= stop_loss: return True, f"止损({current_price:.2f}<={stop_loss:.2f})" # 止盈 4% take_profit = self.entry_price * 1.04 if current_price >= take_profit: return True, f"止盈({current_price:.2f}>={take_profit:.2f})" # 最大持仓时间 if self.holding_periods >= self.max_holding_periods: return True, f"时间平仓({self.holding_periods}周期)" # RSI超买平仓 if signals['rsi'] is not None and signals['rsi'] > 75: return True, f"RSI超买({signals['rsi']:.1f})" return False, "" def open_position(self, price, time_str, reason): """开仓""" position_value = self.capital * self.position_size self.position = position_value / price self.entry_price = price self.entry_time = time_str self.holding_periods = 0 self.trades.append({ 'action': 'OPEN', 'time': time_str, 'price': price, 'shares': self.position, 'value': position_value, 'reason': reason }) def close_position(self, price, time_str, reason): """平仓""" if self.position == 0: return pnl = (price - self.entry_price) * self.position pnl_pct = (price / self.entry_price - 1) * 100 self.capital += pnl self.trades.append({ 'action': 'CLOSE', 'time': time_str, 'price': price, 'shares': self.position, 'pnl': pnl, 'pnl_pct': pnl_pct, 'reason': reason }) self.position = 0 self.entry_price = 0 self.holding_periods = 0 def update(self, timestamp, open_price, high, low, close, regime_manager): """更新回测状态""" self.prices.append(close) self.highs.append(high) self.lows.append(low) # 计算信号 signals = self.calculate_signals() # 获取市场状态 dt_str = timestamp.strftime('%Y-%m-%d %H:%M:%S') can_open, regime_reason = regime_manager.can_open_long(dt_str) # 记录权益 equity = self.capital if self.position > 0: equity += self.position * close self.equity_curve.append({ 'time': dt_str, 'equity': equity, 'close': close, 'position': 1 if self.position > 0 else 0 }) # 持仓更新 if self.position > 0: self.holding_periods += 1 # 检查平仓 should_exit, exit_reason = self.check_exit_signal(signals, close) if should_exit: self.close_position(close, dt_str, exit_reason) else: # 空仓 - 检查开仓 # 先检查技术信号 tech_signal, tech_reason = self.check_long_signal(signals) if tech_signal: # 技术信号满足,再检查择时过滤 if can_open: self.open_position(close, dt_str, f"{tech_reason}|{regime_reason}") else: # 技术信号满足但被择时过滤 self.signals.append({ 'time': dt_str, 'price': close, 'tech_reason': tech_reason, 'block_reason': regime_reason }) return equity # ==================== 主程序 ==================== def load_data(filepath): """加载30分钟数据""" print(f"加载数据: {filepath}") data = [] with open(filepath, 'r', encoding='utf-8-sig') as f: # utf-8-sig handles BOM reader = csv.DictReader(f) for row in reader: try: dt = datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S') data.append({ 'datetime': dt, 'open': float(row['Open']), 'high': float(row['High']), 'low': float(row['Low']), 'close': float(row['Close']), 'volume': float(row['Volume']) }) except Exception as e: continue print(f"[OK] 加载成功: {len(data)}条") return data def run_backtest(data_file, regime_file, output_dir='backtest_results'): """运行回测""" import os os.makedirs(output_dir, exist_ok=True) # 加载数据 data = load_data(data_file) regime_manager = MarketRegimeManager(regime_file) # 创建回测引擎 engine = BacktestEngine(initial_capital=1000000, position_size=0.5) print("\n" + "="*70) print("开始回测 - 择时过滤T+1策略") print("="*70) print("策略规则:") print(" - 只做多,持仓上限50%") print(" - 技术信号: RSI<65 + MA5>MA10 + MACD金叉 + 价格>布林带中轨") print(" - 择时过滤: 只在趋势状态(state=1)且趋势概率>0.5时开仓") print(" - 止损: -2.5% | 止盈: +4% | 最大持仓: 16周期(8小时)") print("="*70) # 运行回测 for row in data: engine.update( row['datetime'], row['open'], row['high'], row['low'], row['close'], regime_manager ) # 统计结果 print("\n" + "="*70) print("回测结果") print("="*70) initial = engine.initial_capital final = engine.equity_curve[-1]['equity'] if engine.equity_curve else initial total_return = (final / initial - 1) * 100 print(f"初始资金: {initial:,.2f} 元") print(f"最终资金: {final:,.2f} 元") print(f"总收益率: {total_return:+.2f}%") # 交易统计 trades = engine.trades closed_trades = [t for t in trades if t['action'] == 'CLOSE'] print(f"\n总交易次数: {len(closed_trades)}") if closed_trades: wins = [t for t in closed_trades if t['pnl'] > 0] losses = [t for t in closed_trades if t['pnl'] <= 0] win_count = len(wins) loss_count = len(losses) win_rate = win_count / len(closed_trades) * 100 total_profit = sum(t['pnl'] for t in wins) if wins else 0 total_loss = sum(t['pnl'] for t in losses) if losses else 0 avg_win = total_profit / win_count if win_count > 0 else 0 avg_loss = total_loss / loss_count if loss_count > 0 else 0 profit_factor = abs(total_profit / total_loss) if total_loss != 0 else 0 print(f" 盈利: {win_count} | 亏损: {loss_count}") print(f" 胜率: {win_rate:.2f}%") print(f" 盈亏比: {profit_factor:.2f}") print(f" 平均每笔盈利: {avg_win:,.2f}") print(f" 平均每笔亏损: {avg_loss:,.2f}") # 过滤掉的信号统计 blocked = engine.signals print(f"\n被择时过滤的信号: {len(blocked)}次") if blocked: print(" (技术信号满足但市场状态不允许开仓)") # 保存结果 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # 保存权益曲线 equity_file = f"{output_dir}/equity_with_regime_{timestamp}.csv" with open(equity_file, 'w', newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=['time', 'equity', 'close', 'position']) writer.writeheader() writer.writerows(engine.equity_curve) # 保存交易记录 trades_file = f"{output_dir}/trades_with_regime_{timestamp}.csv" with open(trades_file, 'w', newline='', encoding='utf-8') as f: if trades and len(trades) > 0: writer = csv.DictWriter(f, fieldnames=trades[0].keys()) writer.writeheader() writer.writerows(trades) # 保存过滤信号 if blocked: blocked_file = f"{output_dir}/blocked_signals_{timestamp}.csv" with open(blocked_file, 'w', newline='', encoding='utf-8') as f: writer = csv.DictWriter(f, fieldnames=blocked[0].keys()) writer.writeheader() writer.writerows(blocked) # 保存报告 report_file = f"{output_dir}/report_with_regime_{timestamp}.txt" with open(report_file, 'w', encoding='utf-8') as f: f.write("="*70 + "\n") f.write("CYB50 择时过滤T+1策略回测报告\n") f.write("="*70 + "\n\n") f.write(f"初始资金: {initial:,.2f} 元\n") f.write(f"最终资金: {final:,.2f} 元\n") f.write(f"总收益率: {total_return:+.2f}%\n") f.write(f"总交易次数: {len(closed_trades)}\n") if closed_trades: f.write(f"胜率: {win_rate:.2f}%\n") f.write(f"盈亏比: {profit_factor:.2f}\n") f.write(f"\n被择时过滤的信号: {len(blocked)}次\n") print(f"\n结果已保存到: {output_dir}/") print(f" - {equity_file}") print(f" - {trades_file}") print(f" - {report_file}") return engine if __name__ == '__main__': DATA_FILE = 'cyb50_30min_2023_to_20260325.csv' REGIME_FILE = '../../market-regime-identifier-30/cyb50_30min_regime_result.csv' engine = run_backtest(DATA_FILE, REGIME_FILE)