runners
Entry points for running standard jobs.
total_variation_clipping_fn(local_energies, threshold=5.0)
Clip local energy to within a multiple of the total variation from the median.
Source code in vmcnet/train/runners.py
def total_variation_clipping_fn(local_energies, threshold=5.0):
"""Clip local energy to within a multiple of the total variation from the median."""
median_local_e = jnp.nanmedian(local_energies)
total_variation = jnp.nanmean(jnp.abs(local_energies - median_local_e))
clipped_local_e = jnp.clip(
local_energies,
median_local_e - threshold * total_variation,
median_local_e + threshold * total_variation,
)
return clipped_local_e
run_molecule()
Run VMC on a molecule.
Source code in vmcnet/train/runners.py
def run_molecule() -> None:
"""Run VMC on a molecule."""
reload_config, config = train.parse_config_flags.parse_flags(FLAGS)
root_logger = logging.getLogger()
root_logger.setLevel(config.logging_level)
logdir = _get_logdir_and_save_config(reload_config, config)
dtype_to_use = _get_dtype(config)
ion_pos, ion_charges, nelec = _get_electron_ion_config_as_arrays(
config, dtype=dtype_to_use
)
key = jax.random.PRNGKey(config.initial_seed)
(
log_psi_apply,
burning_step,
walker_fn,
local_energy_fn,
update_param_fn,
get_amplitude_fn,
params,
data,
optimizer_state,
key,
) = _setup_vmc(
config,
ion_pos,
ion_charges,
nelec,
key,
dtype=dtype_to_use,
apply_pmap=config.distribute,
)
reload_from_checkpoint = (
reload_config.logdir != train.default_config.NO_RELOAD_LOG_DIR
and reload_config.use_checkpoint_file
)
if reload_from_checkpoint:
checkpoint_file_path = os.path.join(
reload_config.logdir, reload_config.checkpoint_relative_file_path
)
directory, filename = os.path.split(checkpoint_file_path)
_, data, params, optimizer_state, key = utils.io.reload_vmc_state(
directory, filename
)
(
data,
params,
optimizer_state,
key,
) = utils.distribute.distribute_vmc_state_from_checkpoint(
data, params, optimizer_state, key
)
params, optimizer_state, data, key = _burn_and_run_vmc(
config.vmc,
logdir,
params,
optimizer_state,
data,
burning_step,
walker_fn,
update_param_fn,
get_amplitude_fn,
key,
should_checkpoint=True,
)
logging.info("Completed VMC! Evaluating")
# TODO: integrate the stuff in mcmc/statistics and write out an evaluation summary
# (energy, var, overall mean acceptance ratio, std error, iac) to eval_logdir, post
# evaluation
eval_logdir = os.path.join(logdir, "eval")
eval_update_param_fn, eval_burning_step, eval_walker_fn = _setup_eval(
config,
log_psi_apply,
local_energy_fn,
pacore.get_position_from_data,
apply_pmap=config.distribute,
)
optimizer_state = None
eval_and_vmc_nchains_match = config.vmc.nchains == config.eval.nchains
if not config.eval.use_data_from_training or not eval_and_vmc_nchains_match:
key, data = _make_new_data_for_eval(
config,
log_psi_apply,
params,
ion_pos,
ion_charges,
nelec,
key,
dtype=dtype_to_use,
)
_burn_and_run_vmc(
config.eval,
eval_logdir,
params,
optimizer_state,
data,
eval_burning_step,
eval_walker_fn,
eval_update_param_fn,
get_amplitude_fn,
key,
should_checkpoint=False,
)
# need to check for local_energy.txt because when config.eval.nepochs=0 the file is
# not created regardless of config.eval.record_local_energies
local_es_were_recorded = os.path.exists(
os.path.join(eval_logdir, "local_energies.txt")
)
if config.eval.record_local_energies and local_es_were_recorded:
local_energies_filepath = os.path.join(eval_logdir, "local_energies.txt")
_compute_and_save_energy_statistics(
local_energies_filepath, eval_logdir, "statistics"
)
vmc_statistics()
Calculate statistics from a VMC evaluation run and write them to disc.
Source code in vmcnet/train/runners.py
def vmc_statistics() -> None:
"""Calculate statistics from a VMC evaluation run and write them to disc."""
parser = argparse.ArgumentParser(
description="Calculate statistics from a VMC evaluation run and write them "
"to disc."
)
parser.add_argument(
"local_energies_file_path",
type=str,
help="File path to load local energies from",
)
parser.add_argument(
"output_file_path",
type=str,
help="File path to which to write the output statistics. The '.json' suffix "
"will be appended to the supplied path.",
)
args = parser.parse_args()
output_dir, output_filename = os.path.split(os.path.abspath(args.output_file_path))
_compute_and_save_energy_statistics(
args.local_energies_file_path, output_dir, output_filename
)