Skip to content

Commit

Permalink
Merge pull request #236 from Renumics/layout_templates
Browse files Browse the repository at this point in the history
Layout templates
  • Loading branch information
neindochoh authored Sep 20, 2023
2 parents be9f877 + 327849c commit 69fa300
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 20 deletions.
10 changes: 5 additions & 5 deletions renumics/spotlight/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from renumics.spotlight.backend.middlewares.timing import add_timing_middleware
from renumics.spotlight.app_config import AppConfig
from renumics.spotlight.data_source import DataSource, create_datasource
from renumics.spotlight.layout.default import DEFAULT_LAYOUT
from renumics.spotlight import layouts

from renumics.spotlight.data_store import DataStore

Expand Down Expand Up @@ -86,7 +86,7 @@ class SpotlightApp(FastAPI):

task_manager: TaskManager
websocket_manager: Optional[WebsocketManager]
_layout: Optional[Layout]
_layout: Layout
config: Config
username: str
filebrowsing_allowed: bool
Expand All @@ -106,7 +106,7 @@ def __init__(self) -> None:
self.task_manager = TaskManager()
self.websocket_manager = None
self.config = Config()
self._layout = None
self._layout = layouts.default()
self.project_root = Path.cwd()
self.vite_url = None
self.username = ""
Expand Down Expand Up @@ -368,11 +368,11 @@ def layout(self) -> Layout:
"""
Frontend layout
"""
return self._layout or DEFAULT_LAYOUT
return self._layout

@layout.setter
def layout(self, layout: Optional[Layout]) -> None:
self._layout = layout
self._layout = layout or layouts.default()
self._broadcast(ResetLayoutMessage())

async def get_current_layout_dict(self, user_id: str) -> Optional[Dict]:
Expand Down
15 changes: 0 additions & 15 deletions renumics/spotlight/layout/default.py

This file was deleted.

5 changes: 5 additions & 0 deletions renumics/spotlight/layouts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .default import default
from .model_debug import debug_classification
from .model_compare import compare_classification

__all__ = ["default", "debug_classification", "compare_classification"]
26 changes: 26 additions & 0 deletions renumics/spotlight/layouts/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from renumics.spotlight.layout import (
histogram,
inspector,
layout,
scatterplot,
similaritymap,
split,
tab,
table,
)
from renumics.spotlight.layout.nodes import Layout


def default() -> Layout:
"""
Default layout for spotlight.
"""

return layout(
split(
tab(table(), weight=60),
tab(similaritymap(), scatterplot(), histogram(), weight=40),
weight=60,
),
tab(inspector(), weight=40),
)
163 changes: 163 additions & 0 deletions renumics/spotlight/layouts/model_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Optional, Union, Dict, Any
from renumics.spotlight import dtypes
from renumics.spotlight import layout
from renumics.spotlight.layout import (
Layout,
Tab,
Split,
lenses,
table,
similaritymap,
inspector,
split,
tab,
metric,
issues,
confusion_matrix,
)


def compare_classification(
label: str = "label",
model1_prediction: str = "m1_prediction",
model1_embedding: str = "",
model1_correct: str = "",
model2_prediction: str = "m2_prediction",
model2_embedding: str = "",
model2_correct: str = "",
inspect: Optional[Dict[str, Any]] = None,
) -> Layout:
"""This function generates a Spotlight layout for comparing two different machine learning classification models.
Args:
label: Name of the column that contains the label.
model1_prediction: Name of the column that contains the prediction for model 1.
model1_embedding: Name of the column that contains thee embedding for model 1.
model1_correct: Name of the column that contains a flag if the data sample is predicted correctly by model 1.
model2_prediction: Name of the column that contains the prediction for model 2.
model2_embedding: Name of the column that contains thee embedding for model 2.
model2_correct: Name of the column that contains a flag if the data sample is predicted correctly by model 2.
inspect: Name and type of the columns that are displayed in the inspector, e.g. {'audio': spotlight.dtypes.audio_dtype}.
Returns:
The configured layout for `spotlight.show`.
"""

# first column: table + issues
metrics = split(
[
tab(
metric(
name="Accuracy model 1",
metric="accuracy",
columns=[label, model1_prediction],
)
),
tab(
metric(
name="Accuracy model 2",
metric="accuracy",
columns=[label, model2_prediction],
)
),
],
orientation="vertical",
weight=15,
)
column1 = split(
[metrics, tab(table(), weight=65)], weight=80, orientation="horizontal"
)
column1 = split(
[column1, tab(issues(), weight=40)], weight=80, orientation="horizontal"
)

column2_list = []
column2_list.append(
tab(
confusion_matrix(
name="Model 1 confusion matrix",
x_column=label,
y_column=model1_prediction,
),
confusion_matrix(
name="Model 2 confusion matrix",
x_column=label,
y_column=model2_prediction,
),
weight=40,
)
)

# third column: similarity maps
if model1_correct != "":
if model2_correct != "":
row2 = tab(
confusion_matrix(
name="Model1 vs. Model2 - binned scatterplot",
x_column=model1_correct,
y_column=model2_correct,
),
weight=40,
)
column2_list.append(row2)

