Grammar of Graphics in Mathematics

Author

Thomas Camminady

Published

July 14, 2023

A gentle introduction

Say you’re a mathematician and you want to plot \(sin(x)\) and \(sin(x)\), how would you do that? You’d probably do something like this1:

import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label="sin(x)")
ax.plot(x, np.cos(x), label="cos(x)")
ax.legend()
ax.set(title="Trigonometric functions")
plt.show()

What I want to show you in this post is that it’s worth do alter that approach slightly. Let me show you.

import altair as alt
import pandas as pd

df = pd.concat(
    [
        pd.DataFrame({"x": x, "y": np.sin(x), "Function": "sin"}),
        pd.DataFrame({"x": x, "y": np.cos(x), "Function": "cos"}),
    ]
)
alt.Chart(df, title="Trigonometric functions").mark_line().encode(
    x="x",
    y="y",
    color="Function",
)

Let’s unwrap what happens here. As a first note, we are using altair as our plotting library, but something like seaborn,ggplot2, or Plot would have been fine as well.

Next, we are creating a pandas.DataFrame which is nothing but a table:

Code
df
x y Function
0 0.00000 0.000000 sin
1 0.10101 0.100838 sin
2 0.20202 0.200649 sin
3 0.30303 0.298414 sin
4 0.40404 0.393137 sin
... ... ... ...
95 9.59596 -0.985384 cos
96 9.69697 -0.963184 cos
97 9.79798 -0.931165 cos
98 9.89899 -0.889653 cos
99 10.00000 -0.839072 cos

200 rows × 3 columns

The one special thing about this dataframe however, is also the reason for this post. This is a long dataframe, not a wide dataframe. A wide dataframe would look like this:

Code
(
    df.pivot(index="x", columns="Function", values="y")
    .reset_index()
    .rename_axis(None, axis=1)
)
x cos sin
0 0.00000 1.000000 0.000000
1 0.10101 0.994903 0.100838
2 0.20202 0.979663 0.200649
3 0.30303 0.954437 0.298414
4 0.40404 0.919480 0.393137
... ... ... ...
95 9.59596 -0.985384 -0.170347
96 9.69697 -0.963184 -0.268843
97 9.79798 -0.931165 -0.364599
98 9.89899 -0.889653 -0.456637
99 10.00000 -0.839072 -0.544021

100 rows × 3 columns

This dataframe is wide, because instead of stacking the values for \(sin\) and \(cos\) on top of another, they are side by side.

There are a couple of reasons why the long way is better than the wide way.

  • A lot of modern visualization tools make heavy use of the Grammar of Graphics2 , an approach that is based on the long format.
  • You can store time series of different lengths in the same dataframe.
  • A lot of data transformations (e.g. groupby) are much easier to use this way.

Something I very much thought of as a potential downside, however, is the different storage that is needed. The wide dataframe needs to store \(3\cdot N\) double values (\(x\),\(sin\), \(cos\)), whereas the long format requires storage for \(4\cdot N\) double values and \(2\cdot N\) string values.

Another drawback is the added complexity when thinking about how the data should be stored.

So before I try to justify why this extra memory usage and complexity might be justified, let’s extend our example a little to make it slightly more complex. Let’s say we want to compare different frequencies for \(sin\) and \(cos\).

With our initial approach, this could look like this

fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
for i in range(1, 4):
    ax.plot(x, np.sin(i * x), label=f"sin({i}*x)")
    ax.plot(x, np.cos(i * x), label=f"cos({i}*x)")
ax.legend()
ax.set(title="Trigonometric functions")
plt.show()

Not pretty, but you get the idea.

Here’s the approach using altair and a long dataframe. First let’s bring the data in the correct form.

