-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfit.py
75 lines (63 loc) · 2.14 KB
/
fit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
Automated fitting script.
"""
import os
import sys
import fnmatch
import argparse
import logging
import multiprocessing
from paramselect import fit, load_datasets
from distributed import Client, LocalCluster
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dask-scheduler",
metavar="HOST:PORT",
help="Host and port of dask distributed scheduler")
parser.add_argument(
"--iter-record",
metavar="FILE",
help="Output file for recording iterations (CSV)")
parser.add_argument(
"--tracefile",
metavar="FILE",
help="Output file for recording MCMC trace (HDF5)")
parser.add_argument(
"--fit-settings",
metavar="FILE",
default="input.json",
help="Input JSON file with settings for fit")
parser.add_argument(
"--input-tdb",
metavar="FILE",
default=None,
help="Input TDB file, with desired degrees of freedom to fit specified as FUNCTIONs starting with 'VV'")
parser.add_argument(
"--output-tdb",
metavar="FILE",
default="out.tdb",
help="Output TDB file")
def recursive_glob(start, pattern):
matches = []
for root, dirnames, filenames in os.walk(start):
for filename in fnmatch.filter(filenames, pattern):
matches.append(os.path.join(root, filename))
return sorted(matches)
if __name__ == '__main__':
args = parser.parse_args(sys.argv[1:])
if not args.dask_scheduler:
args.dask_scheduler = LocalCluster(n_workers=int(multiprocessing.cpu_count() / 2), threads_per_worker=1, nanny=True)
client = Client(args.dask_scheduler)
logging.info(
"Running with dask scheduler: %s [%s cores]" % (
args.dask_scheduler,
sum(client.ncores().values())))
datasets = load_datasets(sorted(recursive_glob('Al-Ni', '*.json')))
recfile = open(args.iter_record, 'a') if args.iter_record else None
tracefile = args.tracefile if args.tracefile else None
try:
dbf, mdl, model_dof = fit(args.fit_settings, datasets, scheduler=client, recfile=recfile, tracefile=tracefile)
finally:
if recfile:
recfile.close()
dbf.to_file(args.output_tdb, if_exists='overwrite')