parse_config
Get update functions from ConfigDicts.
get_update_fn_and_init_optimizer(log_psi_apply, vmc_config, params, data, get_position_fn, energy_data_val_and_grad, key, apply_pmap=True)
Get an update function and initialize optimizer state from the vmc configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_psi_apply |
Callable |
computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)| |
required |
vmc_config |
ConfigDict |
configuration for VMC |
required |
params |
pytree |
params with which to initialize optimizer state |
required |
data |
pytree |
data with which to initialize optimizer state |
required |
get_position_fn |
Callable |
function which gets the position array from the data |
required |
energy_data_val_and_grad |
Callable |
function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance) |
required |
key |
PRNGKey |
PRNGKey with which to initialize optimizer state |
required |
apply_pmap |
bool |
whether to pmap the optimizer steps. Defaults to True. |
True |
Exceptions:
Type | Description |
---|---|
ValueError |
A non-supported optimizer type is requested. Currently, KFAC, Adam, SGD, and SR (with either Adam or SGD) is supported. |
Returns:
Type | Description |
---|---|
(UpdateParamFn, OptimizerState, PRNGKey) |
update param function with signature (params, data, optimizer_state, key) -> (new params, new state, metrics, new key), initial optimizer state, and PRNGKey |
Source code in vmcnet/updates/parse_config.py
def get_update_fn_and_init_optimizer(
log_psi_apply: ModelApply[P],
vmc_config: ConfigDict,
params: P,
data: D,
get_position_fn: GetPositionFromData[D],
energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
key: PRNGKey,
apply_pmap: bool = True,
) -> Tuple[UpdateParamFn[P, D, OptimizerState], OptimizerState, PRNGKey]:
"""Get an update function and initialize optimizer state from the vmc configuration.
Args:
log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
function is (params, x) -> log|psi(x)|
vmc_config (ConfigDict): configuration for VMC
params (pytree): params with which to initialize optimizer state
data (pytree): data with which to initialize optimizer state
get_position_fn (Callable): function which gets the position array from the data
energy_data_val_and_grad (Callable): function which computes the clipped energy
value and gradient. Has the signature
(params, x)
-> ((expected_energy, auxiliary_energy_data), grad_energy),
where auxiliary_energy_data is the tuple
(expected_variance, local_energies, unclipped_energy, unclipped_variance)
key (PRNGKey): PRNGKey with which to initialize optimizer state
apply_pmap (bool, optional): whether to pmap the optimizer steps. Defaults to
True.
Raises:
ValueError: A non-supported optimizer type is requested. Currently, KFAC, Adam,
SGD, and SR (with either Adam or SGD) is supported.
Returns:
(UpdateParamFn, OptimizerState, PRNGKey):
update param function with signature
(params, data, optimizer_state, key)
-> (new params, new state, metrics, new key),
initial optimizer state, and
PRNGKey
"""
learning_rate_schedule = _get_learning_rate_schedule(
vmc_config.optimizer[vmc_config.optimizer_type]
)
if vmc_config.optimizer_type == "kfac":
return get_kfac_update_fn_and_state(
params,
data,
get_position_fn,
energy_data_val_and_grad,
key,
learning_rate_schedule,
vmc_config.optimizer.kfac,
vmc_config.record_param_l1_norm,
apply_pmap=apply_pmap,
)
elif vmc_config.optimizer_type == "sgd":
(update_param_fn, optimizer_state,) = get_sgd_update_fn_and_state(
params,
get_position_fn,
energy_data_val_and_grad,
learning_rate_schedule,
vmc_config.optimizer.sgd,
vmc_config.record_param_l1_norm,
apply_pmap=apply_pmap,
)
return update_param_fn, optimizer_state, key
elif vmc_config.optimizer_type == "adam":
(update_param_fn, optimizer_state,) = get_adam_update_fn_and_state(
params,
get_position_fn,
energy_data_val_and_grad,
learning_rate_schedule,
vmc_config.optimizer.adam,
vmc_config.record_param_l1_norm,
apply_pmap=apply_pmap,
)
return update_param_fn, optimizer_state, key
elif vmc_config.optimizer_type == "sr":
(update_param_fn, optimizer_state,) = get_sr_update_fn_and_state(
log_psi_apply,
params,
get_position_fn,
energy_data_val_and_grad,
learning_rate_schedule,
vmc_config.optimizer.sr,
vmc_config.optimizer[vmc_config.optimizer.sr.descent_type],
vmc_config.record_param_l1_norm,
apply_pmap=apply_pmap,
nan_safe=vmc_config.nan_safe,
)
return update_param_fn, optimizer_state, key
else:
raise ValueError(
"Requested optimizer type not supported; {} was requested".format(
vmc_config.optimizer_type
)
)
get_kfac_update_fn_and_state(params, data, get_position_fn, energy_data_val_and_grad, key, learning_rate_schedule, optimizer_config, record_param_l1_norm=False, apply_pmap=True)
Get an update param function, initial state, and key for KFAC.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
pytree |
params with which to initialize optimizer state |
required |
data |
pytree |
data with which to initialize optimizer state |
required |
get_position_fn |
Callable |
function which gets the position array from the data |
required |
energy_data_val_and_grad |
Callable |
function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance) |
required |
key |
PRNGKey |
PRNGKey with which to initialize optimizer state |
required |
learning_rate_schedule |
Callable |
function which returns a learning rate from epoch number. Has signature epoch -> learning_rate |
required |
optimizer_config |
ConfigDict |
configuration for KFAC |
required |
record_param_l1_norm |
bool |
whether to record the L1 norm of the parameters in the metrics. Defaults to False. |
False |
apply_pmap |
bool |
whether to pmap the optimizer steps. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
(UpdateParamFn, kfac_opt.State, PRNGKey) |
update param function with signature (params, data, optimizer_state, key) -> (new params, new state, metrics, new key), initial optimizer state, and PRNGKey |
Source code in vmcnet/updates/parse_config.py
def get_kfac_update_fn_and_state(
params: P,
data: D,
get_position_fn: GetPositionFromData[D],
energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
key: PRNGKey,
learning_rate_schedule: Callable[[int], jnp.float32],
optimizer_config: ConfigDict,
record_param_l1_norm: bool = False,
apply_pmap: bool = True,
) -> Tuple[UpdateParamFn[P, D, kfac_opt.State], kfac_opt.State, PRNGKey]:
"""Get an update param function, initial state, and key for KFAC.
Args:
params (pytree): params with which to initialize optimizer state
data (pytree): data with which to initialize optimizer state
get_position_fn (Callable): function which gets the position array from the data
energy_data_val_and_grad (Callable): function which computes the clipped energy
value and gradient. Has the signature
(params, x)
-> ((expected_energy, auxiliary_energy_data), grad_energy),
where auxiliary_energy_data is the tuple
(expected_variance, local_energies, unclipped_energy, unclipped_variance)
key (PRNGKey): PRNGKey with which to initialize optimizer state
learning_rate_schedule (Callable): function which returns a learning rate from
epoch number. Has signature epoch -> learning_rate
optimizer_config (ConfigDict): configuration for KFAC
record_param_l1_norm (bool, optional): whether to record the L1 norm of the
parameters in the metrics. Defaults to False.
apply_pmap (bool, optional): whether to pmap the optimizer steps. Defaults to
True.
Returns:
(UpdateParamFn, kfac_opt.State, PRNGKey):
update param function with signature
(params, data, optimizer_state, key)
-> (new params, new state, metrics, new key),
initial optimizer state, and
PRNGKey
"""
optimizer = kfac_ferminet_alpha.Optimizer(
energy_data_val_and_grad,
l2_reg=optimizer_config.l2_reg,
norm_constraint=optimizer_config.norm_constraint,
value_func_has_aux=True,
learning_rate_schedule=learning_rate_schedule,
curvature_ema=optimizer_config.curvature_ema,
inverse_update_period=optimizer_config.inverse_update_period,
min_damping=optimizer_config.min_damping,
num_burnin_steps=0,
register_only_generic=optimizer_config.register_only_generic,
estimation_mode=optimizer_config.estimation_mode,
multi_device=apply_pmap,
pmap_axis_name=utils.distribute.PMAP_AXIS_NAME,
)
key, subkey = utils.distribute.split_or_psplit_key(key, apply_pmap)
optimizer_state = optimizer.init(params, subkey, get_position_fn(data))
update_param_fn = create_kfac_update_param_fn(
optimizer,
optimizer_config.damping,
pacore.get_position_from_data,
record_param_l1_norm=record_param_l1_norm,
)
return update_param_fn, optimizer_state, key
get_adam_update_fn_and_state(params, get_position_fn, energy_data_val_and_grad, learning_rate_schedule, optimizer_config, record_param_l1_norm=False, apply_pmap=True)
Get an update param function and initial state for Adam.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
pytree |
params with which to initialize optimizer state |
required |
get_position_fn |
Callable |
function which gets the position array from the data |
required |
energy_data_val_and_grad |
Callable |
function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance) |
required |
learning_rate_schedule |
Callable |
function which returns a learning rate from epoch number. Has signature epoch -> learning_rate |
required |
optimizer_config |
ConfigDict |
configuration for Adam |
required |
record_param_l1_norm |
bool |
whether to record the L1 norm of the parameters in the metrics. Defaults to False. |
False |
apply_pmap |
bool |
whether to pmap the optimizer steps. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
(UpdateParamFn, optax.OptState) |
update param function with signature (params, data, optimizer_state, key) -> (new params, new state, metrics, new key), and initial optimizer state |
Source code in vmcnet/updates/parse_config.py
def get_adam_update_fn_and_state(
params: P,
get_position_fn: GetPositionFromData[D],
energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
learning_rate_schedule: Callable[[int], jnp.float32],
optimizer_config: ConfigDict,
record_param_l1_norm: bool = False,
apply_pmap: bool = True,
) -> Tuple[UpdateParamFn[P, D, optax.OptState], optax.OptState]:
"""Get an update param function and initial state for Adam.
Args:
params (pytree): params with which to initialize optimizer state
get_position_fn (Callable): function which gets the position array from the data
energy_data_val_and_grad (Callable): function which computes the clipped energy
value and gradient. Has the signature
(params, x)
-> ((expected_energy, auxiliary_energy_data), grad_energy),
where auxiliary_energy_data is the tuple
(expected_variance, local_energies, unclipped_energy, unclipped_variance)
learning_rate_schedule (Callable): function which returns a learning rate from
epoch number. Has signature epoch -> learning_rate
optimizer_config (ConfigDict): configuration for Adam
record_param_l1_norm (bool, optional): whether to record the L1 norm of the
parameters in the metrics. Defaults to False.
apply_pmap (bool, optional): whether to pmap the optimizer steps. Defaults to
True.
Returns:
(UpdateParamFn, optax.OptState):
update param function with signature
(params, data, optimizer_state, key)
-> (new params, new state, metrics, new key), and
initial optimizer state
"""
optimizer = _get_adam_optax_optimizer(learning_rate_schedule, optimizer_config)
return _get_optax_update_fn_and_state(
optimizer,
params,
get_position_fn,
energy_data_val_and_grad,
record_param_l1_norm,
apply_pmap,
)
get_sgd_update_fn_and_state(params, get_position_fn, energy_data_val_and_grad, learning_rate_schedule, optimizer_config, record_param_l1_norm=False, apply_pmap=True)
Get an update param function and initial state for SGD.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
pytree |
params with which to initialize optimizer state |
required |
get_position_fn |
Callable |
function which gets the position array from the data |
required |
energy_data_val_and_grad |
Callable |
function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance) |
required |
learning_rate_schedule |
Callable |
function which returns a learning rate from epoch number. Has signature epoch -> learning_rate |
required |
optimizer_config |
ConfigDict |
configuration for SGD |
required |
record_param_l1_norm |
bool |
whether to record the L1 norm of the parameters in the metrics. Defaults to False. |
False |
apply_pmap |
bool |
whether to pmap the optimizer steps. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
(UpdateParamFn, optax.OptState) |
update param function with signature (params, data, optimizer_state, key) -> (new params, new state, metrics, new key), and initial optimizer state |
Source code in vmcnet/updates/parse_config.py
def get_sgd_update_fn_and_state(
params: P,
get_position_fn: GetPositionFromData[D],
energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
learning_rate_schedule: Callable[[int], jnp.float32],
optimizer_config: ConfigDict,
record_param_l1_norm: bool = False,
apply_pmap: bool = True,
) -> Tuple[UpdateParamFn[P, D, optax.OptState], optax.OptState]:
"""Get an update param function and initial state for SGD.
Args:
params (pytree): params with which to initialize optimizer state
get_position_fn (Callable): function which gets the position array from the data
energy_data_val_and_grad (Callable): function which computes the clipped energy
value and gradient. Has the signature
(params, x)
-> ((expected_energy, auxiliary_energy_data), grad_energy),
where auxiliary_energy_data is the tuple
(expected_variance, local_energies, unclipped_energy, unclipped_variance)
learning_rate_schedule (Callable): function which returns a learning rate from
epoch number. Has signature epoch -> learning_rate
optimizer_config (ConfigDict): configuration for SGD
record_param_l1_norm (bool, optional): whether to record the L1 norm of the
parameters in the metrics. Defaults to False.
apply_pmap (bool, optional): whether to pmap the optimizer steps. Defaults to
True.
Returns:
(UpdateParamFn, optax.OptState):
update param function with signature
(params, data, optimizer_state, key)
-> (new params, new state, metrics, new key), and
initial optimizer state
"""
optimizer = _get_sgd_optax_optimizer(learning_rate_schedule, optimizer_config)
return _get_optax_update_fn_and_state(
optimizer,
params,
get_position_fn,
energy_data_val_and_grad,
record_param_l1_norm,
apply_pmap,
)
get_sr_update_fn_and_state(log_psi_apply, params, get_position_fn, energy_data_val_and_grad, learning_rate_schedule, optimizer_config, descent_config, record_param_l1_norm=False, apply_pmap=True, nan_safe=True)
Get an update param function and initial state for stochastic reconfiguration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_psi_apply |
Callable |
computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)| |
required |
params |
pytree |
params with which to initialize optimizer state |
required |
get_position_fn |
Callable |
function which gets the position array from the data |
required |
energy_data_val_and_grad |
Callable |
function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance) |
required |
learning_rate_schedule |
Callable |
function which returns a learning rate from epoch number. Has signature epoch -> learning_rate |
required |
optimizer_config |
ConfigDict |
configuration for stochastic reconfiguration |
required |
descent_config |
ConfigDict |
configuration for the gradient descent-like method used to apply the preconditioned updates |
required |
record_param_l1_norm |
bool |
whether to record the L1 norm of the parameters in the metrics. Defaults to False. |
False |
apply_pmap |
bool |
whether to pmap the optimizer steps. Defaults to True. |
True |
nan_safe |
bool |
whether the mean function used when centering the Jacobian of log|psi(x)| during the Fisher matvec is nan-safe. Defaults to True. |
True |
Exceptions:
Type | Description |
---|---|
ValueError |
A non-supported descent type is requested. Currently only Adam and SGD are supported. |
Returns:
Type | Description |
---|---|
(UpdateParamFn, optax.OptState) |
update param function with signature (params, data, optimizer_state, key) -> (new params, new state, metrics, new key), and initial optimizer state |
Source code in vmcnet/updates/parse_config.py
def get_sr_update_fn_and_state(
log_psi_apply: ModelApply[P],
params: P,
get_position_fn: GetPositionFromData[D],
energy_data_val_and_grad: physics.core.ValueGradEnergyFn[P],
learning_rate_schedule: Callable[[int], jnp.float32],
optimizer_config: ConfigDict,
descent_config: ConfigDict,
record_param_l1_norm: bool = False,
apply_pmap: bool = True,
nan_safe: bool = True,
) -> Tuple[UpdateParamFn[P, D, optax.OptState], optax.OptState]:
"""Get an update param function and initial state for stochastic reconfiguration.
Args:
log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
function is (params, x) -> log|psi(x)|
params (pytree): params with which to initialize optimizer state
get_position_fn (Callable): function which gets the position array from the data
energy_data_val_and_grad (Callable): function which computes the clipped energy
value and gradient. Has the signature
(params, x)
-> ((expected_energy, auxiliary_energy_data), grad_energy),
where auxiliary_energy_data is the tuple
(expected_variance, local_energies, unclipped_energy, unclipped_variance)
learning_rate_schedule (Callable): function which returns a learning rate from
epoch number. Has signature epoch -> learning_rate
optimizer_config (ConfigDict): configuration for stochastic reconfiguration
descent_config (ConfigDict): configuration for the gradient descent-like method
used to apply the preconditioned updates
record_param_l1_norm (bool, optional): whether to record the L1 norm of the
parameters in the metrics. Defaults to False.
apply_pmap (bool, optional): whether to pmap the optimizer steps. Defaults to
True.
nan_safe (bool, optional): whether the mean function used when centering the
Jacobian of log|psi(x)| during the Fisher matvec is nan-safe. Defaults to
True.
Raises:
ValueError: A non-supported descent type is requested. Currently only Adam and
SGD are supported.
Returns:
(UpdateParamFn, optax.OptState):
update param function with signature
(params, data, optimizer_state, key)
-> (new params, new state, metrics, new key), and
initial optimizer state
"""
maxiter = optimizer_config.maxiter if optimizer_config.maxiter >= 0 else None
mean_grad_fn = utils.distribute.get_mean_over_first_axis_fn(nan_safe=nan_safe)
precondition_grad_fn = get_fisher_inverse_fn(
log_psi_apply,
mean_grad_fn,
damping=optimizer_config.damping,
maxiter=maxiter,
mode=SRMode[optimizer_config.mode.upper()],
)
if optimizer_config.descent_type == "adam":
descent_optimizer = _get_adam_optax_optimizer(
learning_rate_schedule, descent_config
)
elif optimizer_config.descent_type == "sgd":
descent_optimizer = _get_sgd_optax_optimizer(
learning_rate_schedule, descent_config
)
else:
raise ValueError(
"Requested descent type not supported; {} was requested".format(
optimizer_config.descent_type
)
)
def get_optimizer_step_count(optimizer_state):
return optimizer_state[1].count
def optimizer_apply(grad, params, optimizer_state, data):
preconditioned_grad = precondition_grad_fn(grad, params, get_position_fn(data))
step_count = get_optimizer_step_count(optimizer_state)
learning_rate = learning_rate_schedule(step_count)
constrained_grad = constrain_norm(
grad, preconditioned_grad, learning_rate, optimizer_config.norm_constraint
)
updates, optimizer_state = descent_optimizer.update(
constrained_grad, optimizer_state, params
)
params = optax.apply_updates(params, updates)
return params, optimizer_state
update_param_fn = create_grad_energy_update_param_fn(
energy_data_val_and_grad,
optimizer_apply,
get_position_fn=get_position_fn,
record_param_l1_norm=record_param_l1_norm,
apply_pmap=apply_pmap,
)
optimizer_state = _init_optax_optimizer(
descent_optimizer, params, apply_pmap=apply_pmap
)
return update_param_fn, optimizer_state