diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 17c6a090584..d4c7870b838 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -1,8 +1,11 @@ -from typing import Mapping, List, Union, Tuple, Dict, Optional, Any +from typing import Mapping, List, Union, Tuple, Dict, Optional, Any, AbstractSet import abc import orjson import pkg_resources import zipfile + +from hailtop.config.user_config import configuration_of + from ..fs.fs import FS from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS from ..expr import Expression @@ -25,6 +28,34 @@ def fatal_error_from_java_error_triplet(short_message, expanded_message, error_i class Backend(abc.ABC): + # Must match knownFlags in HailFeatureFlags.py + _flags_env_vars_and_defaults: Dict[str, Tuple[str, Optional[str]]] = { + "no_whole_stage_codegen": ("HAIL_DEV_NO_WHOLE_STAGE_CODEGEN", None), + "no_ir_logging": ("HAIL_DEV_NO_IR_LOG", None), + "lower": ("HAIL_DEV_LOWER", None), + "lower_only": ("HAIL_DEV_LOWER_ONLY", None), + "lower_bm": ("HAIL_DEV_LOWER_BM", None), + "print_ir_on_worker": ("HAIL_DEV_PRINT_IR_ON_WORKER", None), + "print_inputs_on_worker": ("HAIL_DEV_PRINT_INPUTS_ON_WORKER", None), + "max_leader_scans": ("HAIL_DEV_MAX_LEADER_SCANS", "1000"), + "distributed_scan_comb_op": ("HAIL_DEV_DISTRIBUTED_SCAN_COMB_OP", None), + "jvm_bytecode_dump": ("HAIL_DEV_JVM_BYTECODE_DUMP", None), + "write_ir_files": ("HAIL_WRITE_IR_FILES", None), + "method_split_ir_limit": ("HAIL_DEV_METHOD_SPLIT_LIMIT", "16"), + "use_new_shuffle": ("HAIL_USE_NEW_SHUFFLE", None), + "shuffle_max_branch_factor": ("HAIL_SHUFFLE_MAX_BRANCH", "64"), + "shuffle_cutoff_to_local_sort": ("HAIL_SHUFFLE_CUTOFF", "512000000"), # This is in bytes + "grouped_aggregate_buffer_size": ("HAIL_GROUPED_AGGREGATE_BUFFER_SIZE", "50"), + "use_ssa_logs": ("HAIL_USE_SSA_LOGS", None), + "gcs_requester_pays_project": ("HAIL_GCS_REQUESTER_PAYS_PROJECT", None), + "gcs_requester_pays_buckets": ("HAIL_GCS_REQUESTER_PAYS_BUCKETS", None), + "index_branching_factor": ("HAIL_INDEX_BRANCHING_FACTOR", None), + "rng_nonce": ("HAIL_RNG_NONCE", "0x0") + } + + def _valid_flags(self) -> AbstractSet[str]: + return self._flags_env_vars_and_defaults.keys() + @abc.abstractmethod def __init__(self): self._persisted_locations = dict() @@ -167,6 +198,12 @@ def register_ir_function(self, def persist_expression(self, expr: Expression) -> Expression: pass + def _initialize_flags(self) -> None: + self.set_flags(**{ + k: configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar) + for k, (deprecated_envvar, default) in Backend._flags_env_vars_and_defaults.items() + }) + @abc.abstractmethod def set_flags(self, **flags: Mapping[str, str]): """Set Hail flags.""" diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index 5934f1da55e..0c42ab10125 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -172,6 +172,7 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor, if not quiet: connect_logger(self._utils_package_object, 'localhost', 12888) + self._initialize_flags() def jvm(self): return self._jvm diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index abc149f5edd..df24a1d4f1c 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -29,7 +29,6 @@ from ..fs.fs import FS from ..fs.router_fs import RouterFS from ..ir import BaseIR -from ..utils import frozendict ReferenceGenomeConfig = Dict[str, Any] @@ -223,7 +222,7 @@ async def create(*, flags = {"use_new_shuffle": "1", **(flags or {})} - return ServiceBackend( + sb = ServiceBackend( billing_project=billing_project, sync_fs=sync_fs, async_fs=async_fs, @@ -239,6 +238,8 @@ async def create(*, worker_memory=worker_memory, name_prefix=name_prefix or '', ) + sb._initialize_flags() + return sb def __init__(self, *, @@ -632,10 +633,16 @@ def persist_expression(self, expr): return read_expression(fname, _assert_type=expr.dtype) def set_flags(self, **flags: str): + unknown_flags = set(flags) - self._valid_flags() + if unknown_flags: + raise ValueError(f'unknown flags: {", ".join(unknown_flags)}') self.flags.update(flags) - def get_flags(self, *flags) -> Mapping[str, str]: - return frozendict(self.flags) + def get_flags(self, *flags: str) -> Mapping[str, str]: + unknown_flags = set(flags) - self._valid_flags() + if unknown_flags: + raise ValueError(f'unknown flags: {", ".join(unknown_flags)}') + return {flag: self.flags[flag] for flag in flags if flag in self.flags} @property def requires_lowering(self): diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index dd5219104b4..ddbb22aaa95 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -224,6 +224,7 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master, connect_logger(self._utils_package_object, 'localhost', 12888) self._jbackend.startProgressBar() + self._initialize_flags() def jvm(self): return self._jvm diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index 31ac420f597..e207aeba6e6 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -39,16 +39,31 @@ def get_user_config() -> configparser.ConfigParser: T = TypeVar('T') -def configuration_of(section: str, option: str, explicit_argument: Optional[str], fallback: T) -> Union[str, T]: +def configuration_of(section: str, + option: str, + explicit_argument: Optional[T], + fallback: T, + *, + deprecated_envvar: Optional[str] = None) -> Union[str, T]: assert VALID_SECTION_AND_OPTION_RE.fullmatch(section), (section, option) assert VALID_SECTION_AND_OPTION_RE.fullmatch(option), (section, option) if explicit_argument is not None: return explicit_argument - envval = os.environ.get('HAIL_' + section.upper() + '_' + option.upper(), None) + envvar = 'HAIL_' + section.upper() + '_' + option.upper() + envval = os.environ.get(envvar, None) + deprecated_envval = None if deprecated_envvar is None else os.environ.get(deprecated_envvar) if envval is not None: + if deprecated_envval is not None: + raise ValueError(f'Value for configuration variable {section}/{option} is ambiguous ' + f'because both {envvar} and {deprecated_envvar} are set (respectively ' + f'to: {envval} and {deprecated_envval}.') return envval + if deprecated_envval is not None: + warnings.warn(f'Use of deprecated envvar {deprecated_envvar} for configuration variable ' + f'{section}/{option}. Please use {envvar} instead.') + return deprecated_envval from_user_config = get_user_config().get(section, option, fallback=None) if from_user_config is not None: diff --git a/hail/python/test/hail/test_context.py b/hail/python/test/hail/test_context.py index f71f6ba19d5..109b6e4fc93 100644 --- a/hail/python/test/hail/test_context.py +++ b/hail/python/test/hail/test_context.py @@ -1,7 +1,28 @@ +from typing import Tuple, Dict, Optional import unittest import hail as hl from hail.utils.java import Env +from hail.backend.backend import Backend +from hail.backend.spark_backend import SparkBackend +from test.hail.helpers import skip_unless_spark_backend + + +def _scala_map_str_to_tuple_str_str_to_dict(scala) -> Dict[str, Tuple[Optional[str], Optional[str]]]: + it = scala.iterator() + s: Dict[str, Tuple[Optional[str], Optional[str]]] = {} + while it.hasNext(): + kv = it.next() + k = kv._1() + assert isinstance(k, str) + v = kv._2() + l = v._1() + r = v._2() + assert l is None or isinstance(l, str) + assert r is None or isinstance(r, str) + assert k not in s + s[k] = (l, r) + return s class Tests(unittest.TestCase): @@ -16,7 +37,7 @@ def test_init_hail_context_twice(self): hl.init(idempotent=True) # Should be no error - if isinstance(Env.backend(), hl.backend.spark_backend.SparkBackend): + if isinstance(Env.backend(), SparkBackend): hl.init(hl.spark_context(), idempotent=True) # Should be no error def test_top_level_functions_are_do_not_error(self): @@ -25,3 +46,15 @@ def test_top_level_functions_are_do_not_error(self): def test_tmpdir_runs(self): isinstance(hl.tmp_dir(), str) + + def test_get_flags(self): + assert hl._get_flags() == {} + assert list(hl._get_flags('use_new_shuffle')) == ['use_new_shuffle'] + + @skip_unless_spark_backend(reason='requires JVM') + def test_flags_same_in_scala_and_python(self): + b = hl.current_backend() + assert isinstance(b, SparkBackend) + + scala_flag_map = _scala_map_str_to_tuple_str_str_to_dict(b._hail_package.HailFeatureFlags.defaults()) + assert scala_flag_map == Backend._flags_env_vars_and_defaults diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index f44249af84d..865ed48b74b 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -6,7 +6,11 @@ import org.json4s.JsonAST.{JArray, JObject, JString} import scala.collection.mutable object HailFeatureFlags { - val defaults: Map[String, (String, String)] = Map[String, (String, String)]( + val defaults = Map[String, (String, String)]( + // Must match __flags_env_vars_and_defaults in hail/backend/backend.py + // + // The default values and envvars here are only used in the Scala tests. In all other + // conditions, Python initializes the flags, see HailContext._initialize_flags in context.py. ("no_whole_stage_codegen", ("HAIL_DEV_NO_WHOLE_STAGE_CODEGEN" -> null)), ("no_ir_logging", ("HAIL_DEV_NO_IR_LOG" -> null)), ("lower", ("HAIL_DEV_LOWER" -> null)), @@ -17,11 +21,6 @@ object HailFeatureFlags { ("max_leader_scans", ("HAIL_DEV_MAX_LEADER_SCANS" -> "1000")), ("distributed_scan_comb_op", ("HAIL_DEV_DISTRIBUTED_SCAN_COMB_OP" -> null)), ("jvm_bytecode_dump", ("HAIL_DEV_JVM_BYTECODE_DUMP" -> null)), - ("use_packed_int_encoding", ("HAIL_DEV_USE_PACKED_INT_ENCODING" -> null)), - ("use_column_encoding", ("HAIL_DEV_USE_COLUMN_ENCODING" -> null)), - ("use_spicy_ptypes", ("HAIL_USE_SPICY_PTYPES" -> null)), - ("log_service_timing", ("HAIL_DEV_LOG_SERVICE_TIMING" -> null)), - ("cache_service_input", ("HAIL_DEV_CACHE_SERVICE_INPUT" -> null)), ("write_ir_files", ("HAIL_WRITE_IR_FILES" -> null)), ("method_split_ir_limit", ("HAIL_DEV_METHOD_SPLIT_LIMIT" -> "16")), ("use_new_shuffle", ("HAIL_USE_NEW_SHUFFLE" -> null)), @@ -54,13 +53,14 @@ object HailFeatureFlags { ) } -class HailFeatureFlags( +class HailFeatureFlags private ( val flags: mutable.Map[String, String] ) extends Serializable { val available: java.util.ArrayList[String] = new java.util.ArrayList[String](java.util.Arrays.asList[String](flags.keys.toSeq: _*)) def set(flag: String, value: String): Unit = { + assert(exists(flag)) flags.update(flag, value) }