fetch_data.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. from datetime import datetime, timedelta
  4. from pathlib import Path
  5. import sys
  6. import time
  7. import akshare as ak
  8. import pandas as pd
  9. import requests
  10. ROOT = Path(__file__).resolve().parent
  11. INDEX_ROTATION_ROOT = ROOT.parent / "index-rotation"
  12. if str(INDEX_ROTATION_ROOT) not in sys.path:
  13. sys.path.insert(0, str(INDEX_ROTATION_ROOT))
  14. from src.data.config import get_instrument, load_instruments
  15. from src.data.pipeline import merge_frames
  16. from src.data.transform import build_clean_frame, build_features_frame
  17. DEFAULT_SAVE_PATH = ROOT / "chinext50.csv"
  18. RAW_PATH = INDEX_ROTATION_ROOT / "data/raw/chinext50/price.parquet"
  19. CLEAN_PATH = INDEX_ROTATION_ROOT / "data/clean/chinext50/daily.parquet"
  20. FEATURE_PATH = INDEX_ROTATION_ROOT / "data/features/chinext50/daily.parquet"
  21. CONFIG_PATH = INDEX_ROTATION_ROOT / "configs/instruments.yaml"
  22. DEFAULT_SYMBOL = "399673"
  23. DEFAULT_PROVIDER_SYMBOL = "sz399673"
  24. DEFAULT_RETRIES = 3
  25. TENCENT_URL = "https://web.ifzq.gtimg.cn/appstock/app/fqkline/get"
  26. def _convert_feature_to_csv(feature_df: pd.DataFrame, save_path: Path) -> pd.DataFrame:
  27. out = feature_df[
  28. [
  29. "trade_date",
  30. "close",
  31. "daily_return",
  32. "ma_20",
  33. "vol_20d",
  34. "distance_to_ma_20",
  35. ]
  36. ].copy()
  37. out.rename(columns={"trade_date": "datetime"}, inplace=True)
  38. out["open"] = out["close"]
  39. out["high"] = out["close"]
  40. out["low"] = out["close"]
  41. out["volume"] = 0
  42. out["instrument"] = "chinext50"
  43. out = out[
  44. [
  45. "datetime",
  46. "open",
  47. "high",
  48. "low",
  49. "close",
  50. "volume",
  51. "instrument",
  52. "daily_return",
  53. "ma_20",
  54. "vol_20d",
  55. "distance_to_ma_20",
  56. ]
  57. ]
  58. out.to_csv(save_path, index=False)
  59. return out
  60. def _fetch_incremental_akshare(start: str, end: str) -> pd.DataFrame:
  61. frame = ak.stock_zh_index_daily_em(symbol=DEFAULT_PROVIDER_SYMBOL, start_date=start, end_date=end)
  62. if frame is None or frame.empty:
  63. return pd.DataFrame(columns=["trade_date", "open", "close", "high", "low", "volume", "amount"])
  64. frame = frame.rename(columns={"date": "trade_date"})
  65. frame["trade_date"] = pd.to_datetime(frame["trade_date"])
  66. frame = frame[["trade_date", "open", "close", "high", "low", "volume", "amount"]].copy()
  67. frame = frame.sort_values("trade_date").drop_duplicates("trade_date", keep="last").reset_index(drop=True)
  68. return frame
  69. def _fetch_incremental_tencent(start_date: str, end_date: str) -> pd.DataFrame:
  70. headers = {"User-Agent": "Mozilla/5.0"}
  71. param = f"{DEFAULT_PROVIDER_SYMBOL},day,{start_date},{end_date},1000,qfq"
  72. r = requests.get(TENCENT_URL, params={"param": param}, headers=headers, timeout=20)
  73. r.raise_for_status()
  74. payload = r.json()
  75. data = (payload.get("data") or {}).get(DEFAULT_PROVIDER_SYMBOL) or {}
  76. rows = data.get("day") or []
  77. if not rows:
  78. return pd.DataFrame(columns=["trade_date", "open", "close", "high", "low", "volume", "amount"])
  79. df = pd.DataFrame(rows, columns=["trade_date", "open", "close", "high", "low", "volume"])
  80. df["trade_date"] = pd.to_datetime(df["trade_date"])
  81. for c in ["open", "close", "high", "low", "volume"]:
  82. df[c] = pd.to_numeric(df[c], errors="coerce")
  83. df["amount"] = pd.NA
  84. return df[["trade_date", "open", "close", "high", "low", "volume", "amount"]].dropna(subset=["trade_date", "close"])
  85. def _with_meta(df: pd.DataFrame) -> pd.DataFrame:
  86. out = df.copy()
  87. out["instrument"] = "chinext50"
  88. out["instrument_name"] = "创业板50"
  89. out["index_code"] = DEFAULT_SYMBOL
  90. out["provider"] = "akshare_eastmoney"
  91. return out[
  92. ["instrument", "instrument_name", "index_code", "provider", "trade_date", "open", "high", "low", "close", "volume", "amount"]
  93. ]
  94. def _validate_overlap(existing_raw: pd.DataFrame, fetched_raw: pd.DataFrame, overlap_days: int = 5) -> None:
  95. if existing_raw.empty or fetched_raw.empty:
  96. return
  97. last_existing = pd.to_datetime(existing_raw["trade_date"]).max()
  98. overlap_start = last_existing - pd.Timedelta(days=overlap_days * 3)
  99. existing_tail = existing_raw[pd.to_datetime(existing_raw["trade_date"]) >= overlap_start].copy()
  100. incoming = fetched_raw[pd.to_datetime(fetched_raw["trade_date"]) <= last_existing].copy()
  101. if incoming.empty:
  102. return
  103. merged = existing_tail.merge(
  104. incoming,
  105. on="trade_date",
  106. suffixes=("_old", "_new"),
  107. )
  108. if merged.empty:
  109. raise RuntimeError("增量更新校验失败:没有可比对的重叠交易日")
  110. for col in ["open", "high", "low", "close", "volume"]:
  111. diff = (pd.to_numeric(merged[f"{col}_old"], errors="coerce") - pd.to_numeric(merged[f"{col}_new"], errors="coerce")).abs().fillna(0)
  112. if float(diff.max()) > 1e-6:
  113. bad = merged.loc[diff.idxmax(), ["trade_date", f"{col}_old", f"{col}_new"]].to_dict()
  114. raise RuntimeError(f"增量更新校验失败:字段 {col} 与旧基线不一致,样例={bad}")
  115. def fetch_chinext50_data(save_path: str | Path = DEFAULT_SAVE_PATH, retries: int = DEFAULT_RETRIES) -> pd.DataFrame:
  116. save_path = Path(save_path)
  117. raw_existing = pd.read_parquet(RAW_PATH)
  118. last_date = pd.to_datetime(raw_existing["trade_date"]).max().date()
  119. today = datetime.now().date()
  120. request_start = last_date + timedelta(days=1)
  121. if request_start > today:
  122. feature_df = pd.read_parquet(FEATURE_PATH)
  123. return _convert_feature_to_csv(feature_df, save_path)
  124. start_ymd = request_start.strftime("%Y%m%d")
  125. end_ymd = today.strftime("%Y%m%d")
  126. start_iso = request_start.strftime("%Y-%m-%d")
  127. end_iso = today.strftime("%Y-%m-%d")
  128. last_error = None
  129. fetched = pd.DataFrame()
  130. for attempt in range(1, retries + 1):
  131. try:
  132. fetched = _fetch_incremental_akshare(start_ymd, end_ymd)
  133. if not fetched.empty:
  134. fetched = _with_meta(fetched)
  135. break
  136. except Exception as e:
  137. last_error = e
  138. time.sleep(attempt)
  139. else:
  140. for attempt in range(1, retries + 1):
  141. try:
  142. # include overlap window for same-source compatibility check
  143. overlap_start = (request_start - timedelta(days=10)).strftime("%Y-%m-%d")
  144. fetched = _fetch_incremental_tencent(overlap_start, end_iso)
  145. if not fetched.empty:
  146. fetched = _with_meta(fetched)
  147. break
  148. except Exception as e:
  149. last_error = e
  150. time.sleep(attempt)
  151. if fetched.empty:
  152. raise RuntimeError(f"远程数据刷新失败(akshare + tencent fallback 均失败): {last_error}")
  153. _validate_overlap(raw_existing, fetched)
  154. fetched_new = fetched[pd.to_datetime(fetched["trade_date"]).dt.date > last_date].copy()
  155. merged_raw = merge_frames(raw_existing, fetched_new)
  156. instruments = load_instruments(CONFIG_PATH)
  157. instrument = get_instrument(instruments, "chinext50")
  158. clean_df = build_clean_frame(merged_raw, instrument)
  159. feature_df = build_features_frame(clean_df)
  160. RAW_PATH.parent.mkdir(parents=True, exist_ok=True)
  161. CLEAN_PATH.parent.mkdir(parents=True, exist_ok=True)
  162. FEATURE_PATH.parent.mkdir(parents=True, exist_ok=True)
  163. merged_raw.to_parquet(RAW_PATH, index=False)
  164. clean_df.to_parquet(CLEAN_PATH, index=False)
  165. feature_df.to_parquet(FEATURE_PATH, index=False)
  166. return _convert_feature_to_csv(feature_df, save_path)
  167. if __name__ == "__main__":
  168. df = fetch_chinext50_data()
  169. print(
  170. f"数据已刷新: {DEFAULT_SAVE_PATH} | rows={len(df)} | "
  171. f"range={pd.to_datetime(df['datetime']).min().date()} ~ {pd.to_datetime(df['datetime']).max().date()}"
  172. )