| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- import akshare as ak
- import pandas as pd
- import numpy as np
- from datetime import datetime, timedelta
- from typing import Optional, Union, List, Tuple
- import time
- class DataFetcherV2:
- """
- 数据获取类V2 - 基于用户提供的优化方法
- 使用ak.stock_zh_index_daily获取更可靠的数据
- """
-
- def __init__(self):
- self.cache = {}
- self.cache_expiry = {}
- self.cache_duration = 3600 # 缓存1小时
-
- def _get_cache_key(self, symbol: str, start_date: str, end_date: str) -> str:
- """生成缓存键"""
- return f"{symbol}_{start_date}_{end_date}"
-
- def _is_cache_valid(self, cache_key: str) -> bool:
- """检查缓存是否有效"""
- if cache_key not in self.cache_expiry:
- return False
- return time.time() < self.cache_expiry[cache_key]
-
- def _set_cache(self, cache_key: str, data: pd.DataFrame):
- """设置缓存"""
- self.cache[cache_key] = data
- self.cache_expiry[cache_key] = time.time() + self.cache_duration
-
- def _format_index_code(self, symbol: str) -> str:
- """
- 格式化指数代码为akshare标准格式
- 例如: 399673 -> sz399673
- """
- symbol = symbol.strip()
-
- # 移除可能的前缀
- if '.' in symbol:
- code, exchange = symbol.split('.')
- symbol = code
-
- # 确保是6位代码
- if len(symbol) == 6:
- # 判断交易所并添加前缀
- if symbol.startswith(('00', '30')): # 深交所
- return f"sz{symbol}"
- elif symbol.startswith(('60', '68')): # 上交所
- return f"sh{symbol}"
- else:
- # 其他交易所默认使用sz
- return f"sz{symbol}"
-
- # 如果已经是格式化好的代码
- return symbol.lower()
-
- def fetch_index_data_v2(self,
- symbol: str,
- start_date: str = "2018-01-01",
- end_date: Optional[str] = None) -> pd.DataFrame:
- """
- 使用优化的方法获取指数数据
-
- Args:
- symbol: 指数代码,支持多种格式
- start_date: 开始日期,默认2018-01-01
- end_date: 结束日期,默认为当前日期
-
- Returns:
- 包含OHLCV数据的DataFrame,索引为日期
- """
- if end_date is None:
- end_date = datetime.now().strftime('%Y-%m-%d')
-
- cache_key = self._get_cache_key(symbol, start_date, end_date)
-
- if self._is_cache_valid(cache_key):
- return self.cache[cache_key].copy()
-
- try:
- # 格式化指数代码
- formatted_code = self._format_index_code(symbol)
- print(f"正在获取指数 {formatted_code} 的日线级别历史数据...")
-
- # 使用akshare获取日线数据(用户提供的优化方法)
- all_data_df = ak.stock_zh_index_daily(symbol=formatted_code)
-
- if all_data_df.empty:
- print(f"Warning: No data found for index {symbol}")
- return pd.DataFrame()
-
- # 处理日期列
- all_data_df['date'] = pd.to_datetime(all_data_df['date'])
- all_data_df.set_index('date', inplace=True)
-
- # 筛选日期范围
- start_datetime = pd.to_datetime(start_date)
- end_datetime = pd.to_datetime(end_date)
-
- # 先筛选出指定日期之后的数据
- filtered_df = all_data_df[all_data_df.index >= start_datetime]
- filtered_df = filtered_df[filtered_df.index <= end_datetime]
-
- if filtered_df.empty:
- print(f"Warning: No data found for {symbol} in date range {start_date} to {end_date}")
- return pd.DataFrame()
-
- # 标准化列名
- filtered_df = self._standardize_columns(filtered_df)
-
- print(f"数据获取成功,期间为 {filtered_df.index[0]} 到 {filtered_df.index[-1]}")
- print(f"获取数据量: {len(filtered_df)} 条")
-
- # 缓存数据
- self._set_cache(cache_key, filtered_df)
- return filtered_df.copy()
-
- except Exception as e:
- print(f"Error fetching index data for {symbol}: {str(e)}")
- return pd.DataFrame()
-
- def fetch_stock_data_v2(self,
- symbol: str,
- start_date: str = "2018-01-01",
- end_date: Optional[str] = None) -> pd.DataFrame:
- """
- 获取股票数据(如果akshare支持的话)
- """
- if end_date is None:
- end_date = datetime.now().strftime('%Y-%m-%d')
-
- # 目前主要支持指数数据,股票数据作为占位符
- print(f"Warning: Stock data fetching not fully implemented, trying as index: {symbol}")
- return self.fetch_index_data_v2(symbol, start_date, end_date)
-
- def _standardize_columns(self, df: pd.DataFrame) -> pd.DataFrame:
- """
- 标准化DataFrame列名和索引
- """
- # 确保索引是日期类型
- if not isinstance(df.index, pd.DatetimeIndex):
- try:
- df.index = pd.to_datetime(df.index)
- except:
- pass
-
- # 创建列名映射(中英文对照)
- column_mapping = {
- '开盘': 'open',
- '收盘': 'close',
- '最高': 'high',
- '最低': 'low',
- '成交量': 'volume',
- '成交额': 'amount',
- '涨跌幅': 'change_pct',
- '涨跌额': 'change',
- '振幅': 'amplitude',
- '换手率': 'turnover'
- }
-
- # 重命名列
- df = df.rename(columns=column_mapping)
-
- # 确保必要的列存在
- required_columns = ['open', 'high', 'low', 'close', 'volume']
- for col in required_columns:
- if col not in df.columns:
- # 尝试从中文列名获取
- chinese_map = {
- 'open': '开盘', 'high': '最高', 'low': '最低',
- 'close': '收盘', 'volume': '成交量'
- }
- if chinese_map[col] in df.columns:
- df[col] = df[chinese_map[col]]
- else:
- print(f"Warning: Missing column {col}, filling with NaN")
- df[col] = np.nan
-
- # 选择并排序列
- result_columns = [col for col in required_columns if col in df.columns]
- other_columns = [col for col in df.columns if col not in required_columns]
- df = df[result_columns + other_columns]
-
- # 确保数值列是数值类型
- numeric_columns = ['open', 'high', 'low', 'close', 'volume']
- for col in numeric_columns:
- if col in df.columns:
- df[col] = pd.to_numeric(df[col], errors='coerce')
-
- # 删除包含NaN的行(可选)
- # df = df.dropna(subset=['close'])
-
- return df
-
- def get_data_by_date_range(self,
- symbol: str,
- start_date: str,
- end_date: str) -> pd.DataFrame:
- """
- 获取指定日期范围的数据(主要接口)
-
- Args:
- symbol: 指数代码
- start_date: 开始日期 'YYYY-MM-DD'
- end_date: 结束日期 'YYYY-MM-DD'
-
- Returns:
- DataFrame with date index and OHLCV columns
- """
- return self.fetch_index_data_v2(symbol, start_date, end_date)
-
- def clear_cache(self):
- """清除缓存"""
- self.cache.clear()
- self.cache_expiry.clear()
- class DataManagerV2:
- """
- 数据管理类V2 - 基于优化的数据获取方式
- """
-
- def __init__(self, data_fetcher: DataFetcherV2):
- self.data_fetcher = data_fetcher
- self.data_cache = {}
- self.market_info = {}
-
- def load_data(self,
- symbol: str,
- start_date: str,
- end_date: str) -> pd.DataFrame:
- """
- 加载数据并缓存
-
- Args:
- symbol: 指数代码
- start_date: 开始日期
- end_date: 结束日期
-
- Returns:
- DataFrame with standardized format
- """
- cache_key = f"{symbol}_{start_date}_{end_date}"
-
- if cache_key not in self.data_cache:
- print(f"Loading data for {symbol} from {start_date} to {end_date}...")
- data = self.data_fetcher.get_data_by_date_range(symbol, start_date, end_date)
-
- if data.empty:
- print(f"Warning: Empty data returned for {symbol}")
- else:
- print(f"Successfully loaded {len(data)} bars for {symbol}")
- # 存储市场信息
- self.market_info[symbol] = {
- 'start_date': data.index[0],
- 'end_date': data.index[-1],
- 'total_bars': len(data),
- 'has_data': True
- }
-
- self.data_cache[cache_key] = data
-
- return self.data_cache[cache_key].copy()
-
- def get_market_info(self, symbol: str) -> dict:
- """获取市场信息"""
- return self.market_info.get(symbol, {})
-
- def get_available_symbols(self) -> list:
- """获取已加载的标的列表"""
- return list(self.market_info.keys())
-
- def get_current_data(self,
- symbol: str,
- current_date: str,
- window_size: int = 100) -> pd.DataFrame:
- """
- 获取当前日期之前的数据窗口(支持日期索引格式)
- """
- # 查找对应的数据缓存
- cache_key = None
- for key, data in self.data_cache.items():
- if not data.empty and key.startswith(symbol):
- cache_key = key
- break
-
- if cache_key is None:
- return pd.DataFrame()
-
- data = self.data_cache[cache_key]
- current_datetime = pd.to_datetime(current_date)
-
- # 支持两种数据格式:日期索引或date列
- if isinstance(data.index, pd.DatetimeIndex):
- # 日期索引格式
- historical_data = data[data.index <= current_datetime].copy()
- result_data = historical_data.tail(window_size).reset_index(drop=True)
- else:
- # date列格式
- historical_data = data[data['date'] <= current_datetime].copy()
- result_data = historical_data.tail(window_size).reset_index(drop=True)
-
- if len(result_data) == 0:
- return pd.DataFrame()
-
- return result_data
-
- def print_data_summary(self):
- """打印数据摘要"""
- print("\n" + "="*70)
- print("DATA MANAGER SUMMARY")
- print("="*70)
- for symbol, info in self.market_info.items():
- print(f"{symbol}:")
- print(f" Period: {info['start_date']} to {info['end_date']}")
- print(f" Total bars: {info['total_bars']}")
- print("="*70)
|