Python: simple Linear Regression – two possibilities, either Numpys‘ Polyfit or sklearn

Let’s start with the most basic version. This one involved using Polyfit.

I recommend checking out data36.com for in depth explanations and other great content! I wanted to visually shorten it and give it just 5 Steps to follow to get the linear regression.

Step 1 – Get the data & visualise it

# credit
# https://data36.com/linear-regression-in-python-numpy-polyfit/
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

students = {'hours': [29, 9, 10, 38, 16, 26, 50, 10, 30, 33, 43, 2, 39, 15, 44, 29, 41, 15, 24, 50],
            'test_results': [65, 7, 8, 76, 23, 56, 100, 3, 74, 48, 73, 0, 62, 37, 74, 40, 90, 42, 58, 100]}

student_data = pd.DataFrame(data=students)

x = student_data.hours
y = student_data.test_results

plt.scatter(x,y)

Step 2 – Use polyfit from numpy to get coefficient and intercept

model = np.polyfit(x, y, 1)

# query the regression coefficient and intercept values for your model
model

# 2.01467487 is the regression coefficient (the a value)
# and -3.9057602 is the intercept (the b value)
 array([ 2.01467487, -3.9057602 ]) 

Step 3 – Test the prediction

# If a student tells you how many hours she studied,
# you can predict the estimated results of her exam.

predict = np.poly1d(model)
hours_studied = 10
predict(hours_studied)
16.240988519071156

Step 4 – Get the accuracy score R2

# R-squared value is a number between 0 and 1. 
# And the closer it is to 1 the more accurate your
# linear regression model is

from sklearn.metrics import r2_score
r2_score(y, predict(x))
 0.8777480188408424 

Step 5 – Visualise the linear regression in a scatter plot

# range you want to display the linear regression model
# over — in our case it’s between 0 and 50 hours.
x_lin_reg = range(0, 51)

# calculates the y values for all the x values between 0 and 50
y_lin_reg = predict(x_lin_reg)

plt.scatter(x, y)
plt.plot(x_lin_reg, y_lin_reg, c = 'r')

This was just to show the list of values calculated.

print(y_lin_reg)
 [-3.9057602  -1.89108532  0.12358955  2.13826442  4.15293929  6.16761416   8.18228903 10.1969639  12.21163878 14.22631365 16.24098852 18.25566339  20.27033826 22.28501313 24.299688   26.31436288 28.32903775 30.34371262  32.35838749 34.37306236 36.38773723 38.4024121  40.41708698 42.43176185  44.44643672 46.46111159 48.47578646 50.49046133 52.50513621 54.51981108  56.53448595 58.54916082 60.56383569 62.57851056 64.59318543 66.60786031  68.62253518 70.63721005 72.65188492 74.66655979 76.68123466 78.69590953  80.71058441 82.72525928 84.73993415 86.75460902 88.76928389 90.78395876  92.79863363 94.81330851 96.82798338] 

Ein Kommentar

Kommentar verfassen

Trage deine Daten unten ein oder klicke ein Icon um dich einzuloggen:

WordPress.com-Logo

Du kommentierst mit Deinem WordPress.com-Konto. Abmelden /  Ändern )

Google Foto

Du kommentierst mit Deinem Google-Konto. Abmelden /  Ändern )

Twitter-Bild

Du kommentierst mit Deinem Twitter-Konto. Abmelden /  Ändern )

Facebook-Foto

Du kommentierst mit Deinem Facebook-Konto. Abmelden /  Ändern )

Verbinde mit %s