Geoff Ruddock

Interactive sankey diagrams (with Plotly)

from pprint import pprint
from typing import Tuple, List
from IPython.display import display

import numpy as np
import pandas as pd
import plotly as px
import plotly.graph_objects as go
Package    Version
---------  ---------
python     3.9.12
pandas     1.4.2
plotly     5.6.0

Basic diagram from dataframe

Generate dummy data

Let’s suppose we have a dataset in which each row is a unique path across potential values within any number of “levels” or stages.

This is a reasonably easy format to reach with an SQL query using GROUP BY 1, 2, 3, etc.

from itertools import product

states = {
    'Country': ['Canada', 'USA', 'Mexico'],
    'Occupation': ['Doctor', 'Lawyer', 'Banker'],
    'Life satisfaction': ['Happy', 'Meh', 'Unhappy']
}

rows = list(product(*states.values()))
vals = np.random.randint(100, size=len(rows))

df_unique_paths = pd.DataFrame(rows, columns=states.keys()).assign(n=vals)

with pd.option_context('display.max_rows', 10):
    display(df_unique_paths)

Country Occupation Life satisfaction n
0 Canada Doctor Happy 19
1 Canada Doctor Meh 73
2 Canada Doctor Unhappy 47
3 Canada Lawyer Happy 16
4 Canada Lawyer Meh 58
... ... ... ... ...
22 Mexico Lawyer Meh 39
23 Mexico Lawyer Unhappy 95
24 Mexico Banker Happy 7
25 Mexico Banker Meh 83
26 Mexico Banker Unhappy 32

27 rows × 4 columns

Generate edge list

Now let’s translate the above dummy dataframe into a set of weighted edges to/from nodes.

def edges_from_unique_paths(df: pd.DataFrame) -> pd.DataFrame:
    """ Generate a "long" set of edges, from a "wide" set of unique paths.

    Ignores any edges starting with an underscore.
    
    Args:
        df: DataFrame with n-1 columns that represent "levels",
            followed by a single numeric count/weight column.
        
    Returns:
        A DataFrame with three columns: (source, target, weight)
    
    """

    from collections import defaultdict
    edges = defaultdict(int)

    for idx, x in df.iterrows():
        only_visible = x.loc[lambda x: x.str.startswith('_') != True]
        n = only_visible['n']
        paths = only_visible.drop('n')
        for a, b in zip(paths[:-1], paths[1:]):
            edges[(a, b)] += n

    return pd.DataFrame(
        [(a, b, n) for (a, b), n in edges.items()],
        columns=['source', 'target', 'count']
    )

    
df_edges = edges_from_unique_paths(df_unique_paths)

with pd.option_context('display.max_rows', 10):
    display(df_edges)

source target count
0 Canada Doctor 139
1 Doctor Happy 126
2 Doctor Meh 98
3 Doctor Unhappy 153
4 Canada Lawyer 98
... ... ... ...
13 USA Lawyer 148
14 USA Banker 146
15 Mexico Doctor 114
16 Mexico Lawyer 140
17 Mexico Banker 122

18 rows × 3 columns

If the structure of our data were less predictable—e.g. if some paths do not “pass” through all columns—it may be desirable to directly specify the edge list.

Generate plotly params

The main plotting function (that we’ll use below) go.Sankey() takes two arguments:

  1. A dict of node-related info (labels, colors, etc.)
  2. A dict of link-related data (to/from, weights)
from functools import reduce
    
def make_sankey_params_v1(df: pd.DataFrame) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Unpack columns into lists
    sources, targets, values = df.values.T.tolist()
    
    # Create list of unique labels across node columns (source, target)
    labels = list(df['source'].pipe(set) | df['target'].pipe(set))
    
    # Map actual labels to their index value
    source_idx = list(map(labels.index, sources))
    target_idx = list(map(labels.index, targets))
    
    # Assemble final outputs into expected format
    nodes_dict = {'label': labels}
    links_dict = {
        'source': source_idx,
        'target': target_idx,
        'value': values
    }
    
    return nodes_dict, links_dict


