‹ Geoff Ruddock

matplotlib

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

print(mpl.__version__)
3.3.4

Plot types

Scatter plot

Basic

def plot_something(srs: pd.Series) -> None:
    fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
    
    ax.scatter(srs.index, srs.values, label='', zorder=2)
    
    ax.set_title('')
    ax.set_ylabel('')
    ax.set_xlabel('')
    
    ax.grid()
    plt.savefig('chart.png')
    
srs = pd.Series(np.random.randint(20, size=10))
plot_something(srs)

png

Labelled scatter plot near intercept

Here’s an example of a simple scatter plot using labeled data points centered around the intercept. I initially wrote this as a base for a function which visualizes how the k-means algorithm works step-by-step. A couple things to note:

  1. We set higher zorder for the elements we want to appear “on top” of others
  2. The core ax.scatter function takes an array of data points, but ax.annotate does not, so we need to loop.
data = np.asarray([
    [2, 2],
    [-1, 1],
    [3, 1],
    [0, -1],
    [-2, -2]
])

def scatter_plot_near_intercept(data):
    fig, ax = plt.subplots(figsize=(5, 5), dpi=100)
    
    # Format grid
    ax.grid(zorder=1)
    ax.axvline(0, c='black', zorder=2)
    ax.axhline(0, c='black', zorder=2)
    ax.set_xlim(-4, 4)
    ax.set_ylim(-4, 4)
    
    # Plot data points
    ax.scatter(*data.T, s=256, c='yellow', edgecolors='black', zorder=3)
    
    for i, (x, y) in enumerate(data):
        ax.annotate(i, (x, y), size=10, ha='center', va='center')
    
scatter_plot_near_intercept(data)

png

Histogram

Default

Here is what we get “out-of-the-box” from seaborn.distplot. I typically start here, because it has much more sensible defaults (bin width, etc.) than plt.hist, and ultimately returns a matplotlib object in the end anyways.

data = np.random.exponential(size=1000)
sns.distplot(data, kde=False);
/Users/geoffruddock/opt/anaconda3/lib/python3.8/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)

png

Clean

We can tidy up this output a bit by hiding the y-axis, which is often irrelevant when our primary goal is to visualize the shape of a distribution. If your plot has a grid, use ax.yaxis.set_ticklabels([]) instead of ax.get_yaxis().set_visible(False) so that the horizontal grid lines remain.

def histogram(data: np.ndarray, **kwargs) -> (mpl.figure.Figure, mpl.axes.Axes):
    """ Plot a seaborn histogram (no KDE) with annotated summary stats. """
    
    # plot core date
    fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
    sns.distplot(data, **kwargs)
    
    # formatting
    ax.get_yaxis().set_visible(False)
    
    return fig, ax

histogram(data, kde=False);

png

Annotate with summary stats

If we’re interested in the same of a distribution, we’re probably also interested in some summary statistics of the data. Rather than calculating these separately, we can display them in-line on the plot using annotations.

We generally want to leave a bit of padding between the vertical line and the text label, to improve readability. But we want this padding to be invariant to the scale of the data, so that the formatting looks consistent. We can achieve this by using the transform=ax.transAxes argument on ax.text, which lets us specify (x, y) position in relative coordinates [0, 1] rather than the coordinates of the data itself. This lets us specify formatting such as pad or vertical position relatively. But the horizontal position of the text is dependent on the data, so we need to convert the data coordinates to relative coordinates using ax.transLimits.transform((raw_value, 0))[0]. More details at Transformations Tutorial (matplotlib docs).

def annotate_distplot(ax: mpl.axes.Axes,
                      data: np.ndarray,
                      mean_color='grey',
                      median_color='pink',
                      pad=0.01,
                      precision=2,
                      log_transform=False) -> None:
    """ Draw lines and annotate text on a histogram to indicate mean and median values. """
    
    text_defaults = dict(va='center', ha='left', transform=ax.transAxes)
    mean = mean_label = np.mean(data)
    median = median_label = np.median(data)
    
    if log_transform:
        mean, mean_label = np.log10(mean), mean
        median, median_label = np.log10(median), median
        
    mean_label = np.round(mean_label, precision)
    median_label = np.round(median_label, precision)
    
    # annotate mean
    ax.axvline(mean, linestyle='--', color=mean_color)
    rel_pos = ax.transLimits.transform((mean, 0))[0]
    ax.text(rel_pos+pad, 0.95, f'Mean: {mean_label}', color=mean_color, **text_defaults)
    
    # annotate median
    ax.axvline(median, linestyle='--', color=median_color)
    rel_pos = ax.transLimits.transform((median, 0))[0]
    ax.text(rel_pos+pad, 0.88, f'Median: {median_label}', color=median_color, **text_defaults)
    
    
