test_agents.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """
  2. 多智能体策略矩阵单元测试
  3. 测试覆盖:
  4. - 各智能体独立测试
  5. - 动态路由算法测试
  6. - 协同机制测试
  7. """
  8. import unittest
  9. from datetime import datetime, timedelta
  10. import numpy as np
  11. import pandas as pd
  12. from agents import (
  13. TrendHunterAgent, MeanReversionAgent, MomentumSurferAgent,
  14. StructureArbitrageAgent, VolatilitySellerAgent, EventDrivenAgent,
  15. DynamicAgentRouter, AgentCoordinator, ConflictResolutionMethod,
  16. SignalDirection, AgentSignal
  17. )
  18. from core.ecosystem import (
  19. MacroEcosystem, MacroRegime,
  20. MesoEcosystem, HealthLevel,
  21. MicroEcosystem, MicroState, FlowToxicity, SmartMoneySignal,
  22. UnifiedEcosystem
  23. )
  24. class TestTrendHunterAgent(unittest.TestCase):
  25. """测试趋势猎手智能体"""
  26. def setUp(self):
  27. self.agent = TrendHunterAgent()
  28. self.price_data = self._generate_trending_data()
  29. def _generate_trending_data(self, days=60):
  30. """生成趋势数据"""
  31. np.random.seed(42)
  32. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  33. returns = np.random.normal(0.002, 0.015, days)
  34. trend = np.linspace(0, 0.15, days)
  35. returns += trend
  36. prices = 100 * np.exp(np.cumsum(returns))
  37. return pd.DataFrame({
  38. 'open': prices * (1 + np.random.normal(0, 0.001, days)),
  39. 'high': prices * (1 + abs(np.random.normal(0, 0.01, days))),
  40. 'low': prices * (1 - abs(np.random.normal(0, 0.01, days))),
  41. 'close': prices,
  42. 'volume': np.random.lognormal(15, 0.5, days)
  43. }, index=dates)
  44. def test_signal_generation(self):
  45. """测试信号生成"""
  46. signal = self.agent.generate_signal(self.price_data)
  47. # 信号可能为None(当条件不满足时),但方法应正常运行
  48. if signal:
  49. self.assertEqual(signal.agent_name, "trend_hunter")
  50. self.assertIn(signal.direction, [SignalDirection.LONG, SignalDirection.SHORT, SignalDirection.NEUTRAL])
  51. def test_expected_return_calculation(self):
  52. """测试预期收益计算"""
  53. expected_return = self.agent.get_expected_return(self.price_data)
  54. self.assertIsInstance(expected_return, float)
  55. self.assertGreater(expected_return, 0)
  56. self.assertLess(expected_return, 1.0)
  57. def test_win_probability_calculation(self):
  58. """测试胜率计算"""
  59. win_prob = self.agent.get_win_probability(self.price_data)
  60. self.assertIsInstance(win_prob, float)
  61. self.assertGreaterEqual(win_prob, 0)
  62. self.assertLessEqual(win_prob, 1.0)
  63. class TestMeanReversionAgent(unittest.TestCase):
  64. """测试均值回归者智能体"""
  65. def setUp(self):
  66. self.agent = MeanReversionAgent()
  67. self.price_data = self._generate_ranging_data()
  68. def _generate_ranging_data(self, days=60):
  69. """生成震荡数据"""
  70. np.random.seed(42)
  71. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  72. prices = 100 + np.cumsum(np.random.normal(0, 0.01, days))
  73. return pd.DataFrame({
  74. 'open': prices + np.random.normal(0, 0.1, days),
  75. 'high': prices + abs(np.random.normal(0, 1, days)),
  76. 'low': prices - abs(np.random.normal(0, 1, days)),
  77. 'close': prices,
  78. 'volume': np.random.lognormal(15, 0.5, days)
  79. }, index=dates)
  80. def test_signal_generation(self):
  81. """测试信号生成"""
  82. signal = self.agent.generate_signal(self.price_data)
  83. # 震荡市场可能没有信号
  84. if signal:
  85. self.assertEqual(signal.agent_name, "mean_reversion")
  86. def test_bollinger_position_calculation(self):
  87. """测试布林带位置计算"""
  88. bb_position = self.agent._calculate_bb_position(self.price_data)
  89. self.assertIsInstance(bb_position, float)
  90. class TestDynamicAgentRouter(unittest.TestCase):
  91. """测试动态路由算法"""
  92. def setUp(self):
  93. self.router = DynamicAgentRouter()
  94. self.agents = {
  95. "trend": TrendHunterAgent(),
  96. "mean_reversion": MeanReversionAgent(),
  97. }
  98. def _generate_price_data(self):
  99. """生成测试价格数据"""
  100. np.random.seed(42)
  101. days = 100
  102. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  103. prices = 100 * np.exp(np.cumsum(np.random.normal(0, 0.01, days)))
  104. return pd.DataFrame({
  105. 'open': prices,
  106. 'high': prices * 1.01,
  107. 'low': prices * 0.99,
  108. 'close': prices,
  109. 'volume': np.random.lognormal(15, 0.5, days)
  110. }, index=dates)
  111. def _create_mock_ecosystem(self):
  112. """创建模拟生态系统"""
  113. macro = MacroEcosystem(
  114. regime=MacroRegime.SUMMER,
  115. confidence=0.8,
  116. volatility_trend=0.1,
  117. volume_trend=0.05,
  118. dispersion=30.0,
  119. adx_value=30.0,
  120. description="夏季繁荣"
  121. )
  122. meso = MesoEcosystem(
  123. health_score=75.0,
  124. health_level=HealthLevel.HIGH,
  125. price_impact=0.001,
  126. order_flow=0.1,
  127. liquidity_depth=1.5,
  128. volatility_efficiency=3.0,
  129. info_response=0.2,
  130. components={}
  131. )
  132. micro = MicroEcosystem(
  133. state=MicroState.TRENDING,
  134. state_probability={MicroState.TRENDING: 0.7, MicroState.RANGING: 0.2, MicroState.REVERSING: 0.1},
  135. flow_toxicity=FlowToxicity.NONE,
  136. smart_money=SmartMoneySignal(False, "neutral", 0.0, 0, 0.0),
  137. hmm_features={},
  138. warnings=[]
  139. )
  140. return UnifiedEcosystem(
  141. timestamp=datetime.now(),
  142. macro=macro,
  143. meso=meso,
  144. micro=micro,
  145. overall_regime="summer_trending",
  146. confidence=0.75,
  147. trading_bias="long",
  148. risk_level="low",
  149. suggested_position=0.8,
  150. suggested_agents=["trend_hunter"],
  151. instant=None,
  152. warnings=[]
  153. )
  154. def test_utility_calculation(self):
  155. """测试效用计算"""
  156. price_data = self._generate_price_data()
  157. ecosystem = self._create_mock_ecosystem()
  158. utilities = self.router._calculate_utilities(self.agents, price_data, ecosystem)
  159. self.assertIsInstance(utilities, dict)
  160. self.assertEqual(len(utilities), len(self.agents))
  161. def test_routing_decision(self):
  162. """测试路由决策"""
  163. price_data = self._generate_price_data()
  164. ecosystem = self._create_mock_ecosystem()
  165. decision = self.router.route(self.agents, price_data, ecosystem)
  166. self.assertIsNotNone(decision)
  167. self.assertIsInstance(decision.weights, dict)
  168. self.assertIsInstance(decision.active_agents, list)
  169. class TestAgentCoordinator(unittest.TestCase):
  170. """测试智能体协同机制"""
  171. def setUp(self):
  172. self.coordinator = AgentCoordinator()
  173. def _create_mock_signal(self, direction, confidence=0.7):
  174. """创建模拟信号"""
  175. return AgentSignal(
  176. agent_name="test_agent",
  177. direction=direction,
  178. strength=None,
  179. confidence=confidence,
  180. suggested_position=0.5,
  181. expected_return=0.1,
  182. win_probability=0.6,
  183. timestamp=datetime.now()
  184. )
  185. def test_conflict_resolution(self):
  186. """测试冲突解决"""
  187. long_signal = self._create_mock_signal(SignalDirection.LONG, 0.8)
  188. short_signal = self._create_mock_signal(SignalDirection.SHORT, 0.6)
  189. signals = {
  190. "agent1": long_signal,
  191. "agent2": short_signal
  192. }
  193. weights = {"agent1": 0.6, "agent2": 0.4}
  194. result = self.coordinator.coordinate(signals, weights)
  195. self.assertIsNotNone(result)
  196. self.assertEqual(result.coordination_type, "conflict_resolved")
  197. def test_reinforcement(self):
  198. """测试信号叠加增强"""
  199. signal1 = self._create_mock_signal(SignalDirection.LONG, 0.7)
  200. signal2 = self._create_mock_signal(SignalDirection.LONG, 0.8)
  201. signal3 = self._create_mock_signal(SignalDirection.LONG, 0.75)
  202. signals = {
  203. "agent1": signal1,
  204. "agent2": signal2,
  205. "agent3": signal3
  206. }
  207. weights = {"agent1": 0.33, "agent2": 0.33, "agent3": 0.34}
  208. result = self.coordinator.coordinate(signals, weights)
  209. self.assertIsNotNone(result)
  210. self.assertEqual(result.final_direction, SignalDirection.LONG)
  211. self.assertEqual(result.coordination_type, "reinforced")
  212. class TestVolatilitySellerAgent(unittest.TestCase):
  213. """测试波动率卖家智能体"""
  214. def setUp(self):
  215. self.agent = VolatilitySellerAgent()
  216. def _generate_high_vol_data(self, days=100):
  217. """生成高波动数据"""
  218. np.random.seed(42)
  219. dates = pd.date_range(end=datetime.now(), periods=days, freq='D')
  220. returns = np.random.normal(0, 0.03, days) # 高波动
  221. prices = 100 * np.exp(np.cumsum(returns))
  222. return pd.DataFrame({
  223. 'open': prices,
  224. 'high': prices * 1.03,
  225. 'low': prices * 0.97,
  226. 'close': prices,
  227. 'volume': np.random.lognormal(15, 0.5, days)
  228. }, index=dates)
  229. def test_iv_rank_calculation(self):
  230. """测试IV Rank计算"""
  231. price_data = self._generate_high_vol_data()
  232. iv_rank = self.agent._calculate_iv_rank(price_data)
  233. self.assertIsInstance(iv_rank, float)
  234. self.assertGreaterEqual(iv_rank, 0)
  235. self.assertLessEqual(iv_rank, 100)
  236. if __name__ == "__main__":
  237. unittest.main()