如何在“泰坦尼克时代”栏中通过估算来填充NaN值?

问题描述

正在进行《泰坦尼克号》花花公子比赛,我目前正在试图估算缺失的Age值。

想法是计算训练集上每个Age组的平均值[Pclass,Sex],然后使用此信息替换训练集和测试集上的NaN

这是我到目前为止所拥有的:

meanAgeTrain = train.groupby(['Pclass','Sex'])['Age'].transform('mean')
    
for df in [train,test]:
    df['Age'] = df['Age'].fillna(meanAgeTrain)

问题在于,这仍然在测试集中保留了一些NaN值,同时消除了训练集中的所有Nan。我认为这与索引有关。

我需要的是:

  1. 计算训练集中每个Pclass /性别组的平均值
  2. 将训练集中的所有NaN值映射到正确的平均值
  3. 将测试集中的所有NaN值映射到正确的平均值(通过Pclass / Sex查找,而不是基于索引)

如何使用熊猫正确完成此操作?

编辑:

感谢您的建议。 @Reza的那个有效,但我不是100%理解。因此,我正在尝试提出自己的解决方案。

这可行,但是我是Pandas的新手,我想知道是否有更简单的方法来实现它。

trainMeans = self.train.groupby(['Pclass','Sex'])['Age'].mean().reset_index()

def f(x):
    if x["Age"] == x["Age"]:  # not NaN
        return x["Age"]
    return trainMeans.loc[(trainMeans["Pclass"] == x["Pclass"]) & (trainMeans["Sex"] == x["Sex"])]["Age"].values[0]

 self.train['Age'] = self.train.apply(f,axis=1)
 self.test['Age'] = self.test.apply(f,axis=1)

对我来说,特别是函数中的if似乎不是最佳实践。我需要一种仅将功能应用于NaN年龄段的方法

编辑2

事实证明,重置索引使事情变得更加复杂和缓慢,因为将索引分组之后,我已经完全可以用作映射键了。这样更快,更容易:

trainMeans = self.train.groupby(['Pclass','Sex'])['Age'].mean()

def f(x):
    if not np.isnan(x["Age"]):  # not NaN
        return x["Age"]
    return trainMeans[x["Pclass"],x["Sex"]]

self.train['Age'] = self.train.apply(f,axis=1)
self.test['Age'] = self.test.apply(f,axis=1)

这可以进一步简化吗?

解决方法

  • 您会看到两种填充方法,即具有均值的groupby fillna 随机森林回归器,彼此之间的距离在大约1/100之内
    • 有关统计比较,请参见答案底部。

用平均值填充nan值

import pandas as pd
import seaborn as sns

# load dataset
df = sns.load_dataset('titanic')

# map sex to a numeric type
df.sex = df.sex.map({'male': 1,'female': 0})

# Populate Age_Fill
df['Age_Fill'] = df['age'].groupby([df['pclass'],df['sex']]).apply(lambda x: x.fillna(x.mean()))

# series with filled ages
groupby_result = df.Age_Fill[df.age.isnull()]

# display(df[df.age.isnull()].head())
 survived  pclass     sex  age  sibsp  parch     fare embarked   class    who  adult_male deck  embark_town alive  alone  Age_Fill
        0       3    male  NaN      0      0   8.4583        Q   Third    man        True  NaN   Queenstown    no   True  26.50759
        1       2    male  NaN      0      0  13.0000        S  Second    man        True  NaN  Southampton   yes   True  30.74071
        1       3  female  NaN      0      0   7.2250        C   Third  woman       False  NaN    Cherbourg   yes   True  21.75000
        0       3    male  NaN      0      0   7.2250        C   Third    man        True  NaN    Cherbourg    no   True  26.50759
        1       3  female  NaN      0      0   7.8792        Q   Third  woman       False  NaN   Queenstown   yes   True  21.75000

从RandomForestRegressor填充nan值

from sklearn.ensemble import RandomForestRegressor
import pandas as pd
import seaborn as sns

# load dataset
df = sns.load_dataset('titanic')

# map sex to a numeric type
df.sex = df.sex.map({'male': 1,'female': 0})

# split data
train = df.loc[(df.age.notnull())]  # known age values
test = df.loc[(df.age.isnull())]  # all nan age values

