问题描述
我想让多类 Pandas 数据框在训练中更加平衡。我的训练集的简化版本如下所示:
不平衡数据帧:0、1、2 类的计数分别为 7、3 和 1
animal class
0 dog1 0
1 dog2 0
2 dog3 0
3 dog4 0
4 dog5 0
5 dog6 0
6 dog7 0
7 cat1 1
8 cat2 1
9 cat3 1
10 fish1 2
我用代码做的:
import pandas as pd
data = {'animal': ['dog1','dog2','dog3','dog4','dog5','dog6','dog7','cat1','cat2','cat3','fish1'],'class': [0,1,2]}
df = pd.DataFrame(data)
现在我想对多数类进行随机抽样,并随机对少数类进行抽样以达到每个类的指定值,以获得更平衡的数据框。
问题是,我可以在网上找到的所有 Pandas 教程或有关此主题的有关 stackoverflow 的其他问题都涉及将少数类随机过度采样到多数类的级别(例如:Duplicating training examples to handle class imbalance in a pandas data frame)或随机下将多数类抽样到少数类的水平。
由于我面临极端不平衡,我无法使多数类的大小等于少数类的大小。因此,我能找到的这些代码片段通常对我不起作用。理想情况下,我将能够指定每个类的确切样本数,然后通过过采样或欠采样生成(取决于我为该类指定的数量以及该类包含的样本数)。
例如
如果我指定:
- counts_0 = 5(是 7,所以意味着随机欠采样 2 个样本),
- counts_1 = 4(是 3,所以意味着随机过度采样 1 个样本),
- counts_2 = 3(1 表示随机过度采样 2 个样本)
我想变成这样:
更平衡的数据框:0、1和2类的计数分别为5、4和3
animal class
0 dog2 0
1 dog3 0
2 dog5 0
3 dog6 0
4 dog7 0
5 cat1 1
6 cat2 1
7 cat3 1
8 cat2 1
9 fish1 2
10 fish1 2
11 fish1 2
解决方法
由于 groupby.sample
不允许 n
大于组大小,如果 replace
不是 True
但替换为 True
意味着替换将甚至在本可以被降采样的组中也会发生。
相反,让我们尝试使用 groupby.apply
+ sample
并有条件地为每个组启用 replace
。
创建一个字典,将每个类映射到样本数量,并使用条件逻辑来确定有无替换:
sample_amounts = {0: 5,1: 4,2: 3}
s = (
df.groupby('class').apply(lambda g: g.sample(
# lookup number of samples to take
n=sample_amounts[g.name],# enable replacement if len is less than number of samples expected
replace=len(g) < sample_amounts[g.name]
))
)
s
:
animal class
class
0 5 dog6 0
3 dog4 0
6 dog7 0
4 dog5 0
2 dog3 0
1 9 cat3 1
8 cat2 1
7 cat1 1
8 cat2 1
2 10 fish1 2
10 fish1 2
10 fish1 2
droplevel
可用于保留初始索引(如果重要):
sample_amounts = {0: 5,2: 3}
s = (
df.groupby('class').apply(lambda g: g.sample(
n=sample_amounts[g.name],replace=len(g) < sample_amounts[g.name]
))
.droplevel(0)
)
s
:
animal class
6 dog7 0
3 dog4 0
2 dog3 0
4 dog5 0
1 dog2 0
7 cat1 1
8 cat2 1
8 cat2 1
8 cat2 1
10 fish1 2
10 fish1 2
10 fish1 2
如果索引不重要,可以使用
sample_amounts = {0: 5,replace=len(g) < sample_amounts[g.name]
))
.reset_index(drop=True)
)
s
:
animal class
0 dog1 0
1 dog2 0
2 dog4 0
3 dog5 0
4 dog3 0
5 cat3 1
6 cat2 1
7 cat1 1
8 cat3 1
9 fish1 2
10 fish1 2
11 fish1 2