test_cli.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from __future__ import annotations
  2. import io
  3. import tempfile
  4. import unittest
  5. from contextlib import redirect_stdout
  6. from pathlib import Path
  7. from unittest.mock import patch
  8. from src.data import bootstrap, status
  9. class CliTests(unittest.TestCase):
  10. def test_bootstrap_requires_all_flag(self) -> None:
  11. with self.assertRaises(SystemExit):
  12. bootstrap.main([])
  13. def test_status_renders_manifest(self) -> None:
  14. temp_dir = tempfile.TemporaryDirectory()
  15. self.addCleanup(temp_dir.cleanup)
  16. root = Path(temp_dir.name)
  17. (root / "src" / "data").mkdir(parents=True, exist_ok=True)
  18. manifest = {
  19. "provider": "fake",
  20. "generated_at": "2026-01-01T00:00:00+00:00",
  21. "common_sample": {"start_date": "2020-01-01", "end_date": "2020-01-31"},
  22. "instruments": {
  23. "sse50": {
  24. "name": "上证50",
  25. "actual_start": "2020-01-01",
  26. "layers": {
  27. "clean": {"end_date": "2020-01-31"},
  28. "features": {"end_date": "2020-01-31"},
  29. },
  30. }
  31. },
  32. }
  33. class PipelineStub:
  34. def status_snapshot(self):
  35. return manifest
  36. buf = io.StringIO()
  37. with patch("src.data.status.build_pipeline", return_value=PipelineStub()):
  38. with redirect_stdout(buf):
  39. exit_code = status.main([])
  40. output = buf.getvalue()
  41. self.assertEqual(exit_code, 0)
  42. self.assertIn("provider: fake", output)
  43. self.assertIn("[sse50] 上证50", output)
  44. if __name__ == "__main__":
  45. unittest.main()