Skip to content

metropolis

Proposal and acceptance fns for Metropolis-Hastings Markov-Chain Monte Carlo.

make_metropolis_step(proposal_fn, acceptance_fn, update_data_fn)

Factory to create a function which takes a single metropolis step.

Following Metropolis-Hastings Markov Chain Monte Carlo, a transition from one data state to another is split into proposal and acceptance. When used in a Metropolis routine to approximate a stationary distribution P, the proposal and acceptance functions should satisfy detailed balance, i.e.,

proposal_prob_ij * acceptance_ij * P_i = proposal_prob_ji * acceptance_ji * P_j,

where proposal_prob_ij is the likelihood of proposing the transition from state i to state j, acceptance_ij is the likelihood of accepting a transition from state i to state j, and P_i is the probability of being in state i.

Parameters:

Name Type Description Default
proposal_fn Callable

proposal function which produces new proposed data. Has the signature (params, data, key) -> proposed_data, key

required
acceptance_fn Callable

acceptance function which produces a vector of numbers used to create a mask for accepting the proposals. Has the signature (params, data, proposed_data) -> Array: acceptance probabilities

required
update_data_fn Callable

function used to update the data given the original data, the proposed data, and the array mask identifying which proposals to accept. Has the signature (data, proposed_data, mask) -> new_data

required

Returns:

Type Description
Callable

function which takes in (data, params, key) and outputs (mean acceptance probability, new data, new jax PRNG key split from previous one)

Source code in vmcnet/mcmc/metropolis.py
def make_metropolis_step(
    proposal_fn: Callable[[P, D, PRNGKey], Tuple[D, PRNGKey]],
    acceptance_fn: Callable[[P, D, D], Array],
    update_data_fn: Callable[[D, D, Array], D],
) -> MetropolisStep[P, D]:
    """Factory to create a function which takes a single metropolis step.

    Following Metropolis-Hastings Markov Chain Monte Carlo, a transition from one data
    state to another is split into proposal and acceptance. When used in a Metropolis
    routine to approximate a stationary distribution P, the proposal and acceptance
    functions should satisfy detailed balance, i.e.,

        proposal_prob_ij * acceptance_ij * P_i = proposal_prob_ji * acceptance_ji * P_j,

    where proposal_prob_ij is the likelihood of proposing the transition from state i to
    state j, acceptance_ij is the likelihood of accepting a transition from state i
    to state j, and P_i is the probability of being in state i.

    Args:
        proposal_fn (Callable): proposal function which produces new proposed data. Has
            the signature (params, data, key) -> proposed_data, key
        acceptance_fn (Callable): acceptance function which produces a vector of numbers
            used to create a mask for accepting the proposals. Has the signature
            (params, data, proposed_data) -> Array: acceptance probabilities
        update_data_fn (Callable): function used to update the data given the original
            data, the proposed data, and the array mask identifying which proposals to
            accept. Has the signature
            (data, proposed_data, mask) -> new_data

    Returns:
        Callable: function which takes in (data, params, key) and outputs
        (mean acceptance probability, new data, new jax PRNG key split from previous
        one)
    """

    def metrop_step_fn(
        params: P, data: D, key: PRNGKey
    ) -> Tuple[jnp.float32, D, PRNGKey]:
        """Take a single metropolis step."""
        key, subkey = jax.random.split(key)
        proposed_data, key = proposal_fn(params, data, key)
        accept_prob = acceptance_fn(params, data, proposed_data)
        move_mask = cast(
            Array,
            jax.random.uniform(subkey, shape=accept_prob.shape) < accept_prob,
        )
        new_data = update_data_fn(data, proposed_data, move_mask)

        return jnp.mean(accept_prob), new_data, key

    return metrop_step_fn

walk_data(nsteps, params, data, key, metrop_step_fn)

Take multiple Metropolis-Hastings steps.

This function is roughly equivalent to:

accept_sum = 0.0
for _ in range(nsteps):
    accept_prob, data, key = metropolis_step_fn(data, params, key)
    accept_sum += accept_prob
return accept_sum / nsteps, data, key

but has better tracing/pmap behavior due to the use of jax.lax.scan instead of a python for loop. See :func:~vmcnet.train.vmc.take_metropolis_step.

Parameters:

Name Type Description Default
nsteps int

number of steps to take

required
data pytree-like

data to walk (update) with each step

required
params pytree-like

parameters passed to proposal_fn and acceptance_fn, e.g. model params

required
key PRNGKey

