Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added lru caches to speed up validphys multi-replica initalization #1945

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ requirements:
- pineappl >=0.6.2
- eko >=0.14.1
- fiatlux
- curio >=1.0 # reportengine uses it but it's not in its dependencies
- frozendict # needed for caching of data loading
- curio >=1.0 # reportengine uses it but it's not in its dependencies

test:
requires:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ vp-nextfitruncard = "validphys.scripts.vp_nextfitruncard:main"
vp-hyperoptplot = "validphys.scripts.vp_hyperoptplot:main"
vp-deltachi2 = "validphys.scripts.vp_deltachi2:main"

[tool.poetry.dependencies]
[tool.poetry.dependencies]
# Generic dependencies (i.e., validphys)
python = "^3.9"
matplotlib = ">=3.3.0,<3.8"
Expand All @@ -68,6 +68,7 @@ pandas = "<2"
numpy = "*"
validobj = "*"
prompt_toolkit = "*"
frozendict = "*" # validphys: needed for caching of data loading
# Reportengine (and its dependencies) need to be installed in a bit more manual way
reportengine = { git = "https://github.com/NNPDF/reportengine", rev = "3bb2b1d"}
ruamel_yaml = {version = "<0.18"}
Expand Down
3 changes: 2 additions & 1 deletion validphys2/src/validphys/commondata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
"""
from reportengine import collect
from validphys.commondataparser import load_commondata
import functools


@functools.lru_cache
def loaded_commondata_with_cuts(commondata, cuts):
"""Load the commondata and apply cuts.

Expand Down
11 changes: 9 additions & 2 deletions validphys2/src/validphys/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@
)
from validphys.paramfits.config import ParamfitsConfig
from validphys.plotoptions import get_info
from validphys.utils import freeze_args
from frozendict import frozendict
import validphys.scalevariations

log = logging.getLogger(__name__)

log = logging.getLogger(__name__)

class Environment(Environment):
"""Container for information to be filled at run time"""
Expand Down Expand Up @@ -1247,6 +1249,11 @@ def parse_default_filter_rules_recorded_spec_(self, spec):
def parse_added_filter_rules(self, rules: (list, type(None)) = None):
return rules

# Every parallel replica triggers a series of calls to this function,
# which should not happen since the rules are identical among replicas.
# E.g for NNPDF4.0 with 2 parallel replicas 693 calls, 3 parallel replicas 1001 calls...
@freeze_args
@functools.lru_cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a note to myself to check what enters here and whether maybe the problems are related to what I see in #1678 (since if so, that won't be a problem)

Are you absolutely certain this one is needed for the multi-replica? Since in principle this function should only happen once.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will check in master

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NNPDF4.0 with 3 replicas it is called 1001 times. For 2 replicas 693 times.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then indeed, let's keep the freeze args for now, but please add a comment with that info on top. That should definitely not happen since the rules should be the same per replica, so I believe this is a signal of a problem somewhere else which hopefully can eventually be solved.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment in 19771a0

def produce_rules(
self,
theoryid,
Expand Down Expand Up @@ -1286,7 +1293,7 @@ def produce_rules(

if added_filter_rules:
for i, rule in enumerate(added_filter_rules):
if not isinstance(rule, dict):
if not isinstance(rule, (dict, frozendict)):
raise ConfigError(f"added rule {i} is not a dict")
try:
rule_list.append(
Expand Down
3 changes: 2 additions & 1 deletion validphys2/src/validphys/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,9 @@ def load_cfactors(self):

return [[parse_cfactor(c.open("rb")) for c in cfacs] for cfacs in self.cfactors]

@functools.lru_cache()
def load_with_cuts(self, cuts):
"""Load the fktable and apply cuts inmediately. Returns a FKTableData"""
"""Load the fktable and apply cuts immediately. Returns a FKTableData"""
return load_fktable(self).with_cuts(cuts)


Expand Down
2 changes: 2 additions & 0 deletions validphys2/src/validphys/covmats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
matrices on different levels of abstraction
"""
import logging
import functools

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -226,6 +227,7 @@ def dataset_inputs_covmat_from_systematics(


@check_cuts_considered
@functools.lru_cache
def dataset_t0_predictions(dataset, t0set):
"""Returns the t0 predictions for a ``dataset`` which are the predictions
calculated using the central member of ``pdf``. Note that if ``pdf`` has
Expand Down
4 changes: 4 additions & 0 deletions validphys2/src/validphys/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from importlib.resources import read_text
import logging
import re
import functools

import numpy as np

from reportengine.checks import check, make_check
from reportengine.compat import yaml
from validphys.commondatawriter import write_commondata_to_file, write_systype_to_file
import validphys.cuts
from validphys.utils import freeze_args

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -561,6 +563,8 @@ def _make_point_namespace(self, dataset, idat) -> dict:
return ns


@freeze_args
@functools.lru_cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note to myself.
But this one I think might be solved by some of the stuff I'm doing in the reader.

Copy link
Collaborator Author

@goord goord Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NNPDF4.0 with 3 replicas it is called 1001 times. For 2 replicas 693 times.

--> I'm sorry, this comment was intended for the other function produce_rules in config.py!!!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one yes, but I think it might be fixed now (in other PR). The other one I'm more worried about.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NNPDF4.0 with 3 replicas it is called 1040 times. For 2 replicas 732 times.

def get_cuts_for_dataset(commondata, rules) -> list:
"""Function to generate a list containing the index
of all experimental points that passed kinematic
Expand Down
29 changes: 29 additions & 0 deletions validphys2/src/validphys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,41 @@
@author: Zahari Kassabov
"""
import contextlib
import functools
import pathlib
import shutil
import tempfile
from typing import Any, Sequence, Mapping, Hashable

import numpy as np
from validobj import ValidationError, parse_input
from frozendict import frozendict


def make_hashable(obj: Any):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
APJansen marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(obj, Hashable):
return obj
elif isinstance(obj, Mapping):
return frozendict(obj)
elif isinstance(obj, Sequence):
return tuple([make_hashable(i) for i in obj])
else:
raise ValueError("Object is not hashable")


def freeze_args(func):
"""Transform mutable dictionary
Into immutable
Useful to be compatible with cache
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
args = tuple([make_hashable(arg) for arg in args])
kwargs = {k: make_hashable(v) for k, v in kwargs.items()}
return func(*args, **kwargs)
return wrapped


def parse_yaml_inp(inp, spec, path):
Expand Down