data_fetcher_v2.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import akshare as ak
  2. import pandas as pd
  3. import numpy as np
  4. from dataclasses import dataclass
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Union, List, Tuple
  7. import time
  8. try:
  9. import requests
  10. except ImportError: # pragma: no cover
  11. requests = None
  12. @dataclass
  13. class LatestSnapshot:
  14. timestamp: datetime
  15. open: float
  16. high: float
  17. low: float
  18. close: float
  19. volume: float = 0.0
  20. class DataFetcherV2:
  21. """
  22. 数据获取类V2 - 基于用户提供的优化方法
  23. 使用ak.stock_zh_index_daily获取更可靠的数据
  24. """
  25. def __init__(self):
  26. self.cache = {}
  27. self.cache_expiry = {}
  28. self.cache_duration = 3600 # 缓存1小时
  29. def _get_cache_key(self, symbol: str, start_date: str, end_date: str) -> str:
  30. """生成缓存键"""
  31. return f"{symbol}_{start_date}_{end_date}"
  32. def _is_cache_valid(self, cache_key: str) -> bool:
  33. """检查缓存是否有效"""
  34. if cache_key not in self.cache_expiry:
  35. return False
  36. return time.time() < self.cache_expiry[cache_key]
  37. def _set_cache(self, cache_key: str, data: pd.DataFrame):
  38. """设置缓存"""
  39. self.cache[cache_key] = data
  40. self.cache_expiry[cache_key] = time.time() + self.cache_duration
  41. def _should_force_refresh_t_day(self, end_date: str) -> bool:
  42. """
  43. 是否应强制刷新当日(T日)请求。
  44. 当请求窗口覆盖今天时,不直接使用旧缓存,避免拿到过期的当日数据。
  45. """
  46. try:
  47. return pd.to_datetime(end_date).date() >= datetime.now().date()
  48. except Exception:
  49. return False
  50. def _writeback_cache_if_exists(self, cache_key: str, data: pd.DataFrame) -> None:
  51. """
  52. 仅在缓存键已存在时回写缓存。
  53. 若缓存键不存在,则跳过(符合“有则回写、无则算了”)。
  54. """
  55. if cache_key not in self.cache:
  56. return
  57. payload = data.copy()
  58. payload.attrs.update(data.attrs)
  59. self._set_cache(cache_key, payload)
  60. def _format_index_code(self, symbol: str) -> str:
  61. """
  62. 格式化指数代码为akshare标准格式
  63. 例如: 399673 -> sz399673
  64. """
  65. symbol = symbol.strip()
  66. # 移除可能的前缀
  67. if '.' in symbol:
  68. code, exchange = symbol.split('.')
  69. symbol = code
  70. # 确保是6位代码
  71. if len(symbol) == 6:
  72. # 判断交易所并添加前缀
  73. if symbol.startswith(('00', '30')): # 深交所
  74. return f"sz{symbol}"
  75. elif symbol.startswith(('60', '68')): # 上交所
  76. return f"sh{symbol}"
  77. else:
  78. # 其他交易所默认使用sz
  79. return f"sz{symbol}"
  80. # 如果已经是格式化好的代码
  81. return symbol.lower()
  82. def _infer_realtime_prefix(self, code: str) -> str:
  83. if code.startswith("399"):
  84. return "sz"
  85. if code.startswith("000"):
  86. return "sh"
  87. if code.startswith(("30", "00", "15")):
  88. return "sz"
  89. if code.startswith(("60", "68")):
  90. return "sh"
  91. return "sz"
  92. def fetch_latest_snapshot(self, symbol: str) -> Optional[LatestSnapshot]:
  93. if requests is None:
  94. return None
  95. formatted_symbol = self._format_index_code(symbol)
  96. if formatted_symbol.startswith(("sz", "sh")):
  97. prefix = formatted_symbol[:2]
  98. code = formatted_symbol[2:]
  99. else:
  100. code = formatted_symbol
  101. prefix = self._infer_realtime_prefix(code)
  102. url = f"http://hq.sinajs.cn/list={prefix}{code}"
  103. headers = {
  104. "User-Agent": "Mozilla/5.0",
  105. "Referer": "http://finance.sina.com.cn",
  106. }
  107. try:
  108. response = requests.get(url, headers=headers, timeout=10)
  109. response.raise_for_status()
  110. except Exception:
  111. return None
  112. response.encoding = "gbk"
  113. text = response.text
  114. if '"' not in text:
  115. return None
  116. try:
  117. payload = text.split('"')[1].split(",")
  118. if len(payload) < 6:
  119. return None
  120. open_price = float(payload[1])
  121. prev_close = float(payload[2])
  122. close_price = float(payload[3])
  123. high_price = float(payload[4])
  124. low_price = float(payload[5])
  125. except (ValueError, IndexError):
  126. return None
  127. if close_price <= 0 or prev_close <= 0 or high_price <= 0 or low_price <= 0:
  128. return None
  129. return LatestSnapshot(
  130. timestamp=datetime.now(),
  131. open=open_price,
  132. high=high_price,
  133. low=low_price,
  134. close=close_price,
  135. volume=0.0,
  136. )
  137. def fetch_index_data_v2(self,
  138. symbol: str,
  139. start_date: str = "2018-01-01",
  140. end_date: Optional[str] = None) -> pd.DataFrame:
  141. """
  142. 使用优化的方法获取指数数据
  143. Args:
  144. symbol: 指数代码,支持多种格式
  145. start_date: 开始日期,默认2018-01-01
  146. end_date: 结束日期,默认为当前日期
  147. Returns:
  148. 包含OHLCV数据的DataFrame,索引为日期
  149. """
  150. resolved_end_date = end_date or datetime.now().strftime('%Y-%m-%d')
  151. cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
  152. force_refresh_t_day = self._should_force_refresh_t_day(resolved_end_date)
  153. if self._is_cache_valid(cache_key) and not force_refresh_t_day:
  154. return self.cache[cache_key].copy()
  155. try:
  156. # 格式化指数代码
  157. formatted_code = self._format_index_code(symbol)
  158. print(f"正在获取指数 {formatted_code} 的日线级别历史数据...")
  159. # 使用akshare获取日线数据(用户提供的优化方法)
  160. all_data_df = ak.stock_zh_index_daily(symbol=formatted_code)
  161. if all_data_df.empty:
  162. print(f"Warning: No data found for index {symbol}")
  163. return pd.DataFrame()
  164. # 处理日期列
  165. all_data_df['date'] = pd.to_datetime(all_data_df['date'])
  166. all_data_df.set_index('date', inplace=True)
  167. # 筛选日期范围
  168. start_datetime = pd.to_datetime(start_date)
  169. end_datetime = pd.to_datetime(resolved_end_date)
  170. # 先筛选出指定日期之后的数据
  171. filtered_df = all_data_df[all_data_df.index >= start_datetime]
  172. filtered_df = filtered_df[filtered_df.index <= end_datetime]
  173. if filtered_df.empty:
  174. print(f"Warning: No data found for {symbol} in date range {start_date} to {end_date}")
  175. return pd.DataFrame()
  176. # 标准化列名
  177. filtered_df = self._standardize_columns(filtered_df)
  178. print(f"数据获取成功,期间为 {filtered_df.index[0]} 到 {filtered_df.index[-1]}")
  179. print(f"获取数据量: {len(filtered_df)} 条")
  180. # 缓存数据
  181. filtered_df.attrs["intraday_snapshot_appended"] = False
  182. filtered_df.attrs["intraday_snapshot_timestamp"] = None
  183. filtered_df.attrs["historical_latest_bar_date"] = filtered_df.index[-1].date().isoformat()
  184. self._set_cache(cache_key, filtered_df)
  185. return filtered_df.copy()
  186. except Exception as e:
  187. print(f"Error fetching index data for {symbol}: {str(e)}")
  188. return pd.DataFrame()
  189. def fetch_index_data_with_latest_snapshot_v2(
  190. self,
  191. symbol: str,
  192. start_date: str = "2018-01-01",
  193. end_date: Optional[str] = None,
  194. ) -> pd.DataFrame:
  195. resolved_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
  196. cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
  197. frame = self.fetch_index_data_v2(symbol=symbol, start_date=start_date, end_date=resolved_end_date)
  198. if frame.empty:
  199. return frame
  200. today = datetime.now().date()
  201. historical_latest_bar_date = frame.index[-1].date().isoformat()
  202. frame.attrs["intraday_snapshot_appended"] = False
  203. frame.attrs["intraday_snapshot_timestamp"] = None
  204. frame.attrs["historical_latest_bar_date"] = historical_latest_bar_date
  205. if end_date is not None and pd.to_datetime(end_date).date() < today:
  206. return frame
  207. if frame.index[-1].date() >= today:
  208. return frame
  209. if today.weekday() >= 5:
  210. return frame
  211. snapshot = self.fetch_latest_snapshot(symbol)
  212. if snapshot is None:
  213. return frame
  214. latest_row = pd.DataFrame(
  215. [
  216. {
  217. "open": snapshot.open,
  218. "high": snapshot.high,
  219. "low": snapshot.low,
  220. "close": snapshot.close,
  221. "volume": snapshot.volume,
  222. }
  223. ],
  224. index=pd.DatetimeIndex([pd.Timestamp(snapshot.timestamp)]),
  225. )
  226. latest_row.index.name = "date"
  227. merged = pd.concat([frame, latest_row])
  228. merged = merged[~merged.index.duplicated(keep="last")].sort_index()
  229. merged.attrs["intraday_snapshot_appended"] = True
  230. merged.attrs["intraday_snapshot_timestamp"] = snapshot.timestamp.isoformat(timespec="seconds")
  231. merged.attrs["historical_latest_bar_date"] = historical_latest_bar_date
  232. self._writeback_cache_if_exists(cache_key, merged)
  233. return merged
  234. def fetch_stock_data_v2(self,
  235. symbol: str,
  236. start_date: str = "2018-01-01",
  237. end_date: Optional[str] = None) -> pd.DataFrame:
  238. """
  239. 获取股票数据(如果akshare支持的话)
  240. """
  241. if end_date is None:
  242. end_date = datetime.now().strftime('%Y-%m-%d')
  243. # 目前主要支持指数数据,股票数据作为占位符
  244. print(f"Warning: Stock data fetching not fully implemented, trying as index: {symbol}")
  245. return self.fetch_index_data_v2(symbol, start_date, end_date)
  246. def _standardize_columns(self, df: pd.DataFrame) -> pd.DataFrame:
  247. """
  248. 标准化DataFrame列名和索引
  249. """
  250. # 确保索引是日期类型
  251. if not isinstance(df.index, pd.DatetimeIndex):
  252. try:
  253. df.index = pd.to_datetime(df.index)
  254. except:
  255. pass
  256. # 创建列名映射(中英文对照)
  257. column_mapping = {
  258. '开盘': 'open',
  259. '收盘': 'close',
  260. '最高': 'high',
  261. '最低': 'low',
  262. '成交量': 'volume',
  263. '成交额': 'amount',
  264. '涨跌幅': 'change_pct',
  265. '涨跌额': 'change',
  266. '振幅': 'amplitude',
  267. '换手率': 'turnover'
  268. }
  269. # 重命名列
  270. df = df.rename(columns=column_mapping)
  271. # 确保必要的列存在
  272. required_columns = ['open', 'high', 'low', 'close', 'volume']
  273. for col in required_columns:
  274. if col not in df.columns:
  275. # 尝试从中文列名获取
  276. chinese_map = {
  277. 'open': '开盘', 'high': '最高', 'low': '最低',
  278. 'close': '收盘', 'volume': '成交量'
  279. }
  280. if chinese_map[col] in df.columns:
  281. df[col] = df[chinese_map[col]]
  282. else:
  283. print(f"Warning: Missing column {col}, filling with NaN")
  284. df[col] = np.nan
  285. # 选择并排序列
  286. result_columns = [col for col in required_columns if col in df.columns]
  287. other_columns = [col for col in df.columns if col not in required_columns]
  288. df = df[result_columns + other_columns]
  289. # 确保数值列是数值类型
  290. numeric_columns = ['open', 'high', 'low', 'close', 'volume']
  291. for col in numeric_columns:
  292. if col in df.columns:
  293. df[col] = pd.to_numeric(df[col], errors='coerce')
  294. # 删除包含NaN的行(可选)
  295. # df = df.dropna(subset=['close'])
  296. return df
  297. def get_data_by_date_range(self,
  298. symbol: str,
  299. start_date: str,
  300. end_date: str) -> pd.DataFrame:
  301. """
  302. 获取指定日期范围的数据(主要接口)
  303. Args:
  304. symbol: 指数代码
  305. start_date: 开始日期 'YYYY-MM-DD'
  306. end_date: 结束日期 'YYYY-MM-DD'
  307. Returns:
  308. DataFrame with date index and OHLCV columns
  309. """
  310. return self.fetch_index_data_v2(symbol, start_date, end_date)
  311. def clear_cache(self):
  312. """清除缓存"""
  313. self.cache.clear()
  314. self.cache_expiry.clear()
  315. class DataManagerV2:
  316. """
  317. 数据管理类V2 - 基于优化的数据获取方式
  318. """
  319. def __init__(self, data_fetcher: DataFetcherV2):
  320. self.data_fetcher = data_fetcher
  321. self.data_cache = {}
  322. self.market_info = {}
  323. def load_data(self,
  324. symbol: str,
  325. start_date: str,
  326. end_date: str) -> pd.DataFrame:
  327. """
  328. 加载数据并缓存
  329. Args:
  330. symbol: 指数代码
  331. start_date: 开始日期
  332. end_date: 结束日期
  333. Returns:
  334. DataFrame with standardized format
  335. """
  336. cache_key = f"{symbol}_{start_date}_{end_date}"
  337. if cache_key not in self.data_cache:
  338. print(f"Loading data for {symbol} from {start_date} to {end_date}...")
  339. data = self.data_fetcher.get_data_by_date_range(symbol, start_date, end_date)
  340. if data.empty:
  341. print(f"Warning: Empty data returned for {symbol}")
  342. else:
  343. print(f"Successfully loaded {len(data)} bars for {symbol}")
  344. # 存储市场信息
  345. self.market_info[symbol] = {
  346. 'start_date': data.index[0],
  347. 'end_date': data.index[-1],
  348. 'total_bars': len(data),
  349. 'has_data': True
  350. }
  351. self.data_cache[cache_key] = data
  352. return self.data_cache[cache_key].copy()
  353. def get_market_info(self, symbol: str) -> dict:
  354. """获取市场信息"""
  355. return self.market_info.get(symbol, {})
  356. def get_available_symbols(self) -> list:
  357. """获取已加载的标的列表"""
  358. return list(self.market_info.keys())
  359. def get_current_data(self,
  360. symbol: str,
  361. current_date: str,
  362. window_size: int = 100) -> pd.DataFrame:
  363. """
  364. 获取当前日期之前的数据窗口(支持日期索引格式)
  365. """
  366. # 查找对应的数据缓存
  367. cache_key = None
  368. for key, data in self.data_cache.items():
  369. if not data.empty and key.startswith(symbol):
  370. cache_key = key
  371. break
  372. if cache_key is None:
  373. return pd.DataFrame()
  374. data = self.data_cache[cache_key]
  375. current_datetime = pd.to_datetime(current_date)
  376. # 支持两种数据格式:日期索引或date列
  377. if isinstance(data.index, pd.DatetimeIndex):
  378. # 日期索引格式
  379. historical_data = data[data.index <= current_datetime].copy()
  380. result_data = historical_data.tail(window_size).reset_index(drop=True)
  381. else:
  382. # date列格式
  383. historical_data = data[data['date'] <= current_datetime].copy()
  384. result_data = historical_data.tail(window_size).reset_index(drop=True)
  385. if len(result_data) == 0:
  386. return pd.DataFrame()
  387. return result_data
  388. def print_data_summary(self):
  389. """打印数据摘要"""
  390. print("\n" + "="*70)
  391. print("DATA MANAGER SUMMARY")
  392. print("="*70)
  393. for symbol, info in self.market_info.items():
  394. print(f"{symbol}:")
  395. print(f" Period: {info['start_date']} to {info['end_date']}")
  396. print(f" Total bars: {info['total_bars']}")
  397. print("="*70)