Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement mini-epochs in training #60

Merged
merged 5 commits into from
Oct 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 59 additions & 53 deletions clair3/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,48 @@ def call(self, y_true, y_pred):
return reduce_fl


class DataSequence(tf.keras.utils.Sequence):
def __init__(self, data, chunk_list, param, tensor_shape, mini_epochs=1, add_indel_length=False, validation=False):
self.data = data
self.chunk_list = chunk_list
self.batch_size = param.trainBatchSize
self.chunk_size = param.chunk_size
self.chunks_per_batch = self.batch_size // self.chunk_size
self.label_shape_cum = param.label_shape_cum[0:4 if add_indel_length else 2]
self.mini_epochs = mini_epochs
self.mini_epochs_count = -1
self.validation = validation
self.position_matrix = np.empty([self.batch_size] + tensor_shape, np.int32)
self.label = np.empty((self.batch_size, param.label_size), np.float32)
self.random_offset = 0
self.on_epoch_end()

def __len__(self):
return int((len(self.chunk_list) // self.chunks_per_batch) // self.mini_epochs)

def __getitem__(self, index):
mini_epoch_offset = self.mini_epochs_count * self.__len__()
chunk_batch_list = self.chunk_list[(mini_epoch_offset + index) * self.chunks_per_batch:(mini_epoch_offset + index + 1) * self.chunks_per_batch]
for chunk_idx, (bin_id, chunk_id) in enumerate(chunk_batch_list):
start_pos = self.random_offset + chunk_id * self.chunk_size
self.position_matrix[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \
self.data[bin_id].root.position_matrix[start_pos:start_pos + self.chunk_size]
self.label[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \
self.data[bin_id].root.label[start_pos:start_pos + self.chunk_size]

return self.position_matrix, tuple(
np.split(self.label, self.label_shape_cum, axis=1)[:len(self.label_shape_cum)]
)

def on_epoch_end(self):
self.mini_epochs_count += 1
if (self.mini_epochs_count % self.mini_epochs) == 0:
self.mini_epochs_count = 0
if not self.validation:
self.random_offset = np.random.randint(0, self.chunk_size)
np.random.shuffle(self.chunk_list)


def get_chunk_list(chunk_offset, train_chunk_num):
"""
get chunk list for training and validation data. we will randomly split training and validation dataset,
Expand Down Expand Up @@ -100,18 +142,16 @@ def train_model(args):
model = model_path.Clair3_F(add_indel_length=add_indel_length)

tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape
label_size, label_shape = param.label_size, param.label_shape
label_shape_cum = list(accumulate(label_shape))
label_shape = param.label_shape
label_shape_cum = param.label_shape_cum
batch_size, chunk_size = param.trainBatchSize, param.chunk_size
chunks_per_batch = batch_size // chunk_size
random.seed(param.RANDOM_SEED)
np.random.seed(param.RANDOM_SEED)
learning_rate = args.learning_rate if args.learning_rate else param.initialLearningRate
max_epoch = args.maxEpoch if args.maxEpoch else param.maxEpoch
task_num = 4 if add_indel_length else 2
TensorShape = (
tf.TensorShape([None] + tensor_shape), tuple(tf.TensorShape([None, label_shape[task]]) for task in range(task_num)))
TensorDtype = (tf.int32, tuple(tf.float32 for _ in range(task_num)))
mini_epochs = args.mini_epochs

def populate_dataset_table(file_list, file_path):
chunk_offset = np.zeros(len(file_list), dtype=int)
Expand Down Expand Up @@ -155,50 +195,13 @@ def populate_dataset_table(file_list, file_path):
train_chunk_num = len(train_shuffle_chunk_list)
validate_chunk_num = len(validate_shuffle_chunk_list)


def DataGenerator(x, shuffle_chunk_list, train_flag=True):

"""
data generator for pileup or full alignment data processing, pytables with blosc:lz4hc are used for extreme fast
compression and decompression. random chunk shuffling and random start position to increase training model robustness.

"""

batch_num = len(shuffle_chunk_list) // chunks_per_batch
position_matrix = np.empty([batch_size] + tensor_shape, np.int32)
label = np.empty((batch_size, param.label_size), np.float32)

random_start_position = np.random.randint(0, chunk_size) if train_flag else 0
if train_flag:
np.random.shuffle(shuffle_chunk_list)
for batch_idx in range(batch_num):
for chunk_idx in range(chunks_per_batch):
bin_id, chunk_id = shuffle_chunk_list[batch_idx * chunks_per_batch + chunk_idx]
position_matrix[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] = x[bin_id].root.position_matrix[
random_start_position + chunk_id * chunk_size:random_start_position + (chunk_id + 1) * chunk_size]
label[chunk_idx * chunk_size:(chunk_idx + 1) * chunk_size] = x[bin_id].root.label[
random_start_position + chunk_id * chunk_size:random_start_position + (chunk_id + 1) * chunk_size]

if add_indel_length:
yield position_matrix, (
label[:, :label_shape_cum[0]],
label[:, label_shape_cum[0]:label_shape_cum[1]],
label[:, label_shape_cum[1]:label_shape_cum[2]],
label[:, label_shape_cum[2]: ]
)
else:
yield position_matrix, (
label[:, :label_shape_cum[0]],
label[:, label_shape_cum[0]:label_shape_cum[1]]
)


train_dataset = tf.data.Dataset.from_generator(
lambda: DataGenerator(table_dataset_list, train_shuffle_chunk_list, True), TensorDtype,
TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
validate_dataset = tf.data.Dataset.from_generator(
lambda: DataGenerator(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, False), TensorDtype,
TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
train_seq = DataSequence(table_dataset_list, train_shuffle_chunk_list, param, tensor_shape,
mini_epochs=mini_epochs, add_indel_length=add_indel_length)
if add_validation_dataset:
val_seq = DataSequence(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, param, tensor_shape,
mini_epochs=1, add_indel_length=add_indel_length, validation=True)
else:
val_seq = None

total_steps = max_epoch * (train_chunk_num // chunks_per_batch)

Expand Down Expand Up @@ -234,16 +237,16 @@ def DataGenerator(x, shuffle_chunk_list, train_flag=True):
logging.info("[INFO] The training learning_rate: {}".format(learning_rate))
logging.info("[INFO] Total training steps: {}".format(total_steps))
logging.info("[INFO] Maximum training epoch: {}".format(max_epoch))
logging.info("[INFO] Mini-epochs per epoch: {}".format(mini_epochs))
logging.info("[INFO] Start training...")

validate_dataset = validate_dataset if add_validation_dataset else None
if args.chkpnt_fn is not None:
model.load_weights(args.chkpnt_fn)
logging.info("[INFO] Starting from model {}".format(args.chkpnt_fn))

train_history = model.fit(x=train_dataset,
epochs=max_epoch,
validation_data=validate_dataset,
train_history = model.fit(x=train_seq,
epochs=max_epoch * mini_epochs,
validation_data=val_seq,
callbacks=[early_stop_callback,
model_save_callback,
model_best_callback,
Expand Down Expand Up @@ -293,6 +296,9 @@ def main():
parser.add_argument('--exclude_training_samples', type=str, default=None,
help="Define training samples to be excluded")

parser.add_argument('--mini_epochs', type=int, default=1,
help="Number of mini-epochs per epoch")

# Internal process control
## In pileup training mode or not
parser.add_argument('--pileup', action='store_true',
Expand Down
26 changes: 16 additions & 10 deletions clair3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,27 @@ def __init__(self, spatial_pool_size=(3, 2, 1)):
super(PyramidPolling, self).__init__()

self.spatial_pool_size = spatial_pool_size
self.pool_len = len(self.spatial_pool_size)
self.window_h = np.empty(self.pool_len, dtype=int)
self.stride_h = np.empty(self.pool_len, dtype=int)
self.window_w = np.empty(self.pool_len, dtype=int)
self.stride_w = np.empty(self.pool_len, dtype=int)

self.flatten = tf.keras.layers.Flatten()

def call(self, x):

height = int(x.get_shape()[1])
width = int(x.get_shape()[2])

for i in range(len(self.spatial_pool_size)):
def build(self, input_shape):
height = int(input_shape[1])
width = int(input_shape[2])

window_h = stride_h = int(np.ceil(height / self.spatial_pool_size[i]))
for i in range(self.pool_len):
self.window_h[i] = self.stride_h[i] = int(np.ceil(height / self.spatial_pool_size[i]))
self.window_w[i] = self.stride_w[i] = int(np.ceil(width / self.spatial_pool_size[i]))

window_w = stride_w = int(np.ceil(width / self.spatial_pool_size[i]))

max_pool = tf.nn.max_pool(x, ksize=[1, window_h, window_w, 1], strides=[1, stride_h, stride_w, 1],
def call(self, x):
for i in range(self.pool_len):
max_pool = tf.nn.max_pool(x,
ksize=[1, self.window_h[i], self.window_w[i], 1],
strides=[1, self.stride_h[i], self.stride_w[i], 1],
padding='SAME')
if i == 0:
pp = self.flatten(max_pool)
Expand Down