如何使用R有效地将每一行拆分为测试和训练子集?

问题描述

我有一个数据表,提供给定向量的长度和组成 例如:

set.seed(1)

dt = data.table(length = c(100,150),n_A = c(30,30),n_B = c(20,100),n_C = c(50,20))

我需要将每个向量随机分为两个子集,分别具有80%和20%的观察值。我目前可以使用for循环执行此操作。例如:

dt_80_list <- list() # create output lists
dt_20_list <- list()

for (i in 1:nrow(dt)){ # for each row in the data.table
  
  sample_vec <- sample( c(   rep("A",dt$n_A[i]),# create a randomised vector with the given nnumber of each component. 
                             rep("B",dt$n_B[i]),rep("C",dt$n_C[i]) ) )
  
  sample_vec_80 <- sample_vec[1:floor(length(sample_vec)*0.8)] # subset 80% of the vector
  
  dt_80_list[[i]] <- data.table(   length = length(sample_vec_80),# count the number of each component in the subset and output to list
                         n_A = length(sample_vec_80[which(sample_vec_80 == "A")]),n_B = length(sample_vec_80[which(sample_vec_80 == "B")]),n_C = length(sample_vec_80[which(sample_vec_80 == "C")])
  )
  
  dt_20_list[[i]] <- data.table(   length = dt$length[i] - dt_80_list[[i]]$length,# subtract the number of each component in the 80% to identify the number in the 20%
                         n_A = dt$n_A[i] - dt_80_list[[i]]$n_A,n_B = dt$n_B[i] - dt_80_list[[i]]$n_B,n_C = dt$n_C[i] - dt_80_list[[i]]$n_C
  )
}
dt_80 <- do.call("rbind",dt_80_list) # collapse lists to output data.tables
dt_20 <- do.call("rbind",dt_20_list)

但是,我需要将此应用于的数据集非常大,而且速度太慢。有人对我如何提高性能有任何建议吗?

谢谢。

解决方法

(我假设您的数据集包含更多行(但只有几行)。)

这是我想出的一个版本,主要有三个变化

  • 使用.Nby=来计算每一行中绘制的“ A”,“ B”和“ C”的数量
  • sample中使用size参数
  • 加入原始dtdt_80以计算dt_20,而无需for循环
## draw training data
dt_80 <- dcast(
      dt[,row:=1:nrow(dt)
       ][,.(draw=sample(c(rep("A80",n_A),rep("B80",n_B),rep("C80",n_C)),size=.8*length)  ),by=row
       ][,.N,by=.(row,draw)],row~draw,value.var="N")[,length80:=A80+B80+C80]

## draw test data
dt_20 <- dt[dt_80,.(A20=n_A-A80,B20=n_B-B80,C20=n_C-C80),on="row"][,length20:=A20+B20+C20]

可能仍有优化的空间,但我希望它已经可以帮助:)

编辑

在这里,我添加了我最初的第一个想法,我没有发布它,因为上面的代码要快得多。但这可能会提高内存效率,这对您而言至关重要。因此,即使您已经有了可行的解决方案,也可能会对此感兴趣……

library(data.table)
library(Rfast)

## add row numbers
dt[,row:=1:nrow(dt)]

## sampling function
sampfunc <- function(n_A,n_B,n_C){ 
  draw <- sample(c(rep("A80",size=.8*(n_A+n_B+n_C))
  out <- Rfast::Table(draw)
  return(as.list(out))
}

## draw training data
dt_80 <- dt[,sampfunc(n_A,n_C),by=row]