data_loader.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """
  2. 多品种数据加载器
  3. 支持沪深300、中证500、创业板50三个指数的数据获取与管理
  4. """
  5. from typing import Dict, List, Optional, Tuple
  6. from datetime import datetime, timedelta
  7. import pandas as pd
  8. import numpy as np
  9. class MultiAssetDataLoader:
  10. """
  11. 多品种数据加载器
  12. 支持三个宽基指数:
  13. - 沪深300 (000300.SH)
  14. - 中证500 (000905.SH)
  15. - 创业板50 (399006.SZ)
  16. """
  17. # 品种代码映射
  18. SYMBOLS = {
  19. "csi300": "000300.SH", # 沪深300
  20. "csi500": "000905.SH", # 中证500
  21. "chinext50": "399006.SZ", # 创业板50
  22. }
  23. # 品种名称映射
  24. NAMES = {
  25. "csi300": "沪深300",
  26. "csi500": "中证500",
  27. "chinext50": "创业板50",
  28. }
  29. def __init__(self, data_dir: str = "quant"):
  30. self.data_dir = data_dir
  31. self._cache: Dict[str, pd.DataFrame] = {}
  32. self._correlation_cache: Optional[pd.DataFrame] = None
  33. def load_all_data(
  34. self,
  35. start_date: Optional[str] = None,
  36. end_date: Optional[str] = None
  37. ) -> Dict[str, pd.DataFrame]:
  38. """
  39. 加载所有三个品种的数据
  40. Args:
  41. start_date: 开始日期 (YYYY-MM-DD)
  42. end_date: 结束日期 (YYYY-MM-DD)
  43. Returns:
  44. Dict[str, pd.DataFrame]: 品种代码 -> 数据框的映射
  45. """
  46. result = {}
  47. for symbol_key in self.SYMBOLS.keys():
  48. df = self.load_symbol(symbol_key, start_date, end_date)
  49. if df is not None:
  50. result[symbol_key] = df
  51. return result
  52. def load_symbol(
  53. self,
  54. symbol_key: str,
  55. start_date: Optional[str] = None,
  56. end_date: Optional[str] = None
  57. ) -> Optional[pd.DataFrame]:
  58. """
  59. 加载单个品种的数据
  60. Args:
  61. symbol_key: 品种代码键 (csi300/csi500/chinext50)
  62. start_date: 开始日期
  63. end_date: 结束日期
  64. Returns:
  65. pd.DataFrame or None: 加载的数据
  66. """
  67. # 检查缓存
  68. cache_key = f"{symbol_key}_{start_date}_{end_date}"
  69. if cache_key in self._cache:
  70. return self._cache[cache_key].copy()
  71. # 构建文件路径
  72. file_path = f"{self.data_dir}/{symbol_key}_daily.csv"
  73. try:
  74. df = pd.read_csv(file_path)
  75. df['date'] = pd.to_datetime(df['date'])
  76. df.set_index('date', inplace=True)
  77. df.sort_index(inplace=True)
  78. # 日期过滤
  79. if start_date:
  80. df = df[df.index >= start_date]
  81. if end_date:
  82. df = df[df.index <= end_date]
  83. # 缓存
  84. self._cache[cache_key] = df.copy()
  85. return df
  86. except FileNotFoundError:
  87. print(f"Warning: Data file not found for {symbol_key} at {file_path}")
  88. return None
  89. except Exception as e:
  90. print(f"Error loading {symbol_key}: {e}")
  91. return None
  92. def align_data(
  93. self,
  94. data_dict: Dict[str, pd.DataFrame],
  95. fill_method: str = "ffill"
  96. ) -> pd.DataFrame:
  97. """
  98. 对齐多个品种的数据到统一时间轴
  99. Args:
  100. data_dict: 品种数据字典
  101. fill_method: 缺失值填充方法 (ffill/bfill/drop)
  102. Returns:
  103. pd.DataFrame: 对齐后的多品种数据(宽格式)
  104. """
  105. if not data_dict:
  106. return pd.DataFrame()
  107. # 获取所有日期
  108. all_dates = set()
  109. for df in data_dict.values():
  110. all_dates.update(df.index)
  111. all_dates = sorted(all_dates)
  112. # 创建统一时间轴
  113. aligned_data = pd.DataFrame(index=all_dates)
  114. # 为每个品种添加数据
  115. for symbol_key, df in data_dict.items():
  116. # 收盘价
  117. aligned_data[f"{symbol_key}_close"] = df['close']
  118. # 成交量
  119. aligned_data[f"{symbol_key}_volume"] = df['volume']
  120. # 最高价
  121. aligned_data[f"{symbol_key}_high"] = df['high']
  122. # 最低价
  123. aligned_data[f"{symbol_key}_low"] = df['low']
  124. # 开盘价
  125. aligned_data[f"{symbol_key}_open"] = df['open']
  126. # 处理缺失值
  127. if fill_method == "ffill":
  128. aligned_data.fillna(method="ffill", inplace=True)
  129. elif fill_method == "bfill":
  130. aligned_data.fillna(method="bfill", inplace=True)
  131. elif fill_method == "drop":
  132. aligned_data.dropna(inplace=True)
  133. return aligned_data
  134. def calculate_correlation(
  135. self,
  136. data_dict: Dict[str, pd.DataFrame],
  137. lookback: int = 60,
  138. current_date: Optional[datetime] = None
  139. ) -> pd.DataFrame:
  140. """
  141. 计算品种间滚动相关系数
  142. Args:
  143. data_dict: 品种数据字典
  144. lookback: 滚动窗口天数
  145. current_date: 当前日期(用于历史回测)
  146. Returns:
  147. pd.DataFrame: 相关系数矩阵
  148. """
  149. # 提取收盘价
  150. close_prices = pd.DataFrame()
  151. for symbol_key, df in data_dict.items():
  152. close_prices[symbol_key] = df['close']
  153. # 日期过滤
  154. if current_date is not None:
  155. close_prices = close_prices[close_prices.index <= current_date]
  156. # 计算收益率
  157. returns = close_prices.pct_change().dropna()
  158. # 滚动相关系数(取最近lookback天)
  159. if len(returns) >= lookback:
  160. recent_returns = returns.iloc[-lookback:]
  161. corr_matrix = recent_returns.corr()
  162. else:
  163. # 数据不足时返回单位矩阵
  164. corr_matrix = pd.DataFrame(
  165. np.eye(len(self.SYMBOLS)),
  166. index=list(self.SYMBOLS.keys()),
  167. columns=list(self.SYMBOLS.keys())
  168. )
  169. return corr_matrix
  170. def get_symbol_data_at_date(
  171. self,
  172. data_dict: Dict[str, pd.DataFrame],
  173. symbol_key: str,
  174. date: datetime,
  175. lookback: int = 60
  176. ) -> Optional[pd.DataFrame]:
  177. """
  178. 获取指定品种在指定日期的历史数据窗口
  179. Args:
  180. data_dict: 品种数据字典
  181. symbol_key: 品种代码键
  182. date: 目标日期
  183. lookback: 回溯天数
  184. Returns:
  185. pd.DataFrame or None: 历史数据窗口
  186. """
  187. if symbol_key not in data_dict:
  188. return None
  189. df = data_dict[symbol_key]
  190. # 找到目标日期的位置
  191. try:
  192. idx = df.index.get_loc(date)
  193. if idx < lookback:
  194. return None
  195. # 返回历史窗口
  196. return df.iloc[idx - lookback:idx + 1]
  197. except KeyError:
  198. return None
  199. def get_current_prices(
  200. self,
  201. data_dict: Dict[str, pd.DataFrame],
  202. date: Optional[datetime] = None
  203. ) -> Dict[str, float]:
  204. """
  205. 获取所有品种的当前价格
  206. Args:
  207. data_dict: 品种数据字典
  208. date: 目标日期(None表示最新)
  209. Returns:
  210. Dict[str, float]: 品种代码 -> 当前价格
  211. """
  212. prices = {}
  213. for symbol_key, df in data_dict.items():
  214. if date is None:
  215. prices[symbol_key] = df['close'].iloc[-1]
  216. else:
  217. try:
  218. prices[symbol_key] = df.loc[date, 'close']
  219. except KeyError:
  220. # 找最近的有效日期
  221. valid_dates = df.index[df.index <= date]
  222. if len(valid_dates) > 0:
  223. prices[symbol_key] = df.loc[valid_dates[-1], 'close']
  224. return prices
  225. def clear_cache(self):
  226. """清除数据缓存"""
  227. self._cache.clear()
  228. self._correlation_cache = None