11  Lasso

This goal of this chapter is to create an interactive data visualization that explains the Lasso, a machine learning model for regularized linear regression.

Chapter outline:

11.1 Static plots of the coefficient regularization path

We begin by loading the prostate cancer data set.

if(!requireNamespace("animint2data"))
  remotes::install_github("animint/animint2data")
Loading required namespace: animint2data
data(prostate, package="animint2data")
library(data.table)
print(prostate, topn=1, trunc.cols = TRUE)
        lcavol  lweight age       lbph svi       lcp gleason pgg45       lpsa
 1: -0.5798185 2.769459  50 -1.3862944   0 -1.386294       6     0 -0.4307829
---                                                                          
97:  3.4719665 3.974998  68  0.4382549   1  2.904165       7    20  5.5829322
1 variable not shown: [train]

The output above shows the first and last row of the data table. The train column indicates a pre-defined split in the data. In this visualization, we will study the regularization path of the Lasso, for which we need a hold-out set to learn the optimal degree of regularization. We will use the split set name subtrain for the data used to compute linear model coefficients, and validation for the data used for selecting the regularization parameter (by minimizing prediction error on this set).

prostate[
, set := ifelse(train, "subtrain", "validation")
][, table(set)]
set
  subtrain validation 
        67         30 

We construct subtrain inputs x and outputs y using the code below.

input.cols <- c(
  "lcavol", "lweight", "age", "lbph",
  "svi", "lcp", "gleason", "pgg45")
prostate.inputs <- prostate[, ..input.cols]
is.subtrain <- prostate$set == "subtrain"
x <- as.matrix(prostate.inputs[is.subtrain])
head(x)
         lcavol  lweight age      lbph svi       lcp gleason pgg45
[1,] -0.5798185 2.769459  50 -1.386294   0 -1.386294       6     0
[2,] -0.9942523 3.319626  58 -1.386294   0 -1.386294       6     0
[3,] -0.5108256 2.691243  74 -1.386294   0 -1.386294       7    20
[4,] -1.2039728 3.282789  58 -1.386294   0 -1.386294       6     0
[5,]  0.7514161 3.432373  62 -1.386294   0 -1.386294       6     0
[6,] -1.0498221 3.228826  50 -1.386294   0 -1.386294       6     0
y <- prostate[is.subtrain, lpsa]
head(y)
[1] -0.4307829 -0.1625189 -0.1625189 -0.1625189  0.3715636  0.7654678

Below we fit the full path of lasso solutions using the lars package.

Loaded lars 1.3
fit <- lars(x,y,type="lasso")
fit$lambda
[1] 7.1939462 3.7172742 2.9403866 1.7305064 1.7002813 0.4933166 0.3711651
[8] 0.0403451

The path of lambda values are not evenly spaced.

pred.nox <- predict(fit, type="coef")
beta <- scale(pred.nox$coefficients, FALSE, 1/fit$normx)
arclength <- rowSums(abs(beta))
path.list <- list()
for(variable in colnames(beta)){
  standardized.coef <- beta[, variable]
  path.list[[variable]] <- data.table::data.table(
    step=seq_along(standardized.coef),
    lambda=c(fit$lambda, 0),
    variable,
    standardized.coef,
    fraction=pred.nox$fraction,
    arclength)
}
path <- do.call(rbind, path.list)
variable.colors <- c(
  "#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33", 
  "#A65628", "#F781BF", "#999999")
library(animint2)
gg.lambda <- ggplot()+
  theme_bw()+
  theme(panel.margin=grid::unit(0, "lines"))+
  scale_color_manual(values=variable.colors)+
  geom_line(aes(
    lambda, standardized.coef, color=variable, group=variable),
    data=path)
gg.lambda

The plot above shows the entire lasso path, the optimal weights in the L1-regularized least squares regression problem, for every regularization parameter lambda. The path begins at the least squares solution, lambda=0 on the left. It ends at the completely regularized intercept-only model on the right. To see the equivalence with the ordinary least squares solution, we add dots in the plot below.

x.scaled <- with(fit, scale(x, meanx, normx))
lfit <- lm.fit(x.scaled, y)
lpoints <- data.table::data.table(
  variable=colnames(x),
  standardized.coef=lfit$coefficients,
  arclength=sum(abs(lfit$coefficients)))
gg.lambda+
  geom_point(aes(
    0, standardized.coef, color=variable),
    data=lpoints)

