Skip to content

Commit

Permalink
Merge pull request #102 from nhsx/harry/streamlit_viz
Browse files Browse the repository at this point in the history
Add `dashboard` module
  • Loading branch information
HarrisonWilde authored Aug 30, 2023
2 parents b6d1190 + 0ff8767 commit e7385fa
Show file tree
Hide file tree
Showing 32 changed files with 1,629 additions and 588 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ If you intend on contributing or working with the codebase directly, or if you w

*Note that you can omit the `poetry run` part and just type `nhssynth` if you followed the optional steps above to manage and activate your own virtual environment, or if you have executed `poetry shell` beforehand.*

- Through building the package with `poetry build` and using it in an existing project (`from nhssynth.modules... import ...`). You can then actively develop the package and test it.
- Through directly importing parts of the package to use in an existing project (`from nhssynth.modules... import ...`).

### Usage

Expand Down
8 changes: 6 additions & 2 deletions config/test_pipeline.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
seed: 1
experiment_name: test
run_type: pipeline
dataloader:
collapse_yaml: false
collapse_yaml: true
write_csv: true
model:
architecture:
- VAE
- DPVAE
num_epochs: 30
target_epsilon: 1.0
target_epsilon:
- 1.0
- 3.0
- 6.0
max_grad_norm: 5.0
secure_mode: false
repeats: 3
Expand Down
1,238 changes: 862 additions & 376 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ classifiers = [
"Docs" = "https://nhsx.github.io/NHSSynth"

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
python = ">=3.9,<3.9.7 || >3.9.7,<3.12"
pandas = "^2.0.1"
scikit-learn = "^1.2.2"
tqdm = "^4.65.0"
Expand All @@ -31,6 +31,8 @@ networkx = "^3.1"
pyvis = "^0.3.2"
sdmetrics = "^0.11.0"
tornado = "^6.3.3"
streamlit = "^1.25.0"
plotly = "^5.16.1"

[tool.poetry.scripts]
nhssynth = "nhssynth.cli.__main__:__main__"
Expand Down
28 changes: 18 additions & 10 deletions src/nhssynth/cli/common_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,7 @@ def get_core_parser(overrides=False) -> argparse.ArgumentParser:
"--experiment-name",
type=str,
default=TIME,
help=f"name the experiment run to affect logging, config, and default-behaviour io",
)
core_grp.add_argument(
"-s",
"--seed",
type=int,
help="specify a seed for reproducibility, this is a recommended option for reproducibility",
help=f"name the experiment run to affect logging, config, and default-behaviour i/o",
)
core_grp.add_argument(
"--save-config",
Expand All @@ -41,6 +35,19 @@ def get_core_parser(overrides=False) -> argparse.ArgumentParser:
return core


def get_seed_parser(overrides=False) -> argparse.ArgumentParser:
"""Create the common parser for the seed."""
parser = argparse.ArgumentParser(add_help=False)
parser_grp = parser.add_argument_group(title="options")
parser_grp.add_argument(
"-s",
"--seed",
type=int,
help="specify a seed for reproducibility, this is a recommended option for reproducibility",
)
return parser


COMMON_TITLE: Final = "starting any of the following args with `_` defaults to a suffix on DATASET (e.g. `_metadata` -> `<DATASET>_metadata`);\nall filenames are relative to `experiments/<EXPERIMENT_NAME>/` unless otherwise stated"


Expand Down Expand Up @@ -71,6 +78,7 @@ def get_parser(overrides: bool = False) -> argparse.ArgumentParser:

COMMON_PARSERS: Final = {
"core": get_core_parser,
"seed": get_seed_parser,
"metadata": suffix_parser_generator(
"metadata",
"filename of the metadata, NOTE that `dataloader` attempts to read this from `<DATA_DIR>`",
Expand All @@ -95,12 +103,12 @@ def get_parser(overrides: bool = False) -> argparse.ArgumentParser:
"synthetic",
"filename of the synthetic data",
),
"experiment_bundle": suffix_parser_generator(
"experiment_bundle",
"experiments": suffix_parser_generator(
"experiments",
"filename of the experiment bundle, i.e. the collection of all seeds, models, and synthetic datasets",
),
"evaluation_bundle": suffix_parser_generator(
"evaluation_bundle",
"filename of the (collection of) evaluation(s) for a given `experiment_bundle`",
"filename of the (collection of) evaluation(s) for a given set of `experiments`",
),
}
27 changes: 22 additions & 5 deletions src/nhssynth/cli/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Read, write and process config files, including handling of module-specific / common config overrides."""
import argparse
import warnings
from importlib.metadata import version as ver
from typing import Any, Callable

import yaml
Expand Down Expand Up @@ -68,6 +69,14 @@ def read_config(

valid_run_types = [x for x in all_subparsers.keys() if x != "config"]

version = config_dict.pop("version", None)
if version and version != version("nhssynth"):
warnings.warn(
f"This config file's specified version ({version}) does not match the currently installed version of nhssynth ({version('nhssynth')}), results may differ."
)
elif not version:
version = ver("nhssynth")

run_type = config_dict.pop("run_type", None)

if run_type == "pipeline":
Expand Down Expand Up @@ -113,6 +122,7 @@ def read_config(
# Run the appropriate execution function(s)
if not new_args.seed:
warnings.warn("No seed has been specified, meaning the results of this run may not be reproducible.")
new_args.version = version
new_args.modules_to_run = modules_to_run
new_args.module_handover = {}
for module in new_args.modules_to_run:
Expand Down Expand Up @@ -184,11 +194,18 @@ def assemble_config(
for module_name in modules_to_run:
for k in args_dict.copy().keys():
# We want to keep dataset, experiment_name, seed and save_config at the top-level as they are core args
if k in module_args[module_name] and k not in {"dataset", "experiment_name", "seed", "save_config"}:
if out_dict.get(module_name):
out_dict[module_name].update({k: args_dict.pop(k)})
else:
out_dict[module_name] = {k: args_dict.pop(k)}
if k in module_args[module_name] and k not in {
"version",
"dataset",
"experiment_name",
"seed",
"save_config",
}:
if module_name not in out_dict:
out_dict[module_name] = {}
v = args_dict.pop(k)
if v is not None:
out_dict[module_name][k] = v

# Assemble the final dictionary in YAML-compliant form
return {**({"run_type": run_type} if run_type else {}), **args_dict, **out_dict}
Expand Down
8 changes: 8 additions & 0 deletions src/nhssynth/cli/model_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,51 @@ def add_vae_args(group: argparse._ArgumentGroup, overrides: bool = False) -> Non
group.add_argument(
"--encoder-latent-dim",
type=int,
nargs="+",
help="the latent dimension of the encoder",
)
group.add_argument(
"--encoder-hidden-dim",
type=int,
nargs="+",
help="the hidden dimension of the encoder",
)
group.add_argument(
"--encoder-activation",
type=str,
nargs="+",
choices=list(ACTIVATION_FUNCTIONS.keys()),
help="the activation function of the encoder",
)
group.add_argument(
"--encoder-learning-rate",
type=float,
nargs="+",
help="the learning rate for the encoder",
)
group.add_argument(
"--decoder-latent-dim",
type=int,
nargs="+",
help="the latent dimension of the decoder",
)
group.add_argument(
"--decoder-hidden-dim",
type=int,
nargs="+",
help="the hidden dimension of the decoder",
)
group.add_argument(
"--decoder-activation",
type=str,
nargs="+",
choices=list(ACTIVATION_FUNCTIONS.keys()),
help="the activation function of the decoder",
)
group.add_argument(
"--decoder-learning-rate",
type=float,
nargs="+",
help="the learning rate for the decoder",
)
group.add_argument(
Expand Down
25 changes: 20 additions & 5 deletions src/nhssynth/cli/module_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,14 @@ def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides:
group.add_argument(
"--batch-size",
type=int,
nargs="+",
default=32,
help="the batch size for the model",
)
group.add_argument(
"--num-epochs",
type=int,
nargs="+",
default=100,
help="number of epochs to train for",
)
Expand Down Expand Up @@ -139,25 +141,23 @@ def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides:
help="the number of samples to generate from the model, defaults to the size of the original dataset",
)
privacy_group = parser.add_argument_group(title="model privacy options")
privacy_group.add_argument(
"--non-private",
action="store_true",
help="train the model in a non-private way",
)
privacy_group.add_argument(
"--target-epsilon",
type=float,
nargs="+",
default=1.0,
help="the target epsilon for differential privacy",
)
privacy_group.add_argument(
"--target-delta",
type=float,
nargs="+",
help="the target delta for differential privacy, defaults to `1 / len(dataset)` if not specified",
)
privacy_group.add_argument(
"--max-grad-norm",
type=float,
nargs="+",
default=5.0,
help="the clipping threshold for gradients (only relevant under differential privacy)",
)
Expand Down Expand Up @@ -266,3 +266,18 @@ def add_plotting_args(parser: argparse.ArgumentParser, group_title: str, overrid
action="store_true",
help="plot the t-SNE embeddings of the real and synthetic data",
)


def add_dashboard_args(parser: argparse.ArgumentParser, group_title: str, overrides: bool = False):
group = parser.add_argument_group(title=group_title)
group.add_argument(
"--file-size-limit",
type=str,
default="1000",
help="the maximum file size to upload in MB",
)
group.add_argument(
"--dont-load",
action="store_true",
help="don't attempt to automatically load data into the dashboard",
)
32 changes: 25 additions & 7 deletions src/nhssynth/cli/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

from nhssynth.cli.common_arguments import COMMON_PARSERS
from nhssynth.cli.module_arguments import *
from nhssynth.modules import dataloader, evaluation, model, plotting, structure
from nhssynth.modules import (
dashboard,
dataloader,
evaluation,
model,
plotting,
structure,
)


class ModuleConfig:
Expand All @@ -26,18 +33,19 @@ def __init__(
description: str,
help: str,
common_parsers: Optional[list[str]] = None,
no_seed: bool = False,
) -> None:
self.func = func
self.add_args = add_args
self.description = description
self.help = help
self.common_parsers = ["core"]
self.common_parsers = ["core", "seed"] if not no_seed else ["core"]
if common_parsers:
assert set(common_parsers) <= COMMON_PARSERS.keys(), "Invalid common parser(s) specified."
# merge the below two assert statements
assert (
"core" not in common_parsers
), "The and 'core' parser group is automatically added to all modules, remove it from the ModuleConfig."
"core" not in common_parsers and "seed" not in common_parsers
), "The 'seed' and 'core' parser groups are automatically added to all modules, remove the from `ModuleConfig`s."
self.common_parsers += common_parsers

def __call__(self, args: argparse.Namespace) -> argparse.Namespace:
Expand Down Expand Up @@ -84,7 +92,7 @@ def add_config_args(parser: argparse.ArgumentParser) -> None:
"dataloader",
"model",
"evaluation",
"plotting",
"dashboard",
] # NOTE this determines the order of a pipeline run

MODULE_MAP: Final = {
Expand All @@ -106,14 +114,14 @@ def add_config_args(parser: argparse.ArgumentParser) -> None:
add_args=add_model_args,
description="run the model architecture module, to train a synthetic data generator",
help="train a model",
common_parsers=["transformed", "metatransformer", "synthetic", "experiment_bundle"],
common_parsers=["transformed", "metatransformer", "synthetic", "experiments"],
),
"evaluation": ModuleConfig(
func=evaluation.run,
add_args=add_evaluation_args,
description="run the evaluation module, to evaluate an experiment",
help="evaluate an experiment",
common_parsers=["sdv_metadata", "typed", "experiment_bundle", "evaluation_bundle"],
common_parsers=["sdv_metadata", "typed", "experiments", "evaluation_bundle"],
),
"plotting": ModuleConfig(
func=plotting.run,
Expand All @@ -122,6 +130,14 @@ def add_config_args(parser: argparse.ArgumentParser) -> None:
help="generate plots",
common_parsers=["typed", "evaluation_bundle"],
),
"dashboard": ModuleConfig(
func=dashboard.run,
add_args=add_dashboard_args,
description="run the dashboard module, to produce a streamlit dashboard",
help="start up a streamlit dashboard to view the results of an evaluation",
common_parsers=["typed", "experiments", "evaluation_bundle"],
no_seed=True,
),
"pipeline": ModuleConfig(
func=run_pipeline,
add_args=add_pipeline_args,
Expand Down Expand Up @@ -149,6 +165,8 @@ def get_parent_parsers(name: str, module_parsers: list[str]) -> list[argparse.Ar
"""Get a list of parent parsers for a given module, based on the module's `common_parsers` attribute."""
if name in {"pipeline", "config"}:
return [p(name == "config") for p in COMMON_PARSERS.values()]
elif name == "dashboard":
return [COMMON_PARSERS[pn](True) for pn in module_parsers]
else:
return [COMMON_PARSERS[pn]() for pn in module_parsers]

Expand Down
9 changes: 7 additions & 2 deletions src/nhssynth/cli/run.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import argparse
import time
import warnings

from nhssynth.cli.config import get_modules_to_run, read_config, write_config
from nhssynth.cli.module_setup import MODULE_MAP, add_subparser
from nhssynth.common.strings import format_timedelta


def run(sysargv) -> None:
print("Starting up the NHSSynth CLI! 🚀\n")
start_time = time.time()

parser = argparse.ArgumentParser(
prog="nhssynth",
description="CLI for preparing, training and evaluating a synthetic data generator.",
Expand All @@ -22,7 +27,7 @@ def run(sysargv) -> None:

executor = vars(args).get("func", None)
if executor:
if not args.seed:
if hasattr(args, "seed") and not args.seed:
warnings.warn("No seed has been specified, meaning the results of this run may not be reproducible.")
args.modules_to_run = get_modules_to_run(executor)
args.module_handover = {}
Expand All @@ -35,4 +40,4 @@ def run(sysargv) -> None:
if args.save_config:
write_config(args, all_subparsers)

print("Finished!")
print(f"Finished! 🎉 (Total run time: {format_timedelta(start_time, time.time())})")
Loading

0 comments on commit e7385fa

Please sign in to comment.