Skip to content

runners

Entry points for running standard jobs.

total_variation_clipping_fn(local_energies, threshold=5.0)

Clip local energy to within a multiple of the total variation from the median.

Source code in vmcnet/train/runners.py
def total_variation_clipping_fn(local_energies, threshold=5.0):
    """Clip local energy to within a multiple of the total variation from the median."""
    median_local_e = jnp.nanmedian(local_energies)
    total_variation = jnp.nanmean(jnp.abs(local_energies - median_local_e))
    clipped_local_e = jnp.clip(
        local_energies,
        median_local_e - threshold * total_variation,
        median_local_e + threshold * total_variation,
    )
    return clipped_local_e

run_molecule()

Run VMC on a molecule.

Source code in vmcnet/train/runners.py
def run_molecule() -> None:
    """Run VMC on a molecule."""
    reload_config, config = train.parse_config_flags.parse_flags(FLAGS)

    root_logger = logging.getLogger()
    root_logger.setLevel(config.logging_level)
    logdir = _get_logdir_and_save_config(reload_config, config)

    dtype_to_use = _get_dtype(config)

    ion_pos, ion_charges, nelec = _get_electron_ion_config_as_arrays(
        config, dtype=dtype_to_use
    )

    key = jax.random.PRNGKey(config.initial_seed)

    (
        log_psi_apply,
        burning_step,
        walker_fn,
        local_energy_fn,
        update_param_fn,
        get_amplitude_fn,
        params,
        data,
        optimizer_state,
        key,
    ) = _setup_vmc(
        config,
        ion_pos,
        ion_charges,
        nelec,
        key,
        dtype=dtype_to_use,
        apply_pmap=config.distribute,
    )

    reload_from_checkpoint = (
        reload_config.logdir != train.default_config.NO_RELOAD_LOG_DIR
        and reload_config.use_checkpoint_file
    )

    if reload_from_checkpoint:
        checkpoint_file_path = os.path.join(
            reload_config.logdir, reload_config.checkpoint_relative_file_path
        )
        directory, filename = os.path.split(checkpoint_file_path)
        _, data, params, optimizer_state, key = utils.io.reload_vmc_state(
            directory, filename
        )
        (
            data,
            params,
            optimizer_state,
            key,
        ) = utils.distribute.distribute_vmc_state_from_checkpoint(
            data, params, optimizer_state, key
        )

    params, optimizer_state, data, key = _burn_and_run_vmc(
        config.vmc,
        logdir,
        params,
        optimizer_state,
        data,
        burning_step,
        walker_fn,
        update_param_fn,
        get_amplitude_fn,
        key,
        should_checkpoint=True,
    )

    logging.info("Completed VMC! Evaluating")

    # TODO: integrate the stuff in mcmc/statistics and write out an evaluation summary
    # (energy, var, overall mean acceptance ratio, std error, iac) to eval_logdir, post
    # evaluation
    eval_logdir = os.path.join(logdir, "eval")

    eval_update_param_fn, eval_burning_step, eval_walker_fn = _setup_eval(
        config,
        log_psi_apply,
        local_energy_fn,
        pacore.get_position_from_data,
        apply_pmap=config.distribute,
    )
    optimizer_state = None

    eval_and_vmc_nchains_match = config.vmc.nchains == config.eval.nchains
    if not config.eval.use_data_from_training or not eval_and_vmc_nchains_match:
        key, data = _make_new_data_for_eval(
            config,
            log_psi_apply,
            params,
            ion_pos,
            ion_charges,
            nelec,
            key,
            dtype=dtype_to_use,
        )

    _burn_and_run_vmc(
        config.eval,
        eval_logdir,
        params,
        optimizer_state,
        data,
        eval_burning_step,
        eval_walker_fn,
        eval_update_param_fn,
        get_amplitude_fn,
        key,
        should_checkpoint=False,
    )

    # need to check for local_energy.txt because when config.eval.nepochs=0 the file is
    # not created regardless of config.eval.record_local_energies
    local_es_were_recorded = os.path.exists(
        os.path.join(eval_logdir, "local_energies.txt")
    )
    if config.eval.record_local_energies and local_es_were_recorded:
        local_energies_filepath = os.path.join(eval_logdir, "local_energies.txt")
        _compute_and_save_energy_statistics(
            local_energies_filepath, eval_logdir, "statistics"
        )

vmc_statistics()

Calculate statistics from a VMC evaluation run and write them to disc.

Source code in vmcnet/train/runners.py
def vmc_statistics() -> None:
    """Calculate statistics from a VMC evaluation run and write them to disc."""
    parser = argparse.ArgumentParser(
        description="Calculate statistics from a VMC evaluation run and write them "
        "to disc."
    )
    parser.add_argument(
        "local_energies_file_path",
        type=str,
        help="File path to load local energies from",
    )
    parser.add_argument(
        "output_file_path",
        type=str,
        help="File path to which to write the output statistics. The '.json' suffix "
        "will be appended to the supplied path.",
    )
    args = parser.parse_args()

    output_dir, output_filename = os.path.split(os.path.abspath(args.output_file_path))
    _compute_and_save_energy_statistics(
        args.local_energies_file_path, output_dir, output_filename
    )