Geoff Ruddock

Heatmaps

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
Package     Version
----------  ---------
python      3.8.8
numpy       1.21.5
pandas      1.4.2
matplotlib  3.5.1
seaborn     0.11.2

2D

2D with matplotlib

df = pd.DataFrame(
    np.random.randint(0, 100, size=(5, 5)),
    index=list('ABCDE'),
    columns=list('ABCDE')
)

df

A B C D E
A 9 51 85 42 33
B 26 50 93 89 15
C 9 0 72 19 93
D 71 65 53 86 8
E 51 57 57 55 26
from matplotlib.colors import LogNorm
import matplotlib as mpl
import functools

def basic_chart(default_figsize=(6, 4), grid=True):
    """ Boilerplate chart labels and axis formatting. """
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # Pull out formatting-related kwargs before calling function
            format_kwargs = {}
            for key in ['title', 'subtitle', 'ylabel', 'xlabel', 'hide_y', 'xlim']:
                if key in kwargs:
                    format_kwargs[key] = kwargs[key]
                    del kwargs[key]
            if 'figsize' in kwargs:
                figsize = kwargs['figsize']
                del kwargs['figsize']
            else:
                figsize = default_figsize

            # Initialize chart
            mpl.rcParams.update({'font.size': 10})
            fig, ax = plt.subplots(figsize=figsize, dpi=100)
            fig.set_facecolor('white')

            # Run the primary plot
            fig, ax = func(fig, ax, *args, **kwargs)
            err = 'Wrapped function must return a (fig, ax) pair'
            assert isinstance(fig, mpl.figure.Figure) and isinstance(ax, mpl.axes._axes.Axes), err

            # Formatting
            if 'title' in format_kwargs:
                ax.set_title(format_kwargs['title'])
                #ax.text(0, 1.08, kwargs['title'], size='xx-large', transform=ax.transAxes)
            if 'subtitle' in format_kwargs:
                ax.text(0, 1.04, format_kwargs['subtitle'], size='medium', color='grey', transform=ax.transAxes)
            if 'ylabel' in format_kwargs:
                ax.set_ylabel(format_kwargs['ylabel'], size='medium', labelpad=10)
            if 'xlabel' in format_kwargs:
                ax.set_xlabel(format_kwargs['xlabel'], size='large', labelpad=10)
            if 'hide_y' in format_kwargs and format_kwargs['hide_y']:
                ax.yaxis.set_ticklabels([])
            if 'xlim' in format_kwargs:
                ax.set_xlim(*format_kwargs['xlim'])
            if grid:
                ax.grid(which='major', color='lightgrey', linestyle='--')

            plt.show()
            
        return wrapper

    return decorator


@basic_chart(default_figsize=(6, 6), grid=False)
def heatmap(fig, ax, df: pd.DataFrame, log_transform=False, **kwargs) -> None:
    """ Wrapper for sns.heatmap with optional log transform
    
    Args:
        df: a pandas DataFrame in heatmap format (x-axis as columns, y-axis as rows)
    
    """

    df = df.copy()

    if log_transform:
        log_scale = LogNorm(vmin=(df.min().min()+1), vmax=df.max().max())
        df += 0.01
    else:
        log_scale = None
    
        
    ax = sns.heatmap(
        df,
        ax=ax,
        annot=True,
        cmap='Blues',
        fmt='.0f',
        cbar=False,
        norm=log_scale
    )
    
    # Post-plot formatting
    fig.canvas.draw()
    
    for tick in ax.get_xticklabels():
        tick.set_rotation(90)
        
    for tick in ax.get_yticklabels():
        tick.set_rotation(0)
    
    return fig, ax


heatmap(df, figsize=(3,3))

png


comments powered by Disqus