梯度迭代樹(GBDT)演算法原理及Spark MLlib調用執行個體(Scala/Java/python)__編碼

來源:互聯網
上載者:User

梯度迭代樹

演算法簡介:

        梯度提升樹是一種決策樹的整合演算法。它通過反覆迭代訓練決策樹來最小化損失函數。決策樹類似,梯度提升樹具有可處理類別特徵、易擴充到多分類問題、不需特徵縮放等性質。Spark.ml通過使用現有decision tree工具來實現。

       梯度提升樹依次迭代訓練一系列的決策樹。在一次迭代中,演算法使用現有的整合來對每個訓練執行個體的類別進行預測,然後將預測結果與真實的標籤值進行比較。通過重新標記,來賦予預測結果不好的執行個體更高的權重。所以,在下次迭代中,決策樹會對先前的錯誤進行修正。

       對執行個體標籤進行重新標記的機制由損失函數來指定。每次迭代過程中,梯度迭代樹在訓練資料上進一步減少損失函數的值。spark.ml為分類問題提供一種損失函數(Log Loss),為迴歸問題提供兩種損失函數(平方誤差與絕對誤差)。

       Spark.ml支援二分類以及迴歸的隨機森林演算法,適用於連續特徵以及類別特徵。

*注意梯度提升樹目前不支援多分類問題。

參數:

checkpointInterval:

類型:整數型。

含義:設定檢查點間隔(>=1),或不設定檢查點(-1)。

featuresCol:

類型:字串型。

含義:特徵列名。

impurity:

類型:字串型。

含義:計算資訊增益的準則(不區分大小寫)。

labelCol:

類型:字串型。

含義:標籤列名。

lossType:

類型:字串型。

含義:損失函數類型。

maxBins:

類型:整數型。

含義:連續特徵離散化的最大數量,以及選擇每個節點分裂特徵的方式。

maxDepth:

類型:整數型。

含義:樹的最大深度(>=0)。

maxIter:

類型:整數型。

含義:迭代次數(>=0)。

minInfoGain:

類型:雙精確度型。

含義:分裂節點時所需最小資訊增益。

minInstancesPerNode:

類型:整數型。

含義:分裂後自節點最少包含的執行個體數量。

predictionCol:

類型:字串型。

含義:預測結果列名。

rawPredictionCol:

類型:字串型。

含義:原始預測。

seed:

類型:長整型。

含義:隨機種子。

subsamplingRate:

類型:雙精確度型。

含義:學習一棵決策樹使用的訓練資料比例,範圍[0,1]。

stepSize:

類型:雙精確度型。

含義:每次迭代最佳化步長。

樣本:

       下面的例子匯入LibSVM格式資料,並將之劃分為訓練資料和測試資料。使用第一部分資料進行訓練,剩下資料來測試。訓練之前我們使用了兩種資料預先處理方法來對特徵進行轉換,並且添加了中繼資料到DataFrame。

Scala:

import org.apache.spark.ml.Pipelineimport org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}// Load and parse the data file, converting it to a DataFrame.val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")// Index labels, adding metadata to the label column.// Fit on whole dataset to include all labels in index.val labelIndexer = new StringIndexer()  .setInputCol("label")  .setOutputCol("indexedLabel")  .fit(data)// Automatically identify categorical features, and index them.// Set maxCategories so features with > 4 distinct values are treated as continuous.val featureIndexer = new VectorIndexer()  .setInputCol("features")  .setOutputCol("indexedFeatures")  .setMaxCategories(4)  .fit(data)// Split the data into training and test sets (30% held out for testing).val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))// Train a GBT model.val gbt = new GBTClassifier()  .setLabelCol("indexedLabel")  .setFeaturesCol("indexedFeatures")  .setMaxIter(10)// Convert indexed labels back to original labels.val labelConverter = new IndexToString()  .setInputCol("prediction")  .setOutputCol("predictedLabel")  .setLabels(labelIndexer.labels)// Chain indexers and GBT in a Pipeline.val pipeline = new Pipeline()  .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))// Train model. This also runs the indexers.val model = pipeline.fit(trainingData)// Make predictions.val predictions = model.transform(testData)// Select example rows to display.predictions.select("predictedLabel", "label", "features").show(5)// Select (prediction, true label) and compute test error.val evaluator = new MulticlassClassificationEvaluator()  .setLabelCol("indexedLabel")  .setPredictionCol("prediction")  .setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)println("Test Error = " + (1.0 - accuracy))val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]println("Learned classification GBT model:\n" + gbtModel.toDebugString)