an array with shape (2,) representing a jax PRNG key passed to proposal_fn and used to randomly accept proposals with probabilities output by acceptance_fn

required
metrop_step_fn Callable

function which does a metropolis step. Has the signature (data, params, key) -> (mean accept prob, new data, new key)

required

Returns:

Type Description
(jnp.float32, pytree-like, PRNGKey)

acceptance probability, new data, new jax PRNG key split (possibly multiple times) from previous one

Source code in vmcnet/mcmc/metropolis.py
def walk_data(
    nsteps: int,
    params: P,
    data: D,
    key: PRNGKey,
    metrop_step_fn: MetropolisStep[P, D],
) -> Tuple[jnp.float32, D, PRNGKey]:
    """Take multiple Metropolis-Hastings steps.

    This function is roughly equivalent to:
    ```
    accept_sum = 0.0
    for _ in range(nsteps):
        accept_prob, data, key = metropolis_step_fn(data, params, key)
        accept_sum += accept_prob
    return accept_sum / nsteps, data, key
    ```

    but has better tracing/pmap behavior due to the use of jax.lax.scan instead of a
    python for loop. See :func:`~vmcnet.train.vmc.take_metropolis_step`.

    Args:
        nsteps (int): number of steps to take
        data (pytree-like): data to walk (update) with each step
        params (pytree-like): parameters passed to proposal_fn and acceptance_fn, e.g.
            model params
        key (PRNGKey): an array with shape (2,) representing a jax PRNG key passed
            to proposal_fn and used to randomly accept proposals with probabilities
            output by acceptance_fn
        metrop_step_fn (Callable): function which does a metropolis step. Has the
            signature (data, params, key) -> (mean accept prob, new data, new key)

    Returns:
        (jnp.float32, pytree-like, PRNGKey): acceptance probability, new data,
            new jax PRNG key split (possibly multiple times) from previous one
    """

    def step_fn(carry, x):
        del x
        accept_prob, data, key = metrop_step_fn(params, carry[1], carry[2])
        return (carry[0] + accept_prob, data, key), None

    out = jax.lax.scan(step_fn, (0.0, data, key), xs=None, length=nsteps)
    accept_sum, data, key = out[0]
    return accept_sum / nsteps, data, key

make_jitted_burning_step(metrop_step_fn, apply_pmap=True)

Factory to create a burning step, which is an optionally pmapped Metropolis step.

