Skip to content

dynamic_width_position_amplitude

Metropolis routines for position amplitude data with dynamically sized steps.

DWPAData (dict)

TypedDict holding positions and wavefunction amplitudes, plus MoveMetadata.

DynamicWidthPositionAmplitudeData (dict)

TypedDict holding positions and wavefunction amplitudes, plus MoveMetadata.

MoveMetadata (dict)

Metadata for metropolis algorithm with dynamically sized gaussian steps.

Attributes:

Name Type Description
std_move jnp.float32

the standard deviation of the gaussian step

move_acceptance_sum jnp.float32

the sum of the move acceptance ratios of each step taken since the last std_move update. At update time, this sum will be divided by moves_since_update to get the overall average, and std_move will be adjusted in order to attempt to keep this value near some target.

moves_since_update jnp.int32

Number of moves since the last std_move update. This is tracked so that the metropolis algorithm can make updates to std_move at fixed intervals rather than with every step.

make_dynamic_width_position_amplitude_data(position, amplitude, std_move, move_acceptance_sum=0.0, moves_since_update=0)

Create instance of DynamicWidthPositionAmplitudeData.

Parameters:

Name Type Description Default
position Array

the particle positions

required
amplitude Array

the wavefunction amplitudes

required
std_move jnp.float32

std for gaussian moves

required
move_acceptance_sum jnp.float32

sum of the acceptance ratios of each step since the last update. Default of 0 should not be altered if using this function for initial data.

0.0
moves_since_update jnp.float32

the number of moves since the std_move was last updated. Default of 0 should not be altered if using this function for initial data.

0

Returns:

Type Description
DynamicWidthPositionAmplitudeData

DWPAData

Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_dynamic_width_position_amplitude_data(
    position: Array,
    amplitude: Array,
    std_move: jnp.float32,
    move_acceptance_sum: jnp.float32 = 0.0,
    moves_since_update: jnp.int32 = 0,
) -> DWPAData:
    """Create instance of DynamicWidthPositionAmplitudeData.

    Args:
        position (Array): the particle positions
        amplitude (Array): the wavefunction amplitudes
        std_move (jnp.float32): std for gaussian moves
        move_acceptance_sum (jnp.float32): sum of the acceptance ratios of each step
            since the last update. Default of 0 should not be altered if using this
            function for initial data.
        moves_since_update (jnp.float32): the number of moves since the std_move was
            last updated. Default of 0 should not be altered if using this function
            for initial data.

    Returns:
        DWPAData
    """
    return make_position_amplitude_data(
        position,
        amplitude,
        MoveMetadata(
            std_move=std_move,
            move_acceptance_sum=move_acceptance_sum,
            moves_since_update=moves_since_update,
        ),
    )

make_threshold_adjust_std_move(target_acceptance_prob=0.5, threshold_delta=0.1, adjustment_delta=0.1)

Create a step size adjustment fn which aims to maintain a 50% acceptance rating.

Works by increasing the step size when the acceptance is at least some delta above a target, and decreasing it when the acceptance is the some delta below the target.

Parameters:

Name Type Description Default
target_acceptance_prob jnp.float32

target value for the average acceptance ratio. Defaults to 0.5.

0.5
threshold_delta jnp.float32

how far away from the target the acceptance ratio must be to trigger a compensating update. Defaults to 0.1.

0.1
adjustment_delta jnp.float32

how big of an adjustment to make to the step width. Adjustments will multiply by either (1.0 + adjustment_delta) or (1.0 - adjustment_delta). Defaults to 0.1.

0.1
Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_threshold_adjust_std_move(
    target_acceptance_prob: jnp.float32 = 0.5,
    threshold_delta: jnp.float32 = 0.1,
    adjustment_delta: jnp.float32 = 0.1,
) -> Callable[[jnp.float32, jnp.float32], jnp.float32]:
    """Create a step size adjustment fn which aims to maintain a 50% acceptance rating.

    Works by increasing the step size when the acceptance is at least some delta above
    a target, and decreasing it when the acceptance is the some delta below the target.

    Args:
        target_acceptance_prob (jnp.float32): target value for the average acceptance
            ratio. Defaults to 0.5.
        threshold_delta (jnp.float32): how far away from the target the acceptance ratio
            must be to trigger a compensating update. Defaults to 0.1.
        adjustment_delta (jnp.float32): how big of an adjustment to make to the step
            width. Adjustments will multiply by either (1.0 + adjustment_delta) or
            (1.0 - adjustment_delta). Defaults to 0.1.
    """

    def adjust_std_move(
        old_std_move: jnp.float32, avg_move_acceptance: jnp.float32
    ) -> jnp.float32:
        # Use jax.lax.cond since the predicates are data dependent.
        std_move = jax.lax.cond(
            avg_move_acceptance > target_acceptance_prob + threshold_delta,
            lambda old_val: old_val * (1 + adjustment_delta),
            lambda old_val: old_val,
            old_std_move,
        )
        std_move = jax.lax.cond(
            avg_move_acceptance < target_acceptance_prob - threshold_delta,
            lambda old_val: old_val * (1 - adjustment_delta),
            lambda old_val: old_val,
            std_move,
        )
        return std_move

    return adjust_std_move

