Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery Starbot ⭐ refactored AgainstEntropy/NLOS-Track #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __getitem__(self, idx):
def split_dataset(phase: str = 'train', train_ratio: float = 0.8, **kwargs):
full_dataset = TrackingDataset(**kwargs)

if phase == 'train':
if phase == 'test':
return full_dataset
elif phase == 'train':
train_size = int(len(full_dataset) * train_ratio)
val_size = len(full_dataset) - train_size
return random_split(full_dataset, [train_size, val_size])
elif phase == 'test':
return full_dataset
Comment on lines -60 to -65
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function split_dataset refactored with the following changes:

  • Simplify conditional into switch-like form (switch)

4 changes: 1 addition & 3 deletions data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

def sub_mean(frames: Tensor) -> Tensor:
mean_frame = frames.mean(axis=0, keepdim=True)
frames_sub_mean = frames.sub(mean_frame)

return frames_sub_mean
return frames.sub(mean_frame)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function sub_mean refactored with the following changes:



def diff(frames: Tensor) -> Tensor:
Expand Down
8 changes: 3 additions & 5 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, model_name: str, pretrained: bool = True,
rnn_hdim: int = 128):
super(PAC_Cell, self).__init__()

assert model_name in ['PAC_Net', 'P_Net', 'C_Net', 'baseline']
assert model_name in {'PAC_Net', 'P_Net', 'C_Net', 'baseline'}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function PAC_Cell.__init__ refactored with the following changes:

self.rnn_hdim = rnn_hdim

