| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436 |
- """
- 生态识别引擎单元测试
- 测试覆盖:
- - 四季识别
- - 健康度分级
- - 有毒订单流检测
- - HMM状态识别
- - 生态融合
- """
- import unittest
- from datetime import datetime, timedelta
- import numpy as np
- import pandas as pd
- from core.ecosystem import (
- MacroEcosystemIdentifier, MacroRegime,
- MesoEcosystemIdentifier, HealthLevel,
- MicroEcosystemIdentifier, MicroState, FlowToxicity,
- InstantEcosystemIdentifier, ImbalanceDirection, TickActivity,
- EcosystemFusion
- )
- class TestMacroEcosystem(unittest.TestCase):
- """测试宏观生态识别器"""
- def setUp(self):
- self.identifier = MacroEcosystemIdentifier()
- def generate_price_data(self, regime_type: str, days: int = 100) -> pd.DataFrame:
- """生成模拟价格数据"""
- np.random.seed(42)
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- if regime_type == "summer":
- # 夏季:强趋势,中等波动
- returns = np.random.normal(0.001, 0.015, days)
- trend = np.linspace(0, 0.1, days)
- returns += trend
- elif regime_type == "winter":
- # 冬季:低波动,低成交量
- returns = np.random.normal(-0.0005, 0.008, days)
- elif regime_type == "autumn":
- # 秋季:高波动,无趋势
- returns = np.random.normal(0, 0.025, days)
- elif regime_type == "spring":
- # 春季:波动率回升
- returns = np.random.normal(0.0005, 0.012, days)
- else:
- returns = np.random.normal(0, 0.015, days)
- prices = 100 * np.exp(np.cumsum(returns))
- volumes = np.random.lognormal(15, 0.5, days) * (
- 0.7 if regime_type == "winter" else 1.0
- )
- return pd.DataFrame({
- 'open': prices * (1 + np.random.normal(0, 0.001, days)),
- 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
- 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
- 'close': prices,
- 'volume': volumes
- }, index=dates)
- def test_summer_identification(self):
- """测试夏季识别"""
- data = self.generate_price_data("summer")
- result = self.identifier.identify(data)
- self.assertIsNotNone(result)
- self.assertIn(result.regime, [MacroRegime.SUMMER, MacroRegime.SPRING])
- self.assertGreater(result.confidence, 0.2)
- self.assertGreater(result.adx_value, 15)
- def test_winter_identification(self):
- """测试冬季识别"""
- data = self.generate_price_data("winter")
- result = self.identifier.identify(data)
- self.assertIsNotNone(result)
- self.assertIn(result.regime, [MacroRegime.WINTER, MacroRegime.UNKNOWN])
- def test_adx_calculation(self):
- """测试ADX计算"""
- data = self.generate_price_data("summer")
- adx = self.identifier._calculate_adx(data)
- self.assertIsInstance(adx, (float, np.floating))
- self.assertGreater(adx, 0)
- self.assertLessEqual(adx, 100)
- def test_volatility_calculation(self):
- """测试波动率计算"""
- data = self.generate_price_data("autumn")
- volatility = self.identifier._calculate_volatility(data)
- self.assertIsInstance(volatility, pd.Series)
- self.assertGreater(volatility.iloc[-1], 0)
- class TestMesoEcosystem(unittest.TestCase):
- """测试中观生态识别器"""
- def setUp(self):
- self.identifier = MesoEcosystemIdentifier()
- def generate_price_data(self, health: str, days: int = 60) -> pd.DataFrame:
- """生成不同健康度的价格数据"""
- np.random.seed(42)
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- if health == "high":
- # 高健康度:低冲击,流动性好
- returns = np.random.normal(0.0005, 0.012, days)
- volumes = np.random.lognormal(15, 0.3, days)
- elif health == "low":
- # 低健康度:高冲击,流动性差
- returns = np.random.normal(0, 0.03, days)
- volumes = np.random.lognormal(14, 0.8, days)
- else:
- returns = np.random.normal(0, 0.015, days)
- volumes = np.random.lognormal(15, 0.5, days)
- prices = 100 * np.exp(np.cumsum(returns))
- return pd.DataFrame({
- 'open': prices * (1 + np.random.normal(0, 0.001, days)),
- 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
- 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
- 'close': prices,
- 'volume': volumes
- }, index=dates)
- def test_high_health_identification(self):
- """测试高健康度识别"""
- data = self.generate_price_data("high")
- result = self.identifier.identify(data)
- self.assertIsNotNone(result)
- self.assertIn(result.health_level, [HealthLevel.HIGH, HealthLevel.MEDIUM])
- self.assertGreater(result.health_score, 30)
- def test_low_health_identification(self):
- """测试低健康度识别"""
- data = self.generate_price_data("low")
- result = self.identifier.identify(data)
- self.assertIsNotNone(result)
- self.assertLess(result.health_score, 70)
- def test_health_components(self):
- """测试健康度各维度计算"""
- data = self.generate_price_data("medium")
- result = self.identifier.identify(data)
- # 检查各维度存在
- self.assertIn("price_impact", result.components)
- self.assertIn("order_flow", result.components)
- self.assertIn("liquidity_depth", result.components)
- self.assertIn("volatility_efficiency", result.components)
- self.assertIn("info_response", result.components)
- # 检查维度范围
- for key, value in result.components.items():
- self.assertGreaterEqual(value, 0)
- self.assertLessEqual(value, 100)
- def test_price_impact_normalization(self):
- """测试价格冲击系数标准化"""
- score_low = self.identifier._normalize_price_impact(0.0001)
- score_high = self.identifier._normalize_price_impact(0.01)
- self.assertGreater(score_low, score_high)
- self.assertGreaterEqual(score_low, 90)
- self.assertLessEqual(score_high, 20)
- class TestMicroEcosystem(unittest.TestCase):
- """测试微观生态识别器"""
- def setUp(self):
- self.identifier = MicroEcosystemIdentifier()
- def generate_price_data(self, state: str, days: int = 100) -> pd.DataFrame:
- """生成不同状态的价格数据"""
- np.random.seed(42)
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- if state == "trending":
- # 趋势状态
- returns = np.random.normal(0.001, 0.012, days)
- returns += np.linspace(0, 0.15, days)
- elif state == "ranging":
- # 震荡状态
- returns = np.random.normal(0, 0.01, days)
- else:
- returns = np.random.normal(0, 0.02, days)
- prices = 100 * np.exp(np.cumsum(returns))
- return pd.DataFrame({
- 'open': prices * (1 + np.random.normal(0, 0.001, days)),
- 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
- 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
- 'close': prices,
- 'volume': np.random.lognormal(15, 0.5, days)
- }, index=dates)
- def test_hmm_fit_and_predict(self):
- """测试HMM训练和预测"""
- data = self.generate_price_data("trending")
- self.identifier.fit(data)
- self.assertTrue(self.identifier._is_fitted)
- self.assertIsNotNone(self.identifier.hmm_model)
- result = self.identifier.identify(data)
- self.assertIsNotNone(result.state)
- self.assertIn(result.state, [MicroState.TRENDING, MicroState.RANGING, MicroState.REVERSING])
- def test_flow_toxicity_detection(self):
- """测试有毒订单流检测"""
- data = self.generate_price_data("ranging")
- # 创建有毒的成交数据(高成交量但价格不动)
- trade_data = pd.DataFrame({
- 'price': [100] * 20,
- 'volume': [1e8] * 20, # 超大成交量
- 'side': ['buy'] * 10 + ['sell'] * 10
- })
- toxicity = self.identifier._detect_flow_toxicity(data, trade_data)
- self.assertIsInstance(toxicity, FlowToxicity)
- def test_smart_money_detection(self):
- """测试主力资金识别"""
- data = self.generate_price_data("ranging")
- # 创建大单成交数据
- trade_data = pd.DataFrame({
- 'price': [100] * 10,
- 'volume': [2e6] * 10, # 大单
- 'side': ['buy'] * 10
- })
- smart_money = self.identifier._detect_smart_money(data, trade_data)
- self.assertIsInstance(smart_money.detected, bool)
- self.assertIsInstance(smart_money.confidence, float)
- def test_state_probabilities(self):
- """测试状态概率"""
- data = self.generate_price_data("trending")
- self.identifier.fit(data)
- result = self.identifier.identify(data)
- self.assertIsInstance(result.state_probability, dict)
- self.assertGreater(sum(result.state_probability.values()), 0.99)
- class TestInstantEcosystem(unittest.TestCase):
- """测试瞬时生态识别器"""
- def setUp(self):
- self.identifier = InstantEcosystemIdentifier()
- def generate_tick_data(self, scenario: str, minutes: int = 10) -> pd.DataFrame:
- """生成tick数据"""
- np.random.seed(42)
- timestamps = pd.date_range(end=datetime.now(), periods=minutes, freq='min')
- if scenario == "bid_dominant":
- # 买盘占优
- sides = ['buy'] * 7 + ['sell'] * 3
- volumes = np.random.lognormal(10, 0.5, minutes)
- elif scenario == "ask_dominant":
- # 卖盘占优
- sides = ['buy'] * 3 + ['sell'] * 7
- volumes = np.random.lognormal(10, 0.5, minutes)
- elif scenario == "spike":
- # 跳动率突变
- sides = np.random.choice(['buy', 'sell'], minutes * 3)
- volumes = np.random.lognormal(10, 0.5, minutes * 3)
- timestamps = pd.date_range(end=datetime.now(), periods=minutes * 3, freq='20s')
- else:
- sides = np.random.choice(['buy', 'sell'], minutes)
- volumes = np.random.lognormal(10, 0.5, minutes)
- if scenario == "spike":
- return pd.DataFrame({
- 'price': 100 + np.random.normal(0, 0.1, minutes * 3),
- 'volume': volumes,
- 'side': sides
- }, index=timestamps)
- return pd.DataFrame({
- 'price': 100 + np.random.normal(0, 0.1, minutes),
- 'volume': volumes,
- 'side': sides
- }, index=timestamps)
- def test_imbalance_detection(self):
- """测试买卖盘不平衡检测"""
- tick_data = self.generate_tick_data("bid_dominant")
- result = self.identifier.identify(tick_data)
- self.assertIn(
- result.imbalance_direction,
- [ImbalanceDirection.BID_DOMINANT, ImbalanceDirection.BALANCED]
- )
- def test_block_flow_calculation(self):
- """测试大单流向计算"""
- tick_data = self.generate_tick_data("normal")
- block_flow = self.identifier._calculate_block_flow(tick_data)
- self.assertIsInstance(block_flow['net_flow'], float)
- self.assertIsInstance(block_flow['buy_count'], int)
- self.assertIsInstance(block_flow['sell_count'], int)
- def test_tick_activity_detection(self):
- """测试跳动率检测"""
- tick_data = self.generate_tick_data("spike")
- result = self.identifier.identify(tick_data)
- self.assertIn(result.tick_activity, [TickActivity.SPIKE, TickActivity.ELEVATED, TickActivity.NORMAL])
- def test_trading_opportunity(self):
- """测试交易机会判断"""
- tick_data = self.generate_tick_data("bid_dominant")
- result = self.identifier.identify(tick_data)
- self.assertIsInstance(result.is_trading_opportunity(), bool)
- class TestEcosystemFusion(unittest.TestCase):
- """测试生态融合器"""
- def setUp(self):
- self.fusion = EcosystemFusion()
- def generate_complete_data(self) -> tuple:
- """生成完整的测试数据"""
- np.random.seed(42)
- days = 100
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- returns = np.random.normal(0.001, 0.015, days)
- prices = 100 * np.exp(np.cumsum(returns))
- price_data = pd.DataFrame({
- 'open': prices * (1 + np.random.normal(0, 0.001, days)),
- 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
- 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
- 'close': prices,
- 'volume': np.random.lognormal(15, 0.5, days)
- }, index=dates)
- tick_data = pd.DataFrame({
- 'price': 100 + np.random.normal(0, 0.1, 20),
- 'volume': np.random.lognormal(10, 0.5, 20),
- 'side': np.random.choice(['buy', 'sell'], 20)
- })
- return price_data, tick_data
- def test_fusion_output(self):
- """测试融合输出"""
- price_data, tick_data = self.generate_complete_data()
- result = self.fusion.fuse(
- price_data=price_data,
- tick_data=tick_data
- )
- self.assertIsNotNone(result)
- self.assertIsNotNone(result.macro)
- self.assertIsNotNone(result.meso)
- self.assertIsNotNone(result.micro)
- self.assertIsNotNone(result.instant)
- def test_confidence_calculation(self):
- """测试置信度计算"""
- price_data, _ = self.generate_complete_data()
- result = self.fusion.fuse(price_data=price_data)
- self.assertGreaterEqual(result.confidence, 0)
- self.assertLessEqual(result.confidence, 1)
- def test_position_suggestion(self):
- """测试仓位建议"""
- price_data, _ = self.generate_complete_data()
- result = self.fusion.fuse(price_data=price_data)
- self.assertGreaterEqual(result.suggested_position, 0)
- self.assertLessEqual(result.suggested_position, 1)
- def test_agent_recommendation(self):
- """测试智能体推荐"""
- price_data, _ = self.generate_complete_data()
- result = self.fusion.fuse(price_data=price_data)
- self.assertIsInstance(result.suggested_agents, list)
- # 至少推荐一个智能体或为空(如果生态不明)
- if result.macro.regime.value != "unknown":
- self.assertGreater(len(result.suggested_agents), 0)
- def test_warning_generation(self):
- """测试警告生成"""
- price_data, _ = self.generate_complete_data()
- result = self.fusion.fuse(price_data=price_data)
- self.assertIsInstance(result.warnings, list)
- def test_to_dict(self):
- """测试字典转换"""
- price_data, _ = self.generate_complete_data()
- result = self.fusion.fuse(price_data=price_data)
- dict_result = result.to_dict()
- self.assertIsInstance(dict_result, dict)
- self.assertIn("timestamp", dict_result)
- self.assertIn("overall_regime", dict_result)
- self.assertIn("confidence", dict_result)
- self.assertIn("trading_bias", dict_result)
- if __name__ == "__main__":
- unittest.main()
|