Skip to content

Commit

Permalink
Re-release 2.8.0 (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephMarinier authored Jul 6, 2023
2 parents fa7fc5e + e7ccc47 commit 0c3928a
Show file tree
Hide file tree
Showing 50 changed files with 933 additions and 415 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ Released changes are shown in the
### Removed

### Fixed
- Fixed importing the same proposed actions CSV file twice

### Security
64 changes: 33 additions & 31 deletions azimuth/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def create_app() -> FastAPI:
Returns:
FastAPI.
"""
app = FastAPI(
api = FastAPI(
title="Azimuth API",
description="Azimuth API",
version="1.0",
Expand All @@ -171,101 +171,103 @@ def create_app() -> FastAPI:
)

# Setup routes
from azimuth.routers.app import router as app_router
from azimuth.routers.class_overlap import router as class_overlap_router
from azimuth.routers.config import router as config_router
from azimuth.routers.custom_utterances import router as custom_utterances_router
from azimuth.routers.dataset_warnings import router as dataset_warnings_router
from azimuth.routers.export import router as export_router
from azimuth.routers.model_performance.confidence_histogram import (
router as confidence_histogram_router,
from azimuth.routers import (
app,
class_overlap,
config,
custom_utterances,
dataset_warnings,
export,
top_words,
utterances,
)
from azimuth.routers.model_performance import (
confidence_histogram,
confusion_matrix,
metrics,
outcome_count,
utterance_count,
)
from azimuth.routers.model_performance.confusion_matrix import router as confusion_matrix_router
from azimuth.routers.model_performance.metrics import router as metrics_router
from azimuth.routers.model_performance.outcome_count import router as outcome_count_router
from azimuth.routers.model_performance.utterance_count import router as utterance_count_router
from azimuth.routers.top_words import router as top_words_router
from azimuth.routers.utterances import router as utterances_router
from azimuth.utils.routers import require_application_ready, require_available_model

api_router = APIRouter()
api_router.include_router(app_router, prefix="", tags=["App"])
api_router.include_router(config_router, prefix="/config", tags=["Config"])
api_router.include_router(app.router, prefix="", tags=["App"])
api_router.include_router(config.router, prefix="/config", tags=["Config"])
api_router.include_router(
class_overlap_router,
class_overlap.router,
prefix="/dataset_splits/{dataset_split_name}/class_overlap",
tags=["Class Overlap"],
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(
confidence_histogram_router,
confidence_histogram.router,
prefix="/dataset_splits/{dataset_split_name}/confidence_histogram",
tags=["Confidence Histogram"],
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
)
api_router.include_router(
dataset_warnings_router,
dataset_warnings.router,
prefix="/dataset_warnings",
tags=["Dataset Warnings"],
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(
metrics_router,
metrics.router,
prefix="/dataset_splits/{dataset_split_name}/metrics",
tags=["Metrics"],
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
)
api_router.include_router(
outcome_count_router,
outcome_count.router,
prefix="/dataset_splits/{dataset_split_name}/outcome_count",
tags=["Outcome Count"],
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
)
api_router.include_router(
utterance_count_router,
utterance_count.router,
prefix="/dataset_splits/{dataset_split_name}/utterance_count",
tags=["Utterance Count"],
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(
utterances_router,
utterances.router,
prefix="/dataset_splits/{dataset_split_name}/utterances",
tags=["Utterances"],
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(
export_router,
export.router,
prefix="/export",
tags=["Export"],
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(
custom_utterances_router,
custom_utterances.router,
prefix="/custom_utterances",
tags=["Custom Utterances"],
dependencies=[Depends(require_application_ready)],
)
api_router.include_router(
top_words_router,
top_words.router,
prefix="/dataset_splits/{dataset_split_name}/top_words",
tags=["Top Words"],
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
)
api_router.include_router(
confusion_matrix_router,
confusion_matrix.router,
prefix="/dataset_splits/{dataset_split_name}/confusion_matrix",
tags=["Confusion Matrix"],
dependencies=[Depends(require_application_ready), Depends(require_available_model)],
)
app.include_router(api_router)
api.include_router(api_router)

app.add_middleware(
api.add_middleware(
CORSMiddleware,
allow_methods=["*"],
allow_headers=["*"],
)

return app
return api


def load_dataset_split_managers_from_config(
Expand Down
8 changes: 6 additions & 2 deletions azimuth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,15 @@ def get_project_hash(self):

class ArtifactsConfig(AzimuthBaseSettings, extra=Extra.ignore):
artifact_path: str = Field(
"cache",
default_factory=lambda: os.path.abspath("cache"),
description="Where to store artifacts (Azimuth config history, HDF5 files, HF datasets).",
exclude_from_cache=True,
)

@validator("artifact_path")
def validate_artifact_path(cls, artifact_path):
return os.path.abspath(artifact_path)

def get_config_history_path(self):
return f"{self.artifact_path}/config_history.jsonl"

Expand Down Expand Up @@ -329,7 +333,7 @@ class ModelContractConfig(CommonFieldsConfig):
# Uncertainty configuration
uncertainty: UncertaintyOptions = UncertaintyOptions()
# Layer name where to calculate the gradients, normally the word embeddings layer.
saliency_layer: Optional[str] = Field(None, nullable=True)
saliency_layer: Union[Literal["auto"], str, None] = Field("auto", nullable=True)

@validator("pipelines", pre=True)
def _check_pipeline_names(cls, pipeline_definitions):
Expand Down
21 changes: 12 additions & 9 deletions azimuth/modules/model_contracts/hf_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from azimuth.types.task import PredictionResponse, SaliencyResponse
from azimuth.utils.ml.mc_dropout import MCDropout
from azimuth.utils.ml.saliency import (
find_word_embeddings_layer,
get_saliency,
register_embedding_gradient_hook,
register_embedding_list_hook,
Expand Down Expand Up @@ -127,6 +128,7 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
raise ValueError("This method should not be called when saliency_layer is not defined.")

hf_pipeline = self.get_model()
hf_model = hf_pipeline.model

inputs = hf_pipeline.tokenizer(
batch[self.config.columns.text_input],
Expand All @@ -140,18 +142,19 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
inputs["input_ids"] = inputs["input_ids"].to(hf_pipeline.device)
inputs["attention_mask"] = inputs["attention_mask"].to(hf_pipeline.device)

logits = hf_pipeline.model(**inputs)[0]
logits = hf_model(**inputs)[0]
output = torch.softmax(logits, dim=1).detach().cpu().numpy()
prediction = output.argmax(-1)

embeddings_list: List[np.ndarray] = []
handle = register_embedding_list_hook(
hf_pipeline.model, embeddings_list, self.saliency_layer
embedding_layer = (
hf_model.base_model.get_input_embeddings()
if self.saliency_layer == "auto"
else find_word_embeddings_layer(hf_model, self.saliency_layer)
)
embeddings_list: List[np.ndarray] = []
handle = register_embedding_list_hook(embeddings_list, embedding_layer)
embeddings_gradients: List[np.ndarray] = []
hook = register_embedding_gradient_hook(
hf_pipeline.model, embeddings_gradients, self.saliency_layer
)
hook = register_embedding_gradient_hook(embeddings_gradients, embedding_layer)

filter_class = self.mod_options.filter_class
selected_classes = (
Expand All @@ -162,8 +165,8 @@ def saliency(self, batch: Dataset) -> List[SaliencyResponse]:
)

# Do backward pass to compute gradients
hf_pipeline.model.zero_grad()
_loss = hf_pipeline.model(**inputs)[0] # loss is at index 0 when passing labels
hf_model.zero_grad()
_loss = hf_model(**inputs)[0] # loss is at index 0 when passing labels
_loss.backward()
handle.remove()
hook.remove()
Expand Down
60 changes: 40 additions & 20 deletions azimuth/routers/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright ServiceNow, Inc. 2021 – 2022
# This source code is licensed under the Apache 2.0 license found in the LICENSE file
# in the root directory of this source tree.
from typing import Any, Dict, List
from typing import Dict, List

import structlog
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import ValidationError
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_403_FORBIDDEN,
HTTP_500_INTERNAL_SERVER_ERROR,
)

from azimuth.app import (
get_config,
Expand Down Expand Up @@ -73,6 +76,25 @@ def get_config_def(
return config


@router.patch(
"/validate",
summary="Validate config",
description="Validate the given partial config update and return the complete config that would"
" result if this update was applied.",
response_model=AzimuthConfig,
dependencies=[Depends(require_editable_config)],
)
def validate_config(
config: AzimuthConfig = Depends(get_config),
partial_config: Dict = Body(...),
) -> AzimuthConfig:
new_config = update_config(old_config=config, partial_config=partial_config)

assert_permission_to_update_config(old_config=config, new_config=new_config)

return new_config


@router.patch(
"",
summary="Update config",
Expand All @@ -85,19 +107,17 @@ def patch_config(
config: AzimuthConfig = Depends(get_config),
partial_config: Dict = Body(...),
) -> AzimuthConfig:
if attribute_changed_in_config("artifact_path", partial_config, config):
raise HTTPException(
HTTP_400_BAD_REQUEST,
detail="Cannot edit artifact_path, otherwise config history would become inconsistent.",
)
log.info(f"Validating config change with {partial_config}.")
new_config = update_config(old_config=config, partial_config=partial_config)

assert_permission_to_update_config(old_config=config, new_config=new_config)

if new_config.large_dask_cluster != config.large_dask_cluster:
cluster = default_cluster(new_config.large_dask_cluster)
else:
cluster = task_manager.cluster

try:
log.info(f"Validating config change with {partial_config}.")
new_config = update_config(old_config=config, partial_config=partial_config)
if attribute_changed_in_config("large_dask_cluster", partial_config, config):
cluster = default_cluster(partial_config["large_dask_cluster"])
else:
cluster = task_manager.cluster
run_startup_tasks(new_config, cluster)
log.info(f"Config successfully updated with {partial_config}.")
except Exception as e:
Expand All @@ -107,8 +127,6 @@ def patch_config(
log.info("Config update cancelled.")
if isinstance(e, AzimuthValidationError):
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
if isinstance(e, ValidationError):
raise
else:
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, detail="Error when loading the new config."
Expand All @@ -117,7 +135,9 @@ def patch_config(
return new_config


def attribute_changed_in_config(
attribute: str, partial_config: Dict[str, Any], config: AzimuthConfig
) -> bool:
return attribute in partial_config and partial_config[attribute] != getattr(config, attribute)
def assert_permission_to_update_config(*, old_config: AzimuthConfig, new_config: AzimuthConfig):
if old_config.artifact_path != new_config.artifact_path:
raise HTTPException(
HTTP_403_FORBIDDEN,
detail="Cannot edit artifact_path, otherwise config history would become inconsistent.",
)
21 changes: 7 additions & 14 deletions azimuth/utils/ml/saliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import structlog
from torch.nn import Embedding

from azimuth.types.general.module_arguments import GradientCalculation

Expand Down Expand Up @@ -36,14 +37,13 @@ def find_word_embeddings_layer(model: Any, layer_name: str) -> Any:


def register_embedding_list_hook(
model: Any, embeddings_list: List[np.ndarray], layer_name: str
embeddings_list: List[np.ndarray], embedding_layer: Embedding
) -> Any:
"""Register hook to get the embedding values from model.
Args:
model: Model.
embeddings_list: Variable to save values.
layer_name: Name of the embedding layer.
embedding_layer: Embedding layer on which to compute the saliency map.
Returns:
Hook.
Expand All @@ -52,21 +52,17 @@ def register_embedding_list_hook(
def forward_hook(module, inputs, output):
embeddings_list.append(output.detach().cpu().clone().numpy())

embedding_layer = find_word_embeddings_layer(model, layer_name)
handle = embedding_layer.register_forward_hook(forward_hook)

return handle
return embedding_layer.register_forward_hook(forward_hook)


def register_embedding_gradient_hook(
model: Any, embeddings_gradients: List[np.ndarray], layer_name: str
embeddings_gradients: List[np.ndarray], embedding_layer: Embedding
) -> Any:
"""Register hook to get the gradient values from the embedding layer.
Args:
model: Model.
embeddings_gradients: Variable to save values.
layer_name: Name of the embedding layer.
embedding_layer: Embedding layer on which to compute the saliency map.
Returns:
Hook.
Expand All @@ -76,10 +72,7 @@ def register_embedding_gradient_hook(
def hook_layers(module, grad_in, grad_out):
embeddings_gradients.append(grad_out[0].detach().cpu().clone().numpy())

embedding_layer = find_word_embeddings_layer(model, layer_name)
hook = embedding_layer.register_full_backward_hook(hook_layers)

return hook
return embedding_layer.register_full_backward_hook(hook_layers)


def get_saliency(
Expand Down
1 change: 0 additions & 1 deletion config/development/clinc/conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
}
},
"batch_size": 64,
"saliency_layer": "distilbert.embeddings.word_embeddings",
"model_contract": "hf_text_classification",
"rejection_class": "NO_INTENT"
}
Loading

0 comments on commit 0c3928a

Please sign in to comment.