From 6973a8d51df4c140ef74476c85945a960c8ac91b Mon Sep 17 00:00:00 2001 From: Nathan Hughes Date: Fri, 19 Jul 2024 19:24:49 +0000 Subject: [PATCH] fix pytorch tests --- python/src/spark_dsg/torch_conversion.py | 26 ++++++++---------------- python/tests/test_torch.py | 14 ++++++------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/python/src/spark_dsg/torch_conversion.py b/python/src/spark_dsg/torch_conversion.py index 12dda39..f730b3a 100644 --- a/python/src/spark_dsg/torch_conversion.py +++ b/python/src/spark_dsg/torch_conversion.py @@ -39,18 +39,14 @@ graphs. Note that `DynamicSceneGraph.to_torch()` calls into the relevant homogeneous or heterogeneous conversion function. """ -from spark_dsg._dsg_bindings import ( - DynamicSceneGraph, - SceneGraphLayer, - SceneGraphNode, - SceneGraphEdge, - DsgLayers, - LayerView, -) -from typing import Callable, Optional, Dict, Union -import numpy as np import importlib +from typing import Callable, Dict, Optional, Union + +import numpy as np +from spark_dsg._dsg_bindings import (DsgLayers, DynamicSceneGraph, LayerView, + SceneGraphEdge, SceneGraphLayer, + SceneGraphNode) NodeConversionFunc = Callable[[DynamicSceneGraph, SceneGraphNode], np.ndarray] EdgeConversionFunc = Callable[[DynamicSceneGraph, SceneGraphEdge], np.ndarray] @@ -136,17 +132,13 @@ def scene_graph_layer_to_torch( scene graph layer. """ torch, torch_geometric = _get_torch() - # output torch tensor data types - if double_precision: - dtype_float = torch.float64 - else: - dtype_float = torch.float32 + dtype_float = torch.float64 if double_precision else torch.float32 N = G.num_nodes() node_features = [] - node_positions = torch.zeros((N, 3), dtype=torch.float64) + node_positions = torch.zeros((N, 3), dtype=dtype_float) id_map = {} for node in G.nodes: @@ -168,7 +160,7 @@ def scene_graph_layer_to_torch( edge_features.append(edge_converter(G, edge)) if edge_converter is not None: - edge_features = torch.tensor(np.array(edge_features), dtype_float) + edge_features = torch.tensor(np.array(edge_features), dtype=dtype_float) if edge_index.size(dim=1) > 0: if edge_converter is None: diff --git a/python/tests/test_torch.py b/python/tests/test_torch.py index 0c27943..e6108ca 100644 --- a/python/tests/test_torch.py +++ b/python/tests/test_torch.py @@ -73,7 +73,7 @@ def _check_interlayer_edges(G, data, to_check, has_edge_attrs=False): edge_name = f"{source}_to_{target}" assert (source, edge_name, target) in metadata[1] assert data[source, edge_name, target].edge_index.size(dim=0) == 2 - assert data[source, edge_name, target].edge_index.size(dim=1) >= 2 + assert data[source, edge_name, target].edge_index.size(dim=1) >= 1 if has_edge_attrs: assert data[source, edge_name, target].edge_attr.size(dim=1) == 20 assert data[source, edge_name, target].edge_attr.size(dim=0) == data[ @@ -86,7 +86,7 @@ def test_torch_layer(resource_dir, has_torch): if not has_torch: return pytest.skip(reason="requires pytorch and pytorch geometric") - G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json")) + G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json")) places = G.get_layer(dsg.DsgLayers.PLACES) assert places.num_nodes() > 0 assert places.num_edges() > 0 @@ -104,7 +104,7 @@ def test_torch_layer_edge_features(resource_dir, has_torch): if not has_torch: return pytest.skip(reason="requires pytorch and pytorch geometric") - G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json")) + G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json")) places = G.get_layer(dsg.DsgLayers.PLACES) assert places.num_nodes() > 0 assert places.num_edges() > 0 @@ -122,7 +122,7 @@ def test_torch_homogeneous(resource_dir, has_torch): if not has_torch: return pytest.skip(reason="requires pytorch and pytorch geometric") - G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json")) + G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json")) assert G.num_nodes() > 0 assert G.num_edges() > 0 @@ -139,7 +139,7 @@ def test_torch_homogeneous_edge_features(resource_dir, has_torch): if not has_torch: return pytest.skip(reason="requires pytorch and pytorch geometric") - G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json")) + G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json")) assert G.num_nodes() > 0 assert G.num_edges() > 0 @@ -160,7 +160,7 @@ def test_torch_hetereogeneous(resource_dir, has_torch): if not has_torch: return pytest.skip(reason="requires pytorch and pytorch geometric") - G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json")) + G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json")) assert G.num_nodes() > 0 assert G.num_edges() > 0 @@ -190,7 +190,7 @@ def test_torch_hetereogeneous_edge_features(resource_dir, has_torch): if not has_torch: return pytest.skip(reason="requires pytorch and pytorch geometric") - G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_igx_dsg.json")) + G = dsg.DynamicSceneGraph.load(str(resource_dir / "apartment_dsg.json")) assert G.num_nodes() > 0 assert G.num_edges() > 0