Skip to content

position_amplitude_core

Shared routines for position amplitude metropolis data.

PositionAmplitudeData (dict)

TypedDict of data holding positions, amplitudes, and optional metadata.

Holding both particle position and wavefn amplitude in the data can be advantageous to avoid recalculating amplitudes in some routines, e.g. acceptance probabilities. Furthermore, holding additional metadata can enable more sophisticated metropolis algorithms such as dynamically adjusted gaussian step sizes.

Attributes:

Name Type Description
walker_data PositionAmplitudeWalkerData

the positions and amplitudes

move_metadata any

any metadata needed for the metropolis algorithm

PositionAmplitudeWalkerData (dict)

TypedDict of walker data holding just positions and amplitudes.

Holding both particle position and wavefn amplitude in the same named tuple allows us to simultaneously mask over both in the acceptance function.

The first dimension of position and amplitude should match, but position can have more dimensions.

Attributes:

Name Type Description
position Array

array of shape (n, ...)

amplitude Array

array of shape (n,)

make_position_amplitude_data(position, amplitude, move_metadata)

Create PositionAmplitudeData from position, amplitude, and move_metadata.

Parameters:

Name Type Description Default
position Array

the particle positions

required
amplitude Array

the wavefunction amplitudes

required
move_metadata Any

other required metadata for the metropolis algorithm

required

Returns:

Type Description
PositionAmplitudeData

data containing positions, wavefn amplitudes, and move metadata

Source code in vmcnet/mcmc/position_amplitude_core.py
def make_position_amplitude_data(position: Array, amplitude: Array, move_metadata: Any):
    """Create PositionAmplitudeData from position, amplitude, and move_metadata.

    Args:
        position (Array): the particle positions
        amplitude (Array): the wavefunction amplitudes
        move_metadata (Any): other required metadata for the metropolis algorithm

    Returns:
        PositionAmplitudeData: data containing positions, wavefn amplitudes, and move
        metadata
    """
    return PositionAmplitudeData(
        walker_data=PositionAmplitudeWalkerData(position=position, amplitude=amplitude),
        move_metadata=move_metadata,
    )

get_position_from_data(data)

Get the position data from PositionAmplitudeData.

Parameters:

Name Type Description Default
data PositionAmplitudeData

the data

required

Returns:

Type Description
Array

the particle positions from the data

Source code in vmcnet/mcmc/position_amplitude_core.py
def get_position_from_data(data: PositionAmplitudeData) -> Array:
    """Get the position data from PositionAmplitudeData.

    Args:
        data (PositionAmplitudeData): the data

    Returns:
        Array: the particle positions from the data
    """
    return data["walker_data"]["position"]

get_amplitude_from_data(data)

Get the amplitude data from PositionAmplitudeData.

Parameters:

Name Type Description Default
data PositionAmplitudeData

the data

required

Returns:

Type Description
Array

the wave function amplitudes from the data

Source code in vmcnet/mcmc/position_amplitude_core.py
def get_amplitude_from_data(data: PositionAmplitudeData) -> Array:
    """Get the amplitude data from PositionAmplitudeData.

    Args:
        data (PositionAmplitudeData): the data

    Returns:
        Array: the wave function amplitudes from the data
    """
    return data["walker_data"]["amplitude"]

to_pam_tuple(data)

Returns data as a (position, amplitude, move_metadata) tuple.

Useful for quickly assigning all three pieces to local variables for further use.

Source code in vmcnet/mcmc/position_amplitude_core.py
def to_pam_tuple(data: PositionAmplitudeData) -> Tuple[Array, Array, Any]:
    """Returns data as a (position, amplitude, move_metadata) tuple.

    Useful for quickly assigning all three pieces to local variables for further use.
    """
    return (
        data["walker_data"]["position"],
        data["walker_data"]["amplitude"],
        data["move_metadata"],
    )

distribute_position_amplitude_data(data)

Distribute PositionAmplitudeData across devices.

Parameters:

Name Type Description Default
data PositionAmplitudeData

the data to distribute

required

Returns:

Type Description
PositionAmplitudeData

the distributed data.

Source code in vmcnet/mcmc/position_amplitude_core.py
def distribute_position_amplitude_data(
    data: PositionAmplitudeData,
) -> PositionAmplitudeData:
    """Distribute PositionAmplitudeData across devices.

    Args:
        data (PositionAmplitudeData): the data to distribute

    Returns:
        PositionAmplitudeData: the distributed data.
    """
    walker_data = data["walker_data"]
    move_metadata = data["move_metadata"]
    walker_data = default_distribute_data(walker_data)
    move_metadata = replicate_all_local_devices(move_metadata)
    return PositionAmplitudeData(walker_data=walker_data, move_metadata=move_metadata)