nodes, links = make_sankey_params_v1(df_edges)

pprint(nodes, compact=True)
print('')
pprint(links, compact=True)
{'label': ['Meh', 'Banker', 'Mexico', 'Doctor', 'Canada', 'Lawyer', 'USA',
           'Happy', 'Unhappy']}

{'source': [4, 3, 3, 3, 4, 5, 5, 5, 4, 1, 1, 1, 6, 6, 6, 2, 2, 2],
 'target': [3, 7, 0, 8, 5, 7, 0, 8, 1, 7, 0, 8, 3, 5, 1, 3, 5, 1],
 'value': [139, 126, 98, 153, 98, 96, 106, 184, 126, 156, 96, 142, 124, 148,
           146, 114, 140, 122]}

Visualize

data = go.Sankey(node=nodes, link=links)
fig = go.Figure(data)
default_margins = {'l': 25, 'r': 25, 't': 50, 'b': 0} 
fig.update_layout(title='Basic example sankey diagram', margin=default_margins)

Formatting

Specify node colors

We can also manually set the color of nodes. Suppose we are only interested in doing so for a handful of specific nodes, but would like to let plotly select default colors for the rest.

These are the colors in the qualitative color palette, with 80% opacity:

from IPython.core.display import HTML

def hex_to_rgba(h: str, a: float = 0.8) -> str:
    """ Convert a hex color into rgba with a specified opacity. """
    r, g, b = tuple(int(h.strip('#')[i:i+2], 16) for i in (0, 2, 4))
    return f'rgba({r},{g},{b},{a})'
    
html_output = ''
colors_rgba = [hex_to_rgba(x) for x in px.colors.qualitative.Plotly]
for c in colors_rgba:
    html_output += f'<span style="color: {c}; font-weight: bold">{c}</span><br>'
    
HTML(html_output)

rgba(99,110,250,0.8)
rgba(239,85,59,0.8)
rgba(0,204,150,0.8)
rgba(171,99,250,0.8)
rgba(255,161,90,0.8)
rgba(25,211,243,0.8)
rgba(255,102,146,0.8)
rgba(182,232,128,0.8)
rgba(255,151,255,0.8)
rgba(254,203,82,0.8)

If our diagram has more nodes than colors in the palette, they will be re-used. This may be confusing for the nodes we wish to manually specify. So let’s exclude those colors from the palette.

from functools import reduce
from itertools import cycle

nodes, links = make_sankey_params_v1(df_edges)

colors_map = {
    'Happy': 'rgba(0,204,150,0.8)',
    'Meh': 'rgba(254,203,82,0.8)',
    'Unhappy': 'rgba(239,85,59,0.8)'
}
    
remaining_colors = cycle(set(colors_rgba) - set(colors_map.values()))
nodes['color'] = [colors_map.get(x, next(remaining_colors)) for x in nodes['label']]

data = go.Sankey(node=nodes, link=links)
fig = go.Figure(data)
fig.update_layout(title='Sankey diagram with specified node colours', margin=default_margins)

Misc

To manually resize the plot:

fig.update_layout(autosize=False, width=400, height=400)

Specify horizontal position

Scenario

Suppose we are visualizing some flow such as an ecommerce funnel, with drop-offs at each stage.

A v1 might look something like this:

funnel_df = pd.DataFrame([
    ['Visit', 'Search', 80],
    ['Visit', 'Bounce', 20],
    ['Search', 'Click', 60],
    ['Search', 'Bounce', 20],
    ['Click', 'Buy', 40],
    ['Click', 'Bounce', 20]
    
], columns=['source', 'target', 'count'])

nodes, links = make_sankey_params_v1(funnel_df)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='An example ecommerce funnel', margin=default_margins)

