|
@@ -1,10 +1,26 @@
|
|
|
import akshare as ak
|
|
import akshare as ak
|
|
|
import pandas as pd
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
from datetime import datetime, timedelta
|
|
from datetime import datetime, timedelta
|
|
|
from typing import Optional, Union, List, Tuple
|
|
from typing import Optional, Union, List, Tuple
|
|
|
import time
|
|
import time
|
|
|
|
|
|
|
|
|
|
+try:
|
|
|
|
|
+ import requests
|
|
|
|
|
+except ImportError: # pragma: no cover
|
|
|
|
|
+ requests = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class LatestSnapshot:
|
|
|
|
|
+ timestamp: datetime
|
|
|
|
|
+ open: float
|
|
|
|
|
+ high: float
|
|
|
|
|
+ low: float
|
|
|
|
|
+ close: float
|
|
|
|
|
+ volume: float = 0.0
|
|
|
|
|
+
|
|
|
class DataFetcherV2:
|
|
class DataFetcherV2:
|
|
|
"""
|
|
"""
|
|
|
数据获取类V2 - 基于用户提供的优化方法
|
|
数据获取类V2 - 基于用户提供的优化方法
|
|
@@ -30,6 +46,27 @@ class DataFetcherV2:
|
|
|
"""设置缓存"""
|
|
"""设置缓存"""
|
|
|
self.cache[cache_key] = data
|
|
self.cache[cache_key] = data
|
|
|
self.cache_expiry[cache_key] = time.time() + self.cache_duration
|
|
self.cache_expiry[cache_key] = time.time() + self.cache_duration
|
|
|
|
|
+
|
|
|
|
|
+ def _should_force_refresh_t_day(self, end_date: str) -> bool:
|
|
|
|
|
+ """
|
|
|
|
|
+ 是否应强制刷新当日(T日)请求。
|
|
|
|
|
+ 当请求窗口覆盖今天时,不直接使用旧缓存,避免拿到过期的当日数据。
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ return pd.to_datetime(end_date).date() >= datetime.now().date()
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ def _writeback_cache_if_exists(self, cache_key: str, data: pd.DataFrame) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ 仅在缓存键已存在时回写缓存。
|
|
|
|
|
+ 若缓存键不存在,则跳过(符合“有则回写、无则算了”)。
|
|
|
|
|
+ """
|
|
|
|
|
+ if cache_key not in self.cache:
|
|
|
|
|
+ return
|
|
|
|
|
+ payload = data.copy()
|
|
|
|
|
+ payload.attrs.update(data.attrs)
|
|
|
|
|
+ self._set_cache(cache_key, payload)
|
|
|
|
|
|
|
|
def _format_index_code(self, symbol: str) -> str:
|
|
def _format_index_code(self, symbol: str) -> str:
|
|
|
"""
|
|
"""
|
|
@@ -56,6 +93,70 @@ class DataFetcherV2:
|
|
|
|
|
|
|
|
# 如果已经是格式化好的代码
|
|
# 如果已经是格式化好的代码
|
|
|
return symbol.lower()
|
|
return symbol.lower()
|
|
|
|
|
+
|
|
|
|
|
+ def _infer_realtime_prefix(self, code: str) -> str:
|
|
|
|
|
+ if code.startswith("399"):
|
|
|
|
|
+ return "sz"
|
|
|
|
|
+ if code.startswith("000"):
|
|
|
|
|
+ return "sh"
|
|
|
|
|
+ if code.startswith(("30", "00", "15")):
|
|
|
|
|
+ return "sz"
|
|
|
|
|
+ if code.startswith(("60", "68")):
|
|
|
|
|
+ return "sh"
|
|
|
|
|
+ return "sz"
|
|
|
|
|
+
|
|
|
|
|
+ def fetch_latest_snapshot(self, symbol: str) -> Optional[LatestSnapshot]:
|
|
|
|
|
+ if requests is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ formatted_symbol = self._format_index_code(symbol)
|
|
|
|
|
+ if formatted_symbol.startswith(("sz", "sh")):
|
|
|
|
|
+ prefix = formatted_symbol[:2]
|
|
|
|
|
+ code = formatted_symbol[2:]
|
|
|
|
|
+ else:
|
|
|
|
|
+ code = formatted_symbol
|
|
|
|
|
+ prefix = self._infer_realtime_prefix(code)
|
|
|
|
|
+
|
|
|
|
|
+ url = f"http://hq.sinajs.cn/list={prefix}{code}"
|
|
|
|
|
+ headers = {
|
|
|
|
|
+ "User-Agent": "Mozilla/5.0",
|
|
|
|
|
+ "Referer": "http://finance.sina.com.cn",
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = requests.get(url, headers=headers, timeout=10)
|
|
|
|
|
+ response.raise_for_status()
|
|
|
|
|
+ except Exception:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ response.encoding = "gbk"
|
|
|
|
|
+ text = response.text
|
|
|
|
|
+ if '"' not in text:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ payload = text.split('"')[1].split(",")
|
|
|
|
|
+ if len(payload) < 6:
|
|
|
|
|
+ return None
|
|
|
|
|
+ open_price = float(payload[1])
|
|
|
|
|
+ prev_close = float(payload[2])
|
|
|
|
|
+ close_price = float(payload[3])
|
|
|
|
|
+ high_price = float(payload[4])
|
|
|
|
|
+ low_price = float(payload[5])
|
|
|
|
|
+ except (ValueError, IndexError):
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ if close_price <= 0 or prev_close <= 0 or high_price <= 0 or low_price <= 0:
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+ return LatestSnapshot(
|
|
|
|
|
+ timestamp=datetime.now(),
|
|
|
|
|
+ open=open_price,
|
|
|
|
|
+ high=high_price,
|
|
|
|
|
+ low=low_price,
|
|
|
|
|
+ close=close_price,
|
|
|
|
|
+ volume=0.0,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
def fetch_index_data_v2(self,
|
|
def fetch_index_data_v2(self,
|
|
|
symbol: str,
|
|
symbol: str,
|
|
@@ -72,12 +173,11 @@ class DataFetcherV2:
|
|
|
Returns:
|
|
Returns:
|
|
|
包含OHLCV数据的DataFrame,索引为日期
|
|
包含OHLCV数据的DataFrame,索引为日期
|
|
|
"""
|
|
"""
|
|
|
- if end_date is None:
|
|
|
|
|
- end_date = datetime.now().strftime('%Y-%m-%d')
|
|
|
|
|
-
|
|
|
|
|
- cache_key = self._get_cache_key(symbol, start_date, end_date)
|
|
|
|
|
-
|
|
|
|
|
- if self._is_cache_valid(cache_key):
|
|
|
|
|
|
|
+ resolved_end_date = end_date or datetime.now().strftime('%Y-%m-%d')
|
|
|
|
|
+ cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
|
|
|
|
|
+ force_refresh_t_day = self._should_force_refresh_t_day(resolved_end_date)
|
|
|
|
|
+
|
|
|
|
|
+ if self._is_cache_valid(cache_key) and not force_refresh_t_day:
|
|
|
return self.cache[cache_key].copy()
|
|
return self.cache[cache_key].copy()
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -98,7 +198,7 @@ class DataFetcherV2:
|
|
|
|
|
|
|
|
# 筛选日期范围
|
|
# 筛选日期范围
|
|
|
start_datetime = pd.to_datetime(start_date)
|
|
start_datetime = pd.to_datetime(start_date)
|
|
|
- end_datetime = pd.to_datetime(end_date)
|
|
|
|
|
|
|
+ end_datetime = pd.to_datetime(resolved_end_date)
|
|
|
|
|
|
|
|
# 先筛选出指定日期之后的数据
|
|
# 先筛选出指定日期之后的数据
|
|
|
filtered_df = all_data_df[all_data_df.index >= start_datetime]
|
|
filtered_df = all_data_df[all_data_df.index >= start_datetime]
|
|
@@ -115,12 +215,68 @@ class DataFetcherV2:
|
|
|
print(f"获取数据量: {len(filtered_df)} 条")
|
|
print(f"获取数据量: {len(filtered_df)} 条")
|
|
|
|
|
|
|
|
# 缓存数据
|
|
# 缓存数据
|
|
|
|
|
+ filtered_df.attrs["intraday_snapshot_appended"] = False
|
|
|
|
|
+ filtered_df.attrs["intraday_snapshot_timestamp"] = None
|
|
|
|
|
+ filtered_df.attrs["historical_latest_bar_date"] = filtered_df.index[-1].date().isoformat()
|
|
|
self._set_cache(cache_key, filtered_df)
|
|
self._set_cache(cache_key, filtered_df)
|
|
|
return filtered_df.copy()
|
|
return filtered_df.copy()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
print(f"Error fetching index data for {symbol}: {str(e)}")
|
|
print(f"Error fetching index data for {symbol}: {str(e)}")
|
|
|
return pd.DataFrame()
|
|
return pd.DataFrame()
|
|
|
|
|
+
|
|
|
|
|
+ def fetch_index_data_with_latest_snapshot_v2(
|
|
|
|
|
+ self,
|
|
|
|
|
+ symbol: str,
|
|
|
|
|
+ start_date: str = "2018-01-01",
|
|
|
|
|
+ end_date: Optional[str] = None,
|
|
|
|
|
+ ) -> pd.DataFrame:
|
|
|
|
|
+ resolved_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
|
|
|
|
+ cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
|
|
|
|
|
+
|
|
|
|
|
+ frame = self.fetch_index_data_v2(symbol=symbol, start_date=start_date, end_date=resolved_end_date)
|
|
|
|
|
+ if frame.empty:
|
|
|
|
|
+ return frame
|
|
|
|
|
+
|
|
|
|
|
+ today = datetime.now().date()
|
|
|
|
|
+ historical_latest_bar_date = frame.index[-1].date().isoformat()
|
|
|
|
|
+
|
|
|
|
|
+ frame.attrs["intraday_snapshot_appended"] = False
|
|
|
|
|
+ frame.attrs["intraday_snapshot_timestamp"] = None
|
|
|
|
|
+ frame.attrs["historical_latest_bar_date"] = historical_latest_bar_date
|
|
|
|
|
+
|
|
|
|
|
+ if end_date is not None and pd.to_datetime(end_date).date() < today:
|
|
|
|
|
+ return frame
|
|
|
|
|
+ if frame.index[-1].date() >= today:
|
|
|
|
|
+ return frame
|
|
|
|
|
+ if today.weekday() >= 5:
|
|
|
|
|
+ return frame
|
|
|
|
|
+
|
|
|
|
|
+ snapshot = self.fetch_latest_snapshot(symbol)
|
|
|
|
|
+ if snapshot is None:
|
|
|
|
|
+ return frame
|
|
|
|
|
+
|
|
|
|
|
+ latest_row = pd.DataFrame(
|
|
|
|
|
+ [
|
|
|
|
|
+ {
|
|
|
|
|
+ "open": snapshot.open,
|
|
|
|
|
+ "high": snapshot.high,
|
|
|
|
|
+ "low": snapshot.low,
|
|
|
|
|
+ "close": snapshot.close,
|
|
|
|
|
+ "volume": snapshot.volume,
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ index=pd.DatetimeIndex([pd.Timestamp(snapshot.timestamp)]),
|
|
|
|
|
+ )
|
|
|
|
|
+ latest_row.index.name = "date"
|
|
|
|
|
+
|
|
|
|
|
+ merged = pd.concat([frame, latest_row])
|
|
|
|
|
+ merged = merged[~merged.index.duplicated(keep="last")].sort_index()
|
|
|
|
|
+ merged.attrs["intraday_snapshot_appended"] = True
|
|
|
|
|
+ merged.attrs["intraday_snapshot_timestamp"] = snapshot.timestamp.isoformat(timespec="seconds")
|
|
|
|
|
+ merged.attrs["historical_latest_bar_date"] = historical_latest_bar_date
|
|
|
|
|
+ self._writeback_cache_if_exists(cache_key, merged)
|
|
|
|
|
+ return merged
|
|
|
|
|
|
|
|
def fetch_stock_data_v2(self,
|
|
def fetch_stock_data_v2(self,
|
|
|
symbol: str,
|
|
symbol: str,
|
|
@@ -317,4 +473,4 @@ class DataManagerV2:
|
|
|
print(f"{symbol}:")
|
|
print(f"{symbol}:")
|
|
|
print(f" Period: {info['start_date']} to {info['end_date']}")
|
|
print(f" Period: {info['start_date']} to {info['end_date']}")
|
|
|
print(f" Total bars: {info['total_bars']}")
|
|
print(f" Total bars: {info['total_bars']}")
|
|
|
- print("="*70)
|
|
|
|
|
|
|
+ print("="*70)
|