-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEMD601_ML_guest_lecture.qmd
564 lines (385 loc) · 13.1 KB
/
EMD601_ML_guest_lecture.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
---
title: "Supervised ML in clinical applications"
subtitle: "EXMD 601 McGill University"
author: "Alton Russell"
date: "27 March 2024"
format: revealjs
editor: visual
---
## R packages used
```{r}
#| echo: true
# install.packages("tidyverse","tidymodels", "ranger")
library(tidyverse)
library(tidymodels)
library(ranger)
library(caret) #for calibration plot
library(vip) #for variable importance plot
theme_set(theme_bw())
```
## Agenda
- **Types of supervised learning**
- Model development and selection
- Clinically useful models
## Supervised learning in a nutshell
- **Training:** Learn to predict on **labeled examples**
- Model maps features (covariates) to outcome (label)
- Can be complex, non-linear functions
- **Deployment:** Generate prediction for new examples
![](stroke_neural_net_example.png)
## ML vs. statistical models
::: columns
::: {.column width="60%"}
- Learn relationship between variables from data (not pre-specified)
- Allow complex non-linear relationships
- Less interpretable
- Fewer theoretical guarantees
- Generally, not suited for answering causal questions
:::
::: {.column width="40%"}
![](XKCD_ml.png)
:::
:::
## Regression
::: columns
::: {.column width="50%"}
- Predict a **continuous** value (hemoglobin level, length of stay)
- Statistics analog: linear regression
:::
::: {.column width="50%"}
![](regression_example.png)
:::
:::
## "Classification" or risk prediction
::: columns
::: {.column width="55%"}
- Predict a **categorical** outcome/event (death, recurrence, rehospitalization)
- Statistics analog: logistic regression
- Can be binary (cancer/no cancer) or multiclass (which bacteria is causing the urinary tract infection)
:::
::: {.column width="45%"}
![](classification_plot.png)
:::
:::
## Classifying vs. predicting risk
- Strict classification returns the most likely outcome
- "Patient will get pneumonia"
- Ignores uncertainty 😔
- Estimated risk \>\> strict classification
- 51% vs. 99% chance of pneumonia can have very different clinical implications
- [**Always**]{.underline} question classification models that don't provide ***probabilistic*** estimates
## Visualizing estimated risk
![](pdp-cervical-2d.jpeg)
[Interpretable Machine Learning by Molnar](https://christophm.github.io/interpretable-ml-book/) (partial dependency plot showing probability of cervical cancer given the interactino between age and number of pregnancies)
## Computer vision
::: columns
::: {.column width="50%"}
- Analyze pixel data
- Often, goal is to outline pbjects and assign label
- Person, car or tree
- Abnormality, tumor
- Used for X-ray, ultrasound, microscopy images
:::
::: {.column width="50%"}
![](computer_vision_xray.PNG)
[Kundu et. al. 2021](https://doi.org/10.1371/journal.pone.0256630)
:::
:::
## Generative AI/Large language models
::: columns
::: {.column width="70%"}
![](LLM_model_tuning.png)
:::
::: {.column width="30%"}
- Trained to mimic human text.
- Prone to bias and can 'hallucinate'!
- Can save time with vetting by expert
:::
:::
[Thirunavukarasu et. al. Nature Medicine 2023](https://doi.org/10.1038/s41591-023-02448-8)
## Agenda
- Types of supervised learning
- **Model development and selection**
- Clinically useful models
## Our example data[^1]
[^1]: Dataset: [Fetal health classification on Kaggle](https://www.kaggle.com/datasets/andrewmvd/fetal-health-classification). Images: [Thinkstock](https://www.babycenter.in/x1045384/what-is-cardiotocography-ctg-and-why-do-i-need-it); [geekymedics.com](https://geekymedics.com/how-to-read-a-ctg/)
Tabular data derived from **cardiotocography (CTGs)** from 2126 pregnant patients.
Outcome: fetal health is normal vs. suspect, or pathological.
::: columns
::: {.column width="50%"}
![](CTG_device.jpeg)
:::
::: {.column width="50%"}
![](CTG_output.jpeg)
:::
:::
## Our example data
```{r}
#| echo: true
dt <- read_csv("fetal_health.csv") |>
mutate(fetal_health = as.factor(ifelse(fetal_health==1,"Normal", "Abnormal")))
str(dt)
```
## Model development must avoid over-/under-fitting
![](overfitting_classifier.jpeg)
[raaicode.com](https://www.raaicode.com/what-is-overfitting/)
## Holding out a test set
- At start, set aside some data as test set
- Use the remaining data to train model
- At the [**very end**]{.underline}, run [**only one**]{.underline} [**final model**]{.underline} on the test set for an unbiased estimate of performance
![](train_test_split.png)
## Splitting with rsample
```{r}
#| echo: true
set.seed(125) #for reproducibility
ctg_split <- initial_split(dt, prop = 0.8)
ctg_split
#Distribution of outcome in training sets
table(training(ctg_split)$fetal_health)/nrow(training(ctg_split))
#Distribution of outcome in test sets
table(testing(ctg_split)$fetal_health)/nrow(testing(ctg_split))
```
## Stratifying data split
Stratifying is easy and ensures data in each split are representative
```{r}
#| echo: true
#Split while stratifying on outcome
set.seed(456)
ctg_split_strat <- initial_split(dt, prop = 0.8 ,strata = fetal_health)
#Distribution of outcome in training sets
table(training(ctg_split_strat)$fetal_health)/nrow(training(ctg_split_strat))
#Distribution of outcome in test sets
table(testing(ctg_split_strat)$fetal_health)/nrow(testing(ctg_split_strat))
```
## Model selection
- If trying just one model configuration, use full training data
- Usually, want to select best model from several **configurations** (algorithm + hyperparameters)
- Algorithm: type of model (e.g., random forest)
- Hyperparameter: 'setting' of that model (e.g., minimum node size)
- Further steps needed to avoid biased model selection in training data
## Cross validation
![](cross_validation_overview.png)
[Statology](https://www.statology.org/validation-set-vs-test-set/)
## Cross validation
Randomly assign each example to a 'fold'
![](three-CV.svg)
[workshops.tidymodels.org](https://workshops.tidymodels.org)
## Cross validation
::: columns
::: {.column width="75%"}
![](three-CV-iter.svg)
:::
::: {.column width="25%"}
Select model with the best **average** performance across each CV fold
:::
:::
[workshops.tidymodels.org](https://workshops.tidymodels.org)
## Split training set for cross validation
Multiple repeats of CV can further protect against choosing a model that 'got lucky'.
```{r}
#| echo: true
set.seed(456)
ctg_folds <- vfold_cv(training(ctg_split_strat),
v=3, #three folds
strata = fetal_health, #stratified on outcome
repeats = 2) #two repeats
ctg_folds
```
## Evaluating models
- Performance metrics identify how 'good' the model is for given dataset
- Perfectly predicting outcome = perfect performance
- Error metrics: smaller is better (0 error = perfect)
- Positive metrics: bigger is better
- Metrics differ for regression vs. classification
## Regression metrics
::: columns
::: {.column width="50%"}
For each example, error is difference between actual outcome $y$ and predicted outcome $\hat{y}$
- Mean absolute error: $\frac{1}{n} \sum \mid y - \hat{y} \mid$
- Mean squared error: $\frac{1}{n} \sum [y - \hat{y}]^2$
- Root mean squared error: $\sqrt{\frac{1}{n} \sum [y - \hat{y}]^2}$
:::
::: {.column width="50%"}
![](rmse_vs_mae.webp)
:::
:::
## Classification metrics: confusion matrix-based
::: columns
::: {.column width="65%"}
**After dichotomizing predicted risk**, you can create a confusion matrix
Many metrics derived from it
- Accuracy: $\frac{TP+TN}{TP+TN+FP+FN}$
- Precision: $\frac{TP}{PT+FP}$
- Recall: $\frac{TP}{TP+FN}$
:::
::: {.column width="35%"}
![](confusion_matrix.webp)
[Harikrishnan N B Medium](https://medium.com/analytics-vidhya/confusion-matrix-accuracy-precision-recall-f1-score-ade299cf63cd#:~:text=What%20is%20the%20accuracy%20of,the%20accuracy%20will%20be%2085%25.)
:::
:::
::: callout-caution
## Often, we do not want to dichotomize before selecting a model
:::
## Area under the ROC curve
::: columns
::: {.column width="60%"}
Measure of **discrimination**: Higher if the examples with a positive outcome were assigned a higher risk score
- -,-,-,-,-,+,+,+,+ $\rightarrow$ AUC=1 (perfectly discriminated +'s from -'s)
- -,+,-,+,-,+,-,+,-,+ $\rightarrow$ AUC = 0.5 (random chance)
:::
::: {.column width="40%"}
![](ROC_curve.png)
:::
:::
Does not measure **calibration** (how closely the predicted probabilities match actual risk)
## Performance in train, test, validate data
- Performance on [**training data**]{.underline} optimistic due to overfitting
- Use for nothing
- Top model performance on **validation data** (e.g., cross validation folds) can still be optimistic if many configurations assessed
- Use for selection only
- Performance on [**test data**]{.underline} only unbiased measure of performance
## [tidymodels](https://www.tidymodels.org/find/parsnip/)
![](parsnip_models.png)
## Random forest model
- Ensembles (combines) many decision trees
- To develop each tree, model randomly selects:
- Which training examples to use (bootstrapping)
- Random subset of covariates to consider for each branch
- Tree 'votes' are counted to estimate probability of outcome
## Decision tree example
![](tree-example.svg)
Outcome: hockey shot on goal. [workshops.tidymodels.org](https://workshops.tidymodels.org)
## Random forest in tidymodels
![](rand_forest_parsnip.png)
## Random forest arguments
![](parsnip_RF_arguments.png)
## Tuning random forest model
```{r}
#| echo: true
#specify a recipe (prediction task as formula;
#. can also include preprocessing)
ctg_recipe <- recipe(fetal_health~., data=dt)
#specify random forest model for tuning
rf_tune_spec <- rand_forest(mtry = tune(),
trees = 1000,
min_n = tune(),
mode = "classification")
rf_tune_spec
```
## Grid of hyperparameter settings
```{r}
#| echo: true
#Create grid of hyperparameters for tuning
rf_grid <- grid_regular(
mtry(range = c(5, 15)),#number of predictors sampled at each split of tree
min_n(range = c(2, 8)),#Minimum datapoints in node for further split
levels = 3
)
rf_grid
```
## Develop & select random forest model
```{r}
#| echo: true
# Tune the model
set.seed(456)
rf_tune_results <- tune_grid(
rf_tune_spec,
ctg_recipe,
resamples = ctg_folds,
grid = rf_grid
)
```
- 2 repeats of 3 fold cross validation
- Grid search \[5,10,15\] as \# randomly-selected predictors at each branch and \[2, 5, 8\] as minimum node size
How many random forest models will we train before selecting the top model?
## Compare AUC by hyperparameter setting
```{r}
#| echo: true
autoplot(rf_tune_results)
```
## Evaluate top configuration in test set
```{r}
#| echo: true
show_best(rf_tune_results, metric="roc_auc", n=3)
best_auc <- select_best(rf_tune_results, metric = "roc_auc")
#Specify a model with best hyperparameters
rf_best_spec <- rand_forest(mtry = best_auc$mtry,
trees = 1000,
min_n = best_auc$min_n,
mode = "classification") |>
set_engine("ranger", importance = "impurity")
#Trains top configuration on all training set; predict on test set
rf_test_results <- last_fit(
rf_best_spec,
ctg_recipe,
split = ctg_split_strat)
```
## Predict on test set
```{r}
#| echo: true
# Estimate unbiased performance on test set
rf_test_results %>% collect_metrics()
# Compare predicted risk to actual outcome
preds <- predict(extract_workflow(rf_test_results), testing(ctg_split_strat), type="prob")
dt_pred_outcome <- cbind(preds,
truth =testing(ctg_split_strat)$fetal_health)
head(dt_pred_outcome,5)
```
## Plot the ROC curve
::: columns
::: {.column width="50%"}
```{r}
#| echo: true
roc <- roc_curve(dt_pred_outcome,
truth,
.pred_Abnormal)
head(roc, 9)
#autoplot(roc)
```
:::
::: {.column width="50%"}
```{r}
#| echo: false
#| fig-width: 5
#| fig-height: 5
autoplot(roc)
```
:::
:::
## Assess calibration
```{r}
#| echo: true
calibration_obj <- caret::calibration(truth ~ .pred_Abnormal,
data = dt_pred_outcome,
cuts = 6)
ggplot(calibration_obj)
```
::: columns
::: {.column width="60%"}
```{r}
#| echo: false
ggplot(calibration_obj)
```
:::
::: {.column width="40%"}
- For perfectly calibrated model, midpoint of bins fall on the diagonal line
- Interpretation: risk predictions above \~12% may be overestimates
:::
:::
## Plot variable importance
```{r}
#| echo: true
rf_test_results |>
extract_fit_parsnip() |>
vip(num_features = 10)
```
## Model development takeaways
- Separates training, validation/comparison/selection, and testing to avoid over-/under-fitting
- Unbiased performance estimate comes from test data not used for training or selection
- In most applied projects, rigorous model development \>\>\> in-depth understanding of algorithms
## Agenda
- Types of supervised learning
- Model development and selection
- **Clinically useful models**
## Switching to Powerpoint
<https://github.com/altonrus/EMD601_ML_guest_lecture/blob/master/clinically_useful_ML_prediction.pptx>