calibrate_execution_constraints.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import sys
  4. from typing import Any, Mapping
  5. ROOT = Path(__file__).resolve().parents[1]
  6. if str(ROOT) not in sys.path:
  7. sys.path.insert(0, str(ROOT))
  8. import argparse
  9. import copy
  10. import json
  11. import pandas as pd
  12. from backtest.frozen_walkforward import run_strategy_bundle
  13. from config.loader import load_config
  14. from data.io import evaluate_data_quality_gate, load_full_pit_data
  15. def _resolve_data_quality_settings(
  16. config: dict[str, Any],
  17. *,
  18. strict_cli: bool,
  19. min_coverage_cli: float | None,
  20. ) -> tuple[bool, float, list[str] | None, list[str] | None, dict[str, float]]:
  21. quality_cfg = config.get('data_quality', {})
  22. strict_mode = bool(quality_cfg.get('strict_mode_default', False)) or strict_cli
  23. default_min_coverage = float(quality_cfg.get('default_min_coverage', 0.95))
  24. if min_coverage_cli is not None:
  25. default_min_coverage = float(min_coverage_cli)
  26. critical_columns = [str(col).strip().lower() for col in quality_cfg.get('critical_columns', [])]
  27. blocking_columns = [str(col).strip().lower() for col in quality_cfg.get('blocking_columns', critical_columns)]
  28. column_min_coverage = {
  29. str(column).strip().lower(): float(value) for column, value in quality_cfg.get('column_min_coverage', {}).items()
  30. }
  31. return strict_mode, default_min_coverage, (critical_columns or None), (blocking_columns or None), column_min_coverage
  32. def _deep_merge_dict(base: Mapping[str, Any], overrides: Mapping[str, Any]) -> dict[str, Any]:
  33. out = copy.deepcopy(dict(base))
  34. for key, value in overrides.items():
  35. if isinstance(value, Mapping) and isinstance(out.get(key), Mapping):
  36. out[key] = _deep_merge_dict(dict(out[key]), value)
  37. else:
  38. out[key] = copy.deepcopy(value)
  39. return out
  40. def _parse_float_list(raw: str, *, label: str) -> list[float]:
  41. values: list[float] = []
  42. for item in raw.split(','):
  43. text = item.strip()
  44. if not text:
  45. continue
  46. values.append(float(text))
  47. if not values:
  48. raise ValueError(f'{label} must include at least one float value.')
  49. return values
  50. def _calibration_score(metrics: Mapping[str, Any]) -> float:
  51. utility = float(metrics.get('utility_total_score', 0.0))
  52. annual_return = float(metrics.get('annual_return', 0.0))
  53. upside_capture = float(metrics.get('upside_capture', 0.0))
  54. tracking_abs = float(metrics.get('tracking_diff_abs_mean', 0.0))
  55. tracking_p95 = float(metrics.get('tracking_error_20_p95', 0.0))
  56. max_drawdown = float(metrics.get('max_drawdown', 0.0))
  57. return (
  58. 0.60 * utility
  59. + 0.25 * annual_return
  60. + 0.15 * upside_capture
  61. - 0.50 * max_drawdown
  62. - 2.0 * max(0.0, tracking_p95 - 0.003)
  63. - 1.0 * max(0.0, tracking_abs - 0.001)
  64. )
  65. def main() -> None:
  66. parser = argparse.ArgumentParser(description='Calibrate execution constraint parameters on full PIT data.')
  67. parser.add_argument('--pit-csv', '--data-csv', dest='pit_csv', type=str, required=True, help='Required CSV/parquet full PIT input keyed by date.')
  68. parser.add_argument('--strict-data', action='store_true', help='Fail fast when blocking quality breaches are detected.')
  69. parser.add_argument('--min-coverage', type=float, default=None, help='Override default minimum non-null coverage ratio.')
  70. parser.add_argument('--cost-multipliers', type=str, default='1.0,1.25,1.5,1.75', help='Comma-separated extreme_day_cost_multiplier candidates.')
  71. parser.add_argument('--gap-slippage-factors', type=str, default='0.0,0.01,0.02,0.03', help='Comma-separated gap_slippage_factor candidates.')
  72. parser.add_argument('--config', type=str, default=None, help='Optional config YAML path.')
  73. parser.add_argument('--output-dir', type=str, default='outputs/execution_calibration', help='Directory for calibration artifacts.')
  74. args = parser.parse_args()
  75. output_dir = Path(args.output_dir)
  76. output_dir.mkdir(parents=True, exist_ok=True)
  77. config = load_config(args.config)
  78. raw = load_full_pit_data(args.pit_csv)
  79. strict_mode, min_coverage, critical_columns, blocking_columns, column_min_coverage = _resolve_data_quality_settings(
  80. config,
  81. strict_cli=args.strict_data,
  82. min_coverage_cli=args.min_coverage,
  83. )
  84. quality_summary = evaluate_data_quality_gate(
  85. raw,
  86. strict=strict_mode,
  87. critical_columns=critical_columns,
  88. blocking_columns=blocking_columns,
  89. default_min_coverage=min_coverage,
  90. column_min_coverage=column_min_coverage,
  91. )
  92. with (output_dir / 'data_quality_summary.json').open('w', encoding='utf-8') as fh:
  93. json.dump(quality_summary, fh, ensure_ascii=False, indent=2)
  94. if quality_summary['blocking']:
  95. failed_items = quality_summary.get('errors') or quality_summary['breaches']
  96. breached = ', '.join(item['column'] for item in failed_items)
  97. raise ValueError(f'Data quality gate failed in strict mode. Breached columns: {breached}')
  98. config.setdefault('_runtime', {})['strict_feature_gate'] = strict_mode
  99. multipliers = _parse_float_list(args.cost_multipliers, label='cost-multipliers')
  100. gap_factors = _parse_float_list(args.gap_slippage_factors, label='gap-slippage-factors')
  101. rows: list[dict[str, Any]] = []
  102. for multiplier in multipliers:
  103. for gap_factor in gap_factors:
  104. candidate_config = _deep_merge_dict(
  105. config,
  106. {
  107. 'trading': {
  108. 'extreme_day_cost_multiplier': float(multiplier),
  109. 'gap_slippage_factor': float(gap_factor),
  110. }
  111. },
  112. )
  113. _, _, metrics = run_strategy_bundle(raw, candidate_config)
  114. score = _calibration_score(metrics)
  115. rows.append(
  116. {
  117. 'extreme_day_cost_multiplier': float(multiplier),
  118. 'gap_slippage_factor': float(gap_factor),
  119. 'calibration_score': float(score),
  120. 'utility_total_score': float(metrics.get('utility_total_score', 0.0)),
  121. 'annual_return': float(metrics.get('annual_return', 0.0)),
  122. 'sharpe': float(metrics.get('sharpe', 0.0)),
  123. 'max_drawdown': float(metrics.get('max_drawdown', 0.0)),
  124. 'tracking_diff_mean': float(metrics.get('tracking_diff_mean', 0.0)),
  125. 'tracking_diff_abs_mean': float(metrics.get('tracking_diff_abs_mean', 0.0)),
  126. 'tracking_error_20_p95': float(metrics.get('tracking_error_20_p95', 0.0)),
  127. }
  128. )
  129. grid = pd.DataFrame(rows).sort_values(by='calibration_score', ascending=False).reset_index(drop=True)
  130. grid.to_csv(output_dir / 'execution_calibration_grid.csv', index=False)
  131. best = grid.iloc[0].to_dict()
  132. recommendation = {
  133. 'input': {
  134. 'pit_path': str(args.pit_csv),
  135. 'row_count': int(len(raw)),
  136. 'date_start': raw.index.min().date().isoformat() if len(raw) else None,
  137. 'date_end': raw.index.max().date().isoformat() if len(raw) else None,
  138. },
  139. 'score_formula': '0.60*utility_total_score + 0.25*annual_return + 0.15*upside_capture - 0.50*max_drawdown - 2.0*max(0, tracking_error_20_p95 - 0.003) - 1.0*max(0, tracking_diff_abs_mean - 0.001)',
  140. 'search_space': {
  141. 'cost_multipliers': multipliers,
  142. 'gap_slippage_factors': gap_factors,
  143. 'combination_count': int(len(grid)),
  144. },
  145. 'recommended': {
  146. 'extreme_day_cost_multiplier': float(best['extreme_day_cost_multiplier']),
  147. 'gap_slippage_factor': float(best['gap_slippage_factor']),
  148. 'calibration_score': float(best['calibration_score']),
  149. },
  150. 'top_candidates': grid.head(5).to_dict(orient='records'),
  151. }
  152. with (output_dir / 'execution_calibration_recommendation.json').open('w', encoding='utf-8') as fh:
  153. json.dump(recommendation, fh, ensure_ascii=False, indent=2)
  154. if __name__ == '__main__':
  155. main()