Skip to content

Commit

Permalink
[query] rework flags and fix QoB flags
Browse files Browse the repository at this point in the history
CHANGELOG: Before this release, Hail always used the same global seed: 0. With this release, Hail generates a random seed at startup. You can restore the old behavior by executing the following before executing any Hail code: `hl.init(global_seed=0)`

Flags now use the same user configuration machinery we use for Batch and QoB. I am not certain
this is the right choice. Feedback very welcome. The configuration_of function lets us uniformly
treat any configuration by checking, in order: explicit argument, envvar, config file, or a
fallback.

I added a bit of code to allow us to support the envvars which do not conform to the new envvar
scheme.

I also removed a few flags that are no longer used.

I kind of think these flags should actually be under a new section like "query_compiler" or
something.

@tpoterba, thoughts?
  • Loading branch information
Daniel King committed Jan 31, 2023
1 parent 9f03cf4 commit 5b737d2
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 30 deletions.
37 changes: 36 additions & 1 deletion hail/python/hail/backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Mapping, List, Union, Tuple, Dict, Optional, Any
from typing import Mapping, List, Union, Tuple, Dict, Optional, Any, AbstractSet
import abc
import orjson
import pkg_resources
import zipfile

from hailtop.config.user_config import configuration_of

from ..fs.fs import FS
from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS
from ..expr import Expression
Expand All @@ -25,6 +28,32 @@ def fatal_error_from_java_error_triplet(short_message, expanded_message, error_i


class Backend(abc.ABC):
# Must match knownFlags in HailFeatureFlags.py
_flags_env_vars_and_defaults: Dict[str, Tuple[str, Optional[str]]] = {
"no_whole_stage_codegen": ("HAIL_DEV_NO_WHOLE_STAGE_CODEGEN", None),
"no_ir_logging": ("HAIL_DEV_NO_IR_LOG", None),
"lower": ("HAIL_DEV_LOWER", None),
"lower_only": ("HAIL_DEV_LOWER_ONLY", None),
"lower_bm": ("HAIL_DEV_LOWER_BM", None),
"max_leader_scans": ("HAIL_DEV_MAX_LEADER_SCANS", "1000"),
"distributed_scan_comb_op": ("HAIL_DEV_DISTRIBUTED_SCAN_COMB_OP", None),
"jvm_bytecode_dump": ("HAIL_DEV_JVM_BYTECODE_DUMP", None),
"write_ir_files": ("HAIL_WRITE_IR_FILES", None),
"method_split_ir_limit": ("HAIL_DEV_METHOD_SPLIT_LIMIT", "16"),
"use_new_shuffle": ("HAIL_USE_NEW_SHUFFLE", None),
"shuffle_max_branch_factor": ("HAIL_SHUFFLE_MAX_BRANCH", "64"),
"shuffle_cutoff_to_local_sort": ("HAIL_SHUFFLE_CUTOFF", "512000000"), # This is in bytes
"grouped_aggregate_buffer_size": ("HAIL_GROUPED_AGGREGATE_BUFFER_SIZE", "50"),
"use_ssa_logs": ("HAIL_USE_SSA_LOGS", None),
"gcs_requester_pays_project": ("HAIL_GCS_REQUESTER_PAYS_PROJECT", None),
"gcs_requester_pays_buckets": ("HAIL_GCS_REQUESTER_PAYS_BUCKETS", None),
"index_branching_factor": ("HAIL_INDEX_BRANCHING_FACTOR", None),
"rng_nonce": ("HAIL_RNG_NONCE", "0")
}

def _valid_flags(self) -> AbstractSet[str]:
return self._flags_env_vars_and_defaults.keys()

@abc.abstractmethod
def __init__(self):
self._persisted_locations = dict()
Expand Down Expand Up @@ -167,6 +196,12 @@ def register_ir_function(self,
def persist_expression(self, expr: Expression) -> Expression:
pass

def _initialize_flags(self) -> None:
self.set_flags(**{
k: configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar)
for k, (deprecated_envvar, default) in Backend._flags_env_vars_and_defaults.items()
})

@abc.abstractmethod
def set_flags(self, **flags: Mapping[str, str]):
"""Set Hail flags."""
Expand Down
1 change: 1 addition & 0 deletions hail/python/hail/backend/local_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor,

if not quiet:
connect_logger(self._utils_package_object, 'localhost', 12888)
self._initialize_flags()

def jvm(self):
return self._jvm
Expand Down
15 changes: 11 additions & 4 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ..fs.fs import FS
from ..fs.router_fs import RouterFS
from ..ir import BaseIR
from ..utils import frozendict


ReferenceGenomeConfig = Dict[str, Any]
Expand Down Expand Up @@ -223,7 +222,7 @@ async def create(*,

flags = {"use_new_shuffle": "1", **(flags or {})}

return ServiceBackend(
sb = ServiceBackend(
billing_project=billing_project,
sync_fs=sync_fs,
async_fs=async_fs,
Expand All @@ -239,6 +238,8 @@ async def create(*,
worker_memory=worker_memory,
name_prefix=name_prefix or '',
)
sb._initialize_flags()
return sb

def __init__(self,
*,
Expand Down Expand Up @@ -632,10 +633,16 @@ def persist_expression(self, expr):
return read_expression(fname, _assert_type=expr.dtype)

def set_flags(self, **flags: str):
unknown_flags = set(flags) - self._valid_flags()
if unknown_flags:
raise ValueError(f'unknown flags: {", ".join(unknown_flags)}')
self.flags.update(flags)

def get_flags(self, *flags) -> Mapping[str, str]:
return frozendict(self.flags)
def get_flags(self, *flags: str) -> Mapping[str, str]:
unknown_flags = set(flags) - self._valid_flags()
if unknown_flags:
raise ValueError(f'unknown flags: {", ".join(unknown_flags)}')
return {flag: self.flags[flag] for flag in flags if flag in self.flags}

@property
def requires_lowering(self):
Expand Down
1 change: 1 addition & 0 deletions hail/python/hail/backend/spark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master,
connect_logger(self._utils_package_object, 'localhost', 12888)

self._jbackend.startProgressBar()
self._initialize_flags()

def jvm(self):
return self._jvm
Expand Down
9 changes: 4 additions & 5 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ def __init__(self, log, quiet, append, tmpdir, local_tmpdir, global_seed, backen
' the latest changes weekly.\n')
sys.stderr.write(f'LOGGING: writing to {log}\n')

self._user_specified_rng_nonce = True
if global_seed is None:
if 'rng_nonce' not in backend.get_flags('rng_nonce'):
backend.set_flags(rng_nonce=hex(Random().randrange(-2**63, 2**63 - 1)))
self._user_specified_rng_nonce = False
self._user_specified_rng_nonce = False
backend.set_flags(rng_nonce=str(Random().randrange(-2**63, 2**63 - 1)))
else:
backend.set_flags(rng_nonce=hex(global_seed))
self._user_specified_rng_nonce = True
backend.set_flags(rng_nonce=str(global_seed))
Env._hc = self

def initialize_references(self, references, default_reference):
Expand Down
19 changes: 17 additions & 2 deletions hail/python/hailtop/config/user_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,31 @@ def get_user_config() -> configparser.ConfigParser:
T = TypeVar('T')


def configuration_of(section: str, option: str, explicit_argument: Optional[str], fallback: T) -> Union[str, T]:
def configuration_of(section: str,
option: str,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
assert VALID_SECTION_AND_OPTION_RE.fullmatch(section), (section, option)
assert VALID_SECTION_AND_OPTION_RE.fullmatch(option), (section, option)

if explicit_argument is not None:
return explicit_argument

envval = os.environ.get('HAIL_' + section.upper() + '_' + option.upper(), None)
envvar = 'HAIL_' + section.upper() + '_' + option.upper()
envval = os.environ.get(envvar, None)
deprecated_envval = None if deprecated_envvar is None else os.environ.get(deprecated_envvar)
if envval is not None:
if deprecated_envval is not None:
raise ValueError(f'Value for configuration variable {section}/{option} is ambiguous '
f'because both {envvar} and {deprecated_envvar} are set (respectively '
f'to: {envval} and {deprecated_envval}.')
return envval
if deprecated_envval is not None:
warnings.warn(f'Use of deprecated envvar {deprecated_envvar} for configuration variable '
f'{section}/{option}. Please use {envvar} instead.')
return deprecated_envval

from_user_config = get_user_config().get(section, option, fallback=None)
if from_user_config is not None:
Expand Down
4 changes: 4 additions & 0 deletions hail/python/test/hail/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ def test_top_level_functions_are_do_not_error(self):

def test_tmpdir_runs(self):
isinstance(hl.tmp_dir(), str)

def test_get_flags(self):
assert hl._get_flags() == {}
assert list(hl._get_flags('use_new_shuffle')) == ['use_new_shuffle']
16 changes: 8 additions & 8 deletions hail/python/test/hail/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,16 @@ def test_matrix_table_entries():
mt = hl.utils.range_matrix_table(5, 2)
mt = mt.annotate_entries(x = hl.rand_int32(5))
expected = [
hl.Struct(row_idx=0, col_idx=0, x=0),
hl.Struct(row_idx=0, col_idx=1, x=3),
hl.Struct(row_idx=0, col_idx=0, x=1),
hl.Struct(row_idx=0, col_idx=1, x=2),
hl.Struct(row_idx=1, col_idx=0, x=2),
hl.Struct(row_idx=1, col_idx=1, x=4),
hl.Struct(row_idx=2, col_idx=0, x=1),
hl.Struct(row_idx=1, col_idx=1, x=3),
hl.Struct(row_idx=2, col_idx=0, x=0),
hl.Struct(row_idx=2, col_idx=1, x=4),
hl.Struct(row_idx=3, col_idx=0, x=4),
hl.Struct(row_idx=3, col_idx=1, x=2),
hl.Struct(row_idx=3, col_idx=0, x=3),
hl.Struct(row_idx=3, col_idx=1, x=0),
hl.Struct(row_idx=4, col_idx=0, x=4),
hl.Struct(row_idx=4, col_idx=1, x=4),
hl.Struct(row_idx=4, col_idx=1, x=3)
]
actual = mt.entries().collect()
assert expected == actual
Expand All @@ -84,7 +84,7 @@ def test_table_filter():
ht = hl.utils.range_table(5)
ht = ht.annotate(x = hl.rand_int32(5))
ht = ht.filter(ht.x % 3 == 0)
expected = [hl.Struct(idx=1, x=3), hl.Struct(idx=3, x=3), hl.Struct(idx=4, x=3)]
expected = [hl.Struct(idx=2, x=3)]
actual = ht.collect()
assert expected == actual

Expand Down
16 changes: 8 additions & 8 deletions hail/src/main/scala/is/hail/HailFeatureFlags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import org.json4s.JsonAST.{JArray, JObject, JString}
import scala.collection.mutable

object HailFeatureFlags {
val defaults: Map[String, (String, String)] = Map[String, (String, String)](
val defaults = Map[String, (String, String)](
// Must match __flags_env_vars_and_defaults in hail/backend/backend.py
//
// The default values and envvars here are only used in the Scala tests. In all other
// conditions, Python initializes the flags, see HailContext._initialize_flags in context.py.
("no_whole_stage_codegen", ("HAIL_DEV_NO_WHOLE_STAGE_CODEGEN" -> null)),
("no_ir_logging", ("HAIL_DEV_NO_IR_LOG" -> null)),
("lower", ("HAIL_DEV_LOWER" -> null)),
Expand All @@ -17,11 +21,6 @@ object HailFeatureFlags {
("max_leader_scans", ("HAIL_DEV_MAX_LEADER_SCANS" -> "1000")),
("distributed_scan_comb_op", ("HAIL_DEV_DISTRIBUTED_SCAN_COMB_OP" -> null)),
("jvm_bytecode_dump", ("HAIL_DEV_JVM_BYTECODE_DUMP" -> null)),
("use_packed_int_encoding", ("HAIL_DEV_USE_PACKED_INT_ENCODING" -> null)),
("use_column_encoding", ("HAIL_DEV_USE_COLUMN_ENCODING" -> null)),
("use_spicy_ptypes", ("HAIL_USE_SPICY_PTYPES" -> null)),
("log_service_timing", ("HAIL_DEV_LOG_SERVICE_TIMING" -> null)),
("cache_service_input", ("HAIL_DEV_CACHE_SERVICE_INPUT" -> null)),
("write_ir_files", ("HAIL_WRITE_IR_FILES" -> null)),
("method_split_ir_limit", ("HAIL_DEV_METHOD_SPLIT_LIMIT" -> "16")),
("use_new_shuffle", ("HAIL_USE_NEW_SHUFFLE" -> null)),
Expand All @@ -32,7 +31,7 @@ object HailFeatureFlags {
("gcs_requester_pays_project", "HAIL_GCS_REQUESTER_PAYS_PROJECT" -> null),
("gcs_requester_pays_buckets", "HAIL_GCS_REQUESTER_PAYS_BUCKETS" -> null),
("index_branching_factor", "HAIL_INDEX_BRANCHING_FACTOR" -> null),
("rng_nonce", "HAIL_RNG_NONCE" -> "0x0")
("rng_nonce", "HAIL_RNG_NONCE" -> "0")
)

def fromEnv(): HailFeatureFlags =
Expand All @@ -54,13 +53,14 @@ object HailFeatureFlags {
)
}

class HailFeatureFlags(
class HailFeatureFlags private (
val flags: mutable.Map[String, String]
) extends Serializable {
val available: java.util.ArrayList[String] =
new java.util.ArrayList[String](java.util.Arrays.asList[String](flags.keys.toSeq: _*))

def set(flag: String, value: String): Unit = {
assert(exists(flag))
flags.update(flag, value)
}

Expand Down
9 changes: 7 additions & 2 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations.{Region, RegionPool}
import is.hail.expr.ir.Threefry
import is.hail.io.fs.FS
import is.hail.utils.{ExecutionTimer, using}
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import java.io._
Expand Down Expand Up @@ -111,7 +111,12 @@ class ExecuteContext(
) extends Closeable {
var backendContext: BackendContext = _

val rngNonce: Long = java.lang.Long.decode(getFlag("rng_nonce"))
val rngNonce: Long = try {
getFlag("rng_nonce").toLong
} catch {
case exc: NumberFormatException =>
fatal(s"Could not parse flag rng_nonce as a 64-bit signed integer: ${getFlag("rng_nonce")}", exc)
}

private val tempFileManager: TempFileManager = if (_tempFileManager != null)
_tempFileManager
Expand Down

0 comments on commit 5b737d2

Please sign in to comment.