| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- """
- 双均线交易策略
- 该策略使用两条不同周期的移动平均线(MA)产生交易信号:
- - 短期均线上穿长期均线时,产生买入信号
- - 短期均线下穿长期均线时,产生卖出信号
- """
- import numpy as np
- import pandas as pd
- from typing import Dict, List, Optional, Tuple
- from dataclasses import dataclass
- from enum import Enum
- class SignalType(Enum):
- """信号类型枚举"""
- BUY = 1 # 买入
- SELL = -1 # 卖出
- HOLD = 0 # 持有
- @dataclass
- class TradeSignal:
- """交易信号数据类"""
- timestamp: pd.Timestamp # 时间戳
- signal_type: SignalType # 信号类型
- price: float # 当前价格
- short_ma: float # 短期均线值
- long_ma: float # 长期均线值
- reason: str # 信号原因
- @dataclass
- class Position:
- """持仓信息数据类"""
- quantity: float # 持仓数量
- entry_price: float # 入场价格
- entry_time: pd.Timestamp # 入场时间
- side: int # 方向:1多头,-1空头
- class DualMAStrategy:
- """
- 双均线交易策略类
- 参数:
- short_window: 短期均线窗口期(默认5)
- long_window: 长期均线窗口期(默认20)
- initial_capital: 初始资金(默认100000)
- """
- def __init__(
- self,
- short_window: int = 5,
- long_window: int = 20,
- initial_capital: float = 100000.0
- ):
- # 参数校验
- if short_window >= long_window:
- raise ValueError("短期均线周期必须小于长期均线周期")
- self.short_window = short_window
- self.long_window = long_window
- self.initial_capital = initial_capital
- # 状态变量
- self.position: Optional[Position] = None
- self.cash = initial_capital
- self.equity = initial_capital
- self.signals: List[TradeSignal] = []
- self.trades: List[Dict] = []
- def calculate_ma(self, data: pd.Series, window: int) -> pd.Series:
- """
- 计算简单移动平均线
- 参数:
- data: 价格序列
- window: 均线周期
- 返回:
- 移动平均线序列
- """
- return data.rolling(window=window, min_periods=window).mean()
- def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
- """
- 生成交易信号
- 参数:
- df: 包含'close'列的DataFrame
- 返回:
- 添加了均线和信号列的DataFrame
- """
- df = df.copy()
- # 计算双均线
- df['short_ma'] = self.calculate_ma(df['close'], self.short_window)
- df['long_ma'] = self.calculate_ma(df['close'], self.long_window)
- # 初始化信号列
- df['signal'] = 0
- # 计算均线差值
- df['ma_diff'] = df['short_ma'] - df['long_ma']
- # 计算差值的一阶差分(判断穿越方向)
- df['ma_diff_prev'] = df['ma_diff'].shift(1)
- # 金叉:短期均线上穿长期均线
- golden_cross = (df['ma_diff'] > 0) & (df['ma_diff_prev'] <= 0)
- df.loc[golden_cross, 'signal'] = 1
- # 死叉:短期均线下穿长期均线
- death_cross = (df['ma_diff'] < 0) & (df['ma_diff_prev'] >= 0)
- df.loc[death_cross, 'signal'] = -1
- return df
- def on_bar(self, timestamp: pd.Timestamp, row: pd.Series) -> Optional[TradeSignal]:
- """
- 处理每根K线数据
- 参数:
- timestamp: 时间戳
- row: 包含价格数据和信号的Series
- 返回:
- 如果有信号则返回TradeSignal,否则返回None
- """
- signal = None
- current_price = row['close']
- short_ma = row['short_ma']
- long_ma = row['long_ma']
- # 处理买入信号(金叉)
- if row['signal'] == 1:
- # 如果没有持仓,则开多仓
- if self.position is None:
- signal = TradeSignal(
- timestamp=timestamp,
- signal_type=SignalType.BUY,
- price=current_price,
- short_ma=short_ma,
- long_ma=long_ma,
- reason=f"金叉: 短期MA({self.short_window})上穿长期MA({self.long_window})"
- )
- self._open_position(timestamp, current_price, 1)
- # 处理卖出信号(死叉)
- elif row['signal'] == -1:
- # 如果持有多仓,则平仓
- if self.position is not None and self.position.side == 1:
- signal = TradeSignal(
- timestamp=timestamp,
- signal_type=SignalType.SELL,
- price=current_price,
- short_ma=short_ma,
- long_ma=long_ma,
- reason=f"死叉: 短期MA({self.short_window})下穿长期MA({self.long_window})"
- )
- self._close_position(timestamp, current_price)
- # 更新权益
- self._update_equity(current_price)
- if signal:
- self.signals.append(signal)
- return signal
- def _open_position(
- self,
- timestamp: pd.Timestamp,
- price: float,
- side: int
- ):
- """开仓"""
- # 全仓买入
- quantity = (self.cash * 0.99) / price # 预留1%作为手续费缓冲
- self.position = Position(
- quantity=quantity,
- entry_price=price,
- entry_time=timestamp,
- side=side
- )
- cost = quantity * price
- self.cash -= cost
- print(f"[{timestamp}] 开仓 | 方向: {'多' if side == 1 else '空'} | "
- f"价格: {price:.2f} | 数量: {quantity:.4f} | 成本: {cost:.2f}")
- def _close_position(self, timestamp: pd.Timestamp, price: float):
- """平仓"""
- if self.position is None:
- return
- # 计算盈亏
- quantity = self.position.quantity
- entry_price = self.position.entry_price
- if self.position.side == 1:
- # 多头平仓
- pnl = (price - entry_price) * quantity
- pnl_pct = (price / entry_price - 1) * 100
- else:
- # 空头平仓
- pnl = (entry_price - price) * quantity
- pnl_pct = (entry_price / price - 1) * 100
- # 回收资金
- proceeds = quantity * price
- self.cash += proceeds
- # 记录交易
- trade = {
- 'entry_time': self.position.entry_time,
- 'exit_time': timestamp,
- 'entry_price': entry_price,
- 'exit_price': price,
- 'quantity': quantity,
- 'side': self.position.side,
- 'pnl': pnl,
- 'pnl_pct': pnl_pct,
- 'holding_periods': (timestamp - self.position.entry_time).days
- }
- self.trades.append(trade)
- print(f"[{timestamp}] 平仓 | 价格: {price:.2f} | "
- f"盈亏: {pnl:+.2f} ({pnl_pct:+.2f}%) | "
- f"持仓周期: {trade['holding_periods']}天")
- self.position = None
- def _update_equity(self, current_price: float):
- """更新账户权益"""
- position_value = 0
- if self.position is not None:
- position_value = self.position.quantity * current_price
- self.equity = self.cash + position_value
- def get_performance_summary(self) -> Dict:
- """
- 获取策略绩效汇总
- 返回:
- 包含各项绩效指标的字典
- """
- if not self.trades:
- return {
- 'total_trades': 0,
- 'winning_trades': 0,
- 'losing_trades': 0,
- 'win_rate': 0,
- 'total_pnl': 0,
- 'avg_pnl': 0,
- 'max_pnl': 0,
- 'min_pnl': 0,
- 'total_return_pct': 0
- }
- pnl_list = [t['pnl'] for t in self.trades]
- winning = [p for p in pnl_list if p > 0]
- losing = [p for p in pnl_list if p < 0]
- total_return = (self.equity - self.initial_capital) / self.initial_capital * 100
- return {
- 'total_trades': len(self.trades),
- 'winning_trades': len(winning),
- 'losing_trades': len(losing),
- 'win_rate': len(winning) / len(self.trades) * 100 if self.trades else 0,
- 'total_pnl': sum(pnl_list),
- 'avg_pnl': np.mean(pnl_list),
- 'max_pnl': max(pnl_list),
- 'min_pnl': min(pnl_list),
- 'avg_win': np.mean(winning) if winning else 0,
- 'avg_loss': np.mean(losing) if losing else 0,
- 'profit_factor': abs(sum(winning) / sum(losing)) if sum(losing) != 0 else float('inf'),
- 'final_equity': self.equity,
- 'total_return_pct': total_return
- }
- def reset(self):
- """重置策略状态"""
- self.position = None
- self.cash = self.initial_capital
- self.equity = self.initial_capital
- self.signals = []
- self.trades = []
|