Pārlūkot izejas kodu

Standardize research cost scenarios and comparison table

erwin 1 mēnesi atpakaļ
vecāks
revīzija
32f0875c22

+ 31 - 0
index-rotation/README.md

@@ -126,6 +126,22 @@ final_score = score_mom - 0.30 * score_risk_penalty
 
 这能明确避免未来函数和时点错配。测试已覆盖该约束。
 
+## 研究成本假设标准
+
+当前统一使用三档成本场景:
+
+- `optimistic` = **10bp 总成本**
+- `base` = **15bp 总成本**(默认研究基准)
+- `conservative` = **20bp 总成本**
+
+配置文件位于:
+
+```text
+configs/research/cost_scenarios.yaml
+```
+
+后续比较策略时,优先看 `base`,并同时检查 `optimistic` / `conservative` 两侧的稳定性。
+
 ## 回测输出
 
 回测至少输出:
@@ -176,6 +192,7 @@ Top1 每 5 个交易日(主候选 + 成本敏感性配置):
 
 ```bash
 python3 -m src.backtest.run --config configs/strategy/top1_every_5_days_p05_cost10bp.yaml
+python3 -m src.backtest.run --config configs/strategy/top1_every_5_days_p05_cost15bp.yaml
 python3 -m src.backtest.run --config configs/strategy/top1_every_5_days_p05_cost20bp.yaml
 ```
 
