Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selecting train IDs in DataCollection and SourceData #559

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions extra_data/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@
import numpy as np

from . import locality, voview
from .aliases import AliasIndexer
from .exceptions import (MultiRunError, PropertyNameError, SourceNameError,
TrainIDError)
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (DETECTOR_SOURCE_RE, FilenameInfo, by_id, by_index,
find_proposal, glob_wildcards_re, same_run,
select_train_ids, split_trains)
from .read_machinery import (DETECTOR_SOURCE_RE, by_id, by_index,
find_proposal, glob_wildcards_re, is_int_like,
same_run, select_train_ids)
from .run_files_map import RunFilesMap
from .sourcedata import SourceData
from .utils import available_cpu_cores
from .aliases import AliasIndexer

__all__ = [
'H5File',
Expand Down Expand Up @@ -278,8 +278,13 @@ def __getitem__(self, item):
return self._get_key_data(*item)
elif isinstance(item, str):
return self._get_source_data(item)
elif (
isinstance(item, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(item)
):
return self.select_trains(item)

raise TypeError("Expected data[source] or data[source, key]")
raise TypeError("Expected data[source], data[source, key] or data[train_selection]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While trying to trigger this error, I ran across a surprising (to me) behavior, the following does not raise this error:

run['SA3_XTD10_XGM/XGM/DOOCS:output', 2]

rather, it returns a very sensible object:

<extra_data.SourceData source='SA3_XTD10_XGM/XGM/DOOCS:output' for 1 trains>

I suppose this is part of train_selection but I couldn't easily find it documented anywhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, interesting side effect. but yes, since this essentially does run[source][key] and now that SourceData[key] can be used for filtering trains, it makes sense.
I'm fine keeping that behavior, unless others object.


def _ipython_key_completions_(self):
return list(self.all_sources)
Expand Down
16 changes: 11 additions & 5 deletions extra_data/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import re
from typing import Dict, List, Optional

import numpy as np
import h5py
import numpy as np

from .exceptions import MultiRunError, PropertyNameError, NoDataError
from .exceptions import MultiRunError, NoDataError, PropertyNameError
from .file_access import FileAccess
from .keydata import KeyData
from .read_machinery import (
glob_wildcards_re, same_run, select_train_ids, split_trains, trains_files_index
)
from .read_machinery import (by_id, by_index, glob_wildcards_re, is_int_like,
same_run, select_train_ids, split_trains,
trains_files_index)


class SourceData:
Expand Down Expand Up @@ -67,6 +67,12 @@ def __contains__(self, key):
return res

def __getitem__(self, key):
if (
isinstance(key, (by_id, by_index, list, np.ndarray, slice)) or
is_int_like(key)
):
return self.select_trains(key)

if key not in self:
raise PropertyNameError(key, self.source)
ds0 = self.files[0].file[
Expand Down
6 changes: 6 additions & 0 deletions extra_data/tests/test_reader_mockdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,12 @@
with pytest.raises(IndexError):
run.select_trains(by_index[[480]])

assert run[10].train_ids == [10010]
assert run[by_id[10000]].train_ids == [10000]
assert run[by_index[479:555]].train_ids == [10479]
with pytest.raises(IndexError):
run[555]
Dismissed Show dismissed Hide dismissed


def test_split_trains(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
Expand Down
13 changes: 13 additions & 0 deletions extra_data/tests/test_sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def test_select_trains(mock_spb_raw_run):
assert sel.train_ids == []
assert sel.keys() == xgm.keys()

sel = xgm[by_id[10020:10040]]
assert sel.train_ids == list(range(10020, 10040))

sel = xgm[by_index[:10]]
assert sel.train_ids == list(range(10000, 10010))

sel = xgm[10]
assert sel.train_ids == [10010]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I see that this is the same behavior I flagged above ✔️ so it's only a matter of documentation if it's indeed missing (I could have simply missed it)


sel = xgm[999:1000]
assert sel.train_ids == []
assert sel.keys() == xgm.keys()


def test_split_trains(mock_spb_raw_run):
run = RunDirectory(mock_spb_raw_run)
Expand Down