Every month I learn a new machine learning algorithm. Until today I’ve learned about ten algorithms and whenever I’m trying to solve a machine learning problem, the question is always “Which algorithm should I use?”

Almost every machine learning practitioner knows that the answer depends on supervised or unsupervised, then classification or regression. So which algorithm to use is quite straight forward right?

Well, no. Take classification for example. We can use Logistic Regression, Naive Bayes, Support Vector Machine, Decision Tree, Random Forest, Gradient Boosting or Neural Network. Which one should we use?

“Well, it depends” is the answer we often hear. “Depends on what?” that is the question! It would be helpful if we know what factors to consider right?

So in this article I would like to try answering those questions. I’m going to first address the general question on “Which machine learning algorithm should I use?” This is useful when you are new in machine learning and never heard about classification and regression, let alone ensemble and boosting. There are many good articles already written about this, so I’m going to point you to them.

Then as an example I’m going to dive specifically into classification algorithms. I’ll try to give a brief outline on what factors we need to consider when deciding, such as linearity, interpretability, multiclass and accuracy. Also the strengths and weaknesses of each algorithm.

**General guide on which ML algorithms to use**

I would recommend that you start with Hui Li’s diagram: link. She categorised ML algorithms into 4: clustering, regression, classification and dimensionality reduction:

It is very easy to follow, and it is detail enough. She wrote it in 2017 but by and large it is still relevant today.

The second one that I’d recommend is Microsoft’s guide: link, which is newer (2019) and more comprehensive. They categorise ML algorithms into 8: clustering, regression, classification (2 class and multiclass), text analytics, image classification, recommenders, and anomaly detection:

So now you know roughly which algorithm to use for each case, using the combination of Hui Li’s and Microsoft’s diagrams. In addition to that, it would be helpful if you read Danny Varghese’s article about comparative study on machine learning algorithms: link. For every algorithm Danny outlines the advantages and disadvantages against other algorithms in the same category. So once you choose an algorithm based on Hui Li’s and Microsoft’s diagrams, check that algorithm against the alternatives on Danny’s list, make sure that the advantages outweigh the disadvantages.

**Classification algorithms: which one should I use?**

For classification we can use Logistic Regression, Naive Bayes, Support Vector Machine, Decision Tree, Random Forest, Gradient Boosting Machine (GBM), Perceptron, Linear Discriminant Analysis (LDA), K Nearest Neighbours (KNN), Learning Vector Quantisation (LVQ) or Neural Network. What factors do we need to consider when deciding? And what are the strength and weaknesses of each algorithm?

The factors we need to consider are: linearity, interpretability and multiclass.

The first consideration is linearity of the data. The data is linear if the plot between the predictor and the target variable is separable by a straight line, like below.

Note that the plots above are over-simplified as the reality is not only 2 dimensions but many dimensions (e.g. we have have 8 predictors, or 8 X axis) so the separator is not a line but a hyperplane.

- If the data is linear, we can use (link): Logistic Regression, Naive Bayes, Support Vector Machine, Perceptron, Linear Discriminant Analysis.
- If the data is not linear, we can use (link): Decision Tree, Random Forest, Gradient Boosting Machine, K Nearest Neighbours, Neural Network, Support Vector Machine using Kernel, Learning Vector Quantisation.

Can we use algorithms in #2 for linear classification? Yes we can, but #1 is more suitable.

Can we use #1 for non-linear classification? No we can’t, not without modification. But there are ways to transform data from a non-linear space to a linear space. They are called “kernel trick”, see my article here: link.

The second factor that we need to consider is interpretability, i.e. the ability to explain why a data point is classified into a certain class. Christoph Molnar explains interpretability in great details: link.

- If we need to be able to explain, we can use Logistic Regression, Naive Bayes, Decision Tree, Linear Support Vector Machine,
- If we don’t need to be able to explain, we can use Random Forest, Support Vector Machine with Kernel (see Hugo Dolan’s article: link), Gradient Boosting Machine, K Nearest Neighbours, Neural Network, Perceptron, Linear Discriminant Analysis, Learning Vector Quantisation

