test_compare.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. import json
  3. import tempfile
  4. import unittest
  5. from pathlib import Path
  6. from src.backtest.compare import build_rows
  7. class CompareTests(unittest.TestCase):
  8. def test_build_rows_includes_cost_label_and_relative_columns(self) -> None:
  9. temp_dir = tempfile.TemporaryDirectory()
  10. self.addCleanup(temp_dir.cleanup)
  11. root = Path(temp_dir.name)
  12. backtests_root = root / "outputs" / "backtests"
  13. result_dir = backtests_root / "demo_strategy"
  14. result_dir.mkdir(parents=True, exist_ok=True)
  15. (result_dir / "summary.json").write_text(
  16. json.dumps(
  17. {
  18. "cumulative_return": 0.20,
  19. "annual_return": 0.05,
  20. "max_drawdown": -0.10,
  21. "annual_volatility": 0.20,
  22. "sharpe": 0.25,
  23. "calmar": 0.50,
  24. "turnover": 12.0,
  25. "rebalance_count": 8,
  26. "cash_days_ratio": 0.10,
  27. }
  28. ),
  29. encoding="utf-8",
  30. )
  31. (result_dir / "benchmark_summary.json").write_text(
  32. json.dumps(
  33. {
  34. "equal_weight": {"cumulative_return": 0.15},
  35. "hs300": {"cumulative_return": 0.08},
  36. "chinext50": {"cumulative_return": 0.30},
  37. }
  38. ),
  39. encoding="utf-8",
  40. )
  41. frame = build_rows(
  42. backtests_root=backtests_root,
  43. strategy_configs={
  44. "demo_strategy": {
  45. "name": "demo_strategy",
  46. "top_n": 1,
  47. "rebalance_frequency": "every_5_days",
  48. "risk_penalty_multiplier": 0.5,
  49. "commission_bps": 5.0,
  50. "slippage_bps": 5.0,
  51. }
  52. },
  53. cost_scenarios={10.0: "optimistic", 15.0: "base", 20.0: "conservative"},
  54. )
  55. self.assertEqual(len(frame.index), 1)
  56. row = frame.iloc[0]
  57. self.assertEqual(row["cost_scenario"], "optimistic")
  58. self.assertAlmostEqual(row["vs_equal_weight_cum"], 0.05)
  59. self.assertAlmostEqual(row["vs_hs300_cum"], 0.12)
  60. self.assertAlmostEqual(row["vs_chinext50_cum"], -0.10)
  61. if __name__ == "__main__":
  62. unittest.main()