问题描述
我正在进行《泰坦尼克号》花花公子比赛,我目前正在试图估算缺失的Age
值。
想法是计算训练集上每个Age
组的平均值[Pclass,Sex]
,然后使用此信息替换训练集和测试集上的NaN
。
这是我到目前为止所拥有的:
meanAgeTrain = train.groupby(['Pclass','Sex'])['Age'].transform('mean')
for df in [train,test]:
df['Age'] = df['Age'].fillna(meanAgeTrain)
问题在于,这仍然在测试集中保留了一些NaN值,同时消除了训练集中的所有Nan。我认为这与索引有关。
我需要的是:
- 计算训练集中每个Pclass /性别组的平均值
- 将训练集中的所有NaN值映射到正确的平均值
- 将测试集中的所有NaN值映射到正确的平均值(通过Pclass / Sex查找,而不是基于索引)
如何使用熊猫正确完成此操作?
编辑:
感谢您的建议。 @Reza的那个有效,但我不是100%理解。因此,我正在尝试提出自己的解决方案。
这可行,但是我是Pandas的新手,我想知道是否有更简单的方法来实现它。
trainMeans = self.train.groupby(['Pclass','Sex'])['Age'].mean().reset_index()
def f(x):
if x["Age"] == x["Age"]: # not NaN
return x["Age"]
return trainMeans.loc[(trainMeans["Pclass"] == x["Pclass"]) & (trainMeans["Sex"] == x["Sex"])]["Age"].values[0]
self.train['Age'] = self.train.apply(f,axis=1)
self.test['Age'] = self.test.apply(f,axis=1)
对我来说,特别是函数中的if似乎不是最佳实践。我需要一种仅将功能应用于NaN年龄段的方法。
编辑2 :
事实证明,重置索引使事情变得更加复杂和缓慢,因为将索引分组之后,我已经完全可以用作映射键了。这样更快,更容易:
trainMeans = self.train.groupby(['Pclass','Sex'])['Age'].mean()
def f(x):
if not np.isnan(x["Age"]): # not NaN
return x["Age"]
return trainMeans[x["Pclass"],x["Sex"]]
self.train['Age'] = self.train.apply(f,axis=1)
self.test['Age'] = self.test.apply(f,axis=1)
这可以进一步简化吗?
解决方法
- 您会看到两种填充方法,即具有均值的groupby fillna 和随机森林回归器,彼此之间的距离在大约1/100之内
- 有关统计比较,请参见答案底部。
用平均值填充nan值
- 将
.groupby
,.apply
和fillna
与.mean
一起使用。 - 下面的代码用整个数据集的每个组的平均值填充
nans
。 - Titanic Age Analysis
import pandas as pd
import seaborn as sns
# load dataset
df = sns.load_dataset('titanic')
# map sex to a numeric type
df.sex = df.sex.map({'male': 1,'female': 0})
# Populate Age_Fill
df['Age_Fill'] = df['age'].groupby([df['pclass'],df['sex']]).apply(lambda x: x.fillna(x.mean()))
# series with filled ages
groupby_result = df.Age_Fill[df.age.isnull()]
# display(df[df.age.isnull()].head())
survived pclass sex age sibsp parch fare embarked class who adult_male deck embark_town alive alone Age_Fill
0 3 male NaN 0 0 8.4583 Q Third man True NaN Queenstown no True 26.50759
1 2 male NaN 0 0 13.0000 S Second man True NaN Southampton yes True 30.74071
1 3 female NaN 0 0 7.2250 C Third woman False NaN Cherbourg yes True 21.75000
0 3 male NaN 0 0 7.2250 C Third man True NaN Cherbourg no True 26.50759
1 3 female NaN 0 0 7.8792 Q Third woman False NaN Queenstown yes True 21.75000
从RandomForestRegressor填充nan值
-
sklearn.ensemble.RandomForestRegressor
- Kaggle: Titanic
- 年龄似乎是很有前途的功能。因此,简单地用中位数/均值/众数填充空值是没有意义的。
- 根据此处的结果,我认为这没有什么不同
from sklearn.ensemble import RandomForestRegressor
import pandas as pd
import seaborn as sns
# load dataset
df = sns.load_dataset('titanic')
# map sex to a numeric type
df.sex = df.sex.map({'male': 1,'female': 0})
# split data
train = df.loc[(df.age.notnull())] # known age values
test = df.loc[(df.age.isnull())] # all nan age values
# select age column
y = train.values[:,3]
# select pclass and sex
X = train.values[:,[1,2]]
# create RandomForestRegressor model
rfr = RandomForestRegressor(n_estimators=2000,n_jobs=-1)
# Fit a model
rfr.fit(X,y)
# Use the fitted model to predict the missing values
predictedAges = rfr.predict(test.values[:,2]])
# create predicted age column
df['pred_age'] = df.age
# fill column
df.loc[(df.pred_age.isnull()),'pred_age'] = predictedAges
# display(df[df.age.isnull()].head())
survived pclass sex age sibsp parch fare embarked class who adult_male deck embark_town alive alone pred_age
0 3 1 NaN 0 0 8.4583 Q Third man True NaN Queenstown no True 26.49935
1 2 1 NaN 0 0 13.0000 S Second man True NaN Southampton yes True 30.73126
1 3 0 NaN 0 0 7.2250 C Third woman False NaN Cherbourg yes True 21.76513
0 3 1 NaN 0 0 7.2250 C Third man True NaN Cherbourg no True 26.49935
1 3 0 NaN 0 0 7.8792 Q Third woman False NaN Queenstown yes True 21.76513
rfr组的比较
print(predictedAges - groupby_result).describe())
count 177.00000
mean 0.00362
std 0.01877
min -0.04167
25% 0.01121
50% 0.01121
75% 0.01131
max 0.02969
Name: Age_Fill,dtype: float64
# comparison dataframe
comp = pd.DataFrame({'rfr': predictedAges.tolist(),'gb': groupby_result.tolist()})
comp['diff'] = comp.rfr - comp.gb
# display(comp)
rfr gb diff
26.51880 26.50759 0.01121
30.69903 30.74071 -0.04167
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
30.69903 30.74071 -0.04167
41.24592 41.28139 -0.03547
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
41.24592 41.28139 -0.03547
21.76131 21.75000 0.01131
30.69903 30.74071 -0.04167
41.24592 41.28139 -0.03547
41.24592 41.28139 -0.03547
41.24592 41.28139 -0.03547
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
28.75266 28.72297 0.02969
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
34.63090 34.61176 0.01913
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
34.63090 34.61176 0.01913
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
30.69903 30.74071 -0.04167
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
30.69903 30.74071 -0.04167
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
30.69903 30.74071 -0.04167
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
30.69903 30.74071 -0.04167
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
28.75266 28.72297 0.02969
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
30.69903 30.74071 -0.04167
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
30.69903 30.74071 -0.04167
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
41.24592 41.28139 -0.03547
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
41.24592 41.28139 -0.03547
26.51880 26.50759 0.01121
34.63090 34.61176 0.01913
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
26.51880 26.50759 0.01121
26.51880 26.50759 0.01121
21.76131 21.75000 0.01131
在随机训练集上计算均值
- 此示例计算出随机训练集的平均值,然后将
nan
值填充到训练集和测试集中 - 使用
pandas.DataFrame.fillna
,当两个数据帧都具有匹配的索引并且填充列相同时,它将从另一个数据帧填充数据框列中的缺失值。- Pclass / Sex而不基于索引 ,
pclass
和sex
被设置为索引,.fillna
的工作方式就是这样。
- Pclass / Sex而不基于索引 ,
- 在此示例中,
train
是数据的67%,而test
是数据的33%。 根据需要设置 -
test_size
和train_size
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
# load dataset
df = sns.load_dataset('titanic')
# map sex to a numeric type
df.sex = df.sex.map({'male': 1,'female': 0})
# randomly split the dataframe into a train and test set
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.33,random_state=42)
# select columns for X and y
X = df[['pclass','sex']]
y = df['age']
# create a dataframe of train (X,y) and test (X,y)
train = pd.concat([X_train,y_train],axis=1).reset_index(drop=True)
test = pd.concat([X_test,y_test],axis=1).reset_index(drop=True)
# calculate means for train
train_means = train.groupby(['pclass','sex']).agg({'age': 'mean'})
# display train_means,a multi-index dataframe
age
pclass sex
1 0 34.66667
1 41.38710
2 0 27.90217
1 30.50000
3 0 21.56338
1 26.87163
# fill nan values in train
train = train.set_index(['pclass','sex']).age.fillna(train_means.age).reset_index()
# fill nan values in test
test = test.set_index(['pclass','sex']).age.fillna(train_means.age).reset_index()
,
您可以先为Age
创建一个地图:
cols = ['Pclass','Sex']
age_class_sex = train.groupby(cols)['Age'].mean().reset_index()
然后将其与测试合并并分别训练,以便解决索引
train['Age'] = train['Age'].fillna(train[cols].reset_index().merge(age_class_sex,how='left',on=cols).set_index('index')['Age'])
test['Age'] = test['Age'].fillna(test[cols].reset_index().merge(age_class_sex,on=cols).set_index('index')['Age'])