Here’s our Short & Taught series of articles


When in Doubt, Use GAMs!

Welcome to Short & Taught: the series that shows how to do interesting analysis techniques in the shortest time possible.

Premise

You are a data analyst for a consulting firm. Your client just handed you a spreadsheet of bivariate data for an outcome of interest and a related variable. They want you to make sense of the relationship between the two. You open up the data file and discover that the plot shows high non-linearity. What do you do?

There are plenty of approaches one can take here. We are going to explore a very underrated and underused one - generalised additive models (GAMs).

GAMs

GAMs are an extremely flexible extension of generalised linear models whose outcomes are modelled as combinations of unknown smoothing functions on each predictor. Users have control over a great number of parameters associated with this smoothing, making them highly suitable to nonlinear, spatial, and all manner of other data types.

Let's code one up!

Setting up the R code

First, we'll load the libraries we need for this tutorial, simulate some noisy data, and plot it.

library(dplyr)
library(ggplot2)
library(mgcv)
library(tidymv)

# Simulate data

set.seed(123)

d <- data.frame(y = cumsum(rnorm(1000, mean = 0, sd = 1))) %>%
  mutate(x = row_number())

# Draw summary plot

d %>%
  ggplot(aes(x = x, y = y)) +
  geom_point(size = 0.8, colour = "steelblue2") +
  labs(title = "Raw data",
       x = "X",
       y = "Y") +
  theme_bw() +
  theme(panel.grid.minor = element_blank())
raw-data.png

With an idea of what the data looks like, we can now fit a very basic model (no parameter tuning or model diagnostics will be presented in this very short tutorial). We will also extract the model deviance explained, which is a more appropriate metric than R-squared for additive models.

# Fit GAM model

m1 <- gam(formula = y ~ s(x), data = d)

# Extract model deviance explained

dev.expl <- round(summary(m1)$dev.expl, digits = 4)*100

We can then use this model to make predictions over the space of our X variable, and then re-create the first plot with our model predictions and associated 95% confidence interval.

# Predict new data

model_p <- predict_gam(m1)

# Re-plot

model_p %>%
  ggplot(aes(x, fit)) +
  geom_smooth_ci() +
  geom_point(data = d, aes(x = x, y = y), size = 0.8, colour = "steelblue2") +
  labs(title = "Raw data with model-predicted 95% confidence interval",
       subtitle = paste0("Model deviance explained: ",dev.expl,"%"),
       x = "X",
       y = "Y") +
  theme_bw() +
  theme(panel.grid.minor = element_blank())
model-preds.png

With no detailed parameter specification, almost 80% deviance explained is not bad at all, especially considering how interpretable a GAM is compared to many common machine learning models. Such is the power of GAMs!