diff --git a/CHANGELOG.md b/CHANGELOG.md index 95aebdf7..92050da1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,5 +19,6 @@ Released changes are shown in the ### Removed ### Fixed +- Fixed importing the same proposed actions CSV file twice ### Security diff --git a/azimuth/app.py b/azimuth/app.py index 2b69fc58..738fcb19 100644 --- a/azimuth/app.py +++ b/azimuth/app.py @@ -155,7 +155,7 @@ def create_app() -> FastAPI: Returns: FastAPI. """ - app = FastAPI( + api = FastAPI( title="Azimuth API", description="Azimuth API", version="1.0", @@ -171,101 +171,103 @@ def create_app() -> FastAPI: ) # Setup routes - from azimuth.routers.app import router as app_router - from azimuth.routers.class_overlap import router as class_overlap_router - from azimuth.routers.config import router as config_router - from azimuth.routers.custom_utterances import router as custom_utterances_router - from azimuth.routers.dataset_warnings import router as dataset_warnings_router - from azimuth.routers.export import router as export_router - from azimuth.routers.model_performance.confidence_histogram import ( - router as confidence_histogram_router, + from azimuth.routers import ( + app, + class_overlap, + config, + custom_utterances, + dataset_warnings, + export, + top_words, + utterances, + ) + from azimuth.routers.model_performance import ( + confidence_histogram, + confusion_matrix, + metrics, + outcome_count, + utterance_count, ) - from azimuth.routers.model_performance.confusion_matrix import router as confusion_matrix_router - from azimuth.routers.model_performance.metrics import router as metrics_router - from azimuth.routers.model_performance.outcome_count import router as outcome_count_router - from azimuth.routers.model_performance.utterance_count import router as utterance_count_router - from azimuth.routers.top_words import router as top_words_router - from azimuth.routers.utterances import router as utterances_router from azimuth.utils.routers import require_application_ready, require_available_model api_router = APIRouter() - api_router.include_router(app_router, prefix="", tags=["App"]) - api_router.include_router(config_router, prefix="/config", tags=["Config"]) + api_router.include_router(app.router, prefix="", tags=["App"]) + api_router.include_router(config.router, prefix="/config", tags=["Config"]) api_router.include_router( - class_overlap_router, + class_overlap.router, prefix="/dataset_splits/{dataset_split_name}/class_overlap", tags=["Class Overlap"], dependencies=[Depends(require_application_ready)], ) api_router.include_router( - confidence_histogram_router, + confidence_histogram.router, prefix="/dataset_splits/{dataset_split_name}/confidence_histogram", tags=["Confidence Histogram"], dependencies=[Depends(require_application_ready), Depends(require_available_model)], ) api_router.include_router( - dataset_warnings_router, + dataset_warnings.router, prefix="/dataset_warnings", tags=["Dataset Warnings"], dependencies=[Depends(require_application_ready)], ) api_router.include_router( - metrics_router, + metrics.router, prefix="/dataset_splits/{dataset_split_name}/metrics", tags=["Metrics"], dependencies=[Depends(require_application_ready), Depends(require_available_model)], ) api_router.include_router( - outcome_count_router, + outcome_count.router, prefix="/dataset_splits/{dataset_split_name}/outcome_count", tags=["Outcome Count"], dependencies=[Depends(require_application_ready), Depends(require_available_model)], ) api_router.include_router( - utterance_count_router, + utterance_count.router, prefix="/dataset_splits/{dataset_split_name}/utterance_count", tags=["Utterance Count"], dependencies=[Depends(require_application_ready)], ) api_router.include_router( - utterances_router, + utterances.router, prefix="/dataset_splits/{dataset_split_name}/utterances", tags=["Utterances"], dependencies=[Depends(require_application_ready)], ) api_router.include_router( - export_router, + export.router, prefix="/export", tags=["Export"], dependencies=[Depends(require_application_ready)], ) api_router.include_router( - custom_utterances_router, + custom_utterances.router, prefix="/custom_utterances", tags=["Custom Utterances"], dependencies=[Depends(require_application_ready)], ) api_router.include_router( - top_words_router, + top_words.router, prefix="/dataset_splits/{dataset_split_name}/top_words", tags=["Top Words"], dependencies=[Depends(require_application_ready), Depends(require_available_model)], ) api_router.include_router( - confusion_matrix_router, + confusion_matrix.router, prefix="/dataset_splits/{dataset_split_name}/confusion_matrix", tags=["Confusion Matrix"], dependencies=[Depends(require_application_ready), Depends(require_available_model)], ) - app.include_router(api_router) + api.include_router(api_router) - app.add_middleware( + api.add_middleware( CORSMiddleware, allow_methods=["*"], allow_headers=["*"], ) - return app + return api def load_dataset_split_managers_from_config( diff --git a/azimuth/config.py b/azimuth/config.py index bfee3133..d24dd39f 100644 --- a/azimuth/config.py +++ b/azimuth/config.py @@ -284,11 +284,15 @@ def get_project_hash(self): class ArtifactsConfig(AzimuthBaseSettings, extra=Extra.ignore): artifact_path: str = Field( - "cache", + default_factory=lambda: os.path.abspath("cache"), description="Where to store artifacts (Azimuth config history, HDF5 files, HF datasets).", exclude_from_cache=True, ) + @validator("artifact_path") + def validate_artifact_path(cls, artifact_path): + return os.path.abspath(artifact_path) + def get_config_history_path(self): return f"{self.artifact_path}/config_history.jsonl" @@ -329,7 +333,7 @@ class ModelContractConfig(CommonFieldsConfig): # Uncertainty configuration uncertainty: UncertaintyOptions = UncertaintyOptions() # Layer name where to calculate the gradients, normally the word embeddings layer. - saliency_layer: Optional[str] = Field(None, nullable=True) + saliency_layer: Union[Literal["auto"], str, None] = Field("auto", nullable=True) @validator("pipelines", pre=True) def _check_pipeline_names(cls, pipeline_definitions): diff --git a/azimuth/modules/model_contracts/hf_text_classification.py b/azimuth/modules/model_contracts/hf_text_classification.py index c76c8fed..d6755079 100644 --- a/azimuth/modules/model_contracts/hf_text_classification.py +++ b/azimuth/modules/model_contracts/hf_text_classification.py @@ -16,6 +16,7 @@ from azimuth.types.task import PredictionResponse, SaliencyResponse from azimuth.utils.ml.mc_dropout import MCDropout from azimuth.utils.ml.saliency import ( + find_word_embeddings_layer, get_saliency, register_embedding_gradient_hook, register_embedding_list_hook, @@ -127,6 +128,7 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]: raise ValueError("This method should not be called when saliency_layer is not defined.") hf_pipeline = self.get_model() + hf_model = hf_pipeline.model inputs = hf_pipeline.tokenizer( batch[self.config.columns.text_input], @@ -140,18 +142,19 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]: inputs["input_ids"] = inputs["input_ids"].to(hf_pipeline.device) inputs["attention_mask"] = inputs["attention_mask"].to(hf_pipeline.device) - logits = hf_pipeline.model(**inputs)[0] + logits = hf_model(**inputs)[0] output = torch.softmax(logits, dim=1).detach().cpu().numpy() prediction = output.argmax(-1) - embeddings_list: List[np.ndarray] = [] - handle = register_embedding_list_hook( - hf_pipeline.model, embeddings_list, self.saliency_layer + embedding_layer = ( + hf_model.base_model.get_input_embeddings() + if self.saliency_layer == "auto" + else find_word_embeddings_layer(hf_model, self.saliency_layer) ) + embeddings_list: List[np.ndarray] = [] + handle = register_embedding_list_hook(embeddings_list, embedding_layer) embeddings_gradients: List[np.ndarray] = [] - hook = register_embedding_gradient_hook( - hf_pipeline.model, embeddings_gradients, self.saliency_layer - ) + hook = register_embedding_gradient_hook(embeddings_gradients, embedding_layer) filter_class = self.mod_options.filter_class selected_classes = ( @@ -162,8 +165,8 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]: ) # Do backward pass to compute gradients - hf_pipeline.model.zero_grad() - _loss = hf_pipeline.model(**inputs)[0] # loss is at index 0 when passing labels + hf_model.zero_grad() + _loss = hf_model(**inputs)[0] # loss is at index 0 when passing labels _loss.backward() handle.remove() hook.remove() diff --git a/azimuth/routers/config.py b/azimuth/routers/config.py index 790ef6a9..33153f67 100644 --- a/azimuth/routers/config.py +++ b/azimuth/routers/config.py @@ -1,12 +1,15 @@ # Copyright ServiceNow, Inc. 2021 – 2022 # This source code is licensed under the Apache 2.0 license found in the LICENSE file # in the root directory of this source tree. -from typing import Any, Dict, List +from typing import Dict, List import structlog from fastapi import APIRouter, Body, Depends, HTTPException, Query -from pydantic import ValidationError -from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_403_FORBIDDEN, + HTTP_500_INTERNAL_SERVER_ERROR, +) from azimuth.app import ( get_config, @@ -73,6 +76,25 @@ def get_config_def( return config +@router.patch( + "/validate", + summary="Validate config", + description="Validate the given partial config update and return the complete config that would" + " result if this update was applied.", + response_model=AzimuthConfig, + dependencies=[Depends(require_editable_config)], +) +def validate_config( + config: AzimuthConfig = Depends(get_config), + partial_config: Dict = Body(...), +) -> AzimuthConfig: + new_config = update_config(old_config=config, partial_config=partial_config) + + assert_permission_to_update_config(old_config=config, new_config=new_config) + + return new_config + + @router.patch( "", summary="Update config", @@ -85,19 +107,17 @@ def patch_config( config: AzimuthConfig = Depends(get_config), partial_config: Dict = Body(...), ) -> AzimuthConfig: - if attribute_changed_in_config("artifact_path", partial_config, config): - raise HTTPException( - HTTP_400_BAD_REQUEST, - detail="Cannot edit artifact_path, otherwise config history would become inconsistent.", - ) + log.info(f"Validating config change with {partial_config}.") + new_config = update_config(old_config=config, partial_config=partial_config) + + assert_permission_to_update_config(old_config=config, new_config=new_config) + + if new_config.large_dask_cluster != config.large_dask_cluster: + cluster = default_cluster(new_config.large_dask_cluster) + else: + cluster = task_manager.cluster try: - log.info(f"Validating config change with {partial_config}.") - new_config = update_config(old_config=config, partial_config=partial_config) - if attribute_changed_in_config("large_dask_cluster", partial_config, config): - cluster = default_cluster(partial_config["large_dask_cluster"]) - else: - cluster = task_manager.cluster run_startup_tasks(new_config, cluster) log.info(f"Config successfully updated with {partial_config}.") except Exception as e: @@ -107,8 +127,6 @@ def patch_config( log.info("Config update cancelled.") if isinstance(e, AzimuthValidationError): raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e)) - if isinstance(e, ValidationError): - raise else: raise HTTPException( HTTP_500_INTERNAL_SERVER_ERROR, detail="Error when loading the new config." @@ -117,7 +135,9 @@ def patch_config( return new_config -def attribute_changed_in_config( - attribute: str, partial_config: Dict[str, Any], config: AzimuthConfig -) -> bool: - return attribute in partial_config and partial_config[attribute] != getattr(config, attribute) +def assert_permission_to_update_config(*, old_config: AzimuthConfig, new_config: AzimuthConfig): + if old_config.artifact_path != new_config.artifact_path: + raise HTTPException( + HTTP_403_FORBIDDEN, + detail="Cannot edit artifact_path, otherwise config history would become inconsistent.", + ) diff --git a/azimuth/utils/ml/saliency.py b/azimuth/utils/ml/saliency.py index 13a0ec9f..3caa71c3 100644 --- a/azimuth/utils/ml/saliency.py +++ b/azimuth/utils/ml/saliency.py @@ -5,6 +5,7 @@ import numpy as np import structlog +from torch.nn import Embedding from azimuth.types.general.module_arguments import GradientCalculation @@ -36,14 +37,13 @@ def find_word_embeddings_layer(model: Any, layer_name: str) -> Any: def register_embedding_list_hook( - model: Any, embeddings_list: List[np.ndarray], layer_name: str + embeddings_list: List[np.ndarray], embedding_layer: Embedding ) -> Any: """Register hook to get the embedding values from model. Args: - model: Model. embeddings_list: Variable to save values. - layer_name: Name of the embedding layer. + embedding_layer: Embedding layer on which to compute the saliency map. Returns: Hook. @@ -52,21 +52,17 @@ def register_embedding_list_hook( def forward_hook(module, inputs, output): embeddings_list.append(output.detach().cpu().clone().numpy()) - embedding_layer = find_word_embeddings_layer(model, layer_name) - handle = embedding_layer.register_forward_hook(forward_hook) - - return handle + return embedding_layer.register_forward_hook(forward_hook) def register_embedding_gradient_hook( - model: Any, embeddings_gradients: List[np.ndarray], layer_name: str + embeddings_gradients: List[np.ndarray], embedding_layer: Embedding ) -> Any: """Register hook to get the gradient values from the embedding layer. Args: - model: Model. embeddings_gradients: Variable to save values. - layer_name: Name of the embedding layer. + embedding_layer: Embedding layer on which to compute the saliency map. Returns: Hook. @@ -76,10 +72,7 @@ def register_embedding_gradient_hook( def hook_layers(module, grad_in, grad_out): embeddings_gradients.append(grad_out[0].detach().cpu().clone().numpy()) - embedding_layer = find_word_embeddings_layer(model, layer_name) - hook = embedding_layer.register_full_backward_hook(hook_layers) - - return hook + return embedding_layer.register_full_backward_hook(hook_layers) def get_saliency( diff --git a/config/development/clinc/conf.json b/config/development/clinc/conf.json index 666c810f..0f671a33 100644 --- a/config/development/clinc/conf.json +++ b/config/development/clinc/conf.json @@ -20,7 +20,6 @@ } }, "batch_size": 64, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT" } diff --git a/config/development/clinc/conf_light.json b/config/development/clinc/conf_light.json index 13849df9..4d11bc73 100644 --- a/config/development/clinc/conf_light.json +++ b/config/development/clinc/conf_light.json @@ -20,7 +20,6 @@ } }, "batch_size": 64, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "behavioral_testing": null, "similarity": null, diff --git a/config/development/clinc/conf_multipipeline.json b/config/development/clinc/conf_multipipeline.json index ab247f1f..17f4f1b5 100644 --- a/config/development/clinc/conf_multipipeline.json +++ b/config/development/clinc/conf_multipipeline.json @@ -39,7 +39,6 @@ } }, "batch_size": 64, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT", "uncertainty": { diff --git a/config/development/clinc_dummy/conf.json b/config/development/clinc_dummy/conf.json index 2524aa28..7d846312 100644 --- a/config/development/clinc_dummy/conf.json +++ b/config/development/clinc_dummy/conf.json @@ -19,7 +19,6 @@ "full_path": "/azimuth_shr/files/clinc-demo/oos-eval/data/clinc_data_dummy.json" } }, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT" } diff --git a/config/development/clinc_dummy/conf_bma.json b/config/development/clinc_dummy/conf_bma.json index 287be8cb..cca15d5b 100644 --- a/config/development/clinc_dummy/conf_bma.json +++ b/config/development/clinc_dummy/conf_bma.json @@ -19,7 +19,6 @@ "full_path": "/azimuth_shr/files/clinc-demo/oos-eval/data/clinc_data_dummy.json" } }, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT", "uncertainty": { diff --git a/config/development/clinc_dummy/conf_multipipeline.json b/config/development/clinc_dummy/conf_multipipeline.json index 7d9bb7c4..f1596a04 100644 --- a/config/development/clinc_dummy/conf_multipipeline.json +++ b/config/development/clinc_dummy/conf_multipipeline.json @@ -38,7 +38,6 @@ "full_path": "/azimuth_shr/files/clinc-demo/oos-eval/data/clinc_data_dummy.json" } }, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT" } diff --git a/config/development/clinc_dummy/conf_no_eval.json b/config/development/clinc_dummy/conf_no_eval.json index 1e9e2ef6..aac619d8 100644 --- a/config/development/clinc_dummy/conf_no_eval.json +++ b/config/development/clinc_dummy/conf_no_eval.json @@ -20,7 +20,6 @@ "eval": false } }, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT" } diff --git a/config/development/clinc_dummy/conf_no_model.json b/config/development/clinc_dummy/conf_no_model.json index c674bf15..2c7eb894 100644 --- a/config/development/clinc_dummy/conf_no_model.json +++ b/config/development/clinc_dummy/conf_no_model.json @@ -9,7 +9,6 @@ "full_path": "/azimuth_shr/files/clinc-demo/oos-eval/data/clinc_data_dummy.json" } }, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT" } diff --git a/config/development/clinc_dummy/conf_no_train.json b/config/development/clinc_dummy/conf_no_train.json index 059cd284..d024a17e 100644 --- a/config/development/clinc_dummy/conf_no_train.json +++ b/config/development/clinc_dummy/conf_no_train.json @@ -20,7 +20,6 @@ "train": false } }, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "hf_text_classification", "rejection_class": "NO_INTENT" } diff --git a/config/development/configurable_fake_dataset/conf.json b/config/development/configurable_fake_dataset/conf.json index 0c818f65..6a7d18d9 100644 --- a/config/development/configurable_fake_dataset/conf.json +++ b/config/development/configurable_fake_dataset/conf.json @@ -21,6 +21,5 @@ "behavioral_testing": null, "similarity": null, "rejection_class": null, - "saliency_layer": "distilbert.embeddings.word_embeddings", "model_contract": "custom_text_classification" } diff --git a/config/examples/banking77/conf.json b/config/examples/banking77/conf.json index 71b4e1b6..08cf8a8c 100644 --- a/config/examples/banking77/conf.json +++ b/config/examples/banking77/conf.json @@ -22,6 +22,5 @@ } } } - ], - "saliency_layer": "roberta.embeddings.word_embeddings" + ] } diff --git a/config/examples/clinc_oos/conf.json b/config/examples/clinc_oos/conf.json index e9a29618..862eb68e 100644 --- a/config/examples/clinc_oos/conf.json +++ b/config/examples/clinc_oos/conf.json @@ -24,6 +24,5 @@ } } ], - "saliency_layer": "distilbert.embeddings.word_embeddings", "batch_size": 256 } diff --git a/config/examples/clinc_oos/conf_multipipeline.json b/config/examples/clinc_oos/conf_multipipeline.json index b81e62e1..49eabb9c 100644 --- a/config/examples/clinc_oos/conf_multipipeline.json +++ b/config/examples/clinc_oos/conf_multipipeline.json @@ -43,7 +43,6 @@ ] } ], - "saliency_layer": "distilbert.embeddings.word_embeddings", "batch_size": 128, "uncertainty": { "iterations": 20 diff --git a/config/examples/sst2/conf.json b/config/examples/sst2/conf.json index 76959e1b..64882a4f 100644 --- a/config/examples/sst2/conf.json +++ b/config/examples/sst2/conf.json @@ -16,6 +16,5 @@ } } } - ], - "saliency_layer": "distilbert.embeddings.word_embeddings" + ] } diff --git a/docs/docs/getting-started/changelog.md b/docs/docs/getting-started/changelog.md index 1b3df584..fcf61364 100644 --- a/docs/docs/getting-started/changelog.md +++ b/docs/docs/getting-started/changelog.md @@ -1,5 +1,15 @@ # Releases +## [2.8.0] - 2023-07-06 + +### Added +- Config UI: + - Two buttons in the config UI to import/export a JSON config file. + - Drop-down menu to load a previous config. + - Discard buttons for individual config fields. +- HF pipelines: + - The saliency layer is now automatically detected and no longer needs to be specified in the config. To disable saliency maps, the user can still specify `null`. + ## [2.7.0] - 2023-06-12 ### Added diff --git a/docs/docs/reference/configuration/model_contract.md b/docs/docs/reference/configuration/model_contract.md index 4fc956fd..ebc95096 100644 --- a/docs/docs/reference/configuration/model_contract.md +++ b/docs/docs/reference/configuration/model_contract.md @@ -16,7 +16,7 @@ Fields from this scope defines how Azimuth interacts with the ML pipelines and t model_contract: SupportedModelContract = SupportedModelContract.hf_text_classification # (1) pipelines: Optional[List[PipelineDefinition]] = None # (2) uncertainty: UncertaintyOptions = UncertaintyOptions() # (3) - saliency_layer: Optional[str] = None # (4) + saliency_layer: Union[Literal["auto"], str, None] = "auto" # (4) metrics: Dict[str, MetricDefinition] = { # (5) "Precision": MetricDefinition( class_name="datasets.load_metric", @@ -40,7 +40,7 @@ Fields from this scope defines how Azimuth interacts with the ML pipelines and t 2. List of pipelines. Can also be set to `null` to launch Azimuth with a dataset only. 3. Enable uncertainty quantification. 4. Layer name where to calculate the gradients, normally the word embeddings layer. - Only available for Pytorch models. + Only available for Pytorch models. Defaults to "auto", in which case the embedding layer will be detected automatically using `model.base_model.get_input_embeddings()`. 5. HuggingFace Metrics. === "Config Example" @@ -208,10 +208,11 @@ in [:material-link: Uncertainty Estimation](../../key-concepts/uncertainty.md). ## Saliency Layer -🟡 **Default value**: `None` +🟡 **Default value**: `auto` If using a Pytorch model, [:material-link: Saliency Maps](../../key-concepts/saliency.md) can be -available. Specify the name of the embedding layer on which to compute them. +available. Specify the name of the embedding layer on which to compute them. The default is set to "auto", +in which case the embedding layer will be detected automatically using `model.base_model.get_input_embeddings()`. Example: `distilbert.embeddings.word_embeddings`. diff --git a/poetry.lock b/poetry.lock index 39b3e409..2ab9c535 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2483,18 +2483,20 @@ files = [ [[package]] name = "mpmath" -version = "1.2.1" +version = "1.3.0" description = "Python library for arbitrary-precision floating-point arithmetic" category = "main" optional = false python-versions = "*" files = [ - {file = "mpmath-1.2.1-py3-none-any.whl", hash = "sha256:604bc21bd22d2322a177c73bdb573994ef76e62edd595d17e00aff24b0667e5c"}, - {file = "mpmath-1.2.1.tar.gz", hash = "sha256:79ffb45cf9f4b101a807595bcb3e72e0396202e0b1d25d689134b48c4216a81a"}, + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, ] [package.extras] develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] tests = ["pytest (>=4.6)"] [[package]] @@ -4439,21 +4441,21 @@ files = [ [[package]] name = "requests" -version = "2.28.1" +version = "2.31.0" description = "Python HTTP for Humans." category = "main" optional = false -python-versions = ">=3.7, <4" +python-versions = ">=3.7" files = [ - {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, - {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, + {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, + {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, ] [package.dependencies] certifi = ">=2017.4.17" -charset-normalizer = ">=2,<3" +charset-normalizer = ">=2,<4" idna = ">=2.5,<4" -urllib3 = ">=1.21.1,<1.27" +urllib3 = ">=1.21.1,<3" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] diff --git a/pyproject.toml b/pyproject.toml index 18c8b798..a366df47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "azimuth" -version = "2.7.0" +version = "2.8.0" description = "Azimuth provides a unified error analysis experience to data scientists." readme = "README.md" authors = ["Azimuth team "] diff --git a/tests/conftest.py b/tests/conftest.py index 0eb7309e..20c0915d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -82,7 +82,6 @@ def simple_text_config(tmp_path): artifact_path=str(tmp_path), batch_size=10, model_contract="hf_text_classification", - saliency_layer="distilbert.embeddings.word_embeddings", rejection_class=None, behavioral_testing=SIMPLE_PERTURBATION_TESTING_CONFIG, ) @@ -156,7 +155,6 @@ def simple_text_config_french(tmp_path): artifact_path=str(tmp_path), batch_size=10, model_contract="hf_text_classification", - saliency_layer="distilbert.embeddings.word_embeddings", rejection_class=None, language=SupportedLanguage.fr, ) diff --git a/tests/test_config.py b/tests/test_config.py index dff7757b..b8e18551 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,6 +8,7 @@ from pydantic import ValidationError from azimuth.config import ( + ArtifactsConfig, AzimuthConfig, AzimuthConfigHistoryWithHash, PipelineDefinition, @@ -289,3 +290,10 @@ def test_config_history_with_hash(): ValidationError, match="1 validation error for AzimuthConfigHistoryWithHash\nconfig -> name" ): AzimuthConfigHistoryWithHash(config={"name": None}) + + +def test_artifact_path_equality(): + # This is important since we forbid updating the config if the artifact_path differs. + default = ArtifactsConfig() + assert ArtifactsConfig(artifact_path="cache/../cache") == default + assert ArtifactsConfig(artifact_path=f"{os.getcwd()}/cache") == default diff --git a/tests/test_routers/conftest.py b/tests/test_routers/conftest.py index 2eb7aff0..2c0300d6 100644 --- a/tests/test_routers/conftest.py +++ b/tests/test_routers/conftest.py @@ -58,7 +58,6 @@ def app() -> FastAPI: batch_size=16, use_cuda=False, model_contract="custom_text_classification", - saliency_layer="distilbert.embeddings.word_embeddings", rejection_class=None, behavioral_testing=SIMPLE_PERTURBATION_TESTING_CONFIG, ) diff --git a/tests/test_routers/test_config.py b/tests/test_routers/test_config.py index 5c639cd0..08ee27d9 100644 --- a/tests/test_routers/test_config.py +++ b/tests/test_routers/test_config.py @@ -1,5 +1,12 @@ +import os.path + from fastapi import FastAPI -from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR +from starlette.status import ( + HTTP_200_OK, + HTTP_400_BAD_REQUEST, + HTTP_403_FORBIDDEN, + HTTP_500_INTERNAL_SERVER_ERROR, +) from starlette.testclient import TestClient from azimuth.config import SupportedLanguage, config_defaults_per_language @@ -23,7 +30,7 @@ def test_get_default_config(app: FastAPI): "persistent_id": "row_idx", }, "rejection_class": "REJECTION_CLASS", - "artifact_path": "cache", + "artifact_path": os.path.abspath("cache"), "batch_size": 32, "use_cuda": "auto", "large_dask_cluster": False, @@ -64,7 +71,7 @@ def test_get_default_config(app: FastAPI): } ], "uncertainty": {"iterations": 1, "high_epistemic_threshold": 0.1}, - "saliency_layer": None, + "saliency_layer": "auto", "behavioral_testing": { "neutral_token": { "threshold": 1.0, @@ -220,7 +227,7 @@ def test_get_config(app: FastAPI): } ], "rejection_class": None, - "saliency_layer": "distilbert.embeddings.word_embeddings", + "saliency_layer": "auto", "similarity": { "faiss_encoder": "all-MiniLM-L12-v2", "conflicting_neighbors_threshold": 0.9, @@ -242,10 +249,19 @@ def test_get_config(app: FastAPI): def test_update_config(app: FastAPI, wait_for_startup_after): client = TestClient(app) initial_config = client.get("/config").json() - initial_contract = initial_config["model_contract"] - initial_pipelines = initial_config["pipelines"] initial_config_count = len(client.get("/config/history").json()) + resp = client.patch("/config", json={"artifact_path": "something/else"}) + assert resp.status_code == HTTP_403_FORBIDDEN, resp.text + + relative_artifact_path = os.path.relpath(initial_config["artifact_path"]) + assert relative_artifact_path != initial_config["artifact_path"] + resp = client.patch("/config", json={"artifact_path": relative_artifact_path}) + assert resp.status_code == HTTP_200_OK, resp.text + assert resp.json() == initial_config + new_config_count = len(client.get("/config/history").json()) + assert new_config_count == initial_config_count + resp = client.patch( "/config", json={"model_contract": "file_based_text_classification", "pipelines": None}, @@ -309,9 +325,7 @@ def test_update_config(app: FastAPI, wait_for_startup_after): assert not loaded_configs[-1]["config"]["pipelines"] # Revert config change - _ = client.patch( - "/config", json={"model_contract": initial_contract, "pipelines": initial_pipelines} - ) + _ = client.patch("/config", json=initial_config) loaded_configs = client.get("/config/history").json() assert loaded_configs[-1]["config"] == loaded_configs[initial_config_count - 1]["config"] diff --git a/webapp/src/components/Analysis/UtterancesTable.tsx b/webapp/src/components/Analysis/UtterancesTable.tsx index 96bd140f..50f07768 100644 --- a/webapp/src/components/Analysis/UtterancesTable.tsx +++ b/webapp/src/components/Analysis/UtterancesTable.tsx @@ -1,13 +1,13 @@ import { ArrowDropDown, Close, + Download, Fullscreen, - GetApp, SvgIconComponent, + Upload, } from "@mui/icons-material"; import FilterAltOutlinedIcon from "@mui/icons-material/FilterAltOutlined"; import MultilineChartIcon from "@mui/icons-material/MultilineChart"; -import UploadIcon from "@mui/icons-material/Upload"; import { Box, Button, @@ -33,6 +33,7 @@ import CopyButton from "components/CopyButton"; import Description from "components/Description"; import OutcomeIcon from "components/Icons/OutcomeIcon"; import TargetIcon from "components/Icons/Target"; +import FileInputButton from "components/FileInputButton"; import Loading from "components/Loading"; import SmartTagFamilyBadge from "components/SmartTagFamilyBadge"; import { Column, RowProps, Table } from "components/Table"; @@ -436,35 +437,28 @@ const UtterancesTable: React.FC = ({ }, ]; - const importProposedActions = (file: File) => { - const fileReader = new FileReader(); - fileReader.onload = ({ target }) => { - if (target) { - const result = target.result as string; - const [header, ...rows] = result.trimEnd().split(/\r?\n/); - if (rows.length === 0) { - raiseErrorToast("There are no records in the CSV file."); - return; - } - if (header !== `${config.columns.persistent_id},proposed_action`) { - raiseErrorToast( - `The CSV file must have column headers ${config.columns.persistent_id} and proposed_action, in that order.` - ); - return; - } + const importProposedActions = (text: string) => { + const [header, ...rows] = text.trimEnd().split(/\r?\n/); + if (rows.length === 0) { + raiseErrorToast("There are no records in the CSV file."); + return; + } + if (header !== `${config.columns.persistent_id},proposed_action`) { + raiseErrorToast( + `The CSV file must have column headers ${config.columns.persistent_id} and proposed_action, in that order.` + ); + return; + } - const body = rows.map((row) => { - const [persistentId, dataAction] = row.split(","); - return { persistentId, dataAction } as UtterancePatch; - }); - updateDataAction({ - ignoreNotFound: true, - body, - ...getUtterancesQueryState, - }); - } - }; - fileReader.readAsText(file); + const body = rows.map((row) => { + const [persistentId, dataAction] = row.split(","); + return { persistentId, dataAction } as UtterancePatch; + }); + updateDataAction({ + ignoreNotFound: true, + body, + ...getUtterancesQueryState, + }); }; const RowLink = (props: RowProps) => ( @@ -484,25 +478,19 @@ const UtterancesTable: React.FC = ({ link="user-guide/exploration-space/utterance-table/" /> - + +); + +export default React.memo(FileInputButton); diff --git a/webapp/src/components/HashChip.tsx b/webapp/src/components/HashChip.tsx new file mode 100644 index 00000000..32f9c98a --- /dev/null +++ b/webapp/src/components/HashChip.tsx @@ -0,0 +1,19 @@ +import { Chip, Tooltip } from "@mui/material"; +import React from "react"; + +const HashChip: React.FC<{ hash: string }> = ({ hash }) => ( + + ({ + backgroundColor: `#${hash}`, + color: theme.palette.getContrastText(`#${hash}`), + cursor: "unset", + fontFamily: "Monospace", + })} + /> + +); + +export default React.memo(HashChip); diff --git a/webapp/src/components/Metrics/DeltaComputationBar.tsx b/webapp/src/components/Metrics/DeltaComputationBar.tsx index 1afd69f9..5ebbff30 100644 --- a/webapp/src/components/Metrics/DeltaComputationBar.tsx +++ b/webapp/src/components/Metrics/DeltaComputationBar.tsx @@ -1,5 +1,5 @@ import React from "react"; -import { alpha, Box, Typography } from "@mui/material"; +import { Box, Typography, alpha } from "@mui/material"; import { motion } from "framer-motion"; type Props = { diff --git a/webapp/src/components/NoMaxWidthTooltip.tsx b/webapp/src/components/NoMaxWidthTooltip.tsx index 7a9d0dbc..2c969721 100644 --- a/webapp/src/components/NoMaxWidthTooltip.tsx +++ b/webapp/src/components/NoMaxWidthTooltip.tsx @@ -1,4 +1,4 @@ -import { TooltipProps, Tooltip, tooltipClasses, styled } from "@mui/material"; +import { Tooltip, TooltipProps, styled, tooltipClasses } from "@mui/material"; // From https://mui.com/material-ui/react-tooltip/#variable-width const NoMaxWidthTooltip = styled(({ className, ...props }: TooltipProps) => ( diff --git a/webapp/src/components/PerturbationTestingSummary/PerturbationTestingExporter.tsx b/webapp/src/components/PerturbationTestingSummary/PerturbationTestingExporter.tsx index ad7b821b..ea1ffa18 100644 --- a/webapp/src/components/PerturbationTestingSummary/PerturbationTestingExporter.tsx +++ b/webapp/src/components/PerturbationTestingSummary/PerturbationTestingExporter.tsx @@ -1,4 +1,4 @@ -import { ArrowDropDown, GetApp } from "@mui/icons-material"; +import { ArrowDropDown, Download } from "@mui/icons-material"; import { Button, Menu, MenuItem } from "@mui/material"; import React from "react"; import { QueryPipelineState } from "types/models"; @@ -42,7 +42,7 @@ const PerturbationTestingExporter: React.FC = ({ jobId, pipeline }) => { aria-controls="perturbation-testing-exporter-menu" aria-haspopup="true" onClick={handleClick} - startIcon={} + startIcon={} endIcon={} > Export diff --git a/webapp/src/components/Settings/AutocompleteStringField.tsx b/webapp/src/components/Settings/AutocompleteStringField.tsx index 4df57da9..78f776f8 100644 --- a/webapp/src/components/Settings/AutocompleteStringField.tsx +++ b/webapp/src/components/Settings/AutocompleteStringField.tsx @@ -6,12 +6,13 @@ import { } from "@mui/material"; import React from "react"; import { FieldProps, FIELD_COMMON_PROPS } from "./utils"; +import { DiscardButton } from "./DiscardButton"; const filterOptions = createFilterOptions(); const AutocompleteStringField: React.FC< Omit & FieldProps & { options: string[] } -> = ({ value, onChange, options, disabled, ...props }) => ( +> = ({ value, originalValue, onChange, options, disabled, ...props }) => ( onChange(newValue as string))} + onChange={(_, newValue) => onChange(newValue as string)} // The autoSelect prop would normally cause onChange when onBlur happens, // except it doesn't work when the input is empty, so we do it manually. - onBlur={ - onChange && - ((event: React.FocusEvent) => { - if (event.target.value !== value) { - onChange(event.target.value); - } - }) - } + onBlur={(event: React.FocusEvent) => { + if (event.target.value !== value) { + onChange(event.target.value); + } + }} renderInput={(params) => ( + {params.InputProps.endAdornment} + {originalValue !== undefined && originalValue !== value && ( + onChange(originalValue)} + /> + )} + + ), + }} {...FIELD_COMMON_PROPS} {...props} {...(value.trim() === "" && { error: true, helperText: "Set a value" })} diff --git a/webapp/src/components/Settings/CustomObjectFields.tsx b/webapp/src/components/Settings/CustomObjectFields.tsx index 0558b434..a386517e 100644 --- a/webapp/src/components/Settings/CustomObjectFields.tsx +++ b/webapp/src/components/Settings/CustomObjectFields.tsx @@ -7,13 +7,15 @@ const CustomObjectFields: React.FC<{ excludeClassName?: boolean; disabled: boolean; value: CustomObject; + originalValue: CustomObject | undefined; onChange: (update: Partial) => void; -}> = ({ excludeClassName, disabled, value, onChange }) => ( +}> = ({ excludeClassName, disabled, value, originalValue, onChange }) => ( <> {!excludeClassName && ( onChange({ class_name })} /> @@ -22,6 +24,7 @@ const CustomObjectFields: React.FC<{ label="remote" nullable value={value.remote} + originalValue={originalValue?.remote} disabled={disabled} onChange={(remote) => onChange({ remote })} /> @@ -29,12 +32,14 @@ const CustomObjectFields: React.FC<{ array label="args" value={value.args} + originalValue={originalValue?.args} disabled={disabled} onChange={(args) => onChange({ args })} /> onChange({ kwargs })} /> diff --git a/webapp/src/components/Settings/DiscardButton.tsx b/webapp/src/components/Settings/DiscardButton.tsx new file mode 100644 index 00000000..4fbc1158 --- /dev/null +++ b/webapp/src/components/Settings/DiscardButton.tsx @@ -0,0 +1,19 @@ +import { Replay } from "@mui/icons-material"; +import { InputAdornment, Tooltip, IconButton } from "@mui/material"; + +export const DiscardButton: React.FC<{ + title: string; + disabled: boolean; + onClick: () => void; +}> = ({ title, ...props }) => ( + + + + + + + +); diff --git a/webapp/src/components/Settings/JSONField.tsx b/webapp/src/components/Settings/JSONField.tsx index 192813b9..ab6e3eb6 100644 --- a/webapp/src/components/Settings/JSONField.tsx +++ b/webapp/src/components/Settings/JSONField.tsx @@ -1,6 +1,7 @@ import { TextFieldProps, TextField, Typography } from "@mui/material"; import React from "react"; import { FieldProps, FIELD_COMMON_PROPS } from "./utils"; +import { DiscardButton } from "./DiscardButton"; const stringifyJSON = (value: unknown, spaces = 2) => JSON.stringify(value, null, spaces) @@ -13,13 +14,15 @@ const JSONField: React.FC< | ({ array: true } & FieldProps) | ({ array?: false } & FieldProps>) ) -> = ({ value, onChange, array, ...props }) => { +> = ({ value, originalValue, onChange, array, ...props }) => { const [stringValue, setStringValue] = React.useState(stringifyJSON(value)); React.useEffect(() => setStringValue(stringifyJSON(value)), [value]); const [errorText, setErrorText] = React.useState(""); + const stringOriginalValue = originalValue && stringifyJSON(originalValue); + const adornments = array ? (["[", "]"] as const) : (["{", "}"] as const); const handleChange = (newStringValue: string) => { @@ -35,15 +38,13 @@ const JSONField: React.FC< } }; - const handleBlur = - onChange && - ((newStringValue: string) => { - try { - onChange(JSON.parse(adornments.join(newStringValue))); - } catch (error) { - setErrorText((error as SyntaxError).message); - } - }); + const handleBlur = (newStringValue: string) => { + try { + onChange(JSON.parse(adornments.join(newStringValue))); + } catch (error) { + setErrorText((error as SyntaxError).message); + } + }; return ( handleChange(event.target.value)} - onBlur={handleBlur && ((event) => handleBlur(event.target.value))} + onBlur={(event) => handleBlur(event.target.value)} InputProps={{ startAdornment: ( @@ -61,9 +62,19 @@ const JSONField: React.FC< ), endAdornment: ( - -  {adornments[1]} - + <> + +  {adornments[1]} + + {originalValue !== undefined && + stringOriginalValue !== stringValue && ( + onChange(originalValue as any)} + /> + )} + ), }} {...props} diff --git a/webapp/src/components/Settings/NumberField.tsx b/webapp/src/components/Settings/NumberField.tsx index 8bcc4658..749e428a 100644 --- a/webapp/src/components/Settings/NumberField.tsx +++ b/webapp/src/components/Settings/NumberField.tsx @@ -1,11 +1,12 @@ -import { InputAdornment, TextField, TextFieldProps } from "@mui/material"; +import { TextField, TextFieldProps, Typography } from "@mui/material"; import React from "react"; import { FieldProps, FIELD_COMMON_PROPS } from "./utils"; +import { DiscardButton } from "./DiscardButton"; const NumberField: React.FC< Omit & FieldProps & { scale?: number; units?: string } -> = ({ value, scale = 1, units, onChange, ...props }) => { +> = ({ value, originalValue, scale = 1, units, onChange, ...props }) => { // Control value with a `string` (and not with a `number`) so that for example // when hitting backspace at the end of `0.01`, you get `0.0` (and not `0`). const [stringValue, setStringValue] = React.useState(String(value * scale)); @@ -16,6 +17,12 @@ const NumberField: React.FC< } }, [value, scale, stringValue]); + const formatNumber = React.useCallback( + (x: number) => + `${x}${units ? ` ${x === 1 ? units.replace(/s$/, "") : units}` : ""}`, + [units] + ); + const helperText = React.useMemo(() => { if (props.inputProps === undefined) { return; @@ -31,10 +38,8 @@ const NumberField: React.FC< } const limit = props.inputProps[prop]; - return `Set ${prop}imum ${limit}${ - units ? ` ${limit === 1 ? units.replace(/s$/, "") : units}` : "" - }`; - }, [props.inputProps, scale, units, value]); + return `Set ${prop}imum ${formatNumber(limit)}`; + }, [props.inputProps, scale, value, formatNumber]); return ( {units} + endAdornment: ( + <> + {units && ( + // If we put text in an , it gets a different font size and color (that doesn't get disabled). + + {units} + + )} + {originalValue !== undefined && originalValue !== value && ( + onChange(originalValue)} + /> + )} + ), }} onChange={(event) => { setStringValue(event.target.value); - onChange && onChange(Number(event.target.value) / scale); + onChange(Number(event.target.value) / scale); }} {...props} /> diff --git a/webapp/src/components/Settings/StringArrayField.tsx b/webapp/src/components/Settings/StringArrayField.tsx index 3337825c..07d67712 100644 --- a/webapp/src/components/Settings/StringArrayField.tsx +++ b/webapp/src/components/Settings/StringArrayField.tsx @@ -1,35 +1,62 @@ import { Autocomplete, - TextField, Chip, + TextField, + TextFieldProps, autocompleteClasses, chipClasses, } from "@mui/material"; import React from "react"; import { FieldProps, FIELD_COMMON_PROPS } from "./utils"; +import { DiscardButton } from "./DiscardButton"; const StringArrayField: React.FC< - FieldProps & { label?: string; units?: string; disabled: boolean } -> = ({ value, onChange, label, units = label || "token", disabled }) => ( - & + FieldProps & { label?: string; units?: string } +> = ({ + value, + originalValue, + onChange, + label, + units = label || "token", + disabled, + ...props +}) => ( + disableClearable freeSolo multiple options={[]} value={value} disabled={disabled} - onChange={onChange && ((_, newValue) => onChange(newValue as string[]))} + onChange={(_, newValue) => onChange(newValue as string[])} renderInput={(params) => ( - Write a{/^[aeiou]/.test(units) && "n"} {units} and press enter + Write a{/^[aeio]/.test(units) && "n"} {units} and press enter } + InputProps={{ + ...params.InputProps, + endAdornment: ( + <> + {params.InputProps.endAdornment} + {originalValue !== undefined && + JSON.stringify(originalValue) !== JSON.stringify(value) && ( + onChange(originalValue)} + /> + )} + + ), + }} + {...props} /> )} renderTags={(value, getTagProps) => diff --git a/webapp/src/components/Settings/StringField.tsx b/webapp/src/components/Settings/StringField.tsx index d5078155..b4c82ea6 100644 --- a/webapp/src/components/Settings/StringField.tsx +++ b/webapp/src/components/Settings/StringField.tsx @@ -1,8 +1,10 @@ import { MenuItem, TextField, TextFieldProps } from "@mui/material"; import { FieldProps, FIELD_COMMON_PROPS } from "./utils"; +import { DiscardButton } from "./DiscardButton"; const StringField = ({ value, + originalValue, onChange, nullable, options, @@ -20,12 +22,22 @@ const StringField = ({ select={Boolean(options)} inputProps={{ sx: { textOverflow: "ellipsis" } }} value={value ?? ""} + required={!nullable && !options} {...(value?.trim() === "" && { error: true, helperText: "Set a value" })} + InputProps={{ + endAdornment: + originalValue === undefined || originalValue === value ? null : ( + onChange(originalValue)} + /> + ), + }} onChange={ - onChange && - (nullable + nullable ? (event) => onChange((event.target.value || null) as T) - : (event) => onChange(event.target.value as T)) + : (event) => onChange(event.target.value as T) } {...props} > diff --git a/webapp/src/components/Settings/utils.ts b/webapp/src/components/Settings/utils.ts index 78e64d96..b69195d7 100644 --- a/webapp/src/components/Settings/utils.ts +++ b/webapp/src/components/Settings/utils.ts @@ -1,4 +1,9 @@ -export type FieldProps = { value: T; onChange?: (newValue: T) => void }; +export type FieldProps = { + value: T; + originalValue: T | undefined; + disabled: boolean; + onChange: (newValue: T) => void; +}; export const FIELD_COMMON_PROPS = { size: "small", diff --git a/webapp/src/components/ThresholdPlot.tsx b/webapp/src/components/ThresholdPlot.tsx index 8b250ef6..ae8003d1 100644 --- a/webapp/src/components/ThresholdPlot.tsx +++ b/webapp/src/components/ThresholdPlot.tsx @@ -1,10 +1,10 @@ import { Info } from "@mui/icons-material"; import { - alpha, Box, CircularProgress, Tooltip, Typography, + alpha, } from "@mui/material"; import makeStyles from "@mui/styles/makeStyles"; import React from "react"; diff --git a/webapp/src/pages/Settings.tsx b/webapp/src/pages/Settings.tsx index 0dd149e4..3f420488 100644 --- a/webapp/src/pages/Settings.tsx +++ b/webapp/src/pages/Settings.tsx @@ -1,4 +1,11 @@ -import { Close, Warning } from "@mui/icons-material"; +import { + ArrowDropDown, + Close, + Download, + History, + Upload, + Warning, +} from "@mui/icons-material"; import { Box, Button, @@ -18,10 +25,15 @@ import { InputBaseComponentProps, inputClasses, inputLabelClasses, + Menu, + MenuItem, + Theme, Typography, } from "@mui/material"; import noData from "assets/void.svg"; import AccordionLayout from "components/AccordionLayout"; +import FileInputButton from "components/FileInputButton"; +import HashChip from "components/HashChip"; import Loading from "components/Loading"; import AutocompleteStringField from "components/Settings/AutocompleteStringField"; import CustomObjectFields from "components/Settings/CustomObjectFields"; @@ -35,8 +47,10 @@ import React from "react"; import { useParams } from "react-router-dom"; import { getConfigEndpoint, + getConfigHistoryEndpoint, getDefaultConfigEndpoint, updateConfigEndpoint, + validateConfigEndpoint, } from "services/api"; import { AzimuthConfig, @@ -49,12 +63,33 @@ import { ThresholdConfig, } from "types/api"; import { PickByValue } from "types/models"; +import { downloadBlob } from "utils/api"; import { UNKNOWN_ERROR } from "utils/const"; +import { formatDateISO } from "utils/format"; +import { raiseErrorToast } from "utils/helpers"; type MetricState = MetricDefinition & { name: string }; type ConfigState = Omit & { metrics: MetricState[] }; +const azimuthConfigToConfigState = ({ + metrics, + ...rest +}: AzimuthConfig): ConfigState => ({ + ...rest, + metrics: Object.entries(metrics).map(([name, m]) => ({ name, ...m })), +}); + +const configStateToAzimuthConfig = ({ + metrics, + ...rest +}: Partial): Partial => ({ + ...rest, + ...(metrics && { + metrics: Object.fromEntries(metrics.map(({ name, ...m }) => [name, m])), + }), +}); + const CONFIG_UPDATE_MESSAGE = "Please wait while the config changes are validated."; const PERCENTAGE = { scale: 100, units: "%", inputProps: { min: 0, max: 100 } }; @@ -102,13 +137,19 @@ const USE_CUDA_OPTIONS = ["auto", "true", "false"] as const; type UseCUDAOption = typeof USE_CUDA_OPTIONS[number]; const FIELDS_TRIGGERING_STARTUP_TASKS: (keyof ConfigState)[] = [ + "dataset", + "columns", + "rejection_class", "behavioral_testing", "similarity", "dataset_warnings", "syntax", + "model_contract", "pipelines", "uncertainty", + "saliency_layer", "metrics", + "language", ]; type KnownPostprocessor = TemperatureScaling | ThresholdConfig; @@ -120,8 +161,18 @@ const KNOWN_POSTPROCESSORS: { "azimuth.utils.ml.postprocessing.Thresholding": { threshold: 0.5 }, }; -const Columns: React.FC<{ columns?: number }> = ({ columns = 1, children }) => ( - +const Columns: React.FC<{ columns: number | string }> = ({ + columns, + children, +}) => ( + {children} ); @@ -132,11 +183,32 @@ const displaySectionTitle = (section: string) => ( ); -const KeyValuePairs: React.FC = ({ children }) => ( - - {children} - -); +const KeyValuePairs: React.FC<{ + label: string; + disabled: boolean; + keyValuePairs: [string, React.ReactNode][]; +}> = ({ label, disabled, keyValuePairs }) => { + const sx = disabled + ? { color: (theme: Theme) => theme.palette.text.disabled } + : {}; + return ( + + + {label} + + + {keyValuePairs.map(([key, value]) => ( + + + {key}: + + {value} + + ))} + + + ); +}; const updateArrayAt = (array: T[], index: number, update: Partial) => splicedArray(array, index, 1, { ...array[index], ...update }); @@ -152,18 +224,19 @@ const Settings: React.FC = ({ open, onClose }) => { SupportedLanguage | undefined >(); const { data: azimuthConfig } = getConfigEndpoint.useQuery({ jobId }); - const config = React.useMemo(() => { - if (azimuthConfig === undefined) return undefined; - const { metrics, ...rest } = azimuthConfig; - return { - metrics: Object.entries(metrics).map(([name, m]) => ({ name, ...m })), - ...rest, - }; - }, [azimuthConfig]); + const config = React.useMemo( + () => azimuthConfig && azimuthConfigToConfigState(azimuthConfig), + [azimuthConfig] + ); + + const [validateConfig, { isLoading: isValidatingConfig }] = + validateConfigEndpoint.useMutation(); const [updateConfig, { isLoading: isUpdatingConfig }] = updateConfigEndpoint.useMutation(); + const areInputsDisabled = isValidatingConfig || isUpdatingConfig; + const [partialConfig, setPartialConfig] = React.useState< Partial >({}); @@ -194,6 +267,11 @@ const Settings: React.FC = ({ open, onClose }) => { language: language ?? resultingConfig.language, }); + const { data: configHistory } = getConfigHistoryEndpoint.useQuery({ jobId }); + + const [configHistoryAnchor, setConfigHistoryAnchor] = + React.useState(null); + const updatePartialConfig = React.useCallback( (update: Partial) => setPartialConfig((partialConfig) => ({ ...partialConfig, ...update })), @@ -243,6 +321,47 @@ const Settings: React.FC = ({ open, onClose }) => { metricsNames.has("") || resultingConfig.metrics.some(({ class_name }) => class_name.trim() === ""); + const fullHashCount = new Set(configHistory?.map(({ hash }) => hash)).size; + const hashCount = new Set(configHistory?.map(({ hash }) => hash.slice(0, 3))) + .size; + const nameCount = new Set(configHistory?.map(({ config }) => config.name)) + .size; + + // Don't show the hash (hashSize = null) if no two different configs have the same name. + // Show a 6-char hash if there is a collision in the first 3 chars. + // Otherwise, show a 3-char hash. + // Probability of a hash collision with the 3-char hash: + // 10 different configs: 1 % + // 30 different configs: 10 % + // 76 different configs: 50 % + // With the 6-char hash: + // 581 different configs: 1 % + const hashChars = + nameCount === fullHashCount ? null : hashCount === fullHashCount ? 3 : 6; + + const handleFileRead = (text: string) => { + try { + const body = JSON.parse(text); + validateConfig({ jobId, body }) + .unwrap() + .then((config) => setPartialConfig(azimuthConfigToConfigState(config))) + .catch(() => {}); // Avoid the uncaught error log. Toast already raised by `rtkQueryErrorInterceptor` middleware. + } catch (error) { + raiseErrorToast( + `Something went wrong parsing JSON file\n${ + (error as SyntaxError).message + }` + ); + } + }; + + const handleDownload = () => { + const azimuthConfig = configStateToAzimuthConfig(resultingConfig); + const text = JSON.stringify(azimuthConfig, null, 2); + const blob = new Blob([text], { type: "application/json" }); + downloadBlob(blob, "config.json"); + }; + const renderDialog = (children: React.ReactNode) => ( = ({ open, onClose }) => { open={open} > - - - View and edit certain fields from your config file. Once your - changes are saved, expect some delays for recomputing the affected - tasks. + + + Configuration + {configHistory?.length && ( + <> + + setConfigHistoryAnchor(null)} + > + {configHistory + .map(({ config, created_on, hash }, index) => ( + { + setConfigHistoryAnchor(null); + setPartialConfig(azimuthConfigToConfigState(config)); + }} + > + {config.name} + {hashChars && ( + + )} + + {formatDateISO(new Date(created_on))} + + + )) + .reverse()} + + + )} + } + onFileRead={handleFileRead} + > + Import JSON config file + + { if ( isEmptyPartialConfig || @@ -299,7 +472,7 @@ const Settings: React.FC = ({ open, onClose }) => {