Having a single node for “Bounce” at the end accentuates the overall conversion rate, but at the expense of making the per-stage conversion rate unclear. So we might split these nodes up:

funnel_df2 = pd.DataFrame([
    ['Visit', 'Search', 80],
    ['Visit', 'Bounce (from visit)', 20],
    ['Search', 'Click', 60],
    ['Search', 'Bounce (from search)', 20],
    ['Click', 'Buy', 40],
    ['Click', 'Bounce (from click)', 20],
    
], columns=['source', 'target', 'count'])

nodes, links = make_sankey_params_v1(funnel_df2)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)

fig.update_layout(title='Separate "Bounce" nodes looks funny', margin=default_margins)

But this looks even worse, because by default, plotly puts all terminal nodes at the end.

Invisible nodes

One (somewhat clunky) approach is to use dummy nodes to enforce positioning, then set them as invisible. Inspiration:

from functools import reduce
    
def make_sankey_params_invisible_nodes(df: pd.DataFrame) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Unpack columns into lists
    sources, targets, values = df.values.T.tolist()
    
    # Create list of unique labels across node columns (source, target)
    labels = list(df['source'].pipe(set) | df['target'].pipe(set))
    
    # Map actual labels to their index value
    source_idx = list(map(labels.index, sources))
    target_idx = list(map(labels.index, targets))
    
    # Assemble final outputs into expected format
    
    display_labels = [x if not x.startswith('_') else '' for x in labels]
    
    node_colors = [
        px.colors.qualitative.Plotly[i%len(px.colors.qualitative.Plotly)]
        if not x.startswith('_')
        else 'rgba(0,0,0,0)'
        for i, x in enumerate(labels)
    ]
    
    nodes_dict = {
        'label': display_labels,
        'color': node_colors,
        'line': {'width': 0}
    }
    
    link_colors = [
        'rgba(0,0,0,0)'
        if x
        else 'rgba(64,64,64,0.2)'
        for x in df[['source', 'target']].apply(lambda x: x.str.startswith('_')).any(axis=1).tolist()
    ]
    
    links_dict = {
        'source': source_idx,
        'target': target_idx,
        'value': values,
        'color': link_colors
    }
    
    return nodes_dict, links_dict


funnel_df3 = pd.DataFrame([
    ['Visit', 'Search', 80],
    ['Visit', 'Bounce (from visit)', 20],
    ['Search', 'Click', 60],
    ['Search', 'Bounce (from search)', 20],
    ['Click', 'Buy', 40],
    ['Click', 'Bounce (from click)', 20],
    
    # invisible
    ['Bounce (from visit)', '_bv1', 20],
    ['_bv1', '_bv2', 20],
    ['Bounce (from search)', '_bs1', 20]
    
], columns=['source', 'target', 'count'])


nodes, links = make_sankey_params_invisible_nodes(funnel_df3)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Horizontal positioning via hidden nodes/links', margin=default_margins)

Explicit x/y coordinates

Inspiration: Plotly: How to set node positions in a Sankey Diagram?

from functools import reduce

def calc_node_positions(
    nodes: list,
    node_groups: list,
    final_group_extra_gap: float = 0.05,
    y_sep=0.3
    ) -> Tuple[list, list]:
    """ Calculate x and y coords for a list of nodes, given a grouping. 
    
    X coords force plotly to place nodes that belong to the same conceptual "level" together.
    Y positions "encourage" (not force) nodes into a vertical order.
    
    Args:
        final_group_extra_gap:
            Amount of additional x-axis (0-1) to allocate to final grouping,
            to account for the fact that labels on this group appear on left,
            and may overlap with labels from the penultimate group.
        y_sep: 
            Amount of "encouragement" for vertical ordering.
            Best results ~0.1, layout gets weird if higher.
            Gets divided by len(group) so that smaller values used in groups with many nodes.

    """

    final_group_extra_gap = 0.05
    normal_gap = (1-final_group_extra_gap) / (len(node_groups)-1)

    x_pos = [
        round(group_idx * normal_gap, 2)
        if group_idx < len(node_groups) - 1
        else round(group_idx * normal_gap + final_group_extra_gap, 2)
        for group_idx, group in enumerate(node_groups)
        for node_idx, node in enumerate(group)
    ]
    
    y_pos = [
        node
        for group_idx, group in enumerate(node_groups)
        for node in np.cumsum(np.full(shape=len(group), fill_value=(y_sep/len(group)))).round(4).tolist()
    ]

    return x_pos, y_pos

    