make_update_move_metadata_fn(nmoves_per_update, adjust_std_move_fn)

Create a function that updates the move_metadata periodically.

Periodicity is controlled by the nmoves_per_update parameter and the logic for updating the std of the gaussian step is handled by adjust_std_move_fn.

Parameters:

Name Type Description Default
nmoves_per_update jnp.int32

std_move will be updated every time this many steps are taken.

required
adjust_std_move_fn Callable

handles the logic for updating std_move. Has signature (old_std_move, avg_move_acceptance) -> new_std_move

required

Returns:

Type Description
Callable

function with signature (old_move_metadata, move_mask) -> new_move_metadata Result can be fed into the factory for a metropolis step to handle the updating of the MoveMetadata.

Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_update_move_metadata_fn(
    nmoves_per_update: jnp.int32,
    adjust_std_move_fn: Callable[[jnp.float32, jnp.float32], jnp.float32],
) -> Callable[[MoveMetadata, Array], MoveMetadata]:
    """Create a function that updates the move_metadata periodically.

    Periodicity is controlled by the nmoves_per_update parameter and the logic for
    updating the std of the gaussian step is handled by adjust_std_move_fn.

    Args:
        nmoves_per_update (jnp.int32): std_move will be updated every time this many
            steps are taken.
        adjust_std_move_fn (Callable): handles the logic for updating std_move.
            Has signature (old_std_move, avg_move_acceptance) -> new_std_move

    Returns:
        Callable: function with signature
            (old_move_metadata, move_mask) -> new_move_metadata
        Result can be fed into the factory for a metropolis step to handle the updating
        of the MoveMetadata.
    """

    def update_move_metadata(
        move_metadata: MoveMetadata, current_move_mask: Array
    ) -> MoveMetadata:
        std_move = move_metadata["std_move"]
        move_acceptance_sum = move_metadata["move_acceptance_sum"]
        moves_since_update = move_metadata["moves_since_update"]

        current_avg_acceptance = mean_all_local_devices(current_move_mask)
        move_acceptance_sum = move_acceptance_sum + current_avg_acceptance
        moves_since_update = moves_since_update + 1

        def update_std_move(_):
            move_acceptance_avg = move_acceptance_sum / moves_since_update
            return (adjust_std_move_fn(std_move, move_acceptance_avg), 0, 0.0)

        def skip_update_std_move(_):
            return (std_move, moves_since_update, move_acceptance_sum)

        (std_move, moves_since_update, move_acceptance_sum) = jax.lax.cond(
            moves_since_update >= nmoves_per_update,
            update_std_move,
            skip_update_std_move,
            operand=None,
        )

        return MoveMetadata(
            std_move=std_move,
            move_acceptance_sum=move_acceptance_sum,
            moves_since_update=moves_since_update,
        )

    return update_move_metadata

make_dynamic_pos_amp_gaussian_step(model_apply, nmoves_per_update=10, adjust_std_move_fn=<function make_threshold_adjust_std_move.<locals>.adjust_std_move at 0x7ff17294be50>, logabs=True)

Create a metropolis step with dynamic gaussian step width.

Parameters:

Name Type Description Default
model_apply Callable

function which evaluates a model. Has signature (params, position) -> amplitude

required
nmoves_per_update jnp.int32

number of metropolis steps to take between each update to std_move

10
adjust_std_move_fn Callable

handles the logic for updating std_move. Has signature (old_std_move, avg_move_acceptance) -> new_std_move

<function make_threshold_adjust_std_move.<locals>.adjust_std_move at 0x7ff17294be50>
logabs bool

whether the provided amplitudes represent psi (logabs = False) or log|psi| (logabs = True). Defaults to True.

True

Returns:

Type Description
Callable

function which runs a metropolis step. Has the signature (params, DWPAData, key) -> (mean acceptance probability, DWPAData, new_key)

Source code in vmcnet/mcmc/dynamic_width_position_amplitude.py
def make_dynamic_pos_amp_gaussian_step(
    model_apply: ModelApply[P],
    nmoves_per_update: jnp.int32 = 10,
    adjust_std_move_fn: Callable[
        [jnp.float32, jnp.float32], jnp.float32
    ] = make_threshold_adjust_std_move(),
    logabs: bool = True,
) -> MetropolisStep:
    """Create a metropolis step with dynamic gaussian step width.

    Args:
        model_apply (Callable): function which evaluates a model. Has signature
            (params, position) -> amplitude
        nmoves_per_update (jnp.int32): number of metropolis steps to take between each
            update to std_move
        adjust_std_move_fn (Callable): handles the logic for updating std_move. Has
            signature (old_std_move, avg_move_acceptance) -> new_std_move
        logabs (bool, optional): whether the provided amplitudes represent psi
            (logabs = False) or log|psi| (logabs = True). Defaults to True.

    Returns:
        Callable: function which runs a metropolis step. Has the signature
            (params, DWPAData, key)
            -> (mean acceptance probability, DWPAData, new_key)
    """
    update_move_metadata_fn = make_update_move_metadata_fn(
        nmoves_per_update, adjust_std_move_fn
    )

    return make_position_amplitude_gaussian_metropolis_step(
        model_apply,
        lambda data: data["move_metadata"]["std_move"],
        update_move_metadata_fn,
        logabs,
    )