Geoff Ruddock

Interactive sankey diagrams (with Plotly)

from pprint import pprint
from typing import Tuple, List
from IPython.display import display

import numpy as np
import pandas as pd
import plotly as px
import plotly.graph_objects as go
Package    Version
---------  ---------
python     3.9.12
pandas     1.4.2
plotly     5.6.0

Basic diagram from dataframe

Generate dummy data

Let’s suppose we have a dataset in which each row is a unique path across potential values within any number of “levels” or stages.

This is a reasonably easy format to reach with an SQL query using GROUP BY 1, 2, 3, etc.

from itertools import product

states = {
    'Country': ['Canada', 'USA', 'Mexico'],
    'Occupation': ['Doctor', 'Lawyer', 'Banker'],
    'Life satisfaction': ['Happy', 'Meh', 'Unhappy']
}

rows = list(product(*states.values()))
vals = np.random.randint(100, size=len(rows))

df_unique_paths = pd.DataFrame(rows, columns=states.keys()).assign(n=vals)

with pd.option_context('display.max_rows', 10):
    display(df_unique_paths)

CountryOccupationLife satisfactionn
0CanadaDoctorHappy19
1CanadaDoctorMeh73
2CanadaDoctorUnhappy47
3CanadaLawyerHappy16
4CanadaLawyerMeh58
...............
22MexicoLawyerMeh39
23MexicoLawyerUnhappy95
24MexicoBankerHappy7
25MexicoBankerMeh83
26MexicoBankerUnhappy32

27 rows × 4 columns

Generate edge list

Now let’s translate the above dummy dataframe into a set of weighted edges to/from nodes.

def edges_from_unique_paths(df: pd.DataFrame) -> pd.DataFrame:
    """ Generate a "long" set of edges, from a "wide" set of unique paths.

    Ignores any edges starting with an underscore.
    
    Args:
        df: DataFrame with n-1 columns that represent "levels",
            followed by a single numeric count/weight column.
        
    Returns:
        A DataFrame with three columns: (source, target, weight)
    
    """

    from collections import defaultdict
    edges = defaultdict(int)

    for idx, x in df.iterrows():
        only_visible = x.loc[lambda x: x.str.startswith('_') != True]
        n = only_visible['n']
        paths = only_visible.drop('n')
        for a, b in zip(paths[:-1], paths[1:]):
            edges[(a, b)] += n

    return pd.DataFrame(
        [(a, b, n) for (a, b), n in edges.items()],
        columns=['source', 'target', 'count']
    )

    
df_edges = edges_from_unique_paths(df_unique_paths)

with pd.option_context('display.max_rows', 10):
    display(df_edges)

sourcetargetcount
0CanadaDoctor139
1DoctorHappy126
2DoctorMeh98
3DoctorUnhappy153
4CanadaLawyer98
............
13USALawyer148
14USABanker146
15MexicoDoctor114
16MexicoLawyer140
17MexicoBanker122

18 rows × 3 columns

If the structure of our data were less predictable—e.g. if some paths do not “pass” through all columns—it may be desirable to directly specify the edge list.

Generate plotly params

The main plotting function (that we’ll use below) go.Sankey() takes two arguments:

  1. A dict of node-related info (labels, colors, etc.)
  2. A dict of link-related data (to/from, weights)
from functools import reduce
    
def make_sankey_params_v1(df: pd.DataFrame) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Unpack columns into lists
    sources, targets, values = df.values.T.tolist()
    
    # Create list of unique labels across node columns (source, target)
    labels = list(df['source'].pipe(set) | df['target'].pipe(set))
    
    # Map actual labels to their index value
    source_idx = list(map(labels.index, sources))
    target_idx = list(map(labels.index, targets))
    
    # Assemble final outputs into expected format
    nodes_dict = {'label': labels}
    links_dict = {
        'source': source_idx,
        'target': target_idx,
        'value': values
    }
    
    return nodes_dict, links_dict


nodes, links = make_sankey_params_v1(df_edges)