def make_sankey_params_v3(df: pd.DataFrame, node_groups: list = None) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Unpack columns into lists
    sources, targets, values = df.values.T.tolist()
    
    # Create list of unique labels across node columns (source, target)
    observed_labels = list(df['source'].pipe(set) | df['target'].pipe(set))
    expected_labels = (reduce(lambda x, y: x+y, node_groups))
    assert set(observed_labels) == set(expected_labels), \
        'Mismatch between node_groups and unique values in source/target columns'
    
    # Map actual labels to their index value
    source_idx = list(map(expected_labels.index, sources))
    target_idx = list(map(expected_labels.index, targets))
    
    # Nodes output
    nodes_dict = {'label': expected_labels}
    nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
    
    # Links output
    links_dict = {
        'source': source_idx,
        'target': target_idx,
        'value': values
    }
    
    return nodes_dict, links_dict


node_groups = [
    ['Visit'],
    ['Search', 'Bounce (from visit)'],
    ['Click', 'Bounce (from search)'],
    ['Buy', 'Bounce (from click)']
]

nodes, links = make_sankey_params_v3(funnel_df2, node_groups)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Horizontal positioning via manual coordinates', margin=default_margins)

Hover tooltips

As you interact with the diagram, you may catch yourself doing mental arithmetic in order to calculate percentages. To alleviate this, we can add tooltips that show percentage flows:

See also: Hovertemplate and customdata of Sankey diagrams (Plotly docs)

def calc_node_tooltips(df: pd.DataFrame, node_groups: list) -> Tuple[List[Tuple[float]], str]:
    """ Generate data/template for tooltips on a Plotly sankey diagram.
    
    Expects input df with three columns: (source, target, count)
    
    Final template: 
    {{ label }}
     x% of Total (in)
     x% of Total (out)
    """
    
    # calculate total flows through each node
    sorted_nodes = (reduce(lambda x, y: x+y, node_groups))
    node_totals = (
        df
        .groupby('target')
        ['count'].sum()
        .reindex(sorted_nodes)
    )  
    
    # calculate % of inputs 
    total_input = df.set_index('source').loc[node_groups[0], 'count'].sum()
    pct_of_total_in = (
        node_totals
        .pipe(lambda x: x / total_input)
        .fillna(1)
        .map('{:.0%}'.format)
        .values
        .tolist()
    )
    
    # calculate % of outputs
    total_output = df.set_index('target').loc[node_groups[-1], 'count'].sum()
    pct_of_total_out = (
        node_totals
        .pipe(lambda x: x / total_output)
        #.fillna(1)
        .map('{:.0%}'.format)
        .values
        .tolist()
    )    
    
    customdata = [(x, y) for x, y in zip(pct_of_total_in, pct_of_total_out)]

    hovertemplate = '%{label}<br>' + \
            ' %{customdata[0]} of total inputs<br>' + \
            ' %{customdata[1]} of total outputs'
    
    return customdata, hovertemplate


