Skip to content

Commit

Permalink
added cropping
Browse files Browse the repository at this point in the history
  • Loading branch information
felbecker committed Jan 30, 2024
1 parent bf63a48 commit 20606b1
Show file tree
Hide file tree
Showing 8 changed files with 36,033 additions and 49 deletions.
2 changes: 1 addition & 1 deletion learnMSA/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.3.2"
__version__ = "1.3.4"
20 changes: 12 additions & 8 deletions learnMSA/msa_hmm/Align.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import os
import sys
import math
from pathlib import Path
import subprocess
from shutil import which
Expand Down Expand Up @@ -213,12 +214,15 @@ def get_state_expectations(data : SequenceDataset,
num_indices = indices.shape[0]
sorted_indices = np.array([[i,j] for l,i,j in sorted(zip(data.seq_lens[indices], indices, range(num_indices)))])
msa_hmm_layer.cell.recurrent_init()
cell = msa_hmm_layer.cell
old_crop_long_seqs = batch_generator.crop_long_seqs
batch_generator.crop_long_seqs = math.inf #do not crop sequences
ds = train.make_dataset(sorted_indices[:,0],
batch_generator,
batch_size,
shuffle=False)

cell = msa_hmm_layer.cell
shuffle=False,
bucket_by_seq_length=True,
model_lengths=cell.length)

@tf.function(input_signature=[[tf.TensorSpec(x.shape, dtype=x.dtype) for x in encoder.inputs]])
def batch_posterior_state_probs(inputs):
Expand All @@ -233,14 +237,14 @@ def batch_posterior_state_probs(inputs):

if reduce:
posterior_probs = tf.zeros((cell.num_models, cell.max_num_states), cell.dtype)
for inputs, _ in ds:
for (*inputs, _), _ in ds:
posterior_probs += batch_posterior_state_probs(inputs)
return posterior_probs.numpy()
else:
posterior_probs = np.zeros((cell.num_models, num_indices, cell.max_num_states), cell.dtype)
for i, (inputs, _) in enumerate(ds):
posterior_probs[:,i*batch_size : (i+1)*batch_size] = batch_posterior_state_probs(inputs)
return posterior_probs
for (*inputs, batch_indices), _ in ds:
posterior_probs[:,batch_indices] = batch_posterior_state_probs(inputs)
batch_generator.crop_long_seqs = old_crop_long_seqs
return posterior_probs