make_position_amplitude_gaussian_proposal(model_apply, get_std_move)

Create a gaussian proposal fn on PositionAmplitudeData.

Positions are perturbed by a guassian; amplitudes are evaluated using the supplied model; move_metadata is not modified.

Parameters:

Name Type Description Default
model_apply Callable

function which evaluates a model. Has signature (params, position) -> amplitude

required
get_std_move Callable

function which gets the standard deviation of the gaussian move, which can optionally depend on the data. Has signature (PositionAmplitudeData) -> std_move

required

Returns:

Type Description
Callable

proposal function which can be passed to the main VMC routine. Has signature (params, PositionAmplitudeData, key) -> (PositionAmplitudeData, key).

Source code in vmcnet/mcmc/position_amplitude_core.py
def make_position_amplitude_gaussian_proposal(
    model_apply: ModelApply[P],
    get_std_move: Callable[[PositionAmplitudeData], jnp.float32],
) -> Callable[
    [P, PositionAmplitudeData, PRNGKey], Tuple[PositionAmplitudeData, PRNGKey]
]:
    """Create a gaussian proposal fn on PositionAmplitudeData.

    Positions are perturbed by a guassian; amplitudes are evaluated using the supplied
    model; move_metadata is not modified.

    Args:
        model_apply (Callable): function which evaluates a model. Has signature
            (params, position) -> amplitude
        get_std_move (Callable): function which gets the standard deviation of the
            gaussian move, which can optionally depend on the data. Has signature
            (PositionAmplitudeData) -> std_move

    Returns:
        Callable: proposal function which can be passed to the main VMC routine. Has
        signature (params, PositionAmplitudeData, key) -> (PositionAmplitudeData, key).
    """

    def proposal_fn(params: P, data: PositionAmplitudeData, key: PRNGKey):
        std_move = get_std_move(data)
        proposed_position, key = metropolis.gaussian_proposal(
            data["walker_data"]["position"], std_move, key
        )
        proposed_amplitude = model_apply(params, proposed_position)
        return (
            make_position_amplitude_data(
                proposed_position, proposed_amplitude, data["move_metadata"]
            ),
            key,
        )

    return proposal_fn

make_position_amplitude_metropolis_symmetric_acceptance(logabs=True)

Create a Metropolis acceptance function on PositionAmplitudeData.

Parameters:

Name Type Description Default
logabs bool

whether amplitudes provided to acceptance_fn represent psi (logabs = False) or log|psi| (logabs = True). Defaults to True.

True

Returns:

Type Description
Callable

acceptance function which can be passed to the main VMC routine. Has signature (params, PositionAmplitudeData, PositionAmplitudeData) -> accept_ratio

Source code in vmcnet/mcmc/position_amplitude_core.py
def make_position_amplitude_metropolis_symmetric_acceptance(
    logabs: bool = True,
) -> Callable[[P, PositionAmplitudeData, PositionAmplitudeData], Array]:
    """Create a Metropolis acceptance function on PositionAmplitudeData.

    Args:
        logabs (bool, optional): whether amplitudes provided to `acceptance_fn`
            represent psi (logabs = False) or log|psi| (logabs = True). Defaults to
            True.

    Returns:
        Callable: acceptance function which can be passed to the main VMC routine. Has
        signature (params, PositionAmplitudeData, PositionAmplitudeData) -> accept_ratio
    """

    def acceptance_fn(
        params: P, data: PositionAmplitudeData, proposed_data: PositionAmplitudeData
    ):
        del params
        return metropolis.metropolis_symmetric_acceptance(
            data["walker_data"]["amplitude"],
            proposed_data["walker_data"]["amplitude"],
            logabs=logabs,
        )

    return acceptance_fn

make_position_amplitude_update(update_move_metadata_fn=None)

Factory for an update to PositionAmplitudeData.

The returned update takes a mask of approved MCMC walker moves move_mask and accepts those proposed moves from proposed_data, for both positions and amplitudes. The std_move gaussian step width can also be modified by an optional adjust_std_move_fn.

The moves in move_mask are applied along the first axis of the position data, and should be the same shape as the amplitude data (one-dimensional Array).

Parameters:

Name Type Description Default
update_move_metadata_fn Callable

function which calculates the new move_metadata. Has signature (old_move_metadata, move_mask) -> new_move_metadata

None

Returns:

Type Description
Callable

function with signature (PositionAmplitudeData, PositionAmplitudeData, Array) -> (PositionAmplitudeData), which takes in the original PositionAmplitudeData, the proposed PositionAmplitudeData, and a move mask. Uses the move mask to decide which proposed data to accept.

