Skip to content

core

Core local energy and gradient construction routines.

initialize_molecular_pos(key, nchains, ion_pos, ion_charges, nelec_total, init_width=1.0, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

Initialize a set of plausible initial electron positions.

For each chain, each electron is assigned to a random ion and then its position is sampled from a normal distribution centered at that ion with diagonal covariance with diagonal entries all equal to init_width.

If there are no more electrons than there are ions, the assignment is done without replacement. If there are more electrons than ions, the assignment is done with replacement, and the probability of choosing ion i is its relative charge (as a fraction of the sum of the ion charges).

Source code in vmcnet/physics/core.py
def initialize_molecular_pos(
    key: PRNGKey,
    nchains: int,
    ion_pos: Array,
    ion_charges: Array,
    nelec_total: int,
    init_width: float = 1.0,
    dtype=jnp.float32,
) -> Tuple[PRNGKey, Array]:
    """Initialize a set of plausible initial electron positions.

    For each chain, each electron is assigned to a random ion and then its position is
    sampled from a normal distribution centered at that ion with diagonal covariance
    with diagonal entries all equal to init_width.

    If there are no more electrons than there are ions, the assignment is done without
    replacement. If there are more electrons than ions, the assignment is done with
    replacement, and the probability of choosing ion i is its relative charge (as a
    fraction of the sum of the ion charges).
    """
    nion = len(ion_charges)
    replace = True

    if nelec_total <= nion:
        replace = False

    assignments = []
    for _ in range(nchains):
        key, subkey = jax.random.split(key)
        choices = jax.random.choice(
            subkey,
            nion,
            shape=(nelec_total,),
            replace=replace,
            p=ion_charges / jnp.sum(ion_charges),
        )
        assignments.append(ion_pos[choices])
    elecs_at_ions = jnp.stack(assignments, axis=0)
    key, subkey = jax.random.split(key)
    return key, elecs_at_ions + init_width * jax.random.normal(
        subkey, elecs_at_ions.shape, dtype=dtype
    )

combine_local_energy_terms(local_energy_terms)

Combine a sequence of local energy terms by adding them.

Parameters:

Name Type Description Default
local_energy_terms Sequence

sequence of local energy terms, each with the signature (params, x) -> array of terms of shape (x.shape[0],)

required

Returns:

Type Description
Callable

local energy function which computes the sum of the local energy terms. Has the signature (params, x) -> local energy array of shape (x.shape[0],)

Source code in vmcnet/physics/core.py
def combine_local_energy_terms(
    local_energy_terms: Sequence[ModelApply[P]],
) -> ModelApply[P]:
    """Combine a sequence of local energy terms by adding them.

    Args:
        local_energy_terms (Sequence): sequence of local energy terms, each with the
            signature (params, x) -> array of terms of shape (x.shape[0],)

    Returns:
        Callable: local energy function which computes the sum of the local energy
        terms. Has the signature
        (params, x) -> local energy array of shape (x.shape[0],)
    """

    def local_energy_fn(params: P, x: Array) -> Array:
        local_energy_sum = local_energy_terms[0](params, x)
        for term in local_energy_terms[1:]:
            local_energy_sum = cast(Array, local_energy_sum + term(params, x))
        return local_energy_sum

    return local_energy_fn

laplacian_psi_over_psi(grad_log_psi_apply, params, x)

Compute (nabla^2 psi) / psi at x given a function which evaluates psi'(x)/psi.

The computation is done by computing (forward-mode) derivatives of the gradient to get the columns of the Hessian, and accumulating the (i, i)th entries (but this implementation is significantly more memory efficient than directly computing the Hessian).

This function uses the identity

(nabla^2 psi) / psi = (nabla^2 log|psi|) + (nabla log|psi|)^2

to avoid leaving the log domain during the computation.

This function should be vmapped in order to be applied to batches of inputs, as it completely flattens x in order to take second derivatives w.r.t. each component.

This is approach is extremely similar to the one in the FermiNet repo (in the jax branch, as of this writing -- see https://github.com/deepmind/ferminet/blob/aade61b3d30883b3238d6b50c85404d0e8176155/ferminet/hamiltonian.py).

The main difference is that we are being explicit about the flattening of x within the Laplacian calculation, so that it does not have to be handled outside of this function (psi is free to take x shapes which are not flat).

Parameters:

Name Type Description Default
grad_log_psi_apply Callable

function which evaluates the derivative of log|psi(x)|, i.e. (nabla psi)(x) / psi(x), with respect to x. Has the signature (params, x) -> (nabla psi)(x) / psi(x), so the derivative should be over the second arg, x, and the output shape should be the same as x

required
params pytree

model parameters, passed as the first arg of grad_log_psi

required
x Array

second input to grad_log_psi

required

Returns:

Type Description
jnp.float32

"local" laplacian calculation, i.e. (nabla^2 psi) / psi

Source code in vmcnet/physics/core.py
def laplacian_psi_over_psi(
    grad_log_psi_apply: ModelApply,
    params: P,
    x: Array,
) -> jnp.float32:
    """Compute (nabla^2 psi) / psi at x given a function which evaluates psi'(x)/psi.

    The computation is done by computing (forward-mode) derivatives of the gradient to
    get the columns of the Hessian, and accumulating the (i, i)th entries (but this
    implementation is significantly more memory efficient than directly computing the
    Hessian).

    This function uses the identity

        (nabla^2 psi) / psi = (nabla^2 log|psi|) + (nabla log|psi|)^2

    to avoid leaving the log domain during the computation.

    This function should be vmapped in order to be applied to batches of inputs, as it
    completely flattens x in order to take second derivatives w.r.t. each component.

    This is approach is extremely similar to the one in the FermiNet repo
    (in the jax branch, as of this writing -- see
    https://github.com/deepmind/ferminet/blob/aade61b3d30883b3238d6b50c85404d0e8176155/ferminet/hamiltonian.py).

    The main difference is that we are being explicit about the flattening of x within
    the Laplacian calculation, so that it does not have to be handled outside of this
    function (psi is free to take x shapes which are not flat).

    Args:
        grad_log_psi_apply (Callable): function which evaluates the derivative of
            log|psi(x)|, i.e. (nabla psi)(x) / psi(x), with respect to x. Has the
            signature (params, x) -> (nabla psi)(x) / psi(x), so the derivative should
            be over the second arg, x, and the output shape should be the same as x
        params (pytree): model parameters, passed as the first arg of grad_log_psi
        x (Array): second input to grad_log_psi

    Returns:
        jnp.float32: "local" laplacian calculation, i.e. (nabla^2 psi) / psi
    """
    x_shape = x.shape
    flat_x = jnp.reshape(x, (-1,))
    n = flat_x.shape[0]
    identity_mat = jnp.eye(n)

    def flattened_grad_log_psi_of_flat_x(flat_x_in):
        """Flattened input to flattened output version of grad_log_psi."""
        grad_log_psi_out = grad_log_psi_apply(params, jnp.reshape(flat_x_in, x_shape))
        return jnp.reshape(grad_log_psi_out, (-1,))

    def step_fn(carry, unused):
        del unused
        i = carry[0]
        primals, tangents = jax.jvp(
            flattened_grad_log_psi_of_flat_x, (flat_x,), (identity_mat[i],)
        )
        return (i + 1, carry[1] + jnp.square(primals[i]) + tangents[i]), None

    out, _ = jax.lax.scan(step_fn, (0, 0.0), xs=None, length=n)
    return out[1]

get_statistics_from_local_energy(local_energies, nchains, nan_safe=True)

Collectively reduce local energies to an average energy and variance.

Parameters:

Name Type Description Default
local_energies Array

local energies of shape (nchains,), possibly distributed across multiple devices via utils.distribute.pmap.

required
nchains int

total number of chains across all devices, used to compute a sample variance estimate of the local energy

required
nan_safe bool

flag which controls if jnp.nanmean is used instead of jnp.mean. Can be set to False when debugging if trying to find the source of unexpected nans. Defaults to True.

True

Returns:

Type Description
(jnp.float32, jnp.float32)

local energy average, local energy (sample) variance

Source code in vmcnet/physics/core.py
def get_statistics_from_local_energy(
    local_energies: Array, nchains: int, nan_safe: bool = True
) -> Tuple[jnp.float32, jnp.float32]:
    """Collectively reduce local energies to an average energy and variance.

    Args:
        local_energies (Array): local energies of shape (nchains,), possibly
            distributed across multiple devices via utils.distribute.pmap.
        nchains (int): total number of chains across all devices, used to compute a
            sample variance estimate of the local energy
        nan_safe (bool, optional): flag which controls if jnp.nanmean is used instead of
            jnp.mean. Can be set to False when debugging if trying to find the source of
            unexpected nans. Defaults to True.

    Returns:
        (jnp.float32, jnp.float32): local energy average, local energy (sample) variance
    """
    # TODO(Jeffmin) might be worth investigating the numerical stability of the XLA
    # compiled version of these two computations, since the quality of the gradients
    # is fairly crucial to the success of the algorithm
    if nan_safe:
        allreduce_mean = utils.distribute.nanmean_all_local_devices
    else:
        allreduce_mean = utils.distribute.mean_all_local_devices
    energy = allreduce_mean(local_energies)
    variance = (
        allreduce_mean(jnp.square(local_energies - energy)) * nchains / (nchains - 1)
    )  # adjust by n / (n - 1) to get an unbiased estimator
    return energy, variance

get_default_energy_bwd(log_psi_apply, mean_grad_fn)

Use a standard variance reduction formula to get the bwd pass of the energy.

The formula is 2 * E_p[(local_e - E_p[local_e]) * grad_log_psi], where the symbol E_p[] refers to the expectation over the probability distribution p defined by p(x) = |psi(x)|^2 / . This is an unbiased estimator of the gradient of E_p[local_e], and has a lower variance than directly differentiating E_p[local_e] with respect to the parameters.

Parameters:

Name Type Description Default
log_psi_apply Callable

computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)|

required
mean_grad_fn Callable

function which is used to average the local gradient terms over all local devices. Has the signature local_grads -> avg_grad / 2, and should only average over the batch axis 0.

required

Returns:

Type Description
Callable

function which computes the backward pass in the custom vjp of the total energy. Has the signature (res, cotangents) -> (gradients, None)

Source code in vmcnet/physics/core.py
def get_default_energy_bwd(
    log_psi_apply: ModelApply[P],
    mean_grad_fn: Callable[[Array], Array],
):
    """Use a standard variance reduction formula to get the bwd pass of the energy.

    The formula is 2 * E_p[(local_e - E_p[local_e]) * grad_log_psi], where the
    symbol E_p[] refers to the expectation over the probability distribution p defined
    by p(x) = |psi(x)|^2 / <psi | psi>. This is an unbiased estimator of the gradient
    of E_p[local_e], and has a lower variance than directly differentiating E_p[local_e]
    with respect to the parameters.

    Args:
        log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
            function is (params, x) -> log|psi(x)|
        mean_grad_fn (Callable): function which is used to average the local gradient
            terms over all local devices. Has the signature local_grads -> avg_grad / 2,
            and should only average over the batch axis 0.

    Returns:
        Callable: function which computes the backward pass in the custom vjp of the
        total energy. Has the signature (res, cotangents) -> (gradients, None)
    """

    def scaled_by_local_e(
        params: P, positions: Array, centered_local_energies: Array
    ) -> jnp.float32:
        log_psi = log_psi_apply(params, positions)
        loss_functions.register_normal_predictive_distribution(log_psi[:, None])
        return 2.0 * mean_grad_fn(centered_local_energies * log_psi)

    _get_energy_grad = jax.grad(scaled_by_local_e, argnums=0)

    def energy_bwd(res, cotangents) -> Tuple[P, None]:
        energy, local_energies, params, positions = res
        centered_local_energies = local_energies - energy
        gradient = _get_energy_grad(params, positions, centered_local_energies)
        return jax.tree_map(lambda x: x * cotangents[0], gradient), None

    return energy_bwd

create_value_and_grad_energy_fn(log_psi_apply, local_energy_fn, nchains, clipping_fn=None, nan_safe=True, get_energy_bwd=<function get_default_energy_bwd at 0x7ff1729ab940>)

Create a function which computes unbiased energy gradients.

Due to the Hermiticity of the Hamiltonian, we can get an unbiased lower variance estimate of the gradient of the expected energy than the naive gradient of the mean of sampled local energies. Specifically, the gradient of the expected energy expect[E_L] takes the form

2 * expect[(E_L - expect[E_L]) * (grad_psi / psi)(x)],

where E_L is the local energy and expect[] denotes the expectation with respect to the distribution |psi|^2.

Parameters:

Name Type Description Default
log_psi_apply Callable

computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)|

