| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- from __future__ import annotations
- import unittest
- import pandas as pd
- from src.backtest.engine import BacktestConfig, run_backtest
- from src.portfolio.rebalance import generate_signal_dates
- def make_backtest_input() -> pd.DataFrame:
- dates = pd.date_range("2020-01-01", periods=4, freq="D")
- rows: list[dict[str, object]] = []
- daily_returns = {
- "sse50": [0.0, 0.50, 0.10, 0.00],
- "hs300": [0.0, 0.00, 0.00, 0.00],
- "chinext50": [0.0, 0.00, 0.00, 0.00],
- "star50": [0.0, 0.00, 0.00, 0.00],
- }
- momentum = {"sse50": 0.20, "hs300": 0.10, "chinext50": 0.05, "star50": 0.02}
- for trade_date in dates:
- for instrument in ["sse50", "hs300", "chinext50", "star50"]:
- rows.append(
- {
- "instrument": instrument,
- "trade_date": trade_date,
- "close": 100.0,
- "daily_return": daily_returns[instrument][(trade_date - dates[0]).days],
- "ret_5d": momentum[instrument],
- "ret_10d": momentum[instrument],
- "ret_20d": momentum[instrument],
- "ret_60d": momentum[instrument],
- "ma_20": 90.0,
- "ma_60": 80.0,
- "vol_10d": 0.01 if instrument == "sse50" else 0.02,
- "vol_20d": 0.01 if instrument == "sse50" else 0.02,
- }
- )
- return pd.DataFrame(rows)
- class BacktestTests(unittest.TestCase):
- def test_generate_signal_dates_supports_weekly_and_every_five_days(self) -> None:
- trade_dates = pd.date_range("2020-01-01", periods=12, freq="B")
- weekly = generate_signal_dates(trade_dates, "weekly")
- every_five = generate_signal_dates(trade_dates, "every_5_days")
- self.assertEqual(weekly.tolist(), [pd.Timestamp("2020-01-03"), pd.Timestamp("2020-01-10")])
- self.assertEqual(every_five.tolist(), [pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-08"), pd.Timestamp("2020-01-15")])
- def test_signal_date_and_execution_date_are_separated(self) -> None:
- result = run_backtest(
- make_backtest_input(),
- BacktestConfig(top_n=1, rebalance_frequency="daily"),
- )
- nav = result["daily_nav"].set_index("trade_date")
- rebalances = result["rebalances"].drop_duplicates(subset=["execution_date"])
- self.assertEqual(rebalances.iloc[0]["signal_date"], pd.Timestamp("2020-01-01"))
- self.assertEqual(rebalances.iloc[0]["execution_date"], pd.Timestamp("2020-01-02"))
- self.assertAlmostEqual(nav.loc[pd.Timestamp("2020-01-01"), "nav"], 1.0)
- self.assertAlmostEqual(nav.loc[pd.Timestamp("2020-01-02"), "nav"], 1.0)
- self.assertAlmostEqual(nav.loc[pd.Timestamp("2020-01-03"), "nav"], 1.1)
- def test_backtest_outputs_holdings_and_basic_metrics(self) -> None:
- result = run_backtest(
- make_backtest_input(),
- BacktestConfig(top_n=2, rebalance_frequency="daily"),
- )
- self.assertIn("summary", result)
- self.assertIn("daily_nav", result)
- self.assertIn("daily_holdings", result)
- self.assertIn("rebalances", result)
- self.assertIn("benchmark_nav", result)
- self.assertIn("benchmark_summary", result)
- self.assertIn("annual_returns", result)
- self.assertGreaterEqual(result["summary"]["rebalance_count"], 1)
- self.assertEqual(result["daily_holdings"]["trade_date"].nunique(), 4)
- self.assertEqual(result["rebalances"]["execution_date"].nunique(), 3)
- self.assertIn("strategy", result["benchmark_nav"].columns)
- self.assertIn("equal_weight", result["benchmark_nav"].columns)
- self.assertIn("strategy", result["benchmark_summary"])
- self.assertTrue((result["annual_returns"]["year"] == 2020).all())
- if __name__ == "__main__":
- unittest.main()
|