dual_ma_strategy.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """
  2. 双均线交易策略
  3. 该策略使用两条不同周期的移动平均线(MA)产生交易信号:
  4. - 短期均线上穿长期均线时,产生买入信号
  5. - 短期均线下穿长期均线时,产生卖出信号
  6. """
  7. import numpy as np
  8. import pandas as pd
  9. from typing import Dict, List, Optional, Tuple
  10. from dataclasses import dataclass
  11. from enum import Enum
  12. class SignalType(Enum):
  13. """信号类型枚举"""
  14. BUY = 1 # 买入
  15. SELL = -1 # 卖出
  16. HOLD = 0 # 持有
  17. @dataclass
  18. class TradeSignal:
  19. """交易信号数据类"""
  20. timestamp: pd.Timestamp # 时间戳
  21. signal_type: SignalType # 信号类型
  22. price: float # 当前价格
  23. short_ma: float # 短期均线值
  24. long_ma: float # 长期均线值
  25. reason: str # 信号原因
  26. @dataclass
  27. class Position:
  28. """持仓信息数据类"""
  29. quantity: float # 持仓数量
  30. entry_price: float # 入场价格
  31. entry_time: pd.Timestamp # 入场时间
  32. side: int # 方向:1多头,-1空头
  33. class DualMAStrategy:
  34. """
  35. 双均线交易策略类
  36. 参数:
  37. short_window: 短期均线窗口期(默认5)
  38. long_window: 长期均线窗口期(默认20)
  39. initial_capital: 初始资金(默认100000)
  40. """
  41. def __init__(
  42. self,
  43. short_window: int = 5,
  44. long_window: int = 20,
  45. initial_capital: float = 100000.0
  46. ):
  47. # 参数校验
  48. if short_window >= long_window:
  49. raise ValueError("短期均线周期必须小于长期均线周期")
  50. self.short_window = short_window
  51. self.long_window = long_window
  52. self.initial_capital = initial_capital
  53. # 状态变量
  54. self.position: Optional[Position] = None
  55. self.cash = initial_capital
  56. self.equity = initial_capital
  57. self.signals: List[TradeSignal] = []
  58. self.trades: List[Dict] = []
  59. def calculate_ma(self, data: pd.Series, window: int) -> pd.Series:
  60. """
  61. 计算简单移动平均线
  62. 参数:
  63. data: 价格序列
  64. window: 均线周期
  65. 返回:
  66. 移动平均线序列
  67. """
  68. return data.rolling(window=window, min_periods=window).mean()
  69. def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
  70. """
  71. 生成交易信号
  72. 参数:
  73. df: 包含'close'列的DataFrame
  74. 返回:
  75. 添加了均线和信号列的DataFrame
  76. """
  77. df = df.copy()
  78. # 计算双均线
  79. df['short_ma'] = self.calculate_ma(df['close'], self.short_window)
  80. df['long_ma'] = self.calculate_ma(df['close'], self.long_window)
  81. # 初始化信号列
  82. df['signal'] = 0
  83. # 计算均线差值
  84. df['ma_diff'] = df['short_ma'] - df['long_ma']
  85. # 计算差值的一阶差分(判断穿越方向)
  86. df['ma_diff_prev'] = df['ma_diff'].shift(1)
  87. # 金叉:短期均线上穿长期均线
  88. golden_cross = (df['ma_diff'] > 0) & (df['ma_diff_prev'] <= 0)
  89. df.loc[golden_cross, 'signal'] = 1
  90. # 死叉:短期均线下穿长期均线
  91. death_cross = (df['ma_diff'] < 0) & (df['ma_diff_prev'] >= 0)
  92. df.loc[death_cross, 'signal'] = -1
  93. return df
  94. def on_bar(self, timestamp: pd.Timestamp, row: pd.Series) -> Optional[TradeSignal]:
  95. """
  96. 处理每根K线数据
  97. 参数:
  98. timestamp: 时间戳
  99. row: 包含价格数据和信号的Series
  100. 返回:
  101. 如果有信号则返回TradeSignal,否则返回None
  102. """
  103. signal = None
  104. current_price = row['close']
  105. short_ma = row['short_ma']
  106. long_ma = row['long_ma']
  107. # 处理买入信号(金叉)
  108. if row['signal'] == 1:
  109. # 如果没有持仓,则开多仓
  110. if self.position is None:
  111. signal = TradeSignal(
  112. timestamp=timestamp,
  113. signal_type=SignalType.BUY,
  114. price=current_price,
  115. short_ma=short_ma,
  116. long_ma=long_ma,
  117. reason=f"金叉: 短期MA({self.short_window})上穿长期MA({self.long_window})"
  118. )
  119. self._open_position(timestamp, current_price, 1)
  120. # 处理卖出信号(死叉)
  121. elif row['signal'] == -1:
  122. # 如果持有多仓,则平仓
  123. if self.position is not None and self.position.side == 1:
  124. signal = TradeSignal(
  125. timestamp=timestamp,
  126. signal_type=SignalType.SELL,
  127. price=current_price,
  128. short_ma=short_ma,
  129. long_ma=long_ma,
  130. reason=f"死叉: 短期MA({self.short_window})下穿长期MA({self.long_window})"
  131. )
  132. self._close_position(timestamp, current_price)
  133. # 更新权益
  134. self._update_equity(current_price)
  135. if signal:
  136. self.signals.append(signal)
  137. return signal
  138. def _open_position(
  139. self,
  140. timestamp: pd.Timestamp,
  141. price: float,
  142. side: int
  143. ):
  144. """开仓"""
  145. # 全仓买入
  146. quantity = (self.cash * 0.99) / price # 预留1%作为手续费缓冲
  147. self.position = Position(
  148. quantity=quantity,
  149. entry_price=price,
  150. entry_time=timestamp,
  151. side=side
  152. )
  153. cost = quantity * price
  154. self.cash -= cost
  155. print(f"[{timestamp}] 开仓 | 方向: {'多' if side == 1 else '空'} | "
  156. f"价格: {price:.2f} | 数量: {quantity:.4f} | 成本: {cost:.2f}")
  157. def _close_position(self, timestamp: pd.Timestamp, price: float):
  158. """平仓"""
  159. if self.position is None:
  160. return
  161. # 计算盈亏
  162. quantity = self.position.quantity
  163. entry_price = self.position.entry_price
  164. if self.position.side == 1:
  165. # 多头平仓
  166. pnl = (price - entry_price) * quantity
  167. pnl_pct = (price / entry_price - 1) * 100
  168. else:
  169. # 空头平仓
  170. pnl = (entry_price - price) * quantity
  171. pnl_pct = (entry_price / price - 1) * 100
  172. # 回收资金
  173. proceeds = quantity * price
  174. self.cash += proceeds
  175. # 记录交易
  176. trade = {
  177. 'entry_time': self.position.entry_time,
  178. 'exit_time': timestamp,
  179. 'entry_price': entry_price,
  180. 'exit_price': price,
  181. 'quantity': quantity,
  182. 'side': self.position.side,
  183. 'pnl': pnl,
  184. 'pnl_pct': pnl_pct,
  185. 'holding_periods': (timestamp - self.position.entry_time).days
  186. }
  187. self.trades.append(trade)
  188. print(f"[{timestamp}] 平仓 | 价格: {price:.2f} | "
  189. f"盈亏: {pnl:+.2f} ({pnl_pct:+.2f}%) | "
  190. f"持仓周期: {trade['holding_periods']}天")
  191. self.position = None
  192. def _update_equity(self, current_price: float):
  193. """更新账户权益"""
  194. position_value = 0
  195. if self.position is not None:
  196. position_value = self.position.quantity * current_price
  197. self.equity = self.cash + position_value
  198. def get_performance_summary(self) -> Dict:
  199. """
  200. 获取策略绩效汇总
  201. 返回:
  202. 包含各项绩效指标的字典
  203. """
  204. if not self.trades:
  205. return {
  206. 'total_trades': 0,
  207. 'winning_trades': 0,
  208. 'losing_trades': 0,
  209. 'win_rate': 0,
  210. 'total_pnl': 0,
  211. 'avg_pnl': 0,
  212. 'max_pnl': 0,
  213. 'min_pnl': 0,
  214. 'total_return_pct': 0
  215. }
  216. pnl_list = [t['pnl'] for t in self.trades]
  217. winning = [p for p in pnl_list if p > 0]
  218. losing = [p for p in pnl_list if p < 0]
  219. total_return = (self.equity - self.initial_capital) / self.initial_capital * 100
  220. return {
  221. 'total_trades': len(self.trades),
  222. 'winning_trades': len(winning),
  223. 'losing_trades': len(losing),
  224. 'win_rate': len(winning) / len(self.trades) * 100 if self.trades else 0,
  225. 'total_pnl': sum(pnl_list),
  226. 'avg_pnl': np.mean(pnl_list),
  227. 'max_pnl': max(pnl_list),
  228. 'min_pnl': min(pnl_list),
  229. 'avg_win': np.mean(winning) if winning else 0,
  230. 'avg_loss': np.mean(losing) if losing else 0,
  231. 'profit_factor': abs(sum(winning) / sum(losing)) if sum(losing) != 0 else float('inf'),
  232. 'final_equity': self.equity,
  233. 'total_return_pct': total_return
  234. }
  235. def reset(self):
  236. """重置策略状态"""
  237. self.position = None
  238. self.cash = self.initial_capital
  239. self.equity = self.initial_capital
  240. self.signals = []
  241. self.trades = []