data_fetcher_v2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import akshare as ak
  2. import pandas as pd
  3. import numpy as np
  4. from datetime import datetime, timedelta
  5. from typing import Optional, Union, List, Tuple
  6. import time
  7. class DataFetcherV2:
  8. """
  9. 数据获取类V2 - 基于用户提供的优化方法
  10. 使用ak.stock_zh_index_daily获取更可靠的数据
  11. """
  12. def __init__(self):
  13. self.cache = {}
  14. self.cache_expiry = {}
  15. self.cache_duration = 3600 # 缓存1小时
  16. def _get_cache_key(self, symbol: str, start_date: str, end_date: str) -> str:
  17. """生成缓存键"""
  18. return f"{symbol}_{start_date}_{end_date}"
  19. def _is_cache_valid(self, cache_key: str) -> bool:
  20. """检查缓存是否有效"""
  21. if cache_key not in self.cache_expiry:
  22. return False
  23. return time.time() < self.cache_expiry[cache_key]
  24. def _set_cache(self, cache_key: str, data: pd.DataFrame):
  25. """设置缓存"""
  26. self.cache[cache_key] = data
  27. self.cache_expiry[cache_key] = time.time() + self.cache_duration
  28. def _format_index_code(self, symbol: str) -> str:
  29. """
  30. 格式化指数代码为akshare标准格式
  31. 例如: 399673 -> sz399673
  32. """
  33. symbol = symbol.strip()
  34. # 移除可能的前缀
  35. if '.' in symbol:
  36. code, exchange = symbol.split('.')
  37. symbol = code
  38. # 确保是6位代码
  39. if len(symbol) == 6:
  40. # 判断交易所并添加前缀
  41. if symbol.startswith(('00', '30')): # 深交所
  42. return f"sz{symbol}"
  43. elif symbol.startswith(('60', '68')): # 上交所
  44. return f"sh{symbol}"
  45. else:
  46. # 其他交易所默认使用sz
  47. return f"sz{symbol}"
  48. # 如果已经是格式化好的代码
  49. return symbol.lower()
  50. def fetch_index_data_v2(self,
  51. symbol: str,
  52. start_date: str = "2018-01-01",
  53. end_date: Optional[str] = None) -> pd.DataFrame:
  54. """
  55. 使用优化的方法获取指数数据
  56. Args:
  57. symbol: 指数代码,支持多种格式
  58. start_date: 开始日期,默认2018-01-01
  59. end_date: 结束日期,默认为当前日期
  60. Returns:
  61. 包含OHLCV数据的DataFrame,索引为日期
  62. """
  63. if end_date is None:
  64. end_date = datetime.now().strftime('%Y-%m-%d')
  65. cache_key = self._get_cache_key(symbol, start_date, end_date)
  66. if self._is_cache_valid(cache_key):
  67. return self.cache[cache_key].copy()
  68. try:
  69. # 格式化指数代码
  70. formatted_code = self._format_index_code(symbol)
  71. print(f"正在获取指数 {formatted_code} 的日线级别历史数据...")
  72. # 使用akshare获取日线数据(用户提供的优化方法)
  73. all_data_df = ak.stock_zh_index_daily(symbol=formatted_code)
  74. if all_data_df.empty:
  75. print(f"Warning: No data found for index {symbol}")
  76. return pd.DataFrame()
  77. # 处理日期列
  78. all_data_df['date'] = pd.to_datetime(all_data_df['date'])
  79. all_data_df.set_index('date', inplace=True)
  80. # 筛选日期范围
  81. start_datetime = pd.to_datetime(start_date)
  82. end_datetime = pd.to_datetime(end_date)
  83. # 先筛选出指定日期之后的数据
  84. filtered_df = all_data_df[all_data_df.index >= start_datetime]
  85. filtered_df = filtered_df[filtered_df.index <= end_datetime]
  86. if filtered_df.empty:
  87. print(f"Warning: No data found for {symbol} in date range {start_date} to {end_date}")
  88. return pd.DataFrame()
  89. # 标准化列名
  90. filtered_df = self._standardize_columns(filtered_df)
  91. print(f"数据获取成功,期间为 {filtered_df.index[0]} 到 {filtered_df.index[-1]}")
  92. print(f"获取数据量: {len(filtered_df)} 条")
  93. # 缓存数据
  94. self._set_cache(cache_key, filtered_df)
  95. return filtered_df.copy()
  96. except Exception as e:
  97. print(f"Error fetching index data for {symbol}: {str(e)}")
  98. return pd.DataFrame()
  99. def fetch_stock_data_v2(self,
  100. symbol: str,
  101. start_date: str = "2018-01-01",
  102. end_date: Optional[str] = None) -> pd.DataFrame:
  103. """
  104. 获取股票数据(如果akshare支持的话)
  105. """
  106. if end_date is None:
  107. end_date = datetime.now().strftime('%Y-%m-%d')
  108. # 目前主要支持指数数据,股票数据作为占位符
  109. print(f"Warning: Stock data fetching not fully implemented, trying as index: {symbol}")
  110. return self.fetch_index_data_v2(symbol, start_date, end_date)
  111. def _standardize_columns(self, df: pd.DataFrame) -> pd.DataFrame:
  112. """
  113. 标准化DataFrame列名和索引
  114. """
  115. # 确保索引是日期类型
  116. if not isinstance(df.index, pd.DatetimeIndex):
  117. try:
  118. df.index = pd.to_datetime(df.index)
  119. except:
  120. pass
  121. # 创建列名映射(中英文对照)
  122. column_mapping = {
  123. '开盘': 'open',
  124. '收盘': 'close',
  125. '最高': 'high',
  126. '最低': 'low',
  127. '成交量': 'volume',
  128. '成交额': 'amount',
  129. '涨跌幅': 'change_pct',
  130. '涨跌额': 'change',
  131. '振幅': 'amplitude',
  132. '换手率': 'turnover'
  133. }
  134. # 重命名列
  135. df = df.rename(columns=column_mapping)
  136. # 确保必要的列存在
  137. required_columns = ['open', 'high', 'low', 'close', 'volume']
  138. for col in required_columns:
  139. if col not in df.columns:
  140. # 尝试从中文列名获取
  141. chinese_map = {
  142. 'open': '开盘', 'high': '最高', 'low': '最低',
  143. 'close': '收盘', 'volume': '成交量'
  144. }
  145. if chinese_map[col] in df.columns:
  146. df[col] = df[chinese_map[col]]
  147. else:
  148. print(f"Warning: Missing column {col}, filling with NaN")
  149. df[col] = np.nan
  150. # 选择并排序列
  151. result_columns = [col for col in required_columns if col in df.columns]
  152. other_columns = [col for col in df.columns if col not in required_columns]
  153. df = df[result_columns + other_columns]
  154. # 确保数值列是数值类型
  155. numeric_columns = ['open', 'high', 'low', 'close', 'volume']
  156. for col in numeric_columns:
  157. if col in df.columns:
  158. df[col] = pd.to_numeric(df[col], errors='coerce')
  159. # 删除包含NaN的行(可选)
  160. # df = df.dropna(subset=['close'])
  161. return df
  162. def get_data_by_date_range(self,
  163. symbol: str,
  164. start_date: str,
  165. end_date: str) -> pd.DataFrame:
  166. """
  167. 获取指定日期范围的数据(主要接口)
  168. Args:
  169. symbol: 指数代码
  170. start_date: 开始日期 'YYYY-MM-DD'
  171. end_date: 结束日期 'YYYY-MM-DD'
  172. Returns:
  173. DataFrame with date index and OHLCV columns
  174. """
  175. return self.fetch_index_data_v2(symbol, start_date, end_date)
  176. def clear_cache(self):
  177. """清除缓存"""
  178. self.cache.clear()
  179. self.cache_expiry.clear()
  180. class DataManagerV2:
  181. """
  182. 数据管理类V2 - 基于优化的数据获取方式
  183. """
  184. def __init__(self, data_fetcher: DataFetcherV2):
  185. self.data_fetcher = data_fetcher
  186. self.data_cache = {}
  187. self.market_info = {}
  188. def load_data(self,
  189. symbol: str,
  190. start_date: str,
  191. end_date: str) -> pd.DataFrame:
  192. """
  193. 加载数据并缓存
  194. Args:
  195. symbol: 指数代码
  196. start_date: 开始日期
  197. end_date: 结束日期
  198. Returns:
  199. DataFrame with standardized format
  200. """
  201. cache_key = f"{symbol}_{start_date}_{end_date}"
  202. if cache_key not in self.data_cache:
  203. print(f"Loading data for {symbol} from {start_date} to {end_date}...")
  204. data = self.data_fetcher.get_data_by_date_range(symbol, start_date, end_date)
  205. if data.empty:
  206. print(f"Warning: Empty data returned for {symbol}")
  207. else:
  208. print(f"Successfully loaded {len(data)} bars for {symbol}")
  209. # 存储市场信息
  210. self.market_info[symbol] = {
  211. 'start_date': data.index[0],
  212. 'end_date': data.index[-1],
  213. 'total_bars': len(data),
  214. 'has_data': True
  215. }
  216. self.data_cache[cache_key] = data
  217. return self.data_cache[cache_key].copy()
  218. def get_market_info(self, symbol: str) -> dict:
  219. """获取市场信息"""
  220. return self.market_info.get(symbol, {})
  221. def get_available_symbols(self) -> list:
  222. """获取已加载的标的列表"""
  223. return list(self.market_info.keys())
  224. def get_current_data(self,
  225. symbol: str,
  226. current_date: str,
  227. window_size: int = 100) -> pd.DataFrame:
  228. """
  229. 获取当前日期之前的数据窗口(支持日期索引格式)
  230. """
  231. # 查找对应的数据缓存
  232. cache_key = None
  233. for key, data in self.data_cache.items():
  234. if not data.empty and key.startswith(symbol):
  235. cache_key = key
  236. break
  237. if cache_key is None:
  238. return pd.DataFrame()
  239. data = self.data_cache[cache_key]
  240. current_datetime = pd.to_datetime(current_date)
  241. # 支持两种数据格式:日期索引或date列
  242. if isinstance(data.index, pd.DatetimeIndex):
  243. # 日期索引格式
  244. historical_data = data[data.index <= current_datetime].copy()
  245. result_data = historical_data.tail(window_size).reset_index(drop=True)
  246. else:
  247. # date列格式
  248. historical_data = data[data['date'] <= current_datetime].copy()
  249. result_data = historical_data.tail(window_size).reset_index(drop=True)
  250. if len(result_data) == 0:
  251. return pd.DataFrame()
  252. return result_data
  253. def print_data_summary(self):
  254. """打印数据摘要"""
  255. print("\n" + "="*70)
  256. print("DATA MANAGER SUMMARY")
  257. print("="*70)
  258. for symbol, info in self.market_info.items():
  259. print(f"{symbol}:")
  260. print(f" Period: {info['start_date']} to {info['end_date']}")
  261. print(f" Total bars: {info['total_bars']}")
  262. print("="*70)