self.backbone_builder = {
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(self, model_name: str, pretrained: bool,
self.rnn_hdim = rnn_hdim
self.v_loss = v_loss

assert model_name in ['PAC_Net', 'P_Net', 'C_Net', 'baseline']
assert model_name in {'PAC_Net', 'P_Net', 'C_Net', 'baseline'}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function PAC_Net_Base.__init__ refactored with the following changes:

# CNN
self.backbone_builder = {
'PAC_Net': tvmodels.resnet18,
Expand Down Expand Up @@ -366,9 +366,7 @@ def warm_up(self, I: Tensor):
fx = self.warmup_encoder(rearrange(I, 'b c t h w -> (b t) c h w'))
fx = rearrange(fx, '(b t) d -> b t d', b=B)

hx = self.warmup_rnn(fx)[1] # (2, B, D)

return hx
return self.warmup_rnn(fx)[1]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function C_Net.warm_up refactored with the following changes:

This removes the following comments ( why? ):

# (2, B, D)



class NLOS_baseline(PAC_Net_Base):
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main(cfg):
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

world_size = len(dist_cfgs['device_ids'].split(','))
dist_cfgs['distributed'] = True if world_size > 1 else False
dist_cfgs['distributed'] = world_size > 1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function main refactored with the following changes:

dist_cfgs['world_size'] = world_size
cfg['loader_kwargs']['batch_size'] = cfg['train_configs']['batch_size'] // world_size

Expand Down
6 changes: 3 additions & 3 deletions utils/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def generate_route(self,
self.route_length = route_length

self._init_pv()
for step in range(route_length):
for _ in range(route_length):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function route_generator.generate_route refactored with the following changes:

# print(self.velocity)
self.next_step(turn_rate=turn_rate)
self.e_route.append(self.e_position.copy())
Expand Down Expand Up @@ -116,8 +116,8 @@ def load_route(self,
mat_path = os.path.join(save_dir, mat_name)
save_dict = loadmat(mat_path)

self.e_route = [p for p in save_dict['route']]
self.velocities = [v for v in save_dict['velocities']]
self.e_route = list(save_dict['route'])
self.velocities = list(save_dict['velocities'])
Comment on lines -119 to +120
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function route_generator.load_route refactored with the following changes:

print(f'Load data from {mat_path} successfully!')


Expand Down
42 changes: 22 additions & 20 deletions utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,20 @@ class Trainer_tracking(Trainer_Base):
def __init__(self, cfg):
super(Trainer_tracking, self).__init__(cfg=cfg)

log_train_cfg = {
"model_name": self.model_name,
**self.model_cfgs,
"batch_size": self.train_cfgs['batch_size'],
"v_loss_alpha": self.train_cfgs['v_loss_alpha'],
"loss_total_alpha": self.train_cfgs['loss_total_alpha'],
"resume": self.train_cfgs['resume'],
"route_len": self.dataset_cfgs['route_len'],
"noise_factor": self.dataset_cfgs['noise_factor'],
**self.optim_kwargs,
"epochs": self.schedule_cfgs['max_epoch'],
}

if self.dist_cfgs['local_rank'] == 0:
Comment on lines -195 to 208
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer_tracking.__init__ refactored with the following changes:

log_train_cfg = {
"model_name": self.model_name,
**self.model_cfgs,
"batch_size": self.train_cfgs['batch_size'],
"v_loss_alpha": self.train_cfgs['v_loss_alpha'],
"loss_total_alpha": self.train_cfgs['loss_total_alpha'],
"resume": self.train_cfgs['resume'],
"route_len": self.dataset_cfgs['route_len'],
"noise_factor": self.dataset_cfgs['noise_factor'],
**self.optim_kwargs,
"epochs": self.schedule_cfgs['max_epoch'],
}

self._init_recorder(log_train_cfg)

self.val_metrics = {'x_loss': 0.0,
Expand Down Expand Up @@ -278,14 +278,14 @@ def run(self):
'Metric/val/dtw': val_metric[2],
}, step=epoch + 1)
if self.epoch % 5 == 0:
logger.info(f'Logging images...')
logger.info('Logging images...')
self.test_plot(epoch=self.epoch, phase='train')
self.test_plot(epoch=self.epoch, phase='val')

self.scheduler.step()

if ((epoch + 1) % self.log_cfgs['save_epoch_interval'] == 0) \
or (epoch + 1) == self.schedule_cfgs['max_epoch']:
or (epoch + 1) == self.schedule_cfgs['max_epoch']:
Comment on lines -281 to +288
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer_tracking.run refactored with the following changes:

checkpoint_path = os.path.join(self.ckpt_dir, f"epoch_{(epoch + 1)}.pth")
self.save_checkpoint(checkpoint_path)

Expand Down Expand Up @@ -315,7 +315,7 @@ def train(self, epoch):
dynamic_ncols=True,
ascii=(platform.version() == 'Windows'))

for step in range(len_loader):
for _ in range(len_loader):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer_tracking.train refactored with the following changes:

try:
inputs, labels, map_sizes = next(iter_loader)
except Exception as e:
Expand Down Expand Up @@ -387,7 +387,7 @@ def train(self, epoch):
pbar.close()

return (x_loss_recorder.avg, v_loss_recorder.avg), \
(pcm_recorder.avg, area_recorder.avg, dtw_recorder.avg)
(pcm_recorder.avg, area_recorder.avg, dtw_recorder.avg)

def val(self, epoch):
self.model.eval()
Expand Down Expand Up @@ -482,13 +482,15 @@ def val(self, epoch):
metrics = [self.val_metrics[name] for name in names]
res_table.add_row([f"{m:.4}" if type(m) is float else m for m in metrics[:-1]] + [metrics[-1]])

logger.info(f'Performance on validation set at epoch: {epoch + 1}\n' + res_table.get_string())
logger.info(
f'Performance on validation set at epoch: {epoch + 1}\n{res_table.get_string()}'
)

return (self.val_metrics['x_loss'], self.val_metrics['v_loss']), \
(self.val_metrics['pcm'], self.val_metrics['area'], self.val_metrics['dtw'])
(self.val_metrics['pcm'], self.val_metrics['area'], self.val_metrics['dtw'])
Comment on lines -485 to +490
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer_tracking.val refactored with the following changes:


def test_plot(self, epoch, phase: str):
assert phase in ['train', 'val']
assert phase in {'train', 'val'}
Comment on lines -491 to +493
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Trainer_tracking.test_plot refactored with the following changes:

self.model.eval()
iter_loader = iter(self.val_loader) if phase == 'val' else iter(self.train_loader)
frames, gt_routes, map_sizes = next(iter_loader)
Expand Down
12 changes: 9 additions & 3 deletions utils/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def draw_route(map_size: ndarray, route: ndarray,
lc.set_array(idxs)
lc.set_linewidth(3)
line = ax.add_collection(lc)
fig.colorbar(line, ax=ax, ticks=idxs[::int(len(idxs) / 10)], label='step')
fig.colorbar(line, ax=ax, ticks=idxs[::len(idxs) // 10], label='step')
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function draw_route refactored with the following changes:


ax.set_xlim(0, map_size[0])
ax.set_xlabel('x')
Expand All @@ -40,7 +40,7 @@ def draw_route(map_size: ndarray, route: ndarray,


def draw_routes(routes: tuple[ndarray, ndarray], return_mode: str = None):
assert return_mode in ['plt_fig', 'fig_array', None]
assert return_mode in {'plt_fig', 'fig_array', None}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function draw_routes refactored with the following changes:

titles = ('GT', 'pred')
cmaps = ('viridis', 'plasma')

Expand All @@ -59,7 +59,13 @@ def draw_routes(routes: tuple[ndarray, ndarray], return_mode: str = None):
lc.set_array(idxs)
lc.set_linewidth(3)
line = axes[i].add_collection(lc)
fig.colorbar(line, ax=axes[i], ticks=idxs[::int(len(idxs) / 10)], label='step', fraction=0.05)
fig.colorbar(
line,
ax=axes[i],
ticks=idxs[:: len(idxs) // 10],
label='step',
fraction=0.05,
)

axes[i].set_title(titles[i])
axes[i].set_xlim(0, 1)
Expand Down