diff --git a/.gitignore b/.gitignore index 703a40179..93101057b 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,7 @@ Local # .DS_Store .DS_Store + +data/atom_init_rev.json + +data/readme.txt diff --git a/README.md b/README.md index f39c16507..6ac2fd602 100644 --- a/README.md +++ b/README.md @@ -48,25 +48,33 @@ Please cite the following work if you want to use CGCNN. This package requires: -- [PyTorch](http://pytorch.org) -- [scikit-learn](http://scikit-learn.org/stable/) -- [pymatgen](http://pymatgen.org) +- [PyTorch](http://pytorch.org) (tested on v.1.4.0) +- [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter) +- [scikit-learn](http://scikit-learn.org/stable/) (tested on v.0.22.1) +- [pymatgen](http://pymatgen.org) (tested on v.2020.3.13) -If you are new to Python, the easiest way of installing the prerequisites is via [conda](https://conda.io/docs/index.html). After installing [conda](http://conda.pydata.org/), run the following command to create a new [environment](https://conda.io/docs/user-guide/tasks/manage-environments.html) named `cgcnn` and install all prerequisites: +If you are new to Python, the easiest way of installing the prerequisites is via [conda](https://conda.io/docs/index.html) and `pip`. After installing [conda](http://conda.pydata.org/), run the following command to create a new [environment](https://conda.io/docs/user-guide/tasks/manage-environments.html) named `cgcnn` and install all prerequisites: ```bash -conda upgrade conda -conda create -n cgcnn python=3 scikit-learn pytorch torchvision pymatgen -c pytorch -c conda-forge +conda create -n cgcnn python=3 +source activate cgcnn +conda install pytorch torchvision cudatoolkit=10.1 -c pytorch +pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.4.0.html +conda install scikit-learn +pip install pymatgen ``` - -*Note: this code is tested for PyTorch v1.0.0+ and is not compatible with versions below v0.4.0 due to some breaking changes. - This creates a conda environment for running CGCNN. Before using CGCNN, activate the environment by: ```bash source activate cgcnn ``` +finally once inside the environment install torch-scatter + +```bash +pip install torch-scatter +``` + Then, in directory `cgcnn`, you can test if all the prerequisites are installed properly by running: ```bash @@ -189,7 +197,7 @@ To reproduce our paper, you can download the corresponding datasets following th ## Authors -This software was primarily written by [Tian Xie](http://txie.me) who was advised by [Prof. Jeffrey Grossman](https://dmse.mit.edu/faculty/profile/grossman). +This software was primarily written by [Tian Xie](http://txie.me) and [Prof. Jeffrey Grossman](https://dmse.mit.edu/faculty/profile/grossman). This version contains [changes made by Rhys Goodall](https://github.com/txie-93/cgcnn/pull/16) to remove zero-padding. This version also contains several changes made by [Andrew S. Rosen](https://asrosen.com/) to enhance the usability of the CGCNN code, including but not limited to: new command-line arguments, saving the CIF data as .json files so they don't have to be regenerated on-the-fly, and [the option](https://github.com/txie-93/cgcnn/pull/17) to use crystal graphs generated by Pymatgen. If using the `--enable-tanh` flag, you should consider referencing the work of [Noh et al.](https://pubs.acs.org/doi/10.1021/acs.jcim.0c00003), who called this approach CGCNN-H. ## License diff --git a/cgcnn/data.py b/cgcnn/data.py index dfd216734..2fb885671 100644 --- a/cgcnn/data.py +++ b/cgcnn/data.py @@ -1,15 +1,16 @@ -from __future__ import print_function, division - import csv import functools import json import os import random import warnings +import pickle import numpy as np import torch +from pymatgen.analysis import local_env from pymatgen.core.structure import Structure +from pymatgen.analysis.graphs import StructureGraph from torch.utils.data import Dataset, DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.data.sampler import SubsetRandomSampler @@ -21,9 +22,7 @@ def get_train_val_test_loader(dataset, collate_fn=default_collate, num_workers=1, pin_memory=False, **kwargs): """ Utility function for dividing a dataset to train, val, test datasets. - !!! The dataset needs to be shuffled before using the function !!! - Parameters ---------- dataset: torch.utils.data.Dataset @@ -38,7 +37,6 @@ def get_train_val_test_loader(dataset, collate_fn=default_collate, data will be hidden. num_workers: int pin_memory: bool - Returns ------- train_loader: torch.utils.data.DataLoader @@ -97,23 +95,18 @@ def collate_pool(dataset_list): """ Collate a list of data and return a batch for predicting crystal properties. - Parameters ---------- - dataset_list: list of tuples for each data point. (atom_fea, nbr_fea, nbr_fea_idx, target) - atom_fea: torch.Tensor shape (n_i, atom_fea_len) nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) nbr_fea_idx: torch.LongTensor shape (n_i, M) target: torch.Tensor shape (1, ) cif_id: str or int - Returns ------- N = sum(n_i); N0 = sum(i) - batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len) Atom features from atom type batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len) @@ -126,25 +119,31 @@ def collate_pool(dataset_list): Target value for prediction batch_cif_ids: list """ - batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], [] + batch_atom_fea, batch_nbr_fea = [], [] + batch_self_fea_idx, batch_nbr_fea_idx = [], [] crystal_atom_idx, batch_target = [], [] batch_cif_ids = [] base_idx = 0 - for i, ((atom_fea, nbr_fea, nbr_fea_idx), target, cif_id)\ + for i, ((atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx), target, cif_id)\ in enumerate(dataset_list): + n_i = atom_fea.shape[0] # number of atoms for this crystal + batch_atom_fea.append(atom_fea) batch_nbr_fea.append(nbr_fea) + batch_self_fea_idx.append(self_fea_idx+base_idx) batch_nbr_fea_idx.append(nbr_fea_idx+base_idx) - new_idx = torch.LongTensor(np.arange(n_i)+base_idx) - crystal_atom_idx.append(new_idx) + + crystal_atom_idx.extend([i]*n_i) batch_target.append(target) batch_cif_ids.append(cif_id) base_idx += n_i + return (torch.cat(batch_atom_fea, dim=0), torch.cat(batch_nbr_fea, dim=0), + torch.cat(batch_self_fea_idx, dim=0), torch.cat(batch_nbr_fea_idx, dim=0), - crystal_atom_idx),\ + torch.LongTensor(crystal_atom_idx)),\ torch.stack(batch_target, dim=0),\ batch_cif_ids @@ -152,14 +151,13 @@ def collate_pool(dataset_list): class GaussianDistance(object): """ Expands the distance by Gaussian basis. - Unit: angstrom """ + def __init__(self, dmin, dmax, step, var=None): """ Parameters ---------- - dmin: float Minimum interatomic distance dmax: float @@ -176,14 +174,11 @@ def __init__(self, dmin, dmax, step, var=None): def expand(self, distances): """ - Apply Gaussian disntance filter to a numpy distance array - + Apply Gaussian distance filter to a numpy distance array Parameters ---------- - distance: np.array shape n-d array A distance matrix of any shape - Returns ------- expanded_distance: shape (n+1)-d array @@ -197,9 +192,9 @@ def expand(self, distances): class AtomInitializer(object): """ Base class for intializing the vector representation for atoms. - !!! Use one AtomInitializer per dataset !!! """ + def __init__(self, atom_types): self.atom_types = set(atom_types) self._embedding = {} @@ -229,13 +224,12 @@ class AtomCustomJSONInitializer(AtomInitializer): Initialize atom feature vectors using a JSON file, which is a python dictionary mapping from element number to a list representing the feature vector of the element. - Parameters ---------- - elem_embedding_file: str The path to the .json file """ + def __init__(self, elem_embedding_file): with open(elem_embedding_file) as f: elem_embedding = json.load(f) @@ -252,65 +246,101 @@ class CIFData(Dataset): The CIFData dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files. The dataset should have the following directory structure: - root_dir ├── id_prop.csv ├── atom_init.json ├── id0.cif ├── id1.cif ├── ... - id_prop.csv: a CSV file with two columns. The first column recodes a unique ID for each crystal, and the second column recodes the value of target property. - atom_init.json: a JSON file that stores the initialization vector for each element. - ID.cif: a CIF file that recodes the crystal structure, where ID is the unique ID for the crystal. - Parameters ---------- - root_dir: str The path to the root directory of the dataset max_num_nbr: int The maximum number of neighbors while constructing the crystal graph radius: float - The cutoff radius for searching neighbors + The (maximum) cutoff radius for searching neighbors + nn_method: string + A pymatgen.analysis.local_env.NearNeighbors object used to construct + a pymatgen.analysis.graphs.StructureGraph dmin: float The minimum distance for constructing GaussianDistance step: float The step size for constructing GaussianDistance + disable_save_torch: bool + Don't save torch files containing CIFData crystal graphs random_seed: int Random seed for shuffling the dataset - Returns ------- - atom_fea: torch.Tensor shape (n_i, atom_fea_len) nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len) nbr_fea_idx: torch.LongTensor shape (n_i, M) target: torch.Tensor shape (1, ) cif_id: str or int """ - def __init__(self, root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, - random_seed=123): + + def __init__(self, root_dir, max_num_nbr=12, radius=8, nn_method=None, + dmin=0, step=0.2, disable_save_torch=False, random_seed=123): self.root_dir = root_dir - self.max_num_nbr, self.radius = max_num_nbr, radius + self.max_num_nbr, self.radius, self.nn_method = max_num_nbr, radius, nn_method + self.disable_save_torch = disable_save_torch assert os.path.exists(root_dir), 'root_dir does not exist!' id_prop_file = os.path.join(self.root_dir, 'id_prop.csv') assert os.path.exists(id_prop_file), 'id_prop.csv does not exist!' with open(id_prop_file) as f: reader = csv.reader(f) - self.id_prop_data = [row for row in reader] + self.id_prop_data = [[x.strip().replace('\ufeff', '') + for x in row] for row in reader] random.seed(random_seed) random.shuffle(self.id_prop_data) atom_init_file = os.path.join(self.root_dir, 'atom_init.json') assert os.path.exists(atom_init_file), 'atom_init.json does not exist!' self.ari = AtomCustomJSONInitializer(atom_init_file) self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step) + self.torch_data_path = os.path.join(self.root_dir, 'cifdata') + if self.nn_method: + if self.nn_method.lower() == 'minimumvirenn': + self.nn_object = local_env.MinimumVIRENN() + elif self.nn_method.lower() == 'voronoinn': + self.nn_object = local_env.VoronoiNN() + elif self.nn_method.lower() == 'jmolnn': + self.nn_object = local_env.JmolNN() + elif self.nn_method.lower() == 'minimumdistancenn': + self.nn_object = local_env.MinimumDistanceNN() + elif self.nn_method.lower() == 'minimumokeeffenn': + self.nn_object = local_env.MinimumOKeeffeNN() + elif self.nn_method.lower() == 'brunnernn_real': + self.nn_object = local_env.BrunnerNN_real() + elif self.nn_method.lower() == 'brunnernn_reciprocal': + self.nn_object = local_env.BrunnerNN_reciprocal() + elif self.nn_method.lower() == 'brunnernn_relative': + self.nn_object = local_env.BrunnerNN_relative() + elif self.nn_method.lower() == 'econnn': + self.nn_object = local_env.EconNN() + elif self.nn_method.lower() == 'cutoffdictnn': + # requires a cutoff dictionary located in cgcnn/cut_off_dict.txt + self.nn_object = local_env.CutOffDictNN( + cut_off_dict='cut_off_dict.txt') + elif self.nn_method.lower() == 'critic2nn': + self.nn_object = local_env.Critic2NN() + elif self.nn_method.lower() == 'openbabelnn': + self.nn_object = local_env.OpenBabelNN() + elif self.nn_method.lower() == 'covalentbondnn': + self.nn_object = local_env.CovalentBondNN() + elif self.nn_method.lower() == 'crystalnn': + self.nn_object = local_env.CrystalNN() + else: + raise ValueError('Invalid NN algorithm specified') + else: + self.nn_object = None def __len__(self): return len(self.id_prop_data) @@ -318,33 +348,69 @@ def __len__(self): @functools.lru_cache(maxsize=None) # Cache loaded structures def __getitem__(self, idx): cif_id, target = self.id_prop_data[idx] - crystal = Structure.from_file(os.path.join(self.root_dir, - cif_id+'.cif')) - atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) - for i in range(len(crystal))]) - atom_fea = torch.Tensor(atom_fea) - all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True) - all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs] - nbr_fea_idx, nbr_fea = [], [] - for nbr in all_nbrs: - if len(nbr) < self.max_num_nbr: - warnings.warn('{} not find enough neighbors to build graph. ' - 'If it happens frequently, consider increase ' - 'radius.'.format(cif_id)) - nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) + - [0] * (self.max_num_nbr - len(nbr))) - nbr_fea.append(list(map(lambda x: x[1], nbr)) + - [self.radius + 1.] * (self.max_num_nbr - - len(nbr))) + cif_id = cif_id.replace('', '') + + target = torch.Tensor([float(target)]) + + if os.path.exists(os.path.join(self.torch_data_path, cif_id+'.pkl')): + with open(os.path.join(self.torch_data_path, cif_id+'.pkl'), 'rb') as f: + pkl_data = pickle.load(f) + atom_fea = pkl_data[0] + nbr_fea = pkl_data[1] + self_fea_idx = pkl_data[2] + nbr_fea_idx = pkl_data[3] + + else: + crystal = Structure.from_file( + os.path.join(self.root_dir, cif_id+'.cif')) + # atom features + atom_fea = np.vstack([self.ari.get_atom_fea(crystal[i].specie.number) + for i in range(len(crystal))]) + + self_fea_idx, nbr_fea_idx, nbr_fea = [], [], [] + if self.nn_object: + graph = StructureGraph.with_local_env_strategy( + crystal, self.nn_object) + all_nbrs = [] + dist_idx = -1 + for i in range(len(crystal)): + nbr = graph.get_connected_sites(i) + nbr = [nbrs for nbrs in nbr if nbrs[dist_idx] <= self.radius] + all_nbrs.append(nbr) else: - nbr_fea_idx.append(list(map(lambda x: x[2], + dist_idx = 1 + all_nbrs = crystal.get_all_neighbors( + self.radius, include_index=True) + all_nbrs = [sorted(nbrs, key=lambda x: x[dist_idx]) + for nbrs in all_nbrs] + for i, nbr in enumerate(all_nbrs): + if len(nbr) < self.max_num_nbr: + warnings.warn('{} does not have enough neighbors to build graph. ' + 'If it happens frequently, consider increasing ' + 'radius or decreasing max_num_nbr.'.format(cif_id)) + + nbr_fea_idx.extend(list(map(lambda x: x[2], nbr))) + nbr_fea.extend(list(map(lambda x: x[dist_idx], nbr))) + + else: + nbr_fea_idx.extend(list(map(lambda x: x[2], + nbr[:self.max_num_nbr]))) + nbr_fea.extend(list(map(lambda x: x[dist_idx], nbr[:self.max_num_nbr]))) - nbr_fea.append(list(map(lambda x: x[1], - nbr[:self.max_num_nbr]))) - nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea) - nbr_fea = self.gdf.expand(nbr_fea) - atom_fea = torch.Tensor(atom_fea) - nbr_fea = torch.Tensor(nbr_fea) - nbr_fea_idx = torch.LongTensor(nbr_fea_idx) - target = torch.Tensor([float(target)]) - return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id + + self_fea_idx.extend([i]*min(len(nbr), self.max_num_nbr)) + + nbr_fea = np.array(nbr_fea) + nbr_fea = self.gdf.expand(nbr_fea) + + atom_fea = torch.Tensor(atom_fea) + nbr_fea = torch.Tensor(nbr_fea) + self_fea_idx = torch.LongTensor(self_fea_idx) + nbr_fea_idx = torch.LongTensor(nbr_fea_idx) + + if not self.disable_save_torch: + with open(os.path.join(self.torch_data_path, cif_id+'.pkl'), 'wb') as f: + pickle.dump( + (atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx), f) + + return (atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx), target, cif_id diff --git a/cgcnn/model.py b/cgcnn/model.py index de739532f..d238210b7 100644 --- a/cgcnn/model.py +++ b/cgcnn/model.py @@ -1,14 +1,14 @@ -from __future__ import print_function, division - import torch import torch.nn as nn +from torch_scatter import scatter_mean, scatter_add class ConvLayer(nn.Module): """ Convolutional operation on graphs """ - def __init__(self, atom_fea_len, nbr_fea_len): + + def __init__(self, atom_fea_len, nbr_fea_len, enable_tanh=False): """ Initialize ConvLayer. @@ -26,12 +26,17 @@ def __init__(self, atom_fea_len, nbr_fea_len): self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len, 2*self.atom_fea_len) self.sigmoid = nn.Sigmoid() - self.softplus1 = nn.Softplus() + if enable_tanh: + self.act1 = nn.Tanh() + self.act2 = nn.Tanh() + else: + self.act1 = nn.Softplus() + self.act2 = nn.Softplus() self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len) self.bn2 = nn.BatchNorm1d(self.atom_fea_len) - self.softplus2 = nn.Softplus() + self.pooling = SumPooling() - def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx): + def forward(self, atom_in_fea, nbr_fea, self_fea_idx, nbr_fea_idx): """ Forward pass @@ -55,22 +60,26 @@ def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx): Atom hidden features after convolution """ - # TODO will there be problems with the index zero padding? - N, M = nbr_fea_idx.shape # convolution atom_nbr_fea = atom_in_fea[nbr_fea_idx, :] - total_nbr_fea = torch.cat( - [atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len), - atom_nbr_fea, nbr_fea], dim=2) - total_gated_fea = self.fc_full(total_nbr_fea) - total_gated_fea = self.bn1(total_gated_fea.view( - -1, self.atom_fea_len*2)).view(N, M, self.atom_fea_len*2) - nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2) - nbr_filter = self.sigmoid(nbr_filter) - nbr_core = self.softplus1(nbr_core) - nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1) + atom_self_fea = atom_in_fea[self_fea_idx, :] + + total_fea = torch.cat([atom_self_fea, atom_nbr_fea, nbr_fea], dim=1) + + total_fea = self.fc_full(total_fea) + total_fea = self.bn1(total_fea) + + filter_fea, core_fea = total_fea.chunk(2, dim=1) + filter_fea = self.sigmoid(filter_fea) + core_fea = self.act1(core_fea) + + # take the elementwise product of the filter and core + nbr_msg = filter_fea * core_fea + nbr_sumed = self.pooling(nbr_msg, self_fea_idx) + nbr_sumed = self.bn2(nbr_sumed) - out = self.softplus2(atom_in_fea + nbr_sumed) + out = self.act2(atom_in_fea + nbr_sumed) + return out @@ -79,9 +88,10 @@ class CrystalGraphConvNet(nn.Module): Create a crystal graph convolutional neural network for predicting total material properties. """ + def __init__(self, orig_atom_fea_len, nbr_fea_len, atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, - classification=False): + classification=False, enable_tanh=False): """ Initialize CrystalGraphConvNet. @@ -100,13 +110,18 @@ def __init__(self, orig_atom_fea_len, nbr_fea_len, Number of hidden features after pooling n_h: int Number of hidden layers after pooling + classification: bool + If classification should be done instead of regression + enable_tanh: bool + Use tanh instead of softplus """ super(CrystalGraphConvNet, self).__init__() self.classification = classification self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len) self.convs = nn.ModuleList([ConvLayer(atom_fea_len=atom_fea_len, - nbr_fea_len=nbr_fea_len) + nbr_fea_len=nbr_fea_len, enable_tanh=enable_tanh) for _ in range(n_conv)]) + self.pooling = MeanPooling() self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len) self.conv_to_fc_softplus = nn.Softplus() if n_h > 1: @@ -118,11 +133,12 @@ def __init__(self, orig_atom_fea_len, nbr_fea_len, self.fc_out = nn.Linear(h_fea_len, 2) else: self.fc_out = nn.Linear(h_fea_len, 1) + if self.classification: self.logsoftmax = nn.LogSoftmax(dim=1) self.dropout = nn.Dropout() - def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx): + def forward(self, atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx, crystal_atom_idx): """ Forward pass @@ -150,38 +166,60 @@ def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx): """ atom_fea = self.embedding(atom_fea) + for conv_func in self.convs: - atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx) + atom_fea = conv_func(atom_fea, nbr_fea, self_fea_idx, nbr_fea_idx) + crys_fea = self.pooling(atom_fea, crystal_atom_idx) + crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea)) crys_fea = self.conv_to_fc_softplus(crys_fea) + if self.classification: crys_fea = self.dropout(crys_fea) + if hasattr(self, 'fcs') and hasattr(self, 'softpluses'): for fc, softplus in zip(self.fcs, self.softpluses): crys_fea = softplus(fc(crys_fea)) + out = self.fc_out(crys_fea) + if self.classification: out = self.logsoftmax(out) return out - def pooling(self, atom_fea, crystal_atom_idx): - """ - Pooling the atom features to crystal features - N: Total number of atoms in the batch - N0: Total number of crystals in the batch +class MeanPooling(nn.Module): + """ + mean pooling + """ - Parameters - ---------- + def __init__(self): + super(MeanPooling, self).__init__() - atom_fea: Variable(torch.Tensor) shape (N, atom_fea_len) - Atom feature vectors of the batch - crystal_atom_idx: list of torch.LongTensor of length N0 - Mapping from the crystal idx to atom idx - """ - assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\ - atom_fea.data.shape[0] - summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True) - for idx_map in crystal_atom_idx] - return torch.cat(summed_fea, dim=0) + def forward(self, x, index): + + mean = scatter_mean(x, index, dim=0) + + return mean + + def __repr__(self): + return '{}'.format(self.__class__.__name__) + + +class SumPooling(nn.Module): + """ + mean pooling + """ + + def __init__(self): + super(SumPooling, self).__init__() + + def forward(self, x, index): + + mean = scatter_add(x, index, dim=0) + + return mean + + def __repr__(self): + return '{}'.format(self.__class__.__name__) diff --git a/main.py b/main.py index 5f3cad321..5b77a973f 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import time import warnings from random import sample +import csv import numpy as np import torch @@ -18,13 +19,14 @@ from cgcnn.data import collate_pool, get_train_val_test_loader from cgcnn.model import CrystalGraphConvNet -parser = argparse.ArgumentParser(description='Crystal Graph Convolutional Neural Networks') +parser = argparse.ArgumentParser( + description='Crystal Graph Convolutional Neural Networks') parser.add_argument('data_options', metavar='OPTIONS', nargs='+', help='dataset options, started with the path to root dir, ' 'then other options') parser.add_argument('--task', choices=['regression', 'classification'], default='regression', help='complete a regression or ' - 'classification task (default: regression)') + 'classification task (default: regression)') parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', @@ -49,21 +51,24 @@ metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') + train_group = parser.add_mutually_exclusive_group() train_group.add_argument('--train-ratio', default=None, type=float, metavar='N', - help='number of training data to be loaded (default none)') + help='number of training data to be loaded (default none)') train_group.add_argument('--train-size', default=None, type=int, metavar='N', help='number of training data to be loaded (default none)') + valid_group = parser.add_mutually_exclusive_group() valid_group.add_argument('--val-ratio', default=0.1, type=float, metavar='N', - help='percentage of validation data to be loaded (default ' + help='percentage of validation data to be loaded (default ' '0.1)') valid_group.add_argument('--val-size', default=None, type=int, metavar='N', help='number of validation data to be loaded (default ' '1000)') + test_group = parser.add_mutually_exclusive_group() test_group.add_argument('--test-ratio', default=0.1, type=float, metavar='N', - help='percentage of test data to be loaded (default 0.1)') + help='percentage of test data to be loaded (default 0.1)') test_group.add_argument('--test-size', default=None, type=int, metavar='N', help='number of test data to be loaded (default 1000)') @@ -78,6 +83,19 @@ parser.add_argument('--n-h', default=1, type=int, metavar='N', help='number of hidden layers after pooling') +parser.add_argument('--max-num-nbr', default=12, type=int, metavar='N', + help='Maximum number of neighbors') +parser.add_argument('--radius', default=8, type=int, metavar='N', + help='Radial distance to search for neighbors') +parser.add_argument('--nn-method', default='', type=str, metavar='N', + help='NN algorithm to search for neighbors (defaults to cutoff)') +parser.add_argument('--disable-save-torch', action='store_true', + help='Do not save CIF PyTorch data as .json files') +parser.add_argument('--clean-torch', action='store_true', + help='Clean CIF PyTorch data .json files') +parser.add_argument('--enable-tanh', action='store_true', + help='Use tanh instead of softplus') + args = parser.parse_args(sys.argv[1:]) args.cuda = not args.disable_cuda and torch.cuda.is_available() @@ -92,7 +110,9 @@ def main(): global args, best_mae_error # load data - dataset = CIFData(*args.data_options) + dataset = CIFData(*args.data_options, max_num_nbr=args.max_num_nbr, + radius=args.radius, nn_method=args.nn_method, + disable_save_torch=args.disable_save_torch) collate_fn = collate_pool train_loader, val_loader, test_loader = get_train_val_test_loader( dataset=dataset, @@ -108,6 +128,55 @@ def main(): test_size=args.test_size, return_test=True) + # Make sure >1 class is present + if args.task == 'classification': + total_train = 0 + total_val = 0 + total_test = 0 + + for i, (_, target, _) in enumerate(train_loader): + for target_i in target.squeeze(): + total_train += target_i + if bool(total_train == 0): + raise ValueError('All 0s in train') + elif bool(total_train == 1): + raise ValueError('All 1s in train') + + for i, (_, target, _) in enumerate(val_loader): + if len(target) == 1: + raise ValueError('Only single entry in val') + for target_i in target.squeeze(): + total_val += target_i + if bool(total_val == 0): + raise ValueError('All 0s in val') + elif bool(total_val == 1): + raise ValueError('All 1s in val') + + for i, (_, target, _) in enumerate(test_loader): + if len(target) == 1: + raise ValueError('Only single entry in test') + for target_i in target.squeeze(): + total_test += target_i + if bool(total_test == 0): + raise ValueError('All 0s in test') + elif bool(total_test == 1): + raise ValueError('All 1s in test') + + # make output folder if needed + if not os.path.exists('output'): + os.mkdir('output') + + # make and clean torch files if needed + torch_data_path = os.path.join(args.data_options[0], 'cifdata') + if args.clean_torch and os.path.exists(torch_data_path): + shutil.rmtree(torch_data_path) + if os.path.exists(torch_data_path): + if not args.clean_torch: + warnings.warn('Found cifdata folder at ' + + torch_data_path+'. Will read in .jsons as-available') + else: + os.mkdir(torch_data_path) + # obtain target value normalizer if args.task == 'classification': normalizer = Normalizer(torch.zeros(2)) @@ -133,7 +202,8 @@ def main(): h_fea_len=args.h_fea_len, n_h=args.n_h, classification=True if args.task == - 'classification' else False) + 'classification' else False, + enable_tanh=args.enable_tanh) if args.cuda: model.cuda() @@ -200,10 +270,20 @@ def main(): }, is_best) # test best model - print('---------Evaluate Model on Test Set---------------') - best_checkpoint = torch.load('model_best.pth.tar') + best_checkpoint = torch.load(os.path.join('output', 'model_best.pth.tar')) model.load_state_dict(best_checkpoint['state_dict']) - validate(test_loader, model, criterion, normalizer, test=True) + + print('---------Evaluate Best Model on Train Set---------------') + validate(train_loader, model, criterion, normalizer, test=True, + csv_name='train_results.csv') + + print('---------Evaluate Best Model on Val Set---------------') + validate(val_loader, model, criterion, normalizer, test=True, + csv_name='val_results.csv') + + print('---------Evaluate Best Model on Test Set---------------') + validate(test_loader, model, criterion, normalizer, test=True, + csv_name='test_results.csv') def train(train_loader, model, criterion, optimizer, epoch, normalizer): @@ -223,20 +303,15 @@ def train(train_loader, model, criterion, optimizer, epoch, normalizer): model.train() end = time.time() - for i, (input, target, _) in enumerate(train_loader): + for i, (input_, target, _) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) if args.cuda: - input_var = (Variable(input[0].cuda(non_blocking=True)), - Variable(input[1].cuda(non_blocking=True)), - input[2].cuda(non_blocking=True), - [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]]) + input_var = (tensor.to("cuda") for tensor in input_) else: - input_var = (Variable(input[0]), - Variable(input[1]), - input[2], - input[3]) + input_var = input_ + # normalize target if args.task == 'regression': target_normed = normalizer.norm(target) @@ -282,9 +357,9 @@ def train(train_loader, model, criterion, optimizer, epoch, normalizer): 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format( - epoch, i, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, mae_errors=mae_errors) - ) + epoch+1, i+1, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, mae_errors=mae_errors) + ) else: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' @@ -295,14 +370,16 @@ def train(train_loader, model, criterion, optimizer, epoch, normalizer): 'Recall {recall.val:.3f} ({recall.avg:.3f})\t' 'F1 {f1.val:.3f} ({f1.avg:.3f})\t' 'AUC {auc.val:.3f} ({auc.avg:.3f})'.format( - epoch, i, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, accu=accuracies, - prec=precisions, recall=recalls, f1=fscores, - auc=auc_scores) - ) + epoch+1, i+1, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, accu=accuracies, + prec=precisions, recall=recalls, f1=fscores, + auc=auc_scores) + ) + +def validate(val_loader, model, criterion, normalizer, test=False, + csv_name='test_results.csv'): -def validate(val_loader, model, criterion, normalizer, test=False): batch_time = AverageMeter() losses = AverageMeter() if args.task == 'regression': @@ -322,19 +399,15 @@ def validate(val_loader, model, criterion, normalizer, test=False): model.eval() end = time.time() - for i, (input, target, batch_cif_ids) in enumerate(val_loader): + for i, (input_, target, batch_cif_ids) in enumerate(val_loader): if args.cuda: with torch.no_grad(): - input_var = (Variable(input[0].cuda(non_blocking=True)), - Variable(input[1].cuda(non_blocking=True)), - input[2].cuda(non_blocking=True), - [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]]) + input_var = (tensor.to("cuda") for tensor in input_) + else: with torch.no_grad(): - input_var = (Variable(input[0]), - Variable(input[1]), - input[2], - input[3]) + input_var = input_ + if args.task == 'regression': target_normed = normalizer.norm(target) else: @@ -382,16 +455,21 @@ def validate(val_loader, model, criterion, normalizer, test=False): batch_time.update(time.time() - end) end = time.time() + if test: + print_name = 'Test' + else: + print_name = 'Validation' + if i % args.print_freq == 0: if args.task == 'regression': - print('Test: [{0}/{1}]\t' + print(print_name+': [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format( - i, len(val_loader), batch_time=batch_time, loss=losses, - mae_errors=mae_errors)) + i+1, len(val_loader), batch_time=batch_time, loss=losses, + mae_errors=mae_errors)) else: - print('Test: [{0}/{1}]\t' + print(print_name+': [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Accu {accu.val:.3f} ({accu.avg:.3f})\t' @@ -399,14 +477,13 @@ def validate(val_loader, model, criterion, normalizer, test=False): 'Recall {recall.val:.3f} ({recall.avg:.3f})\t' 'F1 {f1.val:.3f} ({f1.avg:.3f})\t' 'AUC {auc.val:.3f} ({auc.avg:.3f})'.format( - i, len(val_loader), batch_time=batch_time, loss=losses, - accu=accuracies, prec=precisions, recall=recalls, - f1=fscores, auc=auc_scores)) + i+1, len(val_loader), batch_time=batch_time, loss=losses, + accu=accuracies, prec=precisions, recall=recalls, + f1=fscores, auc=auc_scores)) if test: star_label = '**' - import csv - with open('test_results.csv', 'w') as f: + with open(os.path.join('output', csv_name), 'w') as f: writer = csv.writer(f) for cif_id, target, pred in zip(test_cif_ids, test_targets, test_preds): @@ -495,10 +572,10 @@ def update(self, val, n=1): self.avg = self.sum / self.count -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): +def save_checkpoint(state, is_best, filename=os.path.join('output', 'checkpoint.pth.tar')): torch.save(state, filename) if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') + shutil.copyfile(filename, os.path.join('output', 'model_best.pth.tar')) def adjust_learning_rate(optimizer, epoch, k): diff --git a/predict.py b/predict.py index 57df29e63..56a2fdd5e 100644 --- a/predict.py +++ b/predict.py @@ -1,8 +1,9 @@ import argparse import os -import shutil import sys import time +import shutil +import warnings import numpy as np import torch @@ -11,21 +12,26 @@ from torch.autograd import Variable from torch.utils.data import DataLoader -from cgcnn.data import CIFData +from cgcnn.data import CIFData, get_train_val_test_loader from cgcnn.data import collate_pool from cgcnn.model import CrystalGraphConvNet -parser = argparse.ArgumentParser(description='Crystal gated neural networks') +parser = argparse.ArgumentParser( + description='Crystal Graph Convolutional Neural Networks') parser.add_argument('modelpath', help='path to the trained model.') parser.add_argument('cifpath', help='path to the directory of CIF files.') -parser.add_argument('-b', '--batch-size', default=256, type=int, - metavar='N', help='mini-batch size (default: 256)') parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers (default: 0)') parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') +parser.add_argument('--disable-save-torch', action='store_true', + help='Do not save CIF PyTorch data as .json files') +parser.add_argument('--clean-torch', action='store_true', + help='Clean CIF PyTorch data .json files') +parser.add_argument('--train-val-test', action='store_true', + help='Return training/validation/testing results') args = parser.parse_args(sys.argv[1:]) if os.path.isfile(args.modelpath): @@ -49,11 +55,40 @@ def main(): global args, model_args, best_mae_error # load data - dataset = CIFData(args.cifpath) + dataset = CIFData(args.cifpath, max_num_nbr=model_args.max_num_nbr, + radius=model_args.radius, nn_method=model_args.nn_method, + disable_save_torch=args.disable_save_torch) collate_fn = collate_pool - test_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, - num_workers=args.workers, collate_fn=collate_fn, - pin_memory=args.cuda) + + if args.train_val_test: + train_loader, val_loader, test_loader = get_train_val_test_loader( + dataset=dataset, + collate_fn=collate_fn, + batch_size=model_args.batch_size, + train_ratio=model_args.train_ratio, + num_workers=args.workers, + val_ratio=model_args.val_ratio, + test_ratio=model_args.test_ratio, + pin_memory=args.cuda, + train_size=model_args.train_size, + val_size=model_args.val_size, + test_size=model_args.test_size, + return_test=True) + else: + test_loader = DataLoader(dataset, batch_size=model_args.batch_size, shuffle=True, + num_workers=args.workers, collate_fn=collate_fn, + pin_memory=args.cuda) + + # make and clean torch files if needed + torch_data_path = os.path.join(args.cifpath, 'cifdata') + if args.clean_torch and os.path.exists(torch_data_path): + shutil.rmtree(torch_data_path) + if os.path.exists(torch_data_path): + if not args.clean_torch: + warnings.warn('Found torch .json files at ' + + torch_data_path+'. Will read in .jsons as-available') + else: + os.mkdir(torch_data_path) # build model structures, _, _ = dataset[0] @@ -65,7 +100,8 @@ def main(): h_fea_len=model_args.h_fea_len, n_h=model_args.n_h, classification=True if model_args.task == - 'classification' else False) + 'classification' else False, + enable_tanh=model_args.enable_tanh) if args.cuda: model.cuda() @@ -74,15 +110,6 @@ def main(): criterion = nn.NLLLoss() else: criterion = nn.MSELoss() - # if args.optim == 'SGD': - # optimizer = optim.SGD(model.parameters(), args.lr, - # momentum=args.momentum, - # weight_decay=args.weight_decay) - # elif args.optim == 'Adam': - # optimizer = optim.Adam(model.parameters(), args.lr, - # weight_decay=args.weight_decay) - # else: - # raise NameError('Only SGD or Adam is allowed as --optim') normalizer = Normalizer(torch.zeros(3)) @@ -99,10 +126,24 @@ def main(): else: print("=> no model found at '{}'".format(args.modelpath)) - validate(test_loader, model, criterion, normalizer, test=True) + if args.train_val_test: + print('---------Evaluate Model on Train Set---------------') + validate(train_loader, model, criterion, normalizer, test=True, + csv_name='train_results.csv') + print('---------Evaluate Model on Val Set---------------') + validate(val_loader, model, criterion, normalizer, test=True, + csv_name='val_results.csv') + print('---------Evaluate Model on Test Set---------------') + validate(test_loader, model, criterion, normalizer, test=True, + csv_name='test_results.csv') + else: + print('---------Evaluate Model on Dataset---------------') + validate(test_loader, model, criterion, normalizer, test=True, + csv_name='predictions.csv') -def validate(val_loader, model, criterion, normalizer, test=False): +def validate(val_loader, model, criterion, normalizer, test=False, + csv_name='test_results.csv'): batch_time = AverageMeter() losses = AverageMeter() if model_args.task == 'regression': @@ -122,18 +163,12 @@ def validate(val_loader, model, criterion, normalizer, test=False): model.eval() end = time.time() - for i, (input, target, batch_cif_ids) in enumerate(val_loader): + for i, (input_, target, batch_cif_ids) in enumerate(val_loader): with torch.no_grad(): if args.cuda: - input_var = (Variable(input[0].cuda(non_blocking=True)), - Variable(input[1].cuda(non_blocking=True)), - input[2].cuda(non_blocking=True), - [crys_idx.cuda(non_blocking=True) for crys_idx in input[3]]) + input_var = (tensor.to("cuda") for tensor in input_) else: - input_var = (Variable(input[0]), - Variable(input[1]), - input[2], - input[3]) + input_var = input_ if model_args.task == 'regression': target_normed = normalizer.norm(target) else: @@ -186,8 +221,8 @@ def validate(val_loader, model, criterion, normalizer, test=False): 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'MAE {mae_errors.val:.3f} ({mae_errors.avg:.3f})'.format( - i, len(val_loader), batch_time=batch_time, loss=losses, - mae_errors=mae_errors)) + i+1, len(val_loader), batch_time=batch_time, loss=losses, + mae_errors=mae_errors)) else: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' @@ -197,14 +232,14 @@ def validate(val_loader, model, criterion, normalizer, test=False): 'Recall {recall.val:.3f} ({recall.avg:.3f})\t' 'F1 {f1.val:.3f} ({f1.avg:.3f})\t' 'AUC {auc.val:.3f} ({auc.avg:.3f})'.format( - i, len(val_loader), batch_time=batch_time, loss=losses, - accu=accuracies, prec=precisions, recall=recalls, - f1=fscores, auc=auc_scores)) + i+1, len(val_loader), batch_time=batch_time, loss=losses, + accu=accuracies, prec=precisions, recall=recalls, + f1=fscores, auc=auc_scores)) if test: star_label = '**' import csv - with open('test_results.csv', 'w') as f: + with open(os.path.join('output', csv_name), 'w') as f: writer = csv.writer(f) for cif_id, target, pred in zip(test_cif_ids, test_targets, test_preds): @@ -223,6 +258,7 @@ def validate(val_loader, model, criterion, normalizer, test=False): class Normalizer(object): """Normalize a Tensor and restore it later. """ + def __init__(self, tensor): """tensor is taken as a sample to calculate the mean and std""" self.mean = torch.mean(tensor) @@ -273,6 +309,7 @@ def class_eval(prediction, target): class AverageMeter(object): """Computes and stores the average and current value""" + def __init__(self): self.reset() @@ -289,11 +326,5 @@ def update(self, val, n=1): self.avg = self.sum / self.count -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): - torch.save(state, filename) - if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') - - if __name__ == '__main__': main()