cyb50_parameter_optimization.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. import pandas as pd
  2. import numpy as np
  3. import itertools
  4. import json
  5. import time
  6. from datetime import datetime, timedelta
  7. from typing import Dict, List, Tuple, Any
  8. from concurrent.futures import ThreadPoolExecutor, as_completed
  9. from threading import Lock
  10. import warnings
  11. import sys
  12. import io
  13. warnings.filterwarnings('ignore')
  14. # 导入原有策略模块
  15. from cyb50_30min_dual_direction import (
  16. ConfigManager, IntradayDataFetcher, DualDirectionSignalGenerator,
  17. DualDirectionExecutor, validate_dual_direction_results
  18. )
  19. class ParameterOptimizer:
  20. """参数优化器 - 网格搜索最优参数组合"""
  21. def __init__(self, config_file='config.json', max_workers=4):
  22. self.config_manager = ConfigManager(config_file)
  23. self.optimization_results = []
  24. self.best_params = None
  25. self.best_score = None
  26. self.max_workers = max_workers
  27. self.results_lock = Lock() # 线程锁,保护共享数据
  28. def define_parameter_grid(self) -> Dict[str, List]:
  29. """定义参数搜索网格 - 精简优化版本"""
  30. parameter_grid = {
  31. # 核心交易参数(重点优化)
  32. 'position_size_pct': [0.7, 0.85, 1.0],
  33. 'stop_loss_pct': [0.006, 0.008, 0.010, 0.012],
  34. 'take_profit_pct': [0.012, 0.015, 0.018, 0.022],
  35. 'max_hold_bars': [14, 16, 18, 22],
  36. 'min_signal_strength': [3, 4, 5],
  37. # RSI参数(精简)
  38. 'rsi_oversold': [25, 30],
  39. 'rsi_overbought': [70, 75],
  40. # KDJ参数(精简)
  41. 'kdj_oversold': [20],
  42. 'kdj_overbought': [80],
  43. # 成交量参数(精简)
  44. 'volume_ratio_threshold': [1.2],
  45. # 布林带参数(精简)
  46. 'bb_upper_threshold': [0.995],
  47. 'bb_lower_threshold': [1.005],
  48. # MACD参数(固定)
  49. 'macd_fast': [12],
  50. 'macd_slow': [26],
  51. # 连续涨跌参数(精简)
  52. 'consecutive_bars': [4],
  53. 'consecutive_change_pct': [0.015]
  54. }
  55. return parameter_grid
  56. def generate_parameter_combinations(self, parameter_grid: Dict[str, List]) -> List[Dict[str, Any]]:
  57. """生成所有参数组合"""
  58. param_names = list(parameter_grid.keys())
  59. param_values = list(parameter_grid.values())
  60. total_combinations = 1
  61. for values in param_values:
  62. total_combinations *= len(values)
  63. print(f"参数网格总组合数: {total_combinations:,}")
  64. print(f"参数维度: {len(param_names)}")
  65. print("\n参数搜索范围:")
  66. for name, values in parameter_grid.items():
  67. print(f" {name}: {values}")
  68. # 生成所有组合
  69. combinations = []
  70. for combination in itertools.product(*param_values):
  71. params = dict(zip(param_names, combination))
  72. combinations.append(params)
  73. return combinations
  74. def calculate_performance_score(self, trades_df: pd.DataFrame, results_df: pd.DataFrame,
  75. initial_capital: float) -> Dict[str, float]:
  76. """计算综合性能得分"""
  77. if len(trades_df) == 0:
  78. return {
  79. 'total_return': -100,
  80. 'sharpe_ratio': -999,
  81. 'max_drawdown': -100,
  82. 'win_rate': 0,
  83. 'profit_factor': 0,
  84. 'total_trades': 0,
  85. 'composite_score': -9999
  86. }
  87. # 基础指标
  88. final_capital = results_df['net_value'].iloc[-1]
  89. total_return = (final_capital - initial_capital) / initial_capital * 100
  90. # 夏普比率计算
  91. returns = results_df['net_value'].pct_change().dropna()
  92. sharpe_ratio = returns.mean() / returns.std() * np.sqrt(252 * 16) if returns.std() > 0 else 0
  93. # 最大回撤
  94. cumulative_returns = (1 + returns).cumprod()
  95. running_max = cumulative_returns.expanding().max()
  96. drawdown = (cumulative_returns - running_max) / running_max
  97. max_drawdown = drawdown.min() * 100
  98. # 胜率
  99. winning_trades = trades_df[trades_df['盈亏金额'] > 0]
  100. losing_trades = trades_df[trades_df['盈亏金额'] < 0]
  101. win_rate = len(winning_trades) / len(trades_df) * 100 if len(trades_df) > 0 else 0
  102. # 盈亏比
  103. avg_win = winning_trades['盈亏金额'].mean() if len(winning_trades) > 0 else 0
  104. avg_loss = abs(losing_trades['盈亏金额'].mean()) if len(losing_trades) > 0 else 1
  105. profit_factor = avg_win / avg_loss if avg_loss > 0 else 0
  106. # 综合得分计算 (可自定义权重)
  107. composite_score = (
  108. total_return * 0.3 + # 30%权重收益率
  109. sharpe_ratio * 10 + # 夏普比率权重
  110. max_drawdown * 0.2 + # 20%权重回撤控制
  111. win_rate * 0.2 + # 20%权重胜率
  112. profit_factor * 5 # 盈亏比权重
  113. )
  114. return {
  115. 'total_return': total_return,
  116. 'sharpe_ratio': sharpe_ratio,
  117. 'max_drawdown': max_drawdown,
  118. 'win_rate': win_rate,
  119. 'profit_factor': profit_factor,
  120. 'total_trades': len(trades_df),
  121. 'composite_score': composite_score
  122. }
  123. def run_single_backtest(self, params: Dict[str, Any], backtest_data: pd.DataFrame,
  124. initial_capital: float) -> Tuple[bool, Dict, Any]:
  125. """运行单次回测 - 静默模式(使用预加载数据)"""
  126. try:
  127. # 创建自定义信号生成器和执行器
  128. signal_gen = CustomSignalGenerator(params, silent_mode=True)
  129. executor = CustomExecutor(params, initial_capital, silent_mode=True)
  130. if len(backtest_data) < 50:
  131. return False, {'error': '数据不足'}, None, None
  132. # 生成信号
  133. signals_df = signal_gen.generate_dual_direction_signals(backtest_data)
  134. # 执行交易
  135. results_df, trades_df = executor.execute_dual_direction_trades(signals_df)
  136. # 计算性能指标
  137. performance = self.calculate_performance_score(trades_df, results_df, initial_capital)
  138. return True, performance, trades_df, results_df
  139. except Exception as e:
  140. return False, {'error': str(e)}, None, None
  141. def run_single_backtest_thread(self, params: Dict[str, Any], backtest_data: pd.DataFrame,
  142. initial_capital: float) -> Tuple[bool, Dict, Any]:
  143. """线程安全的单次回测(使用预加载数据)"""
  144. return self.run_single_backtest(params, backtest_data, initial_capital)
  145. def run_optimization(self, max_iterations: int = None, time_limit: int = 3600):
  146. """运行参数优化 - 多线程版本"""
  147. print("=" * 80)
  148. print("创业板50策略参数优化 - 多线程网格搜索")
  149. print("=" * 80)
  150. # 获取回测配置
  151. BACKTEST_START_DATE = self.config_manager.get('strategy', 'backtest_start_date', "2025-10-01")
  152. backtest_end_config = self.config_manager.get('strategy', 'backtest_end_date', "now")
  153. BACKTEST_END_DATE = datetime.now().strftime('%Y-%m-%d') if backtest_end_config.lower() == "now" else backtest_end_config
  154. INITIAL_CAPITAL = self.config_manager.get('strategy', 'initial_capital', 1000000)
  155. start_date = datetime.strptime(BACKTEST_START_DATE, "%Y-%m-%d")
  156. end_date = datetime.strptime(BACKTEST_END_DATE, "%Y-%m-%d").replace(hour=23, minute=59, second=59)
  157. print(f"回测期间: {BACKTEST_START_DATE} 至 {BACKTEST_END_DATE}")
  158. print(f"初始资金: {INITIAL_CAPITAL:,}元")
  159. print(f"并发线程数: {self.max_workers}")
  160. # 预加载数据(只加载一次)
  161. print(f"\n预加载回测数据...")
  162. # 重定向标准输出来控制数据加载时的混乱输出
  163. old_stdout = sys.stdout
  164. sys.stdout = io.StringIO()
  165. try:
  166. fetcher = IntradayDataFetcher(self.config_manager)
  167. prewamp_days = self.config_manager.get('strategy', 'prewamp_days', 30)
  168. data_start_date = start_date - timedelta(days=prewamp_days)
  169. full_data = fetcher.fetch_30min_data(start_date=data_start_date, end_date=end_date)
  170. full_data = fetcher.calculate_intraday_indicators(full_data)
  171. backtest_data = full_data[(full_data.index >= start_date) & (full_data.index <= end_date)].copy()
  172. finally:
  173. # 恢复标准输出
  174. sys.stdout = old_stdout
  175. print(f"数据预加载完成: {len(backtest_data)}条数据")
  176. print(f"数据范围: {backtest_data.index[0]} 到 {backtest_data.index[-1]}")
  177. # 生成参数组合
  178. parameter_grid = self.define_parameter_grid()
  179. all_combinations = self.generate_parameter_combinations(parameter_grid)
  180. # 限制迭代次数
  181. if max_iterations and max_iterations < len(all_combinations):
  182. import random
  183. all_combinations = random.sample(all_combinations, max_iterations)
  184. print(f"随机选择 {max_iterations} 个组合进行测试")
  185. # 运行优化
  186. start_time = time.time()
  187. successful_tests = 0
  188. failed_tests = 0
  189. last_progress_time = start_time
  190. print(f"\n开始多线程网格搜索...")
  191. print(f"时间限制: {time_limit}秒 ({time_limit/3600:.1f}小时)" if time_limit else "无时间限制")
  192. print(f"预期总时间: {len(all_combinations)/self.max_workers * 2:.0f}-{len(all_combinations)/self.max_workers * 5:.0f}分钟")
  193. # 使用线程池执行
  194. with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
  195. # 提交所有任务
  196. future_to_params = {}
  197. for i, params in enumerate(all_combinations):
  198. future = executor.submit(
  199. self.run_single_backtest_thread,
  200. params, backtest_data, INITIAL_CAPITAL
  201. )
  202. future_to_params[future] = (i, params)
  203. # 控制提交速度,避免内存占用过大
  204. if len(future_to_params) >= 100: # 最多同时处理100个任务
  205. # 等待一些任务完成
  206. completed_futures = []
  207. for future in as_completed(list(future_to_params.keys())):
  208. completed_futures.append(future)
  209. if len(completed_futures) >= 50: # 完成一半后继续提交
  210. break
  211. # 处理完成的任务
  212. for future in completed_futures:
  213. i, params = future_to_params[future]
  214. try:
  215. success, performance, trades_df, results_df = future.result()
  216. with self.results_lock: # 使用锁保护共享数据
  217. if success:
  218. successful_tests += 1
  219. result = {
  220. 'params': params.copy(),
  221. 'performance': performance,
  222. 'trades_count': len(trades_df) if trades_df is not None else 0
  223. }
  224. self.optimization_results.append(result)
  225. # 更新最优参数
  226. if self.best_score is None or performance['composite_score'] > self.best_score:
  227. self.best_score = performance['composite_score']
  228. self.best_params = params.copy()
  229. self.best_performance = performance
  230. print(f" 🎯 发现更优参数! 综合得分: {performance['composite_score']:.2f}")
  231. else:
  232. failed_tests += 1
  233. # 增加进度打印频率
  234. total_tests = successful_tests + failed_tests
  235. current_time = time.time()
  236. if total_tests % 5 == 0 or (current_time - last_progress_time) > 30: # 每5个测试或每30秒打印一次
  237. elapsed = current_time - start_time
  238. progress = total_tests / len(all_combinations) * 100
  239. avg_time_per_test = elapsed / total_tests
  240. remaining_tests = len(all_combinations) - total_tests
  241. eta_seconds = remaining_tests * avg_time_per_test / self.max_workers
  242. print(f"进度: {total_tests}/{len(all_combinations)} ({progress:.1f}%) | "
  243. f"成功: {successful_tests} 失败: {failed_tests} | "
  244. f"已用时: {elapsed/60:.1f}分钟 | "
  245. f"预计剩余: {eta_seconds/60:.1f}分钟")
  246. last_progress_time = current_time
  247. # 检查时间限制
  248. if time_limit and (current_time - start_time) > time_limit:
  249. print(f"\n⏰ 达到时间限制 {time_limit/3600:.1f}小时,停止优化")
  250. executor.shutdown(wait=False)
  251. break
  252. except Exception as e:
  253. with self.results_lock:
  254. failed_tests += 1
  255. print(f"回测异常: {e}")
  256. # 移除已处理的future
  257. del future_to_params[future]
  258. # 处理剩余的任务
  259. for future in as_completed(future_to_params.keys()):
  260. i, params = future_to_params[future]
  261. try:
  262. success, performance, trades_df, results_df = future.result()
  263. with self.results_lock:
  264. if success:
  265. successful_tests += 1
  266. result = {
  267. 'params': params.copy(),
  268. 'performance': performance,
  269. 'trades_count': len(trades_df) if trades_df is not None else 0
  270. }
  271. self.optimization_results.append(result)
  272. if self.best_score is None or performance['composite_score'] > self.best_score:
  273. self.best_score = performance['composite_score']
  274. self.best_params = params.copy()
  275. self.best_performance = performance
  276. print(f" 🎯 发现更优参数! 综合得分: {performance['composite_score']:.2f}")
  277. else:
  278. failed_tests += 1
  279. # 增加进度打印
  280. total_tests = successful_tests + failed_tests
  281. current_time = time.time()
  282. if total_tests % 5 == 0 or (current_time - last_progress_time) > 30:
  283. elapsed = current_time - start_time
  284. progress = total_tests / len(all_combinations) * 100
  285. avg_time_per_test = elapsed / total_tests
  286. remaining_tests = len(all_combinations) - total_tests
  287. eta_seconds = remaining_tests * avg_time_per_test / self.max_workers
  288. print(f"进度: {total_tests}/{len(all_combinations)} ({progress:.1f}%) | "
  289. f"成功: {successful_tests} 失败: {failed_tests} | "
  290. f"已用时: {elapsed/60:.1f}分钟 | "
  291. f"预计剩余: {eta_seconds/60:.1f}分钟")
  292. last_progress_time = current_time
  293. except Exception as e:
  294. with self.results_lock:
  295. failed_tests += 1
  296. print(f"回测异常: {e}")
  297. elapsed_time = time.time() - start_time
  298. # 输出优化结果
  299. print(f"\n优化完成!")
  300. print(f"总测试次数: {len(all_combinations)}")
  301. print(f"成功测试: {successful_tests}")
  302. print(f"失败测试: {failed_tests}")
  303. print(f"总用时: {elapsed_time:.1f}秒")
  304. print(f"平均每次测试: {elapsed_time/len(all_combinations):.2f}秒")
  305. print(f"线程并发效果: 理论加速比 {self.max_workers}x,实际加速比 {elapsed_time/(elapsed_time/self.max_workers):.1f}x")
  306. # 分析结果
  307. self.analyze_results()
  308. def analyze_results(self):
  309. """分析优化结果"""
  310. if not self.optimization_results:
  311. print("没有成功的优化结果")
  312. return
  313. print("\n" + "=" * 80)
  314. print("优化结果分析")
  315. print("=" * 80)
  316. # 转换为DataFrame便于分析
  317. results_data = []
  318. for result in self.optimization_results:
  319. row = result['params'].copy()
  320. row.update(result['performance'])
  321. results_data.append(row)
  322. results_df = pd.DataFrame(results_data)
  323. # 排序找到最优参数
  324. top_results = results_df.nlargest(10, 'composite_score')
  325. print(f"\n🏆 TOP 10 最优参数组合:")
  326. print("-" * 80)
  327. for i, (_, row) in enumerate(top_results.iterrows(), 1):
  328. print(f"\n第{i}名 (综合得分: {row['composite_score']:.2f})")
  329. print(f" 收益率: {row['total_return']:.2f}%")
  330. print(f" 夏普比率: {row['sharpe_ratio']:.2f}")
  331. print(f" 最大回撤: {row['max_drawdown']:.2f}%")
  332. print(f" 胜率: {row['win_rate']:.1f}%")
  333. print(f" 盈亏比: {row['profit_factor']:.2f}")
  334. print(f" 交易次数: {int(row['total_trades'])}")
  335. print(f" 关键参数:")
  336. print(f" 仓位比例: {row['position_size_pct']:.1f}")
  337. print(f" 止损: {row['stop_loss_pct']*100:.1f}%")
  338. print(f" 止盈: {row['take_profit_pct']*100:.1f}%")
  339. print(f" 最大持仓: {int(row['max_hold_bars'])}周期")
  340. print(f" 信号强度: {int(row['min_signal_strength'])}")
  341. print(f" RSI超卖/超买: {int(row['rsi_oversold'])}/{int(row['rsi_overbought'])}")
  342. print(f" KDJ超卖/超买: {int(row['kdj_oversold'])}/{int(row['kdj_overbought'])}")
  343. # 参数敏感性分析
  344. print(f"\n📊 参数敏感性分析:")
  345. print("-" * 80)
  346. key_params = ['position_size_pct', 'stop_loss_pct', 'take_profit_pct', 'max_hold_bars']
  347. for param in key_params:
  348. if param in results_df.columns:
  349. param_analysis = results_df.groupby(param)['composite_score'].agg(['mean', 'std', 'min', 'max'])
  350. print(f"\n{param}:")
  351. for value, stats in param_analysis.iterrows():
  352. print(f" {value}: 平均得分 {stats['mean']:.2f} (±{stats['std']:.2f})")
  353. # 保存结果
  354. self.save_optimization_results(results_df)
  355. def save_optimization_results(self, results_df: pd.DataFrame):
  356. """保存优化结果"""
  357. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  358. output_file = f'cyb50_optimization_results_{timestamp}.csv'
  359. # 按综合得分排序
  360. results_df = results_df.sort_values('composite_score', ascending=False)
  361. results_df.to_csv(output_file, index=False, encoding='utf-8-sig')
  362. print(f"\n💾 优化结果已保存到: {output_file}")
  363. # 保存最优参数到配置文件
  364. if self.best_params:
  365. best_params_file = 'best_parameters.json'
  366. with open(best_params_file, 'w', encoding='utf-8') as f:
  367. json.dump({
  368. 'best_params': self.best_params,
  369. 'best_performance': self.best_performance,
  370. 'optimization_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  371. }, f, indent=2, ensure_ascii=False)
  372. print(f"💾 最优参数已保存到: {best_params_file}")
  373. class CustomSignalGenerator(DualDirectionSignalGenerator):
  374. """自定义信号生成器 - 支持动态参数"""
  375. def __init__(self, params: Dict[str, Any], silent_mode=False):
  376. super().__init__()
  377. self.params = params
  378. self.silent_mode = silent_mode
  379. def _calculate_long_signals(self, current_bar, df, i):
  380. """计算做多信号强度 - 使用自定义参数"""
  381. long_score = 0
  382. long_signals = []
  383. # RSI超卖做多
  384. rsi_oversold = self.params.get('rsi_oversold', 30)
  385. if current_bar['RSI'] < rsi_oversold:
  386. long_score += 2
  387. long_signals.append(f"RSI超卖(<{rsi_oversold})")
  388. elif current_bar['RSI'] < rsi_oversold + 5:
  389. long_score += 1
  390. long_signals.append("RSI偏弱")
  391. # KDJ超卖做多
  392. kdj_oversold = self.params.get('kdj_oversold', 20)
  393. if current_bar['K'] < kdj_oversold and current_bar['D'] < kdj_oversold:
  394. long_score += 2
  395. long_signals.append(f"KDJ超卖(<{kdj_oversold})")
  396. elif current_bar['J'] < 0:
  397. long_score += 1
  398. long_signals.append("KDJ极端超卖")
  399. # MACD金叉
  400. macd_fast = self.params.get('macd_fast', 12)
  401. macd_slow = self.params.get('macd_slow', 26)
  402. if current_bar['MACD_hist'] > 0 and df.iloc[i-1]['MACD_hist'] <= 0:
  403. long_score += 2
  404. long_signals.append("MACD金叉")
  405. elif current_bar['MACD_hist'] > df.iloc[i-1]['MACD_hist']:
  406. long_score += 1
  407. long_signals.append("MACD改善")
  408. # 价格触及布林带下轨
  409. bb_lower_threshold = self.params.get('bb_lower_threshold', 1.005)
  410. if current_bar['Close'] <= current_bar['BB_lower'] * bb_lower_threshold:
  411. long_score += 2
  412. long_signals.append("触及下轨")
  413. elif current_bar['Close'] <= current_bar['BB_lower'] * (bb_lower_threshold + 0.005):
  414. long_score += 1
  415. long_signals.append("接近下轨")
  416. # 连续下跌后的反转
  417. consecutive_bars = self.params.get('consecutive_bars', 4)
  418. consecutive_change_pct = self.params.get('consecutive_change_pct', 0.015)
  419. recent_returns = df.iloc[i-6:i]['Returns']
  420. if recent_returns.min() < -consecutive_change_pct:
  421. consecutive_decline = sum(recent_returns < 0)
  422. if consecutive_decline >= consecutive_bars:
  423. long_score += 2
  424. long_signals.append("连续下跌反转")
  425. # 成交量配合
  426. volume_ratio_threshold = self.params.get('volume_ratio_threshold', 1.2)
  427. if current_bar['Volume_Ratio'] > volume_ratio_threshold:
  428. long_score += 1
  429. long_signals.append("放量配合")
  430. # 当日开盘价格关系
  431. try:
  432. daily_high = df[df.index.date == df.index[i].date()]['High'].max()
  433. daily_low = df[df.index.date == df.index[i].date()]['Low'].min()
  434. daily_range = daily_high - daily_low
  435. if daily_range > 0:
  436. position_in_day = (current_bar['Close'] - daily_low) / daily_range
  437. if position_in_day < 0.3:
  438. long_score += 1
  439. long_signals.append("日内低位")
  440. except:
  441. pass
  442. # MA趋势过滤
  443. if current_bar['MA6'] < current_bar['MA12'] < current_bar['MA24']:
  444. long_score -= 1
  445. long_signals.append("MA下降趋势惩罚")
  446. elif current_bar['MA6'] > current_bar['MA12']:
  447. long_score += 1
  448. long_signals.append("MA短期上行")
  449. return long_score, long_signals
  450. def _calculate_short_signals(self, current_bar, df, i):
  451. """计算做空信号强度 - 使用自定义参数"""
  452. short_score = 0
  453. short_signals = []
  454. # RSI超买做空
  455. rsi_overbought = self.params.get('rsi_overbought', 70)
  456. if current_bar['RSI'] > rsi_overbought:
  457. short_score += 2
  458. short_signals.append(f"RSI超买(>{rsi_overbought})")
  459. elif current_bar['RSI'] > rsi_overbought - 5:
  460. short_score += 1
  461. short_signals.append("RSI偏强")
  462. # KDJ超买做空
  463. kdj_overbought = self.params.get('kdj_overbought', 80)
  464. if current_bar['K'] > kdj_overbought and current_bar['D'] > kdj_overbought:
  465. short_score += 2
  466. short_signals.append(f"KDJ超买(>{kdj_overbought})")
  467. elif current_bar['J'] > 100:
  468. short_score += 1
  469. short_signals.append("KDJ极端超买")
  470. # MACD死叉
  471. if current_bar['MACD_hist'] < 0 and df.iloc[i-1]['MACD_hist'] >= 0:
  472. short_score += 2
  473. short_signals.append("MACD死叉")
  474. elif current_bar['MACD_hist'] < df.iloc[i-1]['MACD_hist']:
  475. short_score += 1
  476. short_signals.append("MACD恶化")
  477. # 价格触及布林带上轨
  478. bb_upper_threshold = self.params.get('bb_upper_threshold', 0.995)
  479. if current_bar['Close'] >= current_bar['BB_upper'] * bb_upper_threshold:
  480. short_score += 2
  481. short_signals.append("触及上轨")
  482. elif current_bar['Close'] >= current_bar['BB_upper'] * (bb_upper_threshold - 0.005):
  483. short_score += 1
  484. short_signals.append("接近上轨")
  485. # 连续上涨后的反转
  486. consecutive_bars = self.params.get('consecutive_bars', 4)
  487. consecutive_change_pct = self.params.get('consecutive_change_pct', 0.015)
  488. recent_returns = df.iloc[i-6:i]['Returns']
  489. if recent_returns.max() > consecutive_change_pct:
  490. consecutive_rise = sum(recent_returns > 0)
  491. if consecutive_rise >= consecutive_bars:
  492. short_score += 2
  493. short_signals.append("连续上涨反转")
  494. # 成交量配合
  495. volume_ratio_threshold = self.params.get('volume_ratio_threshold', 1.2)
  496. if current_bar['Volume_Ratio'] > volume_ratio_threshold:
  497. short_score += 1
  498. short_signals.append("放量配合")
  499. # 当日开盘价格关系
  500. try:
  501. daily_high = df[df.index.date == df.index[i].date()]['High'].max()
  502. daily_low = df[df.index.date == df.index[i].date()]['Low'].min()
  503. daily_range = daily_high - daily_low
  504. if daily_range > 0:
  505. position_in_day = (current_bar['Close'] - daily_low) / daily_range
  506. if position_in_day > 0.7:
  507. short_score += 1
  508. short_signals.append("日内高位")
  509. except:
  510. pass
  511. # MA趋势过滤
  512. if current_bar['MA6'] > current_bar['MA12'] > current_bar['MA24']:
  513. short_score -= 1
  514. short_signals.append("MA上升趋势惩罚")
  515. elif current_bar['MA6'] < current_bar['MA12']:
  516. short_score += 1
  517. short_signals.append("MA短期下行")
  518. return short_score, short_signals
  519. def generate_dual_direction_signals(self, data: pd.DataFrame) -> pd.DataFrame:
  520. """生成多空双向信号 - 静默模式"""
  521. if not self.silent_mode:
  522. print("正在生成多空双向信号...")
  523. signals = []
  524. df = data.copy()
  525. for i in range(24, len(df)): # 至少需要12小时(24个30分钟)的历史数据
  526. current_bar = df.iloc[i]
  527. current_time = df.index[i]
  528. # 跳过不适合交易的时间段
  529. if hasattr(current_time, 'hour'): # 有小时信息的30分钟数据
  530. hour = current_time.hour
  531. if hour < 9 or hour > 15: # 只在交易时间内
  532. continue
  533. # 生成基础信号数据
  534. signal = {
  535. 'DateTime': str(current_time),
  536. 'Open': current_bar['Open'],
  537. 'High': current_bar['High'],
  538. 'Low': current_bar['Low'],
  539. 'Close': current_bar['Close'],
  540. 'Volume': current_bar['Volume'],
  541. 'RSI': current_bar['RSI'],
  542. 'MACD': current_bar['MACD'],
  543. 'MACD_hist': current_bar['MACD_hist'],
  544. 'K': current_bar['K'],
  545. 'D': current_bar['D'],
  546. 'J': current_bar['J'],
  547. 'ATR_Pct': current_bar['ATR_Pct'],
  548. 'Volume_Ratio': current_bar['Volume_Ratio'],
  549. 'Price_Momentum': current_bar['Price_Momentum'],
  550. 'Close_Open_Pct': current_bar['Close_Open_Pct']
  551. }
  552. # 计算做多信号强度
  553. long_score, long_signals = self._calculate_long_signals(current_bar, df, i)
  554. # 计算做空信号强度
  555. short_score, short_signals = self._calculate_short_signals(current_bar, df, i)
  556. # 设置信号分数和描述
  557. signal['Long_Score'] = long_score
  558. signal['Long_Signals'] = ', '.join(long_signals) if long_signals else ''
  559. signal['Short_Score'] = short_score
  560. signal['Short_Signals'] = ', '.join(short_signals) if short_signals else ''
  561. # 决定最终信号方向和强度
  562. final_signal = 0
  563. signal_type = ''
  564. # 信号优先级和冲突处理
  565. if long_score >= 4 and short_score >= 4:
  566. # 两个方向都达到阈值,选择信号强度更高的
  567. if long_score > short_score:
  568. final_signal = 1
  569. signal_type = f'做多翻转(强度{long_score} vs {short_score})'
  570. self.long_signal_count += 1
  571. elif short_score > long_score:
  572. final_signal = -1
  573. signal_type = f'做空反转(强度{short_score} vs {long_score})'
  574. self.short_signal_count += 1
  575. else:
  576. # 强度相等时,根据当前价格位置决定
  577. bb_position = (current_bar['Close'] - current_bar['BB_lower']) / (current_bar['BB_upper'] - current_bar['BB_lower'])
  578. if bb_position < 0.3: # 偏向下轨,优先做多
  579. final_signal = 1
  580. signal_type = f'做多翻转(位置优先)'
  581. self.long_signal_count += 1
  582. elif bb_position > 0.7: # 偏向上轨,优先做空
  583. final_signal = -1
  584. signal_type = f'做空反转(位置优先)'
  585. self.short_signal_count += 1
  586. else:
  587. # 中间位置,暂不开仓
  588. final_signal = 0
  589. signal_type = '信号冲突(强度相等)'
  590. elif long_score >= 4:
  591. final_signal = 1
  592. signal_type = '做多翻转'
  593. self.long_signal_count += 1
  594. elif short_score >= 4:
  595. final_signal = -1
  596. signal_type = '做空反转'
  597. self.short_signal_count += 1
  598. self.total_signal_count = self.long_signal_count + self.short_signal_count
  599. signal['Signal'] = final_signal
  600. signal['Signal_Type'] = signal_type
  601. signals.append(signal)
  602. signals_df = pd.DataFrame(signals)
  603. if len(signals_df) > 0:
  604. signals_df.set_index('DateTime', inplace=True)
  605. if not self.silent_mode:
  606. print(f"多空双向信号生成完成")
  607. print(f"做多信号: {self.long_signal_count}个")
  608. print(f"做空信号: {self.short_signal_count}个")
  609. print(f"总信号: {self.total_signal_count}个")
  610. if len(signals_df) > 0:
  611. print(f"信号密度: {self.total_signal_count/len(signals_df)*100:.2f}%")
  612. return signals_df
  613. class CustomExecutor(DualDirectionExecutor):
  614. """自定义交易执行器 - 支持动态参数"""
  615. def __init__(self, params: Dict[str, Any], initial_capital: float = 1000000, silent_mode=False):
  616. super().__init__(initial_capital)
  617. # 更新参数
  618. self.params['position_size_pct'] = params.get('position_size_pct', 1.0)
  619. self.params['stop_loss_pct'] = params.get('stop_loss_pct', 0.008)
  620. self.params['take_profit_pct'] = params.get('take_profit_pct', 0.02)
  621. self.params['max_hold_bars'] = params.get('max_hold_bars', 16)
  622. self.params['min_signal_strength'] = params.get('min_signal_strength', 4)
  623. self.silent_mode = silent_mode
  624. def execute_dual_direction_trades(self, signals_df: pd.DataFrame) -> tuple:
  625. """执行多空双向交易 - 静默模式"""
  626. if not self.silent_mode:
  627. print("正在执行多空双向交易...")
  628. # 复制原有逻辑,但移除详细打印
  629. df = signals_df.copy()
  630. # 初始化
  631. trades = []
  632. capital = self.initial_capital
  633. # 持仓状态
  634. long_position = 0 # 做多持仓数量
  635. short_position = 0 # 做空持仓数量
  636. long_entry_price = 0 # 做多开仓价
  637. short_entry_price = 0 # 做空开仓价
  638. long_entry_time = None # 做多开仓时间
  639. short_entry_time = None # 做空开仓时间
  640. long_holding_bars = 0 # 做多持仓周期
  641. short_holding_bars = 0 # 做空持仓周期
  642. long_entry_signals = '' # 做多入场信号
  643. short_entry_signals = '' # 做空入场信号
  644. # 添加资金列
  645. df = df.copy()
  646. df['capital'] = capital
  647. df['long_position'] = 0
  648. df['short_position'] = 0
  649. df['net_value'] = capital
  650. for i in range(len(df)):
  651. current_time = df.index[i]
  652. current_bar = df.iloc[i]
  653. price = current_bar['Close']
  654. # 更新当前净值
  655. current_value = capital
  656. if long_position > 0:
  657. current_value += long_position * price
  658. if short_position < 0:
  659. # 做空盈亏
  660. short_pnl = (short_entry_price - price) * abs(short_position)
  661. margin_held = abs(short_position) * short_entry_price
  662. current_value += margin_held + short_pnl
  663. df.iloc[i, df.columns.get_loc('net_value')] = current_value
  664. # 开仓逻辑 - 只在无持仓时开仓
  665. if long_position == 0 and short_position == 0:
  666. # 做多开仓
  667. if current_bar['Signal'] == 1:
  668. position_size = int((capital * self.params['position_size_pct']) / price)
  669. if position_size > 0:
  670. cost = position_size * price * (1 + self.params['commission_rate'] + self.params['slippage_rate'])
  671. if cost <= capital:
  672. long_position = position_size
  673. long_entry_price = price
  674. long_entry_time = current_time
  675. long_entry_signals = current_bar.get('Long_Signals', '')
  676. long_holding_bars = 0
  677. capital -= cost
  678. # 计算预计止损止盈价格
  679. long_stop_loss_price = long_entry_price * (1 - self.params['stop_loss_pct'])
  680. long_take_profit_price = long_entry_price * (1 + self.params['take_profit_pct'])
  681. df.iloc[i, df.columns.get_loc('long_position')] = long_position
  682. # 做空开仓
  683. elif current_bar['Signal'] == -1:
  684. position_value = capital * self.params['position_size_pct']
  685. position_size = int(position_value / price)
  686. if position_size > 0:
  687. margin_required = position_size * price
  688. commission = position_size * price * (self.params['commission_rate'] + self.params['slippage_rate'])
  689. total_cost = margin_required + commission
  690. if total_cost <= capital:
  691. short_position = -position_size
  692. short_entry_price = price
  693. short_entry_time = current_time
  694. short_entry_signals = current_bar.get('Short_Signals', '')
  695. short_holding_bars = 0
  696. capital -= total_cost
  697. # 计算预计止损止盈价格
  698. short_stop_loss_price = short_entry_price * (1 + self.params['stop_loss_pct'])
  699. short_take_profit_price = short_entry_price * (1 - self.params['take_profit_pct'])
  700. df.iloc[i, df.columns.get_loc('short_position')] = short_position
  701. # 平仓逻辑 - 做多平仓
  702. elif long_position > 0:
  703. long_holding_bars += 1
  704. # 计算止损止盈价格
  705. stop_loss = long_entry_price * (1 - self.params['stop_loss_pct'])
  706. take_profit = long_entry_price * (1 + self.params['take_profit_pct'])
  707. exit_signal = False
  708. exit_reason = ''
  709. exit_price = price
  710. # 止损
  711. if price <= stop_loss:
  712. exit_signal = True
  713. loss_pct = (long_entry_price - stop_loss) / long_entry_price * 100
  714. exit_reason = f"做多止损触发(价格{price:.2f}跌破止损线{stop_loss:.2f},亏损{loss_pct:.2f}%)"
  715. exit_price = price
  716. # 止盈
  717. elif price >= take_profit:
  718. exit_signal = True
  719. profit_pct = (price - long_entry_price) / long_entry_price * 100
  720. exit_reason = f"做多止盈触发(价格{price:.2f}突破止盈线{take_profit:.2f},盈利{profit_pct:.2f}%)"
  721. exit_price = price
  722. # 最大持仓时间
  723. elif long_holding_bars >= self.params['max_hold_bars']:
  724. exit_signal = True
  725. current_pnl_pct = (price - long_entry_price) / long_entry_price * 100
  726. exit_reason = f"做多时间止损(持仓{long_holding_bars}周期达上限{self.params['max_hold_bars']}周期,当前盈亏{current_pnl_pct:+.2f}%)"
  727. # 做多信号消失
  728. elif current_bar['RSI'] > 70:
  729. exit_signal = True
  730. current_pnl_pct = (price - long_entry_price) / long_entry_price * 100
  731. exit_reason = f"做多RSI超买平仓(RSI={current_bar['RSI']:.1f}超买,信号消失,当前盈亏{current_pnl_pct:+.2f}%)"
  732. # 执行平仓
  733. if exit_signal:
  734. # 计算盈亏
  735. gross_pnl = (exit_price - long_entry_price) * long_position
  736. open_cost = long_position * long_entry_price * (self.params['commission_rate'] + self.params['slippage_rate'])
  737. close_revenue = long_position * exit_price
  738. close_cost = close_revenue * (self.params['commission_rate'] + self.params['slippage_rate'])
  739. pnl = gross_pnl - open_cost - close_cost
  740. # 更新资金
  741. capital += close_revenue - close_cost
  742. # 记录交易
  743. trade = {
  744. '交易方向': '做多',
  745. '开仓时间': long_entry_time,
  746. '平仓时间': current_time,
  747. '开仓价格': long_entry_price,
  748. '平仓价格': exit_price,
  749. '仓位': long_position,
  750. '盈亏金额': pnl,
  751. '盈亏百分比': (exit_price - long_entry_price) / long_entry_price * 100,
  752. '退出原因': exit_reason,
  753. '持仓周期数': long_holding_bars,
  754. '持仓小时数': long_holding_bars * 0.5,
  755. '入场信号': long_entry_signals,
  756. '平仓时资金': capital,
  757. '开仓市值': long_position * long_entry_price,
  758. '预计止损价格': long_stop_loss_price,
  759. '预计止盈价格': long_take_profit_price
  760. }
  761. trades.append(trade)
  762. # 重置做多持仓
  763. long_position = 0
  764. long_entry_price = 0
  765. long_entry_time = None
  766. long_holding_bars = 0
  767. # 平仓逻辑 - 做空平仓
  768. elif short_position < 0:
  769. short_holding_bars += 1
  770. # 计算止损止盈价格(做空逻辑相反)
  771. stop_loss_price = short_entry_price * (1 + self.params['stop_loss_pct']) # 价格上涨止损
  772. take_profit_price = short_entry_price * (1 - self.params['take_profit_pct']) # 价格下跌止盈
  773. exit_signal = False
  774. exit_reason = ''
  775. exit_price = price
  776. # 止损(价格上涨)
  777. if price >= stop_loss_price:
  778. exit_signal = True
  779. loss_pct = (stop_loss_price - short_entry_price) / short_entry_price * 100
  780. exit_reason = f"做空止损触发(价格{price:.2f}突破止损线{stop_loss_price:.2f},亏损{loss_pct:.2f}%)"
  781. exit_price = price
  782. # 止盈(价格下跌)
  783. elif price <= take_profit_price:
  784. exit_signal = True
  785. profit_pct = (short_entry_price - price) / short_entry_price * 100
  786. exit_reason = f"做空止盈触发(价格{price:.2f}跌破止盈线{take_profit_price:.2f},盈利{profit_pct:.2f}%)"
  787. exit_price = price
  788. # 最大持仓时间
  789. elif short_holding_bars >= self.params['max_hold_bars']:
  790. exit_signal = True
  791. current_pnl_pct = (short_entry_price - price) / short_entry_price * 100
  792. exit_reason = f"做空时间止损(持仓{short_holding_bars}周期达上限{self.params['max_hold_bars']}周期,当前盈亏{current_pnl_pct:+.2f}%)"
  793. # 做空信号消失
  794. elif current_bar['RSI'] < 30:
  795. exit_signal = True
  796. current_pnl_pct = (short_entry_price - price) / short_entry_price * 100
  797. exit_reason = f"做空RSI超卖平仓(RSI={current_bar['RSI']:.1f}超卖,信号消失,当前盈亏{current_pnl_pct:+.2f}%)"
  798. # 执行平仓
  799. if exit_signal:
  800. # 计算盈亏
  801. gross_pnl = (short_entry_price - exit_price) * abs(short_position)
  802. open_commission = abs(short_position) * short_entry_price * (self.params['commission_rate'] + self.params['slippage_rate'])
  803. close_commission = abs(short_position) * exit_price * (self.params['commission_rate'] + self.params['slippage_rate'])
  804. net_pnl = gross_pnl - close_commission
  805. # 更新资金(返还保证金 + 净盈亏)
  806. margin_returned = abs(short_position) * short_entry_price
  807. capital += margin_returned + net_pnl
  808. # 记录交易
  809. trade = {
  810. '交易方向': '做空',
  811. '开仓时间': short_entry_time,
  812. '平仓时间': current_time,
  813. '开仓价格': short_entry_price,
  814. '平仓价格': exit_price,
  815. '仓位': abs(short_position),
  816. '盈亏金额': net_pnl,
  817. '盈亏百分比': (short_entry_price - exit_price) / short_entry_price * 100,
  818. '退出原因': exit_reason,
  819. '持仓周期数': short_holding_bars,
  820. '持仓小时数': short_holding_bars * 0.5,
  821. '入场信号': short_entry_signals,
  822. '平仓时资金': capital,
  823. '开仓市值': abs(short_position) * short_entry_price,
  824. '保证金返还': margin_returned,
  825. '预计止损价格': short_stop_loss_price,
  826. '预计止盈价格': short_take_profit_price
  827. }
  828. trades.append(trade)
  829. # 重置做空持仓
  830. short_position = 0
  831. short_entry_price = 0
  832. short_entry_time = None
  833. short_holding_bars = 0
  834. # 更新资金和持仓状态
  835. df.iloc[i, df.columns.get_loc('capital')] = capital
  836. df.iloc[i, df.columns.get_loc('long_position')] = long_position
  837. df.iloc[i, df.columns.get_loc('short_position')] = short_position
  838. # 强制平仓剩余持仓 - 做多
  839. if long_position > 0:
  840. final_price = df.iloc[-1]['Close']
  841. gross_pnl = (final_price - long_entry_price) * long_position
  842. open_cost = long_position * long_entry_price * (self.params['commission_rate'] + self.params['slippage_rate'])
  843. close_revenue = long_position * final_price
  844. close_cost = close_revenue * (self.params['commission_rate'] + self.params['slippage_rate'])
  845. pnl = gross_pnl - open_cost - close_cost
  846. capital += close_revenue - close_cost
  847. trade = {
  848. '交易方向': '做多',
  849. '开仓时间': long_entry_time,
  850. '平仓时间': df.index[-1],
  851. '开仓价格': long_entry_price,
  852. '平仓价格': final_price,
  853. '仓位': long_position,
  854. '盈亏金额': pnl,
  855. '盈亏百分比': (final_price - long_entry_price) / long_entry_price * 100,
  856. '退出原因': f'做多强制平仓(回测结束,持仓{long_holding_bars}周期,最终价格{final_price:.2f},盈亏{(final_price - long_entry_price) / long_entry_price * 100:+.2f}%)',
  857. '持仓周期数': long_holding_bars,
  858. '持仓小时数': long_holding_bars * 0.5,
  859. '入场信号': long_entry_signals,
  860. '平仓时资金': capital,
  861. '开仓市值': long_position * long_entry_price,
  862. '预计止损价格': long_stop_loss_price,
  863. '预计止盈价格': long_take_profit_price
  864. }
  865. trades.append(trade)
  866. # 做空持仓强制平仓
  867. if short_position < 0:
  868. final_price = df.iloc[-1]['Close']
  869. gross_pnl = (short_entry_price - final_price) * abs(short_position)
  870. close_commission = abs(short_position) * final_price * (self.params['commission_rate'] + self.params['slippage_rate'])
  871. net_pnl = gross_pnl - close_commission
  872. margin_returned = abs(short_position) * short_entry_price
  873. capital += margin_returned + net_pnl
  874. trade = {
  875. '交易方向': '做空',
  876. '开仓时间': short_entry_time,
  877. '平仓时间': df.index[-1],
  878. '开仓价格': short_entry_price,
  879. '平仓价格': final_price,
  880. '仓位': abs(short_position),
  881. '盈亏金额': net_pnl,
  882. '盈亏百分比': (short_entry_price - final_price) / short_entry_price * 100,
  883. '退出原因': f'做空强制平仓(回测结束,持仓{short_holding_bars}周期,最终价格{final_price:.2f},盈亏{(short_entry_price - final_price) / short_entry_price * 100:+.2f}%)',
  884. '持仓周期数': short_holding_bars,
  885. '持仓小时数': short_holding_bars * 0.5,
  886. '入场信号': short_entry_signals,
  887. '平仓时资金': capital,
  888. '开仓市值': abs(short_position) * short_entry_price,
  889. '保证金返还': margin_returned,
  890. '预计止损价格': short_stop_loss_price,
  891. '预计止盈价格': short_take_profit_price
  892. }
  893. trades.append(trade)
  894. trades_df = pd.DataFrame(trades)
  895. if len(trades_df) > 0:
  896. # 统一时间格式
  897. for col in trades_df.columns:
  898. if '时间' in col:
  899. trades_df[col] = pd.to_datetime(trades_df[col])
  900. trades_df = trades_df.sort_values('开仓时间')
  901. if not self.silent_mode:
  902. print(f"多空双向交易执行完成,共{len(trades_df)}笔交易")
  903. return df, trades_df
  904. def main():
  905. """主程序 - 运行参数优化"""
  906. # 创建优化器 - 设置4个并发线程
  907. optimizer = ParameterOptimizer('config.json', max_workers=4)
  908. # 运行优化 (可设置最大迭代次数和时间限制)
  909. # max_iterations: 限制测试次数,None表示测试所有组合
  910. # time_limit: 时间限制(秒),None表示无限制
  911. optimizer.run_optimization(
  912. max_iterations=None, # 测试所有组合 (现在只有1,152个组合)
  913. time_limit=86400 # 24小时 (24*60*60=86400秒)
  914. )
  915. print(f"\n参数优化完成!")
  916. if __name__ == "__main__":
  917. main()