Skip to content

kinetic

Kinetic energy terms.

create_continuous_kinetic_energy(log_psi_apply)

Create the local kinetic energy fn (params, x) -> -0.5 (nabla^2 psi(x) / psi(x)).

Parameters:

Name Type Description Default
log_psi_apply Callable

a function which computes log|psi(x)| for single inputs x. It is okay for it to produce batch outputs on batches of x as long as it produces a single number for single x. Has the signature (params, single_x_in) -> log|psi(single_x_in)|

required

Returns:

Type Description
Callable

function which computes the local kinetic energy for continuous problems (as opposed to discrete/lattice problems), i.e. -0.5 nabla^2 psi / psi. Evaluates on batches due to the jax.vmap call, so it has signature (params, x) -> kinetic energy array with shape (x.shape[0],)

Source code in vmcnet/physics/kinetic.py
def create_continuous_kinetic_energy(
    log_psi_apply: Callable[[P, Array], Union[jnp.float32, Array]]
) -> ModelApply[P]:
    """Create the local kinetic energy fn (params, x) -> -0.5 (nabla^2 psi(x) / psi(x)).

    Args:
        log_psi_apply (Callable): a function which computes log|psi(x)| for single
            inputs x. It is okay for it to produce batch outputs on batches of x as long
            as it produces a single number for single x. Has the signature
            (params, single_x_in) -> log|psi(single_x_in)|

    Returns:
        Callable: function which computes the local kinetic energy for continuous
        problems (as opposed to discrete/lattice problems), i.e. -0.5 nabla^2 psi / psi.
        Evaluates on batches due to the jax.vmap call, so it has signature
        (params, x) -> kinetic energy array with shape (x.shape[0],)
    """
    grad_log_psi_apply = jax.grad(log_psi_apply, argnums=1)

    def kinetic_energy_fn(params: P, x: Array) -> jnp.float32:
        return -0.5 * physics.core.laplacian_psi_over_psi(grad_log_psi_apply, params, x)

    return jax.vmap(kinetic_energy_fn, in_axes=(None, 0), out_axes=0)