test_phase2_backtest.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from __future__ import annotations
  2. import unittest
  3. import pandas as pd
  4. from src.backtest.engine import BacktestConfig, run_backtest
  5. from src.portfolio.rebalance import generate_signal_dates
  6. def make_backtest_input() -> pd.DataFrame:
  7. dates = pd.date_range("2020-01-01", periods=4, freq="D")
  8. rows: list[dict[str, object]] = []
  9. daily_returns = {
  10. "sse50": [0.0, 0.50, 0.10, 0.00],
  11. "hs300": [0.0, 0.00, 0.00, 0.00],
  12. "chinext50": [0.0, 0.00, 0.00, 0.00],
  13. "star50": [0.0, 0.00, 0.00, 0.00],
  14. }
  15. momentum = {"sse50": 0.20, "hs300": 0.10, "chinext50": 0.05, "star50": 0.02}
  16. for trade_date in dates:
  17. for instrument in ["sse50", "hs300", "chinext50", "star50"]:
  18. rows.append(
  19. {
  20. "instrument": instrument,
  21. "trade_date": trade_date,
  22. "close": 100.0,
  23. "daily_return": daily_returns[instrument][(trade_date - dates[0]).days],
  24. "ret_5d": momentum[instrument],
  25. "ret_10d": momentum[instrument],
  26. "ret_20d": momentum[instrument],
  27. "ret_60d": momentum[instrument],
  28. "ma_20": 90.0,
  29. "ma_60": 80.0,
  30. "vol_10d": 0.01 if instrument == "sse50" else 0.02,
  31. "vol_20d": 0.01 if instrument == "sse50" else 0.02,
  32. }
  33. )
  34. return pd.DataFrame(rows)
  35. class BacktestTests(unittest.TestCase):
  36. def test_generate_signal_dates_supports_weekly_and_every_five_days(self) -> None:
  37. trade_dates = pd.date_range("2020-01-01", periods=12, freq="B")
  38. weekly = generate_signal_dates(trade_dates, "weekly")
  39. every_five = generate_signal_dates(trade_dates, "every_5_days")
  40. self.assertEqual(weekly.tolist(), [pd.Timestamp("2020-01-03"), pd.Timestamp("2020-01-10")])
  41. self.assertEqual(every_five.tolist(), [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-08"), pd.Timestamp("2020-01-15")])
  42. def test_signal_date_and_execution_date_are_separated(self) -> None:
  43. result = run_backtest(
  44. make_backtest_input(),
  45. BacktestConfig(top_n=1, rebalance_frequency="daily"),
  46. )
  47. nav = result["daily_nav"].set_index("trade_date")
  48. rebalances = result["rebalances"].drop_duplicates(subset=["execution_date"])
  49. self.assertEqual(rebalances.iloc[0]["signal_date"], pd.Timestamp("2020-01-01"))
  50. self.assertEqual(rebalances.iloc[0]["execution_date"], pd.Timestamp("2020-01-02"))
  51. self.assertAlmostEqual(nav.loc[pd.Timestamp("2020-01-01"), "nav"], 1.0)
  52. self.assertAlmostEqual(nav.loc[pd.Timestamp("2020-01-02"), "nav"], 1.0)
  53. self.assertAlmostEqual(nav.loc[pd.Timestamp("2020-01-03"), "nav"], 1.1)
  54. def test_backtest_outputs_holdings_and_basic_metrics(self) -> None:
  55. result = run_backtest(
  56. make_backtest_input(),
  57. BacktestConfig(top_n=2, rebalance_frequency="daily"),
  58. )
  59. self.assertIn("summary", result)
  60. self.assertIn("daily_nav", result)
  61. self.assertIn("daily_holdings", result)
  62. self.assertIn("rebalances", result)
  63. self.assertIn("benchmark_nav", result)
  64. self.assertIn("benchmark_summary", result)
  65. self.assertIn("annual_returns", result)
  66. self.assertGreaterEqual(result["summary"]["rebalance_count"], 1)
  67. self.assertEqual(result["daily_holdings"]["trade_date"].nunique(), 4)
  68. self.assertEqual(result["rebalances"]["execution_date"].nunique(), 3)
  69. self.assertIn("strategy", result["benchmark_nav"].columns)
  70. self.assertIn("equal_weight", result["benchmark_nav"].columns)
  71. self.assertIn("strategy", result["benchmark_summary"])
  72. self.assertTrue((result["annual_returns"]["year"] == 2020).all())
  73. if __name__ == "__main__":
  74. unittest.main()