test_pipeline.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from __future__ import annotations
  2. import json
  3. import tempfile
  4. import unittest
  5. from datetime import date
  6. from pathlib import Path
  7. import pandas as pd
  8. from src.data.metadata import MetadataStore
  9. from src.data.pipeline import DataPipeline
  10. from src.data.providers.base import IndexPriceProvider
  11. from src.data.storage import InMemoryDataLake
  12. from src.data.transform import build_clean_frame, build_features_frame
  13. class FakeProvider(IndexPriceProvider):
  14. name = "fake_provider"
  15. def __init__(self, frames: dict[str, list[pd.DataFrame]]) -> None:
  16. self.frames = frames
  17. self.calls: list[tuple[str, date, date]] = []
  18. def fetch_price_history(self, instrument, start_date, end_date) -> pd.DataFrame:
  19. self.calls.append((instrument.key, start_date, end_date))
  20. queue = self.frames.setdefault(instrument.key, [])
  21. if not queue:
  22. return pd.DataFrame(columns=["trade_date", "open", "close", "high", "low", "volume", "amount"])
  23. return queue.pop(0).copy()
  24. def make_price_frame(start: str, closes: list[float]) -> pd.DataFrame:
  25. dates = pd.date_range(start=start, periods=len(closes), freq="D")
  26. frame = pd.DataFrame(
  27. {
  28. "trade_date": dates,
  29. "open": closes,
  30. "close": closes,
  31. "high": [value + 1 for value in closes],
  32. "low": [value - 1 for value in closes],
  33. "volume": [1000 + idx for idx in range(len(closes))],
  34. "amount": [100000 + idx for idx in range(len(closes))],
  35. }
  36. )
  37. return frame
  38. class PipelineTests(unittest.TestCase):
  39. def setUp(self) -> None:
  40. self.temp_dir = tempfile.TemporaryDirectory()
  41. self.addCleanup(self.temp_dir.cleanup)
  42. self.root = Path(self.temp_dir.name)
  43. (self.root / "configs").mkdir(parents=True, exist_ok=True)
  44. (self.root / "data" / "meta").mkdir(parents=True, exist_ok=True)
  45. self.config_path = self.root / "configs" / "instruments.yaml"
  46. self.config_path.write_text(
  47. "\n".join(
  48. [
  49. "instruments:",
  50. " sse50:",
  51. " name: 上证50",
  52. " index_code: \"000016\"",
  53. " provider_symbol: sh000016",
  54. " exchange: SSE",
  55. " price_type: price_index",
  56. " bootstrap_start: \"2003-12-31\"",
  57. ]
  58. ),
  59. encoding="utf-8",
  60. )
  61. def test_features_do_not_change_when_future_rows_are_appended(self) -> None:
  62. clean_a = pd.DataFrame(
  63. {
  64. "instrument": ["sse50"] * 25,
  65. "instrument_name": ["上证50"] * 25,
  66. "index_code": ["000016"] * 25,
  67. "provider": ["fake"] * 25,
  68. "price_type": ["price_index"] * 25,
  69. "trade_date": pd.date_range("2020-01-01", periods=25, freq="D"),
  70. "open": range(1, 26),
  71. "high": range(2, 27),
  72. "low": range(0, 25),
  73. "close": range(1, 26),
  74. "prev_close": [None] + list(range(1, 25)),
  75. "change_amount": [None] + [1] * 24,
  76. "daily_return": [None] + [1.0 / value for value in range(1, 25)],
  77. "volume": [100] * 25,
  78. "amount": [1000] * 25,
  79. }
  80. )
  81. features_a = build_features_frame(clean_a)
  82. features_b = build_features_frame(pd.concat([clean_a, clean_a.tail(1).assign(trade_date=pd.Timestamp("2020-01-26"), close=26, open=26, high=27, low=25, prev_close=25, change_amount=1, daily_return=0.04)], ignore_index=True))
  83. pd.testing.assert_frame_equal(
  84. features_a.iloc[:25].reset_index(drop=True),
  85. features_b.iloc[:25].reset_index(drop=True),
  86. )
  87. def test_bootstrap_then_incremental_update_merges_raw_and_updates_manifest(self) -> None:
  88. provider = FakeProvider(
  89. {
  90. "sse50": [
  91. make_price_frame("2020-01-01", [10, 11, 12]),
  92. make_price_frame("2020-01-04", [13, 14]),
  93. ]
  94. }
  95. )
  96. datalake = InMemoryDataLake(self.root / "memory")
  97. metadata = MetadataStore(self.root / "data" / "meta")
  98. pipeline = DataPipeline(
  99. repo_root=self.root,
  100. config_path=self.config_path,
  101. data_root=self.root / "data",
  102. provider=provider,
  103. datalake=datalake,
  104. metadata_store=metadata,
  105. )
  106. bootstrap = pipeline.bootstrap_all(today=date(2020, 1, 3))
  107. self.assertEqual(bootstrap["sse50"]["raw"]["rows"], 3)
  108. update = pipeline.update_since_last(today=date(2020, 1, 5))
  109. self.assertEqual(update["sse50"]["raw"]["rows"], 5)
  110. raw_frame = datalake.read_layer("raw", "sse50")
  111. self.assertEqual(len(raw_frame.index), 5)
  112. manifest = metadata.load_manifest()
  113. self.assertEqual(manifest["instruments"]["sse50"]["actual_start"], "2020-01-01")
  114. self.assertEqual(manifest["instruments"]["sse50"]["layers"]["features"]["end_date"], "2020-01-05")
  115. fetch_log_lines = (self.root / "data" / "meta" / "fetch_log.jsonl").read_text(encoding="utf-8").strip().splitlines()
  116. self.assertEqual(len(fetch_log_lines), 6)
  117. latest_payload = json.loads(fetch_log_lines[-1])
  118. self.assertEqual(latest_payload["layer"], "features")
  119. self.assertEqual(latest_payload["operation"], "update")
  120. def test_repair_features_uses_local_clean_only(self) -> None:
  121. provider = FakeProvider({"sse50": [make_price_frame("2020-01-01", [10, 11, 12, 13, 14])]})
  122. datalake = InMemoryDataLake(self.root / "memory")
  123. metadata = MetadataStore(self.root / "data" / "meta")
  124. pipeline = DataPipeline(
  125. repo_root=self.root,
  126. config_path=self.config_path,
  127. data_root=self.root / "data",
  128. provider=provider,
  129. datalake=datalake,
  130. metadata_store=metadata,
  131. )
  132. pipeline.bootstrap_all(today=date(2020, 1, 5))
  133. datalake.write_layer("features", "sse50", pd.DataFrame({"broken": [1]}))
  134. repaired = pipeline.repair("sse50", "features")
  135. repaired_frame = datalake.read_layer("features", "sse50")
  136. self.assertIn("ret_1d", repaired_frame.columns)
  137. self.assertEqual(repaired["features"]["rows"], 5)
  138. self.assertEqual(len(provider.calls), 1)
  139. def test_repair_clean_rebuilds_features_as_downstream_dependency(self) -> None:
  140. provider = FakeProvider({"sse50": [make_price_frame("2020-01-01", [10, 11, 12, 13, 14])]})
  141. datalake = InMemoryDataLake(self.root / "memory")
  142. metadata = MetadataStore(self.root / "data" / "meta")
  143. pipeline = DataPipeline(
  144. repo_root=self.root,
  145. config_path=self.config_path,
  146. data_root=self.root / "data",
  147. provider=provider,
  148. datalake=datalake,
  149. metadata_store=metadata,
  150. )
  151. pipeline.bootstrap_all(today=date(2020, 1, 5))
  152. datalake.write_layer("features", "sse50", pd.DataFrame({"broken": [1]}))
  153. repaired = pipeline.repair("sse50", "clean")
  154. repaired_frame = datalake.read_layer("features", "sse50")
  155. self.assertIn("ma_5", repaired_frame.columns)
  156. self.assertEqual(repaired["clean"]["rows"], 5)
  157. self.assertEqual(repaired["features"]["rows"], 5)
  158. def test_clean_layer_computes_prev_close_and_daily_return(self) -> None:
  159. raw_frame = make_price_frame("2020-01-01", [10, 12, 18])
  160. raw_frame["instrument"] = "sse50"
  161. raw_frame["instrument_name"] = "上证50"
  162. raw_frame["index_code"] = "000016"
  163. raw_frame["provider"] = "fake"
  164. raw_frame = raw_frame[
  165. ["instrument", "instrument_name", "index_code", "provider", "trade_date", "open", "high", "low", "close", "volume", "amount"]
  166. ]
  167. class InstrumentStub:
  168. price_type = "price_index"
  169. clean = build_clean_frame(raw_frame, InstrumentStub())
  170. self.assertTrue(pd.isna(clean.loc[0, "prev_close"]))
  171. self.assertAlmostEqual(clean.loc[1, "daily_return"], 0.2)
  172. self.assertAlmostEqual(clean.loc[2, "daily_return"], 0.5)
  173. if __name__ == "__main__":
  174. unittest.main()