Sfoglia il codice sorgente

Refresh T-day fetch and conditionally write back cache

erwin 1 mese fa
parent
commit
77ff5d4c17

+ 164 - 8
dragon/data_fetcher_v2.py

@@ -1,10 +1,26 @@
 import akshare as ak
 import pandas as pd
 import numpy as np
+from dataclasses import dataclass
 from datetime import datetime, timedelta
 from typing import Optional, Union, List, Tuple
 import time
 
+try:
+    import requests
+except ImportError:  # pragma: no cover
+    requests = None
+
+
+@dataclass
+class LatestSnapshot:
+    timestamp: datetime
+    open: float
+    high: float
+    low: float
+    close: float
+    volume: float = 0.0
+
 class DataFetcherV2:
     """
     数据获取类V2 - 基于用户提供的优化方法
@@ -30,6 +46,27 @@ class DataFetcherV2:
         """设置缓存"""
         self.cache[cache_key] = data
         self.cache_expiry[cache_key] = time.time() + self.cache_duration
+
+    def _should_force_refresh_t_day(self, end_date: str) -> bool:
+        """
+        是否应强制刷新当日(T日)请求。
+        当请求窗口覆盖今天时,不直接使用旧缓存,避免拿到过期的当日数据。
+        """
+        try:
+            return pd.to_datetime(end_date).date() >= datetime.now().date()
+        except Exception:
+            return False
+
+    def _writeback_cache_if_exists(self, cache_key: str, data: pd.DataFrame) -> None:
+        """
+        仅在缓存键已存在时回写缓存。
+        若缓存键不存在,则跳过(符合“有则回写、无则算了”)。
+        """
+        if cache_key not in self.cache:
+            return
+        payload = data.copy()
+        payload.attrs.update(data.attrs)
+        self._set_cache(cache_key, payload)
     
     def _format_index_code(self, symbol: str) -> str:
         """
@@ -56,6 +93,70 @@ class DataFetcherV2:
         
         # 如果已经是格式化好的代码
         return symbol.lower()
