bayesian_optimizer.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. 贝叶斯优化器
  3. 用于智能体参数的在线优化
  4. """
  5. from typing import Callable, Dict, List, Tuple, Optional
  6. import numpy as np
  7. from skopt import gp_minimize
  8. from skopt.space import Real, Integer, Categorical
  9. from skopt.utils import use_named_args
  10. class BayesianOptimizer:
  11. """
  12. 贝叶斯优化器
  13. 使用高斯过程代理模型进行样本高效的参数优化
  14. """
  15. def __init__(
  16. self,
  17. dimensions: List,
  18. n_initial_points: int = 5,
  19. acquisition_function: str = "EI", # Expected Improvement
  20. noise: float = 1e-5
  21. ):
  22. self.dimensions = dimensions
  23. self.n_initial_points = n_initial_points
  24. self.acquisition_function = acquisition_function
  25. self.noise = noise
  26. self.X: List[List] = [] # 参数历史
  27. self.y: List[float] = [] # 结果历史
  28. def optimize(
  29. self,
  30. objective: Callable,
  31. n_calls: int = 10,
  32. x0: Optional[List] = None,
  33. y0: Optional[List] = None
  34. ) -> Dict:
  35. """
  36. 执行优化
  37. Args:
  38. objective: 目标函数,返回要最小化的值(负收益)
  39. n_calls: 总采样次数
  40. x0: 初始参数点
  41. y0: 初始结果
  42. Returns:
  43. 优化结果字典
  44. """
  45. result = gp_minimize(
  46. func=objective,
  47. dimensions=self.dimensions,
  48. n_calls=n_calls,
  49. n_initial_points=self.n_initial_points,
  50. acq_func=self.acquisition_function,
  51. noise=self.noise,
  52. x0=x0,
  53. y0=y0,
  54. random_state=42
  55. )
  56. return {
  57. "best_params": result.x,
  58. "best_value": result.fun,
  59. "all_params": result.x_iters,
  60. "all_values": result.func_vals
  61. }
  62. def suggest_next_point(self) -> List:
  63. """建议下一个采样点(用于在线增量优化)"""
  64. if len(self.X) < self.n_initial_points:
  65. # 随机采样初始点
  66. return [dim.rvs()[0] for dim in self.dimensions]
  67. # 使用采集函数选择下一点(简化实现)
  68. # 实际应使用训练好的GP模型
  69. return [dim.rvs()[0] for dim in self.dimensions]
  70. def update(self, params: List, value: float):
  71. """更新观测历史"""
  72. self.X.append(params)
  73. self.y.append(value)