Skip to content

Commit

Permalink
[GraphBolt] use torch.load/save instead of load/save_fused_xxx() (#6707)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Dec 7, 2023
1 parent 0348ad3 commit c213444
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 127 deletions.
32 changes: 11 additions & 21 deletions graphbolt/include/graphbolt/serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,32 @@ namespace torch {

/**
* @brief Overload input stream operator for FusedCSCSamplingGraph
* deserialization.
* deserialization. This enables `torch::load()` for FusedCSCSamplingGraph.
*
* @param archive Input stream for deserializing.
* @param graph FusedCSCSamplingGraph.
*
* @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::load(*graph, filename);
*/
inline serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
graphbolt::sampling::FusedCSCSamplingGraph& graph);

/**
* @brief Overload output stream operator for FusedCSCSamplingGraph
* serialization.
* serialization. This enables `torch::save()` for FusedCSCSamplingGraph.
* @param archive Output stream for serializing.
* @param graph FusedCSCSamplingGraph.
*
* @return archive
*
* @code
* auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
* torch::save(*graph, filename);
*/
inline serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
Expand All @@ -47,25 +56,6 @@ inline serialize::OutputArchive& operator<<(

namespace graphbolt {

/**
* @brief Load FusedCSCSamplingGraph from file.
* @param filename File name to read.
*
* @return FusedCSCSamplingGraph.
*/
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename);

/**
* @brief Save FusedCSCSamplingGraph to file.
* @param graph FusedCSCSamplingGraph to save.
* @param filename File name to save.
*
*/
void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename);

/**
* @brief Read data from archive.
* @param archive Input archive.
Expand Down
2 changes: 0 additions & 2 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ TORCH_LIBRARY(graphbolt, m) {
return g;
});
m.def("from_fused_csc", &FusedCSCSamplingGraph::FromCSC);
m.def("load_fused_csc_sampling_graph", &LoadFusedCSCSamplingGraph);
m.def("save_fused_csc_sampling_graph", &SaveFusedCSCSamplingGraph);
m.def(
"load_from_shared_memory", &FusedCSCSamplingGraph::LoadFromSharedMemory);
m.def("unique_and_compact", &UniqueAndCompact);
Expand Down
13 changes: 0 additions & 13 deletions graphbolt/src/serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,6 @@ serialize::OutputArchive& operator<<(

namespace graphbolt {

c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> LoadFusedCSCSamplingGraph(
const std::string& filename) {
auto&& graph = c10::make_intrusive<sampling::FusedCSCSamplingGraph>();
torch::load(*graph, filename);
return graph;
}

void SaveFusedCSCSamplingGraph(
c10::intrusive_ptr<sampling::FusedCSCSamplingGraph> graph,
const std::string& filename) {
torch::save(*graph, filename);
}

torch::IValue read_from_archive(
torch::serialize::InputArchive& archive, const std::string& key) {
torch::IValue data;
Expand Down
7 changes: 5 additions & 2 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import numpy as np

import torch

from .. import backend as F
from ..base import DGLError, EID, ETYPE, NID, NTYPE
from ..convert import to_homogeneous
Expand Down Expand Up @@ -1236,6 +1238,7 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
part_config : str
The partition configuration JSON file.
"""

# As only this function requires GraphBolt for now, let's import here.
from .. import graphbolt

Expand Down Expand Up @@ -1279,6 +1282,6 @@ def init_type_per_edge(graph, gpb):
part_meta[f"part-{part_id}"]["part_graph"],
)
csc_graph_path = os.path.join(
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.tar"
os.path.dirname(orig_graph_path), "fused_csc_sampling_graph.pt"
)
graphbolt.save_fused_csc_sampling_graph(csc_graph, csc_graph_path)
torch.save(csc_graph, csc_graph_path)
52 changes: 7 additions & 45 deletions python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""CSC format sampling graph."""
# pylint: disable= invalid-name
import os
import tarfile
import tempfile
from collections import defaultdict
from typing import Dict, Optional, Union

Expand All @@ -27,8 +24,6 @@
"FusedCSCSamplingGraph",
"from_fused_csc",
"load_from_shared_memory",
"load_fused_csc_sampling_graph",
"save_fused_csc_sampling_graph",
"from_dglgraph",
]

Expand Down Expand Up @@ -99,11 +94,11 @@ def __repr__(self):
return _csc_sampling_graph_str(self)

def __init__(
self, c_csc_graph: torch.ScriptObject, metadata: Optional[GraphMetadata]
self,
c_csc_graph: torch.ScriptObject,
):
super().__init__()
self._c_csc_graph = c_csc_graph
self._metadata = metadata

@property
def total_num_nodes(self) -> int:
Expand Down Expand Up @@ -318,12 +313,16 @@ def edge_attributes(
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
return self._metadata
if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)

def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
Expand Down Expand Up @@ -884,7 +883,6 @@ def copy_to_shared_memory(self, shared_memory_name: str):
"""
return FusedCSCSamplingGraph(
self._c_csc_graph.copy_to_shared_memory(shared_memory_name),
self._metadata,
)

def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
Expand Down Expand Up @@ -975,13 +973,11 @@ def from_fused_csc(
edge_type_to_id,
edge_attributes,
),
metadata,
)


def load_from_shared_memory(
shared_memory_name: str,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph:
"""Load a FusedCSCSamplingGraph object from shared memory.
Expand All @@ -997,7 +993,6 @@ def load_from_shared_memory(
"""
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_from_shared_memory(shared_memory_name),
metadata,
)


Expand Down Expand Up @@ -1033,38 +1028,6 @@ def _add_indent(_str, indent):
return final_str


def load_fused_csc_sampling_graph(filename):
"""Load FusedCSCSamplingGraph from tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
with tarfile.open(filename, "r") as archive:
archive.extractall(temp_dir)
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
metadata_filename = os.path.join(temp_dir, "metadata.pt")
return FusedCSCSamplingGraph(
torch.ops.graphbolt.load_fused_csc_sampling_graph(graph_filename),
torch.load(metadata_filename),
)


def save_fused_csc_sampling_graph(graph, filename):
"""Save FusedCSCSamplingGraph to tar file."""
with tempfile.TemporaryDirectory() as temp_dir:
graph_filename = os.path.join(temp_dir, "fused_csc_sampling_graph.pt")
torch.ops.graphbolt.save_fused_csc_sampling_graph(
graph._c_csc_graph, graph_filename
)
metadata_filename = os.path.join(temp_dir, "metadata.pt")
torch.save(graph.metadata, metadata_filename)
with tarfile.open(filename, "w") as archive:
archive.add(
graph_filename, arcname=os.path.basename(graph_filename)
)
archive.add(
metadata_filename, arcname=os.path.basename(metadata_filename)
)
print(f"FusedCSCSamplingGraph has been saved to {filename}.")


def from_dglgraph(
g: DGLGraph,
is_homogeneous: bool = False,
Expand Down Expand Up @@ -1114,5 +1077,4 @@ def from_dglgraph(
edge_type_to_id,
edge_attributes,
),
metadata,
)
13 changes: 4 additions & 9 deletions python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
from ..internal import copy_or_convert_data, read_data
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from .fused_csc_sampling_graph import (
from_dglgraph,
FusedCSCSamplingGraph,
load_fused_csc_sampling_graph,
save_fused_csc_sampling_graph,
)
from .fused_csc_sampling_graph import from_dglgraph, FusedCSCSamplingGraph
from .ondisk_metadata import (
OnDiskGraphTopology,
OnDiskMetaData,
Expand Down Expand Up @@ -147,10 +142,10 @@ def preprocess_ondisk_dataset(
output_config["graph_topology"] = {}
output_config["graph_topology"]["type"] = "FusedCSCSamplingGraph"
output_config["graph_topology"]["path"] = os.path.join(
processed_dir_prefix, "fused_csc_sampling_graph.tar"
processed_dir_prefix, "fused_csc_sampling_graph.pt"
)

save_fused_csc_sampling_graph(
torch.save(
fused_csc_sampling_graph,
os.path.join(
dataset_dir,
Expand Down Expand Up @@ -452,7 +447,7 @@ def _load_graph(
if graph_topology is None:
return None
if graph_topology.type == "FusedCSCSamplingGraph":
return load_fused_csc_sampling_graph(graph_topology.path)
return torch.load(graph_topology.path)
raise NotImplementedError(
f"Graph topology type {graph_topology.type} is not supported."
)
Expand Down
8 changes: 4 additions & 4 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,9 +695,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
new_g = th.load(
os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
Expand Down Expand Up @@ -728,9 +728,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
new_g = dgl.graphbolt.load_fused_csc_sampling_graph(
new_g = th.load(
os.path.join(
test_dir, f"part{part_id}/fused_csc_sampling_graph.tar"
test_dir, f"part{part_id}/fused_csc_sampling_graph.pt"
)
)
orig_indptr, orig_indices, _ = orig_g.adj().csc()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
graph = gb.from_fused_csc(csc_indptr, indices)

with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename)
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, filename)
graph2 = torch.load(filename)

assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
Expand Down Expand Up @@ -338,9 +338,9 @@ def test_load_save_hetero_graph(
)

with tempfile.TemporaryDirectory() as test_dir:
filename = os.path.join(test_dir, "fused_csc_sampling_graph.tar")
gb.save_fused_csc_sampling_graph(graph, filename)
graph2 = gb.load_fused_csc_sampling_graph(filename)
filename = os.path.join(test_dir, "fused_csc_sampling_graph.pt")
torch.save(graph, filename)
graph2 = torch.load(filename)

assert graph.total_num_nodes == graph2.total_num_nodes
assert graph.total_num_edges == graph2.total_num_edges
Expand Down Expand Up @@ -1103,7 +1103,7 @@ def test_homo_graph_on_shared_memory(

shm_name = "test_homo_g"
graph1 = graph.copy_to_shared_memory(shm_name)
graph2 = gb.load_from_shared_memory(shm_name, graph.metadata)
graph2 = gb.load_from_shared_memory(shm_name)

assert graph1.total_num_nodes == total_num_nodes
assert graph1.total_num_nodes == total_num_nodes
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def test_hetero_graph_on_shared_memory(

shm_name = "test_hetero_g"
graph1 = graph.copy_to_shared_memory(shm_name)
graph2 = gb.load_from_shared_memory(shm_name, graph.metadata)
graph2 = gb.load_from_shared_memory(shm_name)

assert graph1.total_num_nodes == total_num_nodes
assert graph1.total_num_nodes == total_num_nodes
Expand Down
Loading

0 comments on commit c213444

Please sign in to comment.