diff --git a/micro_dl/cli/dataset_pooling.py b/micro_dl/cli/dataset_pooling.py index cfc8429d..2af07b83 100644 --- a/micro_dl/cli/dataset_pooling.py +++ b/micro_dl/cli/dataset_pooling.py @@ -38,7 +38,7 @@ def pool_dataset(config): num_workers = pool_config['num_workers'] pool_mode = pool_config['pool_mode'] frames_meta_dst_path = os.path.join(dst_dir, 'frames_meta.csv') - ints_meta_dst_path = os.path.join(dst_dir, 'ints_meta.csv') + ints_meta_dst_path = os.path.join(dst_dir, 'intensity_meta.csv') pos_idx_cur = 0 os.makedirs(dst_dir, exist_ok=True) if os.path.exists(frames_meta_dst_path) and pool_mode == 'add': diff --git a/micro_dl/cli/generate_meta.py b/micro_dl/cli/generate_meta.py index 6bff1b38..cd5ce4ff 100644 --- a/micro_dl/cli/generate_meta.py +++ b/micro_dl/cli/generate_meta.py @@ -36,6 +36,12 @@ def parse_args(): default=4, help="number of workers for multiprocessing", ) + parser.add_argument( + '--block_size', + type=int, + default=256, + help="Pixel block size for intensity sampling", + ) parser.add_argument( '--normalize_im', type=str, @@ -46,19 +52,40 @@ def parse_args(): def main(parsed_args): - meta_utils.frames_meta_generator(parsed_args.input, - parsed_args.order, - parsed_args.name_parser, - ) + """ + Generate metadata for each file by interpreting the file name. + Writes found data in frames_metadata.csv in input directory. + Assumed default file naming convention is: + dir_name + | + |- im_c***_z***_t***_p***.png + |- im_c***_z***_t***_p***.png + + c is channel + z is slice in stack (z) + t is time + p is position (FOV) + + Other naming convention is: + img_channelname_t***_p***_z***.tif for parse_sms_name + + :param argparse parsed_args: Input arguments + """ + # Collect metadata for all image files + meta_utils.frames_meta_generator( + input_dir=parsed_args.input, + order=parsed_args.order, + name_parser=parsed_args.name_parser, + ) + # Compute intensity stats for all images if parsed_args.normalize_im in ['dataset', 'volume', 'slice']: - meta_utils.ints_meta_generator(parsed_args.input, - parsed_args.order, - parsed_args.name_parser, - parsed_args.num_workers, - ) + meta_utils.ints_meta_generator( + input_dir=parsed_args.input, + num_workers=parsed_args.num_workers, + block_size=parsed_args.block_size, + ) if __name__ == '__main__': parsed_args = parse_args() main(parsed_args) - diff --git a/micro_dl/cli/preprocess_script.py b/micro_dl/cli/preprocess_script.py index 9acf2e33..4666413d 100644 --- a/micro_dl/cli/preprocess_script.py +++ b/micro_dl/cli/preprocess_script.py @@ -15,7 +15,6 @@ ImageTilerNonUniform import micro_dl.utils.aux_utils as aux_utils import micro_dl.utils.meta_utils as meta_utils -import micro_dl.utils.preprocess_utils as preprocess_utils def parse_args(): @@ -24,7 +23,6 @@ def parse_args(): In python namespaces are implemented as dictionaries :return: namespace containing the arguments passed. """ - parser = argparse.ArgumentParser() parser.add_argument( '--config', @@ -35,20 +33,101 @@ def parse_args(): return args -def flat_field_correct(params_dict, block_size): +def get_required_params(preprocess_config): + """ + Create a dictionary with required parameters for preprocessing + from the preprocessing config. Required parameters are: + 'input_dir': Directory containing input image data + 'output_dir': Directory to write preprocessed data + 'slice_ids': Slice indices + 'time_ids': Time indices + 'pos_ids': Position indices + 'channel_ids': Channel indices + 'uniform_struct': (bool) If images are uniform + 'int2strlen': (int) How long of a string to convert integers to + 'normalize_channels': (list) Containing bools the length of channels + 'num_workers': Number of workers for multiprocessing + 'normalize_im': (str) Normalization scheme + (stack, dataset, slice, volume) + + :param dict preprocess_config: Preprocessing config + :return dict required_params: Required parameters + """ + input_dir = preprocess_config['input_dir'] + output_dir = preprocess_config['output_dir'] + slice_ids = -1 + if 'slice_ids' in preprocess_config: + slice_ids = preprocess_config['slice_ids'] + + time_ids = -1 + if 'time_ids' in preprocess_config: + time_ids = preprocess_config['time_ids'] + + pos_ids = -1 + if 'pos_ids' in preprocess_config: + pos_ids = preprocess_config['pos_ids'] + + channel_ids = -1 + if 'channel_ids' in preprocess_config: + channel_ids = preprocess_config['channel_ids'] + + uniform_struct = True + if 'uniform_struct' in preprocess_config: + uniform_struct = preprocess_config['uniform_struct'] + + int2str_len = 3 + if 'int2str_len' in preprocess_config: + int2str_len = preprocess_config['int2str_len'] + + num_workers = 4 + if 'num_workers' in preprocess_config: + num_workers = preprocess_config['num_workers'] + + normalize_im = 'stack' + normalize_channels = -1 + if 'normalize' in preprocess_config: + if 'normalize_im' in preprocess_config['normalize']: + normalize_im = preprocess_config['normalize']['normalize_im'] + if 'normalize_channels' in preprocess_config['normalize']: + normalize_channels = preprocess_config['normalize']['normalize_channels'] + if isinstance(channel_ids, list): + assert len(channel_ids) == len(normalize_channels), \ + "Nbr channels {} and normalization {} mismatch".format( + channel_ids, + normalize_channels, + ) + + required_params = { + 'input_dir': input_dir, + 'output_dir': output_dir, + 'slice_ids': slice_ids, + 'time_ids': time_ids, + 'pos_ids': pos_ids, + 'channel_ids': channel_ids, + 'uniform_struct': uniform_struct, + 'int2strlen': int2str_len, + 'normalize_channels': normalize_channels, + 'num_workers': num_workers, + 'normalize_im': normalize_im, + } + return required_params + + +def flat_field_correct(required_params, block_size): """Estimate flat_field_images - :param dict params_dict: dict with keys: input_dir, output_dir, time_ids, + :param dict required_params: dict with keys: input_dir, output_dir, time_ids, channel_ids, pos_ids, slice_ids, int2strlen, uniform_struct, num_workers + :param int block_size: Specify block size if different from default (32 pixels) :return str flat_field_dir: full path of dir with flat field correction images """ flat_field_inst = FlatFieldEstimator2D( - input_dir=params_dict['input_dir'], - output_dir=params_dict['output_dir'], - channel_ids=params_dict['channel_ids'], - slice_ids=params_dict['slice_ids'], + input_dir=required_params['input_dir'], + output_dir=required_params['output_dir'], + channel_ids=required_params['channel_ids'], + slice_ids=required_params['slice_ids'], block_size=block_size, ) flat_field_inst.estimate_flat_field() @@ -56,14 +135,14 @@ def flat_field_correct(params_dict, block_size): return flat_field_dir -def resize_images(params_dict, +def resize_images(required_params, scale_factor, num_slices_subvolume, resize_3d, flat_field_dir): """Resample images first - :param dict params_dict: dict with keys: input_dir, output_dir, time_ids, + :param dict required_params: dict with keys: input_dir, output_dir, time_ids, channel_ids, pos_ids, slice_ids, int2strlen, uniform_struct, num_workers :param int/list scale_factor: scale factor for each dimension :param int num_slices_subvolume: num of slices to be included in each @@ -82,18 +161,18 @@ def resize_images(params_dict, scale_factor = np.array(scale_factor) if np.all(scale_factor == 1): - return params_dict['input_dir'], params_dict['slice_ids'] + return required_params['input_dir'], required_params['slice_ids'] resize_inst = ImageResizer( - input_dir=params_dict['input_dir'], - output_dir=params_dict['output_dir'], + input_dir=required_params['input_dir'], + output_dir=required_params['output_dir'], scale_factor=scale_factor, - channel_ids=params_dict['channel_ids'], - time_ids=params_dict['time_ids'], - slice_ids=params_dict['slice_ids'], - pos_ids=params_dict['pos_ids'], - int2str_len=params_dict['int2strlen'], - num_workers=params_dict['num_workers'], + channel_ids=required_params['channel_ids'], + time_ids=required_params['time_ids'], + slice_ids=required_params['slice_ids'], + pos_ids=required_params['pos_ids'], + int2str_len=required_params['int2strlen'], + num_workers=required_params['num_workers'], flat_field_dir=flat_field_dir ) @@ -102,12 +181,12 @@ def resize_images(params_dict, slice_ids = resize_inst.resize_volumes(num_slices_subvolume) else: resize_inst.resize_frames() - slice_ids = params_dict['slice_ids'] + slice_ids = required_params['slice_ids'] resize_dir = resize_inst.get_resize_dir() return resize_dir, slice_ids -def generate_masks(params_dict, +def generate_masks(required_params, mask_from_channel, flat_field_dir, str_elem_radius, @@ -116,9 +195,10 @@ def generate_masks(params_dict, mask_ext, mask_dir=None, ): - """Generate masks per image or volume + """ + Generate masks per image or volume - :param dict params_dict: dict with keys: input_dir, output_dir, time_ids, + :param dict required_params: dict with keys: input_dir, output_dir, time_ids, channel_ids, pos_ids, slice_ids, int2strlen, uniform_struct, num_workers :param int/list mask_from_channel: generate masks from sum of these channels @@ -141,21 +221,21 @@ def generate_masks(params_dict, ", not {}".format(mask_type) # If generating weights map, input dir is the mask dir - input_dir = params_dict['input_dir'] + input_dir = required_params['input_dir'] if mask_dir is not None: input_dir = mask_dir # Instantiate channel to mask processor mask_processor_inst = MaskProcessor( input_dir=input_dir, - output_dir=params_dict['output_dir'], + output_dir=required_params['output_dir'], channel_ids=mask_from_channel, flat_field_dir=flat_field_dir, - time_ids=params_dict['time_ids'], - slice_ids=params_dict['slice_ids'], - pos_ids=params_dict['pos_ids'], - int2str_len=params_dict['int2strlen'], - uniform_struct=params_dict['uniform_struct'], - num_workers=params_dict['num_workers'], + time_ids=required_params['time_ids'], + slice_ids=required_params['slice_ids'], + pos_ids=required_params['pos_ids'], + int2str_len=required_params['int2strlen'], + uniform_struct=required_params['uniform_struct'], + num_workers=required_params['num_workers'], mask_type=mask_type, mask_channel=mask_channel, mask_ext=mask_ext, @@ -173,36 +253,44 @@ def generate_masks(params_dict, return mask_dir, mask_channel -def generate_zscore_table(params_dict, +def generate_zscore_table(required_params, norm_dict, mask_dir): """ Compute z-score parameters and update frames_metadata based on the normalize_im - :param params_dict: - :param mask_dir: - :return: + :param dict required_params: Required preprocessing parameters + :param dict norm_dict: Normalization scheme (preprocess_config['normalization']) + :param str mask_dir: Directory containing masks """ - frames_metadata = aux_utils.read_meta(params_dict['input_dir']) - ints_metadata = aux_utils.read_meta(params_dict['input_dir'], - meta_fname='ints_meta.csv') + assert 'min_fraction' in norm_dict, \ + "normalization part of config must contain min_fraction" + frames_metadata = aux_utils.read_meta(required_params['input_dir']) + ints_metadata = aux_utils.read_meta( + required_params['input_dir'], + meta_fname='intensity_meta.csv', + ) mask_metadata = aux_utils.read_meta(mask_dir) cols_to_merge = ints_metadata.columns[ints_metadata.columns != 'fg_frac'] - ints_metadata = \ - pd.merge(ints_metadata[cols_to_merge], - mask_metadata[['pos_idx', 'time_idx', 'slice_idx', 'fg_frac']], - how='left', on=['pos_idx', 'time_idx', 'slice_idx']) - _, ints_metadata = \ - meta_utils.compute_zscore_params(frames_metadata, - ints_metadata, - params_dict['input_dir'], - normalize_im=params_dict['normalize_im'], - min_fraction=norm_dict['min_fraction'] - ) - ints_metadata.to_csv(os.path.join(params_dict['input_dir'], 'ints_meta.csv'), - sep=',') - - -def tile_images(params_dict, + ints_metadata = pd.merge( + ints_metadata[cols_to_merge], + mask_metadata[['pos_idx', 'time_idx', 'slice_idx', 'fg_frac']], + how='left', + on=['pos_idx', 'time_idx', 'slice_idx'], + ) + _, ints_metadata = meta_utils.compute_zscore_params( + frames_meta=frames_metadata, + ints_meta=ints_metadata, + input_dir=required_params['input_dir'], + normalize_im=required_params['normalize_im'], + min_fraction=norm_dict['min_fraction'], + ) + ints_metadata.to_csv( + os.path.join(required_params['input_dir'], 'intensity_meta.csv'), + sep=',', + ) + + +def tile_images(required_params, tile_dict, resize_flag, flat_field_dir, @@ -211,7 +299,7 @@ def tile_images(params_dict, """ Tile images. - :param dict params_dict: dict with keys: input_dir, output_dir, time_ids, + :param dict required_params: dict with keys: input_dir, output_dir, time_ids, channel_ids, pos_ids, slice_ids, int2strlen, uniform_struct, num_workers :param dict tile_dict: dict with tiling related keys: tile_size, step_size, image_format, depths, min_fraction. Optional: mask_channel, mask_dir, @@ -236,27 +324,27 @@ def tile_images(params_dict, if 'min_fraction' in tile_dict: min_fraction = tile_dict['min_fraction'] # setup tiling keyword arguments - kwargs = {'input_dir': params_dict['input_dir'], - 'output_dir': params_dict['output_dir'], - 'normalize_channels': params_dict["normalize_channels"], + kwargs = {'input_dir': required_params['input_dir'], + 'output_dir': required_params['output_dir'], + 'normalize_channels': required_params["normalize_channels"], 'tile_size': tile_dict['tile_size'], 'step_size': tile_dict['step_size'], 'depths': tile_dict['depths'], - 'time_ids': params_dict['time_ids'], - 'channel_ids': params_dict['channel_ids'], - 'slice_ids': params_dict['slice_ids'], - 'pos_ids': params_dict['pos_ids'], + 'time_ids': required_params['time_ids'], + 'channel_ids': required_params['channel_ids'], + 'slice_ids': required_params['slice_ids'], + 'pos_ids': required_params['pos_ids'], 'hist_clip_limits': hist_clip_limits, 'flat_field_dir': flat_field_dir, - 'num_workers': params_dict['num_workers'], + 'num_workers': required_params['num_workers'], 'tile_3d': tile_3d, - 'int2str_len': params_dict['int2strlen'], + 'int2str_len': required_params['int2strlen'], 'min_fraction': min_fraction, - 'normalize_im': params_dict['normalize_im'], + 'normalize_im': required_params['normalize_im'], 'tiles_exist': tiles_exist, } - if params_dict['uniform_struct']: + if required_params['uniform_struct']: if tile_3d: if resize_flag: warnings.warn( @@ -293,6 +381,35 @@ def tile_images(params_dict, return tile_dir +def save_config(cur_config, runtime): + """ + Save the current config (cur_config) or append to existing config. + + :param dict cur_config: Current config + :param float runtime: Run time for preprocessing + """ + + # Read preprocessing.json if exists in input dir + parent_dir = cur_config['input_dir'].split(os.sep)[:-1] + parent_dir = os.sep.join(parent_dir) + + prior_config_fname = os.path.join(parent_dir, 'preprocessing_info.json') + prior_preprocess_config = None + if os.path.exists(prior_config_fname): + prior_preprocess_config = aux_utils.read_json(prior_config_fname) + + meta_path = os.path.join(cur_config['output_dir'], + 'preprocessing_info.json') + + processing_info = [{'processing_time': runtime, + 'config': cur_config}] + if prior_preprocess_config is not None: + prior_preprocess_config.append(processing_info[0]) + processing_info = prior_preprocess_config + os.makedirs(cur_config['output_dir'], exist_ok=True) + aux_utils.write_json(processing_info, meta_path) + + def pre_process(preprocess_config): """ Preprocess data. Possible options are: @@ -311,68 +428,32 @@ def pre_process(preprocess_config): :param dict preprocess_config: dict with key options: [input_dir, output_dir, slice_ids, time_ids, pos_ids correct_flat_field, use_masks, masks, tile_stack, tile] - :param dict req_params_dict: dict with commom params for all tasks + :param dict required_params: dict with commom params for all tasks :raises AssertionError: If 'masks' in preprocess_config contains both channels and mask_dir (the former is for generating masks from a channel) """ time_start = time.time() - input_dir = preprocess_config['input_dir'] - output_dir = preprocess_config['output_dir'] - slice_ids = -1 - if 'slice_ids' in preprocess_config: - slice_ids = preprocess_config['slice_ids'] - - time_ids = -1 - if 'time_ids' in preprocess_config: - time_ids = preprocess_config['time_ids'] - - pos_ids = -1 - if 'pos_ids' in preprocess_config: - pos_ids = preprocess_config['pos_ids'] - - channel_ids = -1 - if 'channel_ids' in preprocess_config: - channel_ids = preprocess_config['channel_ids'] - - uniform_struct = True - if 'uniform_struct' in preprocess_config: - uniform_struct = preprocess_config['uniform_struct'] - - int2str_len = 3 - if 'int2str_len' in preprocess_config: - int2str_len = preprocess_config['int2str_len'] - - num_workers = 4 - if 'num_workers' in preprocess_config: - num_workers = preprocess_config['num_workers'] - - normalize_im = 'stack' - normalize_channels = -1 - if 'normalize' in preprocess_config: - if 'normalize_im' in preprocess_config['normalize']: - normalize_im = preprocess_config['normalize']['normalize_im'] - if 'normalize_channels' in preprocess_config['normalize']: - normalize_channels = preprocess_config['normalize']['normalize_channels'] - if isinstance(channel_ids, list): - assert len(channel_ids) == len(normalize_channels), \ - "Nbr channels {} and normalization {} mismatch".format( - channel_ids, - normalize_channels, - ) - - req_params_dict = { - 'input_dir': input_dir, - 'output_dir': output_dir, - 'slice_ids': slice_ids, - 'time_ids': time_ids, - 'pos_ids': pos_ids, - 'channel_ids': channel_ids, - 'uniform_struct': uniform_struct, - 'int2strlen': int2str_len, - 'normalize_channels': normalize_channels, - 'num_workers': num_workers, - 'normalize_im': normalize_im, - } + required_params = get_required_params(preprocess_config) + + # ------------------Check or create metadata--------------------- + try: + # Check if metadata is present + aux_utils.read_meta(required_params['input_dir']) + except AssertionError as e: + print(e, "Generating metadata.") + order = 'cztp' + name_parser = 'parse_sms_name' + if 'metadata' in preprocess_config: + if 'order' in preprocess_config['metadata']: + order = preprocess_config['metadata']['order'] + if 'name_parser' in preprocess_config['metadata']: + name_parser = preprocess_config['metadata']['name_parser'] + # Create metadata from file names instead + meta_utils.frames_meta_generator( + input_dir=required_params['input_dir'], + order=order, + name_parser=name_parser, + ) # -----------------Estimate flat field images-------------------- flat_field_dir = None @@ -384,13 +465,26 @@ def pre_process(preprocess_config): block_size = None if 'block_size' in preprocess_config['flat_field']: block_size = preprocess_config['flat_field']['block_size'] - flat_field_dir = flat_field_correct(req_params_dict, block_size) + flat_field_dir = flat_field_correct(required_params, block_size) preprocess_config['flat_field']['flat_field_dir'] = flat_field_dir elif 'correct' in preprocess_config['flat_field'] and \ preprocess_config['flat_field']['correct']: flat_field_dir = preprocess_config['flat_field']['flat_field_dir'] + # -------Compute intensities for flatfield corrected images------- + if required_params['normalize_im'] in ['dataset', 'volume', 'slice']: + block_size = None + if 'block_size' in preprocess_config['metadata']: + block_size = preprocess_config['metadata']['block_size'] + meta_utils.ints_meta_generator( + input_dir=required_params['input_dir'], + num_workers=required_params['num_workers'], + block_size=block_size, + flat_field_dir=flat_field_dir, + channel_ids=required_params['channel_ids'], + ) + # -------------------------Resize images-------------------------- if 'resize' in preprocess_config: scale_factor = preprocess_config['resize']['scale_factor'] @@ -400,7 +494,7 @@ def pre_process(preprocess_config): preprocess_config['resize']['num_slices_subvolume'] resize_dir, slice_ids = resize_images( - req_params_dict, + required_params, scale_factor, num_slices_subvolume, preprocess_config['resize']['resize_3d'], @@ -409,12 +503,8 @@ def pre_process(preprocess_config): # the images are resized after flat field correction flat_field_dir = None preprocess_config['resize']['resize_dir'] = resize_dir - init_frames_meta = pd.read_csv( - os.path.join(req_params_dict['input_dir'], 'frames_meta.csv') - ) - mask_out_channel = int(init_frames_meta['channel_idx'].max() + 1) - req_params_dict['input_dir'] = resize_dir - req_params_dict['slice_ids'] = slice_ids + required_params['input_dir'] = resize_dir + required_params['slice_ids'] = slice_ids # ------------------------Generate masks------------------------- mask_dir = None @@ -436,7 +526,7 @@ def pre_process(preprocess_config): mask_ext = preprocess_config['masks']['mask_ext'] mask_dir, mask_channel = generate_masks( - params_dict=req_params_dict, + required_params=required_params, mask_from_channel=mask_from_channel, flat_field_dir=flat_field_dir, str_elem_radius=str_elem_radius, @@ -449,24 +539,19 @@ def pre_process(preprocess_config): "Don't specify channels to mask if using pre-generated masks" mask_dir = preprocess_config['masks']['mask_dir'] # Get preexisting masks from directory and match to input dir - mask_meta_fname = None - if 'csv_name' in preprocess_config['masks']: - mask_meta_fname = preprocess_config['masks']['csv_name'] mask_meta = meta_utils.mask_meta_generator( mask_dir, - name_parser='parse_sms_name', ) - frames_meta = aux_utils.read_meta(req_params_dict['input_dir']) + frames_meta = aux_utils.read_meta(required_params['input_dir']) # Automatically assign existing masks the next available channel number - mask_meta['channel_idx'] += (frames_meta['channel_idx'].max() + 1) - # use the first mask channel as the default mask for tiling - mask_channel = int(mask_meta['channel_idx'].unique()[0]) + mask_channel = (frames_meta['channel_idx'].max() + 1) + mask_meta['channel_idx'] = mask_channel # Write metadata mask_meta_fname = os.path.join(mask_dir, 'frames_meta.csv') mask_meta.to_csv(mask_meta_fname, sep=",") # mask_channel = preprocess_utils.validate_mask_meta( # mask_dir=mask_dir, - # input_dir=req_params_dict['input_dir'], + # input_dir=required_params['input_dir'], # csv_name=mask_meta_fname, # mask_channel=mask_channel, # ) @@ -476,12 +561,19 @@ def pre_process(preprocess_config): preprocess_config['masks']['mask_dir'] = mask_dir preprocess_config['masks']['mask_channel'] = mask_channel - if req_params_dict['normalize_im'] in ['dataset', 'volume', 'slice']: + # ---------------------Generate z score table--------------------- + if required_params['normalize_im'] in ['dataset', 'volume', 'slice']: assert mask_dir is not None, \ "'dataset', 'volume', 'slice' normalization requires masks" - generate_zscore_table(req_params_dict, preprocess_config['normalize'], mask_dir) + generate_zscore_table( + required_params, + preprocess_config['normalize'], + mask_dir, + ) # ----------------------Generate weight map----------------------- + weights_dir = None + weights_channel = None if 'make_weight_map' in preprocess_config and preprocess_config['make_weight_map']: # Must have mask dir and mask channel defined to generate weight map assert mask_dir is not None,\ @@ -493,7 +585,7 @@ def pre_process(preprocess_config): weights_channel = mask_channel + 1 # Generate weights weights_dir, _ = generate_masks( - params_dict=req_params_dict, + required_params=required_params, mask_from_channel=mask_channel, flat_field_dir=None, str_elem_radius=5, @@ -519,25 +611,25 @@ def pre_process(preprocess_config): if 'mask_channel' not in preprocess_config['tile']: preprocess_config['tile']['mask_channel'] = mask_channel tile_dir = tile_images( - params_dict=req_params_dict, + required_params=required_params, tile_dict=preprocess_config['tile'], resize_flag=resize_flag, flat_field_dir=flat_field_dir, ) # Tile weight maps as well if they exist if 'weights' in preprocess_config: - weight_params_dict = req_params_dict.copy() - weight_params_dict["input_dir"] = weights_dir - weight_params_dict["channel_ids"] = [weights_channel] + weight_params = required_params.copy() + weight_params["input_dir"] = weights_dir + weight_params["channel_ids"] = [weights_channel] weight_tile_config = preprocess_config['tile'].copy() - weight_params_dict['normalize_channels'] = [False] + weight_params['normalize_channels'] = [False] # Weights depth should be the same as mask depth weight_tile_config['depths'] = 1 weight_tile_config.pop('mask_dir') if 'mask_depth' in preprocess_config['tile']: weight_tile_config['depths'] = [preprocess_config['tile']['mask_depth']] tile_dir = tile_images( - params_dict=weight_params_dict, + required_params=weight_params, tile_dict=weight_tile_config, resize_flag=resize_flag, flat_field_dir=None, @@ -550,30 +642,6 @@ def pre_process(preprocess_config): return preprocess_config, time_el -def save_config(cur_config, runtime): - """Save the cur_config or append to existing config""" - - # Read preprocessing.json if exists in input dir - parent_dir = cur_config['input_dir'].split(os.sep)[:-1] - parent_dir = os.sep.join(parent_dir) - - prior_config_fname = os.path.join(parent_dir, 'preprocessing_info.json') - prior_preprocess_config = None - if os.path.exists(prior_config_fname): - prior_preprocess_config = aux_utils.read_json(prior_config_fname) - - meta_path = os.path.join(cur_config['output_dir'], - 'preprocessing_info.json') - - processing_info = [{'processing_time': runtime, - 'config': cur_config}] - if prior_preprocess_config is not None: - prior_preprocess_config.append(processing_info[0]) - processing_info = prior_preprocess_config - os.makedirs(cur_config['output_dir'], exist_ok=True) - aux_utils.write_json(processing_info, meta_path) - - if __name__ == '__main__': args = parse_args() preprocess_config = aux_utils.read_config(args.config) diff --git a/micro_dl/config_preprocess.yml b/micro_dl/config_preprocess.yml index a70531f9..53ec4bb3 100644 --- a/micro_dl/config_preprocess.yml +++ b/micro_dl/config_preprocess.yml @@ -21,10 +21,12 @@ masks: mask_type: 'unimodal' mask_ext: '.png' make_weight_map: False - tile: tile_size: [256, 256] step_size: [128, 128] depths: [1, 1, 1, 1] image_format: 'zyx' min_fraction: 0.25 +metadata: + order: 'cztp' + name_parser: 'parse_sms_name' diff --git a/micro_dl/deprecated/gen_mask_seg.py b/micro_dl/deprecated/gen_mask_seg.py index 8fc37239..076a6b72 100644 --- a/micro_dl/deprecated/gen_mask_seg.py +++ b/micro_dl/deprecated/gen_mask_seg.py @@ -138,7 +138,8 @@ def create_masks_for_stack(self, str_elem_radius=3): if self.correct_flat_field: cur_image = image_utils.apply_flat_field_correction( - cur_image, flat_field_image=cur_flat_field + cur_image, + flat_field_image=cur_flat_field, ) mask = micro_dl.utils.masks.create_otsu_mask( cur_image, str_elem_size=str_elem_radius diff --git a/micro_dl/deprecated/model_inference.py b/micro_dl/deprecated/model_inference.py index 80424cfc..c04fbee2 100644 --- a/micro_dl/deprecated/model_inference.py +++ b/micro_dl/deprecated/model_inference.py @@ -104,11 +104,12 @@ def _read_one(tp_dir, channel_ids, fname, flat_field_dir=None): fname) cur_image = np.load(cur_fname) if flat_field_dir is not None: - ff_fname = os.path.join(flat_field_dir, - 'flat-field_channel-{}.npy'.format(ch)) - ff_image = np.load(ff_fname) + ff_path = os.path.join( + flat_field_dir, + 'flat-field_channel-{}.npy'.format(ch), + ) cur_image = image_utils.apply_flat_field_correction( - cur_image, flat_field_image=ff_image) + cur_image, flat_field_path=ff_path) cur_image = zscore(cur_image) cur_images.append(cur_image) cur_images = np.stack(cur_images) diff --git a/micro_dl/inference/image_inference.py b/micro_dl/inference/image_inference.py index eb537b5a..ca2718bd 100644 --- a/micro_dl/inference/image_inference.py +++ b/micro_dl/inference/image_inference.py @@ -343,7 +343,6 @@ def _assign_3d_inference(self): if isinstance(num_overlap, list) and \ self.config['network']['class'] != 'UNet3D': num_overlap = self.num_overlap[-1] - overlap_dict = { 'overlap_shape': num_overlap, 'overlap_operation': self.tile_params['overlap_operation'] @@ -396,6 +395,7 @@ def _predict_sub_block_z(self, input_image): num_overlap = self.num_overlap if isinstance(self.num_overlap, list): num_overlap = self.num_overlap[-1] + num_blocks = np.ceil( num_z / (num_slices - num_overlap) ).astype('int') @@ -499,15 +499,14 @@ def unzscore(self, im_target, meta_row): """ - Revert z-score normalization applied during preprocessing. Necessary - before computing SSIM - - :param im_pred: Prediction image, normalized image for un-zscore - :param im_target: Target image to compute stats from - :param pd.DataFrame meta_row: Metadata row for image - :return im_pred: image at its original scale - """ + Revert z-score normalization applied during preprocessing. Necessary + before computing SSIM + :param im_pred: Prediction image, normalized image for un-zscore + :param im_target: Target image to compute stats from + :param pd.DataFrame meta_row: Metadata row for image + :return im_pred: image at its original scale + """ if self.normalize_im is not None: if self.normalize_im in ['dataset', 'volume', 'slice'] \ and ('zscore_median' in meta_row and @@ -526,7 +525,7 @@ def save_pred_image(self, im_pred, metric, meta_row, - pred_chan_name=None, + pred_chan_name=np.nan, ): """ Save predicted images with image extension given in init. @@ -536,9 +535,13 @@ def save_pred_image(self, :param np.array im_pred: 2D / 3D predicted image :param pd.series metric: xy similarity metrics between prediction and target :param pd.DataFrame meta_row: Row of meta dataframe containing sample - :param str/None pred_chan_name: Predicted channel name + :param str/NaN pred_chan_name: Predicted channel name """ - if pred_chan_name is None: + if pd.isnull(pred_chan_name): + if 'channel_name' in meta_row: + pred_chan_name = meta_row['channel_name'] + + if pd.isnull(pred_chan_name): im_name = aux_utils.get_im_name( time_idx=meta_row['time_idx'], channel_idx=meta_row['channel_idx'], @@ -577,7 +580,7 @@ def save_pred_image(self, raise ValueError( 'Unsupported file extension: {}'.format(self.image_ext), ) - if self.save_figs and self.image_ext != '.npy': + if self.save_figs and len(im_target.shape) == 2: # save predicted images assumes 2D fig_dir = os.path.join(self.pred_dir, 'figures') os.makedirs(self.pred_dir, exist_ok=True) @@ -692,7 +695,7 @@ def predict_2d(self, chan_slice_meta): """ Run prediction on 2D or 2.5D on indices given by metadata row. - :param list chan_slice_meta: Inference meta rows + :param pd.DataFrame chan_slice_meta: Inference meta rows :return np.array pred_stack: Prediction :return np.array target_stack: Target :return np.array/list mask_stack: Mask for metrics (empty list if @@ -859,6 +862,7 @@ def predict_3d(self, iteration_rows): pred_image = pred_image.astype(np.float32) target_image = target_image.astype(np.float32) cur_row = self.inf_frames_meta.iloc[iteration_rows.index[0]] + if self.model_task == 'regression': pred_image = self.unzscore( pred_image, @@ -869,10 +873,12 @@ def predict_3d(self, iteration_rows): pred_image = pred_image[0, ...] target_image = target_image[0, ...] input_image = input_image[0, ...] + if self.image_format == 'zyx': input_image = np.transpose(input_image, [0, 2, 3, 1]) pred_image = np.transpose(pred_image, [0, 2, 3, 1]) target_image = np.transpose(target_image, [0, 2, 3, 1]) + # get mask mask_image = None if self.masks_dict is not None: @@ -907,7 +913,6 @@ def run_prediction(self): pred_image, target_image, mask_image, input_image = self.predict_2d( chan_slice_meta, ) - for c, chan_idx in enumerate(self.target_channels): pred_names = [] slice_ids = chan_slice_meta.loc[chan_slice_meta['channel_idx'] == chan_idx, 'slice_idx'].to_list() diff --git a/micro_dl/plotting/plot_utils.py b/micro_dl/plotting/plot_utils.py index e480e099..a690d309 100644 --- a/micro_dl/plotting/plot_utils.py +++ b/micro_dl/plotting/plot_utils.py @@ -47,7 +47,6 @@ def save_predicted_images(input_imgs, axs.axis('off') fig.set_size_inches((12, 5 * n_rows)) axis_count = 0 - # add input images to plot for c, input_img in enumerate(input_imgs): input_imgs[c] = hist_clipping( diff --git a/micro_dl/preprocessing/generate_masks.py b/micro_dl/preprocessing/generate_masks.py index 41f24cbc..906dd489 100644 --- a/micro_dl/preprocessing/generate_masks.py +++ b/micro_dl/preprocessing/generate_masks.py @@ -105,7 +105,7 @@ def __init__(self, self.ints_metadata = None self.channel_thr_df = None if mask_type == 'dataset otsu': - self.ints_metadata = aux_utils.read_meta(self.input_dir, 'ints_meta.csv') + self.ints_metadata = aux_utils.read_meta(self.input_dir, 'intensity_meta.csv') self.channel_thr_df = self.get_channel_thr_df() # for channel_idx in channel_ids: # row_idxs = self.ints_metadata['channel_idx'] == channel_idx diff --git a/micro_dl/utils/aux_utils.py b/micro_dl/utils/aux_utils.py index a0ee9233..9882ed42 100644 --- a/micro_dl/utils/aux_utils.py +++ b/micro_dl/utils/aux_utils.py @@ -173,7 +173,7 @@ def get_im_name(time_idx=None, def get_sms_im_name(time_idx=None, - channel_name=None, + channel_name=np.nan, slice_idx=None, pos_idx=None, extra_field=None, @@ -188,7 +188,7 @@ def get_sms_im_name(time_idx=None, This function will alter list and dict in place. :param int time_idx: Time index - :param str/None channel_name: Channel name + :param str/NaN channel_name: Channel name :param int slice_idx: Slice (z) index :param int pos_idx: Position (FOV) index :param str extra_field: Any extra string you want to include in the name @@ -198,7 +198,7 @@ def get_sms_im_name(time_idx=None, """ im_name = "img" - if channel_name is not None: + if np.isnan(channel_name): im_name += "_" + str(channel_name) if time_idx is not None: im_name += "_t" + str(time_idx).zfill(int2str_len) @@ -389,7 +389,7 @@ def make_dataframe(nbr_rows=None, df_names=DF_NAMES): def read_meta(input_dir, meta_fname='frames_meta.csv'): """ Read metadata file, which is assumed to be named 'frames_meta.csv' - in given directory + in given directory. :param str input_dir: Directory containing data and metadata :param str meta_fname: Metadata file name @@ -403,8 +403,7 @@ def read_meta(input_dir, meta_fname='frames_meta.csv'): frames_metadata = pd.read_csv(meta_fname[0], index_col=0) except IOError as e: raise IOError('cannot read metadata csv file: {}'.format(e)) - # Replace NaNs with None - frames_metadata = frames_metadata.mask(frames_metadata.isna(), None) + return frames_metadata @@ -558,7 +557,7 @@ def parse_idx_from_name(im_name, df_names=DF_NAMES, order="cztp"): "Order needs 4 unique values, not {}".format(order) meta_row = dict.fromkeys(df_names) # Channel name can't be retrieved from image name - meta_row["channel_name"] = None + meta_row["channel_name"] = np.nan meta_row["file_name"] = im_name # Find all integers in name string ints = re.findall(r'\d+', im_name) diff --git a/micro_dl/utils/image_utils.py b/micro_dl/utils/image_utils.py index 10830c9b..e21db3b1 100644 --- a/micro_dl/utils/image_utils.py +++ b/micro_dl/utils/image_utils.py @@ -141,27 +141,20 @@ def apply_flat_field_correction(input_image, **kwargs): """Apply flat field correction. :param np.array input_image: image to be corrected - Kwargs: + Kwargs, either: flat_field_image (np.float): flat_field_image for correction - flat_field_dir (str): dir with split images from stack (or individual - sample images - channel_idx (int): input image channel index + flat_field_path (str): Full path to flatfield image :return: np.array (float) corrected image """ - input_image = input_image.astype('float') if 'flat_field_image' in kwargs: corrected_image = input_image / kwargs['flat_field_image'] - else: - msg = 'flat_field_dir and channel_id are required to fetch flat field image' - assert all(k in kwargs for k in ('flat_field_dir', 'channel_idx')), msg - flat_field_image = np.load( - os.path.join( - kwargs['flat_field_dir'], - 'flat-field_channel-{}.npy'.format(kwargs['channel_idx']), - ) - ) + elif 'flat_field_path' in kwargs: + flat_field_image = np.load(kwargs['flat_field_path']) corrected_image = input_image / flat_field_image + else: + print("Incorrect kwargs: {}, returning input image".format(kwargs)) + corrected_image = input_image.copy() return corrected_image @@ -389,7 +382,6 @@ def preprocess_imstack(frames_metadata, zscore_median = None zscore_iqr = None - # TODO: Are all the normalization schemes below the same now? if normalize_im in ['dataset', 'volume', 'slice']: if 'zscore_median' in frames_metadata: zscore_median = frames_metadata.loc[meta_idx, 'zscore_median'] diff --git a/micro_dl/utils/meta_utils.py b/micro_dl/utils/meta_utils.py index 71d43a55..43a25a0a 100644 --- a/micro_dl/utils/meta_utils.py +++ b/micro_dl/utils/meta_utils.py @@ -1,8 +1,9 @@ -import os -import pandas as pd +import itertools import micro_dl.utils.aux_utils as aux_utils import micro_dl.utils.mp_utils as mp_utils -import itertools +import os +import pandas as pd +import sys def frames_meta_generator( @@ -13,7 +14,7 @@ def frames_meta_generator( """ Generate metadata from file names for preprocessing. Will write found data in frames_metadata.csv in input directory. - Assumed default file naming convention is: + Assumed default file naming convention is for 'parse_idx_from_name': dir_name | |- im_c***_z***_t***_p***.png @@ -24,12 +25,12 @@ def frames_meta_generator( t is time p is position (FOV) - Other naming convention is: + Other naming convention for 'parse_sms_name': img_channelname_t***_p***_z***.tif for parse_sms_name - :param list args: parsed args containing - str input_dir: path to input directory containing images - str name_parser: Function in aux_utils for parsing indices from file name + :param str input_dir: path to input directory containing images + :param str order: Order in which file name encodes cztp + :param str name_parser: Function in aux_utils for parsing indices from file name """ parse_func = aux_utils.import_object('utils.aux_utils', name_parser, 'function') im_names = aux_utils.get_sorted_names(input_dir) @@ -50,12 +51,13 @@ def frames_meta_generator( frames_meta.to_csv(frames_meta_filename, sep=",") return frames_meta + def ints_meta_generator( input_dir, - order='cztp', - name_parser='parse_sms_name', num_workers=4, block_size=256, + flat_field_dir=None, + channel_ids=-1, ): """ Generate pixel intensity metadata for estimating image normalization @@ -78,42 +80,43 @@ def ints_meta_generator( Other naming convention is: img_channelname_t***_p***_z***.tif for parse_sms_name - :param list args: parsed args containing - str input_dir: path to input directory containing images - str name_parser: Function in aux_utils for parsing indices from file name - int num_workers: number of workers for multiprocessing - int block_size: block size for the grid sampling pattern. Default value works + :param str input_dir: path to input directory containing images + :param int num_workers: number of workers for multiprocessing + :param int block_size: block size for the grid sampling pattern. Default value works well for 2048 X 2048 images. + :param str flat_field_dir: Directory containing flatfield images + :param list/int channel_ids: Channel indices to process """ - parse_func = aux_utils.import_object('utils.aux_utils', name_parser, 'function') - im_names = aux_utils.get_sorted_names(input_dir) - channel_names = [] + if block_size is None: + block_size = 256 + frames_metadata = aux_utils.read_meta(input_dir) + if not isinstance(channel_ids, list): + # Use all channels + channel_ids = frames_metadata['channel_idx'].unique() mp_fn_args = [] - # Fill dataframe with rows from image names - for i in range(len(im_names)): - kwargs = {"im_name": im_names[i]} - if name_parser == 'parse_idx_from_name': - kwargs["order"] = order - elif name_parser == 'parse_sms_name': - kwargs["channel_names"] = channel_names - meta_row = parse_func(**kwargs) - meta_row['dir_name'] = input_dir - im_path = os.path.join(input_dir, im_names[i]) - mp_fn_args.append((im_path, block_size, meta_row)) + for i, meta_row in frames_metadata.iterrows(): + im_path = os.path.join(input_dir, meta_row['file_name']) + ff_path = None + if flat_field_dir is not None: + channel_idx = meta_row['channel_idx'] + if isinstance(channel_idx, (int, float)) and channel_idx in channel_ids: + ff_path = os.path.join( + flat_field_dir, + 'flat-field_channel-{}.npy'.format(channel_idx) + ) + mp_fn_args.append((im_path, ff_path, block_size, meta_row)) im_ints_list = mp_utils.mp_sample_im_pixels(mp_fn_args, num_workers) im_ints_list = list(itertools.chain.from_iterable(im_ints_list)) ints_meta = pd.DataFrame.from_dict(im_ints_list) - ints_meta_filename = os.path.join(input_dir, 'ints_meta.csv') + ints_meta_filename = os.path.join(input_dir, 'intensity_meta.csv') ints_meta.to_csv(ints_meta_filename, sep=",") - return ints_meta + def mask_meta_generator( input_dir, - order='cztp', - name_parser='parse_sms_name', num_workers=4, ): """ @@ -122,7 +125,7 @@ def mask_meta_generator( following a grid pattern defined by block_size to for efficient estimation of median and interquatile range. Grid sampling is preferred over random sampling in the case due to the spatial correlation in images. - Will write found data in ints_meta.csv in input directory. + Will write found data in intensity_meta.csv in input directory. Assumed default file naming convention is: dir_name | @@ -137,55 +140,51 @@ def mask_meta_generator( Other naming convention is: img_channelname_t***_p***_z***.tif for parse_sms_name - :param list args: parsed args containing - str input_dir: path to input directory containing images - str name_parser: Function in aux_utils for parsing indices from file name - int num_workers: number of workers for multiprocessing - int block_size: block size for the grid sampling pattern. Default value works - well for 2048 X 2048 images. + :param str input_dir: path to input directory containing images + :param str order: Order in which file name encodes cztp + :param str name_parser: Function in aux_utils for parsing indices from file name + :param int num_workers: number of workers for multiprocessing + :return pd.DataFrame mask_meta: Metadata with mask info """ - parse_func = aux_utils.import_object('utils.aux_utils', name_parser, 'function') - im_names = aux_utils.get_sorted_names(input_dir) - channel_names = [] + frames_metadata = aux_utils.read_meta(input_dir) mp_fn_args = [] - # Fill dataframe with rows from image names - for i in range(len(im_names)): - kwargs = {"im_name": im_names[i]} - if name_parser == 'parse_idx_from_name': - kwargs["order"] = order - elif name_parser == 'parse_sms_name': - kwargs["channel_names"] = channel_names - meta_row = parse_func(**kwargs) + for i, meta_row in frames_metadata.iterrows(): meta_row['dir_name'] = input_dir - im_path = os.path.join(input_dir, im_names[i]) + im_path = os.path.join(input_dir, meta_row['file_name']) mp_fn_args.append((im_path, meta_row)) - meta_row_list = mp_utils.mp_wrapper(mp_utils.get_mask_meta_row, mp_fn_args, num_workers) + meta_row_list = mp_utils.mp_wrapper( + mp_utils.get_mask_meta_row, + mp_fn_args, + num_workers, + ) mask_meta = pd.DataFrame.from_dict(meta_row_list) mask_meta_filename = os.path.join(input_dir, 'mask_meta.csv') mask_meta.to_csv(mask_meta_filename, sep=",") return mask_meta + def compute_zscore_params(frames_meta, ints_meta, input_dir, normalize_im, min_fraction=0.99): - """Get zscore mean and standard deviation - - :param int time_idx: Time index - :param int channel_idx: Channel index - :param int slice_idx: Slice (z) index - :param int pos_idx: Position (FOV) index - :param int slice_ids: Index of which focal plane acquisition to - use (for 2D). - :param str mask_dir: Directory containing masks + """ + Get zscore median and interquartile range + + :param pd.DataFrame frames_meta: Dataframe containing all metadata + :param pd.DataFrame ints_meta: Metadata containing intensity statistics + each z-slice and foreground fraction for masks + :param str input_dir: Directory containing images :param None or str normalize_im: normalization scheme for input images - :param dataframe frames_meta: metadata contains mean and std info of each z-slice - :return float zscore_mean: mean for z-scoring the image - :return float zscore_std: std for z-scoring the image + :param float min_fraction: Minimum foreground fraction (in case of masks) + for computing intensity statistics. + + :return pd.DataFrame frames_meta: Dataframe containing all metadata + :return pd.DataFrame ints_meta: Metadata containing intensity statistics + each z-slice """ assert normalize_im in [None, 'slice', 'volume', 'dataset'], \ @@ -194,9 +193,8 @@ def compute_zscore_params(frames_meta, if normalize_im is None: # No normalization frames_meta['zscore_median'] = 0 - frames_meta['zscore_median'] = 1 + frames_meta['zscore_iqr'] = 1 return frames_meta - elif normalize_im == 'dataset': agg_cols = ['time_idx', 'channel_idx', 'dir_name'] elif normalize_im == 'volume': @@ -215,27 +213,36 @@ def compute_zscore_params(frames_meta, ints_agg.columns = ['zscore_median'] ints_agg['zscore_iqr'] = ints_agg_hq['intensity'] - ints_agg_lq['intensity'] ints_agg.reset_index(inplace=True) - cols_to_merge = \ - frames_meta.columns[[ + + cols_to_merge = frames_meta.columns[[ col not in ['zscore_median', 'zscore_iqr'] for col in frames_meta.columns]] - frames_meta = \ - pd.merge(frames_meta[cols_to_merge], ints_agg, how='left', on=agg_cols) + frames_meta = pd.merge( + frames_meta[cols_to_merge], + ints_agg, + how='left', + on=agg_cols, + ) if frames_meta['zscore_median'].isnull().values.any(): - raise ValueError('Found NaN in normalization parameters. min_fraction might be too low or images might be corrupted.') + raise ValueError('Found NaN in normalization parameters. \ + min_fraction might be too low or images might be corrupted.') frames_meta_filename = os.path.join(input_dir, 'frames_meta.csv') frames_meta.to_csv(frames_meta_filename, sep=",") - cols_to_merge = \ - ints_meta.columns[[ + cols_to_merge = ints_meta.columns[[ col not in ['zscore_median', 'zscore_iqr'] for col in ints_meta.columns]] - ints_meta = \ - pd.merge(ints_meta[cols_to_merge], ints_agg, how='left', on=agg_cols) + ints_meta = pd.merge( + ints_meta[cols_to_merge], + ints_agg, + how='left', + on=agg_cols, + ) ints_meta['intensity_norm'] = \ - (ints_meta['intensity'] - ints_meta['zscore_median']) / ints_meta['zscore_iqr'] - return frames_meta, ints_meta + (ints_meta['intensity'] - ints_meta['zscore_median']) / \ + (ints_meta['zscore_iqr'] + sys.float_info.epsilon) + return frames_meta, ints_meta diff --git a/micro_dl/utils/mp_utils.py b/micro_dl/utils/mp_utils.py index 907346c6..2d206449 100644 --- a/micro_dl/utils/mp_utils.py +++ b/micro_dl/utils/mp_utils.py @@ -369,10 +369,9 @@ def resize_and_save(**kwargs): im = image_utils.read_image(kwargs['file_path']) if kwargs['ff_path'] is not None: - ff_image = np.load(kwargs['ff_path']) im = image_utils.apply_flat_field_correction( im, - flat_field_image=ff_image + flat_field_patjh=kwargs['ff_path'], ) im_resized = image_utils.rescale_image( im=im, @@ -429,10 +428,9 @@ def rescale_vol_and_save(time_idx, cur_fname = frames_metadata.loc[meta_idx, 'file_name'] cur_img = image_utils.read_image(os.path.join(input_dir, cur_fname)) if ff_path is not None: - ff_image = np.load(ff_path) cur_img = image_utils.apply_flat_field_correction( cur_img, - flat_field_image=ff_image + flat_field_path=ff_path, ) input_stack.append(cur_img) input_stack = np.stack(input_stack, axis=2) @@ -459,14 +457,18 @@ def mp_get_im_stats(fn_args, workers): with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions res = ex.map(get_im_stats, fn_args) + for r in res: + print(r) return list(res) def get_im_stats(im_path): - """Read and computes statistics of images - """ + Read and computes statistics of images + :param str im_path: Full path to image + :return dict meta_row: Dict with intensity data for image + """ im = image_utils.read_image(im_path) meta_row = { 'mean': np.nanmean(im), @@ -474,6 +476,7 @@ def get_im_stats(im_path): } return meta_row + def mp_sample_im_pixels(fn_args, workers): """Read and computes statistics of images with multiprocessing @@ -488,12 +491,26 @@ def mp_sample_im_pixels(fn_args, workers): return list(res) -def sample_im_pixels(im_path, grid_spacing, meta_row): - """Read and computes statistics of images - +def sample_im_pixels(im_path, ff_path, grid_spacing, meta_row): + """ + Read and computes statistics of images for each point in a grid. + Grid spacing determines distance in pixels between grid points + for rows and cols. + Applies flatfield correction prior to intensity sampling if flatfield + path is specified. + + :param str im_path: Full path to image + :param str ff_path: Full path to flatfield image corresponding to image + :param int grid_spacing: Distance in pixels between sampling points + :param dict meta_row: Metadata row for image + :return list meta_rows: Dicts with intensity data for each grid point """ - im = image_utils.read_image(im_path) + if ff_path is not None: + im = image_utils.apply_flat_field_correction( + input_image=im, + flat_field_path=ff_path, + ) row_ids, col_ids, sample_values = \ image_utils.grid_sample_pixel_values(im, grid_spacing) @@ -502,6 +519,6 @@ def sample_im_pixels(im_path, grid_spacing, meta_row): 'row_idx': row_idx, 'col_idx': col_idx, 'intensity': sample_value} - for row_idx, col_idx, sample_value - in zip(row_ids, col_ids, sample_values)] + for row_idx, col_idx, sample_value + in zip(row_ids, col_ids, sample_values)] return meta_rows diff --git a/requirements.txt b/requirements.txt index 82045b2a..165a486c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ nose==1.3.7 numpy==1.21.0 opencv-python==4.2.0.32 pandas==1.1.5 +protobuf==3.20.1 pydot==1.4.1 PyYAML>=5.4 scikit-image==0.17.2 diff --git a/requirements_docker.txt b/requirements_docker.txt index f2be3afe..f3c4ca32 100644 --- a/requirements_docker.txt +++ b/requirements_docker.txt @@ -5,6 +5,7 @@ nose==1.3.7 numpy==1.21.0 opencv-python==4.4.0.40 pandas==1.1.5 +protobuf==3.20.1 pydot==1.4.1 PyYAML>=5.4 scikit-image==0.17.2 diff --git a/tests/cli/generate_meta_tests.py b/tests/cli/generate_meta_tests.py index 03760a16..ca70c323 100644 --- a/tests/cli/generate_meta_tests.py +++ b/tests/cli/generate_meta_tests.py @@ -110,3 +110,40 @@ def test_generate_meta_wrong_function(self): name_parser='nonexisting_function', ) generate_meta.main(args) + + def test_generate_intensity_meta(self): + args = argparse.Namespace( + input=self.sms_dir, + name_parser='parse_sms_name', + order="cztp", + normalize_im='dataset', + num_workers=4, + block_size=10, + ) + generate_meta.main(args) + # Check intensity data + ints_meta = pd.read_csv( + os.path.join(self.sms_dir, 'intensity_meta.csv'), + ) + expected_cols = ['channel_idx', + 'pos_idx', + 'slice_idx', + 'time_idx', + 'channel_name', + 'dir_name', + 'file_name', + 'row_idx', + 'col_idx', + 'intensity'] + self.assertListEqual(list(ints_meta)[1:], expected_cols) + # With block size 10 and image size 30x20 there should be 2 points + # per image and there's 20 images + self.assertEqual(ints_meta.shape[0], 40) + # Check values in metadata, every other idx should be (10,10) and (20, 10) + for idx in range(0, 40, 2): + row_even = ints_meta.iloc[idx] + self.assertEqual(row_even['row_idx'], 10) + self.assertEqual(row_even['col_idx'], 10) + row_uneven = ints_meta.iloc[idx + 1] + self.assertEqual(row_uneven['row_idx'], 20) + self.assertEqual(row_uneven['col_idx'], 10) diff --git a/tests/cli/metrics_script_tests.py b/tests/cli/metrics_script_tests.py index 31192741..81ce7fba 100644 --- a/tests/cli/metrics_script_tests.py +++ b/tests/cli/metrics_script_tests.py @@ -135,7 +135,7 @@ def test_compute_metrics(self): for i, row in metrics_xy.iterrows(): expected_name = 't5_p7_xy{}'.format(i) self.assertEqual(row.pred_name, expected_name) - # TODO: Find out why metrics changed + # TODO: Double check values below # self.assertEqual(row.mse, 1.0) # self.assertEqual(row.mae, 1.0) # Same for xyz diff --git a/tests/cli/preprocess_script_test.py b/tests/cli/preprocess_script_test.py index 6ad5cbb3..b017e521 100644 --- a/tests/cli/preprocess_script_test.py +++ b/tests/cli/preprocess_script_test.py @@ -46,6 +46,7 @@ def setUp(self): meta_row = aux_utils.parse_idx_from_name(im_name) meta_row['mean'] = np.nanmean(im) meta_row['std'] = np.nanstd(im) + meta_row['dir_name'] = self.image_dir self.frames_meta = self.frames_meta.append( meta_row, ignore_index=True, @@ -241,6 +242,51 @@ def test_pre_process(self): self.assertTupleEqual(im.shape, (1, 10, 10)) self.assertTrue(im.dtype == np.float64) + def test_pre_process_no_meta(self): + # Remove frames metadata and make sure it regenerates + os.remove(os.path.join(self.image_dir, 'frames_meta.csv')) + out_config, runtime = pp.pre_process(self.pp_config) + self.assertIsInstance(runtime, np.float) + self.assertEqual( + self.base_config['input_dir'], + self.image_dir, + ) + frames_meta = aux_utils.read_meta(self.image_dir) + self.assertTupleEqual(frames_meta.shape, (72, 8)) + + def test_pre_process_intensity_meta(self): + cur_config = self.pp_config + # Use preexisiting masks with more than one class, otherwise + # weight map generation doesn't work + cur_config['normalize'] = { + 'normalize_im': 'volume', + 'min_fraction': .1, + } + cur_config['metadata'] = { + 'block_size': 10, + } + out_config, runtime = pp.pre_process(cur_config) + intensity_meta = aux_utils.read_meta(self.image_dir, 'intensity_meta.csv') + expected_rows = [ + 'channel_idx', + 'pos_idx', + 'slice_idx', + 'time_idx', + 'channel_name', + 'dir_name', + 'file_name', + 'mean', + 'std', + 'row_idx', + 'col_idx', + 'intensity', + 'fg_frac', + 'zscore_median', + 'zscore_iqr', + 'intensity_norm', + ] + self.assertListEqual(list(intensity_meta), expected_rows) + def test_pre_process_weight_maps(self): cur_config = self.pp_config # Use preexisiting masks with more than one class, otherwise @@ -276,7 +322,6 @@ def test_pre_process_weight_maps(self): [self.time_idx], ) # Load one weights file and check contents - print(os.listdir(out_config['weights']['weights_dir'])) im = np.load(os.path.join( out_config['weights']['weights_dir'], 'im_c005_z002_t000_p007.npy', diff --git a/tests/preprocessing/tile_nonuniform_images_tests.py b/tests/preprocessing/tile_nonuniform_images_tests.py index cacd4f6f..c2a8c1be 100644 --- a/tests/preprocessing/tile_nonuniform_images_tests.py +++ b/tests/preprocessing/tile_nonuniform_images_tests.py @@ -306,7 +306,6 @@ def test_tile_mask_stack(self): mask_meta_df.to_csv(os.path.join(mask_dir, 'frames_meta.csv'), sep=',') self.tile_inst.pos_ids = [7] - self.tile_inst.normalize_channels = [None, None, None, False] self.tile_inst.min_fraction = 0.5 self.tile_inst.tile_mask_stack(mask_dir, diff --git a/tests/utils/meta_utils_tests.py b/tests/utils/meta_utils_tests.py new file mode 100644 index 00000000..3cbeb681 --- /dev/null +++ b/tests/utils/meta_utils_tests.py @@ -0,0 +1,196 @@ +import cv2 +import nose.tools +import numpy as np +import os +import pandas as pd +from testfixtures import TempDirectory +import unittest + +import micro_dl.utils.aux_utils as aux_utils +import micro_dl.utils.meta_utils as meta_utils + + +class TestMetaUtils(unittest.TestCase): + + def setUp(self): + """ + Set up a directory with some images to resample + """ + self.tempdir = TempDirectory() + self.temp_path = self.tempdir.path + self.input_dir = os.path.join(self.temp_path, 'input_dir') + self.tempdir.makedir('input_dir') + self.ff_dir = os.path.join(self.temp_path, 'ff_dir') + self.tempdir.makedir('ff_dir') + self.mask_dir = os.path.join(self.temp_path, 'mask_dir') + self.tempdir.makedir('mask_dir') + self.slice_idx = 1 + self.time_idx = 2 + self.im = np.zeros((10, 20), np.uint8) + 5 + self.mask = np.zeros((10, 20), np.uint8) + self.mask[:, 10:] = 1 + ff_im = np.ones((10, 20), np.float) * 2 + # Mask meta file + self.csv_name = 'mask_image_matchup.csv' + self.input_meta = aux_utils.make_dataframe() + # Make input meta + for c in range(3): + ff_path = os.path.join( + self.ff_dir, + 'flat-field_channel-{}.npy'.format(c) + ) + np.save(ff_path, ff_im, allow_pickle=True, fix_imports=True) + for p in range(5): + im_name = aux_utils.get_im_name( + channel_idx=c, + slice_idx=self.slice_idx, + time_idx=self.time_idx, + pos_idx=p, + ) + cv2.imwrite( + os.path.join(self.input_dir, im_name), + self.im + p * 10, + ) + cv2.imwrite( + os.path.join(self.mask_dir, im_name), + self.mask, + ) + meta_row = aux_utils.parse_idx_from_name(im_name) + meta_row['dir_name'] = self.input_dir + self.input_meta = self.input_meta.append( + meta_row, + ignore_index=True, + ) + + def tearDown(self): + """ + Tear down temporary folder and file structure + """ + TempDirectory.cleanup_all() + nose.tools.assert_equal(os.path.isdir(self.temp_path), False) + + def test_frames_meta_generator(self): + frames_meta = meta_utils.frames_meta_generator( + input_dir=self.input_dir, + name_parser='parse_idx_from_name', + ) + for idx, row in frames_meta.iterrows(): + input_row = self.input_meta.iloc[idx] + nose.tools.assert_equal(input_row['file_name'], row['file_name']) + nose.tools.assert_equal(input_row['slice_idx'], row['slice_idx']) + nose.tools.assert_equal(input_row['time_idx'], row['time_idx']) + nose.tools.assert_equal(input_row['channel_idx'], row['channel_idx']) + nose.tools.assert_equal(input_row['pos_idx'], row['pos_idx']) + + def test_ints_meta_generator(self): + # Write metadata + self.input_meta.to_csv( + os.path.join(self.input_dir, 'frames_meta.csv'), + sep=',', + ) + meta_utils.ints_meta_generator( + input_dir=self.input_dir, + block_size=5, + ) + intensity_meta = pd.read_csv( + os.path.join(self.input_dir, 'intensity_meta.csv'), + ) + # There's 15 images and each image should be sampled 3 times + # at col = 5, 10, 15 and row = 5 + self.assertEqual(intensity_meta.shape[0], 45) + # Check one image + meta_im = intensity_meta.loc[ + intensity_meta['file_name'] == 'im_c000_z001_t002_p000.png', + ] + for i, col_idx in enumerate([5, 10, 15]): + self.assertEqual(meta_im.loc[i, 'col_idx'], col_idx) + self.assertEqual(meta_im.loc[i, 'row_idx'], 5) + self.assertEqual(meta_im.loc[i, 'intensity'], 5) + + def test_ints_meta_generator_flatfield(self): + # Write metadata + self.input_meta.to_csv( + os.path.join(self.input_dir, 'frames_meta.csv'), + sep=',', + ) + meta_utils.ints_meta_generator( + input_dir=self.input_dir, + block_size=5, + flat_field_dir=self.ff_dir, + ) + intensity_meta = pd.read_csv( + os.path.join(self.input_dir, 'intensity_meta.csv'), + ) + # There's 15 images and each image should be sampled 3 times + self.assertEqual(intensity_meta.shape[0], 45) + # Check one image + meta_im = intensity_meta.loc[ + intensity_meta['file_name'] == 'im_c000_z001_t002_p000.png', + ] + for i, col_idx in enumerate([5, 10, 15]): + self.assertEqual(meta_im.loc[i, 'col_idx'], col_idx) + self.assertEqual(meta_im.loc[i, 'row_idx'], 5) + self.assertEqual(meta_im.loc[i, 'intensity'], 2.5) + + def test_mask_meta_generator(self): + self.input_meta.to_csv( + os.path.join(self.mask_dir, 'frames_meta.csv'), + sep=',', + ) + mask_meta = meta_utils.mask_meta_generator( + input_dir=self.mask_dir, + ) + self.assertEqual(mask_meta.shape[0], 15) + expected_cols = [ + 'channel_idx', + 'pos_idx', + 'slice_idx', + 'time_idx', + 'channel_name', + 'dir_name', + 'file_name', + 'fg_frac', + ] + self.assertListEqual(list(mask_meta), expected_cols) + # Foreground fraction should be 0.5 + for i in range(15): + self.assertEqual(mask_meta.loc[i, 'fg_frac'], .5) + + def test_compute_zscore_params(self): + self.input_meta.to_csv( + os.path.join(self.input_dir, 'frames_meta.csv'), + sep=',', + ) + meta_utils.ints_meta_generator( + input_dir=self.input_dir, + block_size=5, + ) + intensity_meta = pd.read_csv( + os.path.join(self.input_dir, 'intensity_meta.csv'), + ) + self.input_meta.to_csv( + os.path.join(self.mask_dir, 'frames_meta.csv'), + sep=',', + ) + mask_meta = meta_utils.mask_meta_generator( + input_dir=self.mask_dir, + ) + cols_to_merge = intensity_meta.columns[intensity_meta.columns != 'fg_frac'] + intensity_meta = pd.merge( + intensity_meta[cols_to_merge], + mask_meta[['pos_idx', 'time_idx', 'slice_idx', 'fg_frac']], + how='left', + on=['pos_idx', 'time_idx', 'slice_idx'], + ) + frames_meta, ints_meta = meta_utils.compute_zscore_params( + frames_meta=self.input_meta, + ints_meta=intensity_meta, + input_dir=self.input_dir, + normalize_im='volume', + min_fraction=.4, + ) + # Check medians and iqr values + for i, row in frames_meta.iterrows(): + self.assertEqual(row['zscore_iqr'], 0) + # Added 10 for each p when saving images + self.assertEqual(row['zscore_median'], 5 + row['pos_idx'] * 10) diff --git a/tests/utils/mp_utils_tests.py b/tests/utils/mp_utils_tests.py index 7a9f7c64..fb79b032 100644 --- a/tests/utils/mp_utils_tests.py +++ b/tests/utils/mp_utils_tests.py @@ -1,3 +1,4 @@ +import cv2 import nose.tools import numpy as np import numpy.testing @@ -254,3 +255,42 @@ def test_create_save_mask_border_map(self): for x_coord in range(self.params[0][0] + self.radius, self.params[1][0] - self.radius): distance_near_intersection = weight_map[x_coord, y_coord] nose.tools.assert_equal(max_weight_map, distance_near_intersection) + + +def test_mp_sample_im_pixels(): + with TempDirectory() as tempdir: + temp_path = tempdir.path + im = np.zeros((20, 30), np.uint8) + 50 + im1_path = os.path.join(temp_path, 'im1.tif') + im2_path = os.path.join(temp_path, 'im2.tif') + ff_path = os.path.join(temp_path, 'ff.npy') + cv2.imwrite(im1_path, im) + cv2.imwrite(im2_path, im + 100) + np.save(ff_path, im / 2, allow_pickle=True, fix_imports=True) + meta_row = pd.DataFrame( + [[2, 1, 0, 3]], + columns=['time_idx', 'channel_idx', 'pos_idx', 'slice_idx'], + ) + fn_args = [ + (im1_path, ff_path, 10, meta_row), + (im2_path, ff_path, 10, meta_row), + ] + res = mp_utils.mp_sample_im_pixels(fn_args, 1) + nose.tools.assert_equal(len(res), 2) + # There should be row_idx=10 and col_idx=10, 20 for both images + # and intensity for flatfield corrected images for im1 and im2 + # should be 50/25=2 and 150/25=6 + im1_res = res[0] + nose.tools.assert_equal(im1_res[0]['row_idx'], 10) + nose.tools.assert_equal(im1_res[0]['col_idx'], 10) + nose.tools.assert_equal(im1_res[0]['intensity'], 2.0) + nose.tools.assert_equal(im1_res[1]['row_idx'], 10) + nose.tools.assert_equal(im1_res[1]['col_idx'], 20) + nose.tools.assert_equal(im1_res[1]['intensity'], 2.0) + im2_res = res[1] + nose.tools.assert_equal(im2_res[0]['row_idx'], 10) + nose.tools.assert_equal(im2_res[0]['col_idx'], 10) + nose.tools.assert_equal(im2_res[0]['intensity'], 6.0) + nose.tools.assert_equal(im2_res[1]['row_idx'], 10) + nose.tools.assert_equal(im2_res[1]['col_idx'], 20) + nose.tools.assert_equal(im2_res[1]['intensity'], 6.0)