How do machines learn: a simple explanation
My introduction to the world of data science was through a university course called Introduction to Machine Learning I took during my bachelor's. And it was, by far, my least favorite class of all time. Even though I was a generally high-scoring student, I struggled with it a lot.
At the time, I blamed the professor of the course for not doing a good job explaining things. Though, in retrospect, I think the main problem was my perspective.
Machine Learning was a new way of approaching problems and I didn't take enough time to internalize the logic behind how it worked. As a result, I had a hard time learning the ML algorithms and fundamental concepts.
I recognize similar struggles in the newcomers to this field. That's why I'd like to share a simple explanation of how machine learning works, in hopes that it will make your learning smoother than mine.
By now you have probably heard this explanation: ML algorithms learn like humans. You give it examples and it recognizes and remembers the patterns in them. You need to give it a lot of examples, though, so that it can learn accurately.
Okay, that’s clear. But HOW does it learn?
You have probably also seen this image somewhere:
This image illustrates the sentence above on a more formal level. You provide examples (input) to the model. Based on this input, the model learns and is then able to produce predictions (output).
The learning happens inside the model. And what the model is, is basically a mathematical equation. Like this one:
a*x+b = y
It says a times x plus b makes y. In the context of machine learning, the input we provide would be x and the result of the equation, y, would be the output or the prediction if you will.
The goal of machine learning is to find the a and b values that hold true for all the x and y pairs.
So basically the whole learning process is, getting examples for x and y, and based on those examples, trying to figure out what the a and b values should be by making guesses and improving those guesses.
Let's look at a simple example. Let's say we have a training set of one data point. We know that if x is 5, y should be 12. In a tuple it would look like this: (5,12). We need to find the a, b pair that fits these numbers. As in the equation below.
a*5 + b = 12
From this point on, the model starts its search for the a and b values. But it does not use any prior knowledge. No mathematics knowledge, no assumptions, nothing. To start with, random numbers are assigned to the a and b values.
Let’s say, in our example, a is initialized as 8 and b as 3. This would give us:
8*5 + 3 = 43
Our error is: 43-12 = 31. So this would tell us that our a and b values are too high.
The learning process is basically trial-error mashed up with the hot and cold game (where you tell someone they are getting hotter when they get closer to an item you hid from them and cold when they get further).
The model tries new values each time guided by how much it overshot the actual value and in which direction. As in, was the guess higher than the expected value or lower?
Our model might try 2 and 1 next. This would give us:
2*5 + 1 = 11
This is a much better guess. Our error is only: 11-12 = -1 This error means that now we underestimated the y value. After going with this for a while, let’s say our model tries, 3 and -3:
3*5 + (-3) = 12
Done! We were able to figure out the a and b values! But not so fast. It turns out that there was another pair of x, y values that just came to our attention: (3,4). Plugging in these values in our newly built model, we see a disappointing scene.
3*3 + (-3) = 6
Unfortunately, it doesn’t work. We have an error of 6-4 = 2. After some more guessing, our model will likely come up with the actual value for a and b which is 4 and -8.
So this is how a ML model learns in principal. Here is how some fundamental concepts relate to what we just learned:
Machine Learning models need enough data to learn accurately
What is enough data? It is the amount of data points that will help us find the most accurate explanation (line) for the problem (equation) in our hands. And yes, you guessed it, this amount depends on the complexity of the problem and the model.
Sometimes multiple explanations (or lines) might fit the data but we need enough data to figure out the most accurate explanation. In the above example, our problem could an be represented as a straight line and the model only needs 2 points to find it accurately. One point is not enough because multiple lines can pass through the same point.
The actual problems ML deals with is more complex. The relationship between input and output is not linear: making the line non-linear. Additionally, there would be more than one input variable, making the problem multi-dimensional and effectively impossible to plot after the 3rd dimension. This would mean we need many more data points.
Not just any data though, relevant data
We need a lot of data for an accurate model, but these data points we train our model on, need to be comprehensive. They need to represent a significant portion of the real world. Otherwise, even a (seemingly) perfect model would fail in the real world. Just like in our example how we didn't know that the correct a, b pair was 4 and -8 before we saw the second data point, our model needs to have access to enough diversity in data points to learn an accurate enough estimation of the world.
Neglecting data quality and diversity might cause bias and unfairness in our practice too. One example is a classification algorithm having a bias in favor of men when it comes to classifying who can be the president of a country caused by the training set having 95% of the data points being male presidents and only 5% women.
Inputs might have many shapes
The input can be numerical like in our example here but it can also be categorical. Or it can be text or even an image. A good thing to know is that no matter how the input looks, before we feed it to the model, we translate it into a numerical form. For example, an image of 100x100 pixels will turn into a 100x100=10000 elements long list. Each element representing one pixel's color with a number.
Machine Learning is an estimation
In data science, we work with problems that do not have a perfect mathematical explanations.
The example we worked with in this article can be solved with a couple of example data points because it is a mathematical formula. But there are cases where the problem is not. Especially in problems where humans are involved. One example is credit risk score calculation. If you want to calculate the likelihood someone will pay their debt back without any problems, you are dealing with the human factor. And no matter how many variables you take into account (employed/unemployed, sex, age, residence, and so on) there can always be someone who lies out of the pattern.
There is randomness to these problems. That’s why, no matter how many data points you give as examples, what we are doing at the end of the day is an estimation. We are trying to come up with a mathematical explanation that can closely estimate what happens in the real world.
The equation looks different in every algorithm
This mathematical equation that needs to be solved is different for each algorithm. For linear regression, it does actually look like the example we worked on. But for example in decision trees, the constant values we are trying to find are values at which the tree makes a decision to make a split. But let's not get into the details here. Just know that the working logic is the same but the what the constants are is different for each ML algorithm.
And that's all. How a machine learns on a high level is actually this simple. understanding this helped me greatly when learning about new machine learning algorithms. The trick is understanding what is being optimized and what the value I’m trying to estimate is.
Talking about all this and understanding the logic behind machine learning is one thing but practice is what will get you a job. If you'd like to learn all of this and more in practice, go to Hands-on Data Science: Complete your first portfolio project.