之前尝试使用Spark MLlib 做机器学习,发现不是非常方便,也可能是在使用习惯上面不太适应(相对 python sklearn).
今天尝试使用Spark MLlib 针对Iris数据做一次实践,之后会尝试写一个包装类,将这些步骤简化。
0. 数据准备:
原始的数据以及相应的说明可以到[这里] 下载。 我在这基础之上,增加了header信息。 下载:https://pan.baidu.com/s/1c2d0hpA
如果是可以直接从NFS或者HDFS之类的文件服务里面读csv,会比较方便, 参考下面的python代码:
1 2 3 4 5 6 7 |
from pyspark.sql import SQLContext sqlContext = SQLContext(sc) df = sqlContext.read.format('com.databricks.spark.csv') .options(header='true', inferschema='true') .load('iris.csv') # Displays the content of the DataFrame to stdout df.show() |
在我的环境之中,如果你跟我一样只能从本地读取,就比较麻烦了。 可以参考下面的Java代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
// 首先读iris 数据 // 因为是从本地读取Sample数据,所以比较麻烦一些~ List<String> lines = FileUtils.readLines(new File("E:\\DataSet\\iris_data.txt"), "UTF-8"); List<Row> data = Lists.newArrayList(); String[] headers = lines.get(0).split(","); for(String line : lines.subList(1, lines.size())) { // 前面几个都是double String[] cells = line.split(","); Object[] values = new Object[cells.length]; for(int i = 0; i < cells.length - 1; i++) { values[i] = Double.parseDouble(cells[i]); } values[cells.length - 1] = cells[cells.length - 1]; data.add(RowFactory.create(values)); } // 创建Dataset StructField[] fields = new StructField[headers.length]; for(int i = 0; i < headers.length - 1; i++) { fields[i] = new StructField(headers[i], DataTypes.DoubleType, false, Metadata.empty()); } fields[headers.length - 1] = new StructField(headers[headers.length - 1], DataTypes.StringType, false, Metadata.empty()); StructType schema = new StructType(fields); Dataset<Row> df = ss.createDataFrame(data, schema); df.show(); |
df.show() 的结果如下所示:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
+------------+-----------+------------+-----------+-----------+ |sepal_length|sepal_width|petal_length|petal_width| classes| +------------+-----------+------------+-----------+-----------+ | 5.1| 3.5| 1.4| 0.2|Iris-setosa| | 4.9| 3.0| 1.4| 0.2|Iris-setosa| | 4.7| 3.2| 1.3| 0.2|Iris-setosa| | 4.6| 3.1| 1.5| 0.2|Iris-setosa| | 5.0| 3.6| 1.4| 0.2|Iris-setosa| | 5.4| 3.9| 1.7| 0.4|Iris-setosa| | 4.6| 3.4| 1.4| 0.3|Iris-setosa| | 5.0| 3.4| 1.5| 0.2|Iris-setosa| | 4.4| 2.9| 1.4| 0.2|Iris-setosa| | 4.9| 3.1| 1.5| 0.1|Iris-setosa| | 5.4| 3.7| 1.5| 0.2|Iris-setosa| | 4.8| 3.4| 1.6| 0.2|Iris-setosa| | 4.8| 3.0| 1.4| 0.1|Iris-setosa| | 4.3| 3.0| 1.1| 0.1|Iris-setosa| | 5.8| 4.0| 1.2| 0.2|Iris-setosa| | 5.7| 4.4| 1.5| 0.4|Iris-setosa| | 5.4| 3.9| 1.3| 0.4|Iris-setosa| | 5.1| 3.5| 1.4| 0.3|Iris-setosa| | 5.7| 3.8| 1.7| 0.3|Iris-setosa| | 5.1| 3.8| 1.5| 0.3|Iris-setosa| +------------+-----------+------------+-----------+-----------+ only showing top 20 rows |
1. 使用StringIndexer将字符型的label变成index
1 2 3 4 5 6 |
// 使用StringIndexer将字符型的label变成index StringIndexer indexer = new StringIndexer() .setInputCol("classes") .setOutputCol("classesIndex"); Dataset<Row> indexed = indexer.fit(df).transform(df); indexed.show(); |
这就是Spark第一个不太方便的地方:不能直接处理String类型的label
处理完成之后,show的结果如下:
1 2 3 4 5 6 7 |
+------------+-----------+------------+-----------+-----------+------------+ |sepal_length|sepal_width|petal_length|petal_width| classes|classesIndex| +------------+-----------+------------+-----------+-----------+------------+ | 5.1| 3.5| 1.4| 0.2|Iris-setosa| 2.0| | 4.9| 3.0| 1.4| 0.2|Iris-setosa| 2.0| | 4.7| 3.2| 1.3| 0.2|Iris-setosa| 2.0| | 4.6| 3.1| 1.5| 0.2|Iris-setosa| 2.0| |
如果需要将Index变回来,那么需要用到IndexToString:
1 2 3 4 5 6 |
IndexToString converter = new IndexToString() .setInputCol("classesIndex") .setOutputCol("originalClasses"); Dataset<Row> converted = converter.transform(indexed); converted.show(); |
2. 数据模型的创建与验证
在Spark的机器学习之中,有一个很容易让初学者混淆的问题:ml跟mllib有什么区别?
简单的说:
- spark.mllib中的算法接口是基于RDDs的
- spark.ml中的算法接口是基于DataFrames / Dataset 的
但是根据作者自己的经验,如果你处理的是CSV格式的数据,除非你现行转换成Libsvm的格式,否则后期处理非常非常的麻烦。具体的处理方式将在后续的文章之中尝试,敬请关注。
尝试使用RDD的方式(mllib)进行分类
使用RDD-based API, 核心就是整合出一个LabeledPoint.
1 2 3 4 5 6 7 8 9 10 11 12 |
// 将Row --> LabeledPoint JavaRDD<LabeledPoint> rowRDD = indexed.toJavaRDD().map(new Function<Row, LabeledPoint>() { @Override public LabeledPoint call(Row row) throws Exception { double[] features = new double[4]; for(int i = 0; i < 4; i++) { features[i] = row.getDouble(i); } LabeledPoint point = new LabeledPoint(row.getDouble(6), Vectors.dense(features)); return point; } }); |
当你整合出这个LabeledPoint RDD之后,就直接copy官网代码即可。
就不在这里贴代码了。 比如,如果你采用的是RandomForest, 可以请参考:https://spark.apache.org/docs/2.0.2/mllib-ensembles.html#random-forests
比如在喂给RandomForest的时候,我们需要设置好几个参数:
- numClasses
需要提前设置好有那几个类 - numTrees
有几棵树。 - categoricalFeaturesInfo
每一个feature有几个类别? - featureSubsetStrategy
auto: 默认参数。让算法自己决定,每颗树使用几条数据 - impurity / maxDepth / maxBins / seed
略
3. 检查预测结果
官网使用的是比较简单粗暴的比较方式:
1 2 3 4 5 6 7 8 9 |
Double testErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { @Override public Boolean call(Tuple2<Double, Double> pl) { return !pl._1().equals(pl._2()); } }).count() / testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification forest model:\n" + model.toDebugString()); |
这样确实能看到结果,但是如果想查看比如classification_report等等,Spark自带的类能提供一些比较方便的东西。
在之前官网的基础之上,需要修改predictionAndLabel的数据类型:
1 2 3 4 5 6 7 8 |
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = testData.map( new Function<LabeledPoint, Tuple2<Object, Object>>() { public Tuple2<Object, Object> call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2<Object, Object>(prediction, p.label()); } } ); |
有两个要注意的地方:
- 类型需要是Object, 之前的Double不行
- 从之前的JavaPairRDD 变成 JavaRDD<Tuple2>
做好这个准备之后,我们就可以调用Metrics相关的工具类了:
1 2 3 4 5 6 7 8 9 10 |
// 多分类: MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); // 二分类 BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); // 多标签分类:(相对来说少遇到) MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); |
一些简单实用的Sample:
1 2 3 4 5 6 7 8 9 10 |
System.out.println("准确率:" + metrics.accuracy()); // 准确率:0.9583333333333334 Matrix confusion = metrics.confusionMatrix(); System.out.println("混淆矩阵: \n" + confusion); // 混淆矩阵: // 15.0 1.0 0.0 // 0.0 18.0 0.0 // 0.0 1.0 13.0 |
至此,一个基本的流程算是走通了,但是我们可以看到,在这整个过程之中有一些很不方便的事情:
- 读取本地的CSV非常不方便
- 不支持String类型的label,需要使用StringIndexer
即使读取原始数据是数值类型,也需要使用StringIndexer, 因为除非使用spark-csv并且设置了inferSchema=true, 否则也自动被认为是String类型的值 - 在使用RandomForest的时候,好几个参数需要设置
我感觉有的是应该可以自动设置的。比如:numClasses 、 categoricalFeaturesInfo - 在最后检查结果的时候,比较麻烦:
- 需要自己选择是multi-class or binary-class.
- 混淆矩阵缺少label
- 缺少类似sklearn.classification_report 那种简单明了的report
后面会陆续针对这些问题,做一些wrapper
本文为原创文章,转载请注明出处
原文链接:http://www.flyml.net/2017/01/09/spark-2-0-ml-practice-iris-test/

文章评论