""" 双均线交易策略 该策略使用两条不同周期的移动平均线(MA)产生交易信号: - 短期均线上穿长期均线时,产生买入信号 - 短期均线下穿长期均线时,产生卖出信号 """ import numpy as np import pandas as pd from typing import Dict, List, Optional, Tuple from dataclasses import dataclass from enum import Enum class SignalType(Enum): """信号类型枚举""" BUY = 1 # 买入 SELL = -1 # 卖出 HOLD = 0 # 持有 @dataclass class TradeSignal: """交易信号数据类""" timestamp: pd.Timestamp # 时间戳 signal_type: SignalType # 信号类型 price: float # 当前价格 short_ma: float # 短期均线值 long_ma: float # 长期均线值 reason: str # 信号原因 @dataclass class Position: """持仓信息数据类""" quantity: float # 持仓数量 entry_price: float # 入场价格 entry_time: pd.Timestamp # 入场时间 side: int # 方向:1多头,-1空头 class DualMAStrategy: """ 双均线交易策略类 参数: short_window: 短期均线窗口期(默认5) long_window: 长期均线窗口期(默认20) initial_capital: 初始资金(默认100000) """ def __init__( self, short_window: int = 5, long_window: int = 20, initial_capital: float = 100000.0 ): # 参数校验 if short_window >= long_window: raise ValueError("短期均线周期必须小于长期均线周期") self.short_window = short_window self.long_window = long_window self.initial_capital = initial_capital # 状态变量 self.position: Optional[Position] = None self.cash = initial_capital self.equity = initial_capital self.signals: List[TradeSignal] = [] self.trades: List[Dict] = [] def calculate_ma(self, data: pd.Series, window: int) -> pd.Series: """ 计算简单移动平均线 参数: data: 价格序列 window: 均线周期 返回: 移动平均线序列 """ return data.rolling(window=window, min_periods=window).mean() def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame: """ 生成交易信号 参数: df: 包含'close'列的DataFrame 返回: 添加了均线和信号列的DataFrame """ df = df.copy() # 计算双均线 df['short_ma'] = self.calculate_ma(df['close'], self.short_window) df['long_ma'] = self.calculate_ma(df['close'], self.long_window) # 初始化信号列 df['signal'] = 0 # 计算均线差值 df['ma_diff'] = df['short_ma'] - df['long_ma'] # 计算差值的一阶差分(判断穿越方向) df['ma_diff_prev'] = df['ma_diff'].shift(1) # 金叉:短期均线上穿长期均线 golden_cross = (df['ma_diff'] > 0) & (df['ma_diff_prev'] <= 0) df.loc[golden_cross, 'signal'] = 1 # 死叉:短期均线下穿长期均线 death_cross = (df['ma_diff'] < 0) & (df['ma_diff_prev'] >= 0) df.loc[death_cross, 'signal'] = -1 return df def on_bar(self, timestamp: pd.Timestamp, row: pd.Series) -> Optional[TradeSignal]: """ 处理每根K线数据 参数: timestamp: 时间戳 row: 包含价格数据和信号的Series 返回: 如果有信号则返回TradeSignal,否则返回None """ signal = None current_price = row['close'] short_ma = row['short_ma'] long_ma = row['long_ma'] # 处理买入信号(金叉) if row['signal'] == 1: # 如果没有持仓,则开多仓 if self.position is None: signal = TradeSignal( timestamp=timestamp, signal_type=SignalType.BUY, price=current_price, short_ma=short_ma, long_ma=long_ma, reason=f"金叉: 短期MA({self.short_window})上穿长期MA({self.long_window})" ) self._open_position(timestamp, current_price, 1) # 处理卖出信号(死叉) elif row['signal'] == -1: # 如果持有多仓,则平仓 if self.position is not None and self.position.side == 1: signal = TradeSignal( timestamp=timestamp, signal_type=SignalType.SELL, price=current_price, short_ma=short_ma, long_ma=long_ma, reason=f"死叉: 短期MA({self.short_window})下穿长期MA({self.long_window})" ) self._close_position(timestamp, current_price) # 更新权益 self._update_equity(current_price) if signal: self.signals.append(signal) return signal def _open_position( self, timestamp: pd.Timestamp, price: float, side: int ): """开仓""" # 全仓买入 quantity = (self.cash * 0.99) / price # 预留1%作为手续费缓冲 self.position = Position( quantity=quantity, entry_price=price, entry_time=timestamp, side=side ) cost = quantity * price self.cash -= cost print(f"[{timestamp}] 开仓 | 方向: {'多' if side == 1 else '空'} | " f"价格: {price:.2f} | 数量: {quantity:.4f} | 成本: {cost:.2f}") def _close_position(self, timestamp: pd.Timestamp, price: float): """平仓""" if self.position is None: return # 计算盈亏 quantity = self.position.quantity entry_price = self.position.entry_price if self.position.side == 1: # 多头平仓 pnl = (price - entry_price) * quantity pnl_pct = (price / entry_price - 1) * 100 else: # 空头平仓 pnl = (entry_price - price) * quantity pnl_pct = (entry_price / price - 1) * 100 # 回收资金 proceeds = quantity * price self.cash += proceeds # 记录交易 trade = { 'entry_time': self.position.entry_time, 'exit_time': timestamp, 'entry_price': entry_price, 'exit_price': price, 'quantity': quantity, 'side': self.position.side, 'pnl': pnl, 'pnl_pct': pnl_pct, 'holding_periods': (timestamp - self.position.entry_time).days } self.trades.append(trade) print(f"[{timestamp}] 平仓 | 价格: {price:.2f} | " f"盈亏: {pnl:+.2f} ({pnl_pct:+.2f}%) | " f"持仓周期: {trade['holding_periods']}天") self.position = None def _update_equity(self, current_price: float): """更新账户权益""" position_value = 0 if self.position is not None: position_value = self.position.quantity * current_price self.equity = self.cash + position_value def get_performance_summary(self) -> Dict: """ 获取策略绩效汇总 返回: 包含各项绩效指标的字典 """ if not self.trades: return { 'total_trades': 0, 'winning_trades': 0, 'losing_trades': 0, 'win_rate': 0, 'total_pnl': 0, 'avg_pnl': 0, 'max_pnl': 0, 'min_pnl': 0, 'total_return_pct': 0 } pnl_list = [t['pnl'] for t in self.trades] winning = [p for p in pnl_list if p > 0] losing = [p for p in pnl_list if p < 0] total_return = (self.equity - self.initial_capital) / self.initial_capital * 100 return { 'total_trades': len(self.trades), 'winning_trades': len(winning), 'losing_trades': len(losing), 'win_rate': len(winning) / len(self.trades) * 100 if self.trades else 0, 'total_pnl': sum(pnl_list), 'avg_pnl': np.mean(pnl_list), 'max_pnl': max(pnl_list), 'min_pnl': min(pnl_list), 'avg_win': np.mean(winning) if winning else 0, 'avg_loss': np.mean(losing) if losing else 0, 'profit_factor': abs(sum(winning) / sum(losing)) if sum(losing) != 0 else float('inf'), 'final_equity': self.equity, 'total_return_pct': total_return } def reset(self): """重置策略状态""" self.position = None self.cash = self.initial_capital self.equity = self.initial_capital self.signals = [] self.trades = []