import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
ax.plot(x, np.sin(x), label="sin(x)")
ax.plot(x, np.cos(x), label="cos(x)")
ax.legend()
ax.set(title="Trigonometric functions")
plt.show()
Grammar of Graphics in Mathematics
A gentle introduction
Say you’re a mathematician and you want to plot \(sin(x)\) and \(sin(x)\), how would you do that? You’d probably do something like this1:
What I want to show you in this post is that it’s worth do alter that approach slightly. Let me show you.
import altair as alt
import pandas as pd
df = pd.concat(
[
pd.DataFrame({"x": x, "y": np.sin(x), "Function": "sin"}),
pd.DataFrame({"x": x, "y": np.cos(x), "Function": "cos"}),
]
)
alt.Chart(df, title="Trigonometric functions").mark_line().encode(
x="x",
y="y",
color="Function",
)
Let’s unwrap what happens here. As a first note, we are using altair
as our plotting library, but something like seaborn
,ggplot2
, or Plot
would have been fine as well.
Next, we are creating a pandas.DataFrame
which is nothing but a table:
x | y | Function | |
---|---|---|---|
0 | 0.00000 | 0.000000 | sin |
1 | 0.10101 | 0.100838 | sin |
2 | 0.20202 | 0.200649 | sin |
3 | 0.30303 | 0.298414 | sin |
4 | 0.40404 | 0.393137 | sin |
... | ... | ... | ... |
95 | 9.59596 | -0.985384 | cos |
96 | 9.69697 | -0.963184 | cos |
97 | 9.79798 | -0.931165 | cos |
98 | 9.89899 | -0.889653 | cos |
99 | 10.00000 | -0.839072 | cos |
200 rows × 3 columns
The one special thing about this dataframe
however, is also the reason for this post. This is a long dataframe
, not a wide dataframe
. A wide dataframe
would look like this:
Code
x | cos | sin | |
---|---|---|---|
0 | 0.00000 | 1.000000 | 0.000000 |
1 | 0.10101 | 0.994903 | 0.100838 |
2 | 0.20202 | 0.979663 | 0.200649 |
3 | 0.30303 | 0.954437 | 0.298414 |
4 | 0.40404 | 0.919480 | 0.393137 |
... | ... | ... | ... |
95 | 9.59596 | -0.985384 | -0.170347 |
96 | 9.69697 | -0.963184 | -0.268843 |
97 | 9.79798 | -0.931165 | -0.364599 |
98 | 9.89899 | -0.889653 | -0.456637 |
99 | 10.00000 | -0.839072 | -0.544021 |
100 rows × 3 columns
This dataframe
is wide, because instead of stacking the values for \(sin\) and \(cos\) on top of another, they are side by side.
There are a couple of reasons why the long way is better than the wide way.
- A lot of modern visualization tools make heavy use of the Grammar of Graphics2 , an approach that is based on the long format.
- You can store time series of different lengths in the same
dataframe
. - A lot of data transformations (e.g.
groupby
) are much easier to use this way.
Something I very much thought of as a potential downside, however, is the different storage that is needed. The wide dataframe
needs to store \(3\cdot N\) double
values (\(x\),\(sin\), \(cos\)), whereas the long format requires storage for \(4\cdot N\) double
values and \(2\cdot N\) string
values.
Another drawback is the added complexity when thinking about how the data should be stored.
So before I try to justify why this extra memory usage and complexity might be justified, let’s extend our example a little to make it slightly more complex. Let’s say we want to compare different frequencies for \(sin\) and \(cos\).
With our initial approach, this could look like this
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
for i in range(1, 4):
ax.plot(x, np.sin(i * x), label=f"sin({i}*x)")
ax.plot(x, np.cos(i * x), label=f"cos({i}*x)")
ax.legend()
ax.set(title="Trigonometric functions")
plt.show()
Not pretty, but you get the idea.
Here’s the approach using altair
and a long dataframe
. First let’s bring the data in the correct form.
df = pd.concat(
[
pd.DataFrame(
{
"x": x,
"y": np.sin(i * x),
"Function": "sin",
"Frequency": i,
}
)
for i in range(1, 4)
]
+ [
pd.DataFrame(
{
"x": x,
"y": np.cos(i * x),
"Function": "cos",
"Frequency": i,
}
)
for i in range(1, 4)
]
)
df
x | y | Function | Frequency | |
---|---|---|---|---|
0 | 0.00000 | 0.000000 | sin | 1 |
1 | 0.10101 | 0.100838 | sin | 1 |
2 | 0.20202 | 0.200649 | sin | 1 |
3 | 0.30303 | 0.298414 | sin | 1 |
4 | 0.40404 | 0.393137 | sin | 1 |
... | ... | ... | ... | ... |
95 | 9.59596 | -0.871008 | cos | 3 |
96 | 9.69697 | -0.684721 | cos | 3 |
97 | 9.79798 | -0.436037 | cos | 3 |
98 | 9.89899 | -0.147619 | cos | 3 |
99 | 10.00000 | 0.154251 | cos | 3 |
600 rows × 4 columns
This is of course much more effort than before. However, the data creation is clearly separated from the visualization. Let’s make use of this effort in the visualization.
alt.Chart(df).mark_line().encode(
x="x", y="y", color="Frequency:N", strokeDash="Function", row="Function"
).properties(width=500)
A couple of things are going on here. First we split up the plot into two subplots by specifying row="Function"
, i.e. the column of the dataframe
that should be used as a row identifier. Then we said color="Frequency:N"
. Note the :N
here. Without specifying that our data is nominal, it would be considered quantitative (:Q
, the default), and the color map used for plotting would be a sequential color map instead of qualitative one.
Now the great thing is that we can simply change what we want to color or arrange our plot by.
(
alt.Chart(df)
.mark_line()
.encode(
x="x",
y="y",
color="Function:N",
strokeDash="Frequency",
row="Frequency",
)
.properties(width=450, height=100)
)
I think this is quite a powerful framework and libraries like altair
,seaborn
,ggplot2
,Plot
fundamentally rely on this.
More complex examples
Let’s try to reproduce a figure from a paper I co-authored. The exact data does not really matter, but here’s what we want to end up with.
Now, because I don’t have access to the real data anymore, we’ll use some fake data instead.
Code
def get_fake_data(n, alpha) -> pd.DataFrame:
nx = 10
x = np.linspace(0, 1, nx)
return pd.DataFrame(
{
"x": np.hstack([x, x, x]),
"y": np.hstack(
[
np.exp(-x * n) * np.sin(alpha * x),
np.exp(-x * n) * np.cos(alpha * x),
np.exp(-x * n) * np.cos(alpha * x) * np.sin(alpha * x),
]
),
"cut": ["hori"] * nx + ["diag"] * nx + ["verti"] * nx,
"alpha": [alpha] * (3 * nx),
"n": [n] * (3 * nx),
}
)
N = [1, 2, 3, 4]
ALPHA = [1, 2, 3, 4]
df = pd.concat([get_fake_data(n, alpha) for n in N for alpha in ALPHA])
x | y | cut | alpha | n | |
---|---|---|---|---|---|
0 | 0.000000 | 0.000000 | hori | 1 | 1 |
1 | 0.111111 | 0.099222 | hori | 1 | 1 |
2 | 0.222222 | 0.176481 | hori | 1 | 1 |
3 | 0.333333 | 0.234445 | hori | 1 | 1 |
4 | 0.444444 | 0.275680 | hori | 1 | 1 |
... | ... | ... | ... | ... | ... |
25 | 0.555556 | -0.052251 | verti | 4 | 4 |
26 | 0.666667 | -0.028256 | verti | 4 | 4 |
27 | 0.777778 | -0.001357 | verti | 4 | 4 |
28 | 0.888889 | 0.010520 | verti | 4 | 4 |
29 | 1.000000 | 0.009060 | verti | 4 | 4 |
480 rows × 5 columns
Let’s plot this data to recreate the original figure.
alt.Chart(df).mark_line().encode(
x="x",
y="y",
row="n",
column="alpha",
color="cut",
).properties(width=100, height=100)
Not identical, but you get the idea. The main issue here is the lack of LaTeX support. However, this is altair
specific.
Lastly, not that there is nothing stopping us from having a nicer looking plot by simply changing the theme.