Skip to content

Commit

Permalink
Merge pull request #3877 from jtkrogel/nx_maxsec
Browse files Browse the repository at this point in the history
Nexus: proper use of max_seconds in legacy drivers
  • Loading branch information
ye-luo authored Feb 25, 2022
2 parents 30b7547 + f85336b commit 861297b
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions nexus/lib/qmcpack_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -2482,7 +2482,7 @@ class linear(QIxml):
'tries','min_walkers','samplesperthread',
'shift_i','shift_s','max_relative_change','max_param_change',
'chase_lowest','chase_closest','block_lm','nblocks','nolds',
'nkept',
'nkept','max_seconds'
]
costs = ['energy','unreweightedvariance','reweightedvariance','variance','difference']
write_types = obj(gpu=yesno,usedrift=yesno,nonlocalpp=yesno,usebuffer=yesno,use_nonlocalpp_deriv=yesno,chase_lowest=yesno,chase_closest=yesno,block_lm=yesno)
Expand All @@ -2509,7 +2509,7 @@ class vmc(QIxml):
tag = 'qmc'
attributes = ['method','multiple','warp','move','gpu','checkpoint','trace','target','completed','id']
elements = ['estimator','record']
parameters = ['walkers','warmupsteps','blocks','steps','substeps','timestep','usedrift','stepsbetweensamples','samples','samplesperthread','nonlocalpp','tau','walkersperthread','reconfiguration','dmcwalkersperthread','current','ratio','firststep','minimumtargetwalkers']
parameters = ['walkers','warmupsteps','blocks','steps','substeps','timestep','usedrift','stepsbetweensamples','samples','samplesperthread','nonlocalpp','tau','walkersperthread','reconfiguration','dmcwalkersperthread','current','ratio','firststep','minimumtargetwalkers','max_seconds']
write_types = obj(gpu=yesno,usedrift=yesno,nonlocalpp=yesno,reconfiguration=yesno,ratio=yesno,completed=yesno)
#end class vmc

Expand All @@ -2518,7 +2518,7 @@ class dmc(QIxml):
tag = 'qmc'
attributes = ['method','move','gpu','multiple','warp','checkpoint','trace','target','completed','id','continue']
elements = ['estimator']
parameters = ['walkers','warmupsteps','blocks','steps','timestep','nonlocalmove','nonlocalmoves','pop_control','reconfiguration','targetwalkers','minimumtargetwalkers','sigmabound','energybound','feedback','recordwalkers','fastgrad','popcontrol','branchinterval','usedrift','storeconfigs','en_ref','tau','alpha','gamma','stepsbetweensamples','max_branch','killnode','swap_walkers','swap_trigger','branching_cutoff_scheme','l2_diffusion','maxage']
parameters = ['walkers','warmupsteps','blocks','steps','timestep','nonlocalmove','nonlocalmoves','pop_control','reconfiguration','targetwalkers','minimumtargetwalkers','sigmabound','energybound','feedback','recordwalkers','fastgrad','popcontrol','branchinterval','usedrift','storeconfigs','en_ref','tau','alpha','gamma','stepsbetweensamples','max_branch','killnode','swap_walkers','swap_trigger','branching_cutoff_scheme','l2_diffusion','maxage','max_seconds']
write_types = obj(gpu=yesno,nonlocalmoves=yesnostr,reconfiguration=yesno,fastgrad=yesno,completed=yesno,killnode=yesno,swap_walkers=yesno,l2_diffusion=yesno)
#end class dmc

Expand Down Expand Up @@ -6193,6 +6193,7 @@ def generate_opts(opt_reqs,**kwargs):
substeps = 10,
timestep = 0.3,
usedrift = False,
max_seconds = None,
)

