微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

[Spark][spark_ml]#2_分类算法

object Main {
  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setMaster("local").setAppName("iris")
    val spark = SparkSession.builder().config(conf).getorCreate()
    spark.sparkContext.setLogLevel("WARN") ///日志级别

    val file = spark.read.format("csv").load("iris.data")
    //file.show()

    import spark.implicits._
    val random = new Random()
    val data = file.map(row =>{
      val label =  row.getString(4) match {
        case "Iris-setosa" => 0
        case "Iris-versicolor" => 1
        case "Iris-virginica" => 2
      }

      (row.getString(0).todouble,
      row.getString(1).todouble,
      row.getString(2).todouble,
      row.getString(3).todouble,
      label,
      random.nextDouble())
    }).toDF("_c0","_c1","_c2","_c3","label","rand").sort("rand")//.where("label = 1 or label = 0")

    val assembler = new VectorAssembler().setInputCols(Array("_c0","_c1","_c2","_c3")).setoutputCol("features")

    val dataset = assembler.transform(data)
    val Array(train,test) = dataset.randomSplit(Array(0.8,0.2))

    /*
    //bayes
      val bayes = new NaiveBayes().setFeaturesCol("features").setLabelCol("label")
      val model = bayes.fit(train) //训练数据集进行训练
      model.transform(test).show() //测试数据集进行测试,看看效果如何
      */
    //SVM
    /*
    val svm = new LinearSVC().setMaxIter(20).setRegParam(0.1)
      .setFeaturesCol("features").setLabelCol("label")
    val model = svm.fit(train)
    model.transform(test).show()
    */

    val dt = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label")
    val model = dt.fit(train)
    val result = model.transform(test)
    result.show()
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(result)
    println(s"""accuracy is $accuracy""")
  }
}

object Main {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("gender").setMaster("local")
    val session = SparkSession.builder().config(conf).getorCreate()
    val sc = session.sparkContext
    val pattern = (filename:String,category:Int) => {
      val patternString = "\\[(.*?)\\]".r
      val rand = new Random()
      sc.textFile(filename)
        .flatMap(text => patternString.findAllIn(text.replace(" ","")))
        .map(text =>{
          val pairwise = text.substring(1,text.length - 1).split(",")
          (pairwise(0).todouble,pairwise(1).todouble,category,rand.nextDouble())
        })
    }
    val male = pattern("male.txt",1)
    val female = pattern("female.txt",2)

    val maleDF = session
      .createDataFrame(male)
      .toDF("height","weight","category","rand")
    val femaleDF = session
      .createDataFrame(female)
      .toDF("height","weight","category","rand")
    val dataset = maleDF.union(femaleDF).sort("rand")
    val assembler = new VectorAssembler()
      .setInputCols(Array("height","weight"))
      .setoutputCol("features")

    val transformedDataset = assembler.transform(dataset)
    transformedDataset.show()
    val Array(train,test) = transformedDataset.randomSplit(Array(0.8,0.2))

    val classifier = new DecisionTreeClassifier()
      .setFeaturesCol("features")
      .setLabelCol("category")
    val model = classifier.fit(train)
    val result = model.transform(test)
    result.show()

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("category")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(result)
    println(s"""accuracy is $accuracy""")
  }
}

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。

相关推荐