from pprint import pprint
from typing import Tuple, List
from IPython.display import display
import numpy as np
import pandas as pd
import plotly as px
import plotly.graph_objects as go
Package Version
--------- ---------
python 3.9.12
pandas 1.4.2
plotly 5.6.0
Basic diagram from dataframe
Generate dummy data
Let’s suppose we have a dataset in which each row is a unique path across potential values within any number of “levels” or stages.
This is a reasonably easy format to reach with an SQL query using GROUP BY 1, 2, 3
, etc.
from itertools import product
states = {
'Country': ['Canada', 'USA', 'Mexico'],
'Occupation': ['Doctor', 'Lawyer', 'Banker'],
'Life satisfaction': ['Happy', 'Meh', 'Unhappy']
}
rows = list(product(*states.values()))
vals = np.random.randint(100, size=len(rows))
df_unique_paths = pd.DataFrame(rows, columns=states.keys()).assign(n=vals)
with pd.option_context('display.max_rows', 10):
display(df_unique_paths)
Country | Occupation | Life satisfaction | n | |
---|---|---|---|---|
0 | Canada | Doctor | Happy | 19 |
1 | Canada | Doctor | Meh | 73 |
2 | Canada | Doctor | Unhappy | 47 |
3 | Canada | Lawyer | Happy | 16 |
4 | Canada | Lawyer | Meh | 58 |
... | ... | ... | ... | ... |
22 | Mexico | Lawyer | Meh | 39 |
23 | Mexico | Lawyer | Unhappy | 95 |
24 | Mexico | Banker | Happy | 7 |
25 | Mexico | Banker | Meh | 83 |
26 | Mexico | Banker | Unhappy | 32 |
27 rows × 4 columns
Generate edge list
Now let’s translate the above dummy dataframe into a set of weighted edges to/from nodes.
def edges_from_unique_paths(df: pd.DataFrame) -> pd.DataFrame:
""" Generate a "long" set of edges, from a "wide" set of unique paths.
Ignores any edges starting with an underscore.
Args:
df: DataFrame with n-1 columns that represent "levels",
followed by a single numeric count/weight column.
Returns:
A DataFrame with three columns: (source, target, weight)
"""
from collections import defaultdict
edges = defaultdict(int)
for idx, x in df.iterrows():
only_visible = x.loc[lambda x: x.str.startswith('_') != True]
n = only_visible['n']
paths = only_visible.drop('n')
for a, b in zip(paths[:-1], paths[1:]):
edges[(a, b)] += n
return pd.DataFrame(
[(a, b, n) for (a, b), n in edges.items()],
columns=['source', 'target', 'count']
)
df_edges = edges_from_unique_paths(df_unique_paths)
with pd.option_context('display.max_rows', 10):
display(df_edges)
source | target | count | |
---|---|---|---|
0 | Canada | Doctor | 139 |
1 | Doctor | Happy | 126 |
2 | Doctor | Meh | 98 |
3 | Doctor | Unhappy | 153 |
4 | Canada | Lawyer | 98 |
... | ... | ... | ... |
13 | USA | Lawyer | 148 |
14 | USA | Banker | 146 |
15 | Mexico | Doctor | 114 |
16 | Mexico | Lawyer | 140 |
17 | Mexico | Banker | 122 |
18 rows × 3 columns
If the structure of our data were less predictable—e.g. if some paths do not “pass” through all columns—it may be desirable to directly specify the edge list.
Generate plotly params
The main plotting function (that we’ll use below) go.Sankey()
takes two arguments:
- A dict of node-related info (labels, colors, etc.)
- A dict of link-related data (to/from, weights)
from functools import reduce
def make_sankey_params_v1(df: pd.DataFrame) -> dict:
""" Generate parameter dicts for go.Figure plotting function """
# Unpack columns into lists
sources, targets, values = df.values.T.tolist()
# Create list of unique labels across node columns (source, target)
labels = list(df['source'].pipe(set) | df['target'].pipe(set))
# Map actual labels to their index value
source_idx = list(map(labels.index, sources))
target_idx = list(map(labels.index, targets))
# Assemble final outputs into expected format
nodes_dict = {'label': labels}
links_dict = {
'source': source_idx,
'target': target_idx,
'value': values
}
return nodes_dict, links_dict
nodes, links = make_sankey_params_v1(df_edges)
pprint(nodes, compact=True)
print('')
pprint(links, compact=True)
{'label': ['Meh', 'Banker', 'Mexico', 'Doctor', 'Canada', 'Lawyer', 'USA',
'Happy', 'Unhappy']}
{'source': [4, 3, 3, 3, 4, 5, 5, 5, 4, 1, 1, 1, 6, 6, 6, 2, 2, 2],
'target': [3, 7, 0, 8, 5, 7, 0, 8, 1, 7, 0, 8, 3, 5, 1, 3, 5, 1],
'value': [139, 126, 98, 153, 98, 96, 106, 184, 126, 156, 96, 142, 124, 148,
146, 114, 140, 122]}
Visualize
data = go.Sankey(node=nodes, link=links)
fig = go.Figure(data)
default_margins = {'l': 25, 'r': 25, 't': 50, 'b': 0}
fig.update_layout(title='Basic example sankey diagram', margin=default_margins)
Formatting
Specify node colors
We can also manually set the color of nodes. Suppose we are only interested in doing so for a handful of specific nodes, but would like to let plotly select default colors for the rest.
These are the colors in the qualitative
color palette, with 80% opacity:
from IPython.core.display import HTML
def hex_to_rgba(h: str, a: float = 0.8) -> str:
""" Convert a hex color into rgba with a specified opacity. """
r, g, b = tuple(int(h.strip('#')[i:i+2], 16) for i in (0, 2, 4))
return f'rgba({r},{g},{b},{a})'
html_output = ''
colors_rgba = [hex_to_rgba(x) for x in px.colors.qualitative.Plotly]
for c in colors_rgba:
html_output += f'<span style="color: {c}; font-weight: bold">{c}</span><br>'
HTML(html_output)
rgba(99,110,250,0.8)
rgba(239,85,59,0.8)
rgba(0,204,150,0.8)
rgba(171,99,250,0.8)
rgba(255,161,90,0.8)
rgba(25,211,243,0.8)
rgba(255,102,146,0.8)
rgba(182,232,128,0.8)
rgba(255,151,255,0.8)
rgba(254,203,82,0.8)
If our diagram has more nodes than colors in the palette, they will be re-used. This may be confusing for the nodes we wish to manually specify. So let’s exclude those colors from the palette.
from functools import reduce
from itertools import cycle
nodes, links = make_sankey_params_v1(df_edges)
colors_map = {
'Happy': 'rgba(0,204,150,0.8)',
'Meh': 'rgba(254,203,82,0.8)',
'Unhappy': 'rgba(239,85,59,0.8)'
}
remaining_colors = cycle(set(colors_rgba) - set(colors_map.values()))
nodes['color'] = [colors_map.get(x, next(remaining_colors)) for x in nodes['label']]
data = go.Sankey(node=nodes, link=links)
fig = go.Figure(data)
fig.update_layout(title='Sankey diagram with specified node colours', margin=default_margins)
Misc
To manually resize the plot:
fig.update_layout(autosize=False, width=400, height=400)
Specify horizontal position
Scenario
Suppose we are visualizing some flow such as an ecommerce funnel, with drop-offs at each stage.
A v1 might look something like this:
funnel_df = pd.DataFrame([
['Visit', 'Search', 80],
['Visit', 'Bounce', 20],
['Search', 'Click', 60],
['Search', 'Bounce', 20],
['Click', 'Buy', 40],
['Click', 'Bounce', 20]
], columns=['source', 'target', 'count'])
nodes, links = make_sankey_params_v1(funnel_df)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='An example ecommerce funnel', margin=default_margins)
Having a single node for “Bounce” at the end accentuates the overall conversion rate, but at the expense of making the per-stage conversion rate unclear. So we might split these nodes up:
funnel_df2 = pd.DataFrame([
['Visit', 'Search', 80],
['Visit', 'Bounce (from visit)', 20],
['Search', 'Click', 60],
['Search', 'Bounce (from search)', 20],
['Click', 'Buy', 40],
['Click', 'Bounce (from click)', 20],
], columns=['source', 'target', 'count'])
nodes, links = make_sankey_params_v1(funnel_df2)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Separate "Bounce" nodes looks funny', margin=default_margins)
But this looks even worse, because by default, plotly puts all terminal nodes at the end.
Invisible nodes
One (somewhat clunky) approach is to use dummy nodes to enforce positioning, then set them as invisible. Inspiration:
- Plotly Sankey: how to avoid autoplacing all nodes to the right? → workaround to make invisible dummy nodes.
- Hide plotly sankey nodes and links while preserving total node value → nice list comprehension to set node and link colors for “invisible” dummy nodes.
from functools import reduce
def make_sankey_params_invisible_nodes(df: pd.DataFrame) -> dict:
""" Generate parameter dicts for go.Figure plotting function """
# Unpack columns into lists
sources, targets, values = df.values.T.tolist()
# Create list of unique labels across node columns (source, target)
labels = list(df['source'].pipe(set) | df['target'].pipe(set))
# Map actual labels to their index value
source_idx = list(map(labels.index, sources))
target_idx = list(map(labels.index, targets))
# Assemble final outputs into expected format
display_labels = [x if not x.startswith('_') else '' for x in labels]
node_colors = [
px.colors.qualitative.Plotly[i%len(px.colors.qualitative.Plotly)]
if not x.startswith('_')
else 'rgba(0,0,0,0)'
for i, x in enumerate(labels)
]
nodes_dict = {
'label': display_labels,
'color': node_colors,
'line': {'width': 0}
}
link_colors = [
'rgba(0,0,0,0)'
if x
else 'rgba(64,64,64,0.2)'
for x in df[['source', 'target']].apply(lambda x: x.str.startswith('_')).any(axis=1).tolist()
]
links_dict = {
'source': source_idx,
'target': target_idx,
'value': values,
'color': link_colors
}
return nodes_dict, links_dict
funnel_df3 = pd.DataFrame([
['Visit', 'Search', 80],
['Visit', 'Bounce (from visit)', 20],
['Search', 'Click', 60],
['Search', 'Bounce (from search)', 20],
['Click', 'Buy', 40],
['Click', 'Bounce (from click)', 20],
# invisible
['Bounce (from visit)', '_bv1', 20],
['_bv1', '_bv2', 20],
['Bounce (from search)', '_bs1', 20]
], columns=['source', 'target', 'count'])
nodes, links = make_sankey_params_invisible_nodes(funnel_df3)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Horizontal positioning via hidden nodes/links', margin=default_margins)
Explicit x/y coordinates
Inspiration: Plotly: How to set node positions in a Sankey Diagram?
from functools import reduce
def calc_node_positions(
nodes: list,
node_groups: list,
final_group_extra_gap: float = 0.05,
y_sep=0.3
) -> Tuple[list, list]:
""" Calculate x and y coords for a list of nodes, given a grouping.
X coords force plotly to place nodes that belong to the same conceptual "level" together.
Y positions "encourage" (not force) nodes into a vertical order.
Args:
final_group_extra_gap:
Amount of additional x-axis (0-1) to allocate to final grouping,
to account for the fact that labels on this group appear on left,
and may overlap with labels from the penultimate group.
y_sep:
Amount of "encouragement" for vertical ordering.
Best results ~0.1, layout gets weird if higher.
Gets divided by len(group) so that smaller values used in groups with many nodes.
"""
final_group_extra_gap = 0.05
normal_gap = (1-final_group_extra_gap) / (len(node_groups)-1)
x_pos = [
round(group_idx * normal_gap, 2)
if group_idx < len(node_groups) - 1
else round(group_idx * normal_gap + final_group_extra_gap, 2)
for group_idx, group in enumerate(node_groups)
for node_idx, node in enumerate(group)
]
y_pos = [
node
for group_idx, group in enumerate(node_groups)
for node in np.cumsum(np.full(shape=len(group), fill_value=(y_sep/len(group)))).round(4).tolist()
]
return x_pos, y_pos
def make_sankey_params_v3(df: pd.DataFrame, node_groups: list = None) -> dict:
""" Generate parameter dicts for go.Figure plotting function """
# Unpack columns into lists
sources, targets, values = df.values.T.tolist()
# Create list of unique labels across node columns (source, target)
observed_labels = list(df['source'].pipe(set) | df['target'].pipe(set))
expected_labels = (reduce(lambda x, y: x+y, node_groups))
assert set(observed_labels) == set(expected_labels), \
'Mismatch between node_groups and unique values in source/target columns'
# Map actual labels to their index value
source_idx = list(map(expected_labels.index, sources))
target_idx = list(map(expected_labels.index, targets))
# Nodes output
nodes_dict = {'label': expected_labels}
nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
# Links output
links_dict = {
'source': source_idx,
'target': target_idx,
'value': values
}
return nodes_dict, links_dict
node_groups = [
['Visit'],
['Search', 'Bounce (from visit)'],
['Click', 'Bounce (from search)'],
['Buy', 'Bounce (from click)']
]
nodes, links = make_sankey_params_v3(funnel_df2, node_groups)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Horizontal positioning via manual coordinates', margin=default_margins)
Hover tooltips
As you interact with the diagram, you may catch yourself doing mental arithmetic in order to calculate percentages. To alleviate this, we can add tooltips that show percentage flows:
- Nodes – percentage of total inputs/outputs
- Links – percentage of source/target node
See also: Hovertemplate and customdata of Sankey diagrams (Plotly docs)
def calc_node_tooltips(df: pd.DataFrame, node_groups: list) -> Tuple[List[Tuple[float]], str]:
""" Generate data/template for tooltips on a Plotly sankey diagram.
Expects input df with three columns: (source, target, count)
Final template:
{{ label }}
x% of Total (in)
x% of Total (out)
"""
# calculate total flows through each node
sorted_nodes = (reduce(lambda x, y: x+y, node_groups))
node_totals = (
df
.groupby('target')
['count'].sum()
.reindex(sorted_nodes)
)
# calculate % of inputs
total_input = df.set_index('source').loc[node_groups[0], 'count'].sum()
pct_of_total_in = (
node_totals
.pipe(lambda x: x / total_input)
.fillna(1)
.map('{:.0%}'.format)
.values
.tolist()
)
# calculate % of outputs
total_output = df.set_index('target').loc[node_groups[-1], 'count'].sum()
pct_of_total_out = (
node_totals
.pipe(lambda x: x / total_output)
#.fillna(1)
.map('{:.0%}'.format)
.values
.tolist()
)
customdata = [(x, y) for x, y in zip(pct_of_total_in, pct_of_total_out)]
hovertemplate = '%{label}<br>' + \
' %{customdata[0]} of total inputs<br>' + \
' %{customdata[1]} of total outputs'
return customdata, hovertemplate
def calc_link_tooltips(df: pd.DataFrame, node_groups: list) -> Tuple[List[Tuple[float]], str]:
""" Generate data/template for tooltips on a Plotly sankey diagram.
Expects input df with three columns: (source, target, count)
Final template:
{{ source }} → {{ target }}
x% of {{ source }}
x% of {{ target }}
"""
pct_of_source = (
df
.set_index('source')
.assign(denom=lambda x: x.groupby(level=0)['count'].sum()[x.index])
.pipe(lambda x: x['count'] / x['denom'])
.map('{:.0%}'.format)
.values
.tolist()
)
pct_of_target = (
df
.set_index('target')
.assign(denom=lambda x: x.groupby(level=0)['count'].sum()[x.index])
.pipe(lambda x: x['count'] / x['denom'])
.map('{:.0%}'.format)
.values
.tolist()
)
customdata = [(x, y) for x, y in zip(pct_of_source, pct_of_target)]
hovertemplate = '%{source.label} → %{target.label}<br>' + \
' %{customdata[0]} of %{source.label}<br>' + \
' %{customdata[1]} of %{target.label}'
return customdata, hovertemplate
def make_sankey_params_with_tooltips(df: pd.DataFrame, node_groups: list = None) -> dict:
""" Generate parameter dicts for go.Figure plotting function """
# Create list of unique labels across node columns (source, target)
observed_labels_set = df['source'].pipe(set) | df['target'].pipe(set)
expected_labels = reduce(lambda x, y: x+y, node_groups)
expected_labels_set = set(expected_labels)
assert observed_labels_set == expected_labels_set, \
'Mismatch between node_groups and unique values in source/target columns\n' \
f'\tMissing from node_groups: {observed_labels_set-expected_labels_set}\n' \
f'\tMissing from df: {expected_labels_set-observed_labels_set}'
# Nodes
nodes_dict = {'label': expected_labels}
nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
# Links
sources, targets, values = df.values.T.tolist()
source_idx = list(map(expected_labels.index, sources))
target_idx = list(map(expected_labels.index, targets))
links_dict = {'source': source_idx, 'target': target_idx, 'value': values}
# Tooltips
nodes_dict['customdata'], nodes_dict['hovertemplate'] = calc_node_tooltips(df, node_groups)
links_dict['customdata'], links_dict['hovertemplate'] = calc_link_tooltips(df, node_groups)
return nodes_dict, links_dict
node_groups = [
['Visit'],
['Search', 'Bounce (from visit)'],
['Click', 'Bounce (from search)'],
['Buy', 'Bounce (from click)']
]
nodes, links = make_sankey_params_with_tooltips(funnel_df2, node_groups)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Hover tooltips (for both nodes and links)', margin=default_margins)
Final version
Everything all together now:
def make_sankey_params(
df: pd.DataFrame,
node_groups: list = None,
colors: dict = None
) -> dict:
""" Generate parameter dicts for go.Figure plotting function """
# Create list of unique labels across node columns (source, target)
observed_labels_set = df['source'].pipe(set) | df['target'].pipe(set)
expected_labels = reduce(lambda x, y: x+y, node_groups)
expected_labels_set = set(expected_labels)
assert observed_labels_set == expected_labels_set, \
'Mismatch between node_groups and unique values in source/target columns\n' \
f'\tMissing from node_groups: {observed_labels_set-expected_labels_set}\n' \
f'\tMissing from df: {expected_labels_set-observed_labels_set}'
# Nodes
nodes_dict = {'label': expected_labels}
nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
if colors:
colors_rgba = [hex_to_rgba(x) for x in px.colors.qualitative.Plotly]
remaining_colors = cycle(set(colors_rgba) - set(colors.values()))
nodes_dict['color'] = [colors.get(x, next(remaining_colors)) for x in expected_labels]
# Links
sources, targets, values = df.values.T.tolist()
source_idx = list(map(expected_labels.index, sources))
target_idx = list(map(expected_labels.index, targets))
links_dict = {'source': source_idx, 'target': target_idx, 'value': values}
# Tooltips
nodes_dict['customdata'], nodes_dict['hovertemplate'] = calc_node_tooltips(df, node_groups)
links_dict['customdata'], links_dict['hovertemplate'] = calc_link_tooltips(df, node_groups)
return nodes_dict, links_dict
node_groups = [
['Visit'],
['Search', 'Bounce (from visit)'],
['Click', 'Bounce (from search)'],
['Buy', 'Bounce (from click)']
]
colors_map = {
'Bounce (from visit)': 'rgba(239,85,59,0.8)',
'Bounce (from search)': 'rgba(239,85,59,0.8)',
'Bounce (from click)': 'rgba(239,85,59,0.8)'
}
nodes, links = make_sankey_params(funnel_df2, node_groups, colors=colors_map)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Sankey w/ hover tooltips, specified colours and x/y positions', margin=default_margins)
Interactive widgets
Suppose our data is segmented by some sort of dimension such as country or platform. We’d probably visualize the combined overview first, but would then like to drill-down into particular dimensions.
Naively we could loop over the previous logic to generate separate flowcharts. But plotly diagrams are interactive, so we can do better—using a drop-down filter.
First let’s create an example segmented dataset:
funnel_df3 = pd.concat([
funnel_df2.assign(platform='desktop'),
funnel_df2.assign(platform='mobile').assign(count=lambda x: x['count'] + 10)
]).reset_index(drop=True)
funnel_df3
source | target | count | platform | |
---|---|---|---|---|
0 | Visit | Search | 80 | desktop |
1 | Visit | Bounce (from visit) | 20 | desktop |
2 | Search | Click | 60 | desktop |
3 | Search | Bounce (from search) | 20 | desktop |
4 | Click | Buy | 40 | desktop |
5 | Click | Bounce (from click) | 20 | desktop |
6 | Visit | Search | 90 | mobile |
7 | Visit | Bounce (from visit) | 30 | mobile |
8 | Search | Click | 70 | mobile |
9 | Search | Bounce (from search) | 30 | mobile |
10 | Click | Buy | 50 | mobile |
11 | Click | Bounce (from click) | 30 | mobile |
filtered_default = funnel_df3.groupby(['source', 'target'])['count'].sum().reset_index()
filtered_default
source | target | count | |
---|---|---|---|
0 | Click | Bounce (from click) | 50 |
1 | Click | Buy | 90 |
2 | Search | Bounce (from search) | 50 |
3 | Search | Click | 130 |
4 | Visit | Bounce (from visit) | 50 |
5 | Visit | Search | 170 |
def assemble_update_menus(label: str, nodes: dict, links: dict) -> dict:
""" Assemble the dict structure expected by plotly to update dataset using drop-down filter. """
return {
'method': 'animate',
'label': label,
'args': [{
'data': [{
'type': 'sankey',
'arrangement': 'snap',
'node': nodes,
'link': links
}],
'layout': {'title': label}
}]
}
# generate default view
nodes, links = make_sankey_params(filtered_default, node_groups, colors=colors_map)
default_args = assemble_update_menus('All', nodes, links)
fig = go.Figure(default_args['args'][0]['data'][0])
# generate set of filtered views
filtered_datasets = [default_args]
for group_name, grouped_df in funnel_df3.groupby('platform'):
sub_df = grouped_df.drop('platform', axis=1)
nodes, links = make_sankey_params(sub_df, node_groups, colors=colors_map)
filtered_datasets.append(assemble_update_menus(group_name, nodes, links))
# add filter options
update_menus = [{'buttons': filtered_datasets, 'x': 0.7, 'y': 1.17}]
fig.update_layout(updatemenus=update_menus, title='Sankey diagram with interactive filters', margin=default_margins)
Todo
Coloured links
link_colors = []
for source, target in zip(sources, targets):
print(source, target)
if (source, target) == ('A', 'B') or target == 'C':
link_colors.append('rgba(180,21,21,0.4)')
else:
link_colors.append('rgba(113, 113, 113, 0.4)')
links_dict = {'source': source_idx, 'target': target_idx, 'value': values, 'color': link_colors}
comments powered by Disqus