pprint(nodes, compact=True)
print('')
pprint(links, compact=True)
{'label': ['Meh', 'Banker', 'Mexico', 'Doctor', 'Canada', 'Lawyer', 'USA',
           'Happy', 'Unhappy']}

{'source': [4, 3, 3, 3, 4, 5, 5, 5, 4, 1, 1, 1, 6, 6, 6, 2, 2, 2],
 'target': [3, 7, 0, 8, 5, 7, 0, 8, 1, 7, 0, 8, 3, 5, 1, 3, 5, 1],
 'value': [139, 126, 98, 153, 98, 96, 106, 184, 126, 156, 96, 142, 124, 148,
           146, 114, 140, 122]}

Visualize

data = go.Sankey(node=nodes, link=links)
fig = go.Figure(data)
default_margins = {'l': 25, 'r': 25, 't': 50, 'b': 0} 
fig.update_layout(title='Basic example sankey diagram', margin=default_margins)
MehBankerMexicoDoctorCanadaLawyerUSAHappyUnhappy
Basic example sankey diagram

Formatting

Specify node colors

We can also manually set the color of nodes. Suppose we are only interested in doing so for a handful of specific nodes, but would like to let plotly select default colors for the rest.

These are the colors in the qualitative color palette, with 80% opacity:

from IPython.core.display import HTML

def hex_to_rgba(h: str, a: float = 0.8) -> str:
    """ Convert a hex color into rgba with a specified opacity. """
    r, g, b = tuple(int(h.strip('#')[i:i+2], 16) for i in (0, 2, 4))
    return f'rgba({r},{g},{b},{a})'
    
html_output = ''
colors_rgba = [hex_to_rgba(x) for x in px.colors.qualitative.Plotly]
for c in colors_rgba:
    html_output += f'<span style="color: {c}; font-weight: bold">{c}</span><br>'
    
HTML(html_output)

rgba(99,110,250,0.8)
rgba(239,85,59,0.8)
rgba(0,204,150,0.8)
rgba(171,99,250,0.8)
rgba(255,161,90,0.8)
rgba(25,211,243,0.8)
rgba(255,102,146,0.8)
rgba(182,232,128,0.8)
rgba(255,151,255,0.8)
rgba(254,203,82,0.8)

If our diagram has more nodes than colors in the palette, they will be re-used. This may be confusing for the nodes we wish to manually specify. So let’s exclude those colors from the palette.

from functools import reduce
from itertools import cycle

nodes, links = make_sankey_params_v1(df_edges)

colors_map = {
    'Happy': 'rgba(0,204,150,0.8)',
    'Meh': 'rgba(254,203,82,0.8)',
    'Unhappy': 'rgba(239,85,59,0.8)'
}
    
remaining_colors = cycle(set(colors_rgba) - set(colors_map.values()))
nodes['color'] = [colors_map.get(x, next(remaining_colors)) for x in nodes['label']]

data = go.Sankey(node=nodes, link=links)
fig = go.Figure(data)
fig.update_layout(title='Sankey diagram with specified node colours', margin=default_margins)
MehBankerMexicoDoctorCanadaLawyerUSAHappyUnhappy
Sankey diagram with specified node colours

Misc

To manually resize the plot:

fig.update_layout(autosize=False, width=400, height=400)

Specify horizontal position

Scenario

Suppose we are visualizing some flow such as an ecommerce funnel, with drop-offs at each stage.

A v1 might look something like this:

funnel_df = pd.DataFrame([
    ['Visit', 'Search', 80],
    ['Visit', 'Bounce', 20],
    ['Search', 'Click', 60],
    ['Search', 'Bounce', 20],
    ['Click', 'Buy', 40],
    ['Click', 'Bounce', 20]
    
], columns=['source', 'target', 'count'])

nodes, links = make_sankey_params_v1(funnel_df)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='An example ecommerce funnel', margin=default_margins)

Having a single node for “Bounce” at the end accentuates the overall conversion rate, but at the expense of making the per-stage conversion rate unclear. So we might split these nodes up:

