| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- import akshare as ak
- import pandas as pd
- import numpy as np
- from dataclasses import dataclass
- from datetime import datetime, timedelta
- from typing import Optional, Union, List, Tuple
- import time
- try:
- import requests
- except ImportError: # pragma: no cover
- requests = None
- @dataclass
- class LatestSnapshot:
- timestamp: datetime
- open: float
- high: float
- low: float
- close: float
- volume: float = 0.0
- class DataFetcherV2:
- """
- 数据获取类V2 - 基于用户提供的优化方法
- 使用ak.stock_zh_index_daily获取更可靠的数据
- """
-
- def __init__(self):
- self.cache = {}
- self.cache_expiry = {}
- self.cache_duration = 3600 # 缓存1小时
-
- def _get_cache_key(self, symbol: str, start_date: str, end_date: str) -> str:
- """生成缓存键"""
- return f"{symbol}_{start_date}_{end_date}"
-
- def _is_cache_valid(self, cache_key: str) -> bool:
- """检查缓存是否有效"""
- if cache_key not in self.cache_expiry:
- return False
- return time.time() < self.cache_expiry[cache_key]
-
- def _set_cache(self, cache_key: str, data: pd.DataFrame):
- """设置缓存"""
- self.cache[cache_key] = data
- self.cache_expiry[cache_key] = time.time() + self.cache_duration
- def _should_force_refresh_t_day(self, end_date: str) -> bool:
- """
- 是否应强制刷新当日(T日)请求。
- 当请求窗口覆盖今天时,不直接使用旧缓存,避免拿到过期的当日数据。
- """
- try:
- return pd.to_datetime(end_date).date() >= datetime.now().date()
- except Exception:
- return False
- def _writeback_cache_if_exists(self, cache_key: str, data: pd.DataFrame) -> None:
- """
- 仅在缓存键已存在时回写缓存。
- 若缓存键不存在,则跳过(符合“有则回写、无则算了”)。
- """
- if cache_key not in self.cache:
- return
- payload = data.copy()
- payload.attrs.update(data.attrs)
- self._set_cache(cache_key, payload)
-
- def _format_index_code(self, symbol: str) -> str:
- """
- 格式化指数代码为akshare标准格式
- 例如: 399673 -> sz399673
- """
- symbol = symbol.strip()
-
- # 移除可能的前缀
- if '.' in symbol:
- code, exchange = symbol.split('.')
- symbol = code
-
- # 确保是6位代码
- if len(symbol) == 6:
- # 判断交易所并添加前缀
- if symbol.startswith(('00', '30')): # 深交所
- return f"sz{symbol}"
- elif symbol.startswith(('60', '68')): # 上交所
- return f"sh{symbol}"
- else:
- # 其他交易所默认使用sz
- return f"sz{symbol}"
-
- # 如果已经是格式化好的代码
- return symbol.lower()
- def _infer_realtime_prefix(self, code: str) -> str:
- if code.startswith("399"):
- return "sz"
- if code.startswith("000"):
- return "sh"
- if code.startswith(("30", "00", "15")):
- return "sz"
- if code.startswith(("60", "68")):
- return "sh"
- return "sz"
- def fetch_latest_snapshot(self, symbol: str) -> Optional[LatestSnapshot]:
- if requests is None:
- return None
- formatted_symbol = self._format_index_code(symbol)
- if formatted_symbol.startswith(("sz", "sh")):
- prefix = formatted_symbol[:2]
- code = formatted_symbol[2:]
- else:
- code = formatted_symbol
- prefix = self._infer_realtime_prefix(code)
- url = f"http://hq.sinajs.cn/list={prefix}{code}"
- headers = {
- "User-Agent": "Mozilla/5.0",
- "Referer": "http://finance.sina.com.cn",
- }
- try:
- response = requests.get(url, headers=headers, timeout=10)
- response.raise_for_status()
- except Exception:
- return None
- response.encoding = "gbk"
- text = response.text
- if '"' not in text:
- return None
- try:
- payload = text.split('"')[1].split(",")
- if len(payload) < 6:
- return None
- open_price = float(payload[1])
- prev_close = float(payload[2])
- close_price = float(payload[3])
- high_price = float(payload[4])
- low_price = float(payload[5])
- except (ValueError, IndexError):
- return None
- if close_price <= 0 or prev_close <= 0 or high_price <= 0 or low_price <= 0:
- return None
- return LatestSnapshot(
- timestamp=datetime.now(),
- open=open_price,
- high=high_price,
- low=low_price,
- close=close_price,
- volume=0.0,
- )
-
- def fetch_index_data_v2(self,
- symbol: str,
- start_date: str = "2018-01-01",
- end_date: Optional[str] = None) -> pd.DataFrame:
- """
- 使用优化的方法获取指数数据
-
- Args:
- symbol: 指数代码,支持多种格式
- start_date: 开始日期,默认2018-01-01
- end_date: 结束日期,默认为当前日期
-
- Returns:
- 包含OHLCV数据的DataFrame,索引为日期
- """
- resolved_end_date = end_date or datetime.now().strftime('%Y-%m-%d')
- cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
- force_refresh_t_day = self._should_force_refresh_t_day(resolved_end_date)
- if self._is_cache_valid(cache_key) and not force_refresh_t_day:
- return self.cache[cache_key].copy()
-
- try:
- # 格式化指数代码
- formatted_code = self._format_index_code(symbol)
- print(f"正在获取指数 {formatted_code} 的日线级别历史数据...")
-
- # 使用akshare获取日线数据(用户提供的优化方法)
- all_data_df = ak.stock_zh_index_daily(symbol=formatted_code)
-
- if all_data_df.empty:
- print(f"Warning: No data found for index {symbol}")
- return pd.DataFrame()
-
- # 处理日期列
- all_data_df['date'] = pd.to_datetime(all_data_df['date'])
- all_data_df.set_index('date', inplace=True)
-
- # 筛选日期范围
- start_datetime = pd.to_datetime(start_date)
- end_datetime = pd.to_datetime(resolved_end_date)
-
- # 先筛选出指定日期之后的数据
- filtered_df = all_data_df[all_data_df.index >= start_datetime]
- filtered_df = filtered_df[filtered_df.index <= end_datetime]
-
- if filtered_df.empty:
- print(f"Warning: No data found for {symbol} in date range {start_date} to {end_date}")
- return pd.DataFrame()
-
- # 标准化列名
- filtered_df = self._standardize_columns(filtered_df)
-
- print(f"数据获取成功,期间为 {filtered_df.index[0]} 到 {filtered_df.index[-1]}")
- print(f"获取数据量: {len(filtered_df)} 条")
-
- # 缓存数据
- filtered_df.attrs["intraday_snapshot_appended"] = False
- filtered_df.attrs["intraday_snapshot_timestamp"] = None
- filtered_df.attrs["historical_latest_bar_date"] = filtered_df.index[-1].date().isoformat()
- self._set_cache(cache_key, filtered_df)
- return filtered_df.copy()
-
- except Exception as e:
- print(f"Error fetching index data for {symbol}: {str(e)}")
- return pd.DataFrame()
- def fetch_index_data_with_latest_snapshot_v2(
- self,
- symbol: str,
- start_date: str = "2018-01-01",
- end_date: Optional[str] = None,
- ) -> pd.DataFrame:
- resolved_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
- cache_key = self._get_cache_key(symbol, start_date, resolved_end_date)
- frame = self.fetch_index_data_v2(symbol=symbol, start_date=start_date, end_date=resolved_end_date)
- if frame.empty:
- return frame
- today = datetime.now().date()
- historical_latest_bar_date = frame.index[-1].date().isoformat()
- frame.attrs["intraday_snapshot_appended"] = False
- frame.attrs["intraday_snapshot_timestamp"] = None
- frame.attrs["historical_latest_bar_date"] = historical_latest_bar_date
- if end_date is not None and pd.to_datetime(end_date).date() < today:
- return frame
- if frame.index[-1].date() >= today:
- return frame
- if today.weekday() >= 5:
- return frame
- snapshot = self.fetch_latest_snapshot(symbol)
- if snapshot is None:
- return frame
- latest_row = pd.DataFrame(
- [
- {
- "open": snapshot.open,
- "high": snapshot.high,
- "low": snapshot.low,
- "close": snapshot.close,
- "volume": snapshot.volume,
- }
- ],
- index=pd.DatetimeIndex([pd.Timestamp(snapshot.timestamp)]),
- )
- latest_row.index.name = "date"
- merged = pd.concat([frame, latest_row])
- merged = merged[~merged.index.duplicated(keep="last")].sort_index()
- merged.attrs["intraday_snapshot_appended"] = True
- merged.attrs["intraday_snapshot_timestamp"] = snapshot.timestamp.isoformat(timespec="seconds")
- merged.attrs["historical_latest_bar_date"] = historical_latest_bar_date
- self._writeback_cache_if_exists(cache_key, merged)
- return merged
-
- def fetch_stock_data_v2(self,
- symbol: str,
- start_date: str = "2018-01-01",
- end_date: Optional[str] = None) -> pd.DataFrame:
- """
- 获取股票数据(如果akshare支持的话)
- """
- if end_date is None:
- end_date = datetime.now().strftime('%Y-%m-%d')
-
- # 目前主要支持指数数据,股票数据作为占位符
- print(f"Warning: Stock data fetching not fully implemented, trying as index: {symbol}")
- return self.fetch_index_data_v2(symbol, start_date, end_date)
-
- def _standardize_columns(self, df: pd.DataFrame) -> pd.DataFrame:
- """
- 标准化DataFrame列名和索引
- """
- # 确保索引是日期类型
- if not isinstance(df.index, pd.DatetimeIndex):
- try:
- df.index = pd.to_datetime(df.index)
- except:
- pass
-
- # 创建列名映射(中英文对照)
- column_mapping = {
- '开盘': 'open',
- '收盘': 'close',
- '最高': 'high',
- '最低': 'low',
- '成交量': 'volume',
- '成交额': 'amount',
- '涨跌幅': 'change_pct',
- '涨跌额': 'change',
- '振幅': 'amplitude',
- '换手率': 'turnover'
- }
-
- # 重命名列
- df = df.rename(columns=column_mapping)
-
- # 确保必要的列存在
- required_columns = ['open', 'high', 'low', 'close', 'volume']
- for col in required_columns:
- if col not in df.columns:
- # 尝试从中文列名获取
- chinese_map = {
- 'open': '开盘', 'high': '最高', 'low': '最低',
- 'close': '收盘', 'volume': '成交量'
- }
- if chinese_map[col] in df.columns:
- df[col] = df[chinese_map[col]]
- else:
- print(f"Warning: Missing column {col}, filling with NaN")
- df[col] = np.nan
-
- # 选择并排序列
- result_columns = [col for col in required_columns if col in df.columns]
- other_columns = [col for col in df.columns if col not in required_columns]
- df = df[result_columns + other_columns]
-
- # 确保数值列是数值类型
- numeric_columns = ['open', 'high', 'low', 'close', 'volume']
- for col in numeric_columns:
- if col in df.columns:
- df[col] = pd.to_numeric(df[col], errors='coerce')
-
- # 删除包含NaN的行(可选)
- # df = df.dropna(subset=['close'])
-
- return df
-
- def get_data_by_date_range(self,
- symbol: str,
- start_date: str,
- end_date: str) -> pd.DataFrame:
- """
- 获取指定日期范围的数据(主要接口)
-
- Args:
- symbol: 指数代码
- start_date: 开始日期 'YYYY-MM-DD'
- end_date: 结束日期 'YYYY-MM-DD'
-
- Returns:
- DataFrame with date index and OHLCV columns
- """
- return self.fetch_index_data_v2(symbol, start_date, end_date)
-
- def clear_cache(self):
- """清除缓存"""
- self.cache.clear()
- self.cache_expiry.clear()
- class DataManagerV2:
- """
- 数据管理类V2 - 基于优化的数据获取方式
- """
-
- def __init__(self, data_fetcher: DataFetcherV2):
- self.data_fetcher = data_fetcher
- self.data_cache = {}
- self.market_info = {}
-
- def load_data(self,
- symbol: str,
- start_date: str,
- end_date: str) -> pd.DataFrame:
- """
- 加载数据并缓存
-
- Args:
- symbol: 指数代码
- start_date: 开始日期
- end_date: 结束日期
-
- Returns:
- DataFrame with standardized format
- """
- cache_key = f"{symbol}_{start_date}_{end_date}"
-
- if cache_key not in self.data_cache:
- print(f"Loading data for {symbol} from {start_date} to {end_date}...")
- data = self.data_fetcher.get_data_by_date_range(symbol, start_date, end_date)
-
- if data.empty:
- print(f"Warning: Empty data returned for {symbol}")
- else:
- print(f"Successfully loaded {len(data)} bars for {symbol}")
- # 存储市场信息
- self.market_info[symbol] = {
- 'start_date': data.index[0],
- 'end_date': data.index[-1],
- 'total_bars': len(data),
- 'has_data': True
- }
-
- self.data_cache[cache_key] = data
-
- return self.data_cache[cache_key].copy()
-
- def get_market_info(self, symbol: str) -> dict:
- """获取市场信息"""
- return self.market_info.get(symbol, {})
-
- def get_available_symbols(self) -> list:
- """获取已加载的标的列表"""
- return list(self.market_info.keys())
-
- def get_current_data(self,
- symbol: str,
- current_date: str,
- window_size: int = 100) -> pd.DataFrame:
- """
- 获取当前日期之前的数据窗口(支持日期索引格式)
- """
- # 查找对应的数据缓存
- cache_key = None
- for key, data in self.data_cache.items():
- if not data.empty and key.startswith(symbol):
- cache_key = key
- break
-
- if cache_key is None:
- return pd.DataFrame()
-
- data = self.data_cache[cache_key]
- current_datetime = pd.to_datetime(current_date)
-
- # 支持两种数据格式:日期索引或date列
- if isinstance(data.index, pd.DatetimeIndex):
- # 日期索引格式
- historical_data = data[data.index <= current_datetime].copy()
- result_data = historical_data.tail(window_size).reset_index(drop=True)
- else:
- # date列格式
- historical_data = data[data['date'] <= current_datetime].copy()
- result_data = historical_data.tail(window_size).reset_index(drop=True)
-
- if len(result_data) == 0:
- return pd.DataFrame()
-
- return result_data
-
- def print_data_summary(self):
- """打印数据摘要"""
- print("\n" + "="*70)
- print("DATA MANAGER SUMMARY")
- print("="*70)
- for symbol, info in self.market_info.items():
- print(f"{symbol}:")
- print(f" Period: {info['start_date']} to {info['end_date']}")
- print(f" Total bars: {info['total_bars']}")
- print("="*70)
|