dragon_indicators.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from datetime import datetime
  4. from pathlib import Path
  5. from typing import Optional
  6. import numpy as np
  7. import pandas as pd
  8. import sys
  9. REPO_ROOT = Path(__file__).resolve().parents[3]
  10. ROOT_DRAGON_DIR = REPO_ROOT / "dragon"
  11. if not ROOT_DRAGON_DIR.exists():
  12. raise ModuleNotFoundError(f"Expected dragon dependency directory at {ROOT_DRAGON_DIR}")
  13. if str(ROOT_DRAGON_DIR) not in sys.path:
  14. sys.path.append(str(ROOT_DRAGON_DIR))
  15. import MyTT # noqa: E402
  16. from data_fetcher_v2 import DataFetcherV2 # noqa: E402
  17. @dataclass
  18. class DragonIndicatorConfig:
  19. symbol: str = "399673"
  20. start_date: str = "2015-01-01"
  21. end_date: Optional[str] = None
  22. def _cross_up(left: np.ndarray, right: np.ndarray) -> np.ndarray:
  23. left_prev = np.roll(left, 1)
  24. right_prev = np.roll(right, 1)
  25. result = (left > right) & (left_prev <= right_prev)
  26. result[0] = False
  27. return result
  28. class DragonIndicatorEngine:
  29. def __init__(self, config: Optional[DragonIndicatorConfig] = None):
  30. self.config = config or DragonIndicatorConfig()
  31. self.fetcher = DataFetcherV2()
  32. self.last_fetch_meta: dict[str, object] = {}
  33. def fetch_daily_data(self, include_intraday_snapshot: bool = False) -> pd.DataFrame:
  34. end_date = self.config.end_date or datetime.now().strftime("%Y-%m-%d")
  35. if include_intraday_snapshot:
  36. df = self.fetcher.fetch_index_data_with_latest_snapshot_v2(
  37. symbol=self.config.symbol,
  38. start_date=self.config.start_date,
  39. end_date=end_date,
  40. )
  41. else:
  42. df = self.fetcher.fetch_index_data_v2(
  43. symbol=self.config.symbol,
  44. start_date=self.config.start_date,
  45. end_date=end_date,
  46. )
  47. if df.empty:
  48. raise RuntimeError(f"Failed to fetch daily data for {self.config.symbol}")
  49. self.last_fetch_meta = {
  50. "intraday_snapshot_appended": bool(df.attrs.get("intraday_snapshot_appended", False)),
  51. "intraday_snapshot_timestamp": df.attrs.get("intraday_snapshot_timestamp"),
  52. "historical_latest_bar_date": df.attrs.get(
  53. "historical_latest_bar_date",
  54. df.index[-1].date().isoformat(),
  55. ),
  56. }
  57. result = df.sort_index().copy()
  58. result.attrs.update(self.last_fetch_meta)
  59. return result
  60. def compute(self, df: pd.DataFrame) -> pd.DataFrame:
  61. if df.empty:
  62. return df.copy()
  63. result = df.copy()
  64. close = result["close"].to_numpy(dtype=float)
  65. high = result["high"].to_numpy(dtype=float)
  66. low = result["low"].to_numpy(dtype=float)
  67. open_ = result["open"].to_numpy(dtype=float)
  68. h1_5 = np.nan_to_num(MyTT.EMA(close, 8), nan=0.0)
  69. h2_5 = np.nan_to_num(MyTT.EMA(h1_5, 20), nan=0.0)
  70. rsv = np.nan_to_num((close - MyTT.LLV(low, 7)) / (MyTT.HHV(high, 7) - MyTT.LLV(low, 7)) * 100)
  71. y0 = np.nan_to_num(MyTT.SMA(rsv, 3, 1), nan=0.0)
  72. y1 = np.nan_to_num(MyTT.SMA(y0, 3, 1), nan=0.0)
  73. rsv1 = np.nan_to_num((close - MyTT.LLV(low, 38)) / (MyTT.HHV(high, 38) - MyTT.LLV(low, 38)) * 100)
  74. y2 = np.nan_to_num(MyTT.SMA(rsv1, 5, 1), nan=0.0)
  75. y3 = np.nan_to_num(MyTT.SMA(y2, 10, 1), nan=0.0)
  76. avg_h = (h1_5 + h2_5) / 2.0
  77. a1 = np.divide(h1_5 - h2_5, avg_h, out=np.zeros_like(avg_h), where=avg_h != 0)
  78. b1 = (y2 - y3) / 100.0
  79. c1 = (y2 + y3) / 2.0
  80. xopen = (MyTT.REF(open_, 1) + MyTT.REF(close, 1)) / 2.0
  81. xopen = np.nan_to_num(xopen, nan=close)
  82. xclose = close
  83. xhigh = np.maximum(high, xopen)
  84. xlow = np.minimum(low, xopen)
  85. ql_volatility = np.nan_to_num(MyTT.MA(xhigh - xlow, 8), nan=0.0)
  86. ql_mid = np.nan_to_num(MyTT.MA(xclose, 5), nan=0.0)
  87. ql_upper = ql_mid + ql_volatility / 2.0
  88. ql_lower = ql_mid - ql_volatility / 2.0
  89. kdj_buy = _cross_up(y0, y1)
  90. kdj_sell = _cross_up(y1, y0)
  91. ql_buy = _cross_up(xclose, ql_upper)
  92. ql_sell = _cross_up(ql_lower, xclose)
  93. result["h1_5"] = h1_5
  94. result["h2_5"] = h2_5
  95. result["a1"] = a1
  96. result["y0"] = y0
  97. result["y1"] = y1
  98. result["kdj_buy"] = kdj_buy
  99. result["kdj_sell"] = kdj_sell
  100. result["y2"] = y2
  101. result["y3"] = y3
  102. result["b1"] = b1
  103. result["c1"] = c1
  104. result["ql_xopen"] = xopen
  105. result["ql_upper"] = ql_upper
  106. result["ql_lower"] = ql_lower
  107. result["ql_buy"] = ql_buy
  108. result["ql_sell"] = ql_sell
  109. return result