Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matplotlib not plotting full line

As part of my research project, I was working on performing linear regression with some data using matplotlib. Unfortunately, I am unable to get my line to touch the origin; matplotlib seems to cut it off at the minimum value of my dataset. How can I fix this and get my line to touch the origin? As reference, here is my code:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from statsmodels import api as sm


def file_analysis(csv_file, state):
    """
    This method takes in a file object and the name of a state.

    :param csv_file: Pass in a csv file object.
    :param state: Name of the state as a string.
    :return: None.
    """
    data = pd.read_csv(csv_file)
    data = data[["Total Cases", "Total Deaths"]]

    y = data["Total Deaths"]
    x = data["Total Cases"]

    results = sm.OLS(y, x).fit()

    plt.scatter(x, y)
    yhat = results.params[0] * x
    print(results.params)

    plt.ylim(ymin=0)
    plt.xlim(xmin=0)
    plt.margins(0)

    fig = plt.plot(x, yhat, lw=4, c="orange", label="regressionline")

    plt.xlabel("Total Cases", fontsize=20)
    plt.ylabel('Total Deaths', fontsize=20)
    plt.title(state)

    plt.savefig(state + "_scatterplot" + ".png")
    plt.show()

    with open(state + "_analysis.txt", "w") as file:
        file.write(results.summary().as_text())

And here is the resulting scatter-plot after passing in the name of the state and the csv file for the state:enter image description here

like image 834
Dude156 Avatar asked May 12 '26 17:05

Dude156


2 Answers

You should just change the x-values that you want in your regression to include 0.

yhat = results.params[0] * range(0, x.max())

fig = plt.plot(range(0, x.max()), yhat, lw=4, c="orange", label="regressionline")
like image 159
CopyOfA Avatar answered May 14 '26 06:05

CopyOfA


I think the reason your line does not touch the origin is that your are only plotting it at the extent of your data. By calculating the predicted deaths like this yhat = results.params[0] * x you are restricting the line to points in your dataset. You can easily fix this if you supply a wider range of x parameters:

newX = range(0, 80)
yhat = results.params[0] * newX
fig = plt.plot(newX, yhat, lw=4, c="orange", label="regressionline")

By the way, are you fitting the model without intercept on purpose?

like image 32
Michael Mitter Avatar answered May 14 '26 08:05

Michael Mitter



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!