Core local energy and gradient construction routines.

initialize_molecular_pos(key, nchains, ion_pos, ion_charges, nelec_total, init_width=1.0, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

def initialize_molecular_pos(
    key: PRNGKey,
    nchains: int,
    ion_pos: Array,
    ion_charges: Array,
    nelec_total: int,
    init_width: float = 1.0,
) -> Tuple[PRNGKey, Array]:
    """Initialize a set of plausible initial electron positions.

    For each chain, each electron is assigned to a random ion and then its position is
    sampled from a normal distribution centered at that ion with diagonal covariance
    with diagonal entries all equal to init_width.

    If there are no more electrons than there are ions, the assignment is done without
    replacement. If there are more electrons than ions, the assignment is done with
    replacement, and the probability of choosing ion i is its relative charge (as a
    fraction of the sum of the ion charges).
    nion = len(ion_charges)
    replace = True

    if nelec_total <= nion:
        replace = False

    assignments = []
    for _ in range(nchains):
        key, subkey = jax.random.split(key)
        choices = jax.random.choice(
            p=ion_charges / jnp.sum(ion_charges),
    elecs_at_ions = jnp.stack(assignments, axis=0)
    key, subkey = jax.random.split(key)
    return key, elecs_at_ions + init_width * jax.random.normal(
        subkey, elecs_at_ions.shape, dtype=dtype


Name Type Description Default
local_energy_terms Sequence

sequence of local energy terms, each with the signature (params, x) -> array of terms of shape (x.shape[0],)



Type Description

local energy function which computes the sum of the local energy terms. Has the signature (params, x) -> local energy array of shape (x.shape[0],)

def combine_local_energy_terms(
    local_energy_terms: Sequence[ModelApply[P]],
) -> ModelApply[P]:
    def local_energy_fn(params: P, x: Array) -> Array:
        local_energy_sum = local_energy_terms[0](params, x)
        for term in local_energy_terms[1:]:
            local_energy_sum = cast(Array, local_energy_sum + term(params, x))
        return local_energy_sum

    return local_energy_fn

laplacian_psi_over_psi(grad_log_psi_apply, params, x)

Name Type Description Default
grad_log_psi_apply Callable

function which evaluates the derivative of log|psi(x)|, i.e. (nabla psi)(x) / psi(x), with respect to x. Has the signature (params, x) -> (nabla psi)(x) / psi(x), so the derivative should be over the second arg, x, and the output shape should be the same as x

params pytree

model parameters, passed as the first arg of grad_log_psi

x Array

second input to grad_log_psi



Type Description

"local" laplacian calculation, i.e. (nabla^2 psi) / psi

def laplacian_psi_over_psi(
    grad_log_psi_apply: ModelApply,
    params: P,
    x: Array,
) -> jnp.float32:
        grad_log_psi_apply (Callable): function which evaluates the derivative of
            log|psi(x)|, i.e. (nabla psi)(x) / psi(x), with respect to x. Has the
            signature (params, x) -> (nabla psi)(x) / psi(x), so the derivative should
            be over the second arg, x, and the output shape should be the same as x
        params (pytree): model parameters, passed as the first arg of grad_log_psi
        x (Array): second input to grad_log_psi

        jnp.float32: "local" laplacian calculation, i.e. (nabla^2 psi) / psi
    x_shape = x.shape
    flat_x = jnp.reshape(x, (-1,))
    n = flat_x.shape[0]
    identity_mat = jnp.eye(n)

    def flattened_grad_log_psi_of_flat_x(flat_x_in):
        """Flattened input to flattened output version of grad_log_psi."""
        grad_log_psi_out = grad_log_psi_apply(params, jnp.reshape(flat_x_in, x_shape))
        return jnp.reshape(grad_log_psi_out, (-1,))

    def step_fn(carry, unused):
        del unused
        i = carry[0]
        primals, tangents = jax.jvp(
            flattened_grad_log_psi_of_flat_x, (flat_x,), (identity_mat[i],)
        return (i + 1, carry[1] + jnp.square(primals[i]) + tangents[i]), None

    out, _ = jax.lax.scan(step_fn, (0, 0.0), xs=None, length=n)
    return out[1]

get_statistics_from_local_energy(local_energies, nchains, nan_safe=True)

Name Type Description Default
local_energies Array

local energies of shape (nchains,), possibly distributed across multiple devices via utils.distribute.pmap.

nchains int

total number of chains across all devices, used to compute a sample variance estimate of the local energy

nan_safe bool

flag which controls if jnp.nanmean is used instead of jnp.mean. Can be set to False when debugging if trying to find the source of unexpected nans. Defaults to True.



Type Description
(jnp.float32, jnp.float32)

local energy average, local energy (sample) variance

def get_statistics_from_local_energy(
    local_energies: Array, nchains: int, nan_safe: bool = True
) -> Tuple[jnp.float32, jnp.float32]:
    # TODO(Jeffmin) might be worth investigating the numerical stability of the XLA
    # compiled version of these two computations, since the quality of the gradients
    # is fairly crucial to the success of the algorithm
    if nan_safe:
        allreduce_mean = utils.distribute.nanmean_all_local_devices
        allreduce_mean = utils.distribute.mean_all_local_devices
    energy = allreduce_mean(local_energies)
    variance = (
        allreduce_mean(jnp.square(local_energies - energy)) * nchains / (nchains - 1)
    )  # adjust by n / (n - 1) to get an unbiased estimator
    return energy, variance

get_default_energy_bwd(log_psi_apply, mean_grad_fn)

Name Type Description Default
log_psi_apply Callable

computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)|

mean_grad_fn Callable

function which is used to average the local gradient terms over all local devices. Has the signature local_grads -> avg_grad / 2, and should only average over the batch axis 0.



Type Description

function which computes the backward pass in the custom vjp of the total energy. Has the signature (res, cotangents) -> (gradients, None)

def get_default_energy_bwd(
    log_psi_apply: ModelApply[P],
    mean_grad_fn: Callable[[Array], Array],
        log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
            function is (params, x) -> log|psi(x)|
        mean_grad_fn (Callable): function which is used to average the local gradient
            terms over all local devices. Has the signature local_grads -> avg_grad / 2,
            and should only average over the batch axis 0.

        Callable: function which computes the backward pass in the custom vjp of the
        total energy. Has the signature (res, cotangents) -> (gradients, None)

    def scaled_by_local_e(
        params: P, positions: Array, centered_local_energies: Array
    ) -> jnp.float32:
        log_psi = log_psi_apply(params, positions)
        loss_functions.register_normal_predictive_distribution(log_psi[:, None])
        return 2.0 * mean_grad_fn(centered_local_energies * log_psi)

    _get_energy_grad = jax.grad(scaled_by_local_e, argnums=0)

    def energy_bwd(res, cotangents) -> Tuple[P, None]:
        energy, local_energies, params, positions = res
        centered_local_energies = local_energies - energy
        gradient = _get_energy_grad(params, positions, centered_local_energies)
        return jax.tree_map(lambda x: x * cotangents[0], gradient), None

    return energy_bwd

create_value_and_grad_energy_fn(log_psi_apply, local_energy_fn, nchains, clipping_fn=None, nan_safe=True, get_energy_bwd=<function get_default_energy_bwd at 0x7ff1729ab940>)

Name Type Description Default
log_psi_apply Callable

computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)|

local_energy_fn Callable

computes local energies Hpsi / psi. Has signature (params, x) -> (Hpsi / psi)(x)

nchains int

total number of chains across all devices, used to compute a sample variance estimate of the local energy

clipping_fn Callable

post-processing function on the local energy, e.g. a function which clips the values to be within some multiple of the total variation from the median. The post-processed values are used for the gradient calculation, if available. Defaults to None.

nan_safe bool

flag which controls if jnp.nanmean and jnp.nansum are used instead of jnp.mean and jnp.sum for the terms in the gradient calculation. Can be set to False when debugging if trying to find the source of unexpected nans. Defaults to True.

get_energy_bwd Callable

function which returns a custom backward pass for the total energy calculation. Has the signature (log_psi_apply, mean_grad_fn) -> energy_bwd. Defaults to get_default_energy_bwd, which computes the formula above.

<function get_default_energy_bwd at 0x7ff1729ab940>


Type Description

function which computes the clipped energy value and gradient. Has the signature (params, x) -> ((expected_energy, auxiliary_energy_data), grad_energy), where auxiliary_energy_data is the tuple (expected_variance, local_energies, unclipped_energy, unclipped_variance)

def create_value_and_grad_energy_fn(
    log_psi_apply: ModelApply[P],
    local_energy_fn: ModelApply[P],
    nchains: int,
    clipping_fn: Optional[Callable[[Array], Array]] = None,
    nan_safe: bool = True,
    get_energy_bwd: Callable = get_default_energy_bwd,
) -> ValueGradEnergyFn[P]:
        log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
            function is (params, x) -> log|psi(x)|
        local_energy_fn (Callable): computes local energies Hpsi / psi. Has signature
            (params, x) -> (Hpsi / psi)(x)
        nchains (int): total number of chains across all devices, used to compute a
            sample variance estimate of the local energy
        clipping_fn (Callable, optional): post-processing function on the local energy,
            e.g. a function which clips the values to be within some multiple of the
            total variation from the median. The post-processed values are used for
            the gradient calculation, if available. Defaults to None.
        nan_safe (bool, optional): flag which controls if jnp.nanmean and jnp.nansum are
            used instead of jnp.mean and jnp.sum for the terms in the gradient
            calculation. Can be set to False when debugging if trying to find the source
            of unexpected nans. Defaults to True.
        get_energy_bwd (Callable): function which returns a custom backward pass for the
            total energy calculation. Has the signature
                (log_psi_apply, mean_grad_fn) -> energy_bwd.
            Defaults to get_default_energy_bwd, which computes the formula above.

        Callable: function which computes the clipped energy value and gradient. Has the
            (params, x)
            -> ((expected_energy, auxiliary_energy_data), grad_energy),
        where auxiliary_energy_data is the tuple
        (expected_variance, local_energies, unclipped_energy, unclipped_variance)

    def compute_energy_data(params: P, positions: Array) -> EnergyData:
        local_energies_noclip = local_energy_fn(params, positions)
        if clipping_fn is not None:
            local_energies = clipping_fn(local_energies_noclip)
            energy, variance = get_statistics_from_local_energy(
                local_energies, nchains, nan_safe=nan_safe

            # For the unclipped metrics, which are not used in the gradient, don't
            # do these in a nan-safe way. This makes nans more visible and makes sure
            # the command-line checkpoint_if_nans flag will work properly.
            energy_noclip, variance_noclip = get_statistics_from_local_energy(
                local_energies_noclip, nchains, nan_safe=False

            aux_data = (variance, local_energies, energy_noclip, variance_noclip)
            local_energies = local_energies_noclip
            energy, variance = get_statistics_from_local_energy(
                local_energies, nchains, nan_safe=nan_safe

            # Even though there's no clipping function, still record noclip metrics
            # without nan-safety so that checkpointing epochs with nans can be
            # supported.
            energy_noclip, variance_noclip = get_statistics_from_local_energy(
                local_energies, nchains, nan_safe=False
            aux_data = (variance, local_energies, energy_noclip, variance_noclip)
        return energy, aux_data

    def energy_fwd(params: P, positions: Array):
        output = compute_energy_data(params, positions)
        energy = output[0]
        local_energies = output[1][1]
        return output, (energy, local_energies, params, positions)

    mean_grad_fn = utils.distribute.get_mean_over_first_axis_fn(nan_safe=nan_safe)
    energy_bwd = get_energy_bwd(log_psi_apply, mean_grad_fn)

    compute_energy_data.defvjp(energy_fwd, energy_bwd)

    energy_data_val_and_grad = jax.value_and_grad(
        compute_energy_data, argnums=0, has_aux=True
    return energy_data_val_and_grad