_, ax = histogram(data, kde=False)
annotate_distplot(ax, data)

png

With log transform on x-axis

Suppose our data has some sort of long-tailed distribution—such as Exponential or Poisson—and is difficult to visualize on an absolute scale. We could get a better sense of “orders of magnitude” by performing a log transformation on the x-axis.

We can use ax.set_xscale('log') to automatically handle the log transformation, but its output may not be ideal. The log-distributed ticks and scientific notation may be confusing for a non-technical audience, and may detract from the main takeaway of your plot.

_, ax = histogram(data, kde=False)
ax.set_xscale('log')

png

Alternatively, we can perform our own log transform on the raw data, and then manipulate the x labels to express them as orders of magnitude. Note that we must execute fig.canvas.draw() before calling ax.get_xticklabels(), otherwise it will return an array of empty labels.

def log10_xlabels(ax: mpl.axes.Axes) -> None:
    """ Transform xlabel text to reflect a log10 transformation. """
    labels = [obj.get_text() for obj in ax.get_xticklabels()]
    new_labels = [10**int(x.replace('−', '-')) for x in labels]
    ax.set_xticklabels(new_labels)
    ax.yaxis.set_ticklabels([])
    

def log_histogram(data, **kwargs):
    """ """
    log_data = np.log10(data)
    fig, ax = histogram(data, kde=False)
    fig.canvas.draw()  # necessary to populate xticks
    log10_xlabels(ax)

    
log_histogram(data)
<ipython-input-8-60a03b6d7f65>:5: UserWarning: FixedFormatter should only be used together with FixedLocator
  ax.set_xticklabels(new_labels)

png

Bar chart

data = pd.Series([0.2, 0.3, 0.1, 0.4], index=['A', 'B', 'C', 'D'])
data
A    0.2
B    0.3
C    0.1
D    0.4
dtype: float64
data.plot(kind='bar');

png

plt.bar(x=range(4), height=data);

png

Horizontal

data.plot(kind='barh');

png

plt.barh(y=range(4), width=data);

png

Plotting functions

Plot a pdf

x = np.arange(-4, 4, 0.1)

def plot_pdf(x, y, title=None) -> None:
    fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
    
    ax.plot(x, y)
    
    ax.set_title(title)
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.grid()
    ax.axvline(0, c='black', linewidth=1, zorder=2)
    ax.axhline(0, c='black', linewidth=1, zorder=2)
    plt.savefig('chart.png')
    
plot_pdf(x, x*np.exp(-x**2/2), r'Plot of $z e^{-\frac{z^2}{2}}$')

png

Plot multiple pdfs

from scipy.stats import norm

def plot_pdfs(x, pdfs: dict) -> None:
    fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
    
    for label, pdf in pdfs.items():
        ax.plot(x, pdf, label=label)
    
    ax.set_title('')
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.legend(loc='best')
    ax.grid()
    ax.axvline(0, c='black', linewidth=1, zorder=2)
    ax.axhline(0, c='black', linewidth=1, zorder=2)
    plt.savefig('chart.png')
    
x = np.arange(-3, 3, 0.1)
log_lk = lambda z: -(1/2)*np.log(2*np.pi) - z**2/2
dldz = lambda z: -z
dldz_sqr = lambda z: z**2

pdfs = {
    r'$f(x|\theta)$': norm.pdf(x),
    r'$\ell(x|θ)$': log_lk(x),
    r'$\frac{d \ell}{dx}(x|θ)$': dldz(x),
    r'$[\frac{d \ell}{dx}(x|θ)]^2$': dldz_sqr(x)
}

plot_pdfs(x, pdfs)

png

Discontinuous function

def plot_discontinuous_function(x, y, discontinuities=None, title=None) -> None:
    fig, ax = plt.subplots(figsize=(4, 4), dpi=100)
    
    for dc in discontinuities:
        x_seg = x[x<dc]
        y_seg = np.compress(x<dc, y)
        ax.plot(x_seg, y_seg, color='blue', zorder=3)
        ax.scatter(x_seg[-1], y_seg[-1], marker='o', zorder=4, facecolor='white', edgecolor='blue')
        
    # plot data between final discontinuity and end of series
    x_seg = x[x>dc]
    y_seg = np.compress(x>dc, y)
    ax.plot(x_seg, y_seg, color='blue', zorder=3)
    ax.scatter(x_seg[0], y_seg[0], marker='o', zorder=4, facecolor='white', edgecolor='blue')
    
    ax.set_title(title)
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.grid()
    ax.axvline(0, c='black', linewidth=1, zorder=2)
    ax.axhline(0, c='black', linewidth=1, zorder=2)
    plt.savefig('chart.png')
    
    
