Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExtraDataFunctor for integration with pasha #299

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extra_data/keydata.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,9 @@ def trains(self, keep_dims=False):
yield tid, ds[start]

start += count

def _pasha_functor_(self):
"""Integration with pasha for map operations."""

from .pasha_functor import ExtraDataFunctor
return ExtraDataFunctor(self)
65 changes: 65 additions & 0 deletions extra_data/pasha_functor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from os import getpid

import numpy as np

from . import DataCollection, SourceData, KeyData
from .read_machinery import split_trains


class ExtraDataFunctor:
"""Pasha functor for EXtra-data objects.

This functor wraps an EXtra-data DataCollection, SourceData or
KeyData and performs the map operation over its trains. The kernel
is passed the current train's index in the collection, the train ID
and the data mapping (for DataCollection and SourceData) or data
entry (for KeyData).
"""

def __init__(self, obj):
self.obj = obj
self.n_trains = len(self.obj.train_ids)

# Save PID of parent process where the functor is created to
# close files as appropriately later on, see comment below.
self._parent_pid = getpid()

@classmethod
def wrap(cls, value):
if isinstance(value, (DataCollection, SourceData, KeyData)):
return cls(value)

def split(self, num_workers):
return split_trains(self.n_trains, parts=num_workers)

def iterate(self, share):
subobj = self.obj.select_trains(np.s_[share])

# Older versions of HDF < 1.10.5 are not robust against sharing
# a file descriptor across threads or processes. If running in a
# different process than the functor was initially created in,
# close all file handles inherited from the parent collection to
# force re-opening them again in each child process.
if getpid() != self._parent_pid:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this at all with the dependencies we require?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not really.

It's a bit tricky to be definitive because the change is in HDF5, and h5py can be built with a wide range of HDF5 versions. EXtra-data supports h5py 2.10, and the pre-built packages of 2.10 on PyPI have HDF5 1.10.4. So I might bump the required version to >= 3.0, just to be a bit cautious. I hope very few people are still stuck on h5py 2.x.

for f in subobj.files:
f.close()

index_it = range(*share.indices(self.n_trains))

if isinstance(subobj, SourceData):
# SourceData has no trains() iterator yet, so simulate it
# ourselves by reconstructing a DataCollection object and
# use its trains() iterator.
dc = DataCollection(
subobj.files, {subobj.source: subobj}, subobj.train_ids,
inc_suspect_trains=subobj.inc_suspect_trains,
is_single_run=True)
data_it = ((train_id, data[subobj.source])
for train_id, data in dc.trains())
else:
# Use the regular trains() iterator for DataCollection and
# KeyData
data_it = subobj.trains()

for index, (train_id, data) in zip(index_it, data_it):
yield index, train_id, data
6 changes: 6 additions & 0 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ def trains(self, devices=None, train_range=None, *, require_all=False,
return iter(TrainIterator(dc, require_all=require_all,
flat_keys=flat_keys, keep_dims=keep_dims))

def _pasha_functor_(self):
"""Integration with pasha for map operations."""

from .pasha_functor import ExtraDataFunctor
return ExtraDataFunctor(self)

def train_from_id(
self, train_id, devices=None, *, flat_keys=False, keep_dims=False):
"""Get train data for specified train ID.
Expand Down
6 changes: 6 additions & 0 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,9 @@ def union(self, *others) -> 'SourceData':
section=self.section,
inc_suspect_trains=self.inc_suspect_trains
)

def _pasha_functor_(self):
"""Integration with pasha for map operations."""

from .pasha_functor import ExtraDataFunctor
return ExtraDataFunctor(self)