The third factor that we need to consider whether we are classifying into two classes (binary classification) or more than two classes (multi-class). Support Vector Machine (SVM), Linear Discriminant Analysis (LDA) and Perceptron are binary classification, but everything else can be used for both binary and multi-class. We can make LDA multi-class, see here: link. Ditto SVM: link.

**1. Logistic Regression**

**Strengths: **good accuracy on small amount of data,easy tointerpret (we get feature importance), easy to implement, efficient to train (doesn’t need high compute power), can do multi-class.

**Weaknesses:** tend to overfit on high dimensions (use regularisation), can’t do non-linear classification (or complex relationship), not good with multicollinearity, sensitive to outliers, requires linear relationship between log odds and target variable.

**2. Naive Bayes**

**Strengths: **good accuracy on small amount of data, efficient to train (doesn’t need high compute power), easy to implement, highly scalable, can do multi-class, can do continuous and discreet data, not sensitive to irrelevant features.

**Weaknesses:** features must be are independent, a category which exist in test dataset but not in training data set will get zero probability (zero frequency problem)

**3. Decision Tree**

**Strengths: **easy to interpret (intuitive, show interaction between variables), can classify non-linear data, data doesn’t need to be normalised nor scaled, not affected by missing values, not affected by outliers, performs well with unbalanced data (the nature of data distribution does not matter), can do both classification and regression, can do both numerical and categorical data, provide feature importance (calculated from the decrease in node impurity), good with large dataset, able to handle multicollinearity.

**Weaknesses: **has tendency tooverfit (bias towards training set, requires pruning), not robust (high variance, small change in training data results in major change in the model and output),not good with continuous variable, requires longer time to train the model (resource intensive).

**4. Random Forest**

**Strengths**: high accuracy, doesn’t need pruning, no overfitting, low bias with quite low/moderate variance (because of bootstrapping), can do both classification and regression, can do numerical and categorical, can classify non-linear data, data doesn’t need to be normalised nor scaled, not affected by missing values, not affected by outliers, performs well with unbalanced data (the nature of data distribution does not matter), can be parallelised (can use multiple CPUs in parallel), good with high dimensionality.

**Weaknesses**: long training time, requires large memory, non interpretable (because there are hundreds of trees).

**5. Support Vector Machine (Linear Vanilla)**

**Strengths: **scales well with high dimensional data, stable (low variance), less risk of overftting, doesn’t rely on the entire data (not affected by missing values), works well with noise.

**Weaknesses:** long training time for large data, requires features scaling.

**6. Support Vector Machine (with Kernel)**

**Strengths: **scales well with high dimensional data, stable (low variance), handle non-linear data very well, less risk of overftting (because of regularisation), good with outliers (has gamma and C to control), can detect outliers in anomaly detection, works well with noise.

**Weaknesses:** long training time for large data, tricky to find appropriate kernel, need large memory, requires features scaling, difficult to interpret.

**7. Gradient Boosting**

**Strengths: **high accuracy, flexible with various loss functions, minimal pre-processing, not affected by missing values, works well with unbalanced data, can do both classification and regression.

**Weaknesses: **tendency to overfit (because it continues to minimise errors), sensitive with outliers, large memory requirement (thousands of trees), long training time, large grid search for hyperparameter, not good with noise, difficult to interpret.

**7. K Nearest Neighbours**

**Strengths: **simple to understand (intuitive), simple to implement (both binary and multi-class), handles non-linear data well, non parametric (no requirements on data distribution), respond quickly to data changes in real time implementation, can do both classification and regression.

**Weaknesses: **long training time, doesn’t work well with high dimensional data, requires scaling, doesn’t work well with imbalanced data, sensitive to outliers and noise, affected by missing values.

**8. Neural Network**

**Strengths: **high accuracy,handles non-linear data well, generalise well on unseen data (low variance), non parametric (no requirements on data distribution), works with heteroskedastic data (non-constant variance), works with highly volatile data (time series), works with incomplete data (not affected by missing values), fault tolerance

**Weaknesses: **requires large amount of data,computationally expensive (requires parallel processors/GPU and large memory), not interpretable, tricky to get the architecture right (#layers, #neurons, functions, etc.)