x = np.arange(-2, 2, 0.01)
y = 1/x - 1/np.abs(x)

plot_discontinuous_function(x, y, [0], r'')

png

Cumulative distribution functions

def cdf(x, pdfs: dict) -> None:
    fig, ax = plt.subplots(figsize=(4, 4), dpi=100)

    for name, pdf in pdfs.items():
        _ = ax.plot(x, pdf, label=name)
        
    plt.show()

x = np.arange(-3, 3, 0.1)
y = norm.pdf(x)

cdf(x, {'one': y})

png

Formatting

from scipy.stats import norm

def basic_plot():
    fig, ax = plt.subplots(figsize=(4, 4), dpi=100)

    x = np.arange(-3, 3, 0.1)
    y = norm.pdf(x)
    ax.plot(x, y)
        
    return fig, ax

basic_plot()
(<Figure size 400x400 with 1 Axes>, <AxesSubplot:>)

png

Legend

Multiple plots

Multiple rows and columns

# larger grid
fig, axes = plt.subplots(figsize=(6, 6),
                         nrows=2,
                         ncols=2,
                         sharex=True,
                         sharey=True,
                         gridspec_kw={'wspace': 0.1},
                         dpi=100)
                               
axes = axes.flat

png

Unqual sized

fig, (ax1, ax2) = plt.subplots(nrows=2,
                               sharex=True,
                               gridspec_kw={'height_ratios': [3, 1]},
                               figsize=(18, 9),
                               dpi=100)        

png

Axis formatting

Percent

https://matplotlib.org/stable/api/ticker_api.html#matplotlib.ticker.PercentFormatter

fig, ax = basic_plot()

import matplotlib.ticker as mtick
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))

png

from matplotlib.ticker import FuncFormatter

def currency(x, pos):
    'The two args are the value and tick position'
    if x >= 1000000:
        return '${:1.1f}M'.format(x*1e-6)
    return '${:1.0f}K'.format(x*1e-3)

formatter = FuncFormatter(currency)
ax.xaxis.set_major_formatter(formatter)

Dual Y axis

def add_relabelled_y_axis_to_right_side(ax, new_labels: list):
    """ Add a set of labels to the RHS of a plot, to aid in interpretability. """
    label_objs = list(ax.get_yticklabels())
    for label, orig in zip(label_objs, new_labels):
        label._text = orig
    
    # create a new y-axis on RHS
    ax2 = ax.twinx()

    # Copy tick locations
    _ = ax2.set_yticks(ax.get_yticks())

    # Copy tick labels
    _ = ax2.set_yticklabels(label_objs)

    # Copy axis limits
    _ = ax2.set_ylim(ax.get_ylim())

    # Resize RHS ticks to match LHS
    if 'labelsize' in ax.yaxis._major_tick_kw:
        label_size = ax.yaxis._major_tick_kw['labelsize']
        for tick in ax2.yaxis.majorTicks:
            tick.label2.set_fontsize(label_size)
            
    return ax


fig, ax = plt.subplots()
x = np.arange(1, 101)
ax.plot(x, x**2)
ax.set_ylim(0, 11000)

add_relabelled_y_axis_to_right_side(ax, ['Zero', 'Two', 'Four', 'Six', 'Eight', 'Ten']);

png

Unsorted

fig, ax = plt.subplots()
x = np.arange(1, 101)
ax.plot(x, x**2)
ax.set_ylim(0, 11000)
ax.axhline(2000, linestyle='--', color='purple')
<matplotlib.lines.Line2D at 0x7f9381003e20>

png

Other

Auto-tilt x-axis labels

fig.autofmt_xdate()

Change interval of dates on x-axis

import matplotlib.dates as mdates
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=1))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))

📚 Further reading

Real Python: Python Plotting With Matplotlib – Great introductory resource, with an explanation of the two different APIs, which is a key source of confusion for new users.

Python Data Science Handbook: Visualization with Matplotlib – Collection of 10+ notebooks with details on how to achieve specific tweaks.

comments powered by Disqus