| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- CYB50 最优参数完整回测 + 数据核对
- 参数: min_trend_prob=0.3, require_daily_uptrend=True
- """
- import csv
- import json
- from datetime import datetime, timedelta
- from collections import deque
- import math
- import os
- class TechnicalIndicators:
- @staticmethod
- def sma(data, period):
- if len(data) < period:
- return None
- return sum(data[-period:]) / period
- @staticmethod
- def rsi(prices, period=14):
- if len(prices) < period + 1:
- return None
- gains, losses = [], []
- for i in range(1, len(prices)):
- change = prices[i] - prices[i-1]
- gains.append(change if change > 0 else 0)
- losses.append(abs(change) if change < 0 else 0)
- avg_gain = sum(gains[-period:]) / period
- avg_loss = sum(losses[-period:]) / period
- if avg_loss == 0:
- return 100
- return 100 - (100 / (1 + avg_gain / avg_loss))
- @staticmethod
- def bollinger_bands(prices, period=20, std_dev=2):
- if len(prices) < period:
- return None, None, None
- middle = sum(prices[-period:]) / period
- variance = sum((p - middle) ** 2 for p in prices[-period:]) / period
- std = math.sqrt(variance)
- return middle + std*std_dev, middle, middle - std*std_dev
- @staticmethod
- def macd(prices, fast=12, slow=26, signal=9):
- if len(prices) < slow:
- return None, None, None
- def calc_ema(data, period):
- mult = 2 / (period + 1)
- ema = data[0]
- for p in data[1:]:
- ema = (p - ema) * mult + ema
- return ema
- macd_vals = []
- for i in range(slow, len(prices)+1):
- f = calc_ema(prices[i-fast:i], fast)
- s = calc_ema(prices[i-slow:i], slow)
- macd_vals.append(f - s)
- sig = calc_ema(macd_vals[-signal:], signal) if len(macd_vals) >= signal else None
- return macd_vals[-1], sig, macd_vals[-1] - sig if sig else None
- class DailyTrendManager:
- def __init__(self, daily_file):
- self.daily_data = {}
- self.daily_trend = {}
- self.load_daily_data(daily_file)
- self.calculate_daily_trend()
- def load_daily_data(self, filepath):
- with open(filepath, 'r', encoding='utf-8-sig') as f:
- reader = csv.DictReader(f)
- for row in reader:
- try:
- dt = datetime.strptime(row['datetime'], '%Y-%m-%d %H:%M:%S')
- self.daily_data[dt.strftime('%Y-%m-%d')] = {
- 'open': float(row['open']), 'high': float(row['high']),
- 'low': float(row['low']), 'close': float(row['close'])
- }
- except:
- continue
- def calculate_daily_trend(self, ma_period=20):
- dates = sorted(self.daily_data.keys())
- closes = [self.daily_data[d]['close'] for d in dates]
- for i, date in enumerate(dates):
- if i < ma_period - 1:
- self.daily_trend[date] = {'trend': 0, 'ma20': None, 'trend_strength': 0}
- continue
- ma20 = sum(closes[i-ma_period+1:i+1]) / ma_period
- close = closes[i]
- trend = 1 if close > ma20 * 1.02 else (-1 if close < ma20 * 0.98 else 0)
- self.daily_trend[date] = {
- 'trend': trend, 'ma20': ma20,
- 'trend_strength': (close - ma20) / ma20 * 100
- }
- def get_daily_trend(self, date_str):
- return self.daily_trend.get(date_str, {'trend': 0, 'ma20': None, 'trend_strength': 0})
- class MarketRegimeManager:
- def __init__(self, regime_file):
- self.regime_data = {}
- self.load_regime_data(regime_file)
- def load_regime_data(self, filepath):
- with open(filepath, 'r', encoding='utf-8-sig') as f:
- reader = csv.DictReader(f)
- for row in reader:
- self.regime_data[row['datetime']] = {
- 'state': int(row['state']),
- 'prob_trend': float(row['prob_trend'])
- }
- def get_regime(self, dt_str):
- return self.regime_data.get(dt_str, {'state': 0, 'prob_trend': 0.0})
- class BacktestEngine:
- def __init__(self):
- self.initial_capital = 1000000
- self.position_size = 0.5
- self.capital = self.initial_capital
- self.position = 0
- self.entry_price = 0
- self.holding_periods = 0
- self.max_holding_periods = 16
- self.equity_curve = []
- self.trades = []
- self.prices = deque(maxlen=100)
- def calculate_signals(self):
- if len(self.prices) < 50:
- return None
- pl = list(self.prices)
- return {
- 'rsi': TechnicalIndicators.rsi(pl),
- 'bb_middle': TechnicalIndicators.bollinger_bands(pl)[1],
- 'ma5': TechnicalIndicators.sma(pl, 5),
- 'ma10': TechnicalIndicators.sma(pl, 10),
- 'macd': TechnicalIndicators.macd(pl)[0],
- 'macd_signal': TechnicalIndicators.macd(pl)[1],
- 'price': pl[-1]
- }
- def check_long_signal(self, s):
- if not s:
- return False, ""
- c = []
- if s['rsi'] and s['rsi'] < 65: c.append('RSI<65')
- if s['ma5'] and s['ma10'] and s['ma5'] > s['ma10']: c.append('MA5>MA10')
- if s['macd'] and s['macd_signal'] and s['macd'] > s['macd_signal']: c.append('MACD金叉')
- if s['bb_middle'] and s['price'] > s['bb_middle']: c.append('价格>中轨')
- return (True, '+'.join(c)) if len(c) >= 3 else (False, f"{len(c)}/3")
- def check_exit(self, s, price):
- if not s or self.position == 0:
- return False, ""
- if price <= self.entry_price * 0.975: return True, f"止损({price:.2f})"
- if price >= self.entry_price * 1.04: return True, f"止盈({price:.2f})"
- if self.holding_periods >= self.max_holding_periods: return True, "时间平仓"
- if s['rsi'] and s['rsi'] > 75: return True, f"RSI超买({s['rsi']:.1f})"
- return False, ""
- def open(self, price, time_str, reason):
- val = self.capital * self.position_size
- self.position = val / price
- self.entry_price = price
- self.holding_periods = 0
- self.trades.append({'action': 'OPEN', 'time': time_str, 'price': price,
- 'shares': self.position, 'value': val, 'reason': reason})
- def close(self, price, time_str, reason):
- if self.position == 0: return
- pnl = (price - self.entry_price) * self.position
- pnl_pct = (price / self.entry_price - 1) * 100
- self.capital += pnl
- self.trades.append({'action': 'CLOSE', 'time': time_str, 'price': price,
- 'shares': self.position, 'pnl': pnl, 'pnl_pct': pnl_pct, 'reason': reason})
- self.position = 0
- def update(self, ts, o, h, l, c, dm, rm):
- self.prices.append(c)
- dt_str = ts.strftime('%Y-%m-%d %H:%M:%S')
- date_str = ts.strftime('%Y-%m-%d')
- daily = dm.get_daily_trend(date_str)
- regime = rm.get_regime(dt_str)
- equity = self.capital + (self.position * c if self.position > 0 else 0)
- self.equity_curve.append({'time': dt_str, 'equity': equity, 'close': c, 'position': 1 if self.position else 0,
- 'daily_trend': daily['trend'], 'daily_strength': daily['trend_strength'],
- 'regime_state': regime['state'], 'regime_prob': regime['prob_trend']})
- if self.position > 0:
- self.holding_periods += 1
- s = self.calculate_signals()
- ex, reason = self.check_exit(s, c)
- if ex: self.close(c, dt_str, reason)
- else:
- s = self.calculate_signals()
- ok, tech_reason = self.check_long_signal(s)
- if ok and daily['trend'] == 1 and regime['state'] == 1 and regime['prob_trend'] >= 0.3:
- self.open(c, dt_str, f"{tech_reason}|日线向上|30分钟趋势{regime['prob_trend']:.2f}")
- return equity
- def load_data(fp):
- data = []
- with open(fp, 'r', encoding='utf-8-sig') as f:
- for row in csv.DictReader(f):
- try:
- data.append({
- 'datetime': datetime.strptime(row['DateTime'], '%Y-%m-%d %H:%M:%S'),
- 'open': float(row['Open']), 'high': float(row['High']),
- 'low': float(row['Low']), 'close': float(row['Close'])
- })
- except:
- continue
- return data
- def verify_data_integrity(data, dm, rm):
- """核对数据完整性"""
- print("\n" + "="*70)
- print("数据准确性核对报告")
- print("="*70)
- issues = []
- checked = 0
- for row in data:
- dt_str = row['datetime'].strftime('%Y-%m-%d %H:%M:%S')
- date_str = row['datetime'].strftime('%Y-%m-%d')
- # 检查日线数据
- if date_str not in dm.daily_data:
- issues.append(f"缺少日线数据: {date_str}")
- # 检查30分钟状态
- if dt_str not in rm.regime_data:
- issues.append(f"缺少30分钟状态: {dt_str}")
- checked += 1
- if checked % 1000 == 0:
- print(f" 已核对 {checked}/{len(data)} 条数据...")
- print(f"\n数据核对完成:")
- print(f" 总数据条数: {len(data)}")
- print(f" 日线数据: {len(dm.daily_data)}条")
- print(f" 30分钟状态: {len(rm.regime_data)}条")
- print(f" 发现问题: {len(issues)}个")
- if issues:
- print(f"\n前10个问题:")
- for i in issues[:10]:
- print(f" - {i}")
- return len(issues) == 0
- def run_backtest(data_file, daily_file, regime_file, output_dir='final_backtest'):
- os.makedirs(output_dir, exist_ok=True)
- print("加载数据...")
- data = load_data(data_file)
- dm = DailyTrendManager(daily_file)
- rm = MarketRegimeManager(regime_file)
- # 核对数据
- data_ok = verify_data_integrity(data, dm, rm)
- if not data_ok:
- print("\n[警告] 数据存在问题,但继续回测...")
- print("\n运行最优参数回测...")
- engine = BacktestEngine()
- for row in data:
- engine.update(row['datetime'], row['open'], row['high'], row['low'], row['close'], dm, rm)
- # 统计
- initial = engine.initial_capital
- final = engine.equity_curve[-1]['equity']
- total_ret = (final / initial - 1) * 100
- closed = [t for t in engine.trades if t['action'] == 'CLOSE']
- wins = [t for t in closed if t['pnl'] > 0]
- losses = [t for t in closed if t['pnl'] <= 0]
- win_rate = len(wins) / len(closed) * 100 if closed else 0
- total_profit = sum(t['pnl'] for t in wins) if wins else 0
- total_loss = sum(t['pnl'] for t in losses) if losses else 0
- profit_factor = abs(total_profit / total_loss) if total_loss else 0
- # 计算最大回撤
- peak = initial
- max_dd = 0
- for e in engine.equity_curve:
- if e['equity'] > peak:
- peak = e['equity']
- dd = (peak - e['equity']) / peak * 100
- if dd > max_dd:
- max_dd = dd
- # 保存权益曲线
- with open(f"{output_dir}/equity_final.csv", 'w', newline='') as f:
- w = csv.DictWriter(f, fieldnames=['time', 'equity', 'close', 'position', 'daily_trend', 'daily_strength', 'regime_state', 'regime_prob'])
- w.writeheader()
- w.writerows(engine.equity_curve)
- # 保存交易记录
- with open(f"{output_dir}/trades_final.csv", 'w', newline='') as f:
- if engine.trades:
- # 获取所有可能的字段
- all_fields = set()
- for t in engine.trades:
- all_fields.update(t.keys())
- fieldnames = sorted(all_fields)
- w = csv.DictWriter(f, fieldnames=fieldnames)
- w.writeheader()
- w.writerows(engine.trades)
- # 生成详细报告
- report = f"""
- ================================================================================
- CYB50 最优参数回测报告 - 详细版
- ================================================================================
- 回测参数:
- - 初始资金: 1,000,000 元
- - 持仓上限: 50%
- - 30分钟趋势概率阈值: 0.3 (最优)
- - 日线要求: 必须向上 (MA20之上)
- - 止损: -2.5% | 止盈: +4% | 最大持仓: 16周期(8小时)
- ================================================================================
- 整体表现
- ================================================================================
- 初始资金: {initial:>15,.2f} 元
- 最终资金: {final:>15,.2f} 元
- 净盈亏: {final-initial:>15,.2f} 元
- 总收益率: {total_ret:>15.2f} %
- 最大回撤: {max_dd:>15.2f} %
- ================================================================================
- 交易统计
- ================================================================================
- 总交易次数: {len(closed):>15} 笔
- 盈利次数: {len(wins):>15} 笔
- 亏损次数: {len(losses):>15} 笔
- 胜率: {win_rate:>15.2f} %
- 盈亏比: {profit_factor:>15.2f}
- 总盈利: {total_profit:>15,.2f} 元
- 总亏损: {total_loss:>15,.2f} 元
- 平均每笔盈利: {total_profit/len(wins) if wins else 0:>15,.2f} 元
- 平均每笔亏损: {total_loss/len(losses) if losses else 0:>15,.2f} 元
- ================================================================================
- 最近10笔交易明细
- ================================================================================
- """
- for t in closed[-10:]:
- report += f" {t['time']} | 平仓价: {t['price']:.2f} | 盈亏: {t['pnl']:>+10,.2f} ({t['pnl_pct']:+.2f}%) | {t['reason']}\n"
- report += f"""
- ================================================================================
- 数据核对结果
- ================================================================================
- 30分钟数据条数: {len(data)} 条
- 日线数据条数: {len(dm.daily_data)} 条
- 30分钟状态条数: {len(rm.regime_data)} 条
- 数据完整性: {'通过 ✓' if data_ok else '存在问题 ✗'}
- ================================================================================
- 文件输出
- ================================================================================
- - {output_dir}/equity_final.csv (权益曲线)
- - {output_dir}/trades_final.csv (交易明细)
- - {output_dir}/report_final.txt (本报告)
- ================================================================================
- """
- with open(f"{output_dir}/report_final.txt", 'w') as f:
- f.write(report)
- print(report)
- print(f"\n所有文件已保存到: {output_dir}/")
- return engine
- if __name__ == '__main__':
- run_backtest(
- 'cyb50_30min_2023_to_20260325.csv',
- '../data-fetch/data/399673_SZ_day_20150101_20260325.csv',
- '../../market-regime-identifier-30/cyb50_30min_regime_result.csv'
- )
|