#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ CYB50 多周期确认 + 参数优化回测系统 结合日线趋势和30分钟择时,支持参数扫描 """ import csv import json from datetime import datetime, timedelta from collections import deque import math import os # ==================== 技术指标计算类 ==================== class TechnicalIndicators: """技术指标计算 - 纯Python实现""" @staticmethod def sma(data, period): """简单移动平均线""" if len(data) < period: return None return sum(data[-period:]) / period @staticmethod def rsi(prices, period=14): """RSI计算""" if len(prices) < period + 1: return None gains = [] losses = [] for i in range(1, len(prices)): change = prices[i] - prices[i-1] if change > 0: gains.append(change) losses.append(0) else: gains.append(0) losses.append(abs(change)) if len(gains) < period: return None avg_gain = sum(gains[-period:]) / period avg_loss = sum(losses[-period:]) / period if avg_loss == 0: return 100 rs = avg_gain / avg_loss return 100 - (100 / (1 + rs)) @staticmethod def bollinger_bands(prices, period=20, std_dev=2): """布林带计算""" if len(prices) < period: return None, None, None middle = sum(prices[-period:]) / period variance = sum((p - middle) ** 2 for p in prices[-period:]) / period std = math.sqrt(variance) upper = middle + (std * std_dev) lower = middle - (std * std_dev) return upper, middle, lower @staticmethod def macd(prices, fast=12, slow=26, signal=9): """MACD计算""" if len(prices) < slow: return None, None, None def calc_ema(data, period): multiplier = 2 / (period + 1) ema = data[0] for price in data[1:]: ema = (price - ema) * multiplier + ema return ema ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None if ema_fast is None or ema_slow is None: return None, None, None macd_line = ema_fast - ema_slow macd_prices = [] for i in range(slow, len(prices) + 1): fast_ema = calc_ema(prices[i-fast:i], fast) slow_ema = calc_ema(prices[i-slow:i], slow) macd_prices.append(fast_ema - slow_ema) signal_line = None if len(macd_prices) >= signal: signal_line = calc_ema(macd_prices[-signal:], signal) histogram = macd_line - signal_line if signal_line else None return macd_line, signal_line, histogram # ==================== 日线趋势管理器 ==================== class DailyTrendManager: """管理日线趋势数据,提供多周期确认""" def __init__(self, daily_file): self.daily_data = {} self.daily_trend = {} # date -> trend info self.load_daily_data(daily_file) self.calculate_daily_trend() def load_daily_data(self, filepath): """加载日线数据""" print(f"加载日线数据: {filepath}") try: with open(filepath, 'r', encoding='utf-8-sig') as f: reader = csv.DictReader(f) for row in reader: try: dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S') date_str = dt.strftime('%Y-%m-%d') self.daily_data[date_str] = { 'open': float(row['open']), 'high': float(row['high']), 'low': float(row['low']), 'close': float(row['close']), 'volume': float(row['volume']) } except: continue print(f"[OK] 加载成功: {len(self.daily_data)}条日线数据") except Exception as e: print(f"[ERROR] 加载失败: {e}") def calculate_daily_trend(self, ma_period=20): """计算日线趋势""" print(f"计算日线趋势 (MA{ma_period})...") dates = sorted(self.daily_data.keys()) closes = [self.daily_data[d]['close'] for d in dates] for i, date in enumerate(dates): if i < ma_period - 1: self.daily_trend[date] = {'trend': 0, 'ma20': None, 'trend_strength': 0} continue # 计算MA20 ma20 = sum(closes[i-ma_period+1:i+1]) / ma_period close = closes[i] # 趋势方向: 1=向上, -1=向下, 0=横盘 if close > ma20 * 1.02: trend = 1 # 明显向上 elif close < ma20 * 0.98: trend = -1 # 明显向下 else: trend = 0 # 横盘 # 趋势强度 trend_strength = (close - ma20) / ma20 * 100 self.daily_trend[date] = { 'trend': trend, 'ma20': ma20, 'trend_strength': trend_strength } print(f"[OK] 日线趋势计算完成") def get_daily_trend(self, date_str): """获取指定日期的日线趋势""" return self.daily_trend.get(date_str, {'trend': 0, 'ma20': None, 'trend_strength': 0}) def can_trade_long(self, date_str, require_uptrend=True): """检查是否允许做多""" trend_info = self.get_daily_trend(date_str) if require_uptrend: # 要求日线趋势向上 return trend_info['trend'] == 1, trend_info else: # 允许横盘和向上,禁止向下 return trend_info['trend'] >= 0, trend_info # ==================== 30分钟市场状态管理器 ==================== class MarketRegimeManager: """管理30分钟市场状态数据""" def __init__(self, regime_file): self.regime_data = {} self.load_regime_data(regime_file) def load_regime_data(self, filepath): """加载市场状态数据""" print(f"加载30分钟状态数据: {filepath}") try: with open(filepath, 'r', encoding='utf-8-sig') as f: reader = csv.DictReader(f) for row in reader: dt_str = row['datetime'] self.regime_data[dt_str] = { 'state': int(row['state']), 'prob_ranging': float(row['prob_ranging']), 'prob_trend': float(row['prob_trend']), 'prob_reversal': float(row['prob_reversal']) } print(f"[OK] 加载成功: {len(self.regime_data)}条状态数据") except Exception as e: print(f"[ERROR] 加载失败: {e}") def get_regime(self, dt_str): """获取指定时间的市场状态""" return self.regime_data.get(dt_str, { 'state': 0, 'prob_ranging': 1.0, 'prob_trend': 0.0, 'prob_reversal': 0.0 }) def can_open_long(self, dt_str, min_trend_prob=0.5): """判断是否允许开多单""" regime = self.get_regime(dt_str) state = regime['state'] trend_prob = regime['prob_trend'] if state == 1 and trend_prob >= min_trend_prob: return True, regime if state == 2: return False, regime return False, regime # ==================== 回测引擎(带参数) ==================== class BacktestEngine: """多周期确认回测引擎""" def __init__(self, initial_capital=1000000, position_size=0.5, min_trend_prob=0.5, require_daily_uptrend=True): self.initial_capital = initial_capital self.position_size = position_size self.min_trend_prob = min_trend_prob self.require_daily_uptrend = require_daily_uptrend self.capital = initial_capital self.position = 0 self.entry_price = 0 self.entry_time = None self.holding_periods = 0 self.max_holding_periods = 16 self.equity_curve = [] self.trades = [] self.signals = [] self.block_reasons = {'daily': 0, 'regime': 0} self.prices = deque(maxlen=100) self.highs = deque(maxlen=100) self.lows = deque(maxlen=100) def calculate_signals(self): """计算交易信号""" if len(self.prices) < 50: return None price_list = list(self.prices) rsi = TechnicalIndicators.rsi(price_list, 14) bb_upper, bb_middle, bb_lower = TechnicalIndicators.bollinger_bands(price_list, 20, 2) ma5 = TechnicalIndicators.sma(price_list, 5) ma10 = TechnicalIndicators.sma(price_list, 10) ma20 = TechnicalIndicators.sma(price_list, 20) macd_line, signal_line, histogram = TechnicalIndicators.macd(price_list) return { 'rsi': rsi, 'bb_upper': bb_upper, 'bb_lower': bb_lower, 'bb_middle': bb_middle, 'ma5': ma5, 'ma10': ma10, 'ma20': ma20, 'macd': macd_line, 'macd_signal': signal_line, 'price': price_list[-1] } def check_long_signal(self, signals): """检查做多信号""" if signals is None: return False, "指标不足" conditions = [] if signals['rsi'] is not None and signals['rsi'] < 65: conditions.append('RSI<65') if (signals['ma5'] is not None and signals['ma10'] is not None and signals['ma5'] > signals['ma10']): conditions.append('MA5>MA10') if (signals['macd'] is not None and signals['macd_signal'] is not None and signals['macd'] > signals['macd_signal']): conditions.append('MACD金叉') if (signals['bb_middle'] is not None and signals['price'] > signals['bb_middle']): conditions.append('价格>中轨') if len(conditions) >= 3: return True, '+'.join(conditions) return False, f"条件不足({len(conditions)}/3)" def check_exit_signal(self, signals, current_price): """检查平仓信号""" if signals is None or self.position == 0: return False, "" stop_loss = self.entry_price * 0.975 if current_price <= stop_loss: return True, f"止损({current_price:.2f}<={stop_loss:.2f})" take_profit = self.entry_price * 1.04 if current_price >= take_profit: return True, f"止盈({current_price:.2f}>={take_profit:.2f})" if self.holding_periods >= self.max_holding_periods: return True, f"时间平仓({self.holding_periods}周期)" if signals['rsi'] is not None and signals['rsi'] > 75: return True, f"RSI超买({signals['rsi']:.1f})" return False, "" def open_position(self, price, time_str, reason): """开仓""" position_value = self.capital * self.position_size self.position = position_value / price self.entry_price = price self.entry_time = time_str self.holding_periods = 0 self.trades.append({ 'action': 'OPEN', 'time': time_str, 'price': price, 'shares': self.position, 'value': position_value, 'reason': reason }) def close_position(self, price, time_str, reason): """平仓""" if self.position == 0: return pnl = (price - self.entry_price) * self.position pnl_pct = (price / self.entry_price - 1) * 100 self.capital += pnl self.trades.append({ 'action': 'CLOSE', 'time': time_str, 'price': price, 'shares': self.position, 'pnl': pnl, 'pnl_pct': pnl_pct, 'reason': reason }) self.position = 0 self.entry_price = 0 self.holding_periods = 0 def update(self, timestamp, open_price, high, low, close, daily_manager, regime_manager): """更新回测状态""" self.prices.append(close) self.highs.append(high) self.lows.append(low) signals = self.calculate_signals() dt_str = timestamp.strftime('%Y-%m-%d %H:%M:%S') date_str = timestamp.strftime('%Y-%m-%d') # 多周期确认 daily_ok, daily_info = daily_manager.can_trade_long( date_str, self.require_daily_uptrend) regime_ok, regime_info = regime_manager.can_open_long( dt_str, self.min_trend_prob) equity = self.capital if self.position > 0: equity += self.position * close self.equity_curve.append({ 'time': dt_str, 'equity': equity, 'close': close, 'position': 1 if self.position > 0 else 0 }) if self.position > 0: self.holding_periods += 1 should_exit, exit_reason = self.check_exit_signal(signals, close) if should_exit: self.close_position(close, dt_str, exit_reason) else: tech_signal, tech_reason = self.check_long_signal(signals) if tech_signal: block_reason = [] if not daily_ok: block_reason.append(f"日线趋势向下(强度:{daily_info['trend_strength']:.2f}%)") self.block_reasons['daily'] += 1 if not regime_ok: block_reason.append(f"30分钟非趋势状态(state={regime_info['state']},概率={regime_info['prob_trend']:.2f})") self.block_reasons['regime'] += 1 if daily_ok and regime_ok: self.open_position(close, dt_str, f"{tech_reason}|日线向上|30分钟趋势(prob={regime_info['prob_trend']:.2f})") else: self.signals.append({ 'time': dt_str, 'price': close, 'tech_reason': tech_reason, 'block_reason': '|'.join(block_reason) }) return equity # ==================== 主程序 ==================== def load_data(filepath): """加载30分钟数据""" print(f"加载30分钟数据: {filepath}") data = [] with open(filepath, 'r', encoding='utf-8-sig') as f: reader = csv.DictReader(f) for row in reader: try: dt = datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S') data.append({ 'datetime': dt, 'open': float(row['Open']), 'high': float(row['High']), 'low': float(row['Low']), 'close': float(row['Close']), 'volume': float(row['Volume']) }) except: continue print(f"[OK] 加载成功: {len(data)}条") return data def run_single_backtest(data, daily_manager, regime_manager, params, output_dir='backtest_results'): """运行单次回测""" os.makedirs(output_dir, exist_ok=True) engine = BacktestEngine( initial_capital=1000000, position_size=0.5, min_trend_prob=params['min_trend_prob'], require_daily_uptrend=params['require_daily_uptrend'] ) for row in data: engine.update( row['datetime'], row['open'], row['high'], row['low'], row['close'], daily_manager, regime_manager ) # 统计结果 initial = engine.initial_capital final = engine.equity_curve[-1]['equity'] if engine.equity_curve else initial total_return = (final / initial - 1) * 100 closed_trades = [t for t in engine.trades if t['action'] == 'CLOSE'] win_count = len([t for t in closed_trades if t['pnl'] > 0]) loss_count = len([t for t in closed_trades if t['pnl'] <= 0]) win_rate = win_count / len(closed_trades) * 100 if closed_trades else 0 total_profit = sum(t['pnl'] for t in closed_trades if t['pnl'] > 0) total_loss = sum(t['pnl'] for t in closed_trades if t['pnl'] <= 0) profit_factor = abs(total_profit / total_loss) if total_loss != 0 else 0 result = { 'params': params, 'total_return': total_return, 'trade_count': len(closed_trades), 'win_count': win_count, 'loss_count': loss_count, 'win_rate': win_rate, 'profit_factor': profit_factor, 'blocked_daily': engine.block_reasons['daily'], 'blocked_regime': engine.block_reasons['regime'], 'final_capital': final } return result, engine def run_parameter_scan(data_file, daily_file, regime_file, output_dir='optimization_results'): """参数扫描优化""" os.makedirs(output_dir, exist_ok=True) # 加载数据 data = load_data(data_file) daily_manager = DailyTrendManager(daily_file) regime_manager = MarketRegimeManager(regime_file) # 参数网格 param_grid = [ {'min_trend_prob': 0.3, 'require_daily_uptrend': True}, {'min_trend_prob': 0.4, 'require_daily_uptrend': True}, {'min_trend_prob': 0.5, 'require_daily_uptrend': True}, {'min_trend_prob': 0.6, 'require_daily_uptrend': True}, {'min_trend_prob': 0.7, 'require_daily_uptrend': True}, {'min_trend_prob': 0.5, 'require_daily_uptrend': False}, # 允许横盘 ] print("\n" + "="*70) print("参数优化扫描") print("="*70) all_results = [] for i, params in enumerate(param_grid): print(f"\n[{i+1}/{len(param_grid)}] 测试参数: {params}") result, engine = run_single_backtest( data, daily_manager, regime_manager, params, output_dir) all_results.append(result) print(f" 收益率: {result['total_return']:+.2f}%") print(f" 交易次数: {result['trade_count']}") print(f" 胜率: {result['win_rate']:.1f}%") print(f" 盈亏比: {result['profit_factor']:.2f}") # 排序结果 all_results.sort(key=lambda x: x['total_return'], reverse=True) print("\n" + "="*70) print("参数优化结果排名") print("="*70) for i, r in enumerate(all_results[:5]): print(f"\n第{i+1}名:") print(f" 参数: 趋势概率阈值={r['params']['min_trend_prob']}, " f"要求日线向上={r['params']['require_daily_uptrend']}") print(f" 收益率: {r['total_return']:+.2f}%") print(f" 交易次数: {r['trade_count']}") print(f" 胜率: {r['win_rate']:.1f}%") print(f" 盈亏比: {r['profit_factor']:.2f}") print(f" 被日线过滤: {r['blocked_daily']}次") print(f" 被30分钟过滤: {r['blocked_regime']}次") # 保存优化结果 result_file = f"{output_dir}/parameter_optimization_results.json" with open(result_file, 'w', encoding='utf-8') as f: json.dump(all_results, f, indent=2, ensure_ascii=False) print(f"\n优化结果已保存: {result_file}") return all_results if __name__ == '__main__': DATA_FILE = 'cyb50_30min_2023_to_20260325.csv' DAILY_FILE = '../data-fetch/data/399673_SZ_day_20150101_20260325.csv' REGIME_FILE = '../../market-regime-identifier-30/cyb50_30min_regime_result.csv' results = run_parameter_scan(DATA_FILE, DAILY_FILE, REGIME_FILE)