Skip to content

Commit

Permalink
fix domainnet target issue (#118)
Browse files Browse the repository at this point in the history
* update requirements.txt

* update docs

* add desc

* fix domainnet data target

* remove unnecessary comments

* update algo_config.py

---------

Co-authored-by: tremblerz <abhishek.s14@iiits.in>
  • Loading branch information
gautamjajoo and tremblerz authored Oct 20, 2024
1 parent a4bbe4e commit a185d8a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/configs/sys_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,21 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE):
"seed": 1,
"num_collaborators": NUM_COLLABORATORS,
"load_existing": False,
"device_ids": get_device_ids(num_users=swarm_users, gpus_available=[1, 2]),
# "algo": get_algo_configs(num_users=swarm_users, algo_configs=default_config_list), # type: ignore
"algos": get_algo_configs(
num_users=swarm_users,
algo_configs=default_config_list,
), # type: ignore
"dump_dir": DUMP_DIR,
"device_ids": get_device_ids(num_users=swarm_users, gpus_available=[3, 4]),
"algo": get_algo_configs(num_users=swarm_users, algo_configs=default_config_list), # type: ignore
# Dataset params
"dset": get_domainnet_support(
swarm_users
), # get_camelyon17_support(fedcentral_client), #get_domainnet_support(fedcentral_client),
"dpath": domainnet_dpath, # wilds_dpath,#domainnet_dpath,
"train_label_distribution": "iid", # Either "iid", "shard" "support",
"test_label_distribution": "iid", # Either "iid" "support",
"samples_per_user": 32,
"samples_per_user": 500,
"test_samples_per_class": 100,
"community_type": "dataset",
"exp_keys": [],
Expand Down
1 change: 1 addition & 0 deletions src/data_loaders/domainnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, data_paths, data_labels, transforms, domain_name, cache=False
self.transforms = transforms
self.domain_name = domain_name
self.cached_data = []
self.targets = data_labels

if cache:
for idx, _ in enumerate(data_paths):
Expand Down

0 comments on commit a185d8a

Please sign in to comment.