代码之家  ›  专栏  ›  技术社区  ›  shane

在spark中,将Seq[(String,Any)]转换为Seq[(String,org.apache.spark.ml.PredictionModel[\uu,\])]

  •  -1
  • shane  · 技术社区  · 6 年前

    我将数据集训练成不同的模型,如nbModel、dtModel、rfModel、GbmModel。所有这些都是机器学习模型

    现在,当我将其保存到变量中时

    val models = Seq(("NB", nbModel), ("DT", dtModel), ("RF", rfModel), ("GBM",gbmModel))
    

    我得到一个Seq[(字符串,任意)]

    models: Seq[(String, Any)] = List((NB,NaiveBayesModel (uid=nb_c35f79982850) with 2 classes), (DT,()), (RF,RandomForestClassificationModel (uid=rfc_3f42daf4ea14) with 15 trees), (GBM,GBTClassificationModel (uid=gbtc_534a972357fa) with 20 trees))
    

    如果单个模型,如nbModel

     val models = ("NB", nbModel)
    

    输出: models: (String, org.apache.spark.ml.classification.NaiveBayesModel) = (NB,NaiveBayesModel (uid=nb_c35f79982850) with 2 classes)

    当我试图合并这些模型中的几个列时,我得到了类型不匹配错误

    val mlTrainData= mlData(transferData, "value", models).drop("row_id")
    

    <console>:75: error: type mismatch; found : Seq[(String, Any)] required: Seq[(String, org.apache.spark.ml.PredictionModel[_, _])] val mlTrainData= mlData(transferData, "value", models).drop("row_id")

    我的MlDATA也是

    def mlData(inputData: DataFrame, responseColumn: String, baseModels:
     | Seq[(String, PredictionModel[_, _])]): DataFrame= {
     | baseModels.map{ case(name, model) =>
     | model.transform(inputData)
     | .select("row_id", model.getPredictionCol )
     | .withColumnRenamed("prediction", s"${name}_prediction")
     | }.reduceLeft((a, b) =>a.join(b, Seq("row_id"), "inner"))
     | .join(inputData.select("row_id", responseColumn), Seq("row_id"),
     | "inner")
     | }
    

    输出: mlData: (inputData: org.apache.spark.sql.DataFrame, responseColumn: String, baseModels: Seq[(String, org.apache.spark.ml.PredictionModel[_, _])])org.apache.spark.sql.DataFrame

    1 回复  |  直到 6 年前
        1
  •  0
  •   shuvomiah    6 年前

    你能把代码换掉吗

    val models = Seq(("NB", nbModel), ("DT", dtModel), ("RF", rfModel), ("GBM",gbmModel))
    

    通过

    val models = Seq(("NB", nbModel), ("DT", null : org.apache.spark.mllib.tree.model.DecisionTreeModel), ("RF", rfModel), ("GBM",gbmModel))
    

    我想说的是 dtModel公司 已分配 () 类型为 单元 。因此,整个数据集的类型成为DecisionTreeModel和Unit的超类,即 任何 。您需要确保dtModel的类型是DecisionTreeModel,如果该类型为null,如果您已经处理了null的情况,那么就可以了。一个空的DecisionTreeModel也可以工作。