Java:

import org.apache.spark.ml.Pipeline;import org.apache.spark.ml.PipelineModel;import org.apache.spark.ml.PipelineStage;import org.apache.spark.ml.classification.GBTClassificationModel;import org.apache.spark.ml.classification.GBTClassifier;import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;import org.apache.spark.ml.feature.*;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;// Load and parse the data file, converting it to a DataFrame.Dataset<Row> data = spark  .read()  .format("libsvm")  .load("data/mllib/sample_libsvm_data.txt");// Index labels, adding metadata to the label column.// Fit on whole dataset to include all labels in index.StringIndexerModel labelIndexer = new StringIndexer()  .setInputCol("label")  .setOutputCol("indexedLabel")  .fit(data);// Automatically identify categorical features, and index them.// Set maxCategories so features with > 4 distinct values are treated as continuous.VectorIndexerModel featureIndexer = new VectorIndexer()  .setInputCol("features")  .setOutputCol("indexedFeatures")  .setMaxCategories(4)  .fit(data);// Split the data into training and test sets (30% held out for testing)Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});Dataset<Row> trainingData = splits[0];Dataset<Row> testData = splits[1];// Train a GBT model.GBTClassifier gbt = new GBTClassifier()  .setLabelCol("indexedLabel")  .setFeaturesCol("indexedFeatures")  .setMaxIter(10);// Convert indexed labels back to original labels.IndexToString labelConverter = new IndexToString()  .setInputCol("prediction")  .setOutputCol("predictedLabel")  .setLabels(labelIndexer.labels());// Chain indexers and GBT in a Pipeline.Pipeline pipeline = new Pipeline()  .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter});// Train model. This also runs the indexers.PipelineModel model = pipeline.fit(trainingData);// Make predictions.Dataset<Row> predictions = model.transform(testData);// Select example rows to display.predictions.select("predictedLabel", "label", "features").show(5);// Select (prediction, true label) and compute test error.MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()  .setLabelCol("indexedLabel")  .setPredictionCol("prediction")  .setMetricName("accuracy");double accuracy = evaluator.evaluate(predictions);System.out.println("Test Error = " + (1.0 - accuracy));GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]);System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());

Python:

from pyspark.ml import Pipelinefrom pyspark.ml.classification import GBTClassifierfrom pyspark.ml.feature import StringIndexer, VectorIndexerfrom pyspark.ml.evaluation import MulticlassClassificationEvaluator# Load and parse the data file, converting it to a DataFrame.data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")# Index labels, adding metadata to the label column.# Fit on whole dataset to include all labels in index.labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)# Automatically identify categorical features, and index them.# Set maxCategories so features with > 4 distinct values are treated as continuous.featureIndexer =\    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)# Split the data into training and test sets (30% held out for testing)(trainingData, testData) = data.randomSplit([0.7, 0.3])# Train a GBT model.gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10)# Chain indexers and GBT in a Pipelinepipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])# Train model.  This also runs the indexers.model = pipeline.fit(trainingData)# Make predictions.predictions = model.transform(testData)# Select example rows to display.predictions.select("prediction", "indexedLabel", "features").show(5)# Select (prediction, true label) and compute test errorevaluator = MulticlassClassificationEvaluator(    labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy")accuracy = evaluator.evaluate(predictions)print("Test Error = %g" % (1.0 - accuracy))gbtModel = model.stages[2]print(gbtModel)  # summary only


相關文章

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.