代码之家  ›  专栏  ›  技术社区  ›  Raphael Roth

如何将Spark列的名称作为字符串获取?

  •  0
  • Raphael Roth  · 技术社区  · 6 年前

    我想编写一个方法,在不执行以下操作的情况下对数值列进行四舍五入:

    df
    .select(round($"x",2).as("x"))
    

    因此,我需要一个可重用的列表达式,如:

    def roundKeepName(c:Column,scale:Int) = round(c,scale).as(c.name)
    

    不幸的是 c.name 不存在,因此不编译上述代码。我找到了解决问题的办法 ColumName :

     def roundKeepName(c:ColumnName,scale:Int) = round(c,scale).as(c.string.name)
    

    Column (如果我使用 col("x") 而不是 $"x"

    2 回复  |  直到 6 年前
        1
  •  4
  •   Oli    5 年前

    不确定这个问题是否真的得到了回答。您的函数可以这样实现( toString 返回列的名称):

    def roundKeepname(c:Column,scale:Int) = round(c,scale).as(c.toString)
    

    如果您不喜欢依赖toString,这里有一个更健壮的版本。您可以依赖于基础表达式,将其强制转换为NamedExpression并取其名称。

    import org.apache.spark.sql.catalyst.expressions.NamedExpression
    def roundKeepname(c:Column,scale:Int) = 
        c.expr.asInstanceOf[NamedExpression].name
    

    它的工作原理是:

    scala> spark.range(2).select(roundKeepname('id, 2)).show
    +---+
    | id|
    +---+
    |  0|
    |  1|
    +---+  
    

    最后,如果您可以使用列的名称而不是列对象,那么您可以更改函数的签名,从而得到更简单的实现:

    def roundKeepName(columnName:String, scale:Int) = 
        round(col(columnName),scale).as(columnName)
    
        2
  •  1
  •   stack0114106    6 年前

    更新:

    scala> val df = Seq((1.22,4.34,8.93),(3.44,12.66,17.44),(5.66,9.35,6.54)).toDF("x","y","z")
    df: org.apache.spark.sql.DataFrame = [x: double, y: double ... 1 more field]
    
    scala> df.show
    +----+-----+-----+
    |   x|    y|    z|
    +----+-----+-----+
    |1.22| 4.34| 8.93|
    |3.44|12.66|17.44|
    |5.66| 9.35| 6.54|
    +----+-----+-----+
    
    
    scala>  df.columns.foldLeft(df)( (acc,p)  => (acc.withColumn(p+"_t",round(col(p),1)).drop(p).withColumnRenamed(p+"_t",p))).show
    +---+----+----+
    |  x|   y|   z|
    +---+----+----+
    |1.2| 4.3| 8.9|
    |3.4|12.7|17.4|
    |5.7| 9.4| 6.5|
    +---+----+----+
    
    
    scala>