def calc_link_tooltips(df: pd.DataFrame, node_groups: list) -> Tuple[List[Tuple[float]], str]:
    """ Generate data/template for tooltips on a Plotly sankey diagram.
    
    Expects input df with three columns: (source, target, count)
    
    Final template: 
    {{ source }} → {{ target }}
     x% of {{ source }}
     x% of {{ target }}
    """
    
    pct_of_source = (
        df
        .set_index('source')
        .assign(denom=lambda x: x.groupby(level=0)['count'].sum()[x.index])
        .pipe(lambda x: x['count'] / x['denom'])
        .map('{:.0%}'.format)
        .values
        .tolist()
    )

    pct_of_target = (
        df
        .set_index('target')
        .assign(denom=lambda x: x.groupby(level=0)['count'].sum()[x.index])
        .pipe(lambda x: x['count'] / x['denom'])
        .map('{:.0%}'.format)
        .values
        .tolist()
    )    
    
    customdata = [(x, y) for x, y in zip(pct_of_source, pct_of_target)]
    
    hovertemplate = '%{source.label} → %{target.label}<br>' + \
            ' %{customdata[0]} of %{source.label}<br>' + \
            ' %{customdata[1]} of %{target.label}'
    
    return customdata, hovertemplate

    
def make_sankey_params_with_tooltips(df: pd.DataFrame, node_groups: list = None) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Create list of unique labels across node columns (source, target)
    observed_labels_set = df['source'].pipe(set) | df['target'].pipe(set)
    expected_labels = reduce(lambda x, y: x+y, node_groups)
    expected_labels_set = set(expected_labels)
    assert observed_labels_set == expected_labels_set, \
        'Mismatch between node_groups and unique values in source/target columns\n' \
        f'\tMissing from node_groups: {observed_labels_set-expected_labels_set}\n' \
        f'\tMissing from df: {expected_labels_set-observed_labels_set}'
    
    # Nodes
    nodes_dict = {'label': expected_labels}
    nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)

    # Links
    sources, targets, values = df.values.T.tolist()
    source_idx = list(map(expected_labels.index, sources))
    target_idx = list(map(expected_labels.index, targets))
    links_dict = {'source': source_idx, 'target': target_idx, 'value': values}
    
    # Tooltips
    nodes_dict['customdata'], nodes_dict['hovertemplate'] = calc_node_tooltips(df, node_groups)
    links_dict['customdata'], links_dict['hovertemplate'] = calc_link_tooltips(df, node_groups)
    
    return nodes_dict, links_dict


node_groups = [
    ['Visit'],
    ['Search', 'Bounce (from visit)'],
    ['Click', 'Bounce (from search)'],
    ['Buy', 'Bounce (from click)']
]

nodes, links = make_sankey_params_with_tooltips(funnel_df2, node_groups)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Hover tooltips (for both nodes and links)', margin=default_margins)

Final version

Everything all together now:

def make_sankey_params(
    df: pd.DataFrame,
    node_groups: list = None,
    colors: dict = None
) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Create list of unique labels across node columns (source, target)
    observed_labels_set = df['source'].pipe(set) | df['target'].pipe(set)
    expected_labels = reduce(lambda x, y: x+y, node_groups)
    expected_labels_set = set(expected_labels)
    assert observed_labels_set == expected_labels_set, \
        'Mismatch between node_groups and unique values in source/target columns\n' \
        f'\tMissing from node_groups: {observed_labels_set-expected_labels_set}\n' \
        f'\tMissing from df: {expected_labels_set-observed_labels_set}'
    
    # Nodes
    nodes_dict = {'label': expected_labels}
    nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
    if colors:
        colors_rgba = [hex_to_rgba(x) for x in px.colors.qualitative.Plotly]
        remaining_colors = cycle(set(colors_rgba) - set(colors.values()))
        nodes_dict['color'] = [colors.get(x, next(remaining_colors)) for x in expected_labels]

    # Links
    sources, targets, values = df.values.T.tolist()
    source_idx = list(map(expected_labels.index, sources))
    target_idx = list(map(expected_labels.index, targets))
    links_dict = {'source': source_idx, 'target': target_idx, 'value': values}
    
    # Tooltips
    nodes_dict['customdata'], nodes_dict['hovertemplate'] = calc_node_tooltips(df, node_groups)
    links_dict['customdata'], links_dict['hovertemplate'] = calc_link_tooltips(df, node_groups)
    
    return nodes_dict, links_dict


