Here’s our Short & Taught series of articles
When in Doubt, Use GAMs!
Welcome to Short & Taught: the series that shows how to do interesting analysis techniques in the shortest time possible.
Premise
You are a data analyst for a consulting firm. Your client just handed you a spreadsheet of bivariate data for an outcome of interest and a related variable. They want you to make sense of the relationship between the two. You open up the data file and discover that the plot shows high non-linearity. What do you do?
There are plenty of approaches one can take here. We are going to explore a very underrated and underused one - generalised additive models (GAMs).
GAMs
GAMs are an extremely flexible extension of generalised linear models whose outcomes are modelled as combinations of unknown smoothing functions on each predictor. Users have control over a great number of parameters associated with this smoothing, making them highly suitable to nonlinear, spatial, and all manner of other data types.
Let's code one up!
Setting up the R code
First, we'll load the libraries we need for this tutorial, simulate some noisy data, and plot it.
library(dplyr)
library(ggplot2)
library(mgcv)
library(tidymv)
# Simulate data
set.seed(123)
d <- data.frame(y = cumsum(rnorm(1000, mean = 0, sd = 1))) %>%
mutate(x = row_number())
# Draw summary plot
d %>%
ggplot(aes(x = x, y = y)) +
geom_point(size = 0.8, colour = "steelblue2") +
labs(title = "Raw data",
x = "X",
y = "Y") +
theme_bw() +
theme(panel.grid.minor = element_blank())
With an idea of what the data looks like, we can now fit a very basic model (no parameter tuning or model diagnostics will be presented in this very short tutorial). We will also extract the model deviance explained, which is a more appropriate metric than R-squared for additive models.
# Fit GAM model
m1 <- gam(formula = y ~ s(x), data = d)
# Extract model deviance explained
dev.expl <- round(summary(m1)$dev.expl, digits = 4)*100
We can then use this model to make predictions over the space of our X variable, and then re-create the first plot with our model predictions and associated 95% confidence interval.
# Predict new data
model_p <- predict_gam(m1)
# Re-plot
model_p %>%
ggplot(aes(x, fit)) +
geom_smooth_ci() +
geom_point(data = d, aes(x = x, y = y), size = 0.8, colour = "steelblue2") +
labs(title = "Raw data with model-predicted 95% confidence interval",
subtitle = paste0("Model deviance explained: ",dev.expl,"%"),
x = "X",
y = "Y") +
theme_bw() +
theme(panel.grid.minor = element_blank())
With no detailed parameter specification, almost 80% deviance explained is not bad at all, especially considering how interpretable a GAM is compared to many common machine learning models. Such is the power of GAMs!