params
Routines which handle model parameter updating.
create_grad_energy_update_param_fn(energy_data_val_and_grad, optimizer_apply, get_position_fn, apply_pmap=True, record_param_l1_norm=False)
Create the update_param_fn
based on the gradient of the total energy.
See :func:~vmcnet.train.vmc.vmc_loop
for its usage.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
energy_data_val_and_grad |
Callable |
function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance) |
required |
optimizer_apply |
Callable |
applies an update to the parameters. Has signature (grad_energy, params, optimizer_state) -> (new_params, new_optimizer_state). |
required |
get_position_fn |
GetPositionFromData |
gets the walker positions from the MCMC data. |
required |
apply_pmap |
bool |
whether to apply jax.pmap to the walker function. If False, applies jax.jit. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
Callable |
function which updates the parameters given the current data, params, and optimizer state. The signature of this function is (data, params, optimizer_state, key) -> (new_params, new_optimizer_state, metrics, key) The function is pmapped if apply_pmap is True, and jitted if apply_pmap is False. |
Source code in vmcnet/updates/params.py
def create_grad_energy_update_param_fn(
energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
optimizer_apply: Callable[[P, P, S, D], Tuple[P, S]],
get_position_fn: GetPositionFromData[D],
apply_pmap: bool = True,
record_param_l1_norm: bool = False,
) -> UpdateParamFn[P, D, S]:
"""Create the `update_param_fn` based on the gradient of the total energy.
See :func:`~vmcnet.train.vmc.vmc_loop` for its usage.
Args:
energy_data_val_and_grad (Callable): function which computes the clipped energy
value and gradient. Has the signature
(params, x)
-> ((expected_energy, auxiliary_energy_data), grad_energy),
where auxiliary_energy_data is the tuple
(expected_variance, local_energies, unclipped_energy, unclipped_variance)
optimizer_apply (Callable): applies an update to the parameters. Has signature
(grad_energy, params, optimizer_state) -> (new_params, new_optimizer_state).
get_position_fn (GetPositionFromData): gets the walker positions from the MCMC
data.
apply_pmap (bool, optional): whether to apply jax.pmap to the walker function.
If False, applies jax.jit. Defaults to True.
Returns:
Callable: function which updates the parameters given the current data, params,
and optimizer state. The signature of this function is
(data, params, optimizer_state, key)
-> (new_params, new_optimizer_state, metrics, key)
The function is pmapped if apply_pmap is True, and jitted if apply_pmap is
False.
"""
def update_param_fn(params, data, optimizer_state, key):
position = get_position_fn(data)
energy_data, grad_energy = energy_data_val_and_grad(params, position)
energy, aux_energy_data = energy_data
grad_energy = utils.distribute.pmean_if_pmap(grad_energy)
params, optimizer_state = optimizer_apply(
grad_energy, params, optimizer_state, data
)
metrics = {"energy": energy, "variance": aux_energy_data[0]}
metrics = _update_metrics_with_noclip(
aux_energy_data[2], aux_energy_data[3], metrics
)
if record_param_l1_norm:
metrics.update({"param_l1_norm": tree_reduce_l1(params)})
return params, optimizer_state, metrics, key
traced_fn = _make_traced_fn_with_single_metrics(update_param_fn, apply_pmap)
return traced_fn
create_kfac_update_param_fn(optimizer, damping, get_position_fn, record_param_l1_norm=False)
Create momentum-less KFAC update step function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer |
kfac_ferminet_alpha.Optimizer |
instance of the Optimizer class from kfac_ferminet_alpha |
required |
damping |
jnp.float32 |
damping coefficient |
required |
get_position_fn |
GetPositionFromData |
function which gets the walker positions from the data. Has signature data -> Array |
required |
Returns:
Type | Description |
---|---|
Callable |
function which updates the parameters given the current data, params, and optimizer state. The signature of this function is (data, params, optimizer_state, key) -> (new_params, new_optimizer_state, metrics, key) |
Source code in vmcnet/updates/params.py
def create_kfac_update_param_fn(
optimizer: kfac_ferminet_alpha.Optimizer,
damping: jnp.float32,
get_position_fn: GetPositionFromData[D],
record_param_l1_norm: bool = False,
) -> UpdateParamFn[kfac_opt.Parameters, D, kfac_opt.State]:
"""Create momentum-less KFAC update step function.
Args:
optimizer (kfac_ferminet_alpha.Optimizer): instance of the Optimizer class from
kfac_ferminet_alpha
damping (jnp.float32): damping coefficient
get_position_fn (GetPositionFromData): function which gets the walker positions
from the data. Has signature data -> Array
Returns:
Callable: function which updates the parameters given the current data, params,
and optimizer state. The signature of this function is
(data, params, optimizer_state, key)
-> (new_params, new_optimizer_state, metrics, key)
"""
momentum = jnp.asarray(0.0)
damping = jnp.asarray(damping)
if optimizer.multi_device:
momentum = utils.distribute.replicate_all_local_devices(momentum)
damping = utils.distribute.replicate_all_local_devices(damping)
traced_compute_param_norm = _get_traced_compute_param_norm(optimizer.multi_device)
def update_param_fn(params, data, optimizer_state, key):
key, subkey = utils.distribute.split_or_psplit_key(key, optimizer.multi_device)
params, optimizer_state, stats = optimizer.step(
params=params,
state=optimizer_state,
rng=subkey,
data_iterator=iter([get_position_fn(data)]),
momentum=momentum,
damping=damping,
)
energy = stats["loss"]
variance = stats["aux"][0]
energy_noclip = stats["aux"][2]
variance_noclip = stats["aux"][3]
picked_stats = (energy, variance, energy_noclip, variance_noclip)
if record_param_l1_norm:
param_l1_norm = traced_compute_param_norm(params)
picked_stats = picked_stats + (param_l1_norm,)
stats_to_save = picked_stats
if optimizer.multi_device:
stats_to_save = [utils.distribute.get_first(stat) for stat in picked_stats]
metrics = {"energy": stats_to_save[0], "variance": stats_to_save[1]}
metrics = _update_metrics_with_noclip(
stats_to_save[2], stats_to_save[3], metrics
)
if record_param_l1_norm:
metrics.update({"param_l1_norm": stats_to_save[4]})
return params, optimizer_state, metrics, key
return update_param_fn
create_eval_update_param_fn(local_energy_fn, nchains, get_position_fn, apply_pmap=True, record_local_energies=True, nan_safe=False)
No update/clipping/grad function which simply evaluates the local energies.
Can be used to do simple unclipped MCMC with :func:~vmcnet.train.vmc.vmc_loop
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
local_energy_fn |
Callable |
computes local energies Hpsi / psi. Has signature (params, x) -> (Hpsi / psi)(x) |
required |
nchains |
int |
total number of chains across all devices, used to compute a sample variance estimate of the local energy |
required |
get_position_fn |
GetPositionFromData |
gets the walker positions from the MCMC data. |
required |
nan_safe |
bool |
whether or not to mask local energy nans in the evaluation process. This option should not be used under normal circumstances, as the energy estimates are of unclear validity if nans are masked. However, it can be used to get a coarse estimate of the energy of a wavefunction even if a few walkers are returning nans for their local energies. |
False |
Returns:
Type | Description |
---|---|
Callable |
function which evaluates the local energies and averages them, without updating the parameters |
Source code in vmcnet/updates/params.py
def create_eval_update_param_fn(
local_energy_fn: ModelApply[P],
nchains: int,
get_position_fn: GetPositionFromData[D],
apply_pmap: bool = True,
record_local_energies: bool = True,
nan_safe: bool = False,
) -> UpdateParamFn[P, D, OptimizerState]:
"""No update/clipping/grad function which simply evaluates the local energies.
Can be used to do simple unclipped MCMC with :func:`~vmcnet.train.vmc.vmc_loop`.
Arguments:
local_energy_fn (Callable): computes local energies Hpsi / psi. Has signature
(params, x) -> (Hpsi / psi)(x)
nchains (int): total number of chains across all devices, used to compute a
sample variance estimate of the local energy
get_position_fn (GetPositionFromData): gets the walker positions from the MCMC
data.
nan_safe (bool): whether or not to mask local energy nans in the evaluation
process. This option should not be used under normal circumstances, as the
energy estimates are of unclear validity if nans are masked. However,
it can be used to get a coarse estimate of the energy of a wavefunction even
if a few walkers are returning nans for their local energies.
Returns:
Callable: function which evaluates the local energies and averages them, without
updating the parameters
"""
def eval_update_param_fn(params, data, optimizer_state, key):
local_energies = local_energy_fn(params, get_position_fn(data))
energy, variance = physics.core.get_statistics_from_local_energy(
local_energies, nchains, nan_safe=nan_safe
)
metrics = {"energy": energy, "variance": variance}
if record_local_energies:
metrics.update({"local_energies": local_energies})
return params, optimizer_state, metrics, key
traced_fn = _make_traced_fn_with_single_metrics(
eval_update_param_fn, apply_pmap, {"energy", "variance"}
)
return traced_fn
constrain_norm(grads, preconditioned_grads, learning_rate, norm_constraint=0.001)
Constrains the preconditioned norm of the update, adapted from KFAC.
Source code in vmcnet/updates/params.py
def constrain_norm(
grads: P,
preconditioned_grads: P,
learning_rate: jnp.float32,
norm_constraint: jnp.float32 = 0.001,
) -> P:
"""Constrains the preconditioned norm of the update, adapted from KFAC."""
sq_norm_grads = tree_inner_product(preconditioned_grads, grads)
sq_norm_scaled_grads = sq_norm_grads * learning_rate ** 2
# Sync the norms here, see:
# https://github.com/deepmind/deepmind-research/blob/30799687edb1abca4953aec507be87ebe63e432d/kfac_ferminet_alpha/optimizer.py#L585
sq_norm_scaled_grads = utils.distribute.pmean_if_pmap(sq_norm_scaled_grads)
max_coefficient = jnp.sqrt(norm_constraint / sq_norm_scaled_grads)
coefficient = jnp.minimum(max_coefficient, 1)
constrained_grads = multiply_tree_by_scalar(preconditioned_grads, coefficient)
return constrained_grads