funnel_df2 = pd.DataFrame([
    ['Visit', 'Search', 80],
    ['Visit', 'Bounce (from visit)', 20],
    ['Search', 'Click', 60],
    ['Search', 'Bounce (from search)', 20],
    ['Click', 'Buy', 40],
    ['Click', 'Bounce (from click)', 20],
    
], columns=['source', 'target', 'count'])

nodes, links = make_sankey_params_v1(funnel_df2)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)

fig.update_layout(title='Separate "Bounce" nodes looks funny', margin=default_margins)
ClickBuyBounce (from search)Bounce (from click)VisitBounce (from visit)Search
Separate "Bounce" nodes looks funny

But this looks even worse, because by default, plotly puts all terminal nodes at the end.

Invisible nodes

One (somewhat clunky) approach is to use dummy nodes to enforce positioning, then set them as invisible. Inspiration:

from functools import reduce
    
def make_sankey_params_invisible_nodes(df: pd.DataFrame) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Unpack columns into lists
    sources, targets, values = df.values.T.tolist()
    
    # Create list of unique labels across node columns (source, target)
    labels = list(df['source'].pipe(set) | df['target'].pipe(set))
    
    # Map actual labels to their index value
    source_idx = list(map(labels.index, sources))
    target_idx = list(map(labels.index, targets))
    
    # Assemble final outputs into expected format
    
    display_labels = [x if not x.startswith('_') else '' for x in labels]
    
    node_colors = [
        px.colors.qualitative.Plotly[i%len(px.colors.qualitative.Plotly)]
        if not x.startswith('_')
        else 'rgba(0,0,0,0)'
        for i, x in enumerate(labels)
    ]
    
    nodes_dict = {
        'label': display_labels,
        'color': node_colors,
        'line': {'width': 0}
    }
    
    link_colors = [
        'rgba(0,0,0,0)'
        if x
        else 'rgba(64,64,64,0.2)'
        for x in df[['source', 'target']].apply(lambda x: x.str.startswith('_')).any(axis=1).tolist()
    ]
    
    links_dict = {
        'source': source_idx,
        'target': target_idx,
        'value': values,
        'color': link_colors
    }
    
    return nodes_dict, links_dict


funnel_df3 = pd.DataFrame([
    ['Visit', 'Search', 80],
    ['Visit', 'Bounce (from visit)', 20],
    ['Search', 'Click', 60],
    ['Search', 'Bounce (from search)', 20],
    ['Click', 'Buy', 40],
    ['Click', 'Bounce (from click)', 20],
    
    # invisible
    ['Bounce (from visit)', '_bv1', 20],
    ['_bv1', '_bv2', 20],
    ['Bounce (from search)', '_bs1', 20]
    
], columns=['source', 'target', 'count'])


nodes, links = make_sankey_params_invisible_nodes(funnel_df3)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Horizontal positioning via hidden nodes/links', margin=default_margins)
ClickBounce (from search)BuyBounce (from click)VisitBounce (from visit)Search
Horizontal positioning via hidden nodes/links

Explicit x/y coordinates

Inspiration: Plotly: How to set node positions in a Sankey Diagram?

from functools import reduce

def calc_node_positions(
    nodes: list,
    node_groups: list,
    final_group_extra_gap: float = 0.05,
    y_sep=0.3
    ) -> Tuple[list, list]:
    """ Calculate x and y coords for a list of nodes, given a grouping. 
    
    X coords force plotly to place nodes that belong to the same conceptual "level" together.
    Y positions "encourage" (not force) nodes into a vertical order.
    
    Args:
        final_group_extra_gap:
            Amount of additional x-axis (0-1) to allocate to final grouping,
            to account for the fact that labels on this group appear on left,
            and may overlap with labels from the penultimate group.
        y_sep: 
            Amount of "encouragement" for vertical ordering.
            Best results ~0.1, layout gets weird if higher.
            Gets divided by len(group) so that smaller values used in groups with many nodes.

    """

    final_group_extra_gap = 0.05
    normal_gap = (1-final_group_extra_gap) / (len(node_groups)-1)

    x_pos = [
        round(group_idx * normal_gap, 2)
        if group_idx < len(node_groups) - 1
        else round(group_idx * normal_gap + final_group_extra_gap, 2)
        for group_idx, group in enumerate(node_groups)
        for node_idx, node in enumerate(group)
    ]
    
    y_pos = [
        node
        for group_idx, group in enumerate(node_groups)
        for node in np.cumsum(np.full(shape=len(group), fill_value=(y_sep/len(group)))).round(4).tolist()
    ]

    return x_pos, y_pos

    
