| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- """
- 多品种数据加载器
- 支持沪深300、中证500、创业板50三个指数的数据获取与管理
- """
- from typing import Dict, List, Optional, Tuple
- from datetime import datetime, timedelta
- import pandas as pd
- import numpy as np
- class MultiAssetDataLoader:
- """
- 多品种数据加载器
- 支持三个宽基指数:
- - 沪深300 (000300.SH)
- - 中证500 (000905.SH)
- - 创业板50 (399006.SZ)
- """
- # 品种代码映射
- SYMBOLS = {
- "csi300": "000300.SH", # 沪深300
- "csi500": "000905.SH", # 中证500
- "chinext50": "399006.SZ", # 创业板50
- }
- # 品种名称映射
- NAMES = {
- "csi300": "沪深300",
- "csi500": "中证500",
- "chinext50": "创业板50",
- }
- def __init__(self, data_dir: str = "quant"):
- self.data_dir = data_dir
- self._cache: Dict[str, pd.DataFrame] = {}
- self._correlation_cache: Optional[pd.DataFrame] = None
- def load_all_data(
- self,
- start_date: Optional[str] = None,
- end_date: Optional[str] = None
- ) -> Dict[str, pd.DataFrame]:
- """
- 加载所有三个品种的数据
- Args:
- start_date: 开始日期 (YYYY-MM-DD)
- end_date: 结束日期 (YYYY-MM-DD)
- Returns:
- Dict[str, pd.DataFrame]: 品种代码 -> 数据框的映射
- """
- result = {}
- for symbol_key in self.SYMBOLS.keys():
- df = self.load_symbol(symbol_key, start_date, end_date)
- if df is not None:
- result[symbol_key] = df
- return result
- def load_symbol(
- self,
- symbol_key: str,
- start_date: Optional[str] = None,
- end_date: Optional[str] = None
- ) -> Optional[pd.DataFrame]:
- """
- 加载单个品种的数据
- Args:
- symbol_key: 品种代码键 (csi300/csi500/chinext50)
- start_date: 开始日期
- end_date: 结束日期
- Returns:
- pd.DataFrame or None: 加载的数据
- """
- # 检查缓存
- cache_key = f"{symbol_key}_{start_date}_{end_date}"
- if cache_key in self._cache:
- return self._cache[cache_key].copy()
- # 构建文件路径
- file_path = f"{self.data_dir}/{symbol_key}_daily.csv"
- try:
- df = pd.read_csv(file_path)
- df['date'] = pd.to_datetime(df['date'])
- df.set_index('date', inplace=True)
- df.sort_index(inplace=True)
- # 日期过滤
- if start_date:
- df = df[df.index >= start_date]
- if end_date:
- df = df[df.index <= end_date]
- # 缓存
- self._cache[cache_key] = df.copy()
- return df
- except FileNotFoundError:
- print(f"Warning: Data file not found for {symbol_key} at {file_path}")
- return None
- except Exception as e:
- print(f"Error loading {symbol_key}: {e}")
- return None
- def align_data(
- self,
- data_dict: Dict[str, pd.DataFrame],
- fill_method: str = "ffill"
- ) -> pd.DataFrame:
- """
- 对齐多个品种的数据到统一时间轴
- Args:
- data_dict: 品种数据字典
- fill_method: 缺失值填充方法 (ffill/bfill/drop)
- Returns:
- pd.DataFrame: 对齐后的多品种数据(宽格式)
- """
- if not data_dict:
- return pd.DataFrame()
- # 获取所有日期
- all_dates = set()
- for df in data_dict.values():
- all_dates.update(df.index)
- all_dates = sorted(all_dates)
- # 创建统一时间轴
- aligned_data = pd.DataFrame(index=all_dates)
- # 为每个品种添加数据
- for symbol_key, df in data_dict.items():
- # 收盘价
- aligned_data[f"{symbol_key}_close"] = df['close']
- # 成交量
- aligned_data[f"{symbol_key}_volume"] = df['volume']
- # 最高价
- aligned_data[f"{symbol_key}_high"] = df['high']
- # 最低价
- aligned_data[f"{symbol_key}_low"] = df['low']
- # 开盘价
- aligned_data[f"{symbol_key}_open"] = df['open']
- # 处理缺失值
- if fill_method == "ffill":
- aligned_data.fillna(method="ffill", inplace=True)
- elif fill_method == "bfill":
- aligned_data.fillna(method="bfill", inplace=True)
- elif fill_method == "drop":
- aligned_data.dropna(inplace=True)
- return aligned_data
- def calculate_correlation(
- self,
- data_dict: Dict[str, pd.DataFrame],
- lookback: int = 60,
- current_date: Optional[datetime] = None
- ) -> pd.DataFrame:
- """
- 计算品种间滚动相关系数
- Args:
- data_dict: 品种数据字典
- lookback: 滚动窗口天数
- current_date: 当前日期(用于历史回测)
- Returns:
- pd.DataFrame: 相关系数矩阵
- """
- # 提取收盘价
- close_prices = pd.DataFrame()
- for symbol_key, df in data_dict.items():
- close_prices[symbol_key] = df['close']
- # 日期过滤
- if current_date is not None:
- close_prices = close_prices[close_prices.index <= current_date]
- # 计算收益率
- returns = close_prices.pct_change().dropna()
- # 滚动相关系数(取最近lookback天)
- if len(returns) >= lookback:
- recent_returns = returns.iloc[-lookback:]
- corr_matrix = recent_returns.corr()
- else:
- # 数据不足时返回单位矩阵
- corr_matrix = pd.DataFrame(
- np.eye(len(self.SYMBOLS)),
- index=list(self.SYMBOLS.keys()),
- columns=list(self.SYMBOLS.keys())
- )
- return corr_matrix
- def get_symbol_data_at_date(
- self,
- data_dict: Dict[str, pd.DataFrame],
- symbol_key: str,
- date: datetime,
- lookback: int = 60
- ) -> Optional[pd.DataFrame]:
- """
- 获取指定品种在指定日期的历史数据窗口
- Args:
- data_dict: 品种数据字典
- symbol_key: 品种代码键
- date: 目标日期
- lookback: 回溯天数
- Returns:
- pd.DataFrame or None: 历史数据窗口
- """
- if symbol_key not in data_dict:
- return None
- df = data_dict[symbol_key]
- # 找到目标日期的位置
- try:
- idx = df.index.get_loc(date)
- if idx < lookback:
- return None
- # 返回历史窗口
- return df.iloc[idx - lookback:idx + 1]
- except KeyError:
- return None
- def get_current_prices(
- self,
- data_dict: Dict[str, pd.DataFrame],
- date: Optional[datetime] = None
- ) -> Dict[str, float]:
- """
- 获取所有品种的当前价格
- Args:
- data_dict: 品种数据字典
- date: 目标日期(None表示最新)
- Returns:
- Dict[str, float]: 品种代码 -> 当前价格
- """
- prices = {}
- for symbol_key, df in data_dict.items():
- if date is None:
- prices[symbol_key] = df['close'].iloc[-1]
- else:
- try:
- prices[symbol_key] = df.loc[date, 'close']
- except KeyError:
- # 找最近的有效日期
- valid_dates = df.index[df.index <= date]
- if len(valid_dates) > 0:
- prices[symbol_key] = df.loc[valid_dates[-1], 'close']
- return prices
- def clear_cache(self):
- """清除数据缓存"""
- self._cache.clear()
- self._correlation_cache = None
|