test_ecosystem.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. """
  2. 生态识别引擎单元测试
  3. 测试覆盖:
  4. - 四季识别
  5. - 健康度分级
  6. - 有毒订单流检测
  7. - HMM状态识别
  8. - 生态融合
  9. """
  10. import unittest
  11. from datetime import datetime, timedelta
  12. import numpy as np
  13. import pandas as pd
  14. from core.ecosystem import (
  15. MacroEcosystemIdentifier, MacroRegime,
  16. MesoEcosystemIdentifier, HealthLevel,
  17. MicroEcosystemIdentifier, MicroState, FlowToxicity,
  18. InstantEcosystemIdentifier, ImbalanceDirection, TickActivity,
  19. EcosystemFusion
  20. )
  21. class TestMacroEcosystem(unittest.TestCase):
  22. """测试宏观生态识别器"""
  23. def setUp(self):
  24. self.identifier = MacroEcosystemIdentifier()
  25. def generate_price_data(self, regime_type: str, days: int = 100) -> pd.DataFrame:
  26. """生成模拟价格数据"""
  27. np.random.seed(42)
  28. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  29. if regime_type == "summer":
  30. # 夏季:强趋势,中等波动
  31. returns = np.random.normal(0.001, 0.015, days)
  32. trend = np.linspace(0, 0.1, days)
  33. returns += trend
  34. elif regime_type == "winter":
  35. # 冬季:低波动,低成交量
  36. returns = np.random.normal(-0.0005, 0.008, days)
  37. elif regime_type == "autumn":
  38. # 秋季:高波动,无趋势
  39. returns = np.random.normal(0, 0.025, days)
  40. elif regime_type == "spring":
  41. # 春季:波动率回升
  42. returns = np.random.normal(0.0005, 0.012, days)
  43. else:
  44. returns = np.random.normal(0, 0.015, days)
  45. prices = 100 * np.exp(np.cumsum(returns))
  46. volumes = np.random.lognormal(15, 0.5, days) * (
  47. 0.7 if regime_type == "winter" else 1.0
  48. )
  49. return pd.DataFrame({
  50. 'open': prices * (1 + np.random.normal(0, 0.001, days)),
  51. 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
  52. 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
  53. 'close': prices,
  54. 'volume': volumes
  55. }, index=dates)
  56. def test_summer_identification(self):
  57. """测试夏季识别"""
  58. data = self.generate_price_data("summer")
  59. result = self.identifier.identify(data)
  60. self.assertIsNotNone(result)
  61. self.assertIn(result.regime, [MacroRegime.SUMMER, MacroRegime.SPRING])
  62. self.assertGreater(result.confidence, 0.2)
  63. self.assertGreater(result.adx_value, 15)
  64. def test_winter_identification(self):
  65. """测试冬季识别"""
  66. data = self.generate_price_data("winter")
  67. result = self.identifier.identify(data)
  68. self.assertIsNotNone(result)
  69. self.assertIn(result.regime, [MacroRegime.WINTER, MacroRegime.UNKNOWN])
  70. def test_adx_calculation(self):
  71. """测试ADX计算"""
  72. data = self.generate_price_data("summer")
  73. adx = self.identifier._calculate_adx(data)
  74. self.assertIsInstance(adx, (float, np.floating))
  75. self.assertGreater(adx, 0)
  76. self.assertLessEqual(adx, 100)
  77. def test_volatility_calculation(self):
  78. """测试波动率计算"""
  79. data = self.generate_price_data("autumn")
  80. volatility = self.identifier._calculate_volatility(data)
  81. self.assertIsInstance(volatility, pd.Series)
  82. self.assertGreater(volatility.iloc[-1], 0)
  83. class TestMesoEcosystem(unittest.TestCase):
  84. """测试中观生态识别器"""
  85. def setUp(self):
  86. self.identifier = MesoEcosystemIdentifier()
  87. def generate_price_data(self, health: str, days: int = 60) -> pd.DataFrame:
  88. """生成不同健康度的价格数据"""
  89. np.random.seed(42)
  90. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  91. if health == "high":
  92. # 高健康度:低冲击,流动性好
  93. returns = np.random.normal(0.0005, 0.012, days)
  94. volumes = np.random.lognormal(15, 0.3, days)
  95. elif health == "low":
  96. # 低健康度:高冲击,流动性差
  97. returns = np.random.normal(0, 0.03, days)
  98. volumes = np.random.lognormal(14, 0.8, days)
  99. else:
  100. returns = np.random.normal(0, 0.015, days)
  101. volumes = np.random.lognormal(15, 0.5, days)
  102. prices = 100 * np.exp(np.cumsum(returns))
  103. return pd.DataFrame({
  104. 'open': prices * (1 + np.random.normal(0, 0.001, days)),
  105. 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
  106. 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
  107. 'close': prices,
  108. 'volume': volumes
  109. }, index=dates)
  110. def test_high_health_identification(self):
  111. """测试高健康度识别"""
  112. data = self.generate_price_data("high")
  113. result = self.identifier.identify(data)
  114. self.assertIsNotNone(result)
  115. self.assertIn(result.health_level, [HealthLevel.HIGH, HealthLevel.MEDIUM])
  116. self.assertGreater(result.health_score, 30)
  117. def test_low_health_identification(self):
  118. """测试低健康度识别"""
  119. data = self.generate_price_data("low")
  120. result = self.identifier.identify(data)
  121. self.assertIsNotNone(result)
  122. self.assertLess(result.health_score, 70)
  123. def test_health_components(self):
  124. """测试健康度各维度计算"""
  125. data = self.generate_price_data("medium")
  126. result = self.identifier.identify(data)
  127. # 检查各维度存在
  128. self.assertIn("price_impact", result.components)
  129. self.assertIn("order_flow", result.components)
  130. self.assertIn("liquidity_depth", result.components)
  131. self.assertIn("volatility_efficiency", result.components)
  132. self.assertIn("info_response", result.components)
  133. # 检查维度范围
  134. for key, value in result.components.items():
  135. self.assertGreaterEqual(value, 0)
  136. self.assertLessEqual(value, 100)
  137. def test_price_impact_normalization(self):
  138. """测试价格冲击系数标准化"""
  139. score_low = self.identifier._normalize_price_impact(0.0001)
  140. score_high = self.identifier._normalize_price_impact(0.01)
  141. self.assertGreater(score_low, score_high)
  142. self.assertGreaterEqual(score_low, 90)
  143. self.assertLessEqual(score_high, 20)
  144. class TestMicroEcosystem(unittest.TestCase):
  145. """测试微观生态识别器"""
  146. def setUp(self):
  147. self.identifier = MicroEcosystemIdentifier()
  148. def generate_price_data(self, state: str, days: int = 100) -> pd.DataFrame:
  149. """生成不同状态的价格数据"""
  150. np.random.seed(42)
  151. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  152. if state == "trending":
  153. # 趋势状态
  154. returns = np.random.normal(0.001, 0.012, days)
  155. returns += np.linspace(0, 0.15, days)
  156. elif state == "ranging":
  157. # 震荡状态
  158. returns = np.random.normal(0, 0.01, days)
  159. else:
  160. returns = np.random.normal(0, 0.02, days)
  161. prices = 100 * np.exp(np.cumsum(returns))
  162. return pd.DataFrame({
  163. 'open': prices * (1 + np.random.normal(0, 0.001, days)),
  164. 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
  165. 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
  166. 'close': prices,
  167. 'volume': np.random.lognormal(15, 0.5, days)
  168. }, index=dates)
  169. def test_hmm_fit_and_predict(self):
  170. """测试HMM训练和预测"""
  171. data = self.generate_price_data("trending")
  172. self.identifier.fit(data)
  173. self.assertTrue(self.identifier._is_fitted)
  174. self.assertIsNotNone(self.identifier.hmm_model)
  175. result = self.identifier.identify(data)
  176. self.assertIsNotNone(result.state)
  177. self.assertIn(result.state, [MicroState.TRENDING, MicroState.RANGING, MicroState.REVERSING])
  178. def test_flow_toxicity_detection(self):
  179. """测试有毒订单流检测"""
  180. data = self.generate_price_data("ranging")
  181. # 创建有毒的成交数据(高成交量但价格不动)
  182. trade_data = pd.DataFrame({
  183. 'price': [100] * 20,
  184. 'volume': [1e8] * 20, # 超大成交量
  185. 'side': ['buy'] * 10 + ['sell'] * 10
  186. })
  187. toxicity = self.identifier._detect_flow_toxicity(data, trade_data)
  188. self.assertIsInstance(toxicity, FlowToxicity)
  189. def test_smart_money_detection(self):
  190. """测试主力资金识别"""
  191. data = self.generate_price_data("ranging")
  192. # 创建大单成交数据
  193. trade_data = pd.DataFrame({
  194. 'price': [100] * 10,
  195. 'volume': [2e6] * 10, # 大单
  196. 'side': ['buy'] * 10
  197. })
  198. smart_money = self.identifier._detect_smart_money(data, trade_data)
  199. self.assertIsInstance(smart_money.detected, bool)
  200. self.assertIsInstance(smart_money.confidence, float)
  201. def test_state_probabilities(self):
  202. """测试状态概率"""
  203. data = self.generate_price_data("trending")
  204. self.identifier.fit(data)
  205. result = self.identifier.identify(data)
  206. self.assertIsInstance(result.state_probability, dict)
  207. self.assertGreater(sum(result.state_probability.values()), 0.99)
  208. class TestInstantEcosystem(unittest.TestCase):
  209. """测试瞬时生态识别器"""
  210. def setUp(self):
  211. self.identifier = InstantEcosystemIdentifier()
  212. def generate_tick_data(self, scenario: str, minutes: int = 10) -> pd.DataFrame:
  213. """生成tick数据"""
  214. np.random.seed(42)
  215. timestamps = pd.date_range(end=datetime.now(), periods=minutes, freq='min')
  216. if scenario == "bid_dominant":
  217. # 买盘占优
  218. sides = ['buy'] * 7 + ['sell'] * 3
  219. volumes = np.random.lognormal(10, 0.5, minutes)
  220. elif scenario == "ask_dominant":
  221. # 卖盘占优
  222. sides = ['buy'] * 3 + ['sell'] * 7
  223. volumes = np.random.lognormal(10, 0.5, minutes)
  224. elif scenario == "spike":
  225. # 跳动率突变
  226. sides = np.random.choice(['buy', 'sell'], minutes * 3)
  227. volumes = np.random.lognormal(10, 0.5, minutes * 3)
  228. timestamps = pd.date_range(end=datetime.now(), periods=minutes * 3, freq='20s')
  229. else:
  230. sides = np.random.choice(['buy', 'sell'], minutes)
  231. volumes = np.random.lognormal(10, 0.5, minutes)
  232. if scenario == "spike":
  233. return pd.DataFrame({
  234. 'price': 100 + np.random.normal(0, 0.1, minutes * 3),
  235. 'volume': volumes,
  236. 'side': sides
  237. }, index=timestamps)
  238. return pd.DataFrame({
  239. 'price': 100 + np.random.normal(0, 0.1, minutes),
  240. 'volume': volumes,
  241. 'side': sides
  242. }, index=timestamps)
  243. def test_imbalance_detection(self):
  244. """测试买卖盘不平衡检测"""
  245. tick_data = self.generate_tick_data("bid_dominant")
  246. result = self.identifier.identify(tick_data)
  247. self.assertIn(
  248. result.imbalance_direction,
  249. [ImbalanceDirection.BID_DOMINANT, ImbalanceDirection.BALANCED]
  250. )
  251. def test_block_flow_calculation(self):
  252. """测试大单流向计算"""
  253. tick_data = self.generate_tick_data("normal")
  254. block_flow = self.identifier._calculate_block_flow(tick_data)
  255. self.assertIsInstance(block_flow['net_flow'], float)
  256. self.assertIsInstance(block_flow['buy_count'], int)
  257. self.assertIsInstance(block_flow['sell_count'], int)
  258. def test_tick_activity_detection(self):
  259. """测试跳动率检测"""
  260. tick_data = self.generate_tick_data("spike")
  261. result = self.identifier.identify(tick_data)
  262. self.assertIn(result.tick_activity, [TickActivity.SPIKE, TickActivity.ELEVATED, TickActivity.NORMAL])
  263. def test_trading_opportunity(self):
  264. """测试交易机会判断"""
  265. tick_data = self.generate_tick_data("bid_dominant")
  266. result = self.identifier.identify(tick_data)
  267. self.assertIsInstance(result.is_trading_opportunity(), bool)
  268. class TestEcosystemFusion(unittest.TestCase):
  269. """测试生态融合器"""
  270. def setUp(self):
  271. self.fusion = EcosystemFusion()
  272. def generate_complete_data(self) -> tuple:
  273. """生成完整的测试数据"""
  274. np.random.seed(42)
  275. days = 100
  276. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  277. returns = np.random.normal(0.001, 0.015, days)
  278. prices = 100 * np.exp(np.cumsum(returns))
  279. price_data = pd.DataFrame({
  280. 'open': prices * (1 + np.random.normal(0, 0.001, days)),
  281. 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
  282. 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
  283. 'close': prices,
  284. 'volume': np.random.lognormal(15, 0.5, days)
  285. }, index=dates)
  286. tick_data = pd.DataFrame({
  287. 'price': 100 + np.random.normal(0, 0.1, 20),
  288. 'volume': np.random.lognormal(10, 0.5, 20),
  289. 'side': np.random.choice(['buy', 'sell'], 20)
  290. })
  291. return price_data, tick_data
  292. def test_fusion_output(self):
  293. """测试融合输出"""
  294. price_data, tick_data = self.generate_complete_data()
  295. result = self.fusion.fuse(
  296. price_data=price_data,
  297. tick_data=tick_data
  298. )
  299. self.assertIsNotNone(result)
  300. self.assertIsNotNone(result.macro)
  301. self.assertIsNotNone(result.meso)
  302. self.assertIsNotNone(result.micro)
  303. self.assertIsNotNone(result.instant)
  304. def test_confidence_calculation(self):
  305. """测试置信度计算"""
  306. price_data, _ = self.generate_complete_data()
  307. result = self.fusion.fuse(price_data=price_data)
  308. self.assertGreaterEqual(result.confidence, 0)
  309. self.assertLessEqual(result.confidence, 1)
  310. def test_position_suggestion(self):
  311. """测试仓位建议"""
  312. price_data, _ = self.generate_complete_data()
  313. result = self.fusion.fuse(price_data=price_data)
  314. self.assertGreaterEqual(result.suggested_position, 0)
  315. self.assertLessEqual(result.suggested_position, 1)
  316. def test_agent_recommendation(self):
  317. """测试智能体推荐"""
  318. price_data, _ = self.generate_complete_data()
  319. result = self.fusion.fuse(price_data=price_data)
  320. self.assertIsInstance(result.suggested_agents, list)
  321. # 至少推荐一个智能体或为空(如果生态不明)
  322. if result.macro.regime.value != "unknown":
  323. self.assertGreater(len(result.suggested_agents), 0)
  324. def test_warning_generation(self):
  325. """测试警告生成"""
  326. price_data, _ = self.generate_complete_data()
  327. result = self.fusion.fuse(price_data=price_data)
  328. self.assertIsInstance(result.warnings, list)
  329. def test_to_dict(self):
  330. """测试字典转换"""
  331. price_data, _ = self.generate_complete_data()
  332. result = self.fusion.fuse(price_data=price_data)
  333. dict_result = result.to_dict()
  334. self.assertIsInstance(dict_result, dict)
  335. self.assertIn("timestamp", dict_result)
  336. self.assertIn("overall_regime", dict_result)
  337. self.assertIn("confidence", dict_result)
  338. self.assertIn("trading_bias", dict_result)
  339. if __name__ == "__main__":
  340. unittest.main()