Skip to content

Commit

Permalink
Merge pull request #5 from AntObi/skipspecies
Browse files Browse the repository at this point in the history
Skipspecies
  • Loading branch information
AntObi authored Sep 15, 2023
2 parents 50fed4e + a7b8daf commit 50d7892
Show file tree
Hide file tree
Showing 3 changed files with 561 additions and 276 deletions.
56 changes: 40 additions & 16 deletions bin/create_csv_vectors_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import argparse
from sys import argv

from pymatgen.core import Element
from pymatgen.core import Element, Species

from skipatom import SkipAtomInducedModel, SkipAtomModel
from skipatom import SkipAtomInducedModel, SkipAtomModel, SkipSpeciesInducedModel

"""
e.g.
Expand Down Expand Up @@ -43,6 +43,9 @@
parser.add_argument(
"--induced", action="store_true", help="whether to use induced SkipAtom vectors"
)
parser.add_argument(
"--skipspecies", action="store_true", help="whether to use SkipSpecies vectors"
)
parser.add_argument(
"--min-count",
required=("induced" in argv),
Expand All @@ -59,24 +62,45 @@
args = parser.parse_args()

if args.induced:
model = SkipAtomInducedModel.load(
args.model, args.data, min_count=args.min_count, top_n=args.top_n
)
if args.skipspecies:
model = SkipSpeciesInducedModel.load(
args.model, args.data, min_count=args.min_count, top_n=args.top_n
)
else:
model = SkipAtomInducedModel.load(
args.model, args.data, min_count=args.min_count, top_n=args.top_n
)
else:
model = SkipAtomModel.load(args.model, args.data)

sorted_elems = sorted(
[(e, Element(e).number) for e in model.dictionary], key=lambda v: v[1]
)
if args.skipspecies:
sorted_specs = sorted(
[(s, Species.from_string(s).number) for s in model.dictionary],
key=lambda v: v[1],
)
else:
sorted_elems = sorted(
[(e, Element(e).number) for e in model.dictionary], key=lambda v: v[1]
)

dim = len(model.vectors[0])

with open(args.out, "w") as f:
header = ["element"]
header.extend([str(i) for i in range(dim)])
f.write("%s\n" % ",".join(header))
for elem, _ in sorted_elems:
vec = model.vectors[model.dictionary[elem]].tolist()
row = [elem]
row.extend([str(v) for v in vec])
f.write("%s\n" % ",".join(row))
if args.skipspecies:
header = ["species"]
header.extend([str(i) for i in range(dim)])
f.write("%s\n" % ",".join(header))
for spec, _ in sorted_specs:
vec = model.vectors[model.dictionary[spec]].tolist()
row = [spec]
row.extend([str(v) for v in vec])
f.write("%s\n" % ",".join(row))
else:
header = ["element"]
header.extend([str(i) for i in range(dim)])
f.write("%s\n" % ",".join(header))
for elem, _ in sorted_elems:
vec = model.vectors[model.dictionary[elem]].tolist()
row = [elem]
row.extend([str(v) for v in vec])
f.write("%s\n" % ",".join(row))
Loading

0 comments on commit 50d7892

Please sign in to comment.