代码之家  ›  专栏  ›  技术社区  ›  Chuang

在Spark dataframe udf中,像struct(col1,col2)这样的函数参数是什么类型的?

  •  1
  • Chuang  · 技术社区  · 6 年前

    背景:

    我有一个数据框,它有三列: id, x, y 。x、 y是双倍的。

    • 首先,我 struct (col("x"),col("y")) 获取坐标列。
    • 然后 groupBy(col("id")) agg(collect_list(col("coordinate")))

    因此,现在df只有两列: id ,coordinate

    我认为坐标的数据类型是 collection.mutable.WrappedArray[(Double,Double)] 。 所以我把它交给了udf。但是,数据类型是错误的。我在运行代码时出错。我不知道为什么。结构(col1,col2)的真正数据类型是什么?还是有其他方法可以轻松获得正确答案?

    这是代码:

    def getMedianPoint = udf((array1: collection.mutable.WrappedArray[(Double,Double)]) => {  
        var l = (array1.length/2)
        var c = array1(l)
        val x = c._1.asInstanceOf[Double]
        val y = c._2.asInstanceOf[Double]
        (x,y)
    })
    
    df.withColumn("coordinate",struct(col("x"),col("y")))
      .groupBy(col("id"))
      .agg(collect_list("coordinate").as("coordinate")
      .withColumn("median",getMedianPoint(col("coordinate")))
    

    非常感谢!

    1 回复  |  直到 6 年前
        1
  •  1
  •   Ramesh Maharjan    6 年前

    我认为坐标的数据类型是集合。可变。WrappedArray[(双,双)]

    是的,你说的完全正确 。和 您在udf函数中定义的数据类型以及作为参数传递的数据类型也是正确的 。但是 主要问题是struct列的键的名称 。因为你一定有以下问题

    由于数据类型不匹配,无法解析“UDF(坐标)”:参数1需要数组>但是,键入' coordinate '是数组的(>);类型

    只需使用 alias 重命名结构键

    df.withColumn("coordinate",struct(col("x").as("_1"),col("y").as("_2")))
      .groupBy(col("id"))
      .agg(collect_list("coordinate").as("coordinate"))
        .withColumn("median",getMedianPoint(col("coordinate")))
    

    以便关键字名称匹配。

    但是

    这将引发另一个问题

      var c = array1(l)
    

    原因:java。lang.ClassCastException:组织。阿帕奇。火花sql。催化剂表达式。无法将GenericRowWithSchema转换为scala。元组2

    所以我建议你改变 udf 功能为

    import org.apache.spark.sql.functions._
    
    def getMedianPoint = udf((array1: Seq[Row]) => {
      var l = (array1.length/2)
      (array1(l)(0).asInstanceOf[Double], array1(l)(1).asInstanceOf[Double])
    })
    

    这样你就不需要使用 别名 也因此,完整的解决方案是

    import org.apache.spark.sql.functions._
    
    def getMedianPoint = udf((array1: Seq[Row]) => {
      var l = (array1.length/2)
      (array1(l)(0).asInstanceOf[Double], array1(l)(1).asInstanceOf[Double])
    })
    
    df.withColumn("coordinate",struct(col("x"),col("y")))
      .groupBy(col("id"))
      .agg(collect_list("coordinate").as("coordinate"))
        .withColumn("median",getMedianPoint(col("coordinate")))
      .show(false)
    

    我希望答案有帮助