linear_quartic_legacy_defaults = obj(
Expand Down Expand Up @@ -6248,6 +6249,7 @@ def generate_opts(opt_reqs,**kwargs):
timestep = 0.3,
checkpoint = -1,
usedrift = None,
max_seconds = None,
)
vmc_test_legacy_defaults = obj(
warmupsteps = 10,
Expand Down Expand Up @@ -6289,6 +6291,7 @@ def generate_opts(opt_reqs,**kwargs):
maxage = None,
feedback = None,
sigmabound = None,
max_seconds = None,
)
dmc_test_legacy_defaults = obj(
vmc_warmupsteps = 10,
Expand Down Expand Up @@ -6564,6 +6567,11 @@ def generate_legacy_opt_calculations(
if len(invalid)>0:
error('invalid optimization inputs provided\ninvalid inputs: {}\nvalid options are: {}'.format(sorted(invalid),sorted(allowed_opt_method_legacy_inputs)))
#end if
for k in list(opt_inputs.keys()):
if opt_inputs[k] is None:
del opt_inputs[k]
#end if
#end for
if 'minmethod' in opt_inputs and opt_inputs.minmethod.lower().startswith('oneshift'):
opt_inputs.minmethod = 'OneShiftOnly'
oneshift = True
Expand Down Expand Up @@ -6629,6 +6637,7 @@ def generate_legacy_vmc_calculations(
timestep ,
checkpoint ,
usedrift ,
max_seconds,
loc = 'generate_vmc_calculations',
):

Expand All @@ -6645,6 +6654,9 @@ def generate_legacy_vmc_calculations(
if usedrift is not None:
vmc_calc.usedrift = usedrift
#end if
if max_seconds is not None:
vmc_calc.max_seconds = max_seconds
#end if

vmc_calcs = [vmc_calc]

Expand Down Expand Up @@ -6682,6 +6694,7 @@ def generate_legacy_dmc_calculations(
maxage ,
feedback ,
sigmabound ,
max_seconds ,
loc = 'generate_dmc_calculations',
):

Expand Down Expand Up @@ -6709,6 +6722,9 @@ def generate_legacy_dmc_calculations(
if vmc_usedrift is not None:
vmc_calc.usedrift = vmc_usedrift
#end if
if max_seconds is not None:
vmc_calc.max_seconds = max_seconds
#end if

dmc_calcs = [vmc_calc]
if eq_dmc:
Expand Down Expand Up @@ -6743,6 +6759,7 @@ def generate_legacy_dmc_calculations(
maxage = maxage ,
feedback = feedback ,
sigmabound = sigmabound,
max_seconds = max_seconds,
)
for calc in dmc_calcs:
if isinstance(calc,dmc):
Expand Down Expand Up @@ -7092,11 +7109,11 @@ def generate_basic_input(**kwargs):
# apply method specific defaults
if kw.qmc is not None:
if kw.driver not in qmc_defaults:
QmcpackInput.class_error('Invalid input for argument "driver"\nInvalid input: {}\nValid options are: {}'.format(kw.driver,sorted(qmc_defaults.keys())),'generate_basic_input')
QmcpackInput.class_error('Invalid input for argument "driver".\nInvalid input: {}\nValid options are: {}'.format(kw.driver,sorted(qmc_defaults.keys())),'generate_qmcpack_input')
#end if
qmc_driver_defaults = qmc_defaults[kw.driver]
if kw.qmc not in qmc_driver_defaults:
QmcpackInput.class_error('Invalid input for argument "qmc"\nInvalid input: {}\nValid options are: {}'.format(kw.qmc,sorted(qmc_driver_defaults.keys())),'generate_basic_input')
QmcpackInput.class_error('Invalid input for argument "qmc".\nInvalid input: {}\nValid options are: {}'.format(kw.qmc,sorted(qmc_driver_defaults.keys())),'generate_qmcpack_input')
#end if
qmc_keys = ['driver']
kw.set_optional(**qmc_driver_defaults[kw.qmc])
Expand All @@ -7105,7 +7122,7 @@ def generate_basic_input(**kwargs):
opt_method_driver_defaults = opt_method_defaults[kw.driver]
key = (kw.method,kw.minmethod.lower())
if key not in opt_method_driver_defaults:
QmcpackInput.class_error('invalid input for arguments "method,minmethod"\ninvalid input: {}\nvalid options are: {}'.format(key,sorted(opt_method_driver_defaults.keys())),'generate_basic_input')
QmcpackInput.class_error('invalid input for arguments "method,minmethod".\nInvalid input: {}\nValid options are: {}'.format(key,sorted(opt_method_driver_defaults.keys())),'generate_qmcpack_input')
#end if
kw.set_optional(**opt_method_driver_defaults[key])
qmc_keys += list(opt_method_driver_defaults[key].keys())
Expand All @@ -7116,11 +7133,14 @@ def generate_basic_input(**kwargs):
# screen for invalid keywords
invalid_kwargs = set(kw.keys())-valid
if len(invalid_kwargs)>0:
QmcpackInput.class_error('invalid input parameters encountered\ninvalid input parameters: {0}\nvalid options are: {1}'.format(sorted(invalid_kwargs),sorted(valid)),'generate_qmcpack_input')
QmcpackInput.class_error('invalid input parameters encountered.\nInvalid input parameters: {0}\nValid options are: {1}'.format(sorted(invalid_kwargs),sorted(valid)),'generate_qmcpack_input')
#end if

batched = kw.driver=='batched'
legacy = kw.driver=='legacy'

if kw.system=='missing':
QmcpackInput.class_error('generate_basic_input argument system is missing\nif you really do not want particlesets to be generated, set system to None')
QmcpackInput.class_error('argument "system" is missing.\nIf you really do not want particlesets to be generated, set system to None.','generate_qmcpack_input')
#end if
if kw.bconds is None:
if kw.system is not None:
Expand Down Expand Up @@ -7168,11 +7188,13 @@ def generate_basic_input(**kwargs):
series = kw.series,
application = application(),
)
if kw.maxcpusecs is not None:
proj.maxcpusecs = kw.maxcpusecs
#end if
if kw.max_seconds is not None:
proj.max_seconds = kw.max_seconds
if batched:
if kw.maxcpusecs is not None:
proj.maxcpusecs = kw.maxcpusecs
#end if
if kw.max_seconds is not None:
proj.max_seconds = kw.max_seconds
#end if
#end if

simcell = generate_simulationcell(
Expand All @@ -7196,7 +7218,7 @@ def generate_basic_input(**kwargs):

if kw.det_format=='new':
if kw.excitation is not None:
QmcpackInput.class_error('user provided "excitation" input argument with new style determinant format\nplease add det_format="old" and try again')
QmcpackInput.class_error('user provided "excitation" input argument with new style determinant format.\nPlease add det_format="old" and try again','generate_qmcpack_input')
#end if
if kw.system is not None and isinstance(kw.system.structure,Jellium):
ssb = generate_sposet_builder(
Expand Down Expand Up @@ -7256,7 +7278,7 @@ def generate_basic_input(**kwargs):
system = kw.system,
)
else:
QmcpackInput.class_error('generate_basic_input argument det_format is invalid\n received: {0}\n valid options are: new,old'.format(det_format))
QmcpackInput.class_error('argument "det_format" is invalid.\nReceived: {0}\nValid options are: new, old'.format(det_format),'generate_qmcpack_input')
#end if


Expand Down

0 comments on commit 861297b

Please sign in to comment.