import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
Package Version
---------- ---------
python 3.8.8
numpy 1.21.5
pandas 1.4.2
matplotlib 3.5.1
seaborn 0.11.2
2D
2D with matplotlib
df = pd.DataFrame(
np.random.randint(0, 100, size=(5, 5)),
index=list('ABCDE'),
columns=list('ABCDE')
)
df
| A | B | C | D | E |
|---|
| A | 9 | 51 | 85 | 42 | 33 |
|---|
| B | 26 | 50 | 93 | 89 | 15 |
|---|
| C | 9 | 0 | 72 | 19 | 93 |
|---|
| D | 71 | 65 | 53 | 86 | 8 |
|---|
| E | 51 | 57 | 57 | 55 | 26 |
|---|
from matplotlib.colors import LogNorm
import matplotlib as mpl
import functools
def basic_chart(default_figsize=(6, 4), grid=True):
""" Boilerplate chart labels and axis formatting. """
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Pull out formatting-related kwargs before calling function
format_kwargs = {}
for key in ['title', 'subtitle', 'ylabel', 'xlabel', 'hide_y', 'xlim']:
if key in kwargs:
format_kwargs[key] = kwargs[key]
del kwargs[key]
if 'figsize' in kwargs:
figsize = kwargs['figsize']
del kwargs['figsize']
else:
figsize = default_figsize
# Initialize chart
mpl.rcParams.update({'font.size': 10})
fig, ax = plt.subplots(figsize=figsize, dpi=100)
fig.set_facecolor('white')
# Run the primary plot
fig, ax = func(fig, ax, *args, **kwargs)
err = 'Wrapped function must return a (fig, ax) pair'
assert isinstance(fig, mpl.figure.Figure) and isinstance(ax, mpl.axes._axes.Axes), err
# Formatting
if 'title' in format_kwargs:
ax.set_title(format_kwargs['title'])
#ax.text(0, 1.08, kwargs['title'], size='xx-large', transform=ax.transAxes)
if 'subtitle' in format_kwargs:
ax.text(0, 1.04, format_kwargs['subtitle'], size='medium', color='grey', transform=ax.transAxes)
if 'ylabel' in format_kwargs:
ax.set_ylabel(format_kwargs['ylabel'], size='medium', labelpad=10)
if 'xlabel' in format_kwargs:
ax.set_xlabel(format_kwargs['xlabel'], size='large', labelpad=10)
if 'hide_y' in format_kwargs and format_kwargs['hide_y']:
ax.yaxis.set_ticklabels([])
if 'xlim' in format_kwargs:
ax.set_xlim(*format_kwargs['xlim'])
if grid:
ax.grid(which='major', color='lightgrey', linestyle='--')
plt.show()
return wrapper
return decorator
@basic_chart(default_figsize=(6, 6), grid=False)
def heatmap(fig, ax, df: pd.DataFrame, log_transform=False, **kwargs) -> None:
""" Wrapper for sns.heatmap with optional log transform
Args:
df: a pandas DataFrame in heatmap format (x-axis as columns, y-axis as rows)
"""
df = df.copy()
if log_transform:
log_scale = LogNorm(vmin=(df.min().min()+1), vmax=df.max().max())
df += 0.01
else:
log_scale = None
ax = sns.heatmap(
df,
ax=ax,
annot=True,
cmap='Blues',
fmt='.0f',
cbar=False,
norm=log_scale
)
# Post-plot formatting
fig.canvas.draw()
for tick in ax.get_xticklabels():
tick.set_rotation(90)
for tick in ax.get_yticklabels():
tick.set_rotation(0)
return fig, ax
heatmap(df, figsize=(3,3))