node_groups = [
    ['Visit'],
    ['Search', 'Bounce (from visit)'],
    ['Click', 'Bounce (from search)'],
    ['Buy', 'Bounce (from click)']
]

colors_map = {
    'Bounce (from visit)': 'rgba(239,85,59,0.8)',
    'Bounce (from search)': 'rgba(239,85,59,0.8)',
    'Bounce (from click)': 'rgba(239,85,59,0.8)'
}

nodes, links = make_sankey_params(funnel_df2, node_groups, colors=colors_map)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Sankey w/ hover tooltips, specified colours and x/y positions', margin=default_margins)

Interactive widgets

Suppose our data is segmented by some sort of dimension such as country or platform. We’d probably visualize the combined overview first, but would then like to drill-down into particular dimensions.

Naively we could loop over the previous logic to generate separate flowcharts. But plotly diagrams are interactive, so we can do better—using a drop-down filter.

First let’s create an example segmented dataset:

funnel_df3 = pd.concat([
    funnel_df2.assign(platform='desktop'),
    funnel_df2.assign(platform='mobile').assign(count=lambda x: x['count'] + 10)
]).reset_index(drop=True)

funnel_df3

source target count platform
0 Visit Search 80 desktop
1 Visit Bounce (from visit) 20 desktop
2 Search Click 60 desktop
3 Search Bounce (from search) 20 desktop
4 Click Buy 40 desktop
5 Click Bounce (from click) 20 desktop
6 Visit Search 90 mobile
7 Visit Bounce (from visit) 30 mobile
8 Search Click 70 mobile
9 Search Bounce (from search) 30 mobile
10 Click Buy 50 mobile
11 Click Bounce (from click) 30 mobile
filtered_default = funnel_df3.groupby(['source', 'target'])['count'].sum().reset_index()
filtered_default

source target count
0 Click Bounce (from click) 50
1 Click Buy 90
2 Search Bounce (from search) 50
3 Search Click 130
4 Visit Bounce (from visit) 50
5 Visit Search 170
def assemble_update_menus(label: str, nodes: dict, links: dict) -> dict:
    """ Assemble the dict structure expected by plotly to update dataset using drop-down filter. """
    
    return {
        'method': 'animate',
        'label': label,
        'args': [{
            'data': [{
                'type': 'sankey',
                'arrangement': 'snap',
                'node': nodes,
                'link': links
            }],
            'layout': {'title': label}
        }]
    }

# generate default view
nodes, links = make_sankey_params(filtered_default, node_groups, colors=colors_map)
default_args = assemble_update_menus('All', nodes, links)
fig = go.Figure(default_args['args'][0]['data'][0])

# generate set of filtered views
filtered_datasets = [default_args]
for group_name, grouped_df in funnel_df3.groupby('platform'):
    sub_df = grouped_df.drop('platform', axis=1)
    nodes, links = make_sankey_params(sub_df, node_groups, colors=colors_map)
    filtered_datasets.append(assemble_update_menus(group_name, nodes, links))
    
# add filter options
update_menus = [{'buttons': filtered_datasets, 'x': 0.7, 'y': 1.17}]
fig.update_layout(updatemenus=update_menus, title='Sankey diagram with interactive filters', margin=default_margins)

Todo

Coloured links

link_colors = []
for source, target in zip(sources, targets):
    print(source, target)
    if (source, target) == ('A', 'B') or target == 'C':
        link_colors.append('rgba(180,21,21,0.4)')
    else:
        link_colors.append('rgba(113, 113, 113, 0.4)')

links_dict = {'source': source_idx, 'target': target_idx, 'value': values, 'color': link_colors}

Further reading


comments powered by Disqus