Skip to content

Commit

Permalink
print best value
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Nov 7, 2023
1 parent c783807 commit 11100ea
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 60 deletions.
22 changes: 12 additions & 10 deletions src/cpp/include/reporting/reporting.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ class Metric {
public:
std::string name_;
std::string unit_;
double best_val_;
double best_test_;

virtual ~Metric(){};
};

class RankingMetric : public Metric {
public:
virtual torch::Tensor computeMetric(torch::Tensor ranks) = 0;
virtual torch::Tensor computeMetric(torch::Tensor ranks, bool val = false) = 0;
};

class HitskMetric : public RankingMetric {
Expand All @@ -26,33 +28,33 @@ class HitskMetric : public RankingMetric {
public:
HitskMetric(int k);

torch::Tensor computeMetric(torch::Tensor ranks);
torch::Tensor computeMetric(torch::Tensor ranks, bool val = false);
};

class MeanRankMetric : public RankingMetric {
public:
MeanRankMetric();

torch::Tensor computeMetric(torch::Tensor ranks);
torch::Tensor computeMetric(torch::Tensor ranks, bool val = false);
};

class MeanReciprocalRankMetric : public RankingMetric {
public:
MeanReciprocalRankMetric();

torch::Tensor computeMetric(torch::Tensor ranks);
torch::Tensor computeMetric(torch::Tensor ranks, bool val = false);
};

class ClassificationMetric : public Metric {
public:
virtual torch::Tensor computeMetric(torch::Tensor y_true, torch::Tensor y_pred) = 0;
virtual torch::Tensor computeMetric(torch::Tensor y_true, torch::Tensor y_pred, bool val = false) = 0;
};

class CategoricalAccuracyMetric : public ClassificationMetric {
public:
CategoricalAccuracyMetric();

torch::Tensor computeMetric(torch::Tensor y_true, torch::Tensor y_pred) override;
torch::Tensor computeMetric(torch::Tensor y_true, torch::Tensor y_pred, bool val = false) override;
};

class Reporter {
Expand All @@ -72,7 +74,7 @@ class Reporter {

void addMetric(shared_ptr<Metric> metric) { metrics_.emplace_back(metric); }

virtual void report() = 0;
virtual void report(bool val = false) = 0;
};

class LinkPredictionReporter : public Reporter {
Expand All @@ -94,7 +96,7 @@ class LinkPredictionReporter : public Reporter {

void addResult(torch::Tensor pos_scores, torch::Tensor neg_scores, torch::Tensor edges = torch::Tensor());

void report() override;
void report(bool val = false) override;

void save(string directory, bool scores, bool ranks);
};
Expand All @@ -116,7 +118,7 @@ class NodeClassificationReporter : public Reporter {

void addResult(torch::Tensor y_true, torch::Tensor y_pred, torch::Tensor node_ids = torch::Tensor());

void report() override;
void report(bool val = false) override;

void save(string directory, bool labels);
};
Expand All @@ -140,7 +142,7 @@ class ProgressReporter : public Reporter {

void addResult(int64_t items_processed, double loss = 0.0);

void report() override;
void report(bool val = false) override;
};

#endif // MARIUS_SRC_CPP_INCLUDE_REPORTING_H_
4 changes: 2 additions & 2 deletions src/cpp/src/pipeline/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ void PipelineEvaluator::evaluate(bool validation) {
pipeline_->model_->distNotifyCompleteAndWait(true);

if (dataloader_->batch_worker_)
pipeline_->model_->reporter_->report();
pipeline_->model_->reporter_->report(validation);

int64_t epoch_time = timer.getDuration();
SPDLOG_INFO("Evaluation complete: {}ms", epoch_time);
Expand Down Expand Up @@ -113,5 +113,5 @@ void SynchronousEvaluator::evaluate(bool validation) {

model_->distNotifyCompleteAndWait(true);

model_->reporter_->report();
model_->reporter_->report(validation);
}
62 changes: 50 additions & 12 deletions src/cpp/src/reporting/reporting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,65 @@ HitskMetric::HitskMetric(int k) {
k_ = k;
name_ = "Hits@" + std::to_string(k_);
unit_ = "";
best_val_ = 0;
best_test_ = 0;
}

torch::Tensor HitskMetric::computeMetric(torch::Tensor ranks) { return torch::tensor((double)ranks.le(k_).nonzero().size(0) / ranks.size(0), torch::kFloat64); }
torch::Tensor HitskMetric::computeMetric(torch::Tensor ranks, bool val) {
return torch::tensor((double)ranks.le(k_).nonzero().size(0) / ranks.size(0), torch::kFloat64);
}

MeanRankMetric::MeanRankMetric() {
name_ = "Mean Rank";
unit_ = "";
best_val_ = 0;
best_test_ = 0;
}

torch::Tensor MeanRankMetric::computeMetric(torch::Tensor ranks) { return ranks.to(torch::kFloat64).mean(); }
torch::Tensor MeanRankMetric::computeMetric(torch::Tensor ranks, bool val) {
return ranks.to(torch::kFloat64).mean();
}

MeanReciprocalRankMetric::MeanReciprocalRankMetric() {
name_ = "MRR";
unit_ = "";
best_val_ = 0;
best_test_ = 0;
}

torch::Tensor MeanReciprocalRankMetric::computeMetric(torch::Tensor ranks) { return ranks.to(torch::kFloat32).reciprocal().mean(); }
torch::Tensor MeanReciprocalRankMetric::computeMetric(torch::Tensor ranks, bool val) {
torch::Tensor result = ranks.to(torch::kFloat32).reciprocal().mean();
if (val) {
if (result.item<double>() > best_val_) {
best_val_ = result.item<double>();
}
} else {
if (result.item<double>() > best_test_) {
best_test_ = result.item<double>();
}
}
return result;
}

CategoricalAccuracyMetric::CategoricalAccuracyMetric() {
name_ = "Accuracy";
unit_ = "%";
best_val_ = 0;
best_test_ = 0;
}

torch::Tensor CategoricalAccuracyMetric::computeMetric(torch::Tensor y_true, torch::Tensor y_pred) {
return 100 * torch::tensor({(double)(y_true == y_pred).nonzero().size(0) / y_true.size(0)}, torch::kFloat64);
torch::Tensor CategoricalAccuracyMetric::computeMetric(torch::Tensor y_true, torch::Tensor y_pred, bool val) {
torch::Tensor result = 100 * torch::tensor({(double)(y_true == y_pred).nonzero().size(0) / y_true.size(0)}, torch::kFloat64);
if (val) {
if (result.item<double>() > best_val_) {
best_val_ = result.item<double>();
}
} else {
if (result.item<double>() > best_test_) {
best_test_ = result.item<double>();
}
}
return result;
}

Reporter::~Reporter() { delete lock_; }
Expand Down Expand Up @@ -69,7 +103,7 @@ void LinkPredictionReporter::addResult(torch::Tensor pos_scores, torch::Tensor n
unlock();
}

void LinkPredictionReporter::report() {
void LinkPredictionReporter::report(bool val) {
all_ranks_ = torch::cat(per_batch_ranks_).to(torch::kCPU);
if (per_batch_scores_.size() > 0) {
all_scores_ = torch::cat(per_batch_scores_);
Expand All @@ -83,8 +117,10 @@ void LinkPredictionReporter::report() {

std::string tmp;
for (auto m : metrics_) {
torch::Tensor result = std::dynamic_pointer_cast<RankingMetric>(m)->computeMetric(all_ranks_);
tmp = m->name_ + ": " + std::to_string(result.item<double>()) + m->unit_ + "\n";
torch::Tensor result = std::dynamic_pointer_cast<RankingMetric>(m)->computeMetric(all_ranks_, val);
if (val)
tmp = m->name_ + ": " + std::to_string(result.item<double>()) + m->unit_ + "; Best: " + std::to_string(m->best_val_) + "\n";
else tmp = m->name_ + ": " + std::to_string(result.item<double>()) + m->unit_ + "; Best: " + std::to_string(m->best_test_) + "\n";
report_string = report_string + tmp;
}
std::string footer = "=================================";
Expand Down Expand Up @@ -201,7 +237,7 @@ void NodeClassificationReporter::addResult(torch::Tensor y_true, torch::Tensor y
unlock();
}

void NodeClassificationReporter::report() {
void NodeClassificationReporter::report(bool val) {
all_y_true_ = torch::cat(per_batch_y_true_);
all_y_pred_ = torch::cat(per_batch_y_pred_);
per_batch_y_true_ = {};
Expand All @@ -213,8 +249,10 @@ void NodeClassificationReporter::report() {

std::string tmp;
for (auto m : metrics_) {
torch::Tensor result = std::dynamic_pointer_cast<ClassificationMetric>(m)->computeMetric(all_y_true_, all_y_pred_);
tmp = m->name_ + ": " + std::to_string(result.item<double>()) + m->unit_ + "\n";
torch::Tensor result = std::dynamic_pointer_cast<ClassificationMetric>(m)->computeMetric(all_y_true_, all_y_pred_, val);
if (val)
tmp = m->name_ + ": " + std::to_string(result.item<double>()) + m->unit_ + "; Best: " + std::to_string(m->best_val_) + "\n";
else tmp = m->name_ + ": " + std::to_string(result.item<double>()) + m->unit_ + "; Best: " + std::to_string(m->best_test_) + "\n";
report_string = report_string + tmp;
}
std::string footer = "=================================";
Expand Down Expand Up @@ -321,7 +359,7 @@ void ProgressReporter::addResult(int64_t items_processed, double loss) {
unlock();
}

void ProgressReporter::report() {
void ProgressReporter::report(bool val) {
// std::string report_string = item_name_ + " processed: [" + std::to_string(current_item_) + "/" + std::to_string(total_items_) + "], " +
// fmt::format("{:.2f}", 100 * (double)current_item_ / total_items_) + "%";
std::string report_string = item_name_ + " processed: [" + std::to_string(current_item_) + "/" + std::to_string(total_items_) + "], " +
Expand Down
66 changes: 33 additions & 33 deletions src/python/tools/preprocess/datasets/ogb_wikikg90mv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,41 +74,41 @@ def preprocess(

dataset_stats = converter.convert()

node_features = np.load(self.input_node_feature_file).astype(np.float32)
rel_features = np.load(self.input_rel_feature_file).astype(np.float32)

if remap_ids:
node_mapping = np.genfromtxt(self.output_directory / Path(PathConstants.node_mapping_path), delimiter=",")
random_node_map = node_mapping[:, 1].astype(np.int32)
random_node_map_argsort = np.argsort(random_node_map)

with open(self.node_features_file, "wb") as f:
chunk_size = int(1e7)
num_chunks = int(np.ceil(node_mapping.shape[0] / chunk_size))

offset = 0

for chunk_id in range(num_chunks):
if offset + chunk_size >= node_mapping.shape[0]:
chunk_size = node_mapping.shape[0] - offset
f.write(bytes(node_features[random_node_map_argsort[offset : offset + chunk_size]]))

rel_mapping = np.genfromtxt(
self.output_directory / Path(PathConstants.relation_mapping_path), delimiter=","
)
random_rel_map = rel_mapping[:, 1].astype(np.int32)
random_rel_map_argsort = np.argsort(random_rel_map)
rel_features = rel_features[random_rel_map_argsort]
else:
with open(self.node_features_file, "wb") as f:
f.write(bytes(node_features))

with open(self.relation_features_file, "wb") as f:
f.write(bytes(rel_features))
# node_features = np.load(self.input_node_feature_file).astype(np.float32)
# rel_features = np.load(self.input_rel_feature_file).astype(np.float32)
#
# if remap_ids:
# node_mapping = np.genfromtxt(self.output_directory / Path(PathConstants.node_mapping_path), delimiter=",")
# random_node_map = node_mapping[:, 1].astype(np.int32)
# random_node_map_argsort = np.argsort(random_node_map)
#
# with open(self.node_features_file, "wb") as f:
# chunk_size = int(1e7)
# num_chunks = int(np.ceil(node_mapping.shape[0] / chunk_size))
#
# offset = 0
#
# for chunk_id in range(num_chunks):
# if offset + chunk_size >= node_mapping.shape[0]:
# chunk_size = node_mapping.shape[0] - offset
# f.write(bytes(node_features[random_node_map_argsort[offset : offset + chunk_size]]))
#
# rel_mapping = np.genfromtxt(
# self.output_directory / Path(PathConstants.relation_mapping_path), delimiter=","
# )
# random_rel_map = rel_mapping[:, 1].astype(np.int32)
# random_rel_map_argsort = np.argsort(random_rel_map)
# rel_features = rel_features[random_rel_map_argsort]
# else:
# with open(self.node_features_file, "wb") as f:
# f.write(bytes(node_features))
#
# with open(self.relation_features_file, "wb") as f:
# f.write(bytes(rel_features))

# update dataset yaml
dataset_stats.node_feature_dim = node_features.shape[1]
dataset_stats.rel_feature_dim = rel_features.shape[1]
# dataset_stats.node_feature_dim = node_features.shape[1]
# dataset_stats.rel_feature_dim = rel_features.shape[1]

with open(self.output_directory / Path("dataset.yaml"), "w") as f:
yaml_file = OmegaConf.to_yaml(dataset_stats)
Expand Down
11 changes: 8 additions & 3 deletions src/python/tools/preprocess/metis_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load(self):
raise Exception()

def metis_partition(self):
from partitioning_helpers import relabel_edges, pymetis_partitioning, add_missing_nodes, balance_parts, create_edge_buckets
from partitioning_helpers import relabel_edges, pymetis_partitioning, add_missing_nodes, balance_parts, create_edge_buckets#, tree_partitioning

# partition based on the train_edges
edges = self.train_edges.numpy()
Expand All @@ -79,7 +79,12 @@ def metis_partition(self):
edges, unique_nodes, node_mapping = relabel_edges(edges, self.num_nodes, return_map=True)
num_unique = unique_nodes.shape[0]

parts = pymetis_partitioning(self.num_partitions, num_unique, edges, 0)
parts = pymetis_partitioning(self.num_partitions, num_unique, edges)#, True, False)
# import time
# t1 = time.time()
# parts = tree_partitioning(self.num_partitions, num_unique, edges)
# print("time: ", time.time() - t1)

parts = add_missing_nodes(parts, self.num_nodes)
parts = balance_parts(parts, np.ceil(self.num_nodes/self.num_partitions), None)
edge_bucket_sizes, _, _ = create_edge_buckets(edges, parts, 0, plot=False)
Expand Down Expand Up @@ -107,7 +112,7 @@ def metis_partition(self):
self.train_edges = self.train_edges[indices]

src_splits = torch.searchsorted(self.train_edges[:, 0].contiguous(),
np.ceil(self.num_nodes/self.num_partitions) * torch.arange(self.num_partitions))
int(np.ceil(self.num_nodes/self.num_partitions)) * torch.arange(self.num_partitions, dtype=torch.int32))
for ii in range(self.num_partitions): # src partition index
end_index = self.train_edges.shape[0] if ii == self.num_partitions - 1 else src_splits[ii+1]

Expand Down

0 comments on commit 11100ea

Please sign in to comment.