-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #236 from Renumics/layout_templates
Layout templates
- Loading branch information
Showing
6 changed files
with
326 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .default import default | ||
from .model_debug import debug_classification | ||
from .model_compare import compare_classification | ||
|
||
__all__ = ["default", "debug_classification", "compare_classification"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from renumics.spotlight.layout import ( | ||
histogram, | ||
inspector, | ||
layout, | ||
scatterplot, | ||
similaritymap, | ||
split, | ||
tab, | ||
table, | ||
) | ||
from renumics.spotlight.layout.nodes import Layout | ||
|
||
|
||
def default() -> Layout: | ||
""" | ||
Default layout for spotlight. | ||
""" | ||
|
||
return layout( | ||
split( | ||
tab(table(), weight=60), | ||
tab(similaritymap(), scatterplot(), histogram(), weight=40), | ||
weight=60, | ||
), | ||
tab(inspector(), weight=40), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from typing import Optional, Union, Dict, Any | ||
from renumics.spotlight import dtypes | ||
from renumics.spotlight import layout | ||
from renumics.spotlight.layout import ( | ||
Layout, | ||
Tab, | ||
Split, | ||
lenses, | ||
table, | ||
similaritymap, | ||
inspector, | ||
split, | ||
tab, | ||
metric, | ||
issues, | ||
confusion_matrix, | ||
) | ||
|
||
|
||
def compare_classification( | ||
label: str = "label", | ||
model1_prediction: str = "m1_prediction", | ||
model1_embedding: str = "", | ||
model1_correct: str = "", | ||
model2_prediction: str = "m2_prediction", | ||
model2_embedding: str = "", | ||
model2_correct: str = "", | ||
inspect: Optional[Dict[str, Any]] = None, | ||
) -> Layout: | ||
"""This function generates a Spotlight layout for comparing two different machine learning classification models. | ||
Args: | ||
label: Name of the column that contains the label. | ||
model1_prediction: Name of the column that contains the prediction for model 1. | ||
model1_embedding: Name of the column that contains thee embedding for model 1. | ||
model1_correct: Name of the column that contains a flag if the data sample is predicted correctly by model 1. | ||
model2_prediction: Name of the column that contains the prediction for model 2. | ||
model2_embedding: Name of the column that contains thee embedding for model 2. | ||
model2_correct: Name of the column that contains a flag if the data sample is predicted correctly by model 2. | ||
inspect: Name and type of the columns that are displayed in the inspector, e.g. {'audio': spotlight.dtypes.audio_dtype}. | ||
Returns: | ||
The configured layout for `spotlight.show`. | ||
""" | ||
|
||
# first column: table + issues | ||
metrics = split( | ||
[ | ||
tab( | ||
metric( | ||
name="Accuracy model 1", | ||
metric="accuracy", | ||
columns=[label, model1_prediction], | ||
) | ||
), | ||
tab( | ||
metric( | ||
name="Accuracy model 2", | ||
metric="accuracy", | ||
columns=[label, model2_prediction], | ||
) | ||
), | ||
], | ||
orientation="vertical", | ||
weight=15, | ||
) | ||
column1 = split( | ||
[metrics, tab(table(), weight=65)], weight=80, orientation="horizontal" | ||
) | ||
column1 = split( | ||
[column1, tab(issues(), weight=40)], weight=80, orientation="horizontal" | ||
) | ||
|
||
column2_list = [] | ||
column2_list.append( | ||
tab( | ||
confusion_matrix( | ||
name="Model 1 confusion matrix", | ||
x_column=label, | ||
y_column=model1_prediction, | ||
), | ||
confusion_matrix( | ||
name="Model 2 confusion matrix", | ||
x_column=label, | ||
y_column=model2_prediction, | ||
), | ||
weight=40, | ||
) | ||
) | ||
|
||
# third column: similarity maps | ||
if model1_correct != "": | ||
if model2_correct != "": | ||
row2 = tab( | ||
confusion_matrix( | ||
name="Model1 vs. Model2 - binned scatterplot", | ||
x_column=model1_correct, | ||
y_column=model2_correct, | ||
), | ||
weight=40, | ||
) | ||
column2_list.append(row2) | ||
|
||
if model1_embedding != "": | ||
if model2_embedding != "": | ||
row3 = tab( | ||
similaritymap( | ||
name="Model 1 embedding", | ||
columns=[model1_embedding], | ||
color_by_column=label, | ||
), | ||
similaritymap( | ||
name="Model 2 embedding", | ||
columns=[model2_embedding], | ||
color_by_column=label, | ||
), | ||
weight=40, | ||
) | ||
column2_list.append(row3) | ||
|
||
column2: Union[Tab, Split] | ||
|
||
if len(column2_list) == 1: | ||
column2 = column2_list[0] | ||
elif len(column2_list) == 2: | ||
column2 = split(column2_list, orientation="horizontal") | ||
else: | ||
column2 = split( | ||
[column2_list[0], column2_list[1]], weight=80, orientation="horizontal" | ||
) | ||
column2 = split([column2, column2_list[2]], orientation="horizontal") | ||
|
||
# fourth column: inspector | ||
inspector_fields = [] | ||
if inspect: | ||
for item, dtype_like in inspect.items(): | ||
dtype = dtypes.create_dtype(dtype_like) | ||
if dtypes.is_audio_dtype(dtype): | ||
inspector_fields.append(lenses.audio(item)) | ||
elif dtypes.is_image_dtype(dtype): | ||
inspector_fields.append(lenses.image(item)) | ||
else: | ||
print(f"Type {dtype} not supported by this layout.") | ||
|
||
inspector_fields.append(lenses.scalar(label)) | ||
inspector_fields.append(lenses.scalar(model1_prediction)) | ||
inspector_fields.append(lenses.scalar(model2_prediction)) | ||
|
||
inspector_view = inspector("Inspector", lenses=inspector_fields, num_columns=4) | ||
|
||
else: | ||
inspector_view = inspector("Inspector", num_columns=4) | ||
|
||
# build everything together | ||
column2.weight = 40 | ||
half1 = split([column1, column2], weight=80, orientation="vertical") | ||
half2 = tab(inspector_view, weight=40) | ||
|
||
nodes = [half1, half2] | ||
|
||
the_layout = layout.layout(nodes) | ||
|
||
return the_layout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from typing import Optional, Union, Dict, List, Any | ||
from renumics.spotlight import layout | ||
from renumics.spotlight.layout import ( | ||
Layout, | ||
Tab, | ||
Split, | ||
lenses, | ||
table, | ||
similaritymap, | ||
inspector, | ||
split, | ||
tab, | ||
metric, | ||
issues, | ||
confusion_matrix, | ||
histogram, | ||
) | ||
from renumics.spotlight.dtypes import create_dtype, is_audio_dtype, is_image_dtype | ||
|
||
|
||
def debug_classification( | ||
label: str = "label", | ||
prediction: str = "prediction", | ||
embedding: str = "", | ||
inspect: Optional[Dict[str, Any]] = None, | ||
features: Optional[List[str]] = None, | ||
) -> Layout: | ||
"""This function generates a Spotlight layout for debugging a machine learning classification model. | ||
Args: | ||
label: Name of the column that contains the label. | ||
prediction: Name of the column that contains the prediction. | ||
embedding: Name of the column that contains the embedding. | ||
inspect: Name and type of the columns that are displayed in the inspector, e.g. {'audio': spotlight.dtypes.audio_dtype}. | ||
features: Names of the columns that contain useful metadata and features. | ||
Returns: | ||
The configured layout for `spotlight.show`. | ||
""" | ||
|
||
# first column: table + issues | ||
metrics = tab( | ||
metric(name="Accuracy", metric="accuracy", columns=[label, prediction]), | ||
weight=15, | ||
) | ||
column1 = split( | ||
[metrics, tab(table(), weight=65)], weight=80, orientation="horizontal" | ||
) | ||
column1 = split( | ||
[column1, tab(issues(), weight=40)], weight=80, orientation="horizontal" | ||
) | ||
|
||
column2_list = [] | ||
column2_list.append( | ||
tab( | ||
confusion_matrix( | ||
name="Confusion matrix", x_column=label, y_column=prediction | ||
), | ||
weight=40, | ||
) | ||
) | ||
|
||
# second column: confusion matric, feature histograms (optional), embedding (optional) | ||
if features is not None: | ||
histogram_list = [] | ||
for idx, feature in enumerate(features): | ||
if idx > 2: | ||
break | ||
h = histogram( | ||
name="Histogram {}".format(feature), | ||
column=feature, | ||
stack_by_column=label, | ||
) | ||
histogram_list.append(h) | ||
|
||
row2 = tab(*histogram_list, weight=40) | ||
column2_list.append(row2) | ||
|
||
if embedding != "": | ||
row3 = tab( | ||
similaritymap(name="Embedding", columns=[embedding], color_by_column=label), | ||
weight=40, | ||
) | ||
column2_list.append(row3) | ||
|
||
column2: Union[Tab, Split] | ||
|
||
if len(column2_list) == 1: | ||
column2 = column2_list[0] | ||
elif len(column2_list) == 2: | ||
column2 = split(column2_list, orientation="horizontal") | ||
else: | ||
column2 = split( | ||
[column2_list[0], column2_list[1]], weight=80, orientation="horizontal" | ||
) | ||
column2 = split([column2, column2_list[2]], orientation="horizontal") | ||
|
||
# fourth column: inspector | ||
inspector_fields = [] | ||
if inspect: | ||
for item, dtype_like in inspect.items(): | ||
dtype = create_dtype(dtype_like) | ||
if is_audio_dtype(dtype): | ||
inspector_fields.append(lenses.audio(item)) | ||
elif is_image_dtype(dtype): | ||
inspector_fields.append(lenses.image(item)) | ||
else: | ||
print("Type {} not supported by this layout.".format(dtype)) | ||
|
||
inspector_fields.append(lenses.scalar(label)) | ||
inspector_fields.append(lenses.scalar(prediction)) | ||
|
||
inspector_view = inspector("Inspector", lenses=inspector_fields, num_columns=4) | ||
|
||
else: | ||
inspector_view = inspector("Inspector", num_columns=4) | ||
|
||
# build everything together | ||
column2.weight = 40 | ||
half1 = split([column1, column2], weight=80, orientation="vertical") | ||
half2 = tab(inspector_view, weight=40) | ||
|
||
nodes = [half1, half2] | ||
|
||
the_layout = layout.layout(nodes) | ||
|
||
return the_layout |