import akshare as ak import pandas as pd import numpy as np from dataclasses import dataclass from datetime import datetime, timedelta from typing import Optional, Union, List, Tuple import time try: import requests except ImportError: # pragma: no cover requests = None @dataclass class LatestSnapshot: timestamp: datetime open: float high: float low: float close: float volume: float = 0.0 class DataFetcherV2: """ 数据获取类V2 - 基于用户提供的优化方法 使用ak.stock_zh_index_daily获取更可靠的数据 """ def __init__(self): self.cache = {} self.cache_expiry = {} self.cache_duration = 3600 # 缓存1小时 def _get_cache_key(self, symbol: str, start_date: str, end_date: str) -> str: """生成缓存键""" return f"{symbol}_{start_date}_{end_date}" def _is_cache_valid(self, cache_key: str) -> bool: """检查缓存是否有效""" if cache_key not in self.cache_expiry: return False return time.time() < self.cache_expiry[cache_key] def _set_cache(self, cache_key: str, data: pd.DataFrame): """设置缓存""" self.cache[cache_key] = data self.cache_expiry[cache_key] = time.time() + self.cache_duration def _should_force_refresh_t_day(self, end_date: str) -> bool: """ 是否应强制刷新当日(T日)请求。 当请求窗口覆盖今天时,不直接使用旧缓存,避免拿到过期的当日数据。 """ try: return pd.to_datetime(end_date).date() >= datetime.now().date() except Exception: return False def _writeback_cache_if_exists(self, cache_key: str, data: pd.DataFrame) -> None: """ 仅在缓存键已存在时回写缓存。 若缓存键不存在,则跳过(符合“有则回写、无则算了”)。 """ if cache_key not in self.cache: return payload = data.copy() payload.attrs.update(data.attrs) self._set_cache(cache_key, payload) def _format_index_code(self, symbol: str) -> str: """ 格式化指数代码为akshare标准格式 例如: 399673 -> sz399673 """ symbol = symbol.strip() # 移除可能的前缀 if '.' in symbol: code, exchange = symbol.split('.') symbol = code # 确保是6位代码 if len(symbol) == 6: # 判断交易所并添加前缀 if symbol.startswith(('00', '30')): # 深交所 return f"sz{symbol}" elif symbol.startswith(('60', '68')): # 上交所 return f"sh{symbol}" else: # 其他交易所默认使用sz return f"sz{symbol}" # 如果已经是格式化好的代码 return symbol.lower() def _infer_realtime_prefix(self, code: str) -> str: if code.startswith("399"): return "sz" if code.startswith("000"): return "sh" if code.startswith(("30", "00", "15")): return "sz" if code.startswith(("60", "68")): return "sh" return "sz" def fetch_latest_snapshot(self, symbol: str) -> Optional[LatestSnapshot]: if requests is None: return None formatted_symbol = self._format_index_code(symbol) if formatted_symbol.startswith(("sz", "sh")): prefix = formatted_symbol[:2] code = formatted_symbol[2:] else: code = formatted_symbol prefix = self._infer_realtime_prefix(code) url = f"http://hq.sinajs.cn/list={prefix}{code}" headers = { "User-Agent": "Mozilla/5.0", "Referer": "http://finance.sina.com.cn", } try: response = requests.get(url, headers=headers, timeout=10) response.raise_for_status() except Exception: return None response.encoding = "gbk" text = response.text if '"' not in text: return None try: payload = text.split('"')[1].split(",") if len(payload) < 6: return None open_price = float(payload[1]) prev_close = float(payload[2]) close_price = float(payload[3]) high_price = float(payload[4]) low_price = float(payload[5]) except (ValueError, IndexError): return None if close_price <= 0 or prev_close <= 0 or high_price <= 0 or low_price <= 0: return None return LatestSnapshot( timestamp=datetime.now(), open=open_price, high=high_price, low=low_price, close=close_price, volume=0.0, ) def fetch_index_data_v2(self, symbol: str, start_date: str = "2018-01-01", end_date: Optional[str] = None) -> pd.DataFrame: """ 使用优化的方法获取指数数据 Args: symbol: 指数代码,支持多种格式 start_date: 开始日期,默认2018-01-01 end_date: 结束日期,默认为当前日期 Returns: 包含OHLCV数据的DataFrame,索引为日期 """ resolved_end_date = end_date or datetime.now().strftime('%Y-%m-%d') cache_key = self._get_cache_key(symbol, start_date, resolved_end_date) force_refresh_t_day = self._should_force_refresh_t_day(resolved_end_date) if self._is_cache_valid(cache_key) and not force_refresh_t_day: return self.cache[cache_key].copy() try: # 格式化指数代码 formatted_code = self._format_index_code(symbol) print(f"正在获取指数 {formatted_code} 的日线级别历史数据...") # 使用akshare获取日线数据(用户提供的优化方法) all_data_df = ak.stock_zh_index_daily(symbol=formatted_code) if all_data_df.empty: print(f"Warning: No data found for index {symbol}") return pd.DataFrame() # 处理日期列 all_data_df['date'] = pd.to_datetime(all_data_df['date']) all_data_df.set_index('date', inplace=True) # 筛选日期范围 start_datetime = pd.to_datetime(start_date) end_datetime = pd.to_datetime(resolved_end_date) # 先筛选出指定日期之后的数据 filtered_df = all_data_df[all_data_df.index >= start_datetime] filtered_df = filtered_df[filtered_df.index <= end_datetime] if filtered_df.empty: print(f"Warning: No data found for {symbol} in date range {start_date} to {end_date}") return pd.DataFrame() # 标准化列名 filtered_df = self._standardize_columns(filtered_df) print(f"数据获取成功,期间为 {filtered_df.index[0]} 到 {filtered_df.index[-1]}") print(f"获取数据量: {len(filtered_df)} 条") # 缓存数据 filtered_df.attrs["intraday_snapshot_appended"] = False filtered_df.attrs["intraday_snapshot_timestamp"] = None filtered_df.attrs["historical_latest_bar_date"] = filtered_df.index[-1].date().isoformat() self._set_cache(cache_key, filtered_df) return filtered_df.copy() except Exception as e: print(f"Error fetching index data for {symbol}: {str(e)}") return pd.DataFrame() def fetch_index_data_with_latest_snapshot_v2( self, symbol: str, start_date: str = "2018-01-01", end_date: Optional[str] = None, ) -> pd.DataFrame: resolved_end_date = end_date or datetime.now().strftime("%Y-%m-%d") cache_key = self._get_cache_key(symbol, start_date, resolved_end_date) frame = self.fetch_index_data_v2(symbol=symbol, start_date=start_date, end_date=resolved_end_date) if frame.empty: return frame today = datetime.now().date() historical_latest_bar_date = frame.index[-1].date().isoformat() frame.attrs["intraday_snapshot_appended"] = False frame.attrs["intraday_snapshot_timestamp"] = None frame.attrs["historical_latest_bar_date"] = historical_latest_bar_date if end_date is not None and pd.to_datetime(end_date).date() < today: return frame if frame.index[-1].date() >= today: return frame if today.weekday() >= 5: return frame snapshot = self.fetch_latest_snapshot(symbol) if snapshot is None: return frame latest_row = pd.DataFrame( [ { "open": snapshot.open, "high": snapshot.high, "low": snapshot.low, "close": snapshot.close, "volume": snapshot.volume, } ], index=pd.DatetimeIndex([pd.Timestamp(snapshot.timestamp)]), ) latest_row.index.name = "date" merged = pd.concat([frame, latest_row]) merged = merged[~merged.index.duplicated(keep="last")].sort_index() merged.attrs["intraday_snapshot_appended"] = True merged.attrs["intraday_snapshot_timestamp"] = snapshot.timestamp.isoformat(timespec="seconds") merged.attrs["historical_latest_bar_date"] = historical_latest_bar_date self._writeback_cache_if_exists(cache_key, merged) return merged def fetch_stock_data_v2(self, symbol: str, start_date: str = "2018-01-01", end_date: Optional[str] = None) -> pd.DataFrame: """ 获取股票数据(如果akshare支持的话) """ if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') # 目前主要支持指数数据,股票数据作为占位符 print(f"Warning: Stock data fetching not fully implemented, trying as index: {symbol}") return self.fetch_index_data_v2(symbol, start_date, end_date) def _standardize_columns(self, df: pd.DataFrame) -> pd.DataFrame: """ 标准化DataFrame列名和索引 """ # 确保索引是日期类型 if not isinstance(df.index, pd.DatetimeIndex): try: df.index = pd.to_datetime(df.index) except: pass # 创建列名映射(中英文对照) column_mapping = { '开盘': 'open', '收盘': 'close', '最高': 'high', '最低': 'low', '成交量': 'volume', '成交额': 'amount', '涨跌幅': 'change_pct', '涨跌额': 'change', '振幅': 'amplitude', '换手率': 'turnover' } # 重命名列 df = df.rename(columns=column_mapping) # 确保必要的列存在 required_columns = ['open', 'high', 'low', 'close', 'volume'] for col in required_columns: if col not in df.columns: # 尝试从中文列名获取 chinese_map = { 'open': '开盘', 'high': '最高', 'low': '最低', 'close': '收盘', 'volume': '成交量' } if chinese_map[col] in df.columns: df[col] = df[chinese_map[col]] else: print(f"Warning: Missing column {col}, filling with NaN") df[col] = np.nan # 选择并排序列 result_columns = [col for col in required_columns if col in df.columns] other_columns = [col for col in df.columns if col not in required_columns] df = df[result_columns + other_columns] # 确保数值列是数值类型 numeric_columns = ['open', 'high', 'low', 'close', 'volume'] for col in numeric_columns: if col in df.columns: df[col] = pd.to_numeric(df[col], errors='coerce') # 删除包含NaN的行(可选) # df = df.dropna(subset=['close']) return df def get_data_by_date_range(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame: """ 获取指定日期范围的数据(主要接口) Args: symbol: 指数代码 start_date: 开始日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD' Returns: DataFrame with date index and OHLCV columns """ return self.fetch_index_data_v2(symbol, start_date, end_date) def clear_cache(self): """清除缓存""" self.cache.clear() self.cache_expiry.clear() class DataManagerV2: """ 数据管理类V2 - 基于优化的数据获取方式 """ def __init__(self, data_fetcher: DataFetcherV2): self.data_fetcher = data_fetcher self.data_cache = {} self.market_info = {} def load_data(self, symbol: str, start_date: str, end_date: str) -> pd.DataFrame: """ 加载数据并缓存 Args: symbol: 指数代码 start_date: 开始日期 end_date: 结束日期 Returns: DataFrame with standardized format """ cache_key = f"{symbol}_{start_date}_{end_date}" if cache_key not in self.data_cache: print(f"Loading data for {symbol} from {start_date} to {end_date}...") data = self.data_fetcher.get_data_by_date_range(symbol, start_date, end_date) if data.empty: print(f"Warning: Empty data returned for {symbol}") else: print(f"Successfully loaded {len(data)} bars for {symbol}") # 存储市场信息 self.market_info[symbol] = { 'start_date': data.index[0], 'end_date': data.index[-1], 'total_bars': len(data), 'has_data': True } self.data_cache[cache_key] = data return self.data_cache[cache_key].copy() def get_market_info(self, symbol: str) -> dict: """获取市场信息""" return self.market_info.get(symbol, {}) def get_available_symbols(self) -> list: """获取已加载的标的列表""" return list(self.market_info.keys()) def get_current_data(self, symbol: str, current_date: str, window_size: int = 100) -> pd.DataFrame: """ 获取当前日期之前的数据窗口(支持日期索引格式) """ # 查找对应的数据缓存 cache_key = None for key, data in self.data_cache.items(): if not data.empty and key.startswith(symbol): cache_key = key break if cache_key is None: return pd.DataFrame() data = self.data_cache[cache_key] current_datetime = pd.to_datetime(current_date) # 支持两种数据格式:日期索引或date列 if isinstance(data.index, pd.DatetimeIndex): # 日期索引格式 historical_data = data[data.index <= current_datetime].copy() result_data = historical_data.tail(window_size).reset_index(drop=True) else: # date列格式 historical_data = data[data['date'] <= current_datetime].copy() result_data = historical_data.tail(window_size).reset_index(drop=True) if len(result_data) == 0: return pd.DataFrame() return result_data def print_data_summary(self): """打印数据摘要""" print("\n" + "="*70) print("DATA MANAGER SUMMARY") print("="*70) for symbol, info in self.market_info.items(): print(f"{symbol}:") print(f" Period: {info['start_date']} to {info['end_date']}") print(f" Total bars: {info['total_bars']}") print("="*70)