38
38
CODE_ROOT = Path (__file__ ).parents [0 ]
39
39
PACKAGE_ROOT = Path (__file__ ).parents [1 ]
40
40
41
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 21
41
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 22
42
42
class OriginalLabels :
43
- """Converts labels to species name for original data"""
44
- def __init__ (self , p2mapping = None ):
43
+ """Converts between labels and species name as per original training dataset"""
44
+ def __init__ (
45
+ self ,
46
+ p2mapping :Path | None = None # Path to the mapping file. Uses `virus_name_mapping` by default
47
+ ):
45
48
if p2mapping is None :
46
49
p2mapping = ProjectFileSystem ().data / 'CNN_Virus_data/virus_name_mapping'
47
- assert p2mapping .is_file ()
50
+ else :
51
+ p2mapping = safe_path (p2mapping )
52
+ if not p2mapping .is_file (): raise FileNotFoundError (f"Mapping file not found at { p2mapping } " )
48
53
df = pd .read_csv (p2mapping , sep = '\t ' , header = None , names = ['species' , 'label' ])
49
54
self ._label2species = df ['species' ].to_list ()
50
55
self ._label2species .append ('Unknown Virus Species' )
@@ -53,18 +58,20 @@ def __init__(self, p2mapping=None):
53
58
54
59
def search (self , s :str # string to search through all original virus species
55
60
):
56
- """Prints all species whose name contains the passed string `s` """
61
+ """Prints all species whose name contains the passed string, with their numerical label """
57
62
print ('\n ' .join ([f"{ k } . Label: { v } " for k ,v in self ._species2label .items () if s in k .lower ()]))
58
63
59
64
def label2species (self , n :int # label to convert to species name
60
65
):
66
+ """Converts a numerical label into the correpsonding species label"""
61
67
return self ._label2species [n ]
62
68
63
69
def species2label (self , s :str # string to convert to label
64
70
):
71
+ """Converts a species name into the corresponding label number"""
65
72
return self ._species2label [s ]
66
73
67
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 39
74
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 43
68
75
class FastaFileReader (TextFileBaseReader ):
69
76
"""Wrap a FASTA file and retrieve its content in raw format and parsed format"""
70
77
def __init__ (
@@ -148,7 +155,7 @@ def review(self):
148
155
return i + 1
149
156
150
157
151
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 87
158
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 91
152
159
class FastqFileReader (TextFileBaseReader ):
153
160
"""Iterator going through a fastq file's sequences and return each section + prob error as a dict"""
154
161
def __init__ (
@@ -218,7 +225,7 @@ def parse_file(
218
225
219
226
return parsed
220
227
221
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 102
228
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 106
222
229
class AlnFileReader (TextFileBaseReader ):
223
230
"""Iterator going through an ALN file"""
224
231
def __init__ (
@@ -415,9 +422,9 @@ def set_header_parsing_rules(
415
422
# We used the iterator, now we need to reset it to make all lines available
416
423
self .reset_iterator ()
417
424
418
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 139
425
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 143
419
426
class TextFileDataset (IterableDataset ):
420
- """Load data from the text file and create BHE sequence tensor and label, position OHE tensors """
427
+ """Load data from text file and yield ( BHE sequence tensor, ( label OHE tensor, position OHE tensor)) """
421
428
422
429
base2encoding = {
423
430
'A' : [1 ,0 ,0 ,0 ,0 ],
@@ -439,7 +446,10 @@ def __init__(
439
446
def __iter__ (self ):
440
447
with open (self .p2file , 'r' ) as f :
441
448
for line in f :
442
- seq , lbl , pos = line .strip ().split ('\t ' )
449
+ # wi = torch.utils.data.get_worker_info()
450
+ # if wi:
451
+ # print(f"{wi.id} loading {line}")
452
+ seq , lbl , pos = line .replace ('\n ' ,'' ).strip ().split ('\t ' )
443
453
seq_bhe = torch .tensor (list (map (self ._bhe_fn , seq )))
444
454
lbl_ohe = torch .zeros (self .nb_labels )
445
455
lbl_ohe [int (lbl )] = 1
@@ -451,7 +461,7 @@ def _bhe_fn(self, base:str) -> list[int]:
451
461
"""Convert a base to a one hot encoding vector"""
452
462
return self .base2encoding [base ]
453
463
454
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 147
464
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 152
455
465
class AlnFileDataset (IterableDataset ):
456
466
"""Load data and metadata from ALN file, yield BHE sequence, OHE label, OHE position tensors + metadata
457
467
@@ -505,7 +515,7 @@ def _bhe_fn(self, base:str) -> list[int]:
505
515
"""Convert a base to a one hot encoding vector"""
506
516
return self .base2encoding [base ]
507
517
508
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 156
518
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 161
509
519
def combine_predictions (
510
520
labels :torch .Tensor , # Label predictions for a set of 50-mer corresponding to a single k-mer
511
521
label_probs : torch .Tensor , # Probabilities for the labels
@@ -534,7 +544,7 @@ def combine_predictions(
534
544
combined_pos = counter_pos .most_common (1 )[0 ][0 ]
535
545
return combined_label , combined_pos
536
546
537
- # %% ../../nbs-dev/03_cnn_virus_data.ipynb 157
547
+ # %% ../../nbs-dev/03_cnn_virus_data.ipynb 162
538
548
def combine_prediction_batch (
539
549
probs_elements : tuple [torch .Tensor , torch .Tensor ] # Tuple of label and position probabilities for a batch of 50-mer
540
550
):
0 commit comments