@@ -185,6 +202,19 @@ python3 -m src.backtest.run --config configs/strategy/top1_every_5_days_p05_cost
 outputs/backtests/<config_name>/
 ```
 
+生成统一对比表:
+
+```bash
+python3 -m src.backtest.compare
+```
+
+默认会输出:
+
+```text
+outputs/research/strategy_comparison.csv
+outputs/research/strategy_comparison.md
+```
+
 例如:
 
 - `outputs/backtests/top2_weekly/summary.json`
@@ -249,6 +279,7 @@ python3 -m unittest discover -s tests -v
 - `Top1 / Top2 / 空仓` 分配
 - `t` 日信号、`t+1` 执行的时点约束
 - 回测净值、持仓、调仓记录的基础正确性
+- 统一比较表中的成本标签与相对收益字段
 
 ## 边界与后续
 

+ 10 - 0
index-rotation/configs/research/cost_scenarios.yaml

@@ -0,0 +1,10 @@
+scenarios:
+  optimistic:
+    total_cost_bps: 10
+    description: 低成本研究假设;适合高流动性实现、执行优化较好的理想场景
+  base:
+    total_cost_bps: 15
+    description: 当前默认研究基准;后续比较策略优先参考这一档
+  conservative:
+    total_cost_bps: 20
+    description: 更保守的成本压力测试;用于检验策略是否对成本过度敏感

+ 8 - 0
index-rotation/configs/strategy/top1_every_5_days_p05_cost15bp.yaml

@@ -0,0 +1,8 @@
+name: top1_every_5_days_p05_cost15bp
+top_n: 1
+rebalance_frequency: every_5_days
+commission_bps: 7.5
+slippage_bps: 7.5
+cash_return: 0.0
+risk_penalty_multiplier: 0.5
+start_date: "2019-12-31"

+ 166 - 0
index-rotation/src/backtest/compare.py

@@ -0,0 +1,166 @@
+from __future__ import annotations
+
+import argparse
+import json
+from pathlib import Path
+from typing import Any
+
+import pandas as pd
+import yaml
+
+
+def repo_root() -> Path:
+    return Path(__file__).resolve().parents[2]
+
+
+def build_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(description="Build a unified comparison table for backtest outputs.")
+    root = repo_root()
+    parser.add_argument("--backtests-root", type=Path, default=root / "outputs" / "backtests")
+    parser.add_argument("--strategy-config-root", type=Path, default=root / "configs" / "strategy")
+    parser.add_argument("--cost-scenarios", type=Path, default=root / "configs" / "research" / "cost_scenarios.yaml")
+    parser.add_argument("--output-csv", type=Path, default=root / "outputs" / "research" / "strategy_comparison.csv")
+    parser.add_argument("--output-md", type=Path, default=root / "outputs" / "research" / "strategy_comparison.md")
+    return parser
+
+
+def load_yaml(path: Path) -> dict[str, Any]:
+    with path.open("r", encoding="utf-8") as handle:
+        return yaml.safe_load(handle) or {}
+
+
+def load_cost_scenarios(path: Path) -> dict[float, str]:
+    payload = load_yaml(path)
+    mapping: dict[float, str] = {}
+    for label, config in (payload.get("scenarios") or {}).items():
+        total_cost_bps = float(config["total_cost_bps"])
+        mapping[total_cost_bps] = label
+    return mapping
+
+
+def load_strategy_configs(root: Path) -> dict[str, dict[str, Any]]:
+    configs: dict[str, dict[str, Any]] = {}
+    for path in sorted(root.glob("*.yaml")):
+        payload = load_yaml(path)
+        config_name = str(payload.get("name") or path.stem)
+        payload["_path"] = str(path)
+        configs[config_name] = payload
+    return configs
+
+
+def build_rows(
+    *,
+    backtests_root: Path,
+    strategy_configs: dict[str, dict[str, Any]],
+    cost_scenarios: dict[float, str],
+) -> pd.DataFrame:
+    rows: list[dict[str, Any]] = []
+    for summary_path in sorted(backtests_root.glob("*/summary.json")):
+        result_dir = summary_path.parent
+        name = result_dir.name
+        benchmark_path = result_dir / "benchmark_summary.json"
+        if not benchmark_path.exists():
+            continue
+
+        summary = json.loads(summary_path.read_text(encoding="utf-8"))
+        benchmark_summary = json.loads(benchmark_path.read_text(encoding="utf-8"))
+        cfg = strategy_configs.get(name, {})
+        total_cost_bps = float(cfg.get("commission_bps", 0.0)) + float(cfg.get("slippage_bps", 0.0))
+        row = {
+            "name": name,
+            "top_n": cfg.get("top_n"),
+            "rebalance_frequency": cfg.get("rebalance_frequency"),
+            "risk_penalty_multiplier": cfg.get("risk_penalty_multiplier", 0.30),
+            "total_cost_bps": total_cost_bps,
+            "cost_scenario": cost_scenarios.get(total_cost_bps, "custom" if total_cost_bps else "none"),
+            "cumulative_return": summary.get("cumulative_return"),
+            "annual_return": summary.get("annual_return"),
+            "max_drawdown": summary.get("max_drawdown"),
+            "annual_volatility": summary.get("annual_volatility"),
+            "sharpe": summary.get("sharpe"),
+            "calmar": summary.get("calmar"),
+            "turnover": summary.get("turnover"),
+            "rebalance_count": summary.get("rebalance_count"),
+            "cash_days_ratio": summary.get("cash_days_ratio"),
+            "vs_equal_weight_cum": summary.get("cumulative_return", 0.0) - benchmark_summary.get("equal_weight", {}).get("cumulative_return", 0.0),
+            "vs_hs300_cum": summary.get("cumulative_return", 0.0) - benchmark_summary.get("hs300", {}).get("cumulative_return", 0.0),
+            "vs_chinext50_cum": summary.get("cumulative_return", 0.0) - benchmark_summary.get("chinext50", {}).get("cumulative_return", 0.0),
+        }
+        rows.append(row)
+
+    frame = pd.DataFrame(rows)
+    if frame.empty:
+        return frame
+    return frame.sort_values(
+        ["total_cost_bps", "sharpe", "annual_return", "cumulative_return"],
+        ascending=[True, False, False, False],
+    ).reset_index(drop=True)
+
+
+def render_markdown_table(frame: pd.DataFrame) -> str:
+    if frame.empty:
+        return "# Strategy Comparison\n\n_No comparison rows available._\n"
+
+    display = frame[
+        [
+            "name",
+            "cost_scenario",
+            "total_cost_bps",
+            "top_n",
+            "rebalance_frequency",
+            "risk_penalty_multiplier",
+            "cumulative_return",
+            "annual_return",
+            "max_drawdown",
+            "sharpe",
+            "vs_equal_weight_cum",
+            "vs_hs300_cum",
+            "vs_chinext50_cum",
+        ]
+    ].copy()
+    for column in [
+        "cumulative_return",
+        "annual_return",
+        "max_drawdown",
+        "sharpe",
+        "vs_equal_weight_cum",
+        "vs_hs300_cum",
+        "vs_chinext50_cum",
+    ]:
+        display[column] = display[column].map(lambda value: f"{float(value):.4f}")
+    return "# Strategy Comparison\n\n" + display.to_markdown(index=False) + "\n"
+
+
+def main(argv: list[str] | None = None) -> int:
+    parser = build_parser()
+    args = parser.parse_args(argv)
+
+    strategy_configs = load_strategy_configs(args.strategy_config_root)
+    cost_scenarios = load_cost_scenarios(args.cost_scenarios)
+    frame = build_rows(
+        backtests_root=args.backtests_root,
+        strategy_configs=strategy_configs,
+        cost_scenarios=cost_scenarios,
+    )
+
+    args.output_csv.parent.mkdir(parents=True, exist_ok=True)
+    args.output_md.parent.mkdir(parents=True, exist_ok=True)
+    frame.to_csv(args.output_csv, index=False)
+    args.output_md.write_text(render_markdown_table(frame), encoding="utf-8")
+
+    print(
+        json.dumps(
+            {
+                "rows": int(len(frame.index)),
+                "output_csv": str(args.output_csv),
+                "output_md": str(args.output_md),
+            },
+            ensure_ascii=False,
+            indent=2,
+        )
+    )
+    return 0
+
+
+if __name__ == "__main__":
+    raise SystemExit(main())

+ 71 - 0
index-rotation/tests/test_compare.py

@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+import json
+import tempfile
+import unittest
+from pathlib import Path
+
+from src.backtest.compare import build_rows
+
+
+class CompareTests(unittest.TestCase):
+    def test_build_rows_includes_cost_label_and_relative_columns(self) -> None:
+        temp_dir = tempfile.TemporaryDirectory()
+        self.addCleanup(temp_dir.cleanup)
+        root = Path(temp_dir.name)
+        backtests_root = root / "outputs" / "backtests"
+        result_dir = backtests_root / "demo_strategy"
+        result_dir.mkdir(parents=True, exist_ok=True)
+
+        (result_dir / "summary.json").write_text(
+            json.dumps(
+                {
+                    "cumulative_return": 0.20,
+                    "annual_return": 0.05,
+                    "max_drawdown": -0.10,
+                    "annual_volatility": 0.20,
+                    "sharpe": 0.25,
+                    "calmar": 0.50,
+                    "turnover": 12.0,
+                    "rebalance_count": 8,
+                    "cash_days_ratio": 0.10,
+                }
+            ),
+            encoding="utf-8",
+        )
+        (result_dir / "benchmark_summary.json").write_text(
+            json.dumps(
+                {
+                    "equal_weight": {"cumulative_return": 0.15},
+                    "hs300": {"cumulative_return": 0.08},
+                    "chinext50": {"cumulative_return": 0.30},
+                }
+            ),
+            encoding="utf-8",
+        )
+
+        frame = build_rows(
+            backtests_root=backtests_root,
+            strategy_configs={
+                "demo_strategy": {
+                    "name": "demo_strategy",
+                    "top_n": 1,
+                    "rebalance_frequency": "every_5_days",
+                    "risk_penalty_multiplier": 0.5,
+                    "commission_bps": 5.0,
+                    "slippage_bps": 5.0,
+                }
+            },
+            cost_scenarios={10.0: "optimistic", 15.0: "base", 20.0: "conservative"},
+        )
+
+        self.assertEqual(len(frame.index), 1)
+        row = frame.iloc[0]
+        self.assertEqual(row["cost_scenario"], "optimistic")
+        self.assertAlmostEqual(row["vs_equal_weight_cum"], 0.05)
+        self.assertAlmostEqual(row["vs_hs300_cum"], 0.12)
+        self.assertAlmostEqual(row["vs_chinext50_cum"], -0.10)
+
+
+if __name__ == "__main__":
+    unittest.main()