if model1_embedding != "":
if model2_embedding != "":
row3 = tab(
similaritymap(
name="Model 1 embedding",
columns=[model1_embedding],
color_by_column=label,
),
similaritymap(
name="Model 2 embedding",
columns=[model2_embedding],
color_by_column=label,
),
weight=40,
)
column2_list.append(row3)

column2: Union[Tab, Split]

if len(column2_list) == 1:
column2 = column2_list[0]
elif len(column2_list) == 2:
column2 = split(column2_list, orientation="horizontal")
else:
column2 = split(
[column2_list[0], column2_list[1]], weight=80, orientation="horizontal"
)
column2 = split([column2, column2_list[2]], orientation="horizontal")

# fourth column: inspector
inspector_fields = []
if inspect:
for item, dtype_like in inspect.items():
dtype = dtypes.create_dtype(dtype_like)
if dtypes.is_audio_dtype(dtype):
inspector_fields.append(lenses.audio(item))
elif dtypes.is_image_dtype(dtype):
inspector_fields.append(lenses.image(item))
else:
print(f"Type {dtype} not supported by this layout.")

inspector_fields.append(lenses.scalar(label))
inspector_fields.append(lenses.scalar(model1_prediction))
inspector_fields.append(lenses.scalar(model2_prediction))

inspector_view = inspector("Inspector", lenses=inspector_fields, num_columns=4)

else:
inspector_view = inspector("Inspector", num_columns=4)

# build everything together
column2.weight = 40
half1 = split([column1, column2], weight=80, orientation="vertical")
half2 = tab(inspector_view, weight=40)

nodes = [half1, half2]

the_layout = layout.layout(nodes)

return the_layout
127 changes: 127 additions & 0 deletions renumics/spotlight/layouts/model_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Optional, Union, Dict, List, Any
from renumics.spotlight import layout
from renumics.spotlight.layout import (
Layout,
Tab,
Split,
lenses,
table,
similaritymap,
inspector,
split,
tab,
metric,
issues,
confusion_matrix,
histogram,
)
from renumics.spotlight.dtypes import create_dtype, is_audio_dtype, is_image_dtype


def debug_classification(
label: str = "label",
prediction: str = "prediction",
embedding: str = "",
inspect: Optional[Dict[str, Any]] = None,
features: Optional[List[str]] = None,
) -> Layout:
"""This function generates a Spotlight layout for debugging a machine learning classification model.
Args:
label: Name of the column that contains the label.
prediction: Name of the column that contains the prediction.
embedding: Name of the column that contains the embedding.
inspect: Name and type of the columns that are displayed in the inspector, e.g. {'audio': spotlight.dtypes.audio_dtype}.
features: Names of the columns that contain useful metadata and features.
Returns:
The configured layout for `spotlight.show`.
"""

# first column: table + issues
metrics = tab(
metric(name="Accuracy", metric="accuracy", columns=[label, prediction]),
weight=15,
)
column1 = split(
[metrics, tab(table(), weight=65)], weight=80, orientation="horizontal"
)
column1 = split(
[column1, tab(issues(), weight=40)], weight=80, orientation="horizontal"
)

column2_list = []
column2_list.append(
tab(
confusion_matrix(
name="Confusion matrix", x_column=label, y_column=prediction
),
weight=40,
)
)

# second column: confusion matric, feature histograms (optional), embedding (optional)
if features is not None:
histogram_list = []
for idx, feature in enumerate(features):
if idx > 2:
break
h = histogram(
name="Histogram {}".format(feature),
column=feature,
stack_by_column=label,
)
histogram_list.append(h)

row2 = tab(*histogram_list, weight=40)
column2_list.append(row2)

if embedding != "":
row3 = tab(
similaritymap(name="Embedding", columns=[embedding], color_by_column=label),
weight=40,
)
column2_list.append(row3)

column2: Union[Tab, Split]

if len(column2_list) == 1:
column2 = column2_list[0]
elif len(column2_list) == 2:
column2 = split(column2_list, orientation="horizontal")
else:
column2 = split(
[column2_list[0], column2_list[1]], weight=80, orientation="horizontal"
)
column2 = split([column2, column2_list[2]], orientation="horizontal")

# fourth column: inspector
inspector_fields = []
if inspect:
for item, dtype_like in inspect.items():
dtype = create_dtype(dtype_like)
if is_audio_dtype(dtype):
inspector_fields.append(lenses.audio(item))
elif is_image_dtype(dtype):
inspector_fields.append(lenses.image(item))
else:
print("Type {} not supported by this layout.".format(dtype))

inspector_fields.append(lenses.scalar(label))
inspector_fields.append(lenses.scalar(prediction))

inspector_view = inspector("Inspector", lenses=inspector_fields, num_columns=4)

else:
inspector_view = inspector("Inspector", num_columns=4)

# build everything together
column2.weight = 40
half1 = split([column1, column2], weight=80, orientation="vertical")
half2 = tab(inspector_view, weight=40)

nodes = [half1, half2]

the_layout = layout.layout(nodes)

return the_layout

0 comments on commit 69fa300

Please sign in to comment.