forked from burke86/deepdisc
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
218 additions
and
118 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
## Training script: | ||
|
||
This directory contains the script used to run the full training, ```run_model.py``` | ||
|
||
Run the script with ```python run_model.py --cfgfile $path_to_config --train-metadata $path_to_train_jsondict --eval-metadata $path_to_eval_dict --num-gpus $ngpu --run-name $name_of_run --output-dir $path_to_output.``` | ||
|
||
You can test this with the double/single_test.json files in ```/tests/deepdisc/test_data/dc2/```. You should download the pre-trained weights [here](https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_swin_b_in21k/f342979038/model_final_246a82.pkl) | ||
|
||
|
||
The command line options are explained below | ||
|
||
- cfgfile: The configuration file used to build the model, learning rate optimizer, trainer, and dataloaders. See ```/configs/solo/solo_swin.py``` for an example config. | ||
- train-metadata: The training data as a list of dicts stored in json format. The dicts should have the "instance detection/segmentation" keys specified in the [detectron2 repo](https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html) | ||
- eval-metadata: The same as the training metadata, but for the evaluation set. | ||
- num-gpus: The number of gpus used to train the model. Must be a multiple of the batch size specified in the config | ||
- run-name: A string prefix that will be used to save the outputs of the script such as model weights and loss curves | ||
- output-dir: The directory to save the outputs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
try: | ||
# ignore ShapelyDeprecationWarning from fvcore | ||
import warnings | ||
from shapely.errors import ShapelyDeprecationWarning | ||
warnings.filterwarnings("ignore", category=sShapelyDeprecationWarning) | ||
except: | ||
pass | ||
warnings.filterwarnings("ignore", category=RuntimeWarning) | ||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
# Some basic setup: | ||
# Setup detectron2 logger | ||
from detectron2.utils.logger import setup_logger | ||
setup_logger() | ||
|
||
import gc | ||
import os | ||
import time | ||
|
||
import detectron2.utils.comm as comm | ||
|
||
# import some common libraries | ||
import numpy as np | ||
import torch | ||
|
||
# import some common detectron2 utilities | ||
from detectron2.config import LazyConfig, get_cfg | ||
from detectron2.engine import launch | ||
|
||
from deepdisc.data_format.augment_image import hsc_test_augs, train_augs | ||
from deepdisc.data_format.image_readers import DC2ImageReader, HSCImageReader | ||
from deepdisc.data_format.register_data import register_data_set | ||
from deepdisc.model.loaders import DictMapper, RedshiftDictMapper, return_test_loader, return_train_loader | ||
from deepdisc.model.models import RedshiftPDFCasROIHeads, return_lazy_model | ||
from deepdisc.training.trainers import ( | ||
return_evallosshook, | ||
return_lazy_trainer, | ||
return_optimizer, | ||
return_savehook, | ||
return_schedulerhook, | ||
) | ||
from deepdisc.utils.parse_arguments import dtype_from_args, make_training_arg_parser | ||
|
||
|
||
def main(args, freeze): | ||
# Hack if you get SSL certificate error | ||
import ssl | ||
ssl._create_default_https_context = ssl._create_unverified_context | ||
|
||
# Handle args | ||
output_dir = args.output_dir | ||
run_name = args.run_name | ||
|
||
# Get file locations | ||
trainfile = args.train_metadata | ||
evalfile = args.eval_metadata | ||
|
||
|
||
cfgfile = args.cfgfile | ||
|
||
# Load the config | ||
cfg = LazyConfig.load(cfgfile) | ||
for key in cfg.get("MISC", dict()).keys(): | ||
cfg[key] = cfg.MISC[key] | ||
|
||
# Register the data sets | ||
astrotrain_metadata = register_data_set( | ||
cfg.DATASETS.TRAIN, trainfile, thing_classes=cfg.metadata.classes | ||
) | ||
astroval_metadata = register_data_set( | ||
cfg.DATASETS.TEST, evalfile, thing_classes=cfg.metadata.classes | ||
) | ||
|
||
# Set the output directory | ||
cfg.OUTPUT_DIR = output_dir | ||
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | ||
|
||
# Iterations for 15, 25, 35, 50 epochs | ||
epoch = cfg.dataloader.epoch | ||
e1 = epoch * 15 | ||
e2 = epoch * 25 | ||
e3 = epoch * 30 | ||
efinal = epoch * 50 | ||
|
||
|
||
#val_per = epoch | ||
val_per = 5 | ||
|
||
model = return_lazy_model(cfg,freeze) | ||
|
||
mapper = cfg.dataloader.train.mapper( | ||
cfg.dataloader.imagereader, cfg.dataloader.key_mapper, cfg.dataloader.augs | ||
).map_data | ||
|
||
|
||
loader = return_train_loader(cfg, mapper) | ||
eval_loader = return_test_loader(cfg, mapper) | ||
|
||
cfg.optimizer.params.model = model | ||
|
||
|
||
if freeze: | ||
|
||
cfg.optimizer.lr = 0.001 | ||
optimizer = return_optimizer(cfg) | ||
|
||
|
||
saveHook = return_savehook(run_name) | ||
lossHook = return_evallosshook(val_per, model, eval_loader) | ||
schedulerHook = return_schedulerhook(optimizer) | ||
hookList = [lossHook, schedulerHook, saveHook] | ||
|
||
trainer = return_lazy_trainer(model, loader, optimizer, cfg, hookList) | ||
trainer.set_period(epoch//2) | ||
#trainer.train(0, e1) | ||
trainer.train(0,10) | ||
if comm.is_main_process(): | ||
np.save(output_dir + run_name + "_losses", trainer.lossList) | ||
np.save(output_dir + run_name + "_val_losses", trainer.vallossList) | ||
|
||
return | ||
|
||
else: | ||
cfg.train.init_checkpoint = os.path.join(output_dir, run_name + ".pth") | ||
cfg.SOLVER.BASE_LR = 0.0001 | ||
cfg.SOLVER.MAX_ITER = efinal # for DefaultTrainer | ||
cfg.SOLVER.STEPS=[e2,e3] | ||
|
||
cfg.optimizer.lr = 0.0001 | ||
|
||
optimizer = return_optimizer(cfg) | ||
schedulerHook = return_schedulerhook(optimizer) | ||
|
||
saveHook = return_savehook(run_name) | ||
lossHook = return_evallosshook(val_per, model, eval_loader) | ||
schedulerHook = return_schedulerhook(optimizer) | ||
hookList = [lossHook, schedulerHook, saveHook] | ||
|
||
trainer = return_lazy_trainer(model, loader, optimizer, cfg, hookList) | ||
trainer.set_period(epoch//2) | ||
#trainer.train(e1, efinal) | ||
trainer.train(10,20) | ||
if comm.is_main_process(): | ||
losses = np.load(output_dir + run_name + "_losses.npy") | ||
losses = np.concatenate((losses, trainer.lossList)) | ||
np.save(output_dir + run_name + "_losses", losses) | ||
return | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
args = make_training_arg_parser().parse_args() | ||
print("Command Line Args:", args) | ||
|
||
print("Training head layers") | ||
freeze = True | ||
t0 = time.time() | ||
launch( | ||
main, | ||
args.num_gpus, | ||
num_machines=args.num_machines, | ||
machine_rank=args.machine_rank, | ||
dist_url=args.dist_url, | ||
args=( | ||
args, | ||
freeze | ||
), | ||
) | ||
|
||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
|
||
###### | ||
# After finetuning the head layers, train the whole model | ||
###### | ||
|
||
print("Training all layers") | ||
freeze = False | ||
t0 = time.time() | ||
launch( | ||
main, | ||
args.num_gpus, | ||
num_machines=args.num_machines, | ||
machine_rank=args.machine_rank, | ||
dist_url=args.dist_url, | ||
args=( | ||
args, | ||
freeze | ||
), | ||
) | ||
|
||
torch.cuda.empty_cache() | ||
gc.collect() | ||
|
||
|
||
|
||
print(f"Took {time.time()-t0} seconds") | ||
|