Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
AntObi committed Sep 15, 2023
1 parent 9fb67f7 commit 5c6bef6
Showing 1 changed file with 75 additions and 17 deletions.
92 changes: 75 additions & 17 deletions bin/run_matbench_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import pandas as pd
from keras.callbacks import Callback, CSVLogger, EarlyStopping, ModelCheckpoint
from matbench.bench import MatbenchBenchmark
from pymatgen.core import Composition
from pymatgen.analysis.bond_valence import BVAnalyzer
from pymatgen.core import Composition
from sklearn.model_selection import train_test_split

from skipatom import ElemNet, ElemNetClassifier, max_pool, mean_pool, sum_pool
Expand Down Expand Up @@ -45,6 +45,7 @@ def atom_vectors_from_csv(embedding_csv):
dictionary = {e: i for i, e in enumerate(elements)}
return dictionary, embeddings


def species_vectors_from_csv(embedding_csv):
logger.info(f"reading species vectors from {embedding_csv}")
df = pd.read_csv(embedding_csv)
Expand All @@ -63,18 +64,20 @@ def get_composition(val, input_type):
else:
raise Exception(f"unrecognized input type: {input_type}")


def get_composition_species(val, input_type):
if input_type == "composition":
return Composition(val).add_charges_from_oxi_state_guesses()
elif input_type == "structure":
try:
return val.composition.add_charges_from_oxi_state_guesses()
except:
bv= BVAnalyzer()
bv = BVAnalyzer()
return bv.get_oxi_state_decorated_structure(val).composition
else:
raise Exception(f"unrecognized input type: {input_type}")


def featurize(X, input_type, atom_dictionary, atom_embeddings, pool):
X_featurized = []
for val in X.values:
Expand All @@ -86,24 +89,50 @@ def featurize(X, input_type, atom_dictionary, atom_embeddings, pool):
X_featurized.append(pool(composition, atom_dictionary, atom_embeddings))
return np.array(X_featurized)

def featurize_species(X, input_type, species_dictionary, species_embeddings, atom_dictionary, atom_embeddings, pool):

def featurize_species(
X,
input_type,
species_dictionary,
species_embeddings,
atom_dictionary,
atom_embeddings,
pool,
):
X_featurized = []
for val in X.values:
composition = get_composition_species(val, input_type)
#print(composition.reduced_formula)

if any([e.to_pretty_string() not in species_dictionary for e in composition.elements]):
# print(composition.reduced_formula)

if any(
[
e.to_pretty_string() not in species_dictionary
for e in composition.elements
]
):
try:
composition = get_composition(val, input_type)
if any([e.name not in atom_dictionary for e in composition.elements]):
raise Exception(f"{composition.reduced_formula} contains unsupported atoms")
raise Exception(
f"{composition.reduced_formula} contains unsupported atoms"
)
X_featurized.append(pool(composition, atom_dictionary, atom_embeddings))
except:
raise Exception(f"{composition.reduced_formula} contains unsupported species and atoms")
raise Exception(
f"{composition.reduced_formula} contains unsupported species and atoms"
)
else:
X_featurized.append(pool(composition, species_dictionary, species_embeddings, species_mode=True))
X_featurized.append(
pool(
composition,
species_dictionary,
species_embeddings,
species_mode=True,
)
)
return np.array(X_featurized)


if __name__ == "__main__":
mb = MatbenchBenchmark(autoload=False)
suported_tasks = list(mb.tasks_map.keys())
Expand Down Expand Up @@ -141,7 +170,10 @@ def featurize_species(X, input_type, species_dictionary, species_embeddings, ato
"--vectors", required=True, type=str, help="path to the atom vectors .csv file"
)
parser.add_argument(
"--species-vectors", required=False, type=str, help="path to the species vectors .csv file"
"--species-vectors",
required=False,
type=str,
help="path to the species vectors .csv file",
)
parser.add_argument(
"--pooling",
Expand Down Expand Up @@ -245,13 +277,15 @@ def featurize_species(X, input_type, species_dictionary, species_embeddings, ato
pool = max_pool
else:
raise Exception(f"unsupported pooling: {args.pooling}")

if args.species:
atom_dictionary, atom_embeddings = atom_vectors_from_csv(args.vectors)
species_dictionary, species_embeddings = species_vectors_from_csv(args.species_vectors)
species_dictionary, species_embeddings = species_vectors_from_csv(
args.species_vectors
)
else:
atom_dictionary, atom_embeddings = atom_vectors_from_csv(args.vectors)

input_type = task.metadata["input_type"]

logger.info(f"architecture: {architecture}")
Expand All @@ -274,10 +308,28 @@ def featurize_species(X, input_type, species_dictionary, species_embeddings, ato
shuffle=True,
)
if args.species:
X_train = featurize_species(X_train, input_type,species_dictionary, species_embeddings, atom_dictionary, atom_embeddings, pool)
X_val = featurize_species(X_val, input_type, species_dictionary, species_embeddings, atom_dictionary, atom_embeddings, pool)
X_train = featurize_species(
X_train,
input_type,
species_dictionary,
species_embeddings,
atom_dictionary,
atom_embeddings,
pool,
)
X_val = featurize_species(
X_val,
input_type,
species_dictionary,
species_embeddings,
atom_dictionary,
atom_embeddings,
pool,
)
else:
X_train = featurize(X_train, input_type, atom_dictionary, atom_embeddings, pool)
X_train = featurize(
X_train, input_type, atom_dictionary, atom_embeddings, pool
)
X_val = featurize(X_val, input_type, atom_dictionary, atom_embeddings, pool)

y_train = y_train.to_numpy()
Expand Down Expand Up @@ -330,7 +382,13 @@ def featurize_species(X, input_type, species_dictionary, species_embeddings, ato
test_inputs = task.get_test_data(fold, include_target=False)
if args.species:
X_test = featurize_species(
test_inputs, input_type,species_dictionary, species_embeddings, atom_dictionary, atom_embeddings, pool
test_inputs,
input_type,
species_dictionary,
species_embeddings,
atom_dictionary,
atom_embeddings,
pool,
)
else:
X_test = featurize(
Expand Down

0 comments on commit 5c6bef6

Please sign in to comment.