代码之家  ›  专栏  ›  技术社区  ›  angelcervera

在sparksql中处理具有循环引用的模型?

  •  0
  • angelcervera  · 技术社区  · 6 年前

    Scala/Spark SQL 2.2.1版

    假设这是最初的模型方法,当然,它不起作用(请记住,真正的模型有几十个属性):

    case class Branch(id: Int, branches: List[Branch] = List.empty)
    case class Tree(id: Int, branches: List[Branch])
    
    val trees = Seq(Tree(1, List(Branch(2, List.empty), Branch(3, List(Branch(4, List.empty))))))
    
    val ds = spark.createDataset(trees)
    ds.show
    

    这就是它抛出的错误:

    java.lang.UnsupportedOperationException: cannot have circular references in class, but got the circular reference of class Branch
    

    我们的最高等级是5级 . 因此,作为一种解决方法,我认为:

    case class BranchLevel5(id: Int)
    case class BranchLevel4(id: Int, branches: List[BranchLevel5] = List.empty)
    case class BranchLevel3(id: Int, branches: List[BranchLevel4] = List.empty)
    case class BranchLevel2(id: Int, branches: List[BranchLevel3] = List.empty)
    case class BranchLevel1(id: Int, branches: List[BranchLevel2] = List.empty)
    case class Tree(id: Int, branches: List[BranchLevel1])
    

    当然,这是有效的。但这一点也不优雅,您可以想象实现过程中的痛苦(可读性、耦合、维护、可用性、代码复制等)

    所以问题是, 如何处理模型中循环引用的情况?

    1 回复  |  直到 6 年前
        1
  •  0
  •   Worakarn Isaratham    6 年前

    如果您对使用私有API还满意,那么我发现了一种有效的方法:将整个自引用结构视为用户定义的类型。我遵循这个答案: https://stackoverflow.com/a/51957666/1823254 .

    package org.apache.spark.custom.udts // we're calling some private API so need to be under 'org.apache.spark'
    
    import java.io._
    import org.apache.spark.sql.types.{DataType, UDTRegistration, UserDefinedType}
    
    class BranchUDT extends UserDefinedType[Branch] {
    
      override def sqlType: DataType = org.apache.spark.sql.types.BinaryType
      override def serialize(obj: Branch): Any = {
        val bos = new ByteArrayOutputStream()
        val oos = new ObjectOutputStream(bos)
        oos.writeObject(obj)
        bos.toByteArray
      }
      override def deserialize(datum: Any): Branch = {
        val bis = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
        val ois = new ObjectInputStream(bis)
        val obj = ois.readObject()
        obj.asInstanceOf[Branch]
      }
    
      override def userClass: Class[Branch] = classOf[Branch]
    }
    
    object BranchUDT {
      def register() = UDTRegistration.register(classOf[Branch].getName, classOf[BranchUDT].getName)
    }
    

    BranchUDT.register()
    val trees = Seq(Tree(1, List(Branch(2, List.empty), Branch(3, List(Branch(4, List.empty))))))
    
    val ds = spark.createDataset(trees)
    ds.show(false)
    
    //+---+----------------------------------------------------+
    //|id |branches                                            |
    //+---+----------------------------------------------------+
    //|1  |[Branch(2,List()), Branch(3,List(Branch(4,List())))]|
    //+---+----------------------------------------------------+