Routines which handle model parameter updating.

create_grad_energy_update_param_fn(energy_data_val_and_grad, optimizer_apply, get_position_fn, apply_pmap=True, record_param_l1_norm=False)

Create the update_param_fn based on the gradient of the total energy.

See :func:~vmcnet.train.vmc.vmc_loop for its usage.


Name Type Description Default
energy_data_val_and_grad Callable

function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance)

optimizer_apply Callable

applies an update to the parameters. Has signature (grad_energy, params, optimizer_state) -> (new_params, new_optimizer_state).

get_position_fn GetPositionFromData

gets the walker positions from the MCMC data.

apply_pmap bool

whether to apply jax.pmap to the walker function. If False, applies jax.jit. Defaults to True.



Type Description

function which updates the parameters given the current data, params, and optimizer state. The signature of this function is (data, params, optimizer_state, key) -> (new_params, new_optimizer_state, metrics, key) The function is pmapped if apply_pmap is True, and jitted if apply_pmap is False.

Source code in vmcnet/updates/
def create_grad_energy_update_param_fn(
    energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
    optimizer_apply: Callable[[P, P, S, D], Tuple[P, S]],
    get_position_fn: GetPositionFromData[D],
    apply_pmap: bool = True,
    record_param_l1_norm: bool = False,
) -> UpdateParamFn[P, D, S]:
    """Create the `update_param_fn` based on the gradient of the total energy.

    See :func:`~vmcnet.train.vmc.vmc_loop` for its usage.

        energy_data_val_and_grad (Callable): function which computes the clipped energy
            value and gradient. Has the signature
                (params, x)
                -> ((expected_energy, auxiliary_energy_data), grad_energy),
            where auxiliary_energy_data is the tuple
            (expected_variance, local_energies, unclipped_energy, unclipped_variance)
        optimizer_apply (Callable): applies an update to the parameters. Has signature
            (grad_energy, params, optimizer_state) -> (new_params, new_optimizer_state).
        get_position_fn (GetPositionFromData): gets the walker positions from the MCMC
        apply_pmap (bool, optional): whether to apply jax.pmap to the walker function.
            If False, applies jax.jit. Defaults to True.

        Callable: function which updates the parameters given the current data, params,
        and optimizer state. The signature of this function is
            (data, params, optimizer_state, key)
            -> (new_params, new_optimizer_state, metrics, key)
        The function is pmapped if apply_pmap is True, and jitted if apply_pmap is

    def update_param_fn(params, data, optimizer_state, key):
        position = get_position_fn(data)
        energy_data, grad_energy = energy_data_val_and_grad(params, position)
        energy, aux_energy_data = energy_data

        grad_energy = utils.distribute.pmean_if_pmap(grad_energy)
        params, optimizer_state = optimizer_apply(
            grad_energy, params, optimizer_state, data
        metrics = {"energy": energy, "variance": aux_energy_data[0]}
        metrics = _update_metrics_with_noclip(
            aux_energy_data[2], aux_energy_data[3], metrics
        if record_param_l1_norm:
            metrics.update({"param_l1_norm": tree_reduce_l1(params)})
        return params, optimizer_state, metrics, key

    traced_fn = _make_traced_fn_with_single_metrics(update_param_fn, apply_pmap)

    return traced_fn

create_kfac_update_param_fn(optimizer, damping, get_position_fn, record_param_l1_norm=False)

Create momentum-less KFAC update step function.


Name Type Description Default
optimizer kfac_ferminet_alpha.Optimizer

instance of the Optimizer class from kfac_ferminet_alpha

damping jnp.float32

damping coefficient

get_position_fn GetPositionFromData

function which gets the walker positions from the data. Has signature data -> Array



Type Description

function which updates the parameters given the current data, params, and optimizer state. The signature of this function is (data, params, optimizer_state, key) -> (new_params, new_optimizer_state, metrics, key)

Source code in vmcnet/updates/
def create_kfac_update_param_fn(
    optimizer: kfac_ferminet_alpha.Optimizer,
    damping: jnp.float32,
    get_position_fn: GetPositionFromData[D],
    record_param_l1_norm: bool = False,
) -> UpdateParamFn[kfac_opt.Parameters, D, kfac_opt.State]:
    """Create momentum-less KFAC update step function.

        optimizer (kfac_ferminet_alpha.Optimizer): instance of the Optimizer class from
        damping (jnp.float32): damping coefficient
        get_position_fn (GetPositionFromData): function which gets the walker positions
            from the data. Has signature data -> Array

        Callable: function which updates the parameters given the current data, params,
        and optimizer state. The signature of this function is
            (data, params, optimizer_state, key)
            -> (new_params, new_optimizer_state, metrics, key)
    momentum = jnp.asarray(0.0)
    damping = jnp.asarray(damping)
    if optimizer.multi_device:
        momentum = utils.distribute.replicate_all_local_devices(momentum)
        damping = utils.distribute.replicate_all_local_devices(damping)

    traced_compute_param_norm = _get_traced_compute_param_norm(optimizer.multi_device)

    def update_param_fn(params, data, optimizer_state, key):
        key, subkey = utils.distribute.split_or_psplit_key(key, optimizer.multi_device)
        params, optimizer_state, stats = optimizer.step(
        energy = stats["loss"]
        variance = stats["aux"][0]
        energy_noclip = stats["aux"][2]
        variance_noclip = stats["aux"][3]
        picked_stats = (energy, variance, energy_noclip, variance_noclip)

        if record_param_l1_norm:
            param_l1_norm = traced_compute_param_norm(params)
            picked_stats = picked_stats + (param_l1_norm,)

        stats_to_save = picked_stats
        if optimizer.multi_device:
            stats_to_save = [utils.distribute.get_first(stat) for stat in picked_stats]

        metrics = {"energy": stats_to_save[0], "variance": stats_to_save[1]}
        metrics = _update_metrics_with_noclip(
            stats_to_save[2], stats_to_save[3], metrics

        if record_param_l1_norm:
            metrics.update({"param_l1_norm": stats_to_save[4]})

        return params, optimizer_state, metrics, key

    return update_param_fn

create_eval_update_param_fn(local_energy_fn, nchains, get_position_fn, apply_pmap=True, record_local_energies=True, nan_safe=False)

No update/clipping/grad function which simply evaluates the local energies.

Can be used to do simple unclipped MCMC with :func:~vmcnet.train.vmc.vmc_loop.


Name Type Description Default
local_energy_fn Callable

computes local energies Hpsi / psi. Has signature (params, x) -> (Hpsi / psi)(x)

nchains int

total number of chains across all devices, used to compute a sample variance estimate of the local energy

get_position_fn GetPositionFromData

gets the walker positions from the MCMC data.

nan_safe bool

whether or not to mask local energy nans in the evaluation process. This option should not be used under normal circumstances, as the energy estimates are of unclear validity if nans are masked. However, it can be used to get a coarse estimate of the energy of a wavefunction even if a few walkers are returning nans for their local energies.



Type Description

function which evaluates the local energies and averages them, without updating the parameters

Source code in vmcnet/updates/
def create_eval_update_param_fn(
    local_energy_fn: ModelApply[P],
    nchains: int,
    get_position_fn: GetPositionFromData[D],
    apply_pmap: bool = True,
    record_local_energies: bool = True,
    nan_safe: bool = False,
) -> UpdateParamFn[P, D, OptimizerState]:
    """No update/clipping/grad function which simply evaluates the local energies.

    Can be used to do simple unclipped MCMC with :func:`~vmcnet.train.vmc.vmc_loop`.

        local_energy_fn (Callable): computes local energies Hpsi / psi. Has signature
            (params, x) -> (Hpsi / psi)(x)
        nchains (int): total number of chains across all devices, used to compute a
            sample variance estimate of the local energy
        get_position_fn (GetPositionFromData): gets the walker positions from the MCMC
        nan_safe (bool): whether or not to mask local energy nans in the evaluation
            process. This option should not be used under normal circumstances, as the
            energy estimates are of unclear validity if nans are masked. However,
            it can be used to get a coarse estimate of the energy of a wavefunction even
            if a few walkers are returning nans for their local energies.

        Callable: function which evaluates the local energies and averages them, without
        updating the parameters

    def eval_update_param_fn(params, data, optimizer_state, key):
        local_energies = local_energy_fn(params, get_position_fn(data))
        energy, variance = physics.core.get_statistics_from_local_energy(
            local_energies, nchains, nan_safe=nan_safe
        metrics = {"energy": energy, "variance": variance}
        if record_local_energies:
            metrics.update({"local_energies": local_energies})
        return params, optimizer_state, metrics, key

    traced_fn = _make_traced_fn_with_single_metrics(
        eval_update_param_fn, apply_pmap, {"energy", "variance"}

    return traced_fn

constrain_norm(grads, preconditioned_grads, learning_rate, norm_constraint=0.001)

Constrains the preconditioned norm of the update, adapted from KFAC.

Source code in vmcnet/updates/
def constrain_norm(
    grads: P,
    preconditioned_grads: P,
    learning_rate: jnp.float32,
    norm_constraint: jnp.float32 = 0.001,
) -> P:
    """Constrains the preconditioned norm of the update, adapted from KFAC."""
    sq_norm_grads = tree_inner_product(preconditioned_grads, grads)
    sq_norm_scaled_grads = sq_norm_grads * learning_rate ** 2

    # Sync the norms here, see:
    sq_norm_scaled_grads = utils.distribute.pmean_if_pmap(sq_norm_scaled_grads)

    max_coefficient = jnp.sqrt(norm_constraint / sq_norm_scaled_grads)
    coefficient = jnp.minimum(max_coefficient, 1)
    constrained_grads = multiply_tree_by_scalar(preconditioned_grads, coefficient)

    return constrained_grads