def get_discard_or_expand_positions(am,
Expand Down
28 changes: 21 additions & 7 deletions learnMSA/msa_hmm/Training.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __call__(self, indices):

def get_out_types(self):
if self.return_only_sequences:
return (tf.uint8)
return (tf.uint8, )
else:
return (tf.uint8, tf.int64)

Expand Down Expand Up @@ -263,35 +263,44 @@ def __call__(self, indices):

def get_out_types(self):
if self.return_only_sequences:
return (tf.uint8)
return (tf.uint8, )
else:
return (tf.uint8, tf.int64, tf.float32)


# batch_generator is a callable object that maps a vector of sequence indices to
# inputs compatible with the model
def make_dataset(indices, batch_generator, batch_size=512, shuffle=True, bucket_by_seq_length=False, model_lengths=[0]):
shuffle = shuffle and not bucket_by_seq_length
batch_generator.shuffle = shuffle
ds = tf.data.Dataset.from_tensor_slices(indices)
ds_len = tf.data.Dataset.from_tensor_slices(batch_generator.data.seq_lens[indices].astype(np.int32))
ds = tf.data.Dataset.zip((ds, ds_len))
if bucket_by_seq_length:
ds_len = tf.data.Dataset.from_tensor_slices(batch_generator.data.seq_lens[indices].astype(np.int32))
ds_ind = tf.data.Dataset.from_tensor_slices(np.arange(indices.size))
ds = tf.data.Dataset.zip((ds, ds_len, ds_ind))
adaptive_batch = batch_generator.config["batch_size"]
if not callable(adaptive_batch):
raise ValueError("""Batch generator must be configured with a configuration that support adaptive batch size callsback,
if bucket_by_seq_length is True.""")
bucket_boundaries = [200, 520, 700, 850, 1200, 2000, 4000, math.inf]
bucket_batch_sizes = [adaptive_batch(model_lengths, b) for b in bucket_boundaries]
ds = ds.bucket_by_sequence_length(
element_length_func=lambda _,L: L,
element_length_func=lambda i,L,j: L,
bucket_boundaries=bucket_boundaries[:-1],
bucket_batch_sizes=bucket_batch_sizes)

batch_func_out_types = batch_generator.get_out_types() + (tf.int64,)
func = (lambda i,j: (batch_generator(i), j)) if len(batch_func_out_types) == 2 else lambda i,j: (*batch_generator(i), j)
batch_func = lambda i,_,j: tf.numpy_function(func=func, inp=[i,j], Tout=batch_func_out_types)
else:
if shuffle:
ds = ds.shuffle(indices.size, reshuffle_each_iteration=True)
ds = ds.repeat()
ds = ds.batch(batch_size)
ds = ds.map(lambda i,_: tf.numpy_function(func=batch_generator, inp=[i], Tout=batch_generator.get_out_types()),

batch_func = lambda i: tf.numpy_function(func=batch_generator, inp=[i], Tout=batch_generator.get_out_types())

ds = ds.map(batch_func,
# no parallel processing if using an indexed dataset
num_parallel_calls=None if batch_generator.data.indexed else tf.data.AUTOTUNE,
deterministic=True)
Expand All @@ -302,7 +311,7 @@ def make_dataset(indices, batch_generator, batch_size=512, shuffle=True, bucket_
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
ds = ds.with_options(options)
ds = tf.data.Dataset.zip((ds, tf.data.Dataset.from_tensor_slices(tf.zeros(1))))
ds = tf.data.Dataset.zip((ds, tf.data.Dataset.from_tensor_slices(tf.zeros(1)).repeat()))
return ds


Expand All @@ -328,6 +337,11 @@ def fit_model(model_generator,
print("Using sequence weights ", sequence_weights, ".")
else:
print("Don't use sequence weights.")
if batch_generator.crop_long_seqs < math.inf:
num_cropped = np.sum(data.seq_lens[indices] > batch_generator.crop_long_seqs)
if num_cropped > 0:
print(f"""{num_cropped} sequences are longer than {batch_generator.crop_long_seqs} and will be cropped for training.""")
print("To disable cropping, use --crop disable. To change the cropping limit to X, use --crop X.")
def make_and_compile():
model = model_generator(num_seq=data.num_seq,
effective_num_seq=indices.shape[0],
Expand Down
23 changes: 12 additions & 11 deletions learnMSA/msa_hmm/Viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import learnMSA.msa_hmm.Training as train
from learnMSA.msa_hmm.SequenceDataset import SequenceDataset
import time
import math


def viterbi_step(gamma_prev, emission_probs_i, hmm_cell):
Expand Down Expand Up @@ -120,16 +121,16 @@ def get_state_seqs_max_lik(data : SequenceDataset,
num_gpu = len([x.name for x in tf.config.list_logical_devices() if x.device_type == 'GPU'])
num_devices = num_gpu + int(num_gpu==0) #account for the CPU-only case
batch_size = int(batch_size / num_devices)
#compute an optimized order for decoding that sorts sequences of equal length into the same batch
sorted_indices = np.array([[i,j] for l,i,j in sorted(zip(data.seq_lens[indices], indices, range(indices.size)))])
hmm_cell.recurrent_init()
ds = train.make_dataset(sorted_indices[:,0],
old_crop_long_seqs = batch_generator.crop_long_seqs
batch_generator.crop_long_seqs = math.inf #do not crop sequences
ds = train.make_dataset(indices,
batch_generator,
batch_size,
shuffle=False,
bucket_by_seq_length=True,
model_lengths=hmm_cell.length)
seq_len = data.seq_lens[sorted_indices[-1,0]]+1
seq_len = np.amax(data.seq_lens[indices]+1)
#initialize with terminal states
state_seqs_max_lik = np.zeros((hmm_cell.num_models, indices.size, seq_len),
dtype=np.uint16)
Expand All @@ -154,15 +155,15 @@ def call_viterbi_single(inputs):

for i,q in enumerate(hmm_cell.num_states):
state_seqs_max_lik[i] = q-1 #terminal state
i = 0
for inputs, _ in ds:
for (*inputs, batch_indices), _ in ds:
if hasattr(batch_generator, "return_only_sequences") and batch_generator.return_only_sequences:
state_seqs_max_lik_batch = call_viterbi_single(inputs).numpy()
state_seqs_max_lik_batch = call_viterbi_single(inputs[0]).numpy()
else:
state_seqs_max_lik_batch = call_viterbi(inputs).numpy()
_,b,l = state_seqs_max_lik_batch.shape
state_seqs_max_lik[:, i:i+b, :l] = state_seqs_max_lik_batch
i += b
#reorder back to the original order
state_seqs_max_lik = state_seqs_max_lik[:,np.argsort(sorted_indices[:,1])]
state_seqs_max_lik[:, batch_indices, :l] = state_seqs_max_lik_batch

# revert batch generator state
batch_generator.crop_long_seqs = old_crop_long_seqs

return state_seqs_max_lik
13 changes: 10 additions & 3 deletions learnMSA/run/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def error(self, message):
parser.add_argument("--indexed_data", dest="indexed_data", action='store_true', help="Don't load all data into memory at once at the cost of training time.")

parser.add_argument("--unaligned_insertions", dest="unaligned_insertions", action='store_true', help="Insertions will be left unaligned.")
parser.add_argument("--crop_long_seqs", dest="crop_long_seqs", type=float, default=math.inf, help="During training, sequences longer than the given value will be cropped randomly. Increases training speed and reduces memory usage, but might produce inaccurate results if too much of the sequences is cropped.")
parser.add_argument("--crop", dest="crop", type=str, default="auto", help="""During training, sequences longer than the given value will be cropped randomly.
Reduces training runtime and memory usage, but might produce inaccurate results if too much of the sequences is cropped. The output alignment will not be cropped.
Can be set to auto in which case sequences longer than 3 times the average length are cropped. Can be set to disable. (default: %(default)s)""")

parser.add_argument("--sequence_weights", dest="sequence_weights", action='store_true', help="Uses mmseqs2 to rapidly cluster the sequences and compute sequence weights before the MSA. (default: %(default)s)")
parser.add_argument("--cluster_dir", dest="cluster_dir", type=str, default="tmp", help="Directory where the sequence clustering is stored. (default: %(default)s)")
Expand Down Expand Up @@ -113,7 +115,6 @@ def error(self, message):
config["surgery_ins"] = args.surgery_ins
config["model_criterion"] = args.model_criterion
config["use_language_model"] = args.use_language_model
config["crop_long_seqs"] = args.crop_long_seqs
transitioners = config["transitioner"] if hasattr(config["transitioner"], '__iter__') else [config["transitioner"]]
for trans in transitioners:
trans.prior.alpha_flank = args.alpha_flank
Expand Down Expand Up @@ -157,6 +158,12 @@ def error(self, message):
try:
with SequenceDataset(args.input_file, "fasta", indexed=args.indexed_data) as data:
data.validate_dataset()
if args.crop == "disable":
config["crop_long_seqs"] = math.inf
elif args.crop == "auto":
config["crop_long_seqs"] = int(np.ceil(3 * np.mean(data.seq_lens)))
else:
config["crop_long_seqs"] = int(args.crop)
_ = Align.run_learnMSA(data,
out_filename = args.output_file,
config = config,
Expand All @@ -167,7 +174,7 @@ def error(self, message):
verbose = not args.silent)
except ValueError as e:
raise SystemExit(e)


if __name__ == '__main__':
run_main()
32 changes: 13 additions & 19 deletions learnMSA_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "59efc747-cd9a-47bc-a38c-ad654566b172",
"metadata": {},
"outputs": [],
"source": [
"# Your fasta file with unaligned sequences.\n",
"\n",
"train_filename = \"test/data/egf.fasta\"\n",
"train_filename = \"test/data/rhv.fasta\"\n",
"\n",
"# Reference file with aligned sequences that have matching IDs to (potentially a subset of) the \n",
"# sequences in the train_file.\n",
"# Replace with empty string if no reference is available.\n",
"ref_filename = \"test/data/egf.ref\"\n",
"ref_filename = \"test/data/rhv.ref\"\n",
"\n",
"# The number of independently trained models.\n",
"num_models = 2\n",
Expand Down Expand Up @@ -370,36 +370,30 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "ceefb9fa-1d18-48d1-991d-12bce94a7529",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"HERE: 1ixa\n",
"HERE: 1apo\n",
"HERE: 1urk\n",
"HERE: 1fsb\n",
"HERE: 1esl\n",
"HERE: 1hre\n",
"HERE: 1epi\n",
"HERE: 4tgf\n",
"HERE: 1hcgb\n",
"HERE: 1dan1\n",
"HERE: 1dan2\n",
"HERE: 1rfnb\n"
"HERE: 1tme\n",
"HERE: 2mev\n",
"HERE: 1bbt\n",
"HERE: 1r1a\n",
"HERE: 4rhv\n",
"HERE: 2plv\n"
]
}
],
"source": [
"!id_list=$(sed -n '/^>/p' {ref_filename} | sed 's/^.//') ; export MAX_N_PID_4_TCOFFEE=10000000 ; t_coffee -other_pg seq_reformat -in test/data/interactive.alignment.fasta -action +extract_seq_list ${{id_list[@]}} +rm_gap > test/data/interactive.projection.fasta"
"!id_list=$(sed -n '/^>/p' {ref_filename} | sed 's/^.//') ; export MAX_N_PID_4_TCOFFEE=10000000 ; t_coffee -other_pg seq_reformat -in rhv.out -action +extract_seq_list ${{id_list[@]}} +rm_gap > test/data/interactive.projection.fasta"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "4fb89f14-dd1f-4b28-8869-59ebc41a1778",
"metadata": {},
"outputs": [
Expand All @@ -409,7 +403,7 @@
"text": [
"*****************************************************\n",
"seq1 seq2 Sim [ALL] Tot \n",
"egf 12 31.1 77.2 [100.0] [ 5182]\n"
"rhv 6 33.1 24.1 [100.0] [20998]\n"
]
}
],
Expand Down
Loading

0 comments on commit 20606b1

Please sign in to comment.