代码之家  ›  专栏  ›  技术社区  ›  user4054919

Spark DataFrame UDF分区列

  •  1
  • user4054919  · 技术社区  · 7 年前

    我想变换一列。新列应仅包含原始列的分区。我定义了以下自定义项:

    def extract (index : Integer) = udf((v: Seq[Double]) => v.grouped(16).toSeq(index))
    

    myDF = myDF.withColumn("measurement_"+i,extract(i)($"vector"))
    

    原始向量列是使用以下内容创建的:

    var vectors :Seq[Seq[Double]] = myVectors
    vectors.toDF("vector")
    

    但最后我得到了以下错误:

    Failed to execute user defined function(anonfun$user$sparkapp$MyClass$$extract$2$1: (array<double>) => array<double>)
    

    1 回复  |  直到 7 年前
        1
  •  3
  •   akuiper    7 年前

    当我试图提取不存在的元素时,我可以重现错误,即给出一个大于序列长度的索引:

    val myDF = Seq(Seq(1.0, 2.0 ,3, 4.0), Seq(4.0,3,2,1)).toDF("vector")
    myDF: org.apache.spark.sql.DataFrame = [vector: array<double>]
    
    def extract (index : Integer) = udf((v: Seq[Double]) => v.grouped(2).toSeq(index))
    // extract: (index: Integer)org.apache.spark.sql.expressions.UserDefinedFunction
    
    val i = 2
    
    myDF.withColumn("measurement_"+i,extract(i)($"vector")).show
    

    出现以下错误:

    org.apache.spark.SparkException: Failed to execute user defined function($anonfun$extract$1: (array<double>) => array<double>)
    

    toSeq(index) ,尝试使用 toSeq.lift(index)

    def extract (index : Integer) = udf((v: Seq[Double]) => v.grouped(2).toSeq.lift(index))
    extract: (index: Integer)org.apache.spark.sql.expressions.UserDefinedFunction
    

    正常指数 :

    val i = 1    
    myDF.withColumn("measurement_"+i,extract(i)($"vector")).show
    +--------------------+-------------+
    |              vector|measurement_1|
    +--------------------+-------------+
    |[1.0, 2.0, 3.0, 4.0]|   [3.0, 4.0]|
    |[4.0, 3.0, 2.0, 1.0]|   [2.0, 1.0]|
    +--------------------+-------------+
    

    索引越界 :

    val i = 2
    myDF.withColumn("measurement_"+i,extract(i)($"vector")).show
    +--------------------+-------------+
    |              vector|measurement_2|
    +--------------------+-------------+
    |[1.0, 2.0, 3.0, 4.0]|         null|
    |[4.0, 3.0, 2.0, 1.0]|         null|
    +--------------------+-------------+