Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix data memory #317

Merged
merged 2 commits into from
Dec 9, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions cogdl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def run(self, model_w: ModelWrapper, dataset_w: DataWrapper):
return best_model_w.model

final_test = self.evaluate(best_model_w, dataset_w)

# clear the GPU memory
dataset = dataset_w.get_dataset()
if isinstance(dataset.data, Graph):
dataset.data.to("cpu")

return final_test

def evaluate(self, model_w: ModelWrapper, dataset_w: DataWrapper, cpu=False):
Expand Down Expand Up @@ -244,16 +250,16 @@ def build_optimizer(self, model_w):
opt_wrap = model_w.setup_optimizer()
if isinstance(opt_wrap, list) or isinstance(opt_wrap, tuple):
assert len(opt_wrap) == 2
optimizers, lr_schedulars = opt_wrap
optimizers, lr_schedulers = opt_wrap
else:
optimizers = opt_wrap
lr_schedulars = None
lr_schedulers = None

if not isinstance(optimizers, list):
optimizers = [optimizers]
if lr_schedulars and not isinstance(lr_schedulars, list):
lr_schedulars = [lr_schedulars]
return optimizers, lr_schedulars
if lr_schedulers and not isinstance(lr_schedulers, list):
lr_schedulers = [lr_schedulers]
return optimizers, lr_schedulers

def initialize(self, model_w, rank=0, master_addr: str = "localhost", master_port: int = 10008):
if self.distributed_training:
Expand All @@ -274,7 +280,7 @@ def train(self, rank, model_w, dataset_w):
self.data_controller.prepare_data_wrapper(dataset_w, rank)
self.eval_data_back_to_cpu = dataset_w.data_back_to_cpu

optimizers, lr_schedulars = self.build_optimizer(model_w)
optimizers, lr_schedulers = self.build_optimizer(model_w)
if optimizers[0] is None:
return

Expand Down Expand Up @@ -315,7 +321,7 @@ def train(self, rank, model_w, dataset_w):
# inductive setting ..
dataset_w.train()
train_loader = dataset_w.on_train_wrapper()
training_loss = self.training_step(model_w, train_loader, optimizers, lr_schedulars, rank)
training_loss = self.training_step(model_w, train_loader, optimizers, lr_schedulers, rank)

print_str_dict["Epoch"] = epoch
print_str_dict["train_loss"] = training_loss
Expand Down Expand Up @@ -428,7 +434,7 @@ def distributed_test(self, model_w: ModelWrapper, loader, rank, fn):
dist.broadcast_object_list(object_list, src=0)
return object_list[0]

def training_step(self, model_w, train_loader, optimizers, lr_schedulars, device):
def training_step(self, model_w, train_loader, optimizers, lr_schedulers, device):
model_w.train()
losses = []

Expand All @@ -449,8 +455,8 @@ def training_step(self, model_w, train_loader, optimizers, lr_schedulars, device
optimizer.step()

losses.append(loss.item())
if lr_schedulars is not None:
for lr_schedular in lr_schedulars:
if lr_schedulers is not None:
for lr_schedular in lr_schedulers:
lr_schedular.step()
return np.mean(losses)

Expand Down