Browse Source

Refresh T-day fetch and conditionally write back cache

erwin 1 month ago
parent
commit
77ff5d4c17

+ 164 - 8
dragon/data_fetcher_v2.py

@@ -1,10 +1,26 @@
 import akshare as ak
 import akshare as ak
 import pandas as pd
 import pandas as pd
 import numpy as np
 import numpy as np
+from dataclasses import dataclass
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
 from typing import Optional, Union, List, Tuple
 from typing import Optional, Union, List, Tuple
 import time
 import time
 
 
+try:
+    import requests
+except ImportError:  # pragma: no cover
+    requests = None
+
+
+@dataclass
+class LatestSnapshot:
+    timestamp: datetime
+    open: float
+    high: float
+    low: float
+    close: float
+    volume: float = 0.0
+
 class DataFetcherV2:
 class DataFetcherV2:
     """
     """
     数据获取类V2 - 基于用户提供的优化方法
     数据获取类V2 - 基于用户提供的优化方法
@@ -30,6 +46,27 @@ class DataFetcherV2:
         """设置缓存"""
         """设置缓存"""
         self.cache[cache_key] = data
         self.cache[cache_key] = data
         self.cache_expiry[cache_key] = time.time() + self.cache_duration
         self.cache_expiry[cache_key] = time.time() + self.cache_duration
+
+    def _should_force_refresh_t_day(self, end_date: str) -> bool:
+        """
+        是否应强制刷新当日(T日)请求。
+        当请求窗口覆盖今天时,不直接使用旧缓存,避免拿到过期的当日数据。
+        """
+        try:
+            return pd.to_datetime(end_date).date() >= datetime.now().date()
+        except Exception:
+            return False
+
+    def _writeback_cache_if_exists(self, cache_key: str, data: pd.DataFrame) -> None:
+        """
+        仅在缓存键已存在时回写缓存。
+        若缓存键不存在,则跳过(符合“有则回写、无则算了”)。
+        """
+        if cache_key not in self.cache:
+            return
+        payload = data.copy()
+        payload.attrs.update(data.attrs)
+        self._set_cache(cache_key, payload)
     
     
     def _format_index_code(self, symbol: str) -> str:
     def _format_index_code(self, symbol: str) -> str:
         """
         """
@@ -56,6 +93,70 @@ class DataFetcherV2:
         
         
         # 如果已经是格式化好的代码
         # 如果已经是格式化好的代码
         return symbol.lower()
         return symbol.lower()