This provides the functionality to optionally apply jax.pmap to a single Metropolis step. Only one step is traced so that the first burning step is traced but subsequent steps are properly jit-compiled. The acceptance probabilities (which typically don't mean much during burning) are thrown away.

For more about the Metropolis step itself, see :func:~vmcnet.mcmc.metropolis.make_metropolis_step.

Parameters:

Name Type Description Default
metrop_step_fn Callable

function which does a metropolis step. Has the signature (data, params, key) -> (mean accept prob, new data, new key)

required
apply_pmap bool

whether to apply jax.pmap to the burning step. If False, applies jax.jit. Defaults to True.

True

Returns:

Type Description
Callable

function with signature (data, params, key) -> (data, key), with jax.pmap optionally applied if apply_pmap is True.

Source code in vmcnet/mcmc/metropolis.py
def make_jitted_burning_step(
    metrop_step_fn: MetropolisStep[P, D],
    apply_pmap: bool = True,
) -> BurningStep[P, D]:
    """Factory to create a burning step, which is an optionally pmapped Metropolis step.

    This provides the functionality to optionally apply jax.pmap to a single Metropolis
    step. Only one step is traced so that the first burning step is traced but
    subsequent steps are properly jit-compiled. The acceptance probabilities (which
    typically don't mean much during burning) are thrown away.

    For more about the Metropolis step itself, see
    :func:`~vmcnet.mcmc.metropolis.make_metropolis_step`.

    Args:
        metrop_step_fn (Callable): function which does a metropolis step. Has the
            signature (data, params, key) -> (mean accept prob, new data, new key)
        apply_pmap (bool, optional): whether to apply jax.pmap to the burning step. If
            False, applies jax.jit. Defaults to True.

    Returns:
        Callable: function with signature
            (data, params, key) -> (data, key),
        with jax.pmap optionally applied if apply_pmap is True.
    """

    def burning_step(params: P, data: D, key: PRNGKey) -> Tuple[D, PRNGKey]:
        _, data, key = metrop_step_fn(params, data, key)
        return data, key

    if not apply_pmap:
        return jax.jit(burning_step)

    return utils.distribute.pmap(burning_step)

make_jitted_walker_fn(nsteps, metrop_step_fn, apply_pmap=True)

Factory to create a function which takes multiple Metropolis steps.

This provides the functionality to optionally apply jax.pmap to a jax.lax.scan loop of multiple metropolis steps. A typical use case would be to run this function between parameter updates in a VMC loop. An accumulated mean acceptance probability statistic is returned from this walker function.

See :func:~vmcnet.train.vmc.vmc_loop for usage.

Parameters:

Name Type Description Default
nsteps int

number of metropolis steps to take in each call

required
metrop_step_fn Callable

function which does a metropolis step. Has the signature (data, params, key) -> (mean accept probl, new data, new key)

required
apply_pmap bool

whether to apply jax.pmap to the walker function. If False, applies jax.jit. Defaults to True.

True

Returns:

Type Description
Callable

funciton with signature (params, data, key) -> (mean accept prob, new data, new key) with jax.pmap optionally applied if pmapped is True, and jax.jit applied if apply_pmap is False.

Source code in vmcnet/mcmc/metropolis.py
def make_jitted_walker_fn(
    nsteps: int,
    metrop_step_fn: MetropolisStep[P, D],
    apply_pmap: bool = True,
) -> WalkerFn[P, D]:
    """Factory to create a function which takes multiple Metropolis steps.

    This provides the functionality to optionally apply jax.pmap to a jax.lax.scan loop
    of multiple metropolis steps. A typical use case would be to run this function
    between parameter updates in a VMC loop. An accumulated mean acceptance probability
    statistic is returned from this walker function.

    See :func:`~vmcnet.train.vmc.vmc_loop` for usage.

    Args:
        nsteps (int): number of metropolis steps to take in each call
        metrop_step_fn (Callable): function which does a metropolis step. Has the
            signature (data, params, key) -> (mean accept probl, new data, new key)
        apply_pmap (bool, optional): whether to apply jax.pmap to the walker function.
            If False, applies jax.jit. Defaults to True.

    Returns:
        Callable: funciton with signature
            (params, data, key) -> (mean accept prob, new data, new key)
        with jax.pmap optionally applied if pmapped is True, and jax.jit applied if
        apply_pmap is False.
    """

    def walker_fn(params: P, data: D, key: PRNGKey) -> Tuple[jnp.float32, D, PRNGKey]:
        accept_ratio, data, key = walk_data(nsteps, params, data, key, metrop_step_fn)
        accept_ratio = utils.distribute.pmean_if_pmap(accept_ratio)
        return accept_ratio, data, key

    if not apply_pmap:
        return jax.jit(walker_fn)

    pmapped_walker_fn = utils.distribute.pmap(walker_fn)

    def pmapped_walker_fn_with_single_accept_ratio(
        params: P, data: D, key: PRNGKey
    ) -> Tuple[jnp.float32, D, PRNGKey]:
        accept_ratio, data, key = pmapped_walker_fn(params, data, key)
        accept_ratio = utils.distribute.get_first(accept_ratio)
        return accept_ratio, data, key

    return pmapped_walker_fn_with_single_accept_ratio

burn_data(burning_step, nsteps_to_burn, params, data, key)

Repeatedly apply a burning step.

Parameters:

Name Type Description Default
burning_step BurningStep

function which does a burning step. Has the signature (data, params, key) -> (new data, new key)

required
nsteps_to_burn int

number of times to call burning_step

required
data pytree-like

initial data

required
params pytree-like

parameters passed to the burning step

required
key PRNGKey

an array with shape (2,) representing a jax PRNG key passed to proposal_fn and used to randomly accept proposals with probabilities output by acceptance_fn

required

Returns:

Type Description
(pytree-like, PRNGKey)

new data, new key

Source code in vmcnet/mcmc/metropolis.py
def burn_data(
    burning_step: BurningStep[P, D],
    nsteps_to_burn: int,
    params: P,
    data: D,
    key: PRNGKey,
) -> Tuple[D, PRNGKey]:
    """Repeatedly apply a burning step.

    Args:
        burning_step (BurningStep): function which does a burning step. Has the
            signature (data, params, key) -> (new data, new key)
        nsteps_to_burn (int): number of times to call burning_step
        data (pytree-like): initial data
        params (pytree-like): parameters passed to the burning step
        key (PRNGKey): an array with shape (2,) representing a jax PRNG key passed
            to proposal_fn and used to randomly accept proposals with probabilities
            output by acceptance_fn

    Returns:
        (pytree-like, PRNGKey): new data, new key
    """
    logging.info("Burning data for %d steps", nsteps_to_burn)
    for _ in range(nsteps_to_burn):
        data, key = burning_step(params, data, key)
    return data, key

gaussian_proposal(positions, std_move, key)

Simple symmetric gaussian proposal in all positions at once.

Parameters:

Name Type Description Default
positions Array

original positions

required
std_move jnp.float32

standard deviation of the moves

required
key PRNGKey

an array with shape (2,) representing a jax PRNG key

required

Returns:

Type Description
(Array, Array)

(new positions, new key split from previous)

Source code in vmcnet/mcmc/metropolis.py
def gaussian_proposal(
    positions: Array, std_move: jnp.float32, key: PRNGKey
) -> Tuple[Array, PRNGKey]:
    """Simple symmetric gaussian proposal in all positions at once.

    Args:
        positions (Array): original positions
        std_move (jnp.float32): standard deviation of the moves
        key (PRNGKey): an array with shape (2,) representing a jax PRNG key

    Returns:
        (Array, Array): (new positions, new key split from previous)
    """
    key, subkey = jax.random.split(key)
    return positions + std_move * jax.random.normal(subkey, shape=positions.shape), key

metropolis_symmetric_acceptance(amplitude, proposed_amplitude, logabs=True)

Standard Metropolis acceptance ratio for a symmetric proposal function.

The general Metropolis-Hastings choice of acceptance ratio for moves from state i to state j is given by

accept_ij = min(1, (P_j * proposal_prob_ji) / (P_i * proposal_prob_ij)).

When proposal_prob is symmetric (assumed in this function), this simply reduces to accept_ij = min(1, P_j / P_i). Some care is taken to avoid numerical overflow and division by zero.

The inputs are wavefunction amplitudes psi or log(|psi|), so the probability P_i refers to |psi(i)|^2.

Parameters:

Name Type Description Default
amplitude Array

one-dimensional array of wavefunction amplitudes for the current state, or log wavefunction amplitudes if logabs is True

required
proposed_amplitude Array

one-dimensional array of wavefunction amplitudes for the proposed state, or log wavefunction amplitudes if logabs is True

required
logabs bool

whether the provided amplitudes represent psi (logabs = False) or log|psi| (logabs = True). Defaults to True.

True

Returns:

Type Description
Array

one-dimensional array of acceptance ratios for the Metropolis algorithm

Source code in vmcnet/mcmc/metropolis.py
def metropolis_symmetric_acceptance(
    amplitude: Array, proposed_amplitude: Array, logabs: bool = True
) -> Array:
    """Standard Metropolis acceptance ratio for a symmetric proposal function.

    The general Metropolis-Hastings choice of acceptance ratio for moves from state i to
    state j is given by

        accept_ij = min(1, (P_j * proposal_prob_ji) / (P_i * proposal_prob_ij)).

    When proposal_prob is symmetric (assumed in this function), this simply reduces to
    accept_ij = min(1, P_j / P_i). Some care is taken to avoid numerical overflow and
    division by zero.

    The inputs are wavefunction amplitudes psi or log(|psi|), so the probability P_i
    refers to |psi(i)|^2.

    Args:
        amplitude (Array): one-dimensional array of wavefunction amplitudes for
            the current state, or log wavefunction amplitudes if logabs is True
        proposed_amplitude (Array): one-dimensional array of wavefunction
            amplitudes for the proposed state, or log wavefunction amplitudes if logabs
            is True
        logabs (bool, optional): whether the provided amplitudes represent psi
            (logabs = False) or log|psi| (logabs = True). Defaults to True.

    Returns:
        Array: one-dimensional array of acceptance ratios for the Metropolis
        algorithm
    """
    if not logabs:
        prob_old = jnp.square(amplitude)
        prob_new = jnp.square(proposed_amplitude)
        ratio = prob_new / prob_old
        # safe division by zero
        ratio = jnp.where(
            jnp.logical_or(prob_old < prob_new, prob_old == 0.0),
            jnp.ones_like(ratio),
            ratio,
        )
        return ratio

    log_prob_old = 2.0 * amplitude
    log_prob_new = 2.0 * proposed_amplitude
    # avoid overflow if log_prob_new - log_prob_old is large
    return jnp.where(
        log_prob_new > log_prob_old,
        jnp.ones_like(log_prob_new),
        jnp.exp(log_prob_new - log_prob_old),
    )