bayesian_fusion.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """
  2. 贝叶斯信号融合引擎
  3. 使用贝叶斯网络进行非线性信号融合
  4. P(上涨|信号,生态) ∝ P(生态) × ∏ P(信号i|上涨,生态) / P(信号i|生态)
  5. """
  6. from dataclasses import dataclass, field
  7. from typing import Dict, List, Optional, Tuple, Any
  8. from datetime import datetime
  9. import numpy as np
  10. import pandas as pd
  11. from scipy import stats
  12. @dataclass
  13. class FusedSignal:
  14. """融合后的信号"""
  15. up_probability: float # 上涨概率
  16. down_probability: float # 下跌概率
  17. neutral_probability: float # 横盘概率
  18. overall_confidence: float # 整体置信度
  19. signal_grade: str # 信号等级:strong/medium/weak/none
  20. recommended_action: str # 建议动作
  21. fusion_metadata: Dict[str, Any] = field(default_factory=dict)
  22. timestamp: datetime = field(default_factory=datetime.now)
  23. class BayesianSignalFusion:
  24. """
  25. 贝叶斯信号融合器
  26. 融合逻辑:
  27. 1. 一级信号 → 贝叶斯网络输入
  28. 2. 生态作为隐变量调节条件概率
  29. 3. 输出后验概率分布
  30. """
  31. def __init__(
  32. self,
  33. prior_up: float = 0.33,
  34. prior_down: float = 0.33,
  35. prior_neutral: float = 0.34,
  36. min_sample_size: int = 1000
  37. ):
  38. self.prior = {
  39. "up": prior_up,
  40. "down": prior_down,
  41. "neutral": prior_neutral
  42. }
  43. self.min_sample_size = min_sample_size
  44. # 历史条件概率表(简化实现)
  45. self.conditional_probs: Dict[str, Dict] = {}
  46. def fuse(
  47. self,
  48. primary_signals: List[Any],
  49. ecosystem: Any,
  50. fallback_to_weighted: bool = True
  51. ) -> FusedSignal:
  52. """
  53. 执行贝叶斯融合
  54. Args:
  55. primary_signals: 一级信号列表
  56. ecosystem: 市场生态
  57. fallback_to_weighted: 小样本时是否回退到加权平均
  58. Returns:
  59. FusedSignal: 融合后的信号
  60. """
  61. if not primary_signals:
  62. return FusedSignal(
  63. up_probability=self.prior["up"],
  64. down_probability=self.prior["down"],
  65. neutral_probability=self.prior["neutral"],
  66. overall_confidence=0.0,
  67. signal_grade="none",
  68. recommended_action="hold"
  69. )
  70. # 检查样本量
  71. if fallback_to_weighted:
  72. sample_size = self._estimate_sample_size(ecosystem)
  73. if sample_size < self.min_sample_size:
  74. return self._weighted_fusion(primary_signals, ecosystem, sample_size)
  75. # 贝叶斯推断
  76. posterior = self._bayesian_inference(primary_signals, ecosystem)
  77. # 确定信号等级和建议
  78. grade, action = self._determine_signal_grade(posterior)
  79. return FusedSignal(
  80. up_probability=posterior["up"],
  81. down_probability=posterior["down"],
  82. neutral_probability=posterior["neutral"],
  83. overall_confidence=self._calculate_confidence(posterior, primary_signals),
  84. signal_grade=grade,
  85. recommended_action=action,
  86. fusion_metadata={
  87. "method": "bayesian",
  88. "signal_count": len(primary_signals),
  89. "prior": self.prior.copy(),
  90. "posterior": posterior
  91. }
  92. )
  93. def _bayesian_inference(
  94. self,
  95. signals: List[Any],
  96. ecosystem: Any
  97. ) -> Dict[str, float]:
  98. """执行贝叶斯推断"""
  99. # 初始化后验为 prior
  100. posterior = self.prior.copy()
  101. # 获取生态条件
  102. regime = getattr(ecosystem.macro, 'regime', None) if ecosystem else None
  103. # 对每个信号更新后验
  104. for signal in signals:
  105. likelihood = self._calculate_likelihood(signal, regime)
  106. # 贝叶斯更新: P(H|D) ∝ P(D|H) * P(H)
  107. for direction in ["up", "down", "neutral"]:
  108. posterior[direction] *= likelihood.get(direction, 0.33)
  109. # 归一化
  110. total = sum(posterior.values())
  111. if total > 0:
  112. posterior = {k: v/total for k, v in posterior.items()}
  113. return posterior
  114. def _calculate_likelihood(
  115. self,
  116. signal: Any,
  117. regime: Any
  118. ) -> Dict[str, float]:
  119. """计算似然函数 P(信号|方向,生态)"""
  120. value = getattr(signal, 'value', 0)
  121. # 基于信号值和生态计算似然
  122. # 简化模型:信号值越正,上涨概率越高
  123. likelihood = {
  124. "up": max(0.1, min(0.9, 0.5 + value * 0.4)),
  125. "down": max(0.1, min(0.9, 0.5 - value * 0.4)),
  126. "neutral": max(0.1, min(0.9, 1 - abs(value) * 0.5))
  127. }
  128. # 生态调节
  129. if regime:
  130. regime_boost = {
  131. "summer": {"up": 1.2, "down": 0.8, "neutral": 0.9},
  132. "winter": {"up": 0.8, "down": 1.2, "neutral": 0.9},
  133. "spring": {"up": 1.1, "down": 0.9, "neutral": 0.95},
  134. "autumn": {"up": 0.9, "down": 0.9, "neutral": 1.1}
  135. }.get(regime.value, {"up": 1, "down": 1, "neutral": 1})
  136. for direction in likelihood:
  137. likelihood[direction] *= regime_boost[direction]
  138. # 重新归一化
  139. total = sum(likelihood.values())
  140. return {k: v/total for k, v in likelihood.items()}
  141. def _weighted_fusion(
  142. self,
  143. signals: List[Any],
  144. ecosystem: Any,
  145. sample_size: int
  146. ) -> FusedSignal:
  147. """小样本回退:加权平均融合"""
  148. if not signals:
  149. return self._create_neutral_signal("no_signals")
  150. # 加权平均
  151. weighted_sum = sum(s.value * getattr(s, 'confidence', 0.5) for s in signals)
  152. total_weight = sum(getattr(s, 'confidence', 0.5) for s in signals)
  153. if total_weight == 0:
  154. return self._create_neutral_signal("zero_weight")
  155. avg_signal = weighted_sum / total_weight
  156. # 转换为概率
  157. up_prob = max(0, min(1, 0.5 + avg_signal * 0.5))
  158. down_prob = max(0, min(1, 0.5 - avg_signal * 0.5))
  159. neutral_prob = 1 - abs(avg_signal)
  160. # 归一化
  161. total = up_prob + down_prob + neutral_prob
  162. probs = {
  163. "up": up_prob / total,
  164. "down": down_prob / total,
  165. "neutral": neutral_prob / total
  166. }
  167. grade, action = self._determine_signal_grade(probs)
  168. return FusedSignal(
  169. up_probability=probs["up"],
  170. down_probability=probs["down"],
  171. neutral_probability=probs["neutral"],
  172. overall_confidence=abs(avg_signal) * (sample_size / self.min_sample_size),
  173. signal_grade=grade,
  174. recommended_action=action,
  175. fusion_metadata={
  176. "method": "weighted_fallback",
  177. "reason": f"sample_size {sample_size} < {self.min_sample_size}",
  178. "signal_count": len(signals)
  179. }
  180. )
  181. def _determine_signal_grade(self, probs: Dict[str, float]) -> Tuple[str, str]:
  182. """确定信号等级"""
  183. max_prob = max(probs.values())
  184. dominant = max(probs, key=probs.get)
  185. if max_prob > 0.7 and probs["neutral"] < 0.3:
  186. grade = "strong"
  187. elif max_prob > 0.55:
  188. grade = "medium"
  189. elif max_prob > 0.45:
  190. grade = "weak"
  191. else:
  192. grade = "none"
  193. # 建议动作
  194. if dominant == "up" and grade in ["strong", "medium"]:
  195. action = "buy"
  196. elif dominant == "down" and grade in ["strong", "medium"]:
  197. action = "sell"
  198. else:
  199. action = "hold"
  200. return grade, action
  201. def _calculate_confidence(
  202. self,
  203. posterior: Dict[str, float],
  204. signals: List[Any]
  205. ) -> float:
  206. """计算整体置信度"""
  207. # 基于概率分布的熵计算置信度
  208. entropy = -sum(p * np.log(p + 1e-10) for p in posterior.values())
  209. max_entropy = np.log(3) # 三分类最大熵
  210. confidence = 1 - (entropy / max_entropy)
  211. # 信号数量加成
  212. signal_bonus = min(0.2, len(signals) * 0.02)
  213. return min(1.0, confidence + signal_bonus)
  214. def _estimate_sample_size(self, ecosystem: Any) -> int:
  215. """估算当前生态的历史样本量"""
  216. # 简化:返回固定值,实际应从数据库查询
  217. return 2000 # 假设充足样本
  218. def _create_neutral_signal(self, reason: str) -> FusedSignal:
  219. """创建中性信号"""
  220. return FusedSignal(
  221. up_probability=self.prior["up"],
  222. down_probability=self.prior["down"],
  223. neutral_probability=self.prior["neutral"],
  224. overall_confidence=0.0,
  225. signal_grade="none",
  226. recommended_action="hold",
  227. fusion_metadata={"reason": reason}
  228. )