A Beginner’s Guide to Simple Linear Regression
Imagine this scenario: You have been tasked by your team to identify key insights related to company sales. You have been creating plots to try to visualize some type of relationship between the individual variables and sales. Wouldn’t it be nice to be able to quantify the relationship to your team?
Enter Linear Regression.
Linear regression is a statistical model that uses an independent variable (X) to predict a dependent variable value (Y). Simple Linear Regression is using 1 X variable in order to explain the Y variable.
Important Notes (Before We Start):
- Both the X and the Y column have to be continuous numeric variables.
- Ensure there are no null values in the columns or the model will not work
For illustration of the concepts, we will be using a sample Lemonade Sales dataset found from Kaggle (Credit to irfan5826801): https://www.kaggle.com/irfansth/lemonade
- To begin, import the following libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn import metrics
2. Read in dataset as a Pandas Dataframe
# Read CSV from specified location
lemonade = pd.read_csv('./Lemonade.csv')# Returns top 5 rows of dataset
lemonade.head()
3. Create a Visual to see if Linear Regression will be a good model
# Sets size of scatterplot to be 8 x 5 (Width x Height)
plt.figure(figsize = (8, 5))# Creates Scatterplot and Regression Line using Seaborn
sns.regplot(x = 'Temperature', y = 'Sales', data = lemonade, ci=None,
scatter_kws = {'s': 3},
line_kws = {'color': 'orange'})# Creates Line of Mean Sales
plt.axhline(lemonade['Sales'].mean(), color = 'grey')# Creates Labels and Title for Regression Plot
plt.xlabel('Temperature')
plt.ylabel('Sales')
plt.title('Mean in grey, OLS regression in orange');
This is the result:
The regression line seems to capture the relationship between temperature and sales pretty well! Much better than if we only relied on looking at average sales!
Let’s continue with fitting the model:
4. Assemble the X and Y variables
X = lemonade[['Temperature']]
Y = lemonade['Sales']
5. Instantiate the Linear Regression Model from the sklearn package and fit against X & Y
lr = LinearRegression()
lr.fit(X,Y)
6. Now that the model is fit — lets take a look at the model coefficient (slope) and y-intercept
print(lr.coef_, lr.coefficient)
Interpretation of Slope — For a 1 degree increase in temperature, there is a 0.42 unit increase in sales.
Interpretation of Y-Intercept — If temperature was 0, we can expect -0.26 units of sales (which does not seem very realistic)
7. Find the R-Squared for the Regression Model
The R-Squared, or coefficient of determination, is the proportion of variability in Y that can explained by X.
lr.score(X, Y)
With this measure, you can confidently say that ~97.98% of the variability in Sales can be explained by Temperature!
Before you call it a day , you should also validate that your model satisfy the 4 Assumptions with Simple Linear Regression to a large extent:
- Linear Relationship — The target variable Y should have a roughly linear relationship with the explanatory variable X
(Easy to Check — If the scatterplot of underlying data points looks like it should be explained by a curve instead of a straight line, then this assumption is violated)
2. Independence — Each observation (data point) should be independent and uncorrelated with another
(Usually Satisfied-Unless you are working explicitly with time series data and time is plotted as the X variable)
3. Normality — The residuals (Y-Actual minus Y-Prediction) should follow a normal distribution
(Create a histogram to see distribution of the residuals)
# calculating prediction values of Y using linear regression model
Y_predictions = lr.predict(X)# calculating residuals (Y-Actual minus Y-Prediction)
residuals = Y - Y_predictions# creating a histogram with 15bins
plt.hist(residuals, bins = 15);
For our lemonade model — the distribution of the residuals does not seem to be normal, this assumption might be violated
Ideally you would like to see a bell shape distribution
4. Equal Variance —The residuals (Y-Actual minus Y-Prediction) should be close to constant at every level of X. This is also known as Homoscedasticity.
(To check for Homoscedasticity, create a residual plot to see how residuals are scattered around a horizontal line of 0 — ideal state where error is 0)
# calculating prediction values of Y using linear regression model
Y_predictions = lr.predict(X)# calculating residuals (Y-Actual minus Y-Prediction)
residuals = Y - Y_predictions# plotting predictions on X-Axis and residuals on Y-Axis
plt.scatter(predictions, residuals)
plt.axhline(0, color='orange');
For our Lemonade model — The residuals do not seem to be equal as we move across the X-Axis. The equal variance assumption might be violated.
Ideally you would like to see an somewhat equal and parallel band around the 0 line
Final Thoughts — So you have a model with a high coefficient of determination but the model also violates multiple assumptions. What does this mean? Should you worry?
The answer is…… it depends.
If you are planning to use the model for prediction (predict sales based on temperature), then it is likely ok to continue to use the model.
If you are planning to use the model for inference (interpret sales as a function of temperature), then the residuals tests suggest that there are additional variables which can be useful in explaining the variation in sales. (Hint: Multiple Linear Regression — which is a future topic I plan to explore!)