In the next plot below, we show the path as a function of L1 norm (arclength), with some more points on an evenly spaced grid that we will use later for animation.

fraction <- sort(unique(c(
  seq(0, 1, l=21))))
pred.fraction <- predict(
  fit, prostate.inputs,
  type="coef", mode="fraction", s=fraction)
coef.grid.list <- list()
coef.grid.mat <- scale(pred.fraction$coefficients, FALSE, 1/fit$normx)
for(fraction.i in seq_along(fraction)){
  standardized.coef <- coef.grid.mat[fraction.i,]
  coef.grid.list[[fraction.i]] <- data.table::data.table(
    fraction=fraction[[fraction.i]],
    variable=colnames(x),
    standardized.coef,
    arclength=sum(abs(standardized.coef)))
}
coef.grid <- do.call(rbind, coef.grid.list)
ggplot()+
  theme_bw()+
  theme(panel.margin=grid::unit(0, "lines"))+
  scale_color_manual(values=variable.colors)+
  geom_line(aes(
    arclength, standardized.coef, color=variable, group=variable),
    data=path)+
  geom_point(aes(
    arclength, standardized.coef, color=variable),
    data=lpoints)+
  geom_point(aes(
    arclength, standardized.coef, color=variable),
    shape=21,
    fill=NA,
    size=3,
    data=coef.grid)

The plot above shows that the weights at the grid points are consistent with the lines that represent the entire path of solutions. The LARS algorithm quickly provides Lasso solutions for as many grid points as you like. More precisely, since the LARS only computes the change-points in the piecewise linear path, its time complexity only depends on the number of change-points (not the number of grid points).

11.2 Interactive visualization of the regularization path

In this section, we combine the lasso weight path with the subtrain/validation error plot. First, we compute a data table with one row per model size and set.

pred.list <- predict(
  fit, prostate.inputs,
  mode="fraction", s=fraction)
residual.mat <- pred.list$fit - prostate$lpsa
squares.mat <- residual.mat * residual.mat
mean.error <- prostate[, data.table(
  fraction,
  mse=colMeans(squares.mat[.I, ]),
  arclength=rowSums(abs(coef.grid.mat))
), by=set]
print(mean.error, topn=2)
           set fraction       mse  arclength
 1:   subtrain     0.00 1.4370365  0.0000000
 2:   subtrain     0.05 1.2524384  0.9182159
---                                         
41: validation     0.95 0.5090004 17.4461020
42: validation     1.00 0.5212740 18.3643178

Note in the code above that we used the data table special symbol .I, which is set to the indices corresponding to the current value of by=set, used to compute the mse for each set. The table in the output above is used to plot the error curves below.

rect.width <- diff(mean.error$arclength[1:2])/2
addY <- function(dt, y){
  data.table::data.table(dt, y.var=factor(y, c("error", "weights")))
}
tallrect.dt <- coef.grid[variable==variable[1],]
gg.path <- ggplot()+
  theme_bw()+
  theme(panel.margin=grid::unit(0, "lines"))+
  theme_animint(width=300, rowspan=1)+
  facet_grid(y.var ~ ., scales="free")+
  ylab("")+
  scale_color_manual(values=variable.colors)+
  geom_line(aes(
    arclength, standardized.coef, color=variable, group=variable),
    data=addY(path, "weights"))+
  geom_line(aes(
    arclength, mse, linetype=set, group=set),
    data=addY(mean.error, "error"))+
  geom_tallrect(aes(
    xmin=arclength-rect.width,
    xmax=arclength+rect.width),
    clickSelects="arclength",
    alpha=0.5,
    data=tallrect.dt)
gg.path

Finally, we add a plot of residuals versus actual values.

lasso.res.list <- list()
for(fraction.i in seq_along(fraction)){
  lasso.res.list[[fraction.i]] <- data.table::data.table(
    observation.i=1:nrow(prostate),
    fraction=fraction[[fraction.i]],
    residual=residual.mat[, fraction.i],
    response=prostate$lpsa,
    arclength=sum(abs(coef.grid.mat[fraction.i,])),
    set=prostate$set)
}
lasso.res <- do.call(rbind, lasso.res.list)
hline.dt <- data.table::data.table(residual=0)
gg.res <- ggplot()+
  theme_bw()+
  geom_hline(aes(
    yintercept=residual),
    data=hline.dt,
    color="grey")+
  geom_point(aes(
    response, residual, fill=set, 
    key=observation.i),
    showSelected="arclength",
    shape=21,
    data=lasso.res)
