frozen_hypothesis_validation.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import sys
  4. ROOT = Path(__file__).resolve().parents[1]
  5. if str(ROOT) not in sys.path:
  6. sys.path.insert(0, str(ROOT))
  7. import argparse
  8. import json
  9. from typing import Any
  10. from backtest.frozen_walkforward import (
  11. normalize_hypothesis_candidates,
  12. run_frozen_walkforward,
  13. run_strategy_bundle,
  14. )
  15. from backtest.walkforward import WindowSpec, build_expanding_windows
  16. from config.loader import load_config
  17. from data.io import (
  18. evaluate_data_quality_gate,
  19. load_full_pit_data,
  20. )
  21. def _resolve_data_quality_settings(
  22. config: dict[str, Any],
  23. *,
  24. strict_cli: bool,
  25. min_coverage_cli: float | None,
  26. ) -> tuple[bool, float, list[str] | None, list[str] | None, dict[str, float]]:
  27. quality_cfg = config.get('data_quality', {})
  28. strict_mode = bool(quality_cfg.get('strict_mode_default', False)) or strict_cli
  29. default_min_coverage = float(quality_cfg.get('default_min_coverage', 0.95))
  30. if min_coverage_cli is not None:
  31. default_min_coverage = float(min_coverage_cli)
  32. critical_columns = [str(col).strip().lower() for col in quality_cfg.get('critical_columns', [])]
  33. blocking_columns = [str(col).strip().lower() for col in quality_cfg.get('blocking_columns', critical_columns)]
  34. column_min_coverage = {
  35. str(column).strip().lower(): float(value) for column, value in quality_cfg.get('column_min_coverage', {}).items()
  36. }
  37. return strict_mode, default_min_coverage, (critical_columns or None), (blocking_columns or None), column_min_coverage
  38. def _load_candidate_payload(path: str | None) -> list[dict[str, Any]] | None:
  39. if not path:
  40. return None
  41. with Path(path).open('r', encoding='utf-8') as fh:
  42. payload = json.load(fh)
  43. if not isinstance(payload, list):
  44. raise ValueError('Candidate file must be a JSON list of candidate objects.')
  45. return payload
  46. def _resolve_frozen_settings(
  47. config: dict[str, Any],
  48. *,
  49. candidates_json: str | None,
  50. min_train_rows_cli: int | None,
  51. min_test_rows_cli: int | None,
  52. ) -> tuple[list[Any], int, int]:
  53. frozen_cfg = config.get('frozen_validation', {})
  54. raw_candidates = _load_candidate_payload(candidates_json) or frozen_cfg.get('candidates')
  55. candidates = normalize_hypothesis_candidates(raw_candidates)
  56. min_train_rows = int(frozen_cfg.get('min_train_rows', 120))
  57. min_test_rows = int(frozen_cfg.get('min_test_rows', 40))
  58. if min_train_rows_cli is not None:
  59. min_train_rows = int(min_train_rows_cli)
  60. if min_test_rows_cli is not None:
  61. min_test_rows = int(min_test_rows_cli)
  62. return candidates, min_train_rows, min_test_rows
  63. def _serialize_windows(windows: list[WindowSpec]) -> list[dict[str, str]]:
  64. return [
  65. {
  66. 'train_start': window.train_start,
  67. 'train_end': window.train_end,
  68. 'test_start': window.test_start,
  69. 'test_end': window.test_end,
  70. }
  71. for window in windows
  72. ]
  73. def _resolve_walkforward_windows(config: dict[str, Any], raw_index) -> list[WindowSpec]:
  74. frozen_cfg = config.get('frozen_validation', {})
  75. window_mode = str(frozen_cfg.get('window_mode', 'expanding')).strip().lower()
  76. if window_mode != 'expanding':
  77. raise ValueError(f'Unsupported window_mode: {window_mode}')
  78. return build_expanding_windows(
  79. raw_index,
  80. min_train_years=int(frozen_cfg.get('min_train_years', 2)),
  81. test_years=int(frozen_cfg.get('test_years', 1)),
  82. allow_partial_last_test=bool(frozen_cfg.get('allow_partial_last_test', True)),
  83. )
  84. def main() -> None:
  85. parser = argparse.ArgumentParser(description='Run frozen-hypothesis validation for the ChiNext 50 regime scaffold.')
  86. parser.add_argument(
  87. '--pit-csv',
  88. '--data-csv',
  89. dest='pit_csv',
  90. type=str,
  91. required=True,
  92. help='Required CSV/parquet full PIT input keyed by date.',
  93. )
  94. parser.add_argument(
  95. '--strict-data',
  96. action='store_true',
  97. help='Fail fast when critical input columns breach coverage thresholds.',
  98. )
  99. parser.add_argument(
  100. '--min-coverage',
  101. type=float,
  102. default=None,
  103. help='Override the default minimum non-null coverage ratio for data quality gate.',
  104. )
  105. parser.add_argument(
  106. '--candidates-json',
  107. type=str,
  108. default=None,
  109. help='Optional JSON file describing frozen-validation candidate set.',
  110. )
  111. parser.add_argument(
  112. '--min-train-rows',
  113. type=int,
  114. default=None,
  115. help='Override minimum required rows for each training window.',
  116. )
  117. parser.add_argument(
  118. '--min-test-rows',
  119. type=int,
  120. default=None,
  121. help='Override minimum required rows for each test window.',
  122. )
  123. parser.add_argument('--config', type=str, default=None, help='Optional config YAML path.')
  124. parser.add_argument('--output-dir', type=str, default='outputs/frozen_validation', help='Directory for validation artifacts.')
  125. args = parser.parse_args()
  126. output_dir = Path(args.output_dir)
  127. output_dir.mkdir(parents=True, exist_ok=True)
  128. config = load_config(args.config)
  129. raw = load_full_pit_data(args.pit_csv)
  130. strict_mode, min_coverage, critical_columns, blocking_columns, column_min_coverage = _resolve_data_quality_settings(
  131. config,
  132. strict_cli=args.strict_data,
  133. min_coverage_cli=args.min_coverage,
  134. )
  135. quality_summary = evaluate_data_quality_gate(
  136. raw,
  137. strict=strict_mode,
  138. critical_columns=critical_columns,
  139. blocking_columns=blocking_columns,
  140. default_min_coverage=min_coverage,
  141. column_min_coverage=column_min_coverage,
  142. )
  143. with (output_dir / 'data_quality_summary.json').open('w', encoding='utf-8') as fh:
  144. json.dump(quality_summary, fh, ensure_ascii=False, indent=2)
  145. if quality_summary['blocking']:
  146. failed_items = quality_summary.get('errors') or quality_summary['breaches']
  147. breached = ', '.join(item['column'] for item in failed_items)
  148. raise ValueError(f'Data quality gate failed in strict mode. Breached columns: {breached}')
  149. config.setdefault('_runtime', {})['strict_feature_gate'] = strict_mode
  150. candidates, min_train_rows, min_test_rows = _resolve_frozen_settings(
  151. config,
  152. candidates_json=args.candidates_json,
  153. min_train_rows_cli=args.min_train_rows,
  154. min_test_rows_cli=args.min_test_rows,
  155. )
  156. windows = _resolve_walkforward_windows(config, raw.index)
  157. board, frozen_summary = run_frozen_walkforward(
  158. raw=raw,
  159. config=config,
  160. windows=windows,
  161. candidates=candidates,
  162. min_train_rows=min_train_rows,
  163. min_test_rows=min_test_rows,
  164. )
  165. _, _, full_metrics = run_strategy_bundle(raw, config)
  166. summary = {
  167. 'window_count': int(frozen_summary['total_windows']),
  168. 'processed_window_count': int(frozen_summary['processed_window_count']),
  169. 'skipped_window_count': int(frozen_summary['skipped_window_count']),
  170. 'positive_window_ratio': float(frozen_summary['positive_window_ratio']),
  171. 'selected_candidate_distribution': dict(frozen_summary['selected_candidate_distribution']),
  172. 'window_status_counts': dict(frozen_summary['window_status_counts']),
  173. 'selection_mode_distribution': dict(frozen_summary.get('selection_mode_distribution', {})),
  174. 'windows_with_hard_pass_candidate_count': int(frozen_summary.get('windows_with_hard_pass_candidate_count', 0)),
  175. 'windows_without_hard_pass_candidate_count': int(
  176. frozen_summary.get('windows_without_hard_pass_candidate_count', 0)
  177. ),
  178. 'hard_pass_window_ratio': float(frozen_summary.get('hard_pass_window_ratio', 0.0)),
  179. 'candidate_selection': dict(frozen_summary.get('candidate_selection', {})),
  180. 'candidate_ids': list(frozen_summary['candidate_ids']),
  181. 'min_train_rows': int(frozen_summary['min_train_rows']),
  182. 'min_test_rows': int(frozen_summary['min_test_rows']),
  183. 'windows': _serialize_windows(windows),
  184. 'full_sample_metrics': full_metrics,
  185. }
  186. board.to_csv(output_dir / 'frozen_validation_board.csv', index=False)
  187. with (output_dir / 'frozen_validation_summary.json').open('w', encoding='utf-8') as fh:
  188. json.dump(summary, fh, ensure_ascii=False, indent=2)
  189. if __name__ == '__main__':
  190. main()