在 dbplyr 中传递要作为函数参数应用的函数

问题描述

假设我想创建一个函数,该函数可以使用用户传递的任何函数来改变列。我需要知道如何在函数到达 dbplyr 解析器之前引用和取消引用该函数。让我们看一个例子,假设我有一个这样的函数

testFun <- function(data,fun,colName,colOut = "myAwesomeColumn") {
  dplyr::mutate(.data = data,!!colOut := fun(.data[[colName]]))
}

sc <- sparklyr::spark_connection(master = "local")
mtcars_spark <- dplyr::copy_to(sc,mtcars,"mtcars")
testFun(mtcars_spark,mean,"mpg")

因此在上面的示例中,我想将 mean() 函数应用于 "mpg" 列并将其存储在名为 "myAwesomeColumn" 的新列中。

当使用 Spark,特别是 sparklyr 时,dbplyr 将尝试将此代码转换sql 并将其发送到 Spark。我的理解是 dbplyr 应用以下规则:

  1. 如果它可以找到等效的 Spark sql,它将使用它(例如 mean() -> AVG
  2. 否则它将按原样传递函数以查找 Scala 扩展或 UDF

第二个选项是这里发生的事情,因为它找不到函数 fun,因此它返回一个 Spark 错误

Error: org.apache.spark.sql.AnalysisException: Undefined function: 'fun'.
This function is neither a registered temporary function nor a permanent function
registered in the database 'default'.; line 1 pos 85
...

所以我们需要另一种方法。问题是让 rlang 在 dbplyr 解释之前将 fun 转换为 mean。我知道如果我将函数名称作为字符串传递并使用 rlang::parse_expr(),我可以做到这一点,例如:

testFun <- function(data,colOut = "myAwesomeColumn") {
  dplyr::mutate(data,!!colOut := rlang::parse_expr(paste0(fun,"(",")"))
}
testFun(mtcars_spark,"mean","mpg")
# # Source: spark<?> [?? x 12]
#      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb myAwesomeColumn
#    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>           <dbl>
#  1  21       6  160    110  3.9   2.62  16.5     0     1     4     4            20.1
#  2  21       6  160    110  3.9   2.88  17.0     0     1     4     4            20.1
# # ... with more rows

解决方法

为了让它起作用,我们必须引用和取消引用 fun 参数。我们还构建了我们实际上想要传递到我们对 mutate() 的调用中的表达式。解决方法见下文。

testFun <- function(data,fun,colAmount,colOut = "output") { 
  fun <- rlang::enquo(fun) 
  dplyr::mutate(.data = data,!!colOut := rlang::call2(.fn = !!fun,rlang::sym(colAmount))) 
} 
     
testFun(mtcars_spark,mean,"mpg")                                                                                                                                 
# # Source: spark<?> [?? x 12]
#      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb output
#    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>  <dbl>
#  1  21       6  160    110  3.9   2.62  16.5     0     1     4     4   20.1
#  2  21       6  160    110  3.9   2.88  17.0     0     1     4     4   20.1
# # ... with more rows

请注意,如果您使用 data.frame 而不是 tbl_spark,这会简单得多。