gg.res

Below, we combine the ggplots above in a single animint below. Clicking the first plot changes the regularization parameter, and the residuals that are shown in the second plot.

animint(
  gg.path,
  gg.res,
  duration=list(arclength=2000),
  time=list(variable="arclength", ms=2000))

11.3 Re-design with moving tallrects

The re-design below has two changes. First, you may have noticed that there are two different set legends in the previous animint (linetype=set in the first path plot, and color=set in the second residual plot). It would be easier for the reader to decode if the set variable had just one mapping. So in the re-design below we replace the geom_point in the second plot with a geom_segment with linetype=set.

Second, we have replaced the single tallrect in the first plot with two tallrects. The first tallrect has showSelected=arclength and is used to show the selected arclength using a grey rectangle. Since we specify a duration for the arclength variable, and the same key=1 value, we will observe a smooth transition of the selected grey tallrect. The second tallrect has clickSelects=arclength and so clicking it has the effect of changing the selected value of arclength. We specify a another data set with more rows, and use the named clickSelects/showSelected variables to indicate that arclength should also be used as a showSelected variable.

tallrect.show.list <- list()
for(a in tallrect.dt$arclength){
  is.selected <- tallrect.dt$arclength == a
  not.selected <- tallrect.dt[!is.selected]
  tallrect.show.list[[paste(a)]] <- data.table::data.table(
    not.selected, show.val=a, show.var="arclength")
}
tallrect.show <- do.call(rbind, tallrect.show.list)
animint(
  path=ggplot()+
    theme_bw()+
    theme(panel.margin=grid::unit(0, "lines"))+
    facet_grid(y.var ~ ., scales="free")+
    ylab("")+
    scale_color_manual(values=variable.colors)+
    geom_line(aes(
      arclength, standardized.coef, color=variable, group=variable),
      data=addY(path, "weights"))+
    geom_line(aes(
      arclength, mse, linetype=set, group=set),
      data=addY(mean.error, "error"))+
    geom_tallrect(aes(
      xmin=arclength-rect.width,
      xmax=arclength+rect.width,
      key=1),
      showSelected="arclength",
      alpha=0.5,
      data=tallrect.dt)+
    geom_tallrect(aes(
      xmin=arclength-rect.width,
      xmax=arclength+rect.width,
      key=paste(arclength, show.val)),
      clickSelects="arclength",
      showSelected=c("show.var"="show.val"),
      alpha=0.5,
      data=tallrect.show),
  res=ggplot()+
    theme_bw()+
    geom_hline(aes(
      yintercept=residual),
      data=hline.dt,
      color="grey")+
    guides(linetype="none")+
    geom_point(aes(
      response, residual, 
      key=observation.i),
      showSelected=c("set", "arclength"),
      shape=21,
      fill=NA,
      color="black",
      data=lasso.res)+
    geom_text(aes(
      3, 2.5, label=sprintf("L1 arclength = %.1f", arclength),
      key=1),
      size=15,
      showSelected="arclength",
      data=tallrect.dt)+
    geom_text(aes(
      0, ifelse(set=="subtrain", -2, -2.5),
      label=sprintf("%s error = %.3f", set, mse),
      key=1),
      size=15,
      showSelected=c("set", "arclength"),
      hjust=0,
      data=mean.error[set=="subtrain"])+
    geom_segment(aes(
      response, residual,
      xend=response, yend=0,
      linetype=set,
      key=observation.i),
      showSelected=c("set", "arclength"),
      size=1,
      data=lasso.res),
  duration=list(arclength=2000),
  time=list(variable="arclength", ms=2000))

11.4 Chapter summary and exercises

We created a visualization of the Lasso machine learning model, which simulataneously shows the regularization path and error curves. Interactivity was used to show details for different values of the regularization parameter.

Exercises:

  • Re-make this data viz, including the same visual effect for the tallrects, using only one geom_tallrect. Hint: create another data set with expand.grid(arclength.click=arclength, arclength.show=arclength), as in the definition of the make_tallrect_or_widerect function.
  • Add another scatterplot that shows predicted values versus response, with a geom_abline in the background to indicate perfect prediction.
  • How would the error curves look if other train/validation splits were chosen? Perform 4-fold cross-validation and add a plot that can be used to select test fold.

Next, Chapter 12 explains how to visualize the Support Vector Machine.