Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move methods from TXPipe base stage and update documentation #118

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 228 additions & 15 deletions ceci/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pdb
import datetime
import warnings
import socket

from abc import abstractmethod
from . import errors
Expand Down Expand Up @@ -147,7 +148,10 @@ def get_aliased_tag(self, tag):

@abstractmethod
def run(self): # pragma: no cover
"""Run the stage and return the execution status"""
"""Run the stage and return the execution status.

Subclasses must implemented this method.
"""
raise NotImplementedError("run")

def validate(self):
Expand Down Expand Up @@ -348,7 +352,8 @@ def __init_subclass__(cls, **kwargs):
path = pathlib.Path(filename).resolve()

# Add a description of the parameters to the end of the docstring
if stage_is_complete:
# If no config options are specified, omit this.
if stage_is_complete and cls.config_options:
config_text = cls._describe_configuration_text()
if cls.__doc__ is None:
cls.__doc__ = f"Stage {cls.name}\n\nConfiguration Parameters:\n{config_text}"
Expand Down Expand Up @@ -810,13 +815,24 @@ def is_parallel(self):

def is_mpi(self):
"""
Returns True if the stage is being run under MPI.
Check if the stage is being run under MPI.

Returns
-------
bool
True if the stage is being run under MPI
"""
return self._parallel == MPI_PARALLEL

def is_dask(self):
"""
Returns True if the stage is being run in parallel with Dask.
Check if the stage is being run in parallel with Dask.

Returns
-------
bool
True if the stage is being run under MPI

"""
return self._parallel == DASK_PARALLEL

Expand Down Expand Up @@ -967,6 +983,11 @@ def data_ranges_by_rank(self, n_rows, chunk_rows, parallel=True):
Parallel: bool
Whether to split data by rank or just give all procs all data.
Default=True

Returns
-------
start, end: tuple
The start and end of the range of rows to be read by this process
"""
n_chunks = n_rows // chunk_rows
if n_chunks * chunk_rows < n_rows: # pragma: no cover
Expand All @@ -988,6 +1009,17 @@ def get_input(self, tag):
"""
Return the path of an input file with the given tag,
which can be aliased.

Parameters
----------
tag: str
Tag as listed in self.outputs

Returns
-------
path: str
The path to the output file

"""
tag = self.get_aliased_tag(tag)
return self._inputs[tag]
Expand All @@ -1000,7 +1032,21 @@ def get_output(self, tag, final_name=False):
which can be aliased already.

If final_name is False then use a temporary name - file will
be moved to its final name at the end
be moved to its final name at the end. The temporary name
is prefixed with `inprogress_`.

Parameters
----------
tag: str
Tag as listed in self.outputs

final_name: bool
Default=False. Whether to save to the final name.

Returns
-------
path: str
The path to the output file
"""

tag = self.get_aliased_tag(tag)
Expand All @@ -1023,6 +1069,21 @@ def open_input(self, tag, wrapper=False, **kwargs):
For specialized file types like FITS or HDF5 it will return
a more specific object - see the types.py file for more info.

Parameters
----------
tag: str
Tag as listed in self.inputs

wrapper: bool
Whether to return an underlying file object (False) or a data type instance (True)

**kwargs: dict
Extra arguments to pass to the file class constructor

