Skip to content

Commit

Permalink
Fix UEManager.load (#265)
Browse files Browse the repository at this point in the history
* upd load

* fix
  • Loading branch information
ArtemVazh authored Nov 28, 2024
1 parent a495779 commit e2282c8
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/lm_polygraph/utils/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
FactoryStatCalculator,
StatCalculatorContainer,
)
from lm_polygraph.defaults.register_default_stat_calculators import (
register_default_stat_calculators,
)
from lm_polygraph.utils.common import flatten_results

import logging
Expand Down Expand Up @@ -442,8 +445,8 @@ def save(self, save_path: str):
@staticmethod
def load(
load_path: str,
builder_env_stat_calc: BuilderEnvironmentStatCalculator,
available_stat_calculators: List[StatCalculatorContainer],
builder_env_stat_calc: BuilderEnvironmentStatCalculator = None,
available_stat_calculators: List[StatCalculatorContainer] = None,
) -> "UEManager":
"""
Loads UEManager from the specified path. To save the calculated manager results, see UEManager.save().
Expand All @@ -452,6 +455,16 @@ def load(
load_path (str): Path to file with saved benchmark results to load.
"""
res_dict = torch.load(load_path)

if available_stat_calculators is None:
result_stat_calculators = dict()
scs = register_default_stat_calculators("Whitebox")
for sc in scs:
result_stat_calculators[sc.name] = sc
available_stat_calculators = list(result_stat_calculators.values())
if builder_env_stat_calc is None:
builder_env_stat_calc = BuilderEnvironmentStatCalculator(model=None)

man = UEManager(
data=None,
model=None,
Expand Down

0 comments on commit e2282c8

Please sign in to comment.