| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- CYB50 择时过滤T+1回测系统 - 结合市场状态
- 只做多,使用市场状态过滤开仓信号
- """
- import csv
- import json
- from datetime import datetime, timedelta
- from collections import deque
- import math
- # ==================== 技术指标计算类 ====================
- class TechnicalIndicators:
- """技术指标计算 - 纯Python实现"""
- @staticmethod
- def sma(data, period):
- """简单移动平均线"""
- if len(data) < period:
- return None
- return sum(data[-period:]) / period
- @staticmethod
- def ema(data, period):
- """指数移动平均线"""
- if len(data) < period:
- return None
- multiplier = 2 / (period + 1)
- ema = data[0]
- for price in data[1:]:
- ema = (price - ema) * multiplier + ema
- return ema
- @staticmethod
- def rsi(prices, period=14):
- """RSI计算"""
- if len(prices) < period + 1:
- return None
- gains = []
- losses = []
- for i in range(1, len(prices)):
- change = prices[i] - prices[i-1]
- if change > 0:
- gains.append(change)
- losses.append(0)
- else:
- gains.append(0)
- losses.append(abs(change))
- if len(gains) < period:
- return None
- avg_gain = sum(gains[-period:]) / period
- avg_loss = sum(losses[-period:]) / period
- if avg_loss == 0:
- return 100
- rs = avg_gain / avg_loss
- return 100 - (100 / (1 + rs))
- @staticmethod
- def bollinger_bands(prices, period=20, std_dev=2):
- """布林带计算"""
- if len(prices) < period:
- return None, None, None
- middle = sum(prices[-period:]) / period
- variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
- std = math.sqrt(variance)
- upper = middle + (std * std_dev)
- lower = middle - (std * std_dev)
- return upper, middle, lower
- @staticmethod
- def macd(prices, fast=12, slow=26, signal=9):
- """MACD计算"""
- if len(prices) < slow:
- return None, None, None
- def calc_ema(data, period):
- multiplier = 2 / (period + 1)
- ema = data[0]
- for price in data[1:]:
- ema = (price - ema) * multiplier + ema
- return ema
- ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None
- ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None
- if ema_fast is None or ema_slow is None:
- return None, None, None
- macd_line = ema_fast - ema_slow
- # 计算信号线 (EMA of MACD)
- macd_prices = []
- for i in range(slow, len(prices) + 1):
- fast_ema = calc_ema(prices[i-fast:i], fast)
- slow_ema = calc_ema(prices[i-slow:i], slow)
- macd_prices.append(fast_ema - slow_ema)
- signal_line = None
- if len(macd_prices) >= signal:
- signal_line = calc_ema(macd_prices[-signal:], signal)
- histogram = macd_line - signal_line if signal_line else None
- return macd_line, signal_line, histogram
- # ==================== 市场状态管理器 ====================
- class MarketRegimeManager:
- """管理市场状态数据,提供择时过滤"""
- def __init__(self, regime_file):
- self.regime_data = {}
- self.load_regime_data(regime_file)
- def load_regime_data(self, filepath):
- """加载市场状态数据"""
- print(f"加载市场状态数据: {filepath}")
- try:
- with open(filepath, 'r', encoding='utf-8') as f:
- reader = csv.DictReader(f)
- for row in reader:
- # 解析datetime
- dt_str = row['datetime']
- self.regime_data[dt_str] = {
- 'state': int(row['state']),
- 'prob_ranging': float(row['prob_ranging']),
- 'prob_trend': float(row['prob_trend']),
- 'prob_reversal': float(row['prob_reversal'])
- }
- print(f"[OK] 加载成功: {len(self.regime_data)}条状态数据")
- except Exception as e:
- print(f"[ERROR] 加载失败: {e}")
- self.regime_data = {}
- def get_regime(self, dt_str):
- """获取指定时间的市场状态"""
- return self.regime_data.get(dt_str, {
- 'state': 0, # 默认震荡
- 'prob_ranging': 1.0,
- 'prob_trend': 0.0,
- 'prob_reversal': 0.0
- })
- def can_open_long(self, dt_str, min_trend_prob=0.5):
- """
- 判断是否允许开多单
- 规则:
- - 趋势状态(state=1) + 趋势概率 > min_trend_prob -> 允许
- - 其他状态 -> 禁止
- """
- regime = self.get_regime(dt_str)
- state = regime['state']
- trend_prob = regime['prob_trend']
- # 只在趋势状态且概率足够高时允许开仓
- if state == 1 and trend_prob >= min_trend_prob:
- return True, f"趋势状态(概率{trend_prob:.2f})"
- # 反转状态 - 禁止开仓
- if state == 2:
- return False, f"反转状态(概率{regime['prob_reversal']:.2f})"
- # 震荡状态 - 观望
- return False, f"震荡状态(概率{regime['prob_ranging']:.2f})"
- # ==================== 回测引擎 ====================
- class BacktestEngine:
- """择时过滤T+1回测引擎"""
- def __init__(self, initial_capital=1000000, position_size=0.5):
- self.initial_capital = initial_capital
- self.position_size = position_size
- self.capital = initial_capital
- self.position = 0 # 持仓数量
- self.entry_price = 0
- self.entry_time = None
- self.holding_periods = 0
- self.max_holding_periods = 16 # 最大持仓周期(8小时)
- # 记录
- self.equity_curve = []
- self.trades = []
- self.signals = []
- # 指标
- self.prices = deque(maxlen=100)
- self.highs = deque(maxlen=100)
- self.lows = deque(maxlen=100)
- def calculate_signals(self):
- """计算交易信号"""
- if len(self.prices) < 50:
- return None
- price_list = list(self.prices)
- high_list = list(self.highs)
- low_list = list(self.lows)
- # 技术指标
- rsi = TechnicalIndicators.rsi(price_list, 14)
- bb_upper, bb_middle, bb_lower = TechnicalIndicators.bollinger_bands(price_list, 20, 2)
- # 均线
- ma5 = TechnicalIndicators.sma(price_list, 5)
- ma10 = TechnicalIndicators.sma(price_list, 10)
- ma20 = TechnicalIndicators.sma(price_list, 20)
- # MACD
- macd_line, signal_line, histogram = TechnicalIndicators.macd(price_list)
- return {
- 'rsi': rsi,
- 'bb_upper': bb_upper,
- 'bb_lower': bb_lower,
- 'bb_middle': bb_middle,
- 'ma5': ma5,
- 'ma10': ma10,
- 'ma20': ma20,
- 'macd': macd_line,
- 'macd_signal': signal_line,
- 'price': price_list[-1]
- }
- def check_long_signal(self, signals):
- """检查做多信号"""
- if signals is None:
- return False, "指标不足"
- conditions = []
- # RSI条件 - 避免超买
- if signals['rsi'] is not None and signals['rsi'] < 65:
- conditions.append('RSI<65')
- # 均线条件 - 短期在长期之上
- if (signals['ma5'] is not None and signals['ma10'] is not None and
- signals['ma5'] > signals['ma10']):
- conditions.append('MA5>MA10')
- # MACD条件
- if (signals['macd'] is not None and signals['macd_signal'] is not None and
- signals['macd'] > signals['macd_signal']):
- conditions.append('MACD金叉')
- # 布林带条件 - 价格在布林带中轨之上
- if (signals['bb_middle'] is not None and
- signals['price'] > signals['bb_middle']):
- conditions.append('价格>中轨')
- # 至少需要3个条件满足
- if len(conditions) >= 3:
- return True, '+'.join(conditions)
- return False, f"条件不足({len(conditions)}/3)"
- def check_exit_signal(self, signals, current_price):
- """检查平仓信号"""
- if signals is None or self.position == 0:
- return False, ""
- # 止损 2.5%
- stop_loss = self.entry_price * 0.975
- if current_price <= stop_loss:
- return True, f"止损({current_price:.2f}<={stop_loss:.2f})"
- # 止盈 4%
- take_profit = self.entry_price * 1.04
- if current_price >= take_profit:
- return True, f"止盈({current_price:.2f}>={take_profit:.2f})"
- # 最大持仓时间
- if self.holding_periods >= self.max_holding_periods:
- return True, f"时间平仓({self.holding_periods}周期)"
- # RSI超买平仓
- if signals['rsi'] is not None and signals['rsi'] > 75:
- return True, f"RSI超买({signals['rsi']:.1f})"
- return False, ""
- def open_position(self, price, time_str, reason):
- """开仓"""
- position_value = self.capital * self.position_size
- self.position = position_value / price
- self.entry_price = price
- self.entry_time = time_str
- self.holding_periods = 0
- self.trades.append({
- 'action': 'OPEN',
- 'time': time_str,
- 'price': price,
- 'shares': self.position,
- 'value': position_value,
- 'reason': reason
- })
- def close_position(self, price, time_str, reason):
- """平仓"""
- if self.position == 0:
- return
- pnl = (price - self.entry_price) * self.position
- pnl_pct = (price / self.entry_price - 1) * 100
- self.capital += pnl
- self.trades.append({
- 'action': 'CLOSE',
- 'time': time_str,
- 'price': price,
- 'shares': self.position,
- 'pnl': pnl,
- 'pnl_pct': pnl_pct,
- 'reason': reason
- })
- self.position = 0
- self.entry_price = 0
- self.holding_periods = 0
- def update(self, timestamp, open_price, high, low, close, regime_manager):
- """更新回测状态"""
- self.prices.append(close)
- self.highs.append(high)
- self.lows.append(low)
- # 计算信号
- signals = self.calculate_signals()
- # 获取市场状态
- dt_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
- can_open, regime_reason = regime_manager.can_open_long(dt_str)
- # 记录权益
- equity = self.capital
- if self.position > 0:
- equity += self.position * close
- self.equity_curve.append({
- 'time': dt_str,
- 'equity': equity,
- 'close': close,
- 'position': 1 if self.position > 0 else 0
- })
- # 持仓更新
- if self.position > 0:
- self.holding_periods += 1
- # 检查平仓
- should_exit, exit_reason = self.check_exit_signal(signals, close)
- if should_exit:
- self.close_position(close, dt_str, exit_reason)
- else:
- # 空仓 - 检查开仓
- # 先检查技术信号
- tech_signal, tech_reason = self.check_long_signal(signals)
- if tech_signal:
- # 技术信号满足,再检查择时过滤
- if can_open:
- self.open_position(close, dt_str, f"{tech_reason}|{regime_reason}")
- else:
- # 技术信号满足但被择时过滤
- self.signals.append({
- 'time': dt_str,
- 'price': close,
- 'tech_reason': tech_reason,
- 'block_reason': regime_reason
- })
- return equity
- # ==================== 主程序 ====================
- def load_data(filepath):
- """加载30分钟数据"""
- print(f"加载数据: {filepath}")
- data = []
- with open(filepath, 'r', encoding='utf-8-sig') as f: # utf-8-sig handles BOM
- reader = csv.DictReader(f)
- for row in reader:
- try:
- dt = datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S')
- data.append({
- 'datetime': dt,
- 'open': float(row['Open']),
- 'high': float(row['High']),
- 'low': float(row['Low']),
- 'close': float(row['Close']),
- 'volume': float(row['Volume'])
- })
- except Exception as e:
- continue
- print(f"[OK] 加载成功: {len(data)}条")
- return data
- def run_backtest(data_file, regime_file, output_dir='backtest_results'):
- """运行回测"""
- import os
- os.makedirs(output_dir, exist_ok=True)
- # 加载数据
- data = load_data(data_file)
- regime_manager = MarketRegimeManager(regime_file)
- # 创建回测引擎
- engine = BacktestEngine(initial_capital=1000000, position_size=0.5)
- print("\n" + "="*70)
- print("开始回测 - 择时过滤T+1策略")
- print("="*70)
- print("策略规则:")
- print(" - 只做多,持仓上限50%")
- print(" - 技术信号: RSI<65 + MA5>MA10 + MACD金叉 + 价格>布林带中轨")
- print(" - 择时过滤: 只在趋势状态(state=1)且趋势概率>0.5时开仓")
- print(" - 止损: -2.5% | 止盈: +4% | 最大持仓: 16周期(8小时)")
- print("="*70)
- # 运行回测
- for row in data:
- engine.update(
- row['datetime'],
- row['open'],
- row['high'],
- row['low'],
- row['close'],
- regime_manager
- )
- # 统计结果
- print("\n" + "="*70)
- print("回测结果")
- print("="*70)
- initial = engine.initial_capital
- final = engine.equity_curve[-1]['equity'] if engine.equity_curve else initial
- total_return = (final / initial - 1) * 100
- print(f"初始资金: {initial:,.2f} 元")
- print(f"最终资金: {final:,.2f} 元")
- print(f"总收益率: {total_return:+.2f}%")
- # 交易统计
- trades = engine.trades
- closed_trades = [t for t in trades if t['action'] == 'CLOSE']
- print(f"\n总交易次数: {len(closed_trades)}")
- if closed_trades:
- wins = [t for t in closed_trades if t['pnl'] > 0]
- losses = [t for t in closed_trades if t['pnl'] <= 0]
- win_count = len(wins)
- loss_count = len(losses)
- win_rate = win_count / len(closed_trades) * 100
- total_profit = sum(t['pnl'] for t in wins) if wins else 0
- total_loss = sum(t['pnl'] for t in losses) if losses else 0
- avg_win = total_profit / win_count if win_count > 0 else 0
- avg_loss = total_loss / loss_count if loss_count > 0 else 0
- profit_factor = abs(total_profit / total_loss) if total_loss != 0 else 0
- print(f" 盈利: {win_count} | 亏损: {loss_count}")
- print(f" 胜率: {win_rate:.2f}%")
- print(f" 盈亏比: {profit_factor:.2f}")
- print(f" 平均每笔盈利: {avg_win:,.2f}")
- print(f" 平均每笔亏损: {avg_loss:,.2f}")
- # 过滤掉的信号统计
- blocked = engine.signals
- print(f"\n被择时过滤的信号: {len(blocked)}次")
- if blocked:
- print(" (技术信号满足但市场状态不允许开仓)")
- # 保存结果
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- # 保存权益曲线
- equity_file = f"{output_dir}/equity_with_regime_{timestamp}.csv"
- with open(equity_file, 'w', newline='', encoding='utf-8') as f:
- writer = csv.DictWriter(f, fieldnames=['time', 'equity', 'close', 'position'])
- writer.writeheader()
- writer.writerows(engine.equity_curve)
- # 保存交易记录
- trades_file = f"{output_dir}/trades_with_regime_{timestamp}.csv"
- with open(trades_file, 'w', newline='', encoding='utf-8') as f:
- if trades and len(trades) > 0:
- writer = csv.DictWriter(f, fieldnames=trades[0].keys())
- writer.writeheader()
- writer.writerows(trades)
- # 保存过滤信号
- if blocked:
- blocked_file = f"{output_dir}/blocked_signals_{timestamp}.csv"
- with open(blocked_file, 'w', newline='', encoding='utf-8') as f:
- writer = csv.DictWriter(f, fieldnames=blocked[0].keys())
- writer.writeheader()
- writer.writerows(blocked)
- # 保存报告
- report_file = f"{output_dir}/report_with_regime_{timestamp}.txt"
- with open(report_file, 'w', encoding='utf-8') as f:
- f.write("="*70 + "\n")
- f.write("CYB50 择时过滤T+1策略回测报告\n")
- f.write("="*70 + "\n\n")
- f.write(f"初始资金: {initial:,.2f} 元\n")
- f.write(f"最终资金: {final:,.2f} 元\n")
- f.write(f"总收益率: {total_return:+.2f}%\n")
- f.write(f"总交易次数: {len(closed_trades)}\n")
- if closed_trades:
- f.write(f"胜率: {win_rate:.2f}%\n")
- f.write(f"盈亏比: {profit_factor:.2f}\n")
- f.write(f"\n被择时过滤的信号: {len(blocked)}次\n")
- print(f"\n结果已保存到: {output_dir}/")
- print(f" - {equity_file}")
- print(f" - {trades_file}")
- print(f" - {report_file}")
- return engine
- if __name__ == '__main__':
- DATA_FILE = 'cyb50_30min_2023_to_20260325.csv'
- REGIME_FILE = '../../market-regime-identifier-30/cyb50_30min_regime_result.csv'
- engine = run_backtest(DATA_FILE, REGIME_FILE)
|