Skip to content

Commit

Permalink
Merge branch 'quantization' of https://github.com/ITMO-NSS-team/torch…
Browse files Browse the repository at this point in the history
…_DE_solver into quantization
  • Loading branch information
nikiniki1 committed Oct 18, 2023
2 parents d7392e9 + 28bd31d commit 769801d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
18 changes: 13 additions & 5 deletions tedeous/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def grid_format_prepare(coord_list, mode='NN') -> torch.Tensor:


class Plots():
def __init__(self, model, grid, mode, tol = 0):
def __init__(self, model, grid, mode, tol=0):
self.model = model
self.grid = grid
self.mode = mode
Expand All @@ -66,6 +66,8 @@ def print_nn(self, title: str):
nvars_model = self.model[-1].out_features
except:
nvars_model = self.model.model[-1].out_features
# else:
# nvars_model = self.model[-2].out_features

nparams = self.grid.shape[1]
fig = plt.figure(figsize=(15, 8))
Expand Down Expand Up @@ -287,6 +289,11 @@ def solve(self,

Cache_class.change_cache_dir(cache_dir)

device = device_type()

Cache_class = Model_prepare(self.grid, self.equal_cls,
self.model, self.mode, self.weak_form)

# prepare input data to uniform format
r = create_random_fn(model_randomize_parameter)

Expand All @@ -301,7 +308,7 @@ def solve(self,
cache_verbose,
model_randomize_parameter,
cache_model,
return_normalized_loss=normalized_loss_stop)
return_normalized_loss=normalized_loss_stop)

Solution_class = Solution(self.grid, self.equal_cls,
self.model, self.mode, self.weak_form,
Expand All @@ -313,6 +320,7 @@ def solve(self,

min_loss , _ = Solution_class.evaluate()


self.plot = Plots(self.model, self.grid, self.mode, tol)

optimizer = self.optimizer_choice(optimizer_mode, learning_rate)
Expand Down Expand Up @@ -409,14 +417,14 @@ def solve(self,
solution_save=step_plot_save,
save_dir=image_save_dir)
stop_dings += 1

# print('t',t)
if print_every != None and (t % print_every == 0) and verbose:
print('[{}] Print every {} step'.format(
datetime.datetime.now(), print_every))
print(info_string)
if inverse_parameters is not None:
print(self.str_param(inverse_parameters))

# print('loss', closure().item(), 'loss_norm', cur_loss)
if step_plot_print or step_plot_save:
self.plot.solution_print(title='Iteration = ' + str(t),
solution_print=step_plot_print,
Expand All @@ -433,4 +441,4 @@ def solve(self,
Cache_class.save_model(self.model, self.model.state_dict(),
optimizer.state_dict(),
name=name)
return self.model
return self.model
4 changes: 2 additions & 2 deletions tedeous/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def update(self, op_length: list,

return lambda_op, lambda_bnd


class PadTransform(Module):
"""Pad tensor to a fixed length with given padding value.
Expand Down Expand Up @@ -237,10 +236,11 @@ def closure_mixed_cpu(self):
def do(self, normalized_loss_stop):
if self.mixed_precision:
loss, loss_normalized = self.closure_mixed_cpu()
print(f'Mixed precision enabled. The device is {self.device}')
else:
loss, loss_normalized = self.closure()

if normalized_loss_stop:
return loss_normalized
else:
return loss
return loss

0 comments on commit 769801d

Please sign in to comment.