+
+    def _infer_realtime_prefix(self, code: str) -> str:
+        if code.startswith("399"):
+            return "sz"
+        if code.startswith("000"):
+            return "sh"
+        if code.startswith(("30", "00", "15")):
+            return "sz"
+        if code.startswith(("60", "68")):
+            return "sh"
+        return "sz"
+
+    def fetch_latest_snapshot(self, symbol: str) -> Optional[LatestSnapshot]:
+        if requests is None:
+            return None
+
+        formatted_symbol = self._format_index_code(symbol)
+        if formatted_symbol.startswith(("sz", "sh")):
+            prefix = formatted_symbol[:2]
+            code = formatted_symbol[2:]
+        else:
+            code = formatted_symbol
+            prefix = self._infer_realtime_prefix(code)
+
+        url = f"http://hq.sinajs.cn/list={prefix}{code}"
+        headers = {
+            "User-Agent": "Mozilla/5.0",
+            "Referer": "http://finance.sina.com.cn",
+        }
+
+        try:
+            response = requests.get(url, headers=headers, timeout=10)
+            response.raise_for_status()
+        except Exception:
+            return None
+
+        response.encoding = "gbk"
+        text = response.text
+        if '"' not in text:
+            return None
+
+        try:
+            payload = text.split('"')[1].split(",")
+            if len(payload) < 6:
+                return None
+            open_price = float(payload[1])
+            prev_close = float(payload[2])
+            close_price = float(payload[3])
+            high_price = float(payload[4])
+            low_price = float(payload[5])
+        except (ValueError, IndexError):
+            return None
+
+        if close_price <= 0 or prev_close <= 0 or high_price <= 0 or low_price <= 0:
+            return None
+
+        return LatestSnapshot(
+            timestamp=datetime.now(),
+            open=open_price,
+            high=high_price,
+            low=low_price,
+            close=close_price,
+            volume=0.0,
+        )
     
     def fetch_index_data_v2(self, 
                             symbol: str, 
@@ -72,12 +173,11 @@ class DataFetcherV2:
         Returns:
             包含OHLCV数据的DataFrame,索引为日期
         """
-        if end_date is None:
-            end_date = datetime.now().strftime('%Y-%m-%d')
-        
-        cache_key = self._get_cache_key(symbol, start_date, end_date)
-        
-        if self._is_cache_valid(cache_key):
+        resolved_end_date = end_date or datetime.now().strftime('%Y-%m-%d')
+        cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
+        force_refresh_t_day = self._should_force_refresh_t_day(resolved_end_date)
+
+        if self._is_cache_valid(cache_key) and not force_refresh_t_day:
             return self.cache[cache_key].copy()
         
         try:
@@ -98,7 +198,7 @@ class DataFetcherV2:
             
             # 筛选日期范围
             start_datetime = pd.to_datetime(start_date)
-            end_datetime = pd.to_datetime(end_date)
+            end_datetime = pd.to_datetime(resolved_end_date)
             
             # 先筛选出指定日期之后的数据
             filtered_df = all_data_df[all_data_df.index >= start_datetime]
@@ -115,12 +215,68 @@ class DataFetcherV2:
             print(f"获取数据量: {len(filtered_df)} 条")
             
             # 缓存数据
+            filtered_df.attrs["intraday_snapshot_appended"] = False
+            filtered_df.attrs["intraday_snapshot_timestamp"] = None
+            filtered_df.attrs["historical_latest_bar_date"] = filtered_df.index[-1].date().isoformat()
             self._set_cache(cache_key, filtered_df)
             return filtered_df.copy()
             
         except Exception as e:
             print(f"Error fetching index data for {symbol}: {str(e)}")
             return pd.DataFrame()
+
+    def fetch_index_data_with_latest_snapshot_v2(
+        self,
+        symbol: str,
+        start_date: str = "2018-01-01",
+        end_date: Optional[str] = None,
+    ) -> pd.DataFrame:
+        resolved_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
+        cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
+
+        frame = self.fetch_index_data_v2(symbol=symbol, start_date=start_date, end_date=resolved_end_date)
+        if frame.empty:
+            return frame
+
+        today = datetime.now().date()
+        historical_latest_bar_date = frame.index[-1].date().isoformat()
+
+        frame.attrs["intraday_snapshot_appended"] = False
+        frame.attrs["intraday_snapshot_timestamp"] = None
+        frame.attrs["historical_latest_bar_date"] = historical_latest_bar_date
+
+        if end_date is not None and pd.to_datetime(end_date).date() < today:
+            return frame
+        if frame.index[-1].date() >= today:
+            return frame
+        if today.weekday() >= 5:
+            return frame
+
+        snapshot = self.fetch_latest_snapshot(symbol)
+        if snapshot is None:
+            return frame
+
+        latest_row = pd.DataFrame(
+            [
+                {
+                    "open": snapshot.open,
+                    "high": snapshot.high,
+                    "low": snapshot.low,
+                    "close": snapshot.close,
+                    "volume": snapshot.volume,
+                }
+            ],
+            index=pd.DatetimeIndex([pd.Timestamp(snapshot.timestamp)]),
+        )
+        latest_row.index.name = "date"
+
+        merged = pd.concat([frame, latest_row])
+        merged = merged[~merged.index.duplicated(keep="last")].sort_index()
+        merged.attrs["intraday_snapshot_appended"] = True
+        merged.attrs["intraday_snapshot_timestamp"] = snapshot.timestamp.isoformat(timespec="seconds")
+        merged.attrs["historical_latest_bar_date"] = historical_latest_bar_date
+        self._writeback_cache_if_exists(cache_key, merged)
+        return merged
     
     def fetch_stock_data_v2(self, 
                            symbol: str, 
@@ -317,4 +473,4 @@ class DataManagerV2:
             print(f"{symbol}:")
             print(f"  Period: {info['start_date']} to {info['end_date']}")
             print(f"  Total bars: {info['total_bars']}")
-        print("="*70)
+        print("="*70)

+ 4 - 0
research/dragon/v2/MEMORY.md

@@ -1112,3 +1112,7 @@
 - one-click tracking entry `update_dragon_reports.ps1` was upgraded:
 - now prints latest bar + rollout decision + active/fallback branch + gate counts.
 - supports `-StrictGate` (exit `2` when decision is not `FORWARD_OK`) and `-OpenReport`.
+- T-day data freshness/caching fix was applied in `dragon/data_fetcher_v2.py`:
+- when request end date covers today, `fetch_index_data_v2` now bypasses stale cache and refreshes source data.
+- `fetch_index_data_with_latest_snapshot_v2` now writes merged latest snapshot back only when cache key already exists (skip if missing).
+- added test file `tests/test_data_fetcher_tday_cache.py`; full suite reached `20` passing tests.

+ 10 - 0
research/dragon/v2/memory/2026-04-09.md

@@ -124,3 +124,13 @@
 - added optional switches:
 - `-StrictGate` (non-`FORWARD_OK` exits with code `2`)
 - `-OpenReport` (opens `dragon_reports_index.html` after run).
+- Updated `dragon/data_fetcher_v2.py` for T-day freshness + cache writeback behavior:
+- `fetch_index_data_v2(...)` now forces refresh when `end_date` covers today (`T日`) instead of reusing stale in-memory cache.
+- `fetch_index_data_with_latest_snapshot_v2(...)` now writes merged latest-snapshot data back only if the matching cache key already exists.
+- If cache key does not exist, writeback is skipped by design (`有则回写,无则跳过`).
+- Added regression tests:
+- `tests/test_data_fetcher_tday_cache.py`
+- covers: non-T-day cache hit, T-day forced refresh, conditional cache writeback.
+- Validation after change:
+- `py -3 -m unittest discover -s tests -v` passed (`20` tests).
+- `py -3 dragon_daily_signal_pipeline.py` smoke run passed.

+ 134 - 0
research/dragon/v2/tests/test_data_fetcher_tday_cache.py

@@ -0,0 +1,134 @@
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+from pathlib import Path
+import sys
+import unittest
+
+import pandas as pd
+
+
+DRAGON_ROOT = Path(__file__).resolve().parents[4] / "dragon"
+if str(DRAGON_ROOT) not in sys.path:
+    sys.path.append(str(DRAGON_ROOT))
+
+import data_fetcher_v2 as fetcher_module  # noqa: E402
+from data_fetcher_v2 import DataFetcherV2, LatestSnapshot  # noqa: E402
+
+
+def _make_index_frame(dates: list[pd.Timestamp]) -> pd.DataFrame:
+    rows = []
+    for idx, day in enumerate(dates):
+        base = float(100 + idx)
+        rows.append(
+            {
+                "date": day,
+                "open": base,
+                "high": base + 1.0,
+                "low": base - 1.0,
+                "close": base + 0.5,
+                "volume": 10_000 + idx,
+            }
+        )
+    df = pd.DataFrame(rows)
+    df["date"] = pd.to_datetime(df["date"])
+    return df
+
+
+class TestDataFetcherTDayCache(unittest.TestCase):
+    def setUp(self) -> None:
+        self.fetcher = DataFetcherV2()
+        self.today = datetime.now().date()
+        self.yesterday = self.today - timedelta(days=1)
+
+    def test_non_tday_uses_valid_cache(self) -> None:
+        end_date = self.yesterday.strftime("%Y-%m-%d")
+        cache_key = self.fetcher._get_cache_key("399673", "2018-01-01", end_date)
+
+        cached = _make_index_frame(
+            [
+                pd.Timestamp(self.yesterday - timedelta(days=1)),
+                pd.Timestamp(self.yesterday),
+            ]
+        ).set_index("date")
+        self.fetcher._set_cache(cache_key, cached)
+
+        original = fetcher_module.ak.stock_zh_index_daily
+
+        def _fail_call(symbol: str) -> pd.DataFrame:
+            raise AssertionError("ak.stock_zh_index_daily should not be called when non-tday cache is valid")
+
+        fetcher_module.ak.stock_zh_index_daily = _fail_call
+        try:
+            got = self.fetcher.fetch_index_data_v2("399673", "2018-01-01", end_date=end_date)
+        finally:
+            fetcher_module.ak.stock_zh_index_daily = original
+
+        self.assertEqual(len(got), len(cached))
+        self.assertEqual(got.index[-1].date(), self.yesterday)
+
+    def test_tday_forces_refresh_even_when_cache_valid(self) -> None:
+        end_date = self.today.strftime("%Y-%m-%d")
+        cache_key = self.fetcher._get_cache_key("399673", "2018-01-01", end_date)
+
+        stale = _make_index_frame(
+            [
+                pd.Timestamp(self.yesterday - timedelta(days=1)),
+                pd.Timestamp(self.yesterday),
+            ]
+        ).set_index("date")
+        self.fetcher._set_cache(cache_key, stale)
+
+        call_counter = {"n": 0}
+        original = fetcher_module.ak.stock_zh_index_daily
+
+        def _mock_call(symbol: str) -> pd.DataFrame:
+            call_counter["n"] += 1
+            return _make_index_frame([pd.Timestamp(self.yesterday), pd.Timestamp(self.today)])
+
+        fetcher_module.ak.stock_zh_index_daily = _mock_call
+        try:
+            got = self.fetcher.fetch_index_data_v2("399673", "2018-01-01", end_date=end_date)
+        finally:
+            fetcher_module.ak.stock_zh_index_daily = original
+
+        self.assertEqual(call_counter["n"], 1)
+        self.assertEqual(got.index[-1].date(), self.today)
+
+    def test_snapshot_merge_writes_back_only_when_cache_exists(self) -> None:
+        start_date = "2018-01-01"
+        end_date = self.today.strftime("%Y-%m-%d")
+        cache_key = self.fetcher._get_cache_key("399673", start_date, end_date)
+
+        historical = _make_index_frame([pd.Timestamp(self.yesterday)]).set_index("date")
+        historical.attrs["intraday_snapshot_appended"] = False
+        historical.attrs["intraday_snapshot_timestamp"] = None
+        historical.attrs["historical_latest_bar_date"] = self.yesterday.isoformat()
+
+        self.fetcher._set_cache(cache_key, historical)
+
+        self.fetcher.fetch_index_data_v2 = lambda symbol, start_date, end_date: historical.copy()  # type: ignore[method-assign]
+        self.fetcher.fetch_latest_snapshot = lambda symbol: LatestSnapshot(  # type: ignore[method-assign]
+            timestamp=datetime.combine(self.today, datetime.min.time()),
+            open=101.0,
+            high=102.0,
+            low=100.0,
+            close=101.5,
+            volume=0.0,
+        )
+
+        merged = self.fetcher.fetch_index_data_with_latest_snapshot_v2("399673", start_date, end_date=end_date)
+        self.assertTrue(bool(merged.attrs.get("intraday_snapshot_appended", False)))
+        self.assertIn(cache_key, self.fetcher.cache)
+        self.assertEqual(self.fetcher.cache[cache_key].index[-1].date(), self.today)
+
+        other = DataFetcherV2()
+        other.fetch_index_data_v2 = lambda symbol, start_date, end_date: historical.copy()  # type: ignore[method-assign]
+        other.fetch_latest_snapshot = self.fetcher.fetch_latest_snapshot  # type: ignore[method-assign]
+        other.fetch_index_data_with_latest_snapshot_v2("399673", start_date, end_date=end_date)
+        self.assertNotIn(cache_key, other.cache)
+
+
+if __name__ == "__main__":
+    unittest.main()
+