[Python] 复制代码
# -*- coding: utf-8 -*-
from __future__ import annotations # Py3.7 兼容,如果不喜欢可以删掉
from typing import Optional, Tuple, Union, List, Dict
import numpy as np
import pandas as pd
class OrderFlow(object):
"""
把 (N, 2) 的 numpy 原始数据封装成一个带元数据和筛选能力的工具类。
原始 data 结构:
data.shape = (N, 2), dtype=object
data[i, 0] = ts (时间戳,如 20251121185900)
data[i, 1] = levels = ((price, buy, sell), ...),从低价到高价排序
内部 df 结构(逐价位):
row_idx : 原始 data 行号(第几根 K)
ts : 时间戳
ts_dt : 可选,datetime
pos : 当前 K 内的位置索引,0=最低价,1=次低价,...,N=最高价
price : 价格
buy : 主买量
sell : 主卖量
total_vol : buy + sell
sell_ratio_up : 卖侧基础比值 = sell[i] / buy[i+1],无效为 NaN
buy_ratio_down : 买侧基础比值 = buy[i] / sell[i-1],无效为 NaN
额外:
k_summary : 每根 K 的汇总元数据(见 _build_k_summary)
"""
# ==================== 初始化 & 预处理 ====================
def __init__(
self,
data: np.ndarray,
parse_ts_to_datetime: bool = False,
precompute_summary: bool = True,
precompute_ratios: bool = True,
):
self.df = self._orderflow_to_df(data)
if parse_ts_to_datetime:
self._add_ts_datetime()
if precompute_ratios:
self._precompute_total_and_ratios()
else:
# 至少 total_vol 是很常用的
self.df["total_vol"] = self.df["buy"] + self.df["sell"]
self.k_summary: Optional[pd.DataFrame] = None
if precompute_summary:
self.k_summary = self._build_k_summary()
@staticmethod
def _orderflow_to_df(data: np.ndarray) -> pd.DataFrame:
"""
把 (N,2) 的 numpy 数组打平成 DataFrame:
列为 [row_idx, ts, pos, price, buy, sell]
pos 直接使用原始 levels 的顺序:0=最低价,逐步递增。
"""
rows = []
for row_idx, row in enumerate(data):
ts = int(row[0])
levels = row[1] # ((price, buy, sell), ...),已从低到高排序
for pos, (price, buy, sell) in enumerate(levels):
rows.append((row_idx, ts, pos, float(price), float(buy), float(sell)))
df = pd.DataFrame(rows, columns=["row_idx", "ts", "pos", "price", "buy", "sell"])
return df
def _add_ts_datetime(self) -> None:
"""
可选:增加 ts_dt 列,把 ts 解析为 pandas datetime。
"""
if "ts_dt" in self.df.columns:
return
self.df["ts_dt"] = pd.to_datetime(
self.df["ts"].astype(str),
format="%Y%m%d%H%M%S",
errors="coerce",
)
def _precompute_total_and_ratios(self) -> None:
"""
预计算:
- total_vol = buy + sell
- sell_ratio_up = sell[i] / buy[i+1] (卖侧基础失衡比值)
- buy_ratio_down = buy[i] / sell[i-1] (买侧基础失衡比值)
无效(不存在上/下一档或分母<=0)的位置记为 NaN。
"""
# 先按 row_idx, pos 排好序,保证同一根 K 内从低到高
df = self.df.sort_values(["row_idx", "pos"]).copy()
# 总量
df["total_vol"] = df["buy"] + df["sell"]
# ------- 卖侧比值:sell[i] / buy[i+1] -------
# 对每根 K 单独 shift(-1),得到“更高一档”的买量
buy_next = df.groupby("row_idx")["buy"].shift(-1)
sell_ratio = df["sell"] / buy_next
invalid_sell = (buy_next.isna()) | (buy_next <= 0)
sell_ratio[invalid_sell] = np.nan
# ------- 买侧比值:buy[i] / sell[i-1] -------
# 对每根 K 单独 shift(1),得到“更低一档”的卖量
sell_prev = df.groupby("row_idx")["sell"].shift(1)
buy_ratio = df["buy"] / sell_prev
invalid_buy = (sell_prev.isna()) | (sell_prev <= 0)
buy_ratio[invalid_buy] = np.nan
df["sell_ratio_up"] = sell_ratio
df["buy_ratio_down"] = buy_ratio
self.df = df
# ==================== K 线级元数据 ====================
def _build_k_summary(self) -> pd.DataFrame:
"""
构建每根 K 线的摘要信息,包括:
- total_buy / total_sell / total_net
- level_count / min_price / max_price
- max_buy_pos / max_buy_price / max_buy_value
- max_sell_pos / max_sell_price / max_sell_value
- max_total_pos / max_total_price / max_total_value
"""
df = self.df
if "total_vol" not in df.columns:
df = df.copy()
df["total_vol"] = df["buy"] + df["sell"]
self.df = df
grp = df.groupby(["row_idx", "ts"], as_index=False)
base = grp.agg(
total_buy=("buy", "sum"),
total_sell=("sell", "sum"),
level_count=("pos", "count"),
min_price=("price", "min"),
max_price=("price", "max"),
)
base["total_net"] = base["total_buy"] - base["total_sell"]
# 各类峰值位置
idx_max_buy = df.groupby(["row_idx", "ts"])["buy"].idxmax()
idx_max_sell = df.groupby(["row_idx", "ts"])["sell"].idxmax()
idx_max_total = df.groupby(["row_idx", "ts"])["total_vol"].idxmax()
buy_rows = df.loc[idx_max_buy, ["row_idx", "ts", "pos", "price", "buy"]]
sell_rows = df.loc[idx_max_sell, ["row_idx", "ts", "pos", "price", "sell"]]
total_rows = df.loc[idx_max_total, ["row_idx", "ts", "pos", "price", "total_vol"]]
buy_rows = buy_rows.rename(columns={
"pos": "max_buy_pos",
"price": "max_buy_price",
"buy": "max_buy_value",
})
sell_rows = sell_rows.rename(columns={
"pos": "max_sell_pos",
"price": "max_sell_price",
"sell": "max_sell_value",
})
total_rows = total_rows.rename(columns={
"pos": "max_total_pos",
"price": "max_total_price",
"total_vol": "max_total_value",
})
summary = base.merge(buy_rows, on=["row_idx", "ts"], how="left")
summary = summary.merge(sell_rows, on=["row_idx", "ts"], how="left")
summary = summary.merge(total_rows, on=["row_idx", "ts"], how="left")
return summary
def _ensure_k_summary(self) -> None:
if self.k_summary is None:
self.k_summary = self._build_k_summary()
def get_k_summary(
self,
ts: Optional[int] = None,
row_idx: Optional[int] = None,
) -> Optional[dict]:
"""
获取单根 K 线的元数据摘要。
"""
self._ensure_k_summary()
ks = self.k_summary
if ts is not None:
sub = ks[ks["ts"] == int(ts)]
elif row_idx is not None:
sub = ks[ks["row_idx"] == int(row_idx)]
else:
raise ValueError("get_k_summary 需要 ts 或 row_idx 至少一个参数。")
if sub.empty:
return None
row = sub.iloc[0]
return {
"row_idx": int(row["row_idx"]),# 当前K在原始数据中的索引
"ts": int(row["ts"]), # 当前K时间戳
"total_buy": float(row["total_buy"]),# 主买总量
"total_sell": float(row["total_sell"]),# 主卖总量
"total_net": float(row["total_net"]),# 主买-主卖净量
"level_count": int(row["level_count"]),# 价位档数
"min_price": float(row["min_price"]),# 最低价
"max_price": float(row["max_price"]),# 最高价
"max_buy_pos": int(row["max_buy_pos"]),# 主买最大量所在位置(0=最低价,1=次低价,...)
"max_buy_price": float(row["max_buy_price"]),# 主买最大量所在价格
"max_buy_value": float(row["max_buy_value"]), # 主买最大量
"max_sell_pos": int(row["max_sell_pos"]),# 主卖最大量所在位置(0=最低价,1=次低价,...)
"max_sell_price": float(row["max_sell_price"]),# 主卖最大量所在价格
"max_sell_value": float(row["max_sell_value"]),# 主卖最大量
"max_total_pos": int(row["max_total_pos"]),# 买+卖总量最大所在位置(0=最低价,1=次低价,...)
"max_total_price": float(row["max_total_price"]),# 买+卖总量最大所在价格
"max_total_value": float(row["max_total_value"]),# 买+卖总量最大值
}
# 一些便捷小函数(可选)
# 返回:某根 K 的主买/主卖总量
def k_total_volume(
self,
ts: Optional[int] = None,
row_idx: Optional[int] = None,
) -> Optional[Tuple[float, float]]:
info = self.get_k_summary(ts=ts, row_idx=row_idx)
if info is None:
return None
return info["total_buy"], info["total_sell"]
# 返回:某根 K 的价位档数(levels 数量)
def k_level_count(
self,
ts: Optional[int] = None,
row_idx: Optional[int] = None,
) -> Optional[int]:
info = self.get_k_summary(ts=ts, row_idx=row_idx)
if info is None:
return None
return info["level_count"]
# 返回:买+卖总量最大那个价位的信息:
def k_max_total_volume_price(
self,
ts: Optional[int] = None,
row_idx: Optional[int] = None,
) -> Optional[dict]:
info = self.get_k_summary(ts=ts, row_idx=row_idx)
if info is None:
return None
return {
"pos": info["max_total_pos"],# 当前 K 内位置(0=最低价,1=次低价,...)
"price": info["max_total_price"],# 当前 最大成交量对应的 价格
"buy": info["max_buy_value"],# 当前 K 主买最大量,
"sell": info["max_sell_value"],# 当前 K 主卖最大量
"volume": info["max_total_value"],# 当前 K 买+卖总量最大值
}
# ==================== Selector 层 ====================
def select(
self,
ts: Optional[int] = None,# 单一时间戳
ts_range: Optional[Tuple[int, int]] = None,# 时间戳闭区间
ts_in: Optional[List[int]] = None,# 非连续时间点集合
row_idx: Optional[int] = None,# 单一 K 线索引
row_idx_range: Optional[Tuple[int, int]] = None,# K 线索引闭区间
row_idx_in: Optional[List[int]] = None,# 非连续 K 线索引集合
price_range: Optional[Tuple[float, float]] = None,# 价格闭区间
price_slice: Optional[slice] = None,# 价格切片(半开区间)
price_in: Optional[List[float]] = None, # 非连续价格集合
pos_range: Optional[Tuple[int, int]] = None,# 位置闭区间
pos_slice: Optional[slice] = None,# 位置切片(半开区间)
min_buy: Optional[float] = None,# 买量阈值过滤
min_sell: Optional[float] = None,# 卖量阈值过滤
min_total: Optional[float] = None,# 总量阈值过滤
) -> "OrderFlowView":
"""
通用筛选入口,返回一个 OrderFlowView 视图,后续在 View 上做统计。
维度说明(任意组合):
- 时间:
ts: 单一时间戳
ts_range: (start_ts, end_ts) 闭区间
ts_in: 非连续时间点集合
- K 线索引:
row_idx: 单一索引
row_idx_range: (i_min, i_max) 闭区间
row_idx_in: 非连续索引集合
- 价格:
price_range: (p_min, p_max) 闭区间
price_slice: slice(p_min, p_max) 半开区间 [p_min, p_max)
price_in: 非连续价格集合
- 位置(每根 K 内局部 pos):
pos_range: (i_min, i_max) 闭区间
pos_slice: slice(i_min, i_max) 半开区间 [i_min, i_max)
- 量阈值过滤:
min_buy : buy >= min_buy
min_sell : sell >= min_sell
min_total: (buy + sell) >= min_total
"""
df = self.df
mask = pd.Series(True, index=df.index)
# 时间维度
if ts is not None:
mask &= df["ts"] == int(ts)
if ts_range is not None:
t_min, t_max = ts_range
if t_min is not None:
mask &= df["ts"] >= int(t_min)
if t_max is not None:
mask &= df["ts"] <= int(t_max)
if ts_in is not None and len(ts_in) > 0:
mask &= df["ts"].isin([int(x) for x in ts_in])
# row_idx 维度
if row_idx is not None:
mask &= df["row_idx"] == int(row_idx)
if row_idx_range is not None:
i_min, i_max = row_idx_range
if i_min is not None:
mask &= df["row_idx"] >= int(i_min)
if i_max is not None:
mask &= df["row_idx"] <= int(i_max)
if row_idx_in is not None and len(row_idx_in) > 0:
mask &= df["row_idx"].isin([int(x) for x in row_idx_in])
# 价格维度
if price_range is not None:
p_min, p_max = price_range
if p_min is not None:
mask &= df["price"] >= float(p_min)
if p_max is not None:
mask &= df["price"] <= float(p_max)
if price_slice is not None:
p_min = price_slice.start
p_max = price_slice.stop
if p_min is not None:
mask &= df["price"] >= float(p_min)
if p_max is not None:
mask &= df["price"] < float(p_max)
if price_in is not None and len(price_in) > 0:
price_list = [float(p) for p in price_in]
mask &= df["price"].isin(price_list)
# 位置维度
if pos_range is not None:
i_min, i_max = pos_range
if i_min is not None:
mask &= df["pos"] >= int(i_min)
if i_max is not None:
mask &= df["pos"] <= int(i_max)
if pos_slice is not None:
i_min = pos_slice.start
i_max = pos_slice.stop
if i_min is not None:
mask &= df["pos"] >= int(i_min)
if i_max is not None:
mask &= df["pos"] < int(i_max)
# 量阈值过滤
if min_buy is not None:
mask &= df["buy"] >= float(min_buy)
if min_sell is not None:
mask &= df["sell"] >= float(min_sell)
if min_total is not None:
mask &= (df["buy"] + df["sell"]) >= float(min_total)
sub = df[mask].copy()
return OrderFlowView(self, sub)
# 工具:抽取连续 True 区间
@staticmethod
def _extract_runs(mask: np.ndarray, min_run_len: int) -> List[Tuple[int, int]]:
"""
从 bool 数组 mask 中抽取长度 >= min_run_len 的连续 True 段。
返回 [(start_idx, end_idx), ...] 闭区间。
"""
runs: List[Tuple[int, int]] = []
if min_run_len <= 1:
return runs
n = len(mask)
if n == 0:
return runs
curr_start = None
curr_len = 0
for i in range(n):
if mask[i]:
if curr_start is None:
curr_start = i
curr_len = 1
else:
curr_len += 1
else:
if curr_start is not None and curr_len >= min_run_len:
runs.append((curr_start, curr_start + curr_len - 1))
curr_start = None
curr_len = 0
if curr_start is not None and curr_len >= min_run_len:
runs.append((curr_start, curr_start + curr_len - 1))
return runs
class OrderFlowView(object):
"""
OrderFlow 的筛选视图,持有一个子 DataFrame 和对父对象的引用。
用法示例:
of = OrderFlow(data)
view = of.select(ts_range=(..., ...), min_total=100)
# 1) 对筛选结果做跨K的成交量累加
total_buy, total_sell = view.sum_volume(side="both")
# 2) 对视图中每个 ts 独立做失衡堆积统计
imb = view.imbalance_segments(
threshold_sell=2.0,
threshold_buy=2.0,
min_sell=0.0,
min_buy=0.0,
run_len_sell=2,
run_len_buy=2,
)
"""
def __init__(self, parent: OrderFlow, df: pd.DataFrame):
self.parent = parent
self.df = df
def get_df(self, copy: bool = True) -> pd.DataFrame:
return self.df.copy() if copy else self.df
# ==================== 1. 跨K的成交量累加 ====================
def sum_volume(self, side: str = "both") -> Tuple[float, float, float]:
"""
对当前视图中的所有数据进行成交量累加。
side 参数只是语义上的,返回统一为:
(total_buy, total_sell, total_both)
- total_buy = 所有行 buy 之和
- total_sell = 所有行 sell 之和
- total_both = total_buy + total_sell
"""
sub = self.df
total_buy = float(sub["buy"].sum())
total_sell = float(sub["sell"].sum())
total_both = total_buy + total_sell
return total_buy, total_sell, total_both
# ==================== 2. 单 K 失衡堆积统计 ====================
def imbalance_segments(
self,
threshold_sell: Optional[float] = None,
threshold_buy: Optional[float] = None,
min_sell: float = 0.0,
min_buy: float = 0.0,
run_len_sell: int = 1,
run_len_buy: int = 1,
) -> Dict[int, Dict[str, List[Dict[str, object]]]]:
"""
对当前视图中的每个 ts(单根 K)独立统计失衡堆积区间。
注意事项:
- “失衡堆积”是针对连续价格的概念,需要完整的价格序列。
- 如果当前视图对价格/pos 做了过滤,会导致价格不连续,则结果含义会发生变化。
一般建议:对失衡堆积统计时,视图只按时间维度筛选(ts/row_idx),不要裁剪价格。
基础比值(已在 parent.df 中预处理好):
- 卖侧基础比值:sell_ratio_up = sell[i] / buy[i+1]
- 买侧基础比值:buy_ratio_down = buy[i] / sell[i-1]
阈值:
- threshold_sell: 卖侧比值阈值;为 None 则不统计卖侧
- threshold_buy : 买侧比值阈值;为 None 则不统计买侧
- min_sell : 对卖侧来说,当前档 sell >= min_sell 才参与
- min_buy : 对买侧来说,当前档 buy >= min_buy 才参与
堆积段长度:
- run_len_sell >= 2 表示至少连续 run_len_sell 个卖侧失衡视为一个堆积区间
- run_len_buy 同理
返回:
{
ts1: {
"sell": [ {段1}, {段2}, ... ],
"buy": [ {段1}, {段2}, ... ],
},
ts2: { ... },
...
}
每个段的结构:
{
"start_pos": int,
"end_pos": int,
"length": int,
"start_price": float,
"end_price": float,
"price_list": [float, ...],
"ratio_list": [float, ...], # 对应侧的比值列表
"volume_list": [float, ...], # 卖侧=卖量列表,买侧=买量列表
}
"""
parent_df = self.parent.df
result: Dict[int, Dict[str, List[Dict[str, object]]]] = {}
# 当前视图中包含的 ts 集合(多 K 时各自独立处理)
ts_values = sorted(set(self.df["ts"].tolist()))
for ts_val in ts_values:
snap = parent_df[parent_df["ts"] == ts_val].sort_values("pos").reset_index(drop=True)
if snap.empty or len(snap) < 2:
result[ts_val] = {"sell": [], "buy": []}
continue
pos_series = snap["pos"]
price_series = snap["price"]
buy_curr = snap["buy"]
sell_curr = snap["sell"]
sell_ratio_up = snap.get("sell_ratio_up")
buy_ratio_down = snap.get("buy_ratio_down")
sell_runs: List[Dict[str, object]] = []
buy_runs: List[Dict[str, object]] = []
# ---------- 卖侧堆积 ----------
if threshold_sell is not None and sell_ratio_up is not None:
mask_sell = (
sell_ratio_up.notna()
& (sell_ratio_up >= float(threshold_sell))
& (sell_curr >= float(min_sell))
)
mask_arr = mask_sell.to_numpy()
runs_idx = OrderFlow._extract_runs(mask_arr, run_len_sell) if run_len_sell >= 2 else [
(i, i) for i in range(len(mask_arr)) if mask_arr[i]
]
for start_idx, end_idx in runs_idx:
idx_block = list(range(start_idx, end_idx + 1))
prices_block = price_series.iloc[idx_block].astype(float).tolist()
ratios_block = sell_ratio_up.iloc[idx_block].astype(float).tolist()
vols_block = sell_curr.iloc[idx_block].astype(float).tolist()
sell_runs.append({
"start_pos": int(pos_series.iloc[start_idx]),# 起始位置
"end_pos": int(pos_series.iloc[end_idx]),# 结束位置
"length": int(end_idx - start_idx + 1),# 堆积长度
"start_price": float(price_series.iloc[start_idx]),# 起始价格
"end_price": float(price_series.iloc[end_idx]),# 结束价格
"price_list": prices_block,# 价格列表
"ratio_list": ratios_block,# 卖侧比值列表
"volume_list": vols_block,# 卖量列表
})
# ---------- 买侧堆积 ----------
if threshold_buy is not None and buy_ratio_down is not None:
mask_buy = (
buy_ratio_down.notna()
& (buy_ratio_down >= float(threshold_buy))
& (buy_curr >= float(min_buy))
)
mask_arr = mask_buy.to_numpy()
runs_idx = OrderFlow._extract_runs(mask_arr, run_len_buy) if run_len_buy >= 2 else [
(i, i) for i in range(len(mask_arr)) if mask_arr[i]
]
for start_idx, end_idx in runs_idx:
idx_block = list(range(start_idx, end_idx + 1))
prices_block = price_series.iloc[idx_block].astype(float).tolist()
ratios_block = buy_ratio_down.iloc[idx_block].astype(float).tolist()
vols_block = buy_curr.iloc[idx_block].astype(float).tolist()
buy_runs.append({
"start_pos": int(pos_series.iloc[start_idx]),# 起始位置
"end_pos": int(pos_series.iloc[end_idx]),# 结束位置
"length": int(end_idx - start_idx + 1),# 堆积长度
"start_price": float(price_series.iloc[start_idx]),# 起始价格
"end_price": float(price_series.iloc[end_idx]),# 结束价格
"price_list": prices_block,# 价格列表
"ratio_list": ratios_block,# 买侧比值列表
"volume_list": vols_block,# 买量列表
})
result[ts_val] = {
"sell": sell_runs,
"buy": buy_runs,
}
return result
@staticmethod
def longest_segment_for_ts(
imb_result: Dict[int, Dict[str, List[Dict[str, object]]]],
ts: int,
side: str = "both",
) -> Optional[Dict[str, object]]:
"""
在 imbalance_segments 的结果中,获取单个 ts 下“最长失衡片段”。
参数:
- imb_result: imbalance_segments(...) 的返回 dict
- ts: 目标时间戳
- side:
"sell" -> 只在卖侧堆积中找
"buy" -> 只在买侧堆积中找
"both" -> 卖买两侧一起找(谁长取谁)
返回:
- dict,包含:
{
"ts": int,
"side": "sell" 或 "buy",
"start_pos": ...,
"end_pos": ...,
"length": ...,
"start_price": ...,
"end_price": ...,
"price_list": [...],
"ratio_list": [...],
"volume_list": [...],
}
若该 ts 对应侧完全没有片段,则返回 None。
"""
if ts not in imb_result:
return None
ts_data = imb_result[ts]
side = side.lower()
candidates: List[Tuple[int, str, Dict[str, object]]] = []
if side in ("sell", "both"):
for seg in ts_data.get("sell", []):
candidates.append((int(seg.get("length", 0)), "sell", seg))
if side in ("buy", "both"):
for seg in ts_data.get("buy", []):
candidates.append((int(seg.get("length", 0)), "buy", seg))
if not candidates:
return None
# 取 length 最大的那个(并列时保留第一个)
candidates.sort(key=lambda x: x[0], reverse=True)
_, best_side, best_seg = candidates[0]
result = dict(best_seg) # 复制一份,避免外部修改原数据
result["ts"] = int(ts)
result["side"] = best_side
return result
@staticmethod
def longest_segment_global(
imb_result: Dict[int, Dict[str, List[Dict[str, object]]]],
side: str = "both",
) -> Optional[Dict[str, object]]:
"""
在 imbalance_segments 的结果中,获取“全局最长失衡片段”(跨所有 ts)。
side:
"sell" -> 只看卖侧
"buy" -> 只看买侧
"both" -> 卖买两侧一起比
"""
side = side.lower()
best_seg: Optional[Dict[str, object]] = None
best_length: int = -1
best_ts: Optional[int] = None
best_side: Optional[str] = None
for ts_val, ts_data in imb_result.items():
ts_int = int(ts_val)
# 卖侧
if side in ("sell", "both"):
for seg in ts_data.get("sell", []):
length = int(seg.get("length", 0))
if best_seg is None:
best_seg = seg
best_length = length
best_ts = ts_int
best_side = "sell"
else:
# 先比 length,再比 ts(你也可以改成更早优先)
if length > best_length or (
length == best_length and ts_int > (best_ts or -1)
):
best_seg = seg
best_length = length
best_ts = ts_int
best_side = "sell"
# 买侧
if side in ("buy", "both"):
for seg in ts_data.get("buy", []):
length = int(seg.get("length", 0))
if best_seg is None:
best_seg = seg
best_length = length
best_ts = ts_int
best_side = "buy"
else:
if length > best_length or (
length == best_length and ts_int > (best_ts or -1)
):
best_seg = seg
best_length = length
best_ts = ts_int
best_side = "buy"
if best_seg is None or best_ts is None or best_side is None:
return None
result = dict(best_seg)
result["ts"] = int(best_ts)
result["side"] = best_side
result["length"] = int(best_length)
return result
@staticmethod
def max_volume_segment_for_ts(
imb_result: Dict[int, Dict[str, List[Dict[str, object]]]],
ts: int,
side: str = "both",
) -> Optional[Dict[str, object]]:
"""
单个 ts 下“成交量最大的失衡片段”。
side:
"sell" -> 只在卖侧堆积中找
"buy" -> 只在买侧堆积中找
"both" -> 卖买两侧一起找(谁的总量大用谁)
"""
if ts not in imb_result:
return None
ts_data = imb_result[ts]
side = side.lower()
def seg_total_vol(seg: Dict[str, object]) -> float:
vols = seg.get("volume_list", [])
return float(sum(float(v) for v in vols))
best_seg: Optional[Dict[str, object]] = None
best_total: float = -1.0
best_side: Optional[str] = None
# 卖侧
if side in ("sell", "both"):
for seg in ts_data.get("sell", []):
tv = seg_total_vol(seg)
# 这里用 >,遇到并列就保留第一个
if tv > best_total:
best_total = tv
best_seg = seg
best_side = "sell"
# 买侧
if side in ("buy", "both"):
for seg in ts_data.get("buy", []):
tv = seg_total_vol(seg)
if tv > best_total:
best_total = tv
best_seg = seg
best_side = "buy"
if best_seg is None or best_side is None:
return None
result = dict(best_seg)
result["ts"] = int(ts)
result["side"] = best_side
result["total_volume"] = float(best_total)
return result
@staticmethod
def max_volume_segment_global(
imb_result: Dict[int, Dict[str, List[Dict[str, object]]]],
side: str = "both",
) -> Optional[Dict[str, object]]:
"""
在 imbalance_segments 的结果中,获取“全局成交量最大的失衡片段”(跨所有 ts)。
"""
side = side.lower()
def seg_total_vol(seg: Dict[str, object]) -> float:
vols = seg.get("volume_list", [])
return float(sum(float(v) for v in vols))
best_seg: Optional[Dict[str, object]] = None
best_total: float = -1.0
best_ts: Optional[int] = None
best_side: Optional[str] = None
for ts_val, ts_data in imb_result.items():
ts_int = int(ts_val)
# 卖侧
if side in ("sell", "both"):
for seg in ts_data.get("sell", []):
tv = seg_total_vol(seg)
if best_seg is None:
best_seg = seg
best_total = tv
best_ts = ts_int
best_side = "sell"
else:
# 先看总量,再看 ts(这里同样用 ts 更大优先,可按需要改)
if tv > best_total or (
tv == best_total and ts_int > (best_ts or -1)
):
best_seg = seg
best_total = tv
best_ts = ts_int
best_side = "sell"
# 买侧
if side in ("buy", "both"):
for seg in ts_data.get("buy", []):
tv = seg_total_vol(seg)
if best_seg is None:
best_seg = seg
best_total = tv
best_ts = ts_int
best_side = "buy"
else:
if tv > best_total or (
tv == best_total and ts_int > (best_ts or -1)
):
best_seg = seg
best_total = tv
best_ts = ts_int
best_side = "buy"
if best_seg is None or best_ts is None or best_side is None:
return None
result = dict(best_seg)
result["ts"] = int(best_ts)
result["side"] = best_side
result["total_volume"] = float(best_total)
return result