Interactive regression

Author

Thomas Camminady

Published

June 28, 2023

This example shows how to create an interactive plot with regression lines using altair. The regression lines will be computed over the window that is selected and update accordingly when moving the selected region.

This code block contains all import statements.
import polars as pl
import numpy as np
import altair as alt
from vega_datasets import data

Let’s look at a sample data set from vega_datasets, inspired by an example from the altair documentation.

source = data.iris()
chart = (
    alt.Chart(source)
    .mark_circle()
    .encode(
        alt.X("sepalLength").scale(zero=False),
        alt.Y("sepalWidth").scale(zero=False, padding=1),
        color="species",
    )
    .properties(width=600)
)
chart

Now we can add a linear regression line very easily

regression = chart.transform_regression(
    "sepalLength",
    "sepalWidth",
    groupby=["species"],
    method="poly",
    order=5,
).mark_line()
alt.layer(chart, regression)

Now let’s allow for some interactivity. We want to be able to select points and have the regression line be updated on that selection of points. However, we still want to plot it over the full domain, that’s why we have to set extent. We’ll make the stroke of the extrapolation range dashed and the interpolation range solid.

brush = alt.selection_interval()


chart = (
    alt.Chart(source)
    .mark_circle()
    .encode(
        alt.X("sepalLength").scale(domain=(4, 8)),
        alt.Y("sepalWidth").scale(domain=(1, 6)),
        color="species:N",
    )
    .properties(width=600)
    .add_params(brush)
)

regression_solid = (
    chart.transform_filter(brush)
    .transform_regression(
        "sepalLength",
        "sepalWidth",
        groupby=["species"],
        method="poly",
        order=3,
    )
    .mark_line(clip=True)
)

regression_dash = (
    chart.transform_filter(brush)
    .transform_regression(
        "sepalLength",
        "sepalWidth",
        groupby=["species"],
        method="poly",
        order=3,
        extent=[4, 8],
    )
    .mark_line(clip=True, strokeDash=[5, 5])
)


alt.layer(
    chart.encode(
        opacity=alt.condition(brush, alt.value(1.0), alt.value(0.1)),
    ),
    regression_dash,
    regression_solid,
)