| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- #!/usr/bin/env python3
- from __future__ import annotations
- from datetime import datetime, timedelta
- from pathlib import Path
- import sys
- import time
- import akshare as ak
- import pandas as pd
- import requests
- ROOT = Path(__file__).resolve().parent
- INDEX_ROTATION_ROOT = ROOT.parent / "index-rotation"
- if str(INDEX_ROTATION_ROOT) not in sys.path:
- sys.path.insert(0, str(INDEX_ROTATION_ROOT))
- from src.data.config import get_instrument, load_instruments
- from src.data.pipeline import merge_frames
- from src.data.transform import build_clean_frame, build_features_frame
- DEFAULT_SAVE_PATH = ROOT / "chinext50.csv"
- RAW_PATH = INDEX_ROTATION_ROOT / "data/raw/chinext50/price.parquet"
- CLEAN_PATH = INDEX_ROTATION_ROOT / "data/clean/chinext50/daily.parquet"
- FEATURE_PATH = INDEX_ROTATION_ROOT / "data/features/chinext50/daily.parquet"
- CONFIG_PATH = INDEX_ROTATION_ROOT / "configs/instruments.yaml"
- DEFAULT_SYMBOL = "399673"
- DEFAULT_PROVIDER_SYMBOL = "sz399673"
- DEFAULT_RETRIES = 3
- TENCENT_URL = "https://web.ifzq.gtimg.cn/appstock/app/fqkline/get"
- def _convert_feature_to_csv(feature_df: pd.DataFrame, save_path: Path) -> pd.DataFrame:
- out = feature_df[
- [
- "trade_date",
- "close",
- "daily_return",
- "ma_20",
- "vol_20d",
- "distance_to_ma_20",
- ]
- ].copy()
- out.rename(columns={"trade_date": "datetime"}, inplace=True)
- out["open"] = out["close"]
- out["high"] = out["close"]
- out["low"] = out["close"]
- out["volume"] = 0
- out["instrument"] = "chinext50"
- out = out[
- [
- "datetime",
- "open",
- "high",
- "low",
- "close",
- "volume",
- "instrument",
- "daily_return",
- "ma_20",
- "vol_20d",
- "distance_to_ma_20",
- ]
- ]
- out.to_csv(save_path, index=False)
- return out
- def _fetch_incremental_akshare(start: str, end: str) -> pd.DataFrame:
- frame = ak.stock_zh_index_daily_em(symbol=DEFAULT_PROVIDER_SYMBOL, start_date=start, end_date=end)
- if frame is None or frame.empty:
- return pd.DataFrame(columns=["trade_date", "open", "close", "high", "low", "volume", "amount"])
- frame = frame.rename(columns={"date": "trade_date"})
- frame["trade_date"] = pd.to_datetime(frame["trade_date"])
- frame = frame[["trade_date", "open", "close", "high", "low", "volume", "amount"]].copy()
- frame = frame.sort_values("trade_date").drop_duplicates("trade_date", keep="last").reset_index(drop=True)
- return frame
- def _fetch_incremental_tencent(start_date: str, end_date: str) -> pd.DataFrame:
- headers = {"User-Agent": "Mozilla/5.0"}
- param = f"{DEFAULT_PROVIDER_SYMBOL},day,{start_date},{end_date},1000,qfq"
- r = requests.get(TENCENT_URL, params={"param": param}, headers=headers, timeout=20)
- r.raise_for_status()
- payload = r.json()
- data = (payload.get("data") or {}).get(DEFAULT_PROVIDER_SYMBOL) or {}
- rows = data.get("day") or []
- if not rows:
- return pd.DataFrame(columns=["trade_date", "open", "close", "high", "low", "volume", "amount"])
- df = pd.DataFrame(rows, columns=["trade_date", "open", "close", "high", "low", "volume"])
- df["trade_date"] = pd.to_datetime(df["trade_date"])
- for c in ["open", "close", "high", "low", "volume"]:
- df[c] = pd.to_numeric(df[c], errors="coerce")
- df["amount"] = pd.NA
- return df[["trade_date", "open", "close", "high", "low", "volume", "amount"]].dropna(subset=["trade_date", "close"])
- def _with_meta(df: pd.DataFrame) -> pd.DataFrame:
- out = df.copy()
- out["instrument"] = "chinext50"
- out["instrument_name"] = "创业板50"
- out["index_code"] = DEFAULT_SYMBOL
- out["provider"] = "akshare_eastmoney"
- return out[
- ["instrument", "instrument_name", "index_code", "provider", "trade_date", "open", "high", "low", "close", "volume", "amount"]
- ]
- def _validate_overlap(existing_raw: pd.DataFrame, fetched_raw: pd.DataFrame, overlap_days: int = 5) -> None:
- if existing_raw.empty or fetched_raw.empty:
- return
- last_existing = pd.to_datetime(existing_raw["trade_date"]).max()
- overlap_start = last_existing - pd.Timedelta(days=overlap_days * 3)
- existing_tail = existing_raw[pd.to_datetime(existing_raw["trade_date"]) >= overlap_start].copy()
- incoming = fetched_raw[pd.to_datetime(fetched_raw["trade_date"]) <= last_existing].copy()
- if incoming.empty:
- return
- merged = existing_tail.merge(
- incoming,
- on="trade_date",
- suffixes=("_old", "_new"),
- )
- if merged.empty:
- raise RuntimeError("增量更新校验失败:没有可比对的重叠交易日")
- for col in ["open", "high", "low", "close", "volume"]:
- diff = (pd.to_numeric(merged[f"{col}_old"], errors="coerce") - pd.to_numeric(merged[f"{col}_new"], errors="coerce")).abs().fillna(0)
- if float(diff.max()) > 1e-6:
- bad = merged.loc[diff.idxmax(), ["trade_date", f"{col}_old", f"{col}_new"]].to_dict()
- raise RuntimeError(f"增量更新校验失败:字段 {col} 与旧基线不一致,样例={bad}")
- def fetch_chinext50_data(save_path: str | Path = DEFAULT_SAVE_PATH, retries: int = DEFAULT_RETRIES) -> pd.DataFrame:
- save_path = Path(save_path)
- raw_existing = pd.read_parquet(RAW_PATH)
- last_date = pd.to_datetime(raw_existing["trade_date"]).max().date()
- today = datetime.now().date()
- request_start = last_date + timedelta(days=1)
- if request_start > today:
- feature_df = pd.read_parquet(FEATURE_PATH)
- return _convert_feature_to_csv(feature_df, save_path)
- start_ymd = request_start.strftime("%Y%m%d")
- end_ymd = today.strftime("%Y%m%d")
- start_iso = request_start.strftime("%Y-%m-%d")
- end_iso = today.strftime("%Y-%m-%d")
- last_error = None
- fetched = pd.DataFrame()
- for attempt in range(1, retries + 1):
- try:
- fetched = _fetch_incremental_akshare(start_ymd, end_ymd)
- if not fetched.empty:
- fetched = _with_meta(fetched)
- break
- except Exception as e:
- last_error = e
- time.sleep(attempt)
- else:
- for attempt in range(1, retries + 1):
- try:
- # include overlap window for same-source compatibility check
- overlap_start = (request_start - timedelta(days=10)).strftime("%Y-%m-%d")
- fetched = _fetch_incremental_tencent(overlap_start, end_iso)
- if not fetched.empty:
- fetched = _with_meta(fetched)
- break
- except Exception as e:
- last_error = e
- time.sleep(attempt)
- if fetched.empty:
- raise RuntimeError(f"远程数据刷新失败(akshare + tencent fallback 均失败): {last_error}")
- _validate_overlap(raw_existing, fetched)
- fetched_new = fetched[pd.to_datetime(fetched["trade_date"]).dt.date > last_date].copy()
- merged_raw = merge_frames(raw_existing, fetched_new)
- instruments = load_instruments(CONFIG_PATH)
- instrument = get_instrument(instruments, "chinext50")
- clean_df = build_clean_frame(merged_raw, instrument)
- feature_df = build_features_frame(clean_df)
- RAW_PATH.parent.mkdir(parents=True, exist_ok=True)
- CLEAN_PATH.parent.mkdir(parents=True, exist_ok=True)
- FEATURE_PATH.parent.mkdir(parents=True, exist_ok=True)
- merged_raw.to_parquet(RAW_PATH, index=False)
- clean_df.to_parquet(CLEAN_PATH, index=False)
- feature_df.to_parquet(FEATURE_PATH, index=False)
- return _convert_feature_to_csv(feature_df, save_path)
- if __name__ == "__main__":
- df = fetch_chinext50_data()
- print(
- f"数据已刷新: {DEFAULT_SAVE_PATH} | rows={len(df)} | "
- f"range={pd.to_datetime(df['datetime']).min().date()} ~ {pd.to_datetime(df['datetime']).max().date()}"
- )
|