# select age column
y = train.values[:,3]

# select pclass and sex
X = train.values[:,[1,2]]

# create RandomForestRegressor model
rfr = RandomForestRegressor(n_estimators=2000,n_jobs=-1)

# Fit a model
rfr.fit(X,y)

# Use the fitted model to predict the missing values
predictedAges = rfr.predict(test.values[:,2]])

# create predicted age column
df['pred_age'] = df.age

# fill column
df.loc[(df.pred_age.isnull()),'pred_age'] = predictedAges 

# display(df[df.age.isnull()].head())
 survived  pclass  sex  age  sibsp  parch     fare embarked   class    who  adult_male deck  embark_town alive  alone  pred_age
        0       3    1  NaN      0      0   8.4583        Q   Third    man        True  NaN   Queenstown    no   True  26.49935
        1       2    1  NaN      0      0  13.0000        S  Second    man        True  NaN  Southampton   yes   True  30.73126
        1       3    0  NaN      0      0   7.2250        C   Third  woman       False  NaN    Cherbourg   yes   True  21.76513
        0       3    1  NaN      0      0   7.2250        C   Third    man        True  NaN    Cherbourg    no   True  26.49935
        1       3    0  NaN      0      0   7.8792        Q   Third  woman       False  NaN   Queenstown   yes   True  21.76513

rfr组的比较

print(predictedAges - groupby_result).describe())

count    177.00000
mean       0.00362
std        0.01877
min       -0.04167
25%        0.01121
50%        0.01121
75%        0.01131
max        0.02969
Name: Age_Fill,dtype: float64

# comparison dataframe
comp = pd.DataFrame({'rfr': predictedAges.tolist(),'gb': groupby_result.tolist()})
comp['diff'] = comp.rfr - comp.gb

# display(comp)
      rfr        gb     diff
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 30.69903  30.74071 -0.04167
 41.24592  41.28139 -0.03547
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 21.76131  21.75000  0.01131
 30.69903  30.74071 -0.04167
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 28.75266  28.72297  0.02969
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 28.75266  28.72297  0.02969
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 30.69903  30.74071 -0.04167
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131

在随机训练集上计算均值

  • 此示例计算出随机训练集的平均值,然后将nan值填充到训练集和测试集中
  • 使用pandas.DataFrame.fillna,当两个数据帧都具有匹配的索引并且填充列相同时,它将从另一个数据帧填充数据框列中的缺失值。
    • Pclass / Sex而不基于索引pclasssex被设置为索引,.fillna的工作方式就是这样。
  • 在此示例中,train是数据的67%,而test是数据的33%。
  • 根据需要设置
  • test_sizetrain_size
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split

# load dataset
df = sns.load_dataset('titanic')

# map sex to a numeric type
df.sex = df.sex.map({'male': 1,'female': 0})

# randomly split the dataframe into a train and test set
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.33,random_state=42)

# select columns for X and y
X = df[['pclass','sex']]
y = df['age']

# create a dataframe of train (X,y) and test (X,y)
train = pd.concat([X_train,y_train],axis=1).reset_index(drop=True)
test = pd.concat([X_test,y_test],axis=1).reset_index(drop=True)

# calculate means for train
train_means = train.groupby(['pclass','sex']).agg({'age': 'mean'})

# display train_means,a multi-index dataframe
                 age
pclass sex          
1      0    34.66667
       1    41.38710
2      0    27.90217
       1    30.50000
3      0    21.56338
       1    26.87163

# fill nan values in train
train = train.set_index(['pclass','sex']).age.fillna(train_means.age).reset_index()

# fill nan values in test
test = test.set_index(['pclass','sex']).age.fillna(train_means.age).reset_index()
,

您可以先为Age创建一个地图:

cols = ['Pclass','Sex']
age_class_sex = train.groupby(cols)['Age'].mean().reset_index()

然后将其与测试合并并分别训练,以便解决索引

train['Age'] = train['Age'].fillna(train[cols].reset_index().merge(age_class_sex,how='left',on=cols).set_index('index')['Age'])
test['Age'] = test['Age'].fillna(test[cols].reset_index().merge(age_class_sex,on=cols).set_index('index')['Age'])