Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ResNet v1 as default backbone #19

Merged
merged 1 commit into from
Apr 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,21 @@ class Parser(hyperparams.Config):
small_instance_weight: float = 3.0
dtype = 'float32'


@dataclasses.dataclass
class DataDecoder(common.DataDecoder):
"""Data decoder config."""
simple_decoder: common.TfExampleDecoder = common.TfExampleDecoder()


@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
decoder: DataDecoder = DataDecoder()
parser: Parser = Parser()
file_type: str = 'tfrecord'


@dataclasses.dataclass
class PanopticDeeplabHead(hyperparams.Config):
"""Panoptic Deeplab head config."""
Expand All @@ -75,16 +78,19 @@ class PanopticDeeplabHead(hyperparams.Config):
low_level_num_filters: Union[List[int], Tuple[int]] = (64, 32)
fusion_num_output_filters: int = 256


@dataclasses.dataclass
class SemanticHead(PanopticDeeplabHead):
"""Semantic head config."""
prediction_kernel_size: int = 1


@dataclasses.dataclass
class InstanceHead(PanopticDeeplabHead):
"""Instance head config."""
prediction_kernel_size: int = 1


@dataclasses.dataclass
class PanopticDeeplabPostProcessor(hyperparams.Config):
"""Panoptic Deeplab PostProcessing config."""
Expand All @@ -99,6 +105,7 @@ class PanopticDeeplabPostProcessor(hyperparams.Config):
keep_k_centers: int = 200
rescale_predictions: bool = True


@dataclasses.dataclass
class PanopticDeeplab(hyperparams.Config):
"""Panoptic Deeplab model config."""
Expand All @@ -116,6 +123,7 @@ class PanopticDeeplab(hyperparams.Config):
generate_panoptic_masks: bool = True
post_processor: PanopticDeeplabPostProcessor = PanopticDeeplabPostProcessor()


@dataclasses.dataclass
class Losses(hyperparams.Config):
label_smoothing: float = 0.0
Expand All @@ -127,6 +135,7 @@ class Losses(hyperparams.Config):
center_heatmap_loss_weight: float = 200
center_offset_loss_weight: float = 0.01


@dataclasses.dataclass
class Evaluation(hyperparams.Config):
""" Evaluation config """
Expand All @@ -141,6 +150,7 @@ class Evaluation(hyperparams.Config):
report_per_class_iou: bool = False
report_train_mean_iou: bool = True # Turning this off can speed up training.


@dataclasses.dataclass
class PanopticDeeplabTask(cfg.TaskConfig):
model: PanopticDeeplab = PanopticDeeplab()
Expand Down Expand Up @@ -175,10 +185,9 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
output_stride = 16
aspp_dilation_rates = [6, 12, 18]
multigrid = [1, 2, 4]
stem_type = 'v0'
stem_type = 'v1'
level = int(np.math.log2(output_stride))


config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype='bfloat16', enable_xla=True),
Expand All @@ -191,16 +200,20 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
backbone=backbones.Backbone(
type='dilated_resnet', dilated_resnet=backbones.DilatedResNet(
model_id=50,
stem_type=stem_type,
output_stride=output_stride,
multigrid=multigrid,
stem_type=stem_type)),
se_ratio=0.25,
last_stage_repeats=1,
stochastic_depth_drop_rate=0.2)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level,
num_filters=256,
pool_kernel_size=input_size[:2],
dilation_rates=aspp_dilation_rates,
use_depthwise_convolution=True,
dropout_rate=0.1)),
semantic_head=SemanticHead(
level=level,
Expand Down