Source code in vmcnet/mcmc/position_amplitude_core.py
def make_position_amplitude_update(
    update_move_metadata_fn: Optional[Callable[[M, Array], M]] = None
) -> Callable[
    [
        PositionAmplitudeData,
        PositionAmplitudeData,
        Array,
    ],
    PositionAmplitudeData,
]:
    """Factory for an update to PositionAmplitudeData.

    The returned update takes a mask of approved MCMC walker moves `move_mask` and
    accepts those proposed moves from `proposed_data`, for both positions and
    amplitudes. The `std_move` gaussian step width can also be modified by an optional
    `adjust_std_move_fn`.

    The moves in `move_mask` are applied along the first axis of the position data, and
    should be the same shape as the amplitude data (one-dimensional Array).

    Args:
        update_move_metadata_fn (Callable): function which calculates the new
            move_metadata. Has signature
            (old_move_metadata, move_mask) -> new_move_metadata

    Returns:
        Callable: function with signature
            (PositionAmplitudeData, PositionAmplitudeData, Array) ->
                (PositionAmplitudeData),
            which takes in the original PositionAmplitudeData, the proposed
            PositionAmplitudeData, and a move mask. Uses
            the move mask to decide which proposed data to accept.
    """

    def update_position_amplitude(
        data: PositionAmplitudeData,
        proposed_data: PositionAmplitudeData,
        move_mask: Array,
    ) -> PositionAmplitudeData:
        def mask_on_first_dimension(old_data: Array, proposal: Array):
            shaped_mask = jnp.reshape(move_mask, (-1, *((1,) * (old_data.ndim - 1))))
            return jnp.where(shaped_mask, proposal, old_data)

        new_walker_data = jax.tree_map(
            mask_on_first_dimension, data["walker_data"], proposed_data["walker_data"]
        )

        new_move_metadata = proposed_data["move_metadata"]
        if update_move_metadata_fn is not None:
            new_move_metadata = update_move_metadata_fn(
                data["move_metadata"], move_mask
            )

        return PositionAmplitudeData(
            walker_data=new_walker_data, move_metadata=new_move_metadata
        )

    return update_position_amplitude

make_position_amplitude_gaussian_metropolis_step(model_apply, get_std_move, update_move_metadata_fn=None, logabs=True)

Make a gaussian proposal with Metropolis acceptance for PositionAmplitudeData.

Parameters:

Name Type Description Default
model_apply Callable

function which evaluates a model. Has signature (params, position) -> amplitude

required
get_std_move Callable

function which gets the standard deviation of the gaussian move, which can optionally depend on the data. Has signature (PositionAmplitudeData) -> std_move

required
update_move_metadata_fn Callable

function which calculates the new move_metadata. Has signature (old_move_metadata, move_mask) -> new_move_metadata.

None
logabs bool

whether the provided amplitudes represent psi (logabs = False) or log|psi| (logabs = True). Defaults to True.

True

Returns:

Type Description
Callable

function which does a metropolis step. Has the signature (params, PositionAmplitudeData, key) -> (mean acceptance probability, PositionAmplitudeData, new_key)

Source code in vmcnet/mcmc/position_amplitude_core.py
def make_position_amplitude_gaussian_metropolis_step(
    model_apply: ModelApply[P],
    get_std_move: Callable[[PositionAmplitudeData], jnp.float32],
    update_move_metadata_fn: Optional[Callable[[M, Array], M]] = None,
    logabs: bool = True,
) -> metropolis.MetropolisStep[P, PositionAmplitudeData]:
    """Make a gaussian proposal with Metropolis acceptance for PositionAmplitudeData.

    Args:
        model_apply (Callable): function which evaluates a model. Has signature
            (params, position) -> amplitude
        get_std_move (Callable): function which gets the standard deviation of the
            gaussian move, which can optionally depend on the data. Has signature
            (PositionAmplitudeData) -> std_move
        update_move_metadata_fn (Callable, optional): function which calculates the new
            move_metadata. Has signature
            (old_move_metadata, move_mask) -> new_move_metadata.
        logabs (bool, optional): whether the provided amplitudes represent psi
            (logabs = False) or log|psi| (logabs = True). Defaults to True.

    Returns:
        Callable: function which does a metropolis step. Has the signature
            (params, PositionAmplitudeData, key)
            -> (mean acceptance probability, PositionAmplitudeData, new_key)
    """
    proposal_fn = make_position_amplitude_gaussian_proposal(model_apply, get_std_move)
    accept_fn = make_position_amplitude_metropolis_symmetric_acceptance(logabs=logabs)
    metrop_step_fn = metropolis.make_metropolis_step(
        proposal_fn,
        accept_fn,
        make_position_amplitude_update(update_move_metadata_fn),
    )
    return metrop_step_fn