This repository has been archived by the owner on May 6, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 72
/
03-lasso.Rmd
151 lines (104 loc) · 4.48 KB
/
03-lasso.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# OLS and lasso
## Load packages
```{r load_packages}
library(glmnet)
library(ggplot2)
```
## Load data
Load `train_x_reg`, `train_y_reg`, `test_x_reg`, and `test_y_reg` variables we defined in 02-preprocessing.Rmd for the OLS and Lasso *regression* tasks.
```{r}
# Objects: task_reg, task_class
load("data/preprocessed.RData")
```
## Overview
* LASSO = sets Beta coefficients of unrelated (to Y) predcitors to zero
* RIDGE = sets Beta coefficients of unrealted (to Y) predictors NEAR ZERO but does not remove them
* ELASTICNET = a combination of LASSO and RIDGE
Below is an refresher of ordinary least squares linear (OLS) regression that predicts age using the other variables as predictors.
```{r}
# Fit the regression model; lm() will automatically add a temporary intercept column
fit = lm(train_y_reg ~ ., data = train_x_reg)
# View the output
summary(fit)
# Predict outcome for the test data
predicted = predict(fit, test_x_reg)
# 8. Calculate mean-squared error
(mse_reg = mean((test_y_reg - predicted)^2))
# Root mean-squared error
# We are off ~7.54 years for predicting age
sqrt(mse_reg)
```
Review "Challenge 0" in the Challenges folder for a useful review of how OLS regression works and [see the yhat blog](http://blog.yhat.com/posts/r-lm-summary.html) for help interpreting its output.
Linear regression is a useful introduction to machine learning, but in your research you might be faced with warning messages after `predict()` about the [rank of your matrix](https://stats.stackexchange.com/questions/35071/what-is-rank-deficiency-and-how-to-deal-with-it).
The lasso is useful to try and remove some of the non-associated features from the model. Because glmnet expects a matrix of predictors, use `as.matrix` to convert it from a data frame to a matrix.
Be sure to [read the glmnet vignette](https://web.stanford.edu/~hastie/Papers/Glmnet_Vignette.pdf)
## Fit model
```{r}
lasso = cv.glmnet(as.matrix(train_x_reg), train_y_reg, family = "gaussian", alpha = 1)
```
## Investigate Results
Visualize the distribution of log(lamba) vs mean-squared error.
```{r}
plot(lasso)
# Help interpreting this plot: https://stats.stackexchange.com/questions/404795/interpretation-of-cross-validation-plot-for-lasso-regression
# Generate our own version, but plot lambda (not on log scale) vs. RMSE.
qplot(lasso$lambda, sqrt(lasso$cvm)) + theme_minimal()
```
> NOTE: when log(lamba) is equal to 0 that means lambda is equal to 1. In this graph, the far right side is overpenalized, as the model is emphasizing the beta coefficients being small. As log(lambda) becomes increasingly negative, lambda is correspondingly closer to zero and we are approaching the OLS solution.
```{r}
# And here is a plot of log(lambda) vs lambda.
qplot(log(lasso$lambda), lasso$lambda) + theme_minimal()
```
Show plot of different lambda values:
```{r}
plot(lasso$glmnet.fit, xvar = "lambda", label = TRUE)
```
Show the lambda that results in the minimum estimated mean-squared error (MSE):
```{r}
lasso$lambda.min
```
Show higher lambda within [one standard error](https://stats.stackexchange.com/questions/80268/empirical-justification-for-the-one-standard-error-rule-when-using-cross-validat) of performance of the minimum
```{r}
lasso$lambda.1se
# Log scale versions:
log(c("log_min" = lasso$lambda.min, "log_1se" = lasso$lambda.1se))
```
Look at the coefficients
```{r}
(coef_1se = coef(lasso, s = "lambda.1se"))
```
Look at the coefficients for lambda.min
```{r}
(coef_min = coef(lasso, s = "lambda.min"))
# Compare side-by-side
cbind(as.matrix(coef_1se), as.matrix(coef_min))
```
Predict on the test set
```{r}
predictions = predict(lasso, newx = as.matrix(test_x_reg),
s = lasso$lambda.1se)
# How far off were we, based on absolute error?
rounded_errors = round(abs(test_y_reg - predictions))
table(rounded_errors)
# Group the absolute error into 4 bins.
grouped_errors = round(abs(test_y_reg - predictions) / 5)
grouped_errors[grouped_errors > 2] = 3
table(grouped_errors)
# 4 categories of accuracy
how_close = factor(grouped_errors, labels = c("very close", "close", "meh", "far"))
table(rounded_errors, how_close)
# Scatter plot of actual vs. predicted
qplot(test_y_reg, predictions,
color = how_close) +
geom_point(size = 5, alpha = 0.7) +
theme_minimal()
```
Calculate MSE and RMSE:
```{r}
# Calculate mean-squared error.
mean((predictions - test_y_reg)^2)
# Calculate root mean-squared error.
sqrt(mean((predictions - test_y_reg)^2))
```
## Challenge 1
Open Challenge 1 in the "Challenges" folder.