Skip to content

Commit

Permalink
[Bugfix] Fix GAT's NaN (#310)
Browse files Browse the repository at this point in the history
* Fix NaN (GAT)

* Fix markdown requirements
  • Loading branch information
cenyk1230 authored Nov 19, 2021
1 parent d7d259c commit 7361540
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 4 deletions.
1 change: 1 addition & 0 deletions cogdl/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def train(args): # noqa: C901
checkpoint_path=args.checkpoint_path,
resume_training=args.resume_training,
patience=args.patience,
eval_step=args.eval_step,
logger=args.logger,
log_path=args.log_path,
project=args.project,
Expand Down
4 changes: 2 additions & 2 deletions cogdl/operators/edge_softmax/edge_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ __global__ void edge_softmax(
int lb = rowptr[rid];
int hb = rowptr[(rid + 1)];
int loop = 1 + (hb - lb) / 32;
float weightMax = 0;
float weightMax = -1e8;
float expAll = 0;
for (int j = 0; j < loop; j++)
{
int pid = threadIdx.x + (j << 5) + lb;
float weight = 0;
float weight = -1e8;
if(pid < hb)
{
weight = values[pid * head + hid];
Expand Down
1 change: 1 addition & 0 deletions cogdl/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_parser():
parser.add_argument("--use-best-config", action="store_true", help="use best config")
parser.add_argument("--unsup", action="store_true")
parser.add_argument("--nstage", type=int, default=1)
parser.add_argument("--eval-step", type=int, default=1)
parser.add_argument("--n-trials", type=int, default=3)

parser.add_argument("--devices", default=[0], type=int, nargs="+", help="which GPU to use")
Expand Down
6 changes: 4 additions & 2 deletions cogdl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,8 @@ def train(self, rank, model_w, dataset_w):
epoch_printer = Printer(print, rank=rank, world_size=self.world_size)

self.logger.start()
print_str_dict = dict()
for epoch in epoch_iter:
print_str_dict = dict()
for hook in self.pre_epoch_hooks:
hook(self)

Expand All @@ -321,7 +321,7 @@ def train(self, rank, model_w, dataset_w):
print_str_dict["train_loss"] = training_loss

val_loader = dataset_w.on_val_wrapper()
if val_loader is not None and (epoch % self.eval_step) == 0:
if val_loader is not None and epoch % self.eval_step == 0:
# inductive setting ..
dataset_w.eval()
# do validation in inference device
Expand Down Expand Up @@ -377,6 +377,7 @@ def validate(self, model_w: ModelWrapper, dataset_w: DataWrapper, device):
# ------- distributed training ---------

model_w.eval()
dataset_w.eval()
if self.cpu_inference:
model_w.to("cpu")
_device = device
Expand All @@ -396,6 +397,7 @@ def test(self, model_w: ModelWrapper, dataset_w: DataWrapper, device):
# ------- distributed training ---------

model_w.eval()
dataset_w.eval()
if self.cpu_inference:
model_w.to("cpu")
_device = device
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
sphinx==4.2.0
sphinx_rtd_theme==1.0.0
markdown==3.3.4
sphinx-markdown-tables==0.0.15
recommonmark==0.7.1
networkx
Expand Down

0 comments on commit 7361540

Please sign in to comment.