mairui_fetcher.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. #!/usr/bin/env python3
  2. """
  3. 创业板50指数数据获取模块 (纯标准库版本)
  4. 使用 mairui API 获取 K线数据,无需安装 pandas
  5. """
  6. import urllib.request
  7. import urllib.error
  8. import json
  9. import os
  10. import csv
  11. from datetime import datetime
  12. from typing import Optional, List, Dict
  13. class MairuiDataFetcher:
  14. """mairui 数据获取类 (纯标准库)"""
  15. # API 配置
  16. BASE_URL = "https://api.mairuiapi.com/hsindex/history"
  17. TOKEN = "AE17EE23-AAE4-492F-A959-EC883DFA5A76"
  18. # 指数代码映射
  19. INDEX_CODES = {
  20. "cyb50": "399673.SZ",
  21. "cy": "399006.SZ",
  22. "sh": "000001.SH",
  23. "hs300": "000300.SH",
  24. "sz": "399001.SZ",
  25. }
  26. def __init__(self, data_dir: str = "./data"):
  27. self.data_dir = data_dir
  28. os.makedirs(data_dir, exist_ok=True)
  29. def fetch_data(
  30. self,
  31. index_code: str = "399673.SZ",
  32. timeframe: str = "30",
  33. start_date: Optional[str] = None,
  34. end_date: Optional[str] = None
  35. ) -> List[Dict]:
  36. """
  37. 获取K线数据
  38. Args:
  39. index_code: 指数代码
  40. timeframe: K线周期 (d=日线, 30=30分钟, 60=60分钟)
  41. start_date: 开始日期 (YYYY-MM-DD)
  42. end_date: 结束日期 (YYYY-MM-DD)
  43. Returns:
  44. 数据列表,每个元素是一个字典
  45. """
  46. # 日期格式转换
  47. if start_date:
  48. st = start_date.replace("-", "")
  49. else:
  50. st = "20230101"
  51. if end_date:
  52. et = end_date.replace("-", "")
  53. else:
  54. et = datetime.now().strftime("%Y%m%d")
  55. print(f"正在获取 {index_code} 的{timeframe}分钟K线数据...")
  56. print(f"时间范围: {start_date or st} 至 {end_date or et}")
  57. # 构建API URL
  58. url = f"{self.BASE_URL}/{index_code}/{timeframe}/{self.TOKEN}?st={st}&et={et}"
  59. print(f"API URL: {url}")
  60. try:
  61. # 使用标准库发送请求
  62. req = urllib.request.Request(url, headers={
  63. 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.0'
  64. })
  65. with urllib.request.urlopen(req, timeout=60) as response:
  66. data = json.loads(response.read().decode('utf-8'))
  67. # 解析数据
  68. if isinstance(data, list):
  69. records = data
  70. elif isinstance(data, dict):
  71. records = data.get("data", data.get("list", []))
  72. else:
  73. records = []
  74. if not records:
  75. print("✗ API返回空数据")
  76. return []
  77. print(f"✓ 获取到 {len(records)} 条数据")
  78. # 标准化数据
  79. standardized_records = []
  80. for record in records:
  81. std_record = {
  82. "datetime": record.get("d") or record.get("t"),
  83. "open": record.get("o"),
  84. "high": record.get("h"),
  85. "low": record.get("l"),
  86. "close": record.get("c"),
  87. "volume": record.get("v"),
  88. "amount": record.get("a"),
  89. }
  90. standardized_records.append(std_record)
  91. # 按时间排序
  92. standardized_records.sort(key=lambda x: x["datetime"])
  93. if standardized_records:
  94. print(f" 时间范围: {standardized_records[0]['datetime']} 至 {standardized_records[-1]['datetime']}")
  95. return standardized_records
  96. except urllib.error.URLError as e:
  97. print(f"✗ 网络请求失败: {e}")
  98. return []
  99. except json.JSONDecodeError as e:
  100. print(f"✗ JSON解析失败: {e}")
  101. return []
  102. except Exception as e:
  103. print(f"✗ 处理失败: {e}")
  104. import traceback
  105. traceback.print_exc()
  106. return []
  107. def save_to_csv(self, records: List[Dict], filename: str) -> str:
  108. """保存数据到CSV文件"""
  109. if not records:
  110. print("✗ 没有数据可保存")
  111. return ""
  112. filepath = os.path.join(self.data_dir, filename)
  113. # 获取所有字段名
  114. fieldnames = ["datetime", "open", "high", "low", "close", "volume", "amount"]
  115. with open(filepath, 'w', newline='', encoding='utf-8') as f:
  116. writer = csv.DictWriter(f, fieldnames=fieldnames)
  117. writer.writeheader()
  118. writer.writerows(records)
  119. print(f"✓ 数据已保存到: {filepath}")
  120. return filepath
  121. def load_from_csv(self, filename: str) -> List[Dict]:
  122. """从CSV文件加载数据"""
  123. filepath = os.path.join(self.data_dir, filename)
  124. if not os.path.exists(filepath):
  125. print(f"✗ 文件不存在: {filepath}")
  126. return []
  127. records = []
  128. with open(filepath, 'r', encoding='utf-8') as f:
  129. reader = csv.DictReader(f)
  130. for row in reader:
  131. records.append(dict(row))
  132. print(f"✓ 已从 {filepath} 加载 {len(records)} 条记录")
  133. return records
  134. def print_preview(self, records: List[Dict], head: int = 10, tail: int = 5):
  135. """打印数据预览"""
  136. if not records:
  137. print("没有数据")
  138. return
  139. print(f"\n数据预览 (前{head}条):")
  140. print("-" * 80)
  141. print(f"{'datetime':<20} {'open':<10} {'high':<10} {'low':<10} {'close':<10} {'volume':<12}")
  142. print("-" * 80)
  143. for r in records[:head]:
  144. print(f"{r.get('datetime', ''):<20} {r.get('open', ''):<10} {r.get('high', ''):<10} "
  145. f"{r.get('low', ''):<10} {r.get('close', ''):<10} {r.get('volume', ''):<12}")
  146. if len(records) > head + tail:
  147. print(f"\n... ({len(records) - head - tail} 条数据省略) ...\n")
  148. print(f"数据预览 (后{tail}条):")
  149. print("-" * 80)
  150. for r in records[-tail:]:
  151. print(f"{r.get('datetime', ''):<20} {r.get('open', ''):<10} {r.get('high', ''):<10} "
  152. f"{r.get('low', ''):<10} {r.get('close', ''):<10} {r.get('volume', ''):<12}")
  153. print("-" * 80)
  154. def fetch_cyb50(
  155. timeframe: str = "d",
  156. start_date: str = "2023-01-01",
  157. end_date: Optional[str] = None,
  158. save: bool = True
  159. ) -> List[Dict]:
  160. """获取创业板50指数数据(便捷函数)"""
  161. fetcher = MairuiDataFetcher(data_dir="./data")
  162. records = fetcher.fetch_data(
  163. index_code="399673.SZ",
  164. timeframe=timeframe,
  165. start_date=start_date,
  166. end_date=end_date
  167. )
  168. if save and records:
  169. end_str = end_date.replace("-", "") if end_date else datetime.now().strftime("%Y%m%d")
  170. tf_name = "day" if timeframe == "d" else f"{timeframe}min"
  171. filename = f"cyb50_{tf_name}_{start_date.replace('-', '')}_{end_str}.csv"
  172. fetcher.save_to_csv(records, filename)
  173. return records
  174. if __name__ == "__main__":
  175. import argparse
  176. parser = argparse.ArgumentParser(description="获取指数K线数据")
  177. parser.add_argument("--code", default="399673.SZ", help="指数代码")
  178. parser.add_argument("--tf", default="d", help="K线周期: d=日线, 30=30分钟")
  179. parser.add_argument("--start", default="2015-01-01", help="开始日期 (YYYY-MM-DD)")
  180. parser.add_argument("--end", default=None, help="结束日期 (YYYY-MM-DD)")
  181. parser.add_argument("--no-save", action="store_true", help="不保存到文件")
  182. args = parser.parse_args()
  183. print("=" * 80)
  184. print("mairui 指数数据获取工具 (纯标准库版本)")
  185. print("=" * 80)
  186. fetcher = MairuiDataFetcher(data_dir="./data")
  187. records = fetcher.fetch_data(
  188. index_code=args.code,
  189. timeframe=args.tf,
  190. start_date=args.start,
  191. end_date=args.end
  192. )
  193. if records:
  194. fetcher.print_preview(records)
  195. print(f"\n数据统计:")
  196. print(f" 总记录数: {len(records)}")
  197. if not args.no_save:
  198. end_str = args.end.replace("-", "") if args.end else datetime.now().strftime("%Y%m%d")
  199. tf_name = "day" if args.tf == "d" else f"{args.tf}min"
  200. filename = f"{args.code.replace('.', '_')}_{tf_name}_{args.start.replace('-', '')}_{end_str}.csv"
  201. fetcher.save_to_csv(records, filename)
  202. else:
  203. print("\n✗ 数据获取失败")
  204. exit(1)
  205. print("\n" + "=" * 80)
  206. print("完成!")
  207. print("=" * 80)