-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery Starbot ⭐ refactored AgainstEntropy/NLOS-Track #4
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,7 @@ | |
|
||
def sub_mean(frames: Tensor) -> Tensor: | ||
mean_frame = frames.mean(axis=0, keepdim=True) | ||
frames_sub_mean = frames.sub(mean_frame) | ||
|
||
return frames_sub_mean | ||
return frames.sub(mean_frame) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def diff(frames: Tensor) -> Tensor: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ def __init__(self, model_name: str, pretrained: bool = True, | |
rnn_hdim: int = 128): | ||
super(PAC_Cell, self).__init__() | ||
|
||
assert model_name in ['PAC_Net', 'P_Net', 'C_Net', 'baseline'] | ||
assert model_name in {'PAC_Net', 'P_Net', 'C_Net', 'baseline'} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.rnn_hdim = rnn_hdim | ||
|
||
self.backbone_builder = { | ||
|
@@ -66,7 +66,7 @@ def __init__(self, model_name: str, pretrained: bool, | |
self.rnn_hdim = rnn_hdim | ||
self.v_loss = v_loss | ||
|
||
assert model_name in ['PAC_Net', 'P_Net', 'C_Net', 'baseline'] | ||
assert model_name in {'PAC_Net', 'P_Net', 'C_Net', 'baseline'} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# CNN | ||
self.backbone_builder = { | ||
'PAC_Net': tvmodels.resnet18, | ||
|
@@ -366,9 +366,7 @@ def warm_up(self, I: Tensor): | |
fx = self.warmup_encoder(rearrange(I, 'b c t h w -> (b t) c h w')) | ||
fx = rearrange(fx, '(b t) d -> b t d', b=B) | ||
|
||
hx = self.warmup_rnn(fx)[1] # (2, B, D) | ||
|
||
return hx | ||
return self.warmup_rnn(fx)[1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
|
||
class NLOS_baseline(PAC_Net_Base): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ def main(cfg): | |
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | ||
|
||
world_size = len(dist_cfgs['device_ids'].split(',')) | ||
dist_cfgs['distributed'] = True if world_size > 1 else False | ||
dist_cfgs['distributed'] = world_size > 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
dist_cfgs['world_size'] = world_size | ||
cfg['loader_kwargs']['batch_size'] = cfg['train_configs']['batch_size'] // world_size | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ def generate_route(self, | |
self.route_length = route_length | ||
|
||
self._init_pv() | ||
for step in range(route_length): | ||
for _ in range(route_length): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# print(self.velocity) | ||
self.next_step(turn_rate=turn_rate) | ||
self.e_route.append(self.e_position.copy()) | ||
|
@@ -116,8 +116,8 @@ def load_route(self, | |
mat_path = os.path.join(save_dir, mat_name) | ||
save_dict = loadmat(mat_path) | ||
|
||
self.e_route = [p for p in save_dict['route']] | ||
self.velocities = [v for v in save_dict['velocities']] | ||
self.e_route = list(save_dict['route']) | ||
self.velocities = list(save_dict['velocities']) | ||
Comment on lines
-119
to
+120
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
print(f'Load data from {mat_path} successfully!') | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -192,20 +192,20 @@ class Trainer_tracking(Trainer_Base): | |
def __init__(self, cfg): | ||
super(Trainer_tracking, self).__init__(cfg=cfg) | ||
|
||
log_train_cfg = { | ||
"model_name": self.model_name, | ||
**self.model_cfgs, | ||
"batch_size": self.train_cfgs['batch_size'], | ||
"v_loss_alpha": self.train_cfgs['v_loss_alpha'], | ||
"loss_total_alpha": self.train_cfgs['loss_total_alpha'], | ||
"resume": self.train_cfgs['resume'], | ||
"route_len": self.dataset_cfgs['route_len'], | ||
"noise_factor": self.dataset_cfgs['noise_factor'], | ||
**self.optim_kwargs, | ||
"epochs": self.schedule_cfgs['max_epoch'], | ||
} | ||
|
||
if self.dist_cfgs['local_rank'] == 0: | ||
Comment on lines
-195
to
208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
log_train_cfg = { | ||
"model_name": self.model_name, | ||
**self.model_cfgs, | ||
"batch_size": self.train_cfgs['batch_size'], | ||
"v_loss_alpha": self.train_cfgs['v_loss_alpha'], | ||
"loss_total_alpha": self.train_cfgs['loss_total_alpha'], | ||
"resume": self.train_cfgs['resume'], | ||
"route_len": self.dataset_cfgs['route_len'], | ||
"noise_factor": self.dataset_cfgs['noise_factor'], | ||
**self.optim_kwargs, | ||
"epochs": self.schedule_cfgs['max_epoch'], | ||
} | ||
|
||
self._init_recorder(log_train_cfg) | ||
|
||
self.val_metrics = {'x_loss': 0.0, | ||
|
@@ -278,14 +278,14 @@ def run(self): | |
'Metric/val/dtw': val_metric[2], | ||
}, step=epoch + 1) | ||
if self.epoch % 5 == 0: | ||
logger.info(f'Logging images...') | ||
logger.info('Logging images...') | ||
self.test_plot(epoch=self.epoch, phase='train') | ||
self.test_plot(epoch=self.epoch, phase='val') | ||
|
||
self.scheduler.step() | ||
|
||
if ((epoch + 1) % self.log_cfgs['save_epoch_interval'] == 0) \ | ||
or (epoch + 1) == self.schedule_cfgs['max_epoch']: | ||
or (epoch + 1) == self.schedule_cfgs['max_epoch']: | ||
Comment on lines
-281
to
+288
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
checkpoint_path = os.path.join(self.ckpt_dir, f"epoch_{(epoch + 1)}.pth") | ||
self.save_checkpoint(checkpoint_path) | ||
|
||
|
@@ -315,7 +315,7 @@ def train(self, epoch): | |
dynamic_ncols=True, | ||
ascii=(platform.version() == 'Windows')) | ||
|
||
for step in range(len_loader): | ||
for _ in range(len_loader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
try: | ||
inputs, labels, map_sizes = next(iter_loader) | ||
except Exception as e: | ||
|
@@ -387,7 +387,7 @@ def train(self, epoch): | |
pbar.close() | ||
|
||
return (x_loss_recorder.avg, v_loss_recorder.avg), \ | ||
(pcm_recorder.avg, area_recorder.avg, dtw_recorder.avg) | ||
(pcm_recorder.avg, area_recorder.avg, dtw_recorder.avg) | ||
|
||
def val(self, epoch): | ||
self.model.eval() | ||
|
@@ -482,13 +482,15 @@ def val(self, epoch): | |
metrics = [self.val_metrics[name] for name in names] | ||
res_table.add_row([f"{m:.4}" if type(m) is float else m for m in metrics[:-1]] + [metrics[-1]]) | ||
|
||
logger.info(f'Performance on validation set at epoch: {epoch + 1}\n' + res_table.get_string()) | ||
logger.info( | ||
f'Performance on validation set at epoch: {epoch + 1}\n{res_table.get_string()}' | ||
) | ||
|
||
return (self.val_metrics['x_loss'], self.val_metrics['v_loss']), \ | ||
(self.val_metrics['pcm'], self.val_metrics['area'], self.val_metrics['dtw']) | ||
(self.val_metrics['pcm'], self.val_metrics['area'], self.val_metrics['dtw']) | ||
Comment on lines
-485
to
+490
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def test_plot(self, epoch, phase: str): | ||
assert phase in ['train', 'val'] | ||
assert phase in {'train', 'val'} | ||
Comment on lines
-491
to
+493
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
self.model.eval() | ||
iter_loader = iter(self.val_loader) if phase == 'val' else iter(self.train_loader) | ||
frames, gt_routes, map_sizes = next(iter_loader) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ def draw_route(map_size: ndarray, route: ndarray, | |
lc.set_array(idxs) | ||
lc.set_linewidth(3) | ||
line = ax.add_collection(lc) | ||
fig.colorbar(line, ax=ax, ticks=idxs[::int(len(idxs) / 10)], label='step') | ||
fig.colorbar(line, ax=ax, ticks=idxs[::len(idxs) // 10], label='step') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
ax.set_xlim(0, map_size[0]) | ||
ax.set_xlabel('x') | ||
|
@@ -40,7 +40,7 @@ def draw_route(map_size: ndarray, route: ndarray, | |
|
||
|
||
def draw_routes(routes: tuple[ndarray, ndarray], return_mode: str = None): | ||
assert return_mode in ['plt_fig', 'fig_array', None] | ||
assert return_mode in {'plt_fig', 'fig_array', None} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
titles = ('GT', 'pred') | ||
cmaps = ('viridis', 'plasma') | ||
|
||
|
@@ -59,7 +59,13 @@ def draw_routes(routes: tuple[ndarray, ndarray], return_mode: str = None): | |
lc.set_array(idxs) | ||
lc.set_linewidth(3) | ||
line = axes[i].add_collection(lc) | ||
fig.colorbar(line, ax=axes[i], ticks=idxs[::int(len(idxs) / 10)], label='step', fraction=0.05) | ||
fig.colorbar( | ||
line, | ||
ax=axes[i], | ||
ticks=idxs[:: len(idxs) // 10], | ||
label='step', | ||
fraction=0.05, | ||
) | ||
|
||
axes[i].set_title(titles[i]) | ||
axes[i].set_xlim(0, 1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
split_dataset
refactored with the following changes:switch
)