Returns
-------
obj: file or object
The opened file or object
"""
path = self.get_input(tag)
input_class = self.get_input_type(tag)
Expand All @@ -1039,7 +1100,7 @@ def open_output(
Find and open an output file with the given tag, in write mode.

If final_name is True then they will be opened using their final
target output name. Otherwise we will prepend "inprogress_" to their
target output name. Otherwise we will prepend `inprogress_` to their
file name. This means we know that if the final file exists then it
is completed.

Expand All @@ -1050,19 +1111,22 @@ def open_output(

Parameters
----------

tag: str
Tag as listed in self.outputs

wrapper: bool
Default=False. Whether to return a wrapped file
Whether to return an underlying file object (False) or a data type instance (True)

final_name: bool
Default=False. Whether to save to

**kwargs:
Extra args are passed on to the file's class constructor.

Returns
-------
obj: file or object
The opened file or object
"""
path = self.get_output(tag, final_name=final_name)
output_class = self.get_output_type(tag)
Expand Down Expand Up @@ -1107,33 +1171,63 @@ def open_output(
@classmethod
def inputs_(cls):
"""
Return the dict of inputs
Return the dict mapping input tags to file names.

Returns
-------
in_dict : dict[str:str]
"""
return cls.inputs # pylint: disable=no-member

@classmethod
def outputs_(cls):
"""
Return the dict of inputs
Return the dict mapping output tags to file names.

Returns
-------
out_dict : dict[str:str]
"""
return cls.outputs # pylint: disable=no-member

@classmethod
def output_tags(cls):
"""
Return the list of output tags required by this stage
Return the list of output tags required by this stage.

Returns
-------
out_tags : list[str]
The list of output tags
"""
return [tag for tag, _ in cls.outputs_()]

@classmethod
def input_tags(cls):
"""
Return the list of input tags required by this stage
Return the list of input tags required by this stage.

Returns
-------
in_tags : list[str]
The list of input tags
"""
return [tag for tag, _ in cls.inputs_()]

def get_input_type(self, tag):
"""Return the file type class of an input file with the given tag."""
"""
Return the file type class of an input file with the given tag.

Parameters
----------
tag : str
The tag of the input file

Returns
-------
ftype : FileType
The file type class
"""
tag = self.get_aliased_tag(tag)
for t, dt in self.inputs_():
t = self.get_aliased_tag(t)
Expand All @@ -1142,7 +1236,19 @@ def get_input_type(self, tag):
raise ValueError(f"Tag {tag} is not a known input") # pragma: no cover

def get_output_type(self, tag):
"""Return the file type class of an output file with the given tag."""
"""
Return the file type class of an output file with the given tag.

Parameters
----------
tag : str
The tag of the output file

Returns
-------
ftype : FileType
The file type class
"""
tag = self.get_aliased_tag(tag)
for t, dt in self.outputs_():
t = self.get_aliased_tag(t)
Expand All @@ -1162,8 +1268,12 @@ def instance_name(self):
@property
def config(self):
"""
Returns the configuration dictionary for this stage, aggregating command
The configuration dictionary for this stage, aggregating command
line options and optional configuration file.

Options specified in the subclass variable `config_options` are
read from the configuration file, command line, or `make_stage` choices,
and stored in this dictionary.
"""
return self._configs

Expand Down Expand Up @@ -1292,6 +1402,7 @@ def iterate_fits(
Loop through chunks of the input data from a FITS file with the given tag

TODO: add ceci tests of this functions

Parameters
----------
tag: str
Expand Down Expand Up @@ -1380,6 +1491,59 @@ def iterate_hdf(
data = {col: group[col][start:end] for col in cols}
yield start, end, data

def combined_iterators(self, rows, *inputs, parallel=True):
"""
Iterate through multiple files at the same time.

If you have more several HDF files with the some
columns of the same length then you can use this method to
iterate through them all at once, and combine the data from
all of them into a single dictionary.

Parameters
----------
rows: int
The number of rows to read in each chunk

*inputs: list
A list of (tag, group, cols) triples for each file to read.
In each case tag is the input file name tag, group is the
group within the HDF5 file to read, and cols is a list of
columns to read from that group. Specify multiple triplets
to read from multiple files

parallel: bool
Whether to split up data among processes (parallel=True) or give
all processes all data (parallel=False). Default = True.

Returns
-------
it: iterator
Iterator yielding (int, int, dict) tuples of (start, end, data)
"""
if not len(inputs) % 3 == 0:
raise ValueError(
"Arguments to combined_iterators should be in threes: "
"tag, group, value"
)
n = len(inputs) // 3

iterators = []
for i in range(n):
tag = inputs[3 * i]
section = inputs[3 * i + 1]
cols = inputs[3 * i + 2]
iterators.append(
self.iterate_hdf(tag, section, cols, rows, parallel=parallel)
)

for it in zip(*iterators):
data = {}
for (s, e, d) in it:
data.update(d)
yield s, e, data


################################
# Pipeline-related methods
################################
Expand Down Expand Up @@ -1579,3 +1743,52 @@ def generate_cwl(cls, log_dir=None):
# cwl_tool.metadata = cwlgen.Metadata(**metadata)

return cwl_tool


def time_stamp(self, tag):
"""
Print a time stamp with an optional tag.

Parameters
----------
tag: str
Additional info to print in the output line. Default is empty.
"""
t = datetime.datetime.now()
print(f"Process {self.rank}: {tag} {t}")
sys.stdout.flush()

def memory_report(self, tag=None):
"""
Print a report about memory currently available
on the node the process is running on.

Parameters
----------
tag: str
Additional info to print in the output line. Default is empty.
"""
import psutil

t = datetime.datetime.now()

# The different types of memory are really fiddly and don't
# correspond to how you usually imagine. The simplest thing
# to report here is just how much memory is left on the machine.
mem = psutil.virtual_memory()
avail = mem.available / 1024**3
total = mem.total / 1024**3

if tag is None:
tag = ""
else:
tag = f" {tag}:"

# This gives you the name of the host. At NERSC that is the node name
host = socket.gethostname()

# Print messsage
print(
f"{t}: Process {self.rank}:{tag} Remaining memory on {host} {avail:.1f} GB / {total:.1f} GB"
)
sys.stdout.flush()
Loading