test_compare.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import annotations
  2. import json
  3. import tempfile
  4. import unittest
  5. from pathlib import Path
  6. from src.backtest.compare import build_rows, render_markdown_table
  7. class CompareTests(unittest.TestCase):
  8. def test_build_rows_includes_momentum_profile_cost_label_and_relative_columns(self) -> None:
  9. temp_dir = tempfile.TemporaryDirectory()
  10. self.addCleanup(temp_dir.cleanup)
  11. root = Path(temp_dir.name)
  12. backtests_root = root / "outputs" / "backtests"
  13. result_dir = backtests_root / "demo_strategy"
  14. result_dir.mkdir(parents=True, exist_ok=True)
  15. (result_dir / "summary.json").write_text(
  16. json.dumps(
  17. {
  18. "cumulative_return": 0.20,
  19. "annual_return": 0.05,
  20. "max_drawdown": -0.10,
  21. "annual_volatility": 0.20,
  22. "sharpe": 0.25,
  23. "calmar": 0.50,
  24. "turnover": 12.0,
  25. "rebalance_count": 8,
  26. "cash_days_ratio": 0.10,
  27. }
  28. ),
  29. encoding="utf-8",
  30. )
  31. (result_dir / "benchmark_summary.json").write_text(
  32. json.dumps(
  33. {
  34. "equal_weight": {"cumulative_return": 0.15},
  35. "hs300": {"cumulative_return": 0.08},
  36. "chinext50": {"cumulative_return": 0.30},
  37. }
  38. ),
  39. encoding="utf-8",
  40. )
  41. frame = build_rows(
  42. backtests_root=backtests_root,
  43. strategy_configs={
  44. "demo_strategy": {
  45. "name": "demo_strategy",
  46. "top_n": 1,
  47. "rebalance_frequency": "every_5_days",
  48. "risk_penalty_multiplier": 0.5,
  49. "commission_bps": 7.5,
  50. "slippage_bps": 7.5,
  51. "momentum_weights": {
  52. "ret_5d": 0.25,
  53. "ret_10d": 0.25,
  54. "ret_20d": 0.30,
  55. "ret_60d": 0.20,
  56. },
  57. }
  58. },
  59. cost_scenarios={10.0: "optimistic", 15.0: "base", 20.0: "conservative"},
  60. )
  61. self.assertEqual(len(frame.index), 1)
  62. row = frame.iloc[0]
  63. self.assertEqual(row["cost_scenario"], "base")
  64. self.assertEqual(row["momentum_profile"], "5d=0.25, 10d=0.25, 20d=0.30, 60d=0.20")
  65. self.assertAlmostEqual(row["vs_equal_weight_cum"], 0.05)
  66. self.assertAlmostEqual(row["vs_hs300_cum"], 0.12)
  67. self.assertAlmostEqual(row["vs_chinext50_cum"], -0.10)
  68. markdown = render_markdown_table(frame)
  69. self.assertIn("Top1 Every 5 Days P05 Base 15bp Momentum Experiment", markdown)
  70. self.assertIn("5d=0.25, 10d=0.25, 20d=0.30, 60d=0.20", markdown)
  71. if __name__ == "__main__":
  72. unittest.main()