required
local_energy_fn Callable

computes local energies Hpsi / psi. Has signature (params, x) -> (Hpsi / psi)(x)

required
nchains int

total number of chains across all devices, used to compute a sample variance estimate of the local energy

required
clipping_fn Callable

post-processing function on the local energy, e.g. a function which clips the values to be within some multiple of the total variation from the median. The post-processed values are used for the gradient calculation, if available. Defaults to None.

None
nan_safe bool

flag which controls if jnp.nanmean and jnp.nansum are used instead of jnp.mean and jnp.sum for the terms in the gradient calculation. Can be set to False when debugging if trying to find the source of unexpected nans. Defaults to True.

True
get_energy_bwd Callable

function which returns a custom backward pass for the total energy calculation. Has the signature (log_psi_apply, mean_grad_fn) -> energy_bwd. Defaults to get_default_energy_bwd, which computes the formula above.

<function get_default_energy_bwd at 0x7ff1729ab940>

Returns:

Type Description
Callable

function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance)

Source code in vmcnet/physics/core.py
def create_value_and_grad_energy_fn(
    log_psi_apply: ModelApply[P],
    local_energy_fn: ModelApply[P],
    nchains: int,
    clipping_fn: Optional[Callable[[Array], Array]] = None,
    nan_safe: bool = True,
    get_energy_bwd: Callable = get_default_energy_bwd,
) -> ValueGradEnergyFn[P]:
    """Create a function which computes unbiased energy gradients.

    Due to the Hermiticity of the Hamiltonian, we can get an unbiased lower variance
    estimate of the gradient of the expected energy than the naive gradient of the
    mean of sampled local energies. Specifically, the gradient of the expected energy
    expect[E_L] takes the form

        2 * expect[(E_L - expect[E_L]) * (grad_psi / psi)(x)],

    where E_L is the local energy and expect[] denotes the expectation with respect to
    the distribution |psi|^2.

    Args:
        log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
            function is (params, x) -> log|psi(x)|
        local_energy_fn (Callable): computes local energies Hpsi / psi. Has signature
            (params, x) -> (Hpsi / psi)(x)
        nchains (int): total number of chains across all devices, used to compute a
            sample variance estimate of the local energy
        clipping_fn (Callable, optional): post-processing function on the local energy,
            e.g. a function which clips the values to be within some multiple of the
            total variation from the median. The post-processed values are used for
            the gradient calculation, if available. Defaults to None.
        nan_safe (bool, optional): flag which controls if jnp.nanmean and jnp.nansum are
            used instead of jnp.mean and jnp.sum for the terms in the gradient
            calculation. Can be set to False when debugging if trying to find the source
            of unexpected nans. Defaults to True.
        get_energy_bwd (Callable): function which returns a custom backward pass for the
            total energy calculation. Has the signature
                (log_psi_apply, mean_grad_fn) -> energy_bwd.
            Defaults to get_default_energy_bwd, which computes the formula above.

    Returns:
        Callable: function which computes the clipped energy value and gradient. Has the
        signature
            (params, x)
            -> ((expected_energy, auxiliary_energy_data), grad_energy),
        where auxiliary_energy_data is the tuple
        (expected_variance, local_energies, unclipped_energy, unclipped_variance)
    """

    @jax.custom_vjp
    def compute_energy_data(params: P, positions: Array) -> EnergyData:
        local_energies_noclip = local_energy_fn(params, positions)
        if clipping_fn is not None:
            local_energies = clipping_fn(local_energies_noclip)
            energy, variance = get_statistics_from_local_energy(
                local_energies, nchains, nan_safe=nan_safe
            )

            # For the unclipped metrics, which are not used in the gradient, don't
            # do these in a nan-safe way. This makes nans more visible and makes sure
            # the command-line checkpoint_if_nans flag will work properly.
            energy_noclip, variance_noclip = get_statistics_from_local_energy(
                local_energies_noclip, nchains, nan_safe=False
            )

            aux_data = (variance, local_energies, energy_noclip, variance_noclip)
        else:
            local_energies = local_energies_noclip
            energy, variance = get_statistics_from_local_energy(
                local_energies, nchains, nan_safe=nan_safe
            )

            # Even though there's no clipping function, still record noclip metrics
            # without nan-safety so that checkpointing epochs with nans can be
            # supported.
            energy_noclip, variance_noclip = get_statistics_from_local_energy(
                local_energies, nchains, nan_safe=False
            )
            aux_data = (variance, local_energies, energy_noclip, variance_noclip)
        return energy, aux_data

    def energy_fwd(params: P, positions: Array):
        output = compute_energy_data(params, positions)
        energy = output[0]
        local_energies = output[1][1]
        return output, (energy, local_energies, params, positions)

    mean_grad_fn = utils.distribute.get_mean_over_first_axis_fn(nan_safe=nan_safe)
    energy_bwd = get_energy_bwd(log_psi_apply, mean_grad_fn)

    compute_energy_data.defvjp(energy_fwd, energy_bwd)

    energy_data_val_and_grad = jax.value_and_grad(
        compute_energy_data, argnums=0, has_aux=True
    )
    return energy_data_val_and_grad