dynamic_width_position_amplitude
Metropolis routines for position amplitude data with dynamically sized steps.
DWPAData (dict)
TypedDict holding positions and wavefunction amplitudes, plus MoveMetadata.
DynamicWidthPositionAmplitudeData (dict)
TypedDict holding positions and wavefunction amplitudes, plus MoveMetadata.
MoveMetadata (dict)
Metadata for metropolis algorithm with dynamically sized gaussian steps.
Attributes:
Name | Type | Description |
---|---|---|
std_move |
jnp.float32 |
the standard deviation of the gaussian step |
move_acceptance_sum |
jnp.float32 |
the sum of the move acceptance ratios of each step taken since the last std_move update. At update time, this sum will be divided by moves_since_update to get the overall average, and std_move will be adjusted in order to attempt to keep this value near some target. |
moves_since_update |
jnp.int32 |
Number of moves since the last std_move update. This is tracked so that the metropolis algorithm can make updates to std_move at fixed intervals rather than with every step. |
make_dynamic_width_position_amplitude_data(position, amplitude, std_move, move_acceptance_sum=0.0, moves_since_update=0)
Create instance of DynamicWidthPositionAmplitudeData.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
position |
Array |
the particle positions |
required |
amplitude |
Array |
the wavefunction amplitudes |
required |
std_move |
jnp.float32 |
std for gaussian moves |
required |
move_acceptance_sum |
jnp.float32 |
sum of the acceptance ratios of each step since the last update. Default of 0 should not be altered if using this function for initial data. |
0.0 |
moves_since_update |
jnp.float32 |
the number of moves since the std_move was last updated. Default of 0 should not be altered if using this function for initial data. |
0 |
Returns:
Type | Description |
---|---|
DynamicWidthPositionAmplitudeData |
DWPAData |
Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_dynamic_width_position_amplitude_data(
position: Array,
amplitude: Array,
std_move: jnp.float32,
move_acceptance_sum: jnp.float32 = 0.0,
moves_since_update: jnp.int32 = 0,
) -> DWPAData:
"""Create instance of DynamicWidthPositionAmplitudeData.
Args:
position (Array): the particle positions
amplitude (Array): the wavefunction amplitudes
std_move (jnp.float32): std for gaussian moves
move_acceptance_sum (jnp.float32): sum of the acceptance ratios of each step
since the last update. Default of 0 should not be altered if using this
function for initial data.
moves_since_update (jnp.float32): the number of moves since the std_move was
last updated. Default of 0 should not be altered if using this function
for initial data.
Returns:
DWPAData
"""
return make_position_amplitude_data(
position,
amplitude,
MoveMetadata(
std_move=std_move,
move_acceptance_sum=move_acceptance_sum,
moves_since_update=moves_since_update,
),
)
make_threshold_adjust_std_move(target_acceptance_prob=0.5, threshold_delta=0.1, adjustment_delta=0.1)
Create a step size adjustment fn which aims to maintain a 50% acceptance rating.
Works by increasing the step size when the acceptance is at least some delta above a target, and decreasing it when the acceptance is the some delta below the target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_acceptance_prob |
jnp.float32 |
target value for the average acceptance ratio. Defaults to 0.5. |
0.5 |
threshold_delta |
jnp.float32 |
how far away from the target the acceptance ratio must be to trigger a compensating update. Defaults to 0.1. |
0.1 |
adjustment_delta |
jnp.float32 |
how big of an adjustment to make to the step width. Adjustments will multiply by either (1.0 + adjustment_delta) or (1.0 - adjustment_delta). Defaults to 0.1. |
0.1 |
Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_threshold_adjust_std_move(
target_acceptance_prob: jnp.float32 = 0.5,
threshold_delta: jnp.float32 = 0.1,
adjustment_delta: jnp.float32 = 0.1,
) -> Callable[[jnp.float32, jnp.float32], jnp.float32]:
"""Create a step size adjustment fn which aims to maintain a 50% acceptance rating.
Works by increasing the step size when the acceptance is at least some delta above
a target, and decreasing it when the acceptance is the some delta below the target.
Args:
target_acceptance_prob (jnp.float32): target value for the average acceptance
ratio. Defaults to 0.5.
threshold_delta (jnp.float32): how far away from the target the acceptance ratio
must be to trigger a compensating update. Defaults to 0.1.
adjustment_delta (jnp.float32): how big of an adjustment to make to the step
width. Adjustments will multiply by either (1.0 + adjustment_delta) or
(1.0 - adjustment_delta). Defaults to 0.1.
"""
def adjust_std_move(
old_std_move: jnp.float32, avg_move_acceptance: jnp.float32
) -> jnp.float32:
# Use jax.lax.cond since the predicates are data dependent.
std_move = jax.lax.cond(
avg_move_acceptance > target_acceptance_prob + threshold_delta,
lambda old_val: old_val * (1 + adjustment_delta),
lambda old_val: old_val,
old_std_move,
)
std_move = jax.lax.cond(
avg_move_acceptance < target_acceptance_prob - threshold_delta,
lambda old_val: old_val * (1 - adjustment_delta),
lambda old_val: old_val,
std_move,
)
return std_move
return adjust_std_move
make_update_move_metadata_fn(nmoves_per_update, adjust_std_move_fn)
Create a function that updates the move_metadata periodically.
Periodicity is controlled by the nmoves_per_update parameter and the logic for updating the std of the gaussian step is handled by adjust_std_move_fn.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nmoves_per_update |
jnp.int32 |
std_move will be updated every time this many steps are taken. |
required |
adjust_std_move_fn |
Callable |
handles the logic for updating std_move. Has signature (old_std_move, avg_move_acceptance) -> new_std_move |
required |
Returns:
Type | Description |
---|---|
Callable |
function with signature (old_move_metadata, move_mask) -> new_move_metadata Result can be fed into the factory for a metropolis step to handle the updating of the MoveMetadata. |
Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_update_move_metadata_fn(
nmoves_per_update: jnp.int32,
adjust_std_move_fn: Callable[[jnp.float32, jnp.float32], jnp.float32],
) -> Callable[[MoveMetadata, Array], MoveMetadata]:
"""Create a function that updates the move_metadata periodically.
Periodicity is controlled by the nmoves_per_update parameter and the logic for
updating the std of the gaussian step is handled by adjust_std_move_fn.
Args:
nmoves_per_update (jnp.int32): std_move will be updated every time this many
steps are taken.
adjust_std_move_fn (Callable): handles the logic for updating std_move.
Has signature (old_std_move, avg_move_acceptance) -> new_std_move
Returns:
Callable: function with signature
(old_move_metadata, move_mask) -> new_move_metadata
Result can be fed into the factory for a metropolis step to handle the updating
of the MoveMetadata.
"""
def update_move_metadata(
move_metadata: MoveMetadata, current_move_mask: Array
) -> MoveMetadata:
std_move = move_metadata["std_move"]
move_acceptance_sum = move_metadata["move_acceptance_sum"]
moves_since_update = move_metadata["moves_since_update"]
current_avg_acceptance = mean_all_local_devices(current_move_mask)
move_acceptance_sum = move_acceptance_sum + current_avg_acceptance
moves_since_update = moves_since_update + 1
def update_std_move(_):
move_acceptance_avg = move_acceptance_sum / moves_since_update
return (adjust_std_move_fn(std_move, move_acceptance_avg), 0, 0.0)
def skip_update_std_move(_):
return (std_move, moves_since_update, move_acceptance_sum)
(std_move, moves_since_update, move_acceptance_sum) = jax.lax.cond(
moves_since_update >= nmoves_per_update,
update_std_move,
skip_update_std_move,
operand=None,
)
return MoveMetadata(
std_move=std_move,
move_acceptance_sum=move_acceptance_sum,
moves_since_update=moves_since_update,
)
return update_move_metadata
make_dynamic_pos_amp_gaussian_step(model_apply, nmoves_per_update=10, adjust_std_move_fn=<function make_threshold_adjust_std_move.<locals>.adjust_std_move at 0x7ff17294be50>, logabs=True)
Create a metropolis step with dynamic gaussian step width.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_apply |
Callable |
function which evaluates a model. Has signature (params, position) -> amplitude |
required |
nmoves_per_update |
jnp.int32 |
number of metropolis steps to take between each update to std_move |
10 |
adjust_std_move_fn |
Callable |
handles the logic for updating std_move. Has signature (old_std_move, avg_move_acceptance) -> new_std_move |
<function make_threshold_adjust_std_move.<locals>.adjust_std_move at 0x7ff17294be50> |
logabs |
bool |
whether the provided amplitudes represent psi (logabs = False) or log|psi| (logabs = True). Defaults to True. |
True |
Returns:
Type | Description |
---|---|
Callable |
function which runs a metropolis step. Has the signature (params, DWPAData, key) -> (mean acceptance probability, DWPAData, new_key) |
Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_dynamic_pos_amp_gaussian_step(
model_apply: ModelApply[P],
nmoves_per_update: jnp.int32 = 10,
adjust_std_move_fn: Callable[
[jnp.float32, jnp.float32], jnp.float32
] = make_threshold_adjust_std_move(),
logabs: bool = True,
) -> MetropolisStep:
"""Create a metropolis step with dynamic gaussian step width.
Args:
model_apply (Callable): function which evaluates a model. Has signature
(params, position) -> amplitude
nmoves_per_update (jnp.int32): number of metropolis steps to take between each
update to std_move
adjust_std_move_fn (Callable): handles the logic for updating std_move. Has
signature (old_std_move, avg_move_acceptance) -> new_std_move
logabs (bool, optional): whether the provided amplitudes represent psi
(logabs = False) or log|psi| (logabs = True). Defaults to True.
Returns:
Callable: function which runs a metropolis step. Has the signature
(params, DWPAData, key)
-> (mean acceptance probability, DWPAData, new_key)
"""
update_move_metadata_fn = make_update_move_metadata_fn(
nmoves_per_update, adjust_std_move_fn
)
return make_position_amplitude_gaussian_metropolis_step(
model_apply,
lambda data: data["move_metadata"]["std_move"],
update_move_metadata_fn,
logabs,
)