vmc
Main VMC loop.
vmc_loop(params, optimizer_state, data, nchains, nepochs, walker_fn, update_param_fn, key, logdir=None, checkpoint_every=1000, best_checkpoint_every=100, checkpoint_dir='checkpoints', checkpoint_variance_scale=10.0, checkpoint_if_nans=False, only_checkpoint_first_nans=True, record_amplitudes=False, get_amplitude_fn=None, nhistory_max=200)
Main Variational Monte Carlo loop routine.
Variational Monte Carlo (VMC) can be generically viewed as minimizing a
parameterized variational loss stochastically by sampling over a data distribution
via Monte Carlo sampling. This function implements this idea at a high level, using
a walker_fn to sample the data distribution, and passing the optimization step to a
generic function update_param_fn
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
pytree-like |
model parameters which are trained |
required |
optimizer_state |
pytree-like |
initial state of the optimizer |
required |
data |
pytree-like |
initial data |
required |
nchains |
int |
number of parallel MCMC chains being run. This can be difficult to infer from data, depending on the structure of data, whether data has been pmapped, etc. |
required |
nepochs |
int |
number of parameter updates to do |
required |
walker_fn |
Callable |
function which does a number of walker steps between each parameter update. Has the signature (data, params, key) -> (mean accept prob, new data, new key) |
required |
update_param_fn |
Callable |
function which updates the parameters. Has signature (data, params, optimizer_state, key) -> (new_params, optimizer_state, dict: metrics, key). If metrics is not None, it is required to have the entries "energy" and "variance" at a minimum. If metrics is None, no checkpointing is done. |
required |
key |
PRNGKey |
an array with shape (2,) representing a jax PRNG key passed to proposal_fn and used to randomly accept proposals with probabilities output by acceptance_fn |
required |
logdir |
str |
name of parent log directory. If None, no checkpointing is done. Defaults to None. |
None |
checkpoint_every |
int |
how often to regularly save checkpoints. If None, checkpoints are only saved when the error-adjusted running avg of the energy improves. Defaults to 1000. |
1000 |
best_checkpoint_every |
int |
limit on how often to save best
checkpoint, even if energy is improving. When the error-adjusted running avg
of the energy improves, instead of immediately saving a checkpoint, we hold
onto the data from that epoch in memory, and if it's still the best one when
we hit an epoch which is a multiple of |
100 |
checkpoint_dir |
str |
name of subdirectory to save the regular checkpoints. These are saved as "logdir/checkpoint_dir/(epoch + 1).npz". Defaults to "checkpoints". |
'checkpoints' |
checkpoint_variance_scale |
float |
scale of the variance term in the
error-adjusted running avg of the energy. Higher means the variance is more
important, and lower means the energy is more important. See
:func: |
10.0 |
checkpoint_if_nans |
bool |
whether to save checkpoints when nan energy values are recorded. Defaults to False. |
False |
only_checkpoint_first_nans |
bool |
whether to checkpoint only the first time nans are encountered, or every time. Useful to capture a nan checkpoint without risking writing too many checkpoints if the optimization starts to hit nans most or every epoch after some point. Only relevant if checkpoint_if_nans is True. Defaults to True. |
True |
nhistory_max |
int |
How much history to keep in the running histories of the energy and variance. Defaults to 200. |
200 |
Returns:
Type | Description |
---|---|
Tuple[~P, ~S, ~D, jax._src.prng.PRNGKeyArray] |
A tuple of (trained parameters, final optimizer state, final data, final key). These are the same structure as (params, optimizer_state, initial_data, key). |
Source code in vmcnet/train/vmc.py
def vmc_loop(
params: P,
optimizer_state: S,
data: D,
nchains: int,
nepochs: int,
walker_fn: WalkerFn[P, D],
update_param_fn: UpdateParamFn[P, D, S],
key: PRNGKey,
logdir: str = None,
checkpoint_every: Optional[int] = 1000,
best_checkpoint_every: Optional[int] = 100,
checkpoint_dir: str = "checkpoints",
checkpoint_variance_scale: float = 10.0,
checkpoint_if_nans: bool = False,
only_checkpoint_first_nans: bool = True,
record_amplitudes: bool = False,
get_amplitude_fn: Optional[GetAmplitudeFromData[D]] = None,
nhistory_max: int = 200,
) -> Tuple[P, S, D, PRNGKey]:
"""Main Variational Monte Carlo loop routine.
Variational Monte Carlo (VMC) can be generically viewed as minimizing a
parameterized variational loss stochastically by sampling over a data distribution
via Monte Carlo sampling. This function implements this idea at a high level, using
a walker_fn to sample the data distribution, and passing the optimization step to a
generic function `update_param_fn`.
Args:
params (pytree-like): model parameters which are trained
optimizer_state (pytree-like): initial state of the optimizer
data (pytree-like): initial data
nchains (int): number of parallel MCMC chains being run. This can be difficult
to infer from data, depending on the structure of data, whether data has
been pmapped, etc.
nepochs (int): number of parameter updates to do
walker_fn (Callable): function which does a number of walker steps between each
parameter update. Has the signature
(data, params, key) -> (mean accept prob, new data, new key)
update_param_fn (Callable): function which updates the parameters. Has signature
(data, params, optimizer_state, key)
-> (new_params, optimizer_state, dict: metrics, key).
If metrics is not None, it is required to have the entries "energy" and
"variance" at a minimum. If metrics is None, no checkpointing is done.
key (PRNGKey): an array with shape (2,) representing a jax PRNG key passed
to proposal_fn and used to randomly accept proposals with probabilities
output by acceptance_fn
logdir (str, optional): name of parent log directory. If None, no checkpointing
is done. Defaults to None.
checkpoint_every (int, optional): how often to regularly save checkpoints. If
None, checkpoints are only saved when the error-adjusted running avg of the
energy improves. Defaults to 1000.
best_checkpoint_every (int, optional): limit on how often to save best
checkpoint, even if energy is improving. When the error-adjusted running avg
of the energy improves, instead of immediately saving a checkpoint, we hold
onto the data from that epoch in memory, and if it's still the best one when
we hit an epoch which is a multiple of `best_checkpoint_every`, we save it
then. This ensures we don't waste time saving best checkpoints too often
when the energy is on a downward trajectory (as we hope it often is!).
Defaults to 100.
checkpoint_dir (str, optional): name of subdirectory to save the regular
checkpoints. These are saved as "logdir/checkpoint_dir/(epoch + 1).npz".
Defaults to "checkpoints".
checkpoint_variance_scale (float, optional): scale of the variance term in the
error-adjusted running avg of the energy. Higher means the variance is more
important, and lower means the energy is more important. See
:func:`~vmctrain.train.vmc.get_checkpoint_metric`. Defaults to 10.0.
checkpoint_if_nans (bool, optional): whether to save checkpoints when
nan energy values are recorded. Defaults to False.
only_checkpoint_first_nans (bool, optional): whether to checkpoint only the
first time nans are encountered, or every time. Useful to capture a nan
checkpoint without risking writing too many checkpoints if the optimization
starts to hit nans most or every epoch after some point. Only relevant if
checkpoint_if_nans is True. Defaults to True.
nhistory_max (int, optional): How much history to keep in the running histories
of the energy and variance. Defaults to 200.
Returns:
A tuple of (trained parameters, final optimizer state, final data, final key).
These are the same structure as (params, optimizer_state, initial_data, key).
"""
(
checkpoint_dir,
checkpoint_metric,
running_energy_and_variance,
best_checkpoint_data,
saved_nans_checkpoint,
) = utils.checkpoint.initialize_checkpointing(
checkpoint_dir, nhistory_max, logdir, checkpoint_every
)
with CheckpointWriter() as checkpoint_writer, MetricsWriter() as metrics_writer:
for epoch in range(nepochs):
# Save state for checkpointing at the start of the epoch for two reasons:
# 1. To save the model that generates the best energy and variance metrics,
# rather than the model one parameter UPDATE after the best metrics.
# 2. To ensure a fully consistent state can be reloaded from a checkpoint, &
# the exact subsequent behavior can be reproduced (if run on same machine).
old_params = params
old_state = optimizer_state
old_data = data
old_key = key
accept_ratio, data, key = walker_fn(params, data, key)
params, optimizer_state, metrics, key = update_param_fn(
params, data, optimizer_state, key
)
if metrics is None: # don't checkpoint if no metrics to checkpoint
continue
metrics["accept_ratio"] = accept_ratio
(
checkpoint_metric,
checkpoint_str,
best_checkpoint_data,
saved_nans_checkpoint,
) = utils.checkpoint.save_metrics_and_handle_checkpoints(
epoch,
old_params,
params,
old_state,
old_data,
data,
old_key,
metrics,
nchains,
running_energy_and_variance,
checkpoint_writer,
metrics_writer,
checkpoint_metric,
logdir=logdir,
variance_scale=checkpoint_variance_scale,
checkpoint_every=checkpoint_every,
best_checkpoint_every=best_checkpoint_every,
best_checkpoint_data=best_checkpoint_data,
checkpoint_dir=checkpoint_dir,
checkpoint_if_nans=checkpoint_if_nans,
only_checkpoint_first_nans=only_checkpoint_first_nans,
saved_nans_checkpoint=saved_nans_checkpoint,
record_amplitudes=record_amplitudes,
get_amplitude_fn=get_amplitude_fn,
)
utils.checkpoint.log_vmc_loop_state(epoch, metrics, checkpoint_str)
# TODO: add flag which gives a way to break out of the VMC loop when the
# first nan has been hit, to keep jobs from running past useful output in
# some cases
utils.checkpoint.finish_checkpointing(
checkpoint_writer, best_checkpoint_data, logdir
)
return params, optimizer_state, data, key