def make_sankey_params_v3(df: pd.DataFrame, node_groups: list = None) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Unpack columns into lists
    sources, targets, values = df.values.T.tolist()
    
    # Create list of unique labels across node columns (source, target)
    observed_labels = list(df['source'].pipe(set) | df['target'].pipe(set))
    expected_labels = (reduce(lambda x, y: x+y, node_groups))
    assert set(observed_labels) == set(expected_labels), \
        'Mismatch between node_groups and unique values in source/target columns'
    
    # Map actual labels to their index value
    source_idx = list(map(expected_labels.index, sources))
    target_idx = list(map(expected_labels.index, targets))
    
    # Nodes output
    nodes_dict = {'label': expected_labels}
    nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
    
    # Links output
    links_dict = {
        'source': source_idx,
        'target': target_idx,
        'value': values
    }
    
    return nodes_dict, links_dict


node_groups = [
    ['Visit'],
    ['Search', 'Bounce (from visit)'],
    ['Click', 'Bounce (from search)'],
    ['Buy', 'Bounce (from click)']
]

nodes, links = make_sankey_params_v3(funnel_df2, node_groups)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Horizontal positioning via manual coordinates', margin=default_margins)
VisitSearchBounce (from visit)ClickBounce (from search)BuyBounce (from click)
Horizontal positioning via manual coordinates

Hover tooltips

As you interact with the diagram, you may catch yourself doing mental arithmetic in order to calculate percentages. To alleviate this, we can add tooltips that show percentage flows:

See also: Hovertemplate and customdata of Sankey diagrams (Plotly docs)

def calc_node_tooltips(df: pd.DataFrame, node_groups: list) -> Tuple[List[Tuple[float]], str]:
    """ Generate data/template for tooltips on a Plotly sankey diagram.
    
    Expects input df with three columns: (source, target, count)
    
    Final template: 
    {{ label }}
     x% of Total (in)
     x% of Total (out)
    """
    
    # calculate total flows through each node
    sorted_nodes = (reduce(lambda x, y: x+y, node_groups))
    node_totals = (
        df
        .groupby('target')
        ['count'].sum()
        .reindex(sorted_nodes)
    )  
    
    # calculate % of inputs 
    total_input = df.set_index('source').loc[node_groups[0], 'count'].sum()
    pct_of_total_in = (
        node_totals
        .pipe(lambda x: x / total_input)
        .fillna(1)
        .map('{:.0%}'.format)
        .values
        .tolist()
    )
    
    # calculate % of outputs
    total_output = df.set_index('target').loc[node_groups[-1], 'count'].sum()
    pct_of_total_out = (
        node_totals
        .pipe(lambda x: x / total_output)
        #.fillna(1)
        .map('{:.0%}'.format)
        .values
        .tolist()
    )    
    
    customdata = [(x, y) for x, y in zip(pct_of_total_in, pct_of_total_out)]

    hovertemplate = '%{label}<br>' + \
            ' %{customdata[0]} of total inputs<br>' + \
            ' %{customdata[1]} of total outputs'
    
    return customdata, hovertemplate


