有没有一种简单的方法可以从布尔表达式中从 Pandas DataFrame 中提取行?

问题描述

我目前正在努力尝试使用矢量化从 DataFrame 中提取行。我很确定有一种简单的方法、表达式或函数可以实现这一点,但我找不到。 我有这个数据框(来自 MysqL 数据库):

             date_taux    taux  taux_min  taux_max
0  2021-02-15 13:55:00  2.1166    2.1155    2.1232
1  2021-02-15 14:00:00  2.1256    2.1166    2.1300
2  2021-02-15 14:05:00  2.1312    2.1206    2.1348
3  2021-02-15 14:10:00  2.1174    2.1166    2.1416
4  2021-02-15 14:15:00  2.1103    2.1060    2.1253
5  2021-02-15 14:20:00  2.1269    2.1143    2.1277
6  2021-02-15 14:25:00  2.1239    2.1115    2.1300
7  2021-02-15 14:30:00  2.0880    2.0879    2.1299
8  2021-02-15 14:35:00  2.0827    2.0827    2.1060
9  2021-02-15 14:40:00  2.0747    2.0718    2.0996
10 2021-02-15 14:45:00  2.0846    2.0779    2.0861
11 2021-02-15 14:50:00  2.0826    2.0806    2.0894
12 2021-02-15 14:55:00  2.0350    2.0350    2.0857
13 2021-02-15 15:00:00  2.0796    2.0350    2.0797
14 2021-02-15 15:05:00  2.0717    2.0587    2.0800
15 2021-02-15 15:10:00  2.0762    2.0705    2.0819
16 2021-02-15 15:15:00  2.0793    2.0650    2.0884
17 2021-02-15 15:20:00  2.1005    2.0831    2.1064
18 2021-02-15 15:25:00  2.1164    2.1017    2.1206
19 2021-02-15 15:30:00  2.1199    2.1176    2.1300

我也有这个 numpy 数组:

[2.         2.01694915 2.03389831 2.05084746 2.06779661 2.08474576
 2.10169492 2.11864407 2.13559322 2.15254237 2.16949153 2.18644068
 2.20338983 2.22033898 2.23728814 2.25423729 2.27118644 2.28813559
 2.30508475 2.3220339  2.33898305 2.3559322  2.37288136 2.38983051
 2.40677966 2.42372881 2.44067797 2.45762712 2.47457627 2.49152542
 2.50847458 2.52542373 2.54237288 2.55932203 2.57627119 2.59322034
 2.61016949 2.62711864 2.6440678  2.66101695 2.6779661  2.69491525
 2.71186441 2.72881356 2.74576271 2.76271186 2.77966102 2.79661017
 2.81355932 2.83050847 2.84745763 2.86440678 2.88135593 2.89830508
 2.91525424 2.93220339 2.94915254 2.96610169 2.98305085 3.        ]

我的目标是向数据框中添加一列,数组中的数字数量介于 taux_mintaux_max 之间。预期结果是:

             date_taux    taux  taux_min  taux_max amount_lines
0  2021-02-15 13:55:00  2.1166    2.1155    2.1232            1
1  2021-02-15 14:00:00  2.1256    2.1166    2.1300            1
2  2021-02-15 14:05:00  2.1312    2.1206    2.1348            0
3  2021-02-15 14:10:00  2.1174    2.1166    2.1416            2
4  2021-02-15 14:15:00  2.1103    2.1060    2.1253            1
5  2021-02-15 14:20:00  2.1269    2.1143    2.1277            1
6  2021-02-15 14:25:00  2.1239    2.1115    2.1300            1
7  2021-02-15 14:30:00  2.0880    2.0879    2.1299            2
8  2021-02-15 14:35:00  2.0827    2.0827    2.1060            2
9  2021-02-15 14:40:00  2.0747    2.0718    2.0996            1
10 2021-02-15 14:45:00  2.0846    2.0779    2.0861            1
...

我尝试使用此代码

sql = dbm.MysqL()
data = sql.pdselect("SELECT date_taux,taux,taux_min,taux_max FROM binance_rates_grid WHERE action = %s AND date_taux > %s ORDER BY date_taux ASC","TOMOUSDT",datetime.utcNow()-timedelta(days=11))
print(data)

print("==================")
grids = np.linspace(2,4,60)

data["lignes"] = len(grids[(data["taux_min"] < grids) & (data["taux_max"] < grids)])

print(data)

但我很清楚这个错误ValueError: ('Lengths must match to compare',(2868,),(60,))

我很确定我在这里遗漏了一些东西,但我不知道是什么。

解决方法

让我们尝试SELECT * FROM bible ORDER BY cast(id AS integer); 广播:

numpy

x,y = df[['taux_min','taux_max']].values.T
mask = (x[:,None] <= arr) & (arr <= y[:,None])
df['amount_lines'] = mask.sum(1)
,

我会使用 goapply 来遍历数组:

lambda

其中 df['amount_lines'] = df.apply(lambda x: sum(np.logical_and(arr >= x['taux_min'],arr <= x['taux_max'])),axis=1) 是 numpy 数组。

举个简单的例子:

grids

输出

arr = np.array([1,2,3,4,5,6,7,9])
df = pd.DataFrame({'A':[1,52,10],'B':[3,100,13]})
df.apply(lambda x: sum(np.logical_and(arr >= x['A'],arr <= x['B'])),axis=1)