| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- """
- 多智能体策略矩阵单元测试
- 测试覆盖:
- - 各智能体独立测试
- - 动态路由算法测试
- - 协同机制测试
- """
- import unittest
- from datetime import datetime, timedelta
- import numpy as np
- import pandas as pd
- from agents import (
- TrendHunterAgent, MeanReversionAgent, MomentumSurferAgent,
- StructureArbitrageAgent, VolatilitySellerAgent, EventDrivenAgent,
- DynamicAgentRouter, AgentCoordinator, ConflictResolutionMethod,
- SignalDirection, AgentSignal
- )
- from core.ecosystem import (
- MacroEcosystem, MacroRegime,
- MesoEcosystem, HealthLevel,
- MicroEcosystem, MicroState, FlowToxicity, SmartMoneySignal,
- UnifiedEcosystem
- )
- class TestTrendHunterAgent(unittest.TestCase):
- """测试趋势猎手智能体"""
- def setUp(self):
- self.agent = TrendHunterAgent()
- self.price_data = self._generate_trending_data()
- def _generate_trending_data(self, days=60):
- """生成趋势数据"""
- np.random.seed(42)
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- returns = np.random.normal(0.002, 0.015, days)
- trend = np.linspace(0, 0.15, days)
- returns += trend
- prices = 100 * np.exp(np.cumsum(returns))
- return pd.DataFrame({
- 'open': prices * (1 + np.random.normal(0, 0.001, days)),
- 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
- 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
- 'close': prices,
- 'volume': np.random.lognormal(15, 0.5, days)
- }, index=dates)
- def test_signal_generation(self):
- """测试信号生成"""
- signal = self.agent.generate_signal(self.price_data)
- # 信号可能为None(当条件不满足时),但方法应正常运行
- if signal:
- self.assertEqual(signal.agent_name, "trend_hunter")
- self.assertIn(signal.direction, [SignalDirection.LONG, SignalDirection.SHORT, SignalDirection.NEUTRAL])
- def test_expected_return_calculation(self):
- """测试预期收益计算"""
- expected_return = self.agent.get_expected_return(self.price_data)
- self.assertIsInstance(expected_return, float)
- self.assertGreater(expected_return, 0)
- self.assertLess(expected_return, 1.0)
- def test_win_probability_calculation(self):
- """测试胜率计算"""
- win_prob = self.agent.get_win_probability(self.price_data)
- self.assertIsInstance(win_prob, float)
- self.assertGreaterEqual(win_prob, 0)
- self.assertLessEqual(win_prob, 1.0)
- class TestMeanReversionAgent(unittest.TestCase):
- """测试均值回归者智能体"""
- def setUp(self):
- self.agent = MeanReversionAgent()
- self.price_data = self._generate_ranging_data()
- def _generate_ranging_data(self, days=60):
- """生成震荡数据"""
- np.random.seed(42)
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- prices = 100 + np.cumsum(np.random.normal(0, 0.01, days))
- return pd.DataFrame({
- 'open': prices + np.random.normal(0, 0.1, days),
- 'high': prices + abs(np.random.normal(0, 1, days)),
- 'low': prices - abs(np.random.normal(0, 1, days)),
- 'close': prices,
- 'volume': np.random.lognormal(15, 0.5, days)
- }, index=dates)
- def test_signal_generation(self):
- """测试信号生成"""
- signal = self.agent.generate_signal(self.price_data)
- # 震荡市场可能没有信号
- if signal:
- self.assertEqual(signal.agent_name, "mean_reversion")
- def test_bollinger_position_calculation(self):
- """测试布林带位置计算"""
- bb_position = self.agent._calculate_bb_position(self.price_data)
- self.assertIsInstance(bb_position, float)
- class TestDynamicAgentRouter(unittest.TestCase):
- """测试动态路由算法"""
- def setUp(self):
- self.router = DynamicAgentRouter()
- self.agents = {
- "trend": TrendHunterAgent(),
- "mean_reversion": MeanReversionAgent(),
- }
- def _generate_price_data(self):
- """生成测试价格数据"""
- np.random.seed(42)
- days = 100
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- prices = 100 * np.exp(np.cumsum(np.random.normal(0, 0.01, days)))
- return pd.DataFrame({
- 'open': prices,
- 'high': prices * 1.01,
- 'low': prices * 0.99,
- 'close': prices,
- 'volume': np.random.lognormal(15, 0.5, days)
- }, index=dates)
- def _create_mock_ecosystem(self):
- """创建模拟生态系统"""
- macro = MacroEcosystem(
- regime=MacroRegime.SUMMER,
- confidence=0.8,
- volatility_trend=0.1,
- volume_trend=0.05,
- dispersion=30.0,
- adx_value=30.0,
- description="夏季繁荣"
- )
- meso = MesoEcosystem(
- health_score=75.0,
- health_level=HealthLevel.HIGH,
- price_impact=0.001,
- order_flow=0.1,
- liquidity_depth=1.5,
- volatility_efficiency=3.0,
- info_response=0.2,
- components={}
- )
- micro = MicroEcosystem(
- state=MicroState.TRENDING,
- state_probability={MicroState.TRENDING: 0.7, MicroState.RANGING: 0.2, MicroState.REVERSING: 0.1},
- flow_toxicity=FlowToxicity.NONE,
- smart_money=SmartMoneySignal(False, "neutral", 0.0, 0, 0.0),
- hmm_features={},
- warnings=[]
- )
- return UnifiedEcosystem(
- timestamp=datetime.now(),
- macro=macro,
- meso=meso,
- micro=micro,
- overall_regime="summer_trending",
- confidence=0.75,
- trading_bias="long",
- risk_level="low",
- suggested_position=0.8,
- suggested_agents=["trend_hunter"],
- instant=None,
- warnings=[]
- )
- def test_utility_calculation(self):
- """测试效用计算"""
- price_data = self._generate_price_data()
- ecosystem = self._create_mock_ecosystem()
- utilities = self.router._calculate_utilities(self.agents, price_data, ecosystem)
- self.assertIsInstance(utilities, dict)
- self.assertEqual(len(utilities), len(self.agents))
- def test_routing_decision(self):
- """测试路由决策"""
- price_data = self._generate_price_data()
- ecosystem = self._create_mock_ecosystem()
- decision = self.router.route(self.agents, price_data, ecosystem)
- self.assertIsNotNone(decision)
- self.assertIsInstance(decision.weights, dict)
- self.assertIsInstance(decision.active_agents, list)
- class TestAgentCoordinator(unittest.TestCase):
- """测试智能体协同机制"""
- def setUp(self):
- self.coordinator = AgentCoordinator()
- def _create_mock_signal(self, direction, confidence=0.7):
- """创建模拟信号"""
- return AgentSignal(
- agent_name="test_agent",
- direction=direction,
- strength=None,
- confidence=confidence,
- suggested_position=0.5,
- expected_return=0.1,
- win_probability=0.6,
- timestamp=datetime.now()
- )
- def test_conflict_resolution(self):
- """测试冲突解决"""
- long_signal = self._create_mock_signal(SignalDirection.LONG, 0.8)
- short_signal = self._create_mock_signal(SignalDirection.SHORT, 0.6)
- signals = {
- "agent1": long_signal,
- "agent2": short_signal
- }
- weights = {"agent1": 0.6, "agent2": 0.4}
- result = self.coordinator.coordinate(signals, weights)
- self.assertIsNotNone(result)
- self.assertEqual(result.coordination_type, "conflict_resolved")
- def test_reinforcement(self):
- """测试信号叠加增强"""
- signal1 = self._create_mock_signal(SignalDirection.LONG, 0.7)
- signal2 = self._create_mock_signal(SignalDirection.LONG, 0.8)
- signal3 = self._create_mock_signal(SignalDirection.LONG, 0.75)
- signals = {
- "agent1": signal1,
- "agent2": signal2,
- "agent3": signal3
- }
- weights = {"agent1": 0.33, "agent2": 0.33, "agent3": 0.34}
- result = self.coordinator.coordinate(signals, weights)
- self.assertIsNotNone(result)
- self.assertEqual(result.final_direction, SignalDirection.LONG)
- self.assertEqual(result.coordination_type, "reinforced")
- class TestVolatilitySellerAgent(unittest.TestCase):
- """测试波动率卖家智能体"""
- def setUp(self):
- self.agent = VolatilitySellerAgent()
- def _generate_high_vol_data(self, days=100):
- """生成高波动数据"""
- np.random.seed(42)
- dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
- returns = np.random.normal(0, 0.03, days) # 高波动
- prices = 100 * np.exp(np.cumsum(returns))
- return pd.DataFrame({
- 'open': prices,
- 'high': prices * 1.03,
- 'low': prices * 0.97,
- 'close': prices,
- 'volume': np.random.lognormal(15, 0.5, days)
- }, index=dates)
- def test_iv_rank_calculation(self):
- """测试IV Rank计算"""
- price_data = self._generate_high_vol_data()
- iv_rank = self.agent._calculate_iv_rank(price_data)
- self.assertIsInstance(iv_rank, float)
- self.assertGreaterEqual(iv_rank, 0)
- self.assertLessEqual(iv_rank, 100)
- if __name__ == "__main__":
- unittest.main()
|