def calc_link_tooltips(df: pd.DataFrame, node_groups: list) -> Tuple[List[Tuple[float]], str]:
    """ Generate data/template for tooltips on a Plotly sankey diagram.
    
    Expects input df with three columns: (source, target, count)
    
    Final template: 
    {{ source }} → {{ target }}
     x% of {{ source }}
     x% of {{ target }}
    """
    
    pct_of_source = (
        df
        .set_index('source')
        .assign(denom=lambda x: x.groupby(level=0)['count'].sum()[x.index])
        .pipe(lambda x: x['count'] / x['denom'])
        .map('{:.0%}'.format)
        .values
        .tolist()
    )

    pct_of_target = (
        df
        .set_index('target')
        .assign(denom=lambda x: x.groupby(level=0)['count'].sum()[x.index])
        .pipe(lambda x: x['count'] / x['denom'])
        .map('{:.0%}'.format)
        .values
        .tolist()
    )    
    
    customdata = [(x, y) for x, y in zip(pct_of_source, pct_of_target)]
    
    hovertemplate = '%{source.label} → %{target.label}<br>' + \
            ' %{customdata[0]} of %{source.label}<br>' + \
            ' %{customdata[1]} of %{target.label}'
    
    return customdata, hovertemplate

    
def make_sankey_params_with_tooltips(df: pd.DataFrame, node_groups: list = None) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Create list of unique labels across node columns (source, target)
    observed_labels_set = df['source'].pipe(set) | df['target'].pipe(set)
    expected_labels = reduce(lambda x, y: x+y, node_groups)
    expected_labels_set = set(expected_labels)
    assert observed_labels_set == expected_labels_set, \
        'Mismatch between node_groups and unique values in source/target columns\n' \
        f'\tMissing from node_groups: {observed_labels_set-expected_labels_set}\n' \
        f'\tMissing from df: {expected_labels_set-observed_labels_set}'
    
    # Nodes
    nodes_dict = {'label': expected_labels}
    nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)

    # Links
    sources, targets, values = df.values.T.tolist()
    source_idx = list(map(expected_labels.index, sources))
    target_idx = list(map(expected_labels.index, targets))
    links_dict = {'source': source_idx, 'target': target_idx, 'value': values}
    
    # Tooltips
    nodes_dict['customdata'], nodes_dict['hovertemplate'] = calc_node_tooltips(df, node_groups)
    links_dict['customdata'], links_dict['hovertemplate'] = calc_link_tooltips(df, node_groups)
    
    return nodes_dict, links_dict


node_groups = [
    ['Visit'],
    ['Search', 'Bounce (from visit)'],
    ['Click', 'Bounce (from search)'],
    ['Buy', 'Bounce (from click)']
]

nodes, links = make_sankey_params_with_tooltips(funnel_df2, node_groups)

data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Hover tooltips (for both nodes and links)', margin=default_margins)
VisitSearchBounce (from visit)ClickBounce (from search)BuyBounce (from click)
Hover tooltips (for both nodes and links)

Final version

Everything all together now:

def make_sankey_params(
    df: pd.DataFrame,
    node_groups: list = None,
    colors: dict = None
) -> dict:
    """ Generate parameter dicts for go.Figure plotting function """
    
    # Create list of unique labels across node columns (source, target)
    observed_labels_set = df['source'].pipe(set) | df['target'].pipe(set)
    expected_labels = reduce(lambda x, y: x+y, node_groups)
    expected_labels_set = set(expected_labels)
    assert observed_labels_set == expected_labels_set, \
        'Mismatch between node_groups and unique values in source/target columns\n' \
        f'\tMissing from node_groups: {observed_labels_set-expected_labels_set}\n' \
        f'\tMissing from df: {expected_labels_set-observed_labels_set}'
    
    # Nodes
    nodes_dict = {'label': expected_labels}
    nodes_dict['x'], nodes_dict['y'] = calc_node_positions(expected_labels, node_groups)
    if colors:
        colors_rgba = [hex_to_rgba(x) for x in px.colors.qualitative.Plotly]
        remaining_colors = cycle(set(colors_rgba) - set(colors.values()))
        nodes_dict['color'] = [colors.get(x, next(remaining_colors)) for x in expected_labels]

    # Links
    sources, targets, values = df.values.T.tolist()
    source_idx = list(map(expected_labels.index, sources))
    target_idx = list(map(expected_labels.index, targets))
    links_dict = {'source': source_idx, 'target': target_idx, 'value': values}
    
    # Tooltips
    nodes_dict['customdata'], nodes_dict['hovertemplate'] = calc_node_tooltips(df, node_groups)
    links_dict['customdata'], links_dict['hovertemplate'] = calc_link_tooltips(df, node_groups)
    
    return nodes_dict, links_dict


