| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- CYB50 多周期确认 + 参数优化回测系统
- 结合日线趋势和30分钟择时,支持参数扫描
- """
- import csv
- import json
- from datetime import datetime, timedelta
- from collections import deque
- import math
- import os
- # ==================== 技术指标计算类 ====================
- class TechnicalIndicators:
- """技术指标计算 - 纯Python实现"""
- @staticmethod
- def sma(data, period):
- """简单移动平均线"""
- if len(data) < period:
- return None
- return sum(data[-period:]) / period
- @staticmethod
- def rsi(prices, period=14):
- """RSI计算"""
- if len(prices) < period + 1:
- return None
- gains = []
- losses = []
- for i in range(1, len(prices)):
- change = prices[i] - prices[i-1]
- if change > 0:
- gains.append(change)
- losses.append(0)
- else:
- gains.append(0)
- losses.append(abs(change))
- if len(gains) < period:
- return None
- avg_gain = sum(gains[-period:]) / period
- avg_loss = sum(losses[-period:]) / period
- if avg_loss == 0:
- return 100
- rs = avg_gain / avg_loss
- return 100 - (100 / (1 + rs))
- @staticmethod
- def bollinger_bands(prices, period=20, std_dev=2):
- """布林带计算"""
- if len(prices) < period:
- return None, None, None
- middle = sum(prices[-period:]) / period
- variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
- std = math.sqrt(variance)
- upper = middle + (std * std_dev)
- lower = middle - (std * std_dev)
- return upper, middle, lower
- @staticmethod
- def macd(prices, fast=12, slow=26, signal=9):
- """MACD计算"""
- if len(prices) < slow:
- return None, None, None
- def calc_ema(data, period):
- multiplier = 2 / (period + 1)
- ema = data[0]
- for price in data[1:]:
- ema = (price - ema) * multiplier + ema
- return ema
- ema_fast = calc_ema(prices[-fast:], fast) if len(prices) >= fast else None
- ema_slow = calc_ema(prices[-slow:], slow) if len(prices) >= slow else None
- if ema_fast is None or ema_slow is None:
- return None, None, None
- macd_line = ema_fast - ema_slow
- macd_prices = []
- for i in range(slow, len(prices) + 1):
- fast_ema = calc_ema(prices[i-fast:i], fast)
- slow_ema = calc_ema(prices[i-slow:i], slow)
- macd_prices.append(fast_ema - slow_ema)
- signal_line = None
- if len(macd_prices) >= signal:
- signal_line = calc_ema(macd_prices[-signal:], signal)
- histogram = macd_line - signal_line if signal_line else None
- return macd_line, signal_line, histogram
- # ==================== 日线趋势管理器 ====================
- class DailyTrendManager:
- """管理日线趋势数据,提供多周期确认"""
- def __init__(self, daily_file):
- self.daily_data = {}
- self.daily_trend = {} # date -> trend info
- self.load_daily_data(daily_file)
- self.calculate_daily_trend()
- def load_daily_data(self, filepath):
- """加载日线数据"""
- print(f"加载日线数据: {filepath}")
- try:
- with open(filepath, 'r', encoding='utf-8-sig') as f:
- reader = csv.DictReader(f)
- for row in reader:
- try:
- dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S')
- date_str = dt.strftime('%Y-%m-%d')
- self.daily_data[date_str] = {
- 'open': float(row['open']),
- 'high': float(row['high']),
- 'low': float(row['low']),
- 'close': float(row['close']),
- 'volume': float(row['volume'])
- }
- except:
- continue
- print(f"[OK] 加载成功: {len(self.daily_data)}条日线数据")
- except Exception as e:
- print(f"[ERROR] 加载失败: {e}")
- def calculate_daily_trend(self, ma_period=20):
- """计算日线趋势"""
- print(f"计算日线趋势 (MA{ma_period})...")
- dates = sorted(self.daily_data.keys())
- closes = [self.daily_data[d]['close'] for d in dates]
- for i, date in enumerate(dates):
- if i < ma_period - 1:
- self.daily_trend[date] = {'trend': 0, 'ma20': None, 'trend_strength': 0}
- continue
- # 计算MA20
- ma20 = sum(closes[i-ma_period+1:i+1]) / ma_period
- close = closes[i]
- # 趋势方向: 1=向上, -1=向下, 0=横盘
- if close > ma20 * 1.02:
- trend = 1 # 明显向上
- elif close < ma20 * 0.98:
- trend = -1 # 明显向下
- else:
- trend = 0 # 横盘
- # 趋势强度
- trend_strength = (close - ma20) / ma20 * 100
- self.daily_trend[date] = {
- 'trend': trend,
- 'ma20': ma20,
- 'trend_strength': trend_strength
- }
- print(f"[OK] 日线趋势计算完成")
- def get_daily_trend(self, date_str):
- """获取指定日期的日线趋势"""
- return self.daily_trend.get(date_str, {'trend': 0, 'ma20': None, 'trend_strength': 0})
- def can_trade_long(self, date_str, require_uptrend=True):
- """检查是否允许做多"""
- trend_info = self.get_daily_trend(date_str)
- if require_uptrend:
- # 要求日线趋势向上
- return trend_info['trend'] == 1, trend_info
- else:
- # 允许横盘和向上,禁止向下
- return trend_info['trend'] >= 0, trend_info
- # ==================== 30分钟市场状态管理器 ====================
- class MarketRegimeManager:
- """管理30分钟市场状态数据"""
- def __init__(self, regime_file):
- self.regime_data = {}
- self.load_regime_data(regime_file)
- def load_regime_data(self, filepath):
- """加载市场状态数据"""
- print(f"加载30分钟状态数据: {filepath}")
- try:
- with open(filepath, 'r', encoding='utf-8-sig') as f:
- reader = csv.DictReader(f)
- for row in reader:
- dt_str = row['datetime']
- self.regime_data[dt_str] = {
- 'state': int(row['state']),
- 'prob_ranging': float(row['prob_ranging']),
- 'prob_trend': float(row['prob_trend']),
- 'prob_reversal': float(row['prob_reversal'])
- }
- print(f"[OK] 加载成功: {len(self.regime_data)}条状态数据")
- except Exception as e:
- print(f"[ERROR] 加载失败: {e}")
- def get_regime(self, dt_str):
- """获取指定时间的市场状态"""
- return self.regime_data.get(dt_str, {
- 'state': 0,
- 'prob_ranging': 1.0,
- 'prob_trend': 0.0,
- 'prob_reversal': 0.0
- })
- def can_open_long(self, dt_str, min_trend_prob=0.5):
- """判断是否允许开多单"""
- regime = self.get_regime(dt_str)
- state = regime['state']
- trend_prob = regime['prob_trend']
- if state == 1 and trend_prob >= min_trend_prob:
- return True, regime
- if state == 2:
- return False, regime
- return False, regime
- # ==================== 回测引擎(带参数) ====================
- class BacktestEngine:
- """多周期确认回测引擎"""
- def __init__(self, initial_capital=1000000, position_size=0.5,
- min_trend_prob=0.5, require_daily_uptrend=True):
- self.initial_capital = initial_capital
- self.position_size = position_size
- self.min_trend_prob = min_trend_prob
- self.require_daily_uptrend = require_daily_uptrend
- self.capital = initial_capital
- self.position = 0
- self.entry_price = 0
- self.entry_time = None
- self.holding_periods = 0
- self.max_holding_periods = 16
- self.equity_curve = []
- self.trades = []
- self.signals = []
- self.block_reasons = {'daily': 0, 'regime': 0}
- self.prices = deque(maxlen=100)
- self.highs = deque(maxlen=100)
- self.lows = deque(maxlen=100)
- def calculate_signals(self):
- """计算交易信号"""
- if len(self.prices) < 50:
- return None
- price_list = list(self.prices)
- rsi = TechnicalIndicators.rsi(price_list, 14)
- bb_upper, bb_middle, bb_lower = TechnicalIndicators.bollinger_bands(price_list, 20, 2)
- ma5 = TechnicalIndicators.sma(price_list, 5)
- ma10 = TechnicalIndicators.sma(price_list, 10)
- ma20 = TechnicalIndicators.sma(price_list, 20)
- macd_line, signal_line, histogram = TechnicalIndicators.macd(price_list)
- return {
- 'rsi': rsi,
- 'bb_upper': bb_upper,
- 'bb_lower': bb_lower,
- 'bb_middle': bb_middle,
- 'ma5': ma5,
- 'ma10': ma10,
- 'ma20': ma20,
- 'macd': macd_line,
- 'macd_signal': signal_line,
- 'price': price_list[-1]
- }
- def check_long_signal(self, signals):
- """检查做多信号"""
- if signals is None:
- return False, "指标不足"
- conditions = []
- if signals['rsi'] is not None and signals['rsi'] < 65:
- conditions.append('RSI<65')
- if (signals['ma5'] is not None and signals['ma10'] is not None and
- signals['ma5'] > signals['ma10']):
- conditions.append('MA5>MA10')
- if (signals['macd'] is not None and signals['macd_signal'] is not None and
- signals['macd'] > signals['macd_signal']):
- conditions.append('MACD金叉')
- if (signals['bb_middle'] is not None and
- signals['price'] > signals['bb_middle']):
- conditions.append('价格>中轨')
- if len(conditions) >= 3:
- return True, '+'.join(conditions)
- return False, f"条件不足({len(conditions)}/3)"
- def check_exit_signal(self, signals, current_price):
- """检查平仓信号"""
- if signals is None or self.position == 0:
- return False, ""
- stop_loss = self.entry_price * 0.975
- if current_price <= stop_loss:
- return True, f"止损({current_price:.2f}<={stop_loss:.2f})"
- take_profit = self.entry_price * 1.04
- if current_price >= take_profit:
- return True, f"止盈({current_price:.2f}>={take_profit:.2f})"
- if self.holding_periods >= self.max_holding_periods:
- return True, f"时间平仓({self.holding_periods}周期)"
- if signals['rsi'] is not None and signals['rsi'] > 75:
- return True, f"RSI超买({signals['rsi']:.1f})"
- return False, ""
- def open_position(self, price, time_str, reason):
- """开仓"""
- position_value = self.capital * self.position_size
- self.position = position_value / price
- self.entry_price = price
- self.entry_time = time_str
- self.holding_periods = 0
- self.trades.append({
- 'action': 'OPEN',
- 'time': time_str,
- 'price': price,
- 'shares': self.position,
- 'value': position_value,
- 'reason': reason
- })
- def close_position(self, price, time_str, reason):
- """平仓"""
- if self.position == 0:
- return
- pnl = (price - self.entry_price) * self.position
- pnl_pct = (price / self.entry_price - 1) * 100
- self.capital += pnl
- self.trades.append({
- 'action': 'CLOSE',
- 'time': time_str,
- 'price': price,
- 'shares': self.position,
- 'pnl': pnl,
- 'pnl_pct': pnl_pct,
- 'reason': reason
- })
- self.position = 0
- self.entry_price = 0
- self.holding_periods = 0
- def update(self, timestamp, open_price, high, low, close,
- daily_manager, regime_manager):
- """更新回测状态"""
- self.prices.append(close)
- self.highs.append(high)
- self.lows.append(low)
- signals = self.calculate_signals()
- dt_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
- date_str = timestamp.strftime('%Y-%m-%d')
- # 多周期确认
- daily_ok, daily_info = daily_manager.can_trade_long(
- date_str, self.require_daily_uptrend)
- regime_ok, regime_info = regime_manager.can_open_long(
- dt_str, self.min_trend_prob)
- equity = self.capital
- if self.position > 0:
- equity += self.position * close
- self.equity_curve.append({
- 'time': dt_str,
- 'equity': equity,
- 'close': close,
- 'position': 1 if self.position > 0 else 0
- })
- if self.position > 0:
- self.holding_periods += 1
- should_exit, exit_reason = self.check_exit_signal(signals, close)
- if should_exit:
- self.close_position(close, dt_str, exit_reason)
- else:
- tech_signal, tech_reason = self.check_long_signal(signals)
- if tech_signal:
- block_reason = []
- if not daily_ok:
- block_reason.append(f"日线趋势向下(强度:{daily_info['trend_strength']:.2f}%)")
- self.block_reasons['daily'] += 1
- if not regime_ok:
- block_reason.append(f"30分钟非趋势状态(state={regime_info['state']},概率={regime_info['prob_trend']:.2f})")
- self.block_reasons['regime'] += 1
- if daily_ok and regime_ok:
- self.open_position(close, dt_str,
- f"{tech_reason}|日线向上|30分钟趋势(prob={regime_info['prob_trend']:.2f})")
- else:
- self.signals.append({
- 'time': dt_str,
- 'price': close,
- 'tech_reason': tech_reason,
- 'block_reason': '|'.join(block_reason)
- })
- return equity
- # ==================== 主程序 ====================
- def load_data(filepath):
- """加载30分钟数据"""
- print(f"加载30分钟数据: {filepath}")
- data = []
- with open(filepath, 'r', encoding='utf-8-sig') as f:
- reader = csv.DictReader(f)
- for row in reader:
- try:
- dt = datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S')
- data.append({
- 'datetime': dt,
- 'open': float(row['Open']),
- 'high': float(row['High']),
- 'low': float(row['Low']),
- 'close': float(row['Close']),
- 'volume': float(row['Volume'])
- })
- except:
- continue
- print(f"[OK] 加载成功: {len(data)}条")
- return data
- def run_single_backtest(data, daily_manager, regime_manager, params,
- output_dir='backtest_results'):
- """运行单次回测"""
- os.makedirs(output_dir, exist_ok=True)
- engine = BacktestEngine(
- initial_capital=1000000,
- position_size=0.5,
- min_trend_prob=params['min_trend_prob'],
- require_daily_uptrend=params['require_daily_uptrend']
- )
- for row in data:
- engine.update(
- row['datetime'],
- row['open'],
- row['high'],
- row['low'],
- row['close'],
- daily_manager,
- regime_manager
- )
- # 统计结果
- initial = engine.initial_capital
- final = engine.equity_curve[-1]['equity'] if engine.equity_curve else initial
- total_return = (final / initial - 1) * 100
- closed_trades = [t for t in engine.trades if t['action'] == 'CLOSE']
- win_count = len([t for t in closed_trades if t['pnl'] > 0])
- loss_count = len([t for t in closed_trades if t['pnl'] <= 0])
- win_rate = win_count / len(closed_trades) * 100 if closed_trades else 0
- total_profit = sum(t['pnl'] for t in closed_trades if t['pnl'] > 0)
- total_loss = sum(t['pnl'] for t in closed_trades if t['pnl'] <= 0)
- profit_factor = abs(total_profit / total_loss) if total_loss != 0 else 0
- result = {
- 'params': params,
- 'total_return': total_return,
- 'trade_count': len(closed_trades),
- 'win_count': win_count,
- 'loss_count': loss_count,
- 'win_rate': win_rate,
- 'profit_factor': profit_factor,
- 'blocked_daily': engine.block_reasons['daily'],
- 'blocked_regime': engine.block_reasons['regime'],
- 'final_capital': final
- }
- return result, engine
- def run_parameter_scan(data_file, daily_file, regime_file, output_dir='optimization_results'):
- """参数扫描优化"""
- os.makedirs(output_dir, exist_ok=True)
- # 加载数据
- data = load_data(data_file)
- daily_manager = DailyTrendManager(daily_file)
- regime_manager = MarketRegimeManager(regime_file)
- # 参数网格
- param_grid = [
- {'min_trend_prob': 0.3, 'require_daily_uptrend': True},
- {'min_trend_prob': 0.4, 'require_daily_uptrend': True},
- {'min_trend_prob': 0.5, 'require_daily_uptrend': True},
- {'min_trend_prob': 0.6, 'require_daily_uptrend': True},
- {'min_trend_prob': 0.7, 'require_daily_uptrend': True},
- {'min_trend_prob': 0.5, 'require_daily_uptrend': False}, # 允许横盘
- ]
- print("\n" + "="*70)
- print("参数优化扫描")
- print("="*70)
- all_results = []
- for i, params in enumerate(param_grid):
- print(f"\n[{i+1}/{len(param_grid)}] 测试参数: {params}")
- result, engine = run_single_backtest(
- data, daily_manager, regime_manager, params, output_dir)
- all_results.append(result)
- print(f" 收益率: {result['total_return']:+.2f}%")
- print(f" 交易次数: {result['trade_count']}")
- print(f" 胜率: {result['win_rate']:.1f}%")
- print(f" 盈亏比: {result['profit_factor']:.2f}")
- # 排序结果
- all_results.sort(key=lambda x: x['total_return'], reverse=True)
- print("\n" + "="*70)
- print("参数优化结果排名")
- print("="*70)
- for i, r in enumerate(all_results[:5]):
- print(f"\n第{i+1}名:")
- print(f" 参数: 趋势概率阈值={r['params']['min_trend_prob']}, "
- f"要求日线向上={r['params']['require_daily_uptrend']}")
- print(f" 收益率: {r['total_return']:+.2f}%")
- print(f" 交易次数: {r['trade_count']}")
- print(f" 胜率: {r['win_rate']:.1f}%")
- print(f" 盈亏比: {r['profit_factor']:.2f}")
- print(f" 被日线过滤: {r['blocked_daily']}次")
- print(f" 被30分钟过滤: {r['blocked_regime']}次")
- # 保存优化结果
- result_file = f"{output_dir}/parameter_optimization_results.json"
- with open(result_file, 'w', encoding='utf-8') as f:
- json.dump(all_results, f, indent=2, ensure_ascii=False)
- print(f"\n优化结果已保存: {result_file}")
- return all_results
- if __name__ == '__main__':
- DATA_FILE = 'cyb50_30min_2023_to_20260325.csv'
- DAILY_FILE = '../data-fetch/data/399673_SZ_day_20150101_20260325.csv'
- REGIME_FILE = '../../market-regime-identifier-30/cyb50_30min_regime_result.csv'
- results = run_parameter_scan(DATA_FILE, DAILY_FILE, REGIME_FILE)
|