-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathcollate.py
41 lines (34 loc) · 1.22 KB
/
collate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Any, List, Optional, Sequence
from ase.atoms import Atoms
from torch_geometric.loader.dataloader import Collater
from sevenn.atom_graph_data import AtomGraphData
from .dataload import atoms_to_graph
class AtomsToGraphCollater(Collater):
def __init__(
self,
dataset: Sequence[Atoms],
cutoff: float,
transfer_info: bool = False,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
y_from_calc: bool = True,
):
# quite original collator's type mismatch with []
super().__init__([], follow_batch, exclude_keys)
self.dataset = dataset
self.cutoff = cutoff
self.transfer_info = transfer_info
self.y_from_calc = y_from_calc
def __call__(self, batch: List[Any]) -> Any:
# build list of graph
graph_list = []
for stct in batch:
graph = atoms_to_graph(
stct,
self.cutoff,
transfer_info=self.transfer_info,
y_from_calc=self.y_from_calc,
)
graph = AtomGraphData.from_numpy_dict(graph)
graph_list.append(graph)
return super().__call__(graph_list)