**package** mllib.tree
**import** org.apache.log4j.{Level, Logger}
**import** org.apache.spark.mllib.evaluation.MulticlassMetrics
**import** org.apache.spark.mllib.linalg.Vectors
**import** org.apache.spark.mllib.regression.LabeledPoint
**import** org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
**import** org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
**import** org.apache.spark.rdd.RDD
**import** org.apache.spark.{SparkContext, SparkConf}
_/**_ _* Created by_ _汪本成_ _on 2016/7/18._ _*/_ **object** randomForest {
//屏蔽不必要的日志显示在终端上
// Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
// Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)
**var** _beg_ = System.currentTimeMillis()
//创建入口对象
**val** _conf_ = **new** SparkConf().setAppName("rndomForest").setMaster("local")
**val** _sc_ = **new** SparkContext( _conf_ )
**val** _HDFS_COVDATA_PATH_ = "hdfs://192.168.43.150:9000/user/spark/@R_220_4502@arning/mllib/covtype.data"
**val** _rawData_ = _sc_.textFile( _HDFS_COVDATA_PATH_ )
//设置LabeledPoint格式
**val** _data_ = _rawData_.map{
line =>
**val** values = line.split(",").map(_.todouble)
// init返回除最后一个值之外的所有值,最后一列是目标
**val** FeatureVector = Vectors.dense(values.init)
//决策树要求(目标变量)label从0开始,所以要减一
**val** label = values.last - 1
LabeledPoint(label, FeatureVector)
}
//分成训练集(80%),交叉验证集(10%),测试集(10%)
**val** Array( _trainData_ , _cvData_ , _testData_ ) = _data_.randomSplit(Array(0.8, 0.1, 0.1))
_trainData_.cache()
_cvData_.cache()
_testData_.cache()
//新建随机森林
**val** _numClass_ = 7 //分类数量
**val** _categoricalFeaturesInfo_ = _Map_ [Int, Int](10 -> 4, 11-> 40) //用map存储类别(离散)特征及每个类特征对应值(类别)的数量
**val** _impurity_ = "entropy" //纯度计算方法,用于信息增益的计算
**val** _number_ = 20 //构建树的数量
**val** _maxDepth_ = 4 //树的最大高度
**val** _maxBins_ = 100 // 用于分裂特征的最大划分数量
//训练分类决策树模型
**val** _model_ = RandomForest.trainClassifier( _trainData_ , _numClass_ , _categoricalFeaturesInfo_ , _number_ , "auto", _impurity_ , _maxDepth_ , _maxBins_ )
**val** _metrics_ = getMetrics( _model_ , _cvData_ )
//计算精确度(样本比例)
**val** _precision_ = _metrics_. _precision_ __ //计算每个样本的准确度(召回率)
**val** _recall_ = (0 until 7).map( //DecisionTreeModel模型的类别号从0开始
cat => ( _metrics_.precision(cat), _metrics_.recall(cat))
)
**val** _end_ = System.currentTimeMillis()
//耗时时间
**var** _castTime_ = _end_ - _beg_ ____**def** main(args: Array[String]) {
println("========================================================================================")
//精确度(样本比例)
println("精确度: " + _precision_ )
println("========================================================================================")
//准确度(召回率)
println("准确度: ")
_recall_.foreach(println)
println("========================================================================================")
println(" 运行程序耗时: " + _castTime_ /1000 + "s")
}
_/**_ _*_ _在训练集构建RandomForestModel_ ___*_ ** _@param model_** ** __**_*_ ** _@param data_** ** __**_*_ ** _@return_** ** __**_*/_ __**def** getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
**val** predictionsAndLabels = data.map(example => (model.predict(example.features), example.label))
**new** MulticlassMetrics(predictionsAndLabels)
}
_/**_ _*_ _按照类别在训练集出现的比例预测类别_ ___*_ _*_ ** _@param data_** ** __**_*_ ** _@return_** ** __**_*/_ __**def** classprobabilities(data: RDD[LabeledPoint]): Array[Double] = {
//计算数据中每个类别的样本数(类别, 样本数)
**val** countsByCategory = data.map(_.label).countByValue()
//对类别的样本数进行排序并取出样本数
**val** counts = countsByCategory.toArray.sortBy(_._1).map(_._2)
counts.map(_.todouble / counts.sum)
}
}
[/code]
---
运行结果如下
16/07/18 23:30:11 INFO DAGScheduler: ResultStage 17 (collectAsMap at
MulticlassMetrics.scala:54) finished in 0.003 s
16/07/18 23:30:11 INFO TaskSchedulerImpl: Removed TaskSet 17.0, whose tasks
have all completed, from pool
16/07/18 23:30:11 INFO DAGScheduler: Job 9 finished: collectAsMap at
MulticlassMetrics.scala:54, took 0.197033 s
========================================================================================
精确度: 0.5307208847065288
========================================================================================
准确度:
(0.8087885985748219,0.03206818609907704)
(0.5233824352041768,0.9884494841004331)
(0.5730994152046783,0.6121521862578081)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
(0.0,0.0)
========================================================================================
运行程序耗时: 44s
16/07/18 23:30:12 INFO SparkContext: Invoking stop() from shutdown hook
16/07/18 23:30:12 INFO SparkUI: Stopped Spark web UI at
http://192.168.43.1:4040
16/07/18 23:30:12 INFO MapOutputTrackerMasterEndpoint:
MapOutputTrackerMasterEndpoint stopped!
16/07/18 23:30:12 INFO MemoryStore: MemoryStore cleared
16/07/18 23:30:12 INFO BlockManager: BlockManager stopped
16/07/18 23:30:12 INFO BlockManagerMaster: BlockManagerMaster stopped
16/07/18 23:30:12 INFO
OutputCommitCoordinator$OutputCommitCoordinatorEndpoint:
OutputCommitCoordinator stopped!
16/07/18 23:30:12 INFO SparkContext: Successfully stopped SparkContext
16/07/18 23:30:12 INFO ShutdownHookManager: Shutdown hook called
16/07/18 23:30:12 INFO ShutdownHookManager: Deleting directory
C:\Users\Administrator\AppData\Local\Temp\spark-a375456b-af35-40aa-8416-ac6b61b39b19
16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Shutting
down remote daemon.
16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Remote
daemon shut down; proceeding with flushing remote transports.
16/07/18 23:30:12 INFO RemoteActorRefProvider$RemotingTerminator: Remoting
shut down.
Process finished with exit code 0
---

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。