Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code for performance #46

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def batch_size(self) -> int:
class DataclassBatch(BatchBase):
@classmethod
def feature_names(cls):
return list(cls.__dataclass_fields__.keys())
return [* cls.__dataclass_fields__]

def as_dict(self):
return {
Expand Down
2 changes: 1 addition & 1 deletion common/checkpointing/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def load_snapshot_to_weight(
"""
start_time = time.time()
manifest = embedding_snapshot.get_manifest()
for path in manifest.keys():
for path in manifest:
if path.startswith("0") and snapshot_emb_name in path:
snapshot_path_to_load = path
embedding_snapshot.read_object(snapshot_path_to_load, weight_tensor)
Expand Down
4 changes: 2 additions & 2 deletions common/log_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def weights_to_log(
if not how_to_log:
return

to_log = dict()
to_log = {}
named_parameters = model.named_parameters()
logging.info(f"Using DMP: {isinstance(model, DistributedModelParallel)}")
if isinstance(model, DistributedModelParallel):
Expand Down Expand Up @@ -58,7 +58,7 @@ def log_ebc_norms(
i.e. model.embeddings.ebc.embedding_bags.meta__user_id.weight
sample_size: Limits number of rows per rank to compute average on to avoid OOM.
"""
norm_logs = dict()
norm_logs = {}
for emb_key in ebc_keys:
norms = (torch.ones(1, dtype=torch.float32) * -1).to(torch.device(f"cuda:{dist.get_rank()}"))
if emb_key in model_state_dict:
Expand Down
2 changes: 1 addition & 1 deletion optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
self.group_names = [* self.lr_dict]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way compromises the readability

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there is a tradeoff between readability and performance. It is up to the team to decide which to maximize. You may argue that I would prefer readability for maintainability over that little performance gain, and that's totally fine.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no practical performance benefit from doing it this way.
It's trivial to benchmark the difference:

In [4]: def test1():
   ...:     return list(a.keys())
   ...: 

In [5]: def test2():
   ...:     return [*a]
   ...: 

In [6]: %timeit -n 10_000_000 test2
23.1 ns ± 0.214 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

In [7]: %timeit -n 10_000_000 test1
23.1 ns ± 0.246 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

using ipython

Copy link
Author

@wiseaidev wiseaidev Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> from timeit import timeit
>>> timeit("[]")
0.09845466999831842
>>> timeit("list()")
0.22419986899876676

Also, less memory footprint:

>>> import dis
>>> dis.dis("[]")
  1           0 BUILD_LIST               0
              2 RETURN_VALUE
>>> dis.dis("list()")
  1           0 LOAD_NAME                0 (list)
              2 CALL_FUNCTION            0
              4 RETURN_VALUE

Copy link

@KristinnVikarJ KristinnVikarJ Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is simply not the same code at all, you're comparing apples to oranges.
My benchmark shows the exact change you did, and shows no performance improvement in the slightest.

In fact, if you run the benchmark using Python 3.11 the older method is consistently faster.

Python 3.11.2 (main, Feb  8 2023, 14:49:25) [GCC 11.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.12.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: a = {"a": 5, "k": 3, "asdf": 53}

In [2]: def test1():
   ...:     return list(a.keys())
   ...: 

In [3]: def test2():
   ...:     return [*a]
   ...: 

In [4]: %timeit -n 10_000_000 test2
13.7 ns ± 1.16 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [5]: %timeit -n 10_000_000 test1
13 ns ± 1.15 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [6]: %timeit -n 10_000_000 test1
13 ns ± 1.22 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

In [7]: %timeit -n 10_000_000 test2
14.1 ns ± 1.15 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [12]: dis.dis(test1)
  1           0 RESUME                   0

  2           2 LOAD_GLOBAL              1 (NULL + list)
             14 LOAD_GLOBAL              2 (a)
             26 LOAD_METHOD              2 (keys)
             48 PRECALL                  0
             52 CALL                     0
             62 PRECALL                  1
             66 CALL                     1
             76 RETURN_VALUE

In [13]: dis.dis(test2)
  1           0 RESUME                   0

  2           2 BUILD_LIST               0
              4 LOAD_GLOBAL              0 (a)
             16 LIST_EXTEND              1
             18 RETURN_VALUE

While test2 has a smaller bytecode output, it is still slower than test1. Not all opcodes are equally fast.

Copy link
Author

@wiseaidev wiseaidev Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like python 3.11 introduced a lot of optimizations out of the box, and it is the default behavior under the hood. On python 3.9, I am getting these results:

>>> timeit('[* {"a": 5, "k": 3, "asdf": 53}]')
0.7558078520014533
>>> timeit('list({"a": 5, "k": 3, "asdf": 53}.keys())')
1.142674779999652

Copy link

@KristinnVikarJ KristinnVikarJ Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the change would have a minor performance benefit, none of the code paths seem to be hot, so these nano optimizations gain nothing in practice.


num_param_groups = sum(1 for _, _optim in optimizer._optims for _ in _optim.param_groups)
if num_param_groups != len(lr_dict):
Expand Down
4 changes: 2 additions & 2 deletions projects/home/recap/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(
vocab_mapper: tf.keras.Model = None,
):
logging.info("***** Labels *****")
logging.info(list(data_config.tasks.keys()))
logging.info([* data_config.tasks])

self._data_config = data_config
self._parse_fn = get_seg_dense_parse_fn(data_config)
Expand Down Expand Up @@ -295,7 +295,7 @@ def __init__(
add_weights=should_add_weights,
)

sparse_feature_names = list(vocab_mapper.vocabs.keys()) if vocab_mapper else None
sparse_feature_names = [* vocab_mapper.vocabs] if vocab_mapper else None

self._tf_dataset = self._create_tf_dataset()

Expand Down
4 changes: 2 additions & 2 deletions projects/home/recap/data/tfe_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def create_tf_example_schema(
A dictionary schema suitable for deserializing tf.Example.
"""
segdense_config = data_config.seg_dense_schema
labels = list(data_config.tasks.keys())
labels = [* data_config.tasks]
used_features = (
segdense_config.features + list(segdense_config.renamed_features.values()) + labels
)
Expand Down Expand Up @@ -96,7 +96,7 @@ def parse_tf_example(
# at TF level.
# We should not return empty tensors if we dont use embeddings.
# Otherwise, it breaks numpy->pt conversion
renamed_keys = list(seg_dense_schema_config.renamed_features.keys())
renamed_keys = [* seg_dense_schema_config.renamed_features]
for renamed_key in renamed_keys:
if "embedding" in renamed_key and (renamed_key not in inputs):
inputs[renamed_key] = tf.zeros([], tf.float32)
Expand Down
4 changes: 2 additions & 2 deletions projects/home/recap/data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def keyed_tensor_from_tensors_dict(
Returns:

"""
keys = list(tensor_map.keys())
keys = [* tensor_map]
# We expect batch size to be first dim. However, if we get a shape [Batch_size],
# KeyedTensor will not find the correct batch_size. So, in those cases we make sure the shape is
# [Batch_size x 1].
Expand Down Expand Up @@ -84,7 +84,7 @@ def keyed_jagged_tensor_from_tensors_dict(
lengths = torch.cat(lengths, axis=0)

return torchrec.KeyedJaggedTensor(
keys=list(tensor_map.keys()),
keys=[* tensor_map],
values=values,
lengths=lengths,
)
Expand Down
2 changes: 1 addition & 1 deletion projects/home/recap/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(unused_argv: str, data_service_dispatcher: Optional[str] = None):

loss_fn = losses.build_multi_task_loss(
loss_type=LossType.BCE_WITH_LOGITS,
tasks=list(config.model.tasks.keys()),
tasks=[* config.model.tasks],
pos_weights=[task.pos_weight for task in config.model.tasks.values()],
)

Expand Down
2 changes: 1 addition & 1 deletion projects/home/recap/model/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
neg_downsampling_rate=data_config.tasks[task_name].neg_downsampling_rate,
)

self._task_names = list(config.tasks.keys())
self._task_names = [* config.tasks]
self._towers = torch.nn.ModuleDict(_towers)
self._affine_maps = torch.nn.ModuleDict(_affine_maps)
self._calibrators = torch.nn.ModuleDict(_calibrators)
Expand Down
4 changes: 2 additions & 2 deletions projects/home/recap/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
):
self.optimizer = optimizer
self.lr_dict = lr_dict
self.group_names = list(self.lr_dict.keys())
self.group_names = [* self.lr_dict]
self.emb_learning_rate = emb_learning_rate

# We handle sparse LR scheduling separately, so only validate LR groups against dense param groups
Expand Down Expand Up @@ -146,7 +146,7 @@ def build_optimizer(
)
)

if not parameter_groups.keys() == all_learning_rates.keys():
if not parameter_groups == all_learning_rates:
raise ValueError("Learning rates do not match optimizers")

# If the optimiser is dense, model.fused_optimizer will be empty (but not None)
Expand Down
2 changes: 1 addition & 1 deletion projects/twhin/data/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(

self.table_sizes = table_sizes
self.num_tables = len(table_sizes)
self.table_names = list(table_sizes.keys())
self.table_names = [* table_sizes]

self.relations = relations
self.relations_t = torch.tensor(
Expand Down
2 changes: 1 addition & 1 deletion projects/twhin/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def create_metrics(
device: torch.device,
):
metrics = dict()
metrics = {}
metrics.update(
{
"AUC": core_metrics.Auc(128),
Expand Down