1
1
{
2
2
"cells" : [
3
+ {
4
+ "cell_type" : " code" ,
5
+ "execution_count" : null ,
6
+ "metadata" : {},
7
+ "outputs" : [],
8
+ "source" : [
9
+ " %load_ext autoreload\n " ,
10
+ " %autoreload 2"
11
+ ]
12
+ },
3
13
{
4
14
"cell_type" : " code" ,
5
15
"execution_count" : null ,
50
60
"outputs" : [],
51
61
"source" : [
52
62
" n_samples = 2000\n " ,
53
- " n_classes = 2 "
63
+ " n_classes = 3 "
54
64
]
55
65
},
56
66
{
68
78
" random_state=42,\n " ,
69
79
" )\n " ,
70
80
" X_train, X_test, y_train, y_test = train_test_split(\n " ,
71
- " X, y, test_size=0.2 , random_state=42\n " ,
81
+ " X, y, test_size=0.5 , random_state=42\n " ,
72
82
" )"
73
83
]
74
84
},
85
95
"metadata" : {},
86
96
"outputs" : [],
87
97
"source" : [
88
- " model = MLPClassifier(hidden_layer_sizes=(50, 50, 50 ))\n " ,
98
+ " model = MLPClassifier(hidden_layer_sizes=(20, 20, 10 ))\n " ,
89
99
" model.fit(X_train, y_train)"
90
100
]
91
101
},
181
191
"metadata" : {},
182
192
"outputs" : [],
183
193
"source" : [
184
- " ece = ECE(bins=10 )"
194
+ " ece = ECE(bins=12 )"
185
195
]
186
196
},
187
197
{
191
201
"outputs" : [],
192
202
"source" : [
193
203
" # Evaluate uncalibrated predictions\n " ,
194
- " uncalibrated_confidences = model.predict_proba(X_test)\n " ,
204
+ " y_pred = model.predict_proba(X_test)\n " ,
195
205
" \n " ,
196
- " pre_calibration_ece = ece.compute(uncalibrated_confidences , y_test)\n " ,
206
+ " pre_calibration_ece = ece.compute(y_pred , y_test)\n " ,
197
207
" \n " ,
198
208
" f\" ECE before calibration: {pre_calibration_ece}\" "
199
209
]
212
222
"metadata" : {},
213
223
"outputs" : [],
214
224
"source" : [
215
- " eval_stats = EvalStats(y_test, uncalibrated_confidences)\n " ,
216
- " class_labels = [i for i in range(n_classes)]\n " ,
217
- " \n " ,
218
- " eval_stats.plot_reliability_curves(class_labels)"
225
+ " eval_stats = EvalStats(y_test, y_pred)\n " ,
226
+ " class_labels = range(n_classes)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type" : " code" ,
231
+ "execution_count" : null ,
232
+ "metadata" : {},
233
+ "outputs" : [],
234
+ "source" : [
235
+ " fig = eval_stats.plot_reliability_curves(\n " ,
236
+ " [\" top_class\" , 0], display_weights=True, strategy=\" uniform\" , n_bins=8\n " ,
237
+ " )"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type" : " markdown" ,
242
+ "metadata" : {},
243
+ "source" : [
244
+ " The density of predictions is distributed highly inhomogeneously on the unit interval, some bins have\n " ,
245
+ " few members and the estimate of the reliability has high variance. This can be helped by employing\n " ,
246
+ " the \" quantile\" binning strategy, also called adaptive binning"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type" : " code" ,
251
+ "execution_count" : null ,
252
+ "metadata" : {},
253
+ "outputs" : [],
254
+ "source" : [
255
+ " fig = eval_stats.plot_reliability_curves(\n " ,
256
+ " [0, \" top_class\" ], display_weights=True, n_bins=8, strategy=\" quantile\"\n " ,
257
+ " )"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type" : " markdown" ,
262
+ "metadata" : {},
263
+ "source" : [
264
+ " Now all bins have the same weight but different width. The pointwise reliability estimates\n " ,
265
+ " have lower variance but there are wide gaps, thus requiring more interpolation.\n " ,
266
+ " Both binning strategies have their advantages and disadvantages."
219
267
]
220
268
},
221
269
{
455
503
"source" : [
456
504
" ece.compute(confidences, ground_truth)"
457
505
]
458
- },
459
- {
460
- "cell_type" : " markdown" ,
461
- "metadata" : {},
462
- "source" : [
463
- " Once again, to verify that miscalibration will indeed increase with more samples, let's sample *5x* as many samples as\n " ,
464
- " before and measure $ECE$ again:"
465
- ]
466
- },
467
- {
468
- "cell_type" : " code" ,
469
- "execution_count" : null ,
470
- "metadata" : {},
471
- "outputs" : [],
472
- "source" : [
473
- " uncalibrated_samples = shifted_sampler.get_sample_arrays(1000)\n " ,
474
- " ground_truth, confidences = uncalibrated_samples\n " ,
475
- " \n " ,
476
- " ece.compute(confidences, ground_truth)"
477
- ]
478
- },
479
- {
480
- "cell_type" : " markdown" ,
481
- "metadata" : {},
482
- "source" : [
483
- " Great! Calibration error goes up as we sample more instances."
484
- ]
485
506
}
486
507
],
487
508
"metadata" : {
488
509
"kernelspec" : {
489
- "display_name" : " Python 3" ,
510
+ "display_name" : " Python 3 (ipykernel) " ,
490
511
"language" : " python" ,
491
512
"name" : " python3"
492
513
},
500
521
"name" : " python" ,
501
522
"nbconvert_exporter" : " python" ,
502
523
"pygments_lexer" : " ipython3" ,
503
- "version" : " 3.8.5 "
524
+ "version" : " 3.8.13 "
504
525
}
505
526
},
506
527
"nbformat" : 4 ,
507
528
"nbformat_minor" : 1
508
- }
529
+ }
0 commit comments