代码之家  ›  专栏  ›  技术社区  ›  Jeff Saremi

Spark ML:决策树分类模型如何知道树的权重?

  •  0
  • Jeff Saremi  · 技术社区  · 6 年前

    我想从保存(或未保存)中获取树节点的权重 DecisionTreeClassificationModel . 但是我找不到任何类似的东西。

    模型在不知道这些的情况下如何实际执行分类。下面是保存在模型中的参数:

    {"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel"
    "timestamp":1551207582648
    "sparkVersion":"2.3.2"
    "uid":"DecisionTreeClassifier_4ffc94d20f1ddb29f282"
    "paramMap":{
    "cacheNodeIds":false
    "maxBins":32
    "minInstancesPerNode":1
    "predictionCol":"prediction"
    "minInfoGain":0.0
    "rawPredictionCol":"rawPrediction"
    "featuresCol":"features"
    "probabilityCol":"probability"
    "checkpointInterval":10
    "seed":956191873026065186
    "impurity":"gini"
    "maxMemoryInMB":256
    "maxDepth":2
    "labelCol":"indexed"
    }
    "numFeatures":1
    "numClasses":2
    }
    
    0 回复  |  直到 6 年前
        1
  •  1
  •   10465355 user11020637    6 年前

    通过使用 treeWeights :

    树重

    返回每棵树的权重

    版本1.5.0中的新功能。

    所以

    模型在不知道这些的情况下如何实际执行分类。

    权重是存储的,而不是作为元数据的一部分。如果你有 model

    from pyspark.ml.classification import RandomForestClassificationModel
    
    model: RandomForestClassificationModel = ...
    

    并保存到磁盘

    path: str = ...
    
    model.save(path)
    

    你会看到作者创造了 treesMetadata 子目录。如果加载内容(默认编写器使用拼花):

    import os
    
    trees_metadata = spark.read.parquet(os.path.join(path, "treesMetadata"))
    

    您将看到以下结构:

    trees_metadata.printSchema()
    
    root
     |-- treeID: integer (nullable = true)
     |-- metadata: string (nullable = true)
     |-- weights: double (nullable = true)
    

    哪里 weights 列包含由 treeID .

    类似地,节点数据存储在 data 子目录(参见示例 Extract and Visualize Model Trees from Sparklyr ):

    spark.read.parquet(os.path.join(path, "data")).printSchema()     
    
    root
     |-- id: integer (nullable = true)
     |-- prediction: double (nullable = true)
     |-- impurity: double (nullable = true)
     |-- impurityStats: array (nullable = true)
     |    |-- element: double (containsNull = true)
     |-- gain: double (nullable = true)
     |-- leftChild: integer (nullable = true)
     |-- rightChild: integer (nullable = true)
     |-- split: struct (nullable = true)
     |    |-- featureIndex: integer (nullable = true)
     |    |-- leftCategoriesOrThreshold: array (nullable = true)
     |    |    |-- element: double (containsNull = true)
     |    |-- numCategories: integer (nullable = true)
    

    等效信息(减去树数据和树权重)可用于 DecisionTreeClassificationModel 也。