import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
Package Version
---------- ---------
python 3.8.8
numpy 1.21.5
pandas 1.4.2
matplotlib 3.5.1
seaborn 0.11.2
2D
2D with matplotlib
df = pd.DataFrame(
np.random.randint(0, 100, size=(5, 5)),
index=list('ABCDE'),
columns=list('ABCDE')
)
df
|
A |
B |
C |
D |
E |
A |
9 |
51 |
85 |
42 |
33 |
B |
26 |
50 |
93 |
89 |
15 |
C |
9 |
0 |
72 |
19 |
93 |
D |
71 |
65 |
53 |
86 |
8 |
E |
51 |
57 |
57 |
55 |
26 |
from matplotlib.colors import LogNorm
import matplotlib as mpl
import functools
def basic_chart(default_figsize=(6, 4), grid=True):
""" Boilerplate chart labels and axis formatting. """
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Pull out formatting-related kwargs before calling function
format_kwargs = {}
for key in ['title', 'subtitle', 'ylabel', 'xlabel', 'hide_y', 'xlim']:
if key in kwargs:
format_kwargs[key] = kwargs[key]
del kwargs[key]
if 'figsize' in kwargs:
figsize = kwargs['figsize']
del kwargs['figsize']
else:
figsize = default_figsize
# Initialize chart
mpl.rcParams.update({'font.size': 10})
fig, ax = plt.subplots(figsize=figsize, dpi=100)
fig.set_facecolor('white')
# Run the primary plot
fig, ax = func(fig, ax, *args, **kwargs)
err = 'Wrapped function must return a (fig, ax) pair'
assert isinstance(fig, mpl.figure.Figure) and isinstance(ax, mpl.axes._axes.Axes), err
# Formatting
if 'title' in format_kwargs:
ax.set_title(format_kwargs['title'])
#ax.text(0, 1.08, kwargs['title'], size='xx-large', transform=ax.transAxes)
if 'subtitle' in format_kwargs:
ax.text(0, 1.04, format_kwargs['subtitle'], size='medium', color='grey', transform=ax.transAxes)
if 'ylabel' in format_kwargs:
ax.set_ylabel(format_kwargs['ylabel'], size='medium', labelpad=10)
if 'xlabel' in format_kwargs:
ax.set_xlabel(format_kwargs['xlabel'], size='large', labelpad=10)
if 'hide_y' in format_kwargs and format_kwargs['hide_y']:
ax.yaxis.set_ticklabels([])
if 'xlim' in format_kwargs:
ax.set_xlim(*format_kwargs['xlim'])
if grid:
ax.grid(which='major', color='lightgrey', linestyle='--')
plt.show()
return wrapper
return decorator
@basic_chart(default_figsize=(6, 6), grid=False)
def heatmap(fig, ax, df: pd.DataFrame, log_transform=False, **kwargs) -> None:
""" Wrapper for sns.heatmap with optional log transform
Args:
df: a pandas DataFrame in heatmap format (x-axis as columns, y-axis as rows)
"""
df = df.copy()
if log_transform:
log_scale = LogNorm(vmin=(df.min().min()+1), vmax=df.max().max())
df += 0.01
else:
log_scale = None
ax = sns.heatmap(
df,
ax=ax,
annot=True,
cmap='Blues',
fmt='.0f',
cbar=False,
norm=log_scale
)
# Post-plot formatting
fig.canvas.draw()
for tick in ax.get_xticklabels():
tick.set_rotation(90)
for tick in ax.get_yticklabels():
tick.set_rotation(0)
return fig, ax
heatmap(df, figsize=(3,3))
comments powered by Disqus