df = pd.concat(
    [
        pd.DataFrame(
            {
                "x": x,
                "y": np.sin(i * x),
                "Function": "sin",
                "Frequency": i,
            }
        )
        for i in range(1, 4)
    ]
    + [
        pd.DataFrame(
            {
                "x": x,
                "y": np.cos(i * x),
                "Function": "cos",
                "Frequency": i,
            }
        )
        for i in range(1, 4)
    ]
)
df
x y Function Frequency
0 0.00000 0.000000 sin 1
1 0.10101 0.100838 sin 1
2 0.20202 0.200649 sin 1
3 0.30303 0.298414 sin 1
4 0.40404 0.393137 sin 1
... ... ... ... ...
95 9.59596 -0.871008 cos 3
96 9.69697 -0.684721 cos 3
97 9.79798 -0.436037 cos 3
98 9.89899 -0.147619 cos 3
99 10.00000 0.154251 cos 3

600 rows × 4 columns

This is of course much more effort than before. However, the data creation is clearly separated from the visualization. Let’s make use of this effort in the visualization.

alt.Chart(df).mark_line().encode(
    x="x", y="y", color="Frequency:N", strokeDash="Function", row="Function"
).properties(width=500)

A couple of things are going on here. First we split up the plot into two subplots by specifying row="Function", i.e. the column of the dataframe that should be used as a row identifier. Then we said color="Frequency:N". Note the :N here. Without specifying that our data is nominal, it would be considered quantitative (:Q, the default), and the color map used for plotting would be a sequential color map instead of qualitative one.

Now the great thing is that we can simply change what we want to color or arrange our plot by.

(
    alt.Chart(df)
    .mark_line()
    .encode(
        x="x",
        y="y",
        color="Function:N",
        strokeDash="Frequency",
        row="Frequency",
    )
    .properties(width=450, height=100)
)

I think this is quite a powerful framework and libraries like altair,seaborn,ggplot2,Plot fundamentally rely on this.

More complex examples

Let’s try to reproduce a figure from a paper I co-authored. The exact data does not really matter, but here’s what we want to end up with.

A figure from a paper that I co-authored, https://arxiv.org/pdf/1808.05846.pdf

Now, because I don’t have access to the real data anymore, we’ll use some fake data instead.

Code
def get_fake_data(n, alpha) -> pd.DataFrame:
    nx = 10
    x = np.linspace(0, 1, nx)

    return pd.DataFrame(
        {
            "x": np.hstack([x, x, x]),
            "y": np.hstack(
                [
                    np.exp(-x * n) * np.sin(alpha * x),
                    np.exp(-x * n) * np.cos(alpha * x),
                    np.exp(-x * n) * np.cos(alpha * x) * np.sin(alpha * x),
                ]
            ),
            "cut": ["hori"] * nx + ["diag"] * nx + ["verti"] * nx,
            "alpha": [alpha] * (3 * nx),
            "n": [n] * (3 * nx),
        }
    )


N = [1, 2, 3, 4]
ALPHA = [1, 2, 3, 4]
df = pd.concat([get_fake_data(n, alpha) for n in N for alpha in ALPHA])
df
x y cut alpha n
0 0.000000 0.000000 hori 1 1
1 0.111111 0.099222 hori 1 1
2 0.222222 0.176481 hori 1 1
3 0.333333 0.234445 hori 1 1
4 0.444444 0.275680 hori 1 1
... ... ... ... ... ...
25 0.555556 -0.052251 verti 4 4
26 0.666667 -0.028256 verti 4 4
27 0.777778 -0.001357 verti 4 4
28 0.888889 0.010520 verti 4 4
29 1.000000 0.009060 verti 4 4

480 rows × 5 columns

Let’s plot this data to recreate the original figure.

alt.Chart(df).mark_line().encode(
    x="x",
    y="y",
    row="n",
    column="alpha",
    color="cut",
).properties(width=100, height=100)

Not identical, but you get the idea. The main issue here is the lack of LaTeX support. However, this is altair specific.

Lastly, not that there is nothing stopping us from having a nicer looking plot by simply changing the theme.

from camminapy.plot import altair_theme

altair_theme()
alt.Chart(df).mark_line().encode(
    x="x",
    y="y",
    row="n",
    column="alpha",
    color="cut",
).properties(width=100, height=100)

Footnotes

  1. If you’re using matlab instead of python or julia it might be worth considering a switch.↩︎

  2. https://link.springer.com/book/10.1007/0-387-28695-0↩︎