trend_quality_evaluator.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 趋势质量评估器 (Trend Quality Evaluator)
  5. 多因子评分模型:0-100分制,≥60分触发交易
  6. 因子权重:
  7. - ADX趋势强度: 30%
  8. - 均线斜率: 25%
  9. - 波动率收缩: 20%
  10. - 多时间框架共振: 15%
  11. - 成交量确认: 10%
  12. """
  13. import numpy as np
  14. import pandas as pd
  15. import baostock as bs
  16. from dataclasses import dataclass
  17. from typing import Optional, Tuple
  18. import warnings
  19. warnings.filterwarnings('ignore')
  20. @dataclass
  21. class TrendQualityScore:
  22. """趋势质量评分结果"""
  23. total_score: float # 总分 0-100
  24. adx_score: float # ADX得分 0-30
  25. ma_slope_score: float # 均线斜率得分 0-25
  26. volatility_score: float # 波动率得分 0-20
  27. timeframe_score: float # 时间框架得分 0-15
  28. volume_score: float # 成交量得分 0-10
  29. is_tradeable: bool # 是否可交易 (>=60分)
  30. adx_value: float # 原始ADX值
  31. ma_slope: float # 均线斜率
  32. volatility_ratio: float # 波动率比率
  33. volume_ratio: float # 成交量比率
  34. class TrendQualityEvaluator:
  35. """趋势质量评估器"""
  36. def __init__(self):
  37. self.weights = {
  38. 'adx': 30,
  39. 'ma_slope': 25,
  40. 'volatility': 20,
  41. 'timeframe': 15,
  42. 'volume': 10
  43. }
  44. def calculate_adx(self, df: pd.DataFrame, period: int = 14) -> pd.Series:
  45. """计算ADX趋势强度指标"""
  46. high, low, close = df['high'], df['low'], df['close']
  47. # +DM和-DM
  48. plus_dm = high.diff()
  49. minus_dm = low.diff().abs()
  50. plus_dm = plus_dm.where((plus_dm > minus_dm) & (plus_dm > 0), 0)
  51. minus_dm = minus_dm.where((minus_dm > plus_dm) & (minus_dm > 0), 0)
  52. # 真实波幅 TR
  53. tr1 = high - low
  54. tr2 = (high - close.shift()).abs()
  55. tr3 = (low - close.shift()).abs()
  56. tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
  57. # ATR
  58. atr = tr.rolling(period).mean()
  59. # +DI和-DI
  60. plus_di = 100 * (plus_dm.rolling(period).mean() / atr)
  61. minus_di = 100 * (minus_dm.rolling(period).mean() / atr)
  62. # DX和ADX
  63. dx = (abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)) * 100
  64. adx = dx.rolling(period).mean()
  65. return adx
  66. def evaluate(self, df: pd.DataFrame, df_weekly: Optional[pd.DataFrame] = None) -> TrendQualityScore:
  67. """
  68. 评估趋势质量
  69. Args:
  70. df: 日线数据 DataFrame (需要包含 open, high, low, close, volume)
  71. df_weekly: 周线数据 (可选,用于多时间框架共振)
  72. Returns:
  73. TrendQualityScore: 评分结果
  74. """
  75. latest = df.iloc[-1]
  76. # 1. ADX趋势强度 (30分) - 阈值: ADX > 25
  77. adx = self.calculate_adx(df, 14)
  78. latest_adx = adx.iloc[-1]
  79. # 评分: ADX 0-50映射到0-30分,>25得满分
  80. if latest_adx >= 25:
  81. adx_score = 30
  82. else:
  83. adx_score = min(30, latest_adx * 30 / 25)
  84. # 2. 均线斜率 (25分) - 阈值: MA20/MA20[5] > 1.002
  85. ma20 = df['close'].rolling(20).mean()
  86. ma20_current = ma20.iloc[-1]
  87. ma20_5days_ago = ma20.iloc[-5] if len(ma20) >= 5 else ma20_current
  88. ma_slope = ma20_current / ma20_5days_ago if ma20_5days_ago > 0 else 1
  89. # 评分: 斜率 > 1.002得满分,线性递减
  90. if ma_slope >= 1.005:
  91. ma_slope_score = 25
  92. elif ma_slope >= 1.002:
  93. ma_slope_score = 25 * (ma_slope - 1.002) / (1.005 - 1.002) + 15
  94. elif ma_slope >= 1.0:
  95. ma_slope_score = 15 * (ma_slope - 1.0) / 0.002
  96. else:
  97. ma_slope_score = 0
  98. # 3. 波动率收缩 (20分) - 阈值: ATR(14)/ATR(50) < 0.8
  99. atr14 = self._calculate_atr(df, 14)
  100. atr50 = self._calculate_atr(df, 50)
  101. volatility_ratio = atr14 / atr50 if atr50 > 0 else 1
  102. # 评分: 比率 < 0.8得满分,越小越好
  103. if volatility_ratio <= 0.6:
  104. volatility_score = 20
  105. elif volatility_ratio <= 0.8:
  106. volatility_score = 20 - 10 * (volatility_ratio - 0.6) / 0.2
  107. elif volatility_ratio <= 1.0:
  108. volatility_score = 10 - 10 * (volatility_ratio - 0.8) / 0.2
  109. else:
  110. volatility_score = 0
  111. # 4. 多时间框架共振 (15分) - 日线突破+周线方向一致
  112. timeframe_score = 0
  113. if df_weekly is not None and len(df_weekly) >= 5:
  114. # 日线突破: 收盘价 > MA20
  115. daily_breakout = latest['close'] > ma20_current
  116. # 周线方向: 周线MA5 > MA10
  117. weekly_ma5 = df_weekly['close'].rolling(5).mean().iloc[-1]
  118. weekly_ma10 = df_weekly['close'].rolling(10).mean().iloc[-1]
  119. weekly_aligned = weekly_ma5 > weekly_ma10
  120. if daily_breakout and weekly_aligned:
  121. timeframe_score = 15
  122. elif daily_breakout or weekly_aligned:
  123. timeframe_score = 7.5
  124. else:
  125. # 无周线数据时,仅看日线突破
  126. daily_breakout = latest['close'] > ma20_current
  127. timeframe_score = 15 if daily_breakout else 0
  128. # 5. 成交量确认 (10分) - 突破当日成交量 > 20日均量1.5倍
  129. volume_ma20 = df['volume'].rolling(20).mean().iloc[-1]
  130. volume_ratio = latest['volume'] / volume_ma20 if volume_ma20 > 0 else 1
  131. # 评分: >1.5倍得满分
  132. if volume_ratio >= 2.0:
  133. volume_score = 10
  134. elif volume_ratio >= 1.5:
  135. volume_score = 10 - 5 * (2.0 - volume_ratio) / 0.5
  136. elif volume_ratio >= 1.0:
  137. volume_score = 5 - 5 * (1.5 - volume_ratio) / 0.5
  138. else:
  139. volume_score = 0
  140. # 计算总分
  141. total_score = adx_score + ma_slope_score + volatility_score + timeframe_score + volume_score
  142. return TrendQualityScore(
  143. total_score=round(total_score, 1),
  144. adx_score=round(adx_score, 1),
  145. ma_slope_score=round(ma_slope_score, 1),
  146. volatility_score=round(volatility_score, 1),
  147. timeframe_score=round(timeframe_score, 1),
  148. volume_score=round(volume_score, 1),
  149. is_tradeable=total_score >= 60,
  150. adx_value=round(latest_adx, 2),
  151. ma_slope=round(ma_slope, 4),
  152. volatility_ratio=round(volatility_ratio, 3),
  153. volume_ratio=round(volume_ratio, 2)
  154. )
  155. def _calculate_atr(self, df: pd.DataFrame, period: int) -> float:
  156. """计算ATR"""
  157. high, low, close = df['high'], df['low'], df['close']
  158. tr1 = high - low
  159. tr2 = (high - close.shift()).abs()
  160. tr3 = (low - close.shift()).abs()
  161. tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
  162. return tr.rolling(period).mean().iloc[-1]
  163. def fetch_stock_data(symbol: str, start_date: str, end_date: str, frequency: str = "d") -> Optional[pd.DataFrame]:
  164. """获取股票数据"""
  165. try:
  166. bs.login()
  167. if symbol.startswith('6'):
  168. code = f"sh.{symbol}"
  169. elif symbol.startswith('0') or symbol.startswith('3'):
  170. code = f"sz.{symbol}"
  171. else:
  172. code = symbol
  173. rs = bs.query_history_k_data_plus(
  174. code,
  175. "date,open,high,low,close,volume",
  176. start_date=start_date,
  177. end_date=end_date,
  178. frequency=frequency,
  179. adjustflag="3"
  180. )
  181. data = []
  182. while rs.error_code == '0' and rs.next():
  183. row = rs.get_row_data()
  184. data.append({
  185. 'date': row[0],
  186. 'open': float(row[1]),
  187. 'high': float(row[2]),
  188. 'low': float(row[3]),
  189. 'close': float(row[4]),
  190. 'volume': int(float(row[5]))
  191. })
  192. bs.logout()
  193. if not data:
  194. return None
  195. df = pd.DataFrame(data)
  196. df['date'] = pd.to_datetime(df['date'])
  197. df = df.set_index('date').sort_index()
  198. return df
  199. except Exception as e:
  200. print(f"数据获取失败: {e}")
  201. return None
  202. def main():
  203. """主函数 - 示例用法"""
  204. print("="*70)
  205. print("趋势质量评估器 (Trend Quality Evaluator)")
  206. print("="*70)
  207. # 示例: 评估创业板50
  208. symbol = "399673" # 创业板50
  209. print(f"\n评估标的: 创业板50 ({symbol})")
  210. print("-"*70)
  211. # 获取日线数据
  212. df_daily = fetch_stock_data(symbol, "2024-01-01", "2026-12-31", "d")
  213. if df_daily is None:
  214. print("数据获取失败")
  215. return
  216. # 获取周线数据(用于多时间框架共振)
  217. df_weekly = fetch_stock_data(symbol, "2023-01-01", "2026-12-31", "w")
  218. # 评估趋势质量
  219. evaluator = TrendQualityEvaluator()
  220. score = evaluator.evaluate(df_daily, df_weekly)
  221. # 打印结果
  222. print(f"\n📊 评估日期: {df_daily.index[-1].strftime('%Y-%m-%d')}")
  223. print(f"📈 当前价格: {df_daily['close'].iloc[-1]:.2f}")
  224. print()
  225. print("="*50)
  226. print("评分详情 (满分100分)")
  227. print("="*50)
  228. print(f"{'1. ADX趋势强度 (30分):':<25} {score.adx_score:>6.1f}分 (ADX={score.adx_value:.2f})")
  229. print(f"{'2. 均线斜率 (25分):':<25} {score.ma_slope_score:>6.1f}分 (斜率={score.ma_slope:.4f})")
  230. print(f"{'3. 波动率收缩 (20分):':<25} {score.volatility_score:>6.1f}分 (ATR比={score.volatility_ratio:.3f})")
  231. print(f"{'4. 多时间框架共振 (15分):':<25} {score.timeframe_score:>6.1f}分")
  232. print(f"{'5. 成交量确认 (10分):':<25} {score.volume_score:>6.1f}分 (量比={score.volume_ratio:.2f}x)")
  233. print("-"*50)
  234. print(f"{'总分:':<25} {score.total_score:>6.1f}分")
  235. print("="*50)
  236. # 交易建议
  237. print(f"\n🎯 交易建议:")
  238. if score.is_tradeable:
  239. print(f" ✅ 趋势质量良好 (≥60分),建议交易")
  240. if score.total_score >= 80:
  241. print(f" 💎 优秀趋势!建议重仓")
  242. elif score.total_score >= 70:
  243. print(f" ⭐ 良好趋势!建议中等仓位")
  244. else:
  245. print(f" 📌 及格趋势!建议轻仓试探")
  246. else:
  247. print(f" ❌ 趋势质量不足 (<60分),建议观望")
  248. if score.total_score < 40:
  249. print(f" ⚠️ 趋势混乱,避免交易")
  250. print("\n" + "="*70)
  251. if __name__ == "__main__":
  252. main()