node_groups = [
    ['Visit'],
    ['Search', 'Bounce (from visit)'],
    ['Click', 'Bounce (from search)'],
    ['Buy', 'Bounce (from click)']
]

colors_map = {
    'Bounce (from visit)': 'rgba(239,85,59,0.8)',
    'Bounce (from search)': 'rgba(239,85,59,0.8)',
    'Bounce (from click)': 'rgba(239,85,59,0.8)'
}

nodes, links = make_sankey_params(funnel_df2, node_groups, colors=colors_map)
data = go.Sankey(node=nodes, link=links, arrangement='snap')
fig = go.Figure(data)
fig.update_layout(title='Sankey w/ hover tooltips, specified colours and x/y positions', margin=default_margins)
VisitSearchBounce (from visit)ClickBounce (from search)BuyBounce (from click)
Sankey w/ hover tooltips, specified colours and x/y positions

Interactive widgets

Suppose our data is segmented by some sort of dimension such as country or platform. We’d probably visualize the combined overview first, but would then like to drill-down into particular dimensions.

Naively we could loop over the previous logic to generate separate flowcharts. But plotly diagrams are interactive, so we can do better—using a drop-down filter.

First let’s create an example segmented dataset:

funnel_df3 = pd.concat([
    funnel_df2.assign(platform='desktop'),
    funnel_df2.assign(platform='mobile').assign(count=lambda x: x['count'] + 10)
]).reset_index(drop=True)

funnel_df3

sourcetargetcountplatform
0VisitSearch80desktop
1VisitBounce (from visit)20desktop
2SearchClick60desktop
3SearchBounce (from search)20desktop
4ClickBuy40desktop
5ClickBounce (from click)20desktop
6VisitSearch90mobile
7VisitBounce (from visit)30mobile
8SearchClick70mobile
9SearchBounce (from search)30mobile
10ClickBuy50mobile
11ClickBounce (from click)30mobile
filtered_default = funnel_df3.groupby(['source', 'target'])['count'].sum().reset_index()
filtered_default

sourcetargetcount
0ClickBounce (from click)50
1ClickBuy90
2SearchBounce (from search)50
3SearchClick130
4VisitBounce (from visit)50
5VisitSearch170
def assemble_update_menus(label: str, nodes: dict, links: dict) -> dict:
    """ Assemble the dict structure expected by plotly to update dataset using drop-down filter. """
    
    return {
        'method': 'animate',
        'label': label,
        'args': [{
            'data': [{
                'type': 'sankey',
                'arrangement': 'snap',
                'node': nodes,
                'link': links
            }],
            'layout': {'title': label}
        }]
    }

# generate default view
nodes, links = make_sankey_params(filtered_default, node_groups, colors=colors_map)
default_args = assemble_update_menus('All', nodes, links)
fig = go.Figure(default_args['args'][0]['data'][0])

# generate set of filtered views
filtered_datasets = [default_args]
for group_name, grouped_df in funnel_df3.groupby('platform'):
    sub_df = grouped_df.drop('platform', axis=1)
    nodes, links = make_sankey_params(sub_df, node_groups, colors=colors_map)
    filtered_datasets.append(assemble_update_menus(group_name, nodes, links))
    
# add filter options
update_menus = [{'buttons': filtered_datasets, 'x': 0.7, 'y': 1.17}]
fig.update_layout(updatemenus=update_menus, title='Sankey diagram with interactive filters', margin=default_margins)
VisitSearchBounce (from visit)ClickBounce (from search)BuyBounce (from click)
Sankey diagram with interactive filtersAll

Todo

Coloured links

link_colors = []
for source, target in zip(sources, targets):
    print(source, target)
    if (source, target) == ('A', 'B') or target == 'C':
        link_colors.append('rgba(180,21,21,0.4)')
    else:
        link_colors.append('rgba(113, 113, 113, 0.4)')

links_dict = {'source': source_idx, 'target': target_idx, 'value': values, 'color': link_colors}

Further reading