package cn.itcast.tags.ml.classification
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel, ParamGridBuilder}
import org.apache.spark.sql.{DataFrame, SparkSession}
object RfModel {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName.stripSuffix("$"))
.master("local[4]")
.getorCreate()
import org.apache.spark.sql.functions._
import spark.implicits._
// 1. 加载数据
val dataframe: DataFrame = spark.read
.format("libsvm")
.load("datas/ship/total001.txt")
// 划分数据集:训练数据和测试数据
val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))
// 2. 特征工程:特征提取、特征转换及特征选择
// 2.1. 将标签值label,转换为索引,从0开始,到 K-1
val labelIndexer: StringIndexerModel = new StringIndexer()
.setInputCol("label")
.setoutputCol("index_label")
.fit(dataframe)
val df1: DataFrame = labelIndexer.transform(dataframe)
// 2.2. 对类别特征数据进行特殊处理, 当每列的值的个数小于等于设置K,那么此列数据被当做类别特征,自动进行索引转换
val featureIndexer: VectorIndexerModel = new VectorIndexer()
.setInputCol("features")
.setoutputCol("index_features")
// Todo: 表示哪些特征列为类别特征列,并且将特征列的特征值进行索引化转换操作
.setMaxCategories(4) // 类别特征最大类别个数
.fit(df1)
val df2: DataFrame = featureIndexer.transform(df1)
val rf = new RandomForestClassifier()
.setLabelCol("index_label")
.setFeaturesCol("index_features")
// .setsubsamplingRate(1.0)
// Todo: 4. 构建Pipeline管道,设置Stage,每个Stage要么是算法(模型学习器Estimator),要么是模型(转换器Transformer)
val pipeline: Pipeline = new Pipeline()
// 设置Stage,依赖顺序
.setStages(
Array(labelIndexer, featureIndexer, rf)
)
val paramGrid: Array[ParamMap] = new ParamGridBuilder()
.addGrid(rf.maxDepth, Array(5, 10,15,20,25,30))
.addGrid(rf.impurity, Array("gini", "entropy"))
.addGrid(rf.maxBins, Array(32, 64))
.addGrid(rf.numTrees, Array(5, 10,20,30,40,50))
.addGrid(rf.featureSubsetStrategy,Array("auto","sqrt"))
.build()
// val paramGrid: Array[ParamMap] = new ParamGridBuilder()
// .addGrid(rf.maxDepth, Array(5, 10,15,20,25,30))
// .addGrid(rf.impurity, Array("gini", "entropy"))
// .addGrid(rf.maxBins, Array(32, 64))
// .addGrid(rf.numTrees, Array(5, 10))
// .addGrid(rf.featureSubsetStrategy,Array("auto","sqrt"))
// .build()
// 多分类评估器
val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("index_label")
.setPredictionCol("prediction")
指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
.setMetricName("accuracy")
// 训练验证
val validator: CrossValidator = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(4)
// 训练模型
println("正在训练模型...")
val model: CrossValidatorModel = validator.fit(trainingDF)
println(model.toString())
// 5. 模型评估,计算准确度
val predictionDF: DataFrame = model.transform(testingDF)
predictionDF.printSchema()
predictionDF
.select( $"probability", $"prediction",$"index_label")
.show(100, truncate = false)
val accuracy: Double = evaluator.evaluate(predictionDF)
println(s"Accuracy = $accuracy")
spark.stop()
}
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。