dragon_rule_ablation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. from __future__ import annotations
  2. from pathlib import Path
  3. import pandas as pd
  4. from dragon_indicators import DragonIndicatorConfig, DragonIndicatorEngine
  5. from dragon_strategy import DragonRuleEngine
  6. from dragon_strategy_config import StrategyConfig
  7. from dragon_workbook import DragonWorkbook
  8. def _find_workbook(base_dir: Path) -> Path:
  9. matches = sorted(base_dir.glob("*.xlsx"))
  10. if not matches:
  11. raise FileNotFoundError(f"No workbook found in {base_dir}")
  12. return matches[0]
  13. def _load_workbook_events(workbook_path: Path) -> pd.DataFrame:
  14. workbook = DragonWorkbook(workbook_path)
  15. return pd.DataFrame(
  16. [
  17. {
  18. "date": event.date.isoformat(),
  19. "side": event.side,
  20. "layer": event.layer,
  21. "signal_reason": event.signal_reason,
  22. "note": event.note,
  23. }
  24. for event in workbook.split_layers()
  25. ]
  26. )
  27. def _event_match_report(workbook_events: pd.DataFrame, strategy_events: pd.DataFrame, side: str, layer: str) -> dict[str, object]:
  28. wb = set(workbook_events[(workbook_events["side"] == side) & (workbook_events["layer"] == layer)]["date"])
  29. st = set(strategy_events[(strategy_events["side"] == side) & (strategy_events["layer"] == layer)]["date"])
  30. hit = wb & st
  31. return {
  32. "workbook": len(wb),
  33. "strategy": len(st),
  34. "overlap": len(hit),
  35. "missing": len(wb - st),
  36. "extra": len(st - wb),
  37. }
  38. def _profit_factor(series: pd.Series) -> float:
  39. gross_profit = series[series > 0].sum()
  40. gross_loss = -series[series < 0].sum()
  41. if gross_loss == 0:
  42. return float("inf") if gross_profit > 0 else 0.0
  43. return float(gross_profit / gross_loss)
  44. def _trade_quality(trades: pd.DataFrame, indicator_df: pd.DataFrame) -> tuple[float, float]:
  45. if trades.empty:
  46. return float("nan"), float("nan")
  47. lookup = indicator_df.reset_index().rename(columns={"index": "dt"})
  48. lookup["date_str"] = lookup["date"].dt.date.astype(str)
  49. pos_lookup = {date_str: idx for idx, date_str in enumerate(lookup["date_str"])}
  50. mfe_values: list[float] = []
  51. mae_values: list[float] = []
  52. for _, trade in trades.iterrows():
  53. buy_idx = pos_lookup.get(trade["buy_date"])
  54. sell_idx = pos_lookup.get(trade["sell_date"])
  55. if buy_idx is None or sell_idx is None or sell_idx < buy_idx:
  56. continue
  57. window = lookup.iloc[buy_idx : sell_idx + 1]
  58. entry_price = float(trade["buy_price"])
  59. mfe_values.append(float(window["high"].max()) / entry_price - 1.0)
  60. mae_values.append(float(window["low"].min()) / entry_price - 1.0)
  61. return (
  62. float(pd.Series(mfe_values).mean()) if mfe_values else float("nan"),
  63. float(pd.Series(mae_values).mean()) if mae_values else float("nan"),
  64. )
  65. def _run_single_experiment(
  66. label: str,
  67. config: StrategyConfig,
  68. workbook_events: pd.DataFrame,
  69. indicator_df: pd.DataFrame,
  70. first_workbook_date: str,
  71. last_workbook_date: str,
  72. ) -> dict[str, object]:
  73. strategy = DragonRuleEngine(config=config)
  74. events, trades = strategy.run(indicator_df)
  75. events = events[(events["date"] >= first_workbook_date) & (events["date"] <= last_workbook_date)].copy()
  76. trades = trades[
  77. (trades["buy_date"] >= first_workbook_date)
  78. & (trades["buy_date"] <= last_workbook_date)
  79. & (trades["sell_date"] >= first_workbook_date)
  80. & (trades["sell_date"] <= last_workbook_date)
  81. ].copy()
  82. buy_match = _event_match_report(workbook_events, events, "BUY", "real_trade")
  83. sell_match = _event_match_report(workbook_events, events, "SELL", "real_trade")
  84. aux_buy_match = _event_match_report(workbook_events, events, "BUY", "aux_signal")
  85. aux_sell_match = _event_match_report(workbook_events, events, "SELL", "aux_signal")
  86. avg_mfe, avg_mae = _trade_quality(trades, indicator_df)
  87. win_rate = float((trades["return_pct"] > 0).mean()) if not trades.empty else float("nan")
  88. avg_return = float(trades["return_pct"].mean()) if not trades.empty else float("nan")
  89. median_return = float(trades["return_pct"].median()) if not trades.empty else float("nan")
  90. profit_factor = _profit_factor(trades["return_pct"]) if not trades.empty else float("nan")
  91. return {
  92. "experiment": label,
  93. "disabled_rules": "|".join(sorted(config.disabled_rules)),
  94. "aux_sell_same_side_once_per_cycle": config.aux_sell_same_side_once_per_cycle,
  95. "enable_knife_take_profit_2_wait_ql": config.enable_knife_take_profit_2_wait_ql,
  96. "trades": int(len(trades)),
  97. "win_rate": win_rate,
  98. "avg_return": avg_return,
  99. "median_return": median_return,
  100. "profit_factor": profit_factor,
  101. "avg_mfe_pct": avg_mfe,
  102. "avg_mae_pct": avg_mae,
  103. "real_buy_overlap": int(buy_match["overlap"]),
  104. "real_buy_missing": int(buy_match["missing"]),
  105. "real_buy_extra": int(buy_match["extra"]),
  106. "real_sell_overlap": int(sell_match["overlap"]),
  107. "real_sell_missing": int(sell_match["missing"]),
  108. "real_sell_extra": int(sell_match["extra"]),
  109. "aux_buy_overlap": int(aux_buy_match["overlap"]),
  110. "aux_buy_missing": int(aux_buy_match["missing"]),
  111. "aux_buy_extra": int(aux_buy_match["extra"]),
  112. "aux_sell_overlap": int(aux_sell_match["overlap"]),
  113. "aux_sell_missing": int(aux_sell_match["missing"]),
  114. "aux_sell_extra": int(aux_sell_match["extra"]),
  115. }
  116. def main() -> None:
  117. base_dir = Path(__file__).resolve().parent
  118. workbook_path = _find_workbook(base_dir)
  119. workbook_events = _load_workbook_events(workbook_path)
  120. first_workbook_date = pd.to_datetime(workbook_events["date"]).min().date().isoformat()
  121. last_workbook_date = pd.to_datetime(workbook_events["date"]).max().date().isoformat()
  122. engine = DragonIndicatorEngine(DragonIndicatorConfig(start_date="2015-01-01", end_date="2026-01-31"))
  123. indicator_df = engine.compute(engine.fetch_daily_data())
  124. baseline_config = StrategyConfig()
  125. experiments: list[tuple[str, StrategyConfig]] = [
  126. ("baseline", baseline_config),
  127. ("disable_entry_glued_buy", baseline_config.with_updates(disabled_rules={"glued_buy"})),
  128. ("disable_entry_deep_oversold_rebound_buy", baseline_config.with_updates(disabled_rules={"deep_oversold_rebound_buy"})),
  129. ("disable_entry_oversold_recovery_buy", baseline_config.with_updates(disabled_rules={"oversold_recovery_buy"})),
  130. ("disable_entry_post_sell_rebound_buy", baseline_config.with_updates(disabled_rules={"post_sell_rebound_buy"})),
  131. ("disable_entry_oversold_reversal_after_ql_buy", baseline_config.with_updates(disabled_rules={"oversold_reversal_after_ql_buy"})),
  132. ("disable_entry_non_glued_positive_expansion_buy", baseline_config.with_updates(disabled_rules={"non_glued_positive_expansion_buy"})),
  133. ("disable_entry_early_crash_probe_buy", baseline_config.with_updates(disabled_rules={"early_crash_probe_buy"})),
  134. ("disable_entry_dual_gold_resonance_buy", baseline_config.with_updates(disabled_rules={"dual_gold_resonance_buy"})),
  135. ("disable_exit_knife_take_profit_1", baseline_config.with_updates(disabled_rules={"knife_take_profit_1"})),
  136. ("disable_exit_knife_take_profit_2_glued", baseline_config.with_updates(disabled_rules={"knife_take_profit_2_glued"})),
  137. ("disable_exit_ql_mid_zone_take_profit", baseline_config.with_updates(disabled_rules={"ql_mid_zone_take_profit"})),
  138. ("disable_exit_high_regime_confirmed_exit_kdj", baseline_config.with_updates(disabled_rules={"high_regime_confirmed_exit:kdj_sell"})),
  139. ("disable_exit_predictive_b1_break_exit", baseline_config.with_updates(disabled_rules={"predictive_b1_break_exit"})),
  140. ("disable_exit_prewarning_reduction_exit", baseline_config.with_updates(disabled_rules={"prewarning_reduction_exit"})),
  141. ("disable_exit_crash_protection_exit", baseline_config.with_updates(disabled_rules={"crash_protection_exit"})),
  142. ("disable_aux_same_side_cycle_cap", baseline_config.with_updates(aux_sell_same_side_once_per_cycle=False)),
  143. ("disable_knife_take_profit_2_wait_ql", baseline_config.with_updates(enable_knife_take_profit_2_wait_ql=False)),
  144. ]
  145. rows = [
  146. _run_single_experiment(
  147. label,
  148. config,
  149. workbook_events,
  150. indicator_df,
  151. first_workbook_date,
  152. last_workbook_date,
  153. )
  154. for label, config in experiments
  155. ]
  156. result_df = pd.DataFrame(rows)
  157. baseline_row = result_df[result_df["experiment"] == "baseline"].iloc[0]
  158. for col in [
  159. "trades",
  160. "win_rate",
  161. "avg_return",
  162. "median_return",
  163. "profit_factor",
  164. "avg_mfe_pct",
  165. "avg_mae_pct",
  166. "real_buy_overlap",
  167. "real_sell_overlap",
  168. "aux_sell_overlap",
  169. ]:
  170. result_df[f"delta_{col}"] = result_df[col] - baseline_row[col]
  171. result_df.to_csv(base_dir / "dragon_rule_ablation.csv", index=False, encoding="utf-8-sig")
  172. protected = result_df[
  173. (result_df["experiment"] != "baseline")
  174. & (result_df["real_buy_overlap"] == baseline_row["real_buy_overlap"])
  175. & (result_df["real_sell_overlap"] == baseline_row["real_sell_overlap"])
  176. ].copy()
  177. protected_sorted = protected.sort_values("delta_avg_return", ascending=False)
  178. harmful_sorted = result_df[result_df["experiment"] != "baseline"].sort_values("delta_avg_return")
  179. lines = [
  180. "# Dragon Rule Ablation",
  181. "",
  182. "## Baseline",
  183. f"- trades: `{int(baseline_row['trades'])}`",
  184. f"- win_rate: `{baseline_row['win_rate']:.2%}`",
  185. f"- avg_return: `{baseline_row['avg_return']:.2%}`",
  186. f"- profit_factor: `{baseline_row['profit_factor']:.2f}`",
  187. f"- real BUY overlap: `{int(baseline_row['real_buy_overlap'])}`",
  188. f"- real SELL overlap: `{int(baseline_row['real_sell_overlap'])}`",
  189. "",
  190. "## Protected Experiments",
  191. "- Interpretation: these experiments preserved current real-trade overlap and only changed quality or auxiliary behavior.",
  192. ]
  193. if protected_sorted.empty:
  194. lines.append("- None.")
  195. else:
  196. for _, row in protected_sorted.head(8).iterrows():
  197. lines.append(
  198. f"- `{row['experiment']}`: delta_avg_return `{row['delta_avg_return']:.2%}`, "
  199. f"delta_profit_factor `{row['delta_profit_factor']:.2f}`, delta_aux_sell_overlap `{int(row['delta_aux_sell_overlap'])}`"
  200. )
  201. lines.extend(["", "## Most Harmful Removals"])
  202. for _, row in harmful_sorted.head(8).iterrows():
  203. lines.append(
  204. f"- `{row['experiment']}`: delta_avg_return `{row['delta_avg_return']:.2%}`, "
  205. f"real BUY `{int(row['real_buy_overlap'])}`, real SELL `{int(row['real_sell_overlap'])}`, "
  206. f"delta_trades `{int(row['delta_trades'])}`"
  207. )
  208. lines.extend(["", "## Best Removal Candidates"])
  209. best_candidates = result_df[
  210. (result_df["experiment"] != "baseline")
  211. & (result_df["delta_avg_return"] > 0)
  212. ].sort_values(["real_buy_overlap", "real_sell_overlap", "delta_avg_return"], ascending=[False, False, False])
  213. for _, row in best_candidates.head(8).iterrows():
  214. lines.append(
  215. f"- `{row['experiment']}`: delta_avg_return `{row['delta_avg_return']:.2%}`, "
  216. f"delta_profit_factor `{row['delta_profit_factor']:.2f}`, "
  217. f"real BUY `{int(row['real_buy_overlap'])}`, real SELL `{int(row['real_sell_overlap'])}`"
  218. )
  219. (base_dir / "dragon_rule_ablation.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
  220. if __name__ == "__main__":
  221. main()