From a185d8a8bc4513c98f2467ffa56722d8c91dd8a1 Mon Sep 17 00:00:00 2001 From: Gautam Jajoo Date: Sun, 20 Oct 2024 22:05:47 +0530 Subject: [PATCH] fix domainnet target issue (#118) * update requirements.txt * update docs * add desc * fix domainnet data target * remove unnecessary comments * update algo_config.py --------- Co-authored-by: tremblerz --- src/configs/sys_config.py | 10 +++++++--- src/data_loaders/domainnet.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 36fb607..5baa403 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -265,9 +265,13 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "seed": 1, "num_collaborators": NUM_COLLABORATORS, "load_existing": False, + "device_ids": get_device_ids(num_users=swarm_users, gpus_available=[1, 2]), + # "algo": get_algo_configs(num_users=swarm_users, algo_configs=default_config_list), # type: ignore + "algos": get_algo_configs( + num_users=swarm_users, + algo_configs=default_config_list, + ), # type: ignore "dump_dir": DUMP_DIR, - "device_ids": get_device_ids(num_users=swarm_users, gpus_available=[3, 4]), - "algo": get_algo_configs(num_users=swarm_users, algo_configs=default_config_list), # type: ignore # Dataset params "dset": get_domainnet_support( swarm_users @@ -275,7 +279,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dpath": domainnet_dpath, # wilds_dpath,#domainnet_dpath, "train_label_distribution": "iid", # Either "iid", "shard" "support", "test_label_distribution": "iid", # Either "iid" "support", - "samples_per_user": 32, + "samples_per_user": 500, "test_samples_per_class": 100, "community_type": "dataset", "exp_keys": [], diff --git a/src/data_loaders/domainnet.py b/src/data_loaders/domainnet.py index 4354524..0edca0f 100644 --- a/src/data_loaders/domainnet.py +++ b/src/data_loaders/domainnet.py @@ -43,6 +43,7 @@ def __init__(self, data_paths, data_labels, transforms, domain_name, cache=False self.transforms = transforms self.domain_name = domain_name self.cached_data = [] + self.targets = data_labels if cache: for idx, _ in enumerate(data_paths):