Skip to content

Commit 14f35e5

Browse files
committed
improve doc data
1 parent 533025d commit 14f35e5

File tree

6 files changed

+306
-116
lines changed

6 files changed

+306
-116
lines changed

metagentorch/cnn_virus/data.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,18 @@
3838
CODE_ROOT = Path(__file__).parents[0]
3939
PACKAGE_ROOT = Path(__file__).parents[1]
4040

41-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 21
41+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 22
4242
class OriginalLabels:
43-
"""Converts labels to species name for original data"""
44-
def __init__(self, p2mapping=None):
43+
"""Converts between labels and species name as per original training dataset"""
44+
def __init__(
45+
self,
46+
p2mapping:Path|None = None # Path to the mapping file. Uses `virus_name_mapping` by default
47+
):
4548
if p2mapping is None:
4649
p2mapping = ProjectFileSystem().data / 'CNN_Virus_data/virus_name_mapping'
47-
assert p2mapping.is_file()
50+
else:
51+
p2mapping = safe_path(p2mapping)
52+
if not p2mapping.is_file(): raise FileNotFoundError(f"Mapping file not found at {p2mapping}")
4853
df = pd.read_csv(p2mapping, sep='\t', header=None, names=['species', 'label'])
4954
self._label2species = df['species'].to_list()
5055
self._label2species.append('Unknown Virus Species')
@@ -53,18 +58,20 @@ def __init__(self, p2mapping=None):
5358

5459
def search(self, s:str # string to search through all original virus species
5560
):
56-
"""Prints all species whose name contains the passed string `s`"""
61+
"""Prints all species whose name contains the passed string, with their numerical label"""
5762
print('\n'.join([f"{k}. Label: {v}" for k,v in self._species2label.items() if s in k.lower()]))
5863

5964
def label2species(self, n:int # label to convert to species name
6065
):
66+
"""Converts a numerical label into the correpsonding species label"""
6167
return self._label2species[n]
6268

6369
def species2label(self, s:str # string to convert to label
6470
):
71+
"""Converts a species name into the corresponding label number"""
6572
return self._species2label[s]
6673

67-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 39
74+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 43
6875
class FastaFileReader(TextFileBaseReader):
6976
"""Wrap a FASTA file and retrieve its content in raw format and parsed format"""
7077
def __init__(
@@ -148,7 +155,7 @@ def review(self):
148155
return i+1
149156

150157

151-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 87
158+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 91
152159
class FastqFileReader(TextFileBaseReader):
153160
"""Iterator going through a fastq file's sequences and return each section + prob error as a dict"""
154161
def __init__(
@@ -218,7 +225,7 @@ def parse_file(
218225

219226
return parsed
220227

221-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 102
228+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 106
222229
class AlnFileReader(TextFileBaseReader):
223230
"""Iterator going through an ALN file"""
224231
def __init__(
@@ -415,9 +422,9 @@ def set_header_parsing_rules(
415422
# We used the iterator, now we need to reset it to make all lines available
416423
self.reset_iterator()
417424

418-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 139
425+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 143
419426
class TextFileDataset(IterableDataset):
420-
"""Load data from the text file and create BHE sequence tensor and label,position OHE tensors"""
427+
"""Load data from text file and yield (BHE sequence tensor, (label OHE tensor, position OHE tensor))"""
421428

422429
base2encoding = {
423430
'A': [1,0,0,0,0],
@@ -439,7 +446,10 @@ def __init__(
439446
def __iter__(self):
440447
with open(self.p2file, 'r') as f:
441448
for line in f:
442-
seq, lbl, pos = line.strip().split('\t')
449+
# wi = torch.utils.data.get_worker_info()
450+
# if wi:
451+
# print(f"{wi.id} loading {line}")
452+
seq, lbl, pos = line.replace('\n','').strip().split('\t')
443453
seq_bhe = torch.tensor(list(map(self._bhe_fn, seq)))
444454
lbl_ohe = torch.zeros(self.nb_labels)
445455
lbl_ohe[int(lbl)] = 1
@@ -451,7 +461,7 @@ def _bhe_fn(self, base:str) -> list[int]:
451461
"""Convert a base to a one hot encoding vector"""
452462
return self.base2encoding[base]
453463

454-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 147
464+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 152
455465
class AlnFileDataset(IterableDataset):
456466
"""Load data and metadata from ALN file, yield BHE sequence, OHE label, OHE position tensors + metadata
457467
@@ -505,7 +515,7 @@ def _bhe_fn(self, base:str) -> list[int]:
505515
"""Convert a base to a one hot encoding vector"""
506516
return self.base2encoding[base]
507517

508-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 156
518+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 161
509519
def combine_predictions(
510520
labels:torch.Tensor, # Label predictions for a set of 50-mer corresponding to a single k-mer
511521
label_probs: torch.Tensor, # Probabilities for the labels
@@ -534,7 +544,7 @@ def combine_predictions(
534544
combined_pos = counter_pos.most_common(1)[0][0]
535545
return combined_label, combined_pos
536546

537-
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 157
547+
# %% ../../nbs-dev/03_cnn_virus_data.ipynb 162
538548
def combine_prediction_batch(
539549
probs_elements: tuple[torch.Tensor, torch.Tensor] # Tuple of label and position probabilities for a batch of 50-mer
540550
):

0 commit comments

Comments
 (0)