Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plot time series with colorbar in pandas + matplotlib

I'm trying to plot a colorbar below this chart, where the color depends on when each of the time series starts: Partial cumulative returns plotted over 2 years

The code generated to create the plot is this:

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import seaborn as sns
sns.set()

def partial_cum_returns(start, cum_returns):
    return cum_returns.loc[start:].div(cum_returns.loc[start])

index = pd.DatetimeIndex(pd.date_range('20170101', '20190101', freq='W'))
np.random.seed(5)
returns = pd.Series(np.exp(np.random.normal(loc=0, scale=0.05, size=len(index))), index=index)
cum_returns = returns.cumprod()
df = pd.DataFrame(index=index)
for date in index:
    df[date] = partial_cum_returns(date, cum_returns)

df.plot(legend=False, colormap='viridis');
plt.colorbar();

But when executing this error appears:

RuntimeError: No mappable was found to use for colorbar creation. First define a mappable such as an image (with imshow) or a contour set (with contourf).

I've tried to add the colorbar in different ways, like the fig, ax = plt.figure()... one, but I couldn't make it work so far. Any ideas? Thanks!

like image 901
Xoel Avatar asked Oct 30 '25 13:10

Xoel


1 Answers

The first point is that you need to create a ScalarMappable for your colorbar. You need to define the colormap, which in your case is 'viridis' and specify the maximum and minimum of the values you want for the colorbar. Then because it uses numeric time values you want to reformat those.

import matplotlib.pyplot as plt
import pandas as pd

# Define your mappable for colorbar creation
sm = plt.cm.ScalarMappable(cmap='viridis', 
                           norm=plt.Normalize(vmin=df.index.min().value,
                                              vmax=df.index.max().value))
sm._A = []  

df.plot(legend=False, colormap='viridis', figsize=(12,7));

cbar = plt.colorbar(sm);
# Change the numeric ticks into ones that match the x-axis
cbar.ax.set_yticklabels(pd.to_datetime(cbar.get_ticks()).strftime(date_format='%b %Y'))

enter image description here

like image 140
ALollz Avatar answered Nov 01 '25 05:11

ALollz



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!