+
+    def _infer_realtime_prefix(self, code: str) -> str:
+        if code.startswith("399"):
+            return "sz"
+        if code.startswith("000"):
+            return "sh"
+        if code.startswith(("30", "00", "15")):
+            return "sz"
+        if code.startswith(("60", "68")):
+            return "sh"
+        return "sz"
+
+    def fetch_latest_snapshot(self, symbol: str) -> Optional[LatestSnapshot]:
+        if requests is None:
+            return None
+
+        formatted_symbol = self._format_index_code(symbol)
+        if formatted_symbol.startswith(("sz", "sh")):
+            prefix = formatted_symbol[:2]
+            code = formatted_symbol[2:]
+        else:
+            code = formatted_symbol
+            prefix = self._infer_realtime_prefix(code)
+
+        url = f"http://hq.sinajs.cn/list={prefix}{code}"
+        headers = {
+            "User-Agent": "Mozilla/5.0",
+            "Referer": "http://finance.sina.com.cn",
+        }
+
+        try:
+            response = requests.get(url, headers=headers, timeout=10)
+            response.raise_for_status()
+        except Exception:
+            return None
+
+        response.encoding = "gbk"
+        text = response.text
+        if '"' not in text:
+            return None
+
+        try:
+            payload = text.split('"')[1].split(",")
+            if len(payload) < 6:
+                return None
+            open_price = float(payload[1])
+            prev_close = float(payload[2])
+            close_price = float(payload[3])
+            high_price = float(payload[4])
+            low_price = float(payload[5])
+        except (ValueError, IndexError):
+            return None
+
+        if close_price <= 0 or prev_close <= 0 or high_price <= 0 or low_price <= 0:
+            return None
+
+        return LatestSnapshot(
+            timestamp=datetime.now(),
+            open=open_price,
+            high=high_price,
+            low=low_price,
+            close=close_price,
+            volume=0.0,
+        )
     
     
     def fetch_index_data_v2(self, 
     def fetch_index_data_v2(self, 
                             symbol: str, 
                             symbol: str, 
@@ -72,12 +173,11 @@ class DataFetcherV2:
         Returns:
         Returns:
             包含OHLCV数据的DataFrame,索引为日期
             包含OHLCV数据的DataFrame,索引为日期
         """
         """
-        if end_date is None:
-            end_date = datetime.now().strftime('%Y-%m-%d')
-        
-        cache_key = self._get_cache_key(symbol, start_date, end_date)
-        
-        if self._is_cache_valid(cache_key):
+        resolved_end_date = end_date or datetime.now().strftime('%Y-%m-%d')
+        cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
+        force_refresh_t_day = self._should_force_refresh_t_day(resolved_end_date)
+
+        if self._is_cache_valid(cache_key) and not force_refresh_t_day:
             return self.cache[cache_key].copy()
             return self.cache[cache_key].copy()
         
         
         try:
         try:
@@ -98,7 +198,7 @@ class DataFetcherV2:
             
             
             # 筛选日期范围
             # 筛选日期范围
             start_datetime = pd.to_datetime(start_date)
             start_datetime = pd.to_datetime(start_date)
-            end_datetime = pd.to_datetime(end_date)
+            end_datetime = pd.to_datetime(resolved_end_date)
             
             
             # 先筛选出指定日期之后的数据
             # 先筛选出指定日期之后的数据
             filtered_df = all_data_df[all_data_df.index >= start_datetime]
             filtered_df = all_data_df[all_data_df.index >= start_datetime]
@@ -115,12 +215,68 @@ class DataFetcherV2:
             print(f"获取数据量: {len(filtered_df)} 条")
             print(f"获取数据量: {len(filtered_df)} 条")
             
             
             # 缓存数据
             # 缓存数据
+            filtered_df.attrs["intraday_snapshot_appended"] = False
+            filtered_df.attrs["intraday_snapshot_timestamp"] = None
+            filtered_df.attrs["historical_latest_bar_date"] = filtered_df.index[-1].date().isoformat()
             self._set_cache(cache_key, filtered_df)
             self._set_cache(cache_key, filtered_df)
             return filtered_df.copy()
             return filtered_df.copy()
             
             
         except Exception as e:
         except Exception as e:
             print(f"Error fetching index data for {symbol}: {str(e)}")
             print(f"Error fetching index data for {symbol}: {str(e)}")
             return pd.DataFrame()
             return pd.DataFrame()
+
+    def fetch_index_data_with_latest_snapshot_v2(
+        self,
+        symbol: str,
+        start_date: str = "2018-01-01",
+        end_date: Optional[str] = None,
+    ) -> pd.DataFrame:
+        resolved_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
+        cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
+
+        frame = self.fetch_index_data_v2(symbol=symbol, start_date=start_date, end_date=resolved_end_date)
+        if frame.empty:
+            return frame
+
+        today = datetime.now().date()
+        historical_latest_bar_date = frame.index[-1].date().isoformat()
+
+        frame.attrs["intraday_snapshot_appended"] = False
+        frame.attrs["intraday_snapshot_timestamp"] = None
+        frame.attrs["historical_latest_bar_date"] = historical_latest_bar_date
+
+        if end_date is not None and pd.to_datetime(end_date).date() < today:
+            return frame
+        if frame.index[-1].date() >= today:
+            return frame
+        if today.weekday() >= 5:
+            return frame
+
+        snapshot = self.fetch_latest_snapshot(symbol)
+        if snapshot is None:
+            return frame
+
+        latest_row = pd.DataFrame(
+            [
+                {
+                    "open": snapshot.open,
+                    "high": snapshot.high,
+                    "low": snapshot.low,
+                    "close": snapshot.close,
+                    "volume": snapshot.volume,
+                }
+            ],
+            index=pd.DatetimeIndex([pd.Timestamp(snapshot.timestamp)]),
+        )
+        latest_row.index.name = "date"
+
+        merged = pd.concat([frame, latest_row])
+        merged = merged[~merged.index.duplicated(keep="last")].sort_index()
+        merged.attrs["intraday_snapshot_appended"] = True
+        merged.attrs["intraday_snapshot_timestamp"] = snapshot.timestamp.isoformat(timespec="seconds")
+        merged.attrs["historical_latest_bar_date"] = historical_latest_bar_date
+        self._writeback_cache_if_exists(cache_key, merged)
+        return merged
     
     
     def fetch_stock_data_v2(self, 
     def fetch_stock_data_v2(self, 
                            symbol: str, 
                            symbol: str, 
@@ -317,4 +473,4 @@ class DataManagerV2:
             print(f"{symbol}:")
             print(f"{symbol}:")
             print(f"  Period: {info['start_date']} to {info['end_date']}")
             print(f"  Period: {info['start_date']} to {info['end_date']}")
             print(f"  Total bars: {info['total_bars']}")
             print(f"  Total bars: {info['total_bars']}")
-        print("="*70)
+        print("="*70)

+ 4 - 0
research/dragon/v2/MEMORY.md

@@ -1112,3 +1112,7 @@
 - one-click tracking entry `update_dragon_reports.ps1` was upgraded:
 - one-click tracking entry `update_dragon_reports.ps1` was upgraded:
 - now prints latest bar + rollout decision + active/fallback branch + gate counts.
 - now prints latest bar + rollout decision + active/fallback branch + gate counts.
 - supports `-StrictGate` (exit `2` when decision is not `FORWARD_OK`) and `-OpenReport`.
 - supports `-StrictGate` (exit `2` when decision is not `FORWARD_OK`) and `-OpenReport`.
+- T-day data freshness/caching fix was applied in `dragon/data_fetcher_v2.py`:
+- when request end date covers today, `fetch_index_data_v2` now bypasses stale cache and refreshes source data.
+- `fetch_index_data_with_latest_snapshot_v2` now writes merged latest snapshot back only when cache key already exists (skip if missing).
+- added test file `tests/test_data_fetcher_tday_cache.py`; full suite reached `20` passing tests.

+ 10 - 0
research/dragon/v2/memory/2026-04-09.md

@@ -124,3 +124,13 @@
 - added optional switches:
 - added optional switches:
 - `-StrictGate` (non-`FORWARD_OK` exits with code `2`)
 - `-StrictGate` (non-`FORWARD_OK` exits with code `2`)
 - `-OpenReport` (opens `dragon_reports_index.html` after run).
 - `-OpenReport` (opens `dragon_reports_index.html` after run).
+- Updated `dragon/data_fetcher_v2.py` for T-day freshness + cache writeback behavior:
+- `fetch_index_data_v2(...)` now forces refresh when `end_date` covers today (`T日`) instead of reusing stale in-memory cache.
+- `fetch_index_data_with_latest_snapshot_v2(...)` now writes merged latest-snapshot data back only if the matching cache key already exists.
+- If cache key does not exist, writeback is skipped by design (`有则回写,无则跳过`).
+- Added regression tests:
+- `tests/test_data_fetcher_tday_cache.py`
+- covers: non-T-day cache hit, T-day forced refresh, conditional cache writeback.
+- Validation after change:
+- `py -3 -m unittest discover -s tests -v` passed (`20` tests).
+- `py -3 dragon_daily_signal_pipeline.py` smoke run passed.

+ 134 - 0
research/dragon/v2/tests/test_data_fetcher_tday_cache.py

@@ -0,0 +1,134 @@
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+from pathlib import Path
+import sys
+import unittest
+
+import pandas as pd
+
+
+DRAGON_ROOT = Path(__file__).resolve().parents[4] / "dragon"
+if str(DRAGON_ROOT) not in sys.path:
+    sys.path.append(str(DRAGON_ROOT))
+
+import data_fetcher_v2 as fetcher_module  # noqa: E402
+from data_fetcher_v2 import DataFetcherV2, LatestSnapshot  # noqa: E402
+
+
+def _make_index_frame(dates: list[pd.Timestamp]) -> pd.DataFrame:
+    rows = []
+    for idx, day in enumerate(dates):
+        base = float(100 + idx)
+        rows.append(
+            {
+                "date": day,
+                "open": base,
+                "high": base + 1.0,
+                "low": base - 1.0,
+                "close": base + 0.5,
+                "volume": 10_000 + idx,
+            }
+        )
+    df = pd.DataFrame(rows)
+    df["date"] = pd.to_datetime(df["date"])
+    return df
+
+
+class TestDataFetcherTDayCache(unittest.TestCase):
+    def setUp(self) -> None:
+        self.fetcher = DataFetcherV2()
+        self.today = datetime.now().date()
+        self.yesterday = self.today - timedelta(days=1)
+
+    def test_non_tday_uses_valid_cache(self) -> None:
+        end_date = self.yesterday.strftime("%Y-%m-%d")
+        cache_key = self.fetcher._get_cache_key("399673", "2018-01-01", end_date)
+
+        cached = _make_index_frame(
+            [
+                pd.Timestamp(self.yesterday - timedelta(days=1)),
+                pd.Timestamp(self.yesterday),
+            ]
+        ).set_index("date")
+        self.fetcher._set_cache(cache_key, cached)
+
+        original = fetcher_module.ak.stock_zh_index_daily
+
+        def _fail_call(symbol: str) -> pd.DataFrame:
+            raise AssertionError("ak.stock_zh_index_daily should not be called when non-tday cache is valid")
+
+        fetcher_module.ak.stock_zh_index_daily = _fail_call
+        try:
+            got = self.fetcher.fetch_index_data_v2("399673", "2018-01-01", end_date=end_date)
+        finally:
+            fetcher_module.ak.stock_zh_index_daily = original
+
+        self.assertEqual(len(got), len(cached))
+        self.assertEqual(got.index[-1].date(), self.yesterday)
+
+    def test_tday_forces_refresh_even_when_cache_valid(self) -> None:
+        end_date = self.today.strftime("%Y-%m-%d")
+        cache_key = self.fetcher._get_cache_key("399673", "2018-01-01", end_date)
+
+        stale = _make_index_frame(
+            [
+                pd.Timestamp(self.yesterday - timedelta(days=1)),
+                pd.Timestamp(self.yesterday),
+            ]
+        ).set_index("date")
+        self.fetcher._set_cache(cache_key, stale)
+
+        call_counter = {"n": 0}
+        original = fetcher_module.ak.stock_zh_index_daily
+
+        def _mock_call(symbol: str) -> pd.DataFrame:
+            call_counter["n"] += 1
+            return _make_index_frame([pd.Timestamp(self.yesterday), pd.Timestamp(self.today)])
+
+        fetcher_module.ak.stock_zh_index_daily = _mock_call
+        try:
+            got = self.fetcher.fetch_index_data_v2("399673", "2018-01-01", end_date=end_date)
+        finally:
+            fetcher_module.ak.stock_zh_index_daily = original
+
+        self.assertEqual(call_counter["n"], 1)
+        self.assertEqual(got.index[-1].date(), self.today)
+
+    def test_snapshot_merge_writes_back_only_when_cache_exists(self) -> None:
+        start_date = "2018-01-01"
+        end_date = self.today.strftime("%Y-%m-%d")
+        cache_key = self.fetcher._get_cache_key("399673", start_date, end_date)
+
+        historical = _make_index_frame([pd.Timestamp(self.yesterday)]).set_index("date")
+        historical.attrs["intraday_snapshot_appended"] = False
+        historical.attrs["intraday_snapshot_timestamp"] = None
+        historical.attrs["historical_latest_bar_date"] = self.yesterday.isoformat()
+
+        self.fetcher._set_cache(cache_key, historical)
+
+        self.fetcher.fetch_index_data_v2 = lambda symbol, start_date, end_date: historical.copy()  # type: ignore[method-assign]
+        self.fetcher.fetch_latest_snapshot = lambda symbol: LatestSnapshot(  # type: ignore[method-assign]
+            timestamp=datetime.combine(self.today, datetime.min.time()),
+            open=101.0,
+            high=102.0,
+            low=100.0,
+            close=101.5,
+            volume=0.0,
+        )
+
+        merged = self.fetcher.fetch_index_data_with_latest_snapshot_v2("399673", start_date, end_date=end_date)
+        self.assertTrue(bool(merged.attrs.get("intraday_snapshot_appended", False)))
+        self.assertIn(cache_key, self.fetcher.cache)
+        self.assertEqual(self.fetcher.cache[cache_key].index[-1].date(), self.today)
+
+        other = DataFetcherV2()
+        other.fetch_index_data_v2 = lambda symbol, start_date, end_date: historical.copy()  # type: ignore[method-assign]
+        other.fetch_latest_snapshot = self.fetcher.fetch_latest_snapshot  # type: ignore[method-assign]
+        other.fetch_index_data_with_latest_snapshot_v2("399673", start_date, end_date=end_date)
+        self.assertNotIn(cache_key, other.cache)
+
+
+if __name__ == "__main__":
+    unittest.main()
+