Linear Support Vector Machine(SVM) and Decision Tree Classifiers with Spark and Scala
Decision tree and Support Vector Machines are two two supervised machine learning algorithms used for both classification and regression.
Support Vector Machines vs Maximum Marginal Classier
Support Vector Machine is a supervised Machine Learning algorithm used for both classification and regression, mostly used for addressing classification problems. SVM has evolved from a simple classifier called Maximum Marginal classifier, which is a classifier separated by linear boundary.
Since such classifiers has to be separated by a linear boundary it cannot be applied to a large dataset(with large number of features). Since an improved version called Support Vector Machine was introduced.
Support vector Machines
Suppose there are p dimensions (features) in a dataset we can fit such a datset in a p dimensional plane whose equation will be as follows.
Such a hyperplane is called separating hyperplane that forms decision boundary. The result will be classified on the result , if greater than zero then on one side and if less than zero then on another side as shown in the figure.
Kernel functions
The SVM uses something called kernel to take care of the non-linearity in the datset. Kernel methods are used to map data into a higher dimensional space.
Some of the commonly used kernel functions are.
Linear Kernel: Most basic type of kernels which allows us to pick out only lines of hyperplane.
Polynomial Kernel: kernel which allows us to address some level on non-linearity to the extend of the order of polynomials.
Radial based functional Kernel: When you are not sure which kernel to use Radial Based Functional kernel(RBF) is a good choice. Does not functions when the number of feature is huge.
Sigmoid kernel: Sigmoid function has its root in neural network. SVM with sigmoid kernel is similar to neural network with two layered perceptron network.
Training an SVM
- How to preprocess the data. Convert the categorical variables into numerical values (0,1 or -1).
- Visualize the data to see which kernel function is best use cross valiadtion if visualization is not possible.
Scala implementation
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.regression.LabeledPointobject SVMExample {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local").setAppName("testSVM")
val sc= new SparkContext(conf)
/*Sample test data with 3 features
* */
val rdd = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.sparse(5, Array(1, 3), Array(9.0, 3.2,4.3))),
LabeledPoint(5.0, Vectors.sparse(5, Array(0, 2, 4), Array(1.0, 2.0, 3.0)))
))/*Split the data into 60 percent training set and 40 percent test set
* */
val splits = rdd.randomSplit(Array(0.6, 0.4), seed = 11L)
/*Cache the training data*/
val training=splits(0).cache()
val test=splits(1)/*Run the training model with training dataset to build the algorithm
* */
val model=SVMWithSGD.train(training,numIterations = 100)/*Clear the default threshold
* */
model.clearThreshold()/*Calculate raw score on test set
* */
val scoreAndLabels = test.map { point =>
val score = model.predict(point.features)
(score, point.label)
}
}}
Decision Trees
Decision trees are inverted trees in which root at the top and leaf node forming downwards.
Impurity measures
Impurity is calculated to get the best split result. Most of the impurity measures are probability based.
Probability of a class = number of observations of that class / total number of observations
Some of the commonly used impurity measures used are.
Gini Index
Mainly intended for continues measurement or features of the database.
If all observations of a response belong to a single class, then probability P of that class j, that is (Pj), will be 1 as there is only one class, and Pj square would also be 1. This makes the Gini Index to be zero.
Entropy
intended for categorical values.
If all observations of a response belong to a single class, then the probability of that class (Pj) will be 1, and log(P) would be zero. This makes the entropy to be zero.
Scala implementation
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}
import org.apache.spark.ml.classification.DecisionTreeClassifier
object decisionTree {
def main(args: Array[String]): Unit = {
val spark= new SparkSession.Builder().getOrCreate()
val trainDF = spark.read.options(Map("header"->"true",
"inferSchema"->"true")).csv("/path/to/file")
/*Defines decision tree pipeline
* String indexer maps string to
* */
val lblIdx = new StringIndexer().
setInputCol("Label").
setOutputCol("indexedLabel")
/*Create labels list to decode predictions
* */
val labels=lblIdx.fit(trainDF).labels
/*Define text column indexing stage
* */
val fIdx = new StringIndexer().
setInputCol("Text").
setOutputCol("indexedText")
/*vector assembler
* */
val va = new VectorAssembler().
setInputCols(Array("indexedText")).
setOutputCol("features")
/*Define decision tree classifier set labels and vectors
* */
val dt = new DecisionTreeClassifier().
setLabelCol("indexedLabel").
setFeaturesCol("features")
/*Label converter to convert prediction index back to string
* */
val lc = new IndexToString().
setInputCol("prediction").
setOutputCol("predictedLabel").
setLabels(labels)
/*Starting stages together to form a pipeline
* */
val dt_pipeline = new Pipeline().setStages(
Array(lblIdx,fIdx,va,dt,lc))
/*Apply the pipeline to train the data
* */
val resultDF = dt_pipeline.fit(trainDF).transform(trainDF)
resultDF.select("Text","Label","features","prediction","predictedLabel").show()
}
}
Final result will be like
+----+-----+--------+----------+--------------+
|Text|Label|features|prediction|predictedLabel|
+----+-----+--------+----------+--------------+
| A| 1| [1.0]| 1.0| 1|
| B| 2| [0.0]| 0.0| 2|
| C| 3| [2.0]| 2.0| 3|
| A| 1| [1.0]| 1.0| 1|
| B| 2| [0.0]| 0.0| 2|
+----+-----+--------+----------+--------------+