Skip to content

sr

Stochastic reconfiguration (SR) routine.

SRMode (Enum)

Modes for computing the preconditioning by the Fisher inverse during SR.

If LAZY, then uses composed jvp and vjp calls to lazily compute the various Jacobian-vector products. This is more computationally and memory-efficient. If DEBUG, then directly computes the Jacobian (per-example gradients) and uses jnp.matmul to compute the Jacobian-vector products. Defaults to LAZY.

get_fisher_inverse_fn(log_psi_apply, mean_grad_fn, damping=0.001, maxiter=None, mode=<SRMode.LAZY: 1>)

Get a Fisher-preconditioned update.

Given a gradient update grad_E, the function returned here approximates

(0.25 * F + damping * I)^{-1} * grad_E,

where F is the Fisher information matrix. The inversion is approximated via the conjugate gradient algorithm (possibly truncated to a finite number of iterations).

This preconditioned gradient update, when used as-is, is also known as the stochastic reconfiguration algorithm. See https://arxiv.org/pdf/1909.02487.pdf, Appendix C for the connection between natural gradient descent and stochastic reconfiguration.

Parameters:

Name Type Description Default
log_psi_apply Callable

computes log|psi(x)|, where the signature of this function is (params, x) -> log|psi(x)|

required
mean_grad_fn Callable

function which is used to average the local gradient terms over all local devices. Has the signature local_grads -> avg_grad / 2, and should only average over the batch axis 0.

required
damping float

multiple of the identity to add to the Fisher before inverting. Without this term, the approximation to the Fisher will always be less than full rank when nchains < nparams, and so CG will fail to converge. This should be tuned together with the learning rate. Defaults to 0.001.

0.001
maxiter int

maximum number of CG iterations to do when computing the inverse application of the Fisher. Defaults to None, which uses maxiter equal to 10 * number of params.

None
mode SRMode

mode of computing the forward Fisher-vector products. If LAZY, then uses composed jvp and vjp calls to lazily compute the various Jacobian-vector products. This is more computationally and memory-efficient. If DEBUG, then directly computes the Jacobian (per-example gradients) and uses jnp.matmul to compute the Jacobian-vector products. Defaults to LAZY.

<SRMode.LAZY: 1>

Returns:

Type Description
Callable

function which computes the gradient preconditioned with the inverse of the Fisher information matrix. Has the signature (energy_grad, params, positions) -> preconditioned_grad

Source code in vmcnet/updates/sr.py
def get_fisher_inverse_fn(
    log_psi_apply: ModelApply[P],
    mean_grad_fn: Callable[[Array], Array],
    damping: float = 0.001,
    maxiter: Optional[int] = None,
    mode: SRMode = SRMode.LAZY,
):
    """Get a Fisher-preconditioned update.

    Given a gradient update grad_E, the function returned here approximates

        (0.25 * F + damping * I)^{-1} * grad_E,

    where F is the Fisher information matrix. The inversion is approximated via the
    conjugate gradient algorithm (possibly truncated to a finite number of iterations).

    This preconditioned gradient update, when used as-is, is also known as the
    stochastic reconfiguration algorithm. See https://arxiv.org/pdf/1909.02487.pdf,
    Appendix C for the connection between natural gradient descent and stochastic
    reconfiguration.

    Args:
        log_psi_apply (Callable): computes log|psi(x)|, where the signature of this
            function is (params, x) -> log|psi(x)|
        mean_grad_fn (Callable): function which is used to average the local gradient
            terms over all local devices. Has the signature local_grads -> avg_grad / 2,
            and should only average over the batch axis 0.
        damping (float, optional): multiple of the identity to add to the Fisher before
            inverting. Without this term, the approximation to the Fisher will always
            be less than full rank when nchains < nparams, and so CG will fail to
            converge. This should be tuned together with the learning rate. Defaults to
            0.001.
        maxiter (int, optional): maximum number of CG iterations to do when computing
            the inverse application of the Fisher. Defaults to None, which uses maxiter
            equal to 10 * number of params.
        mode (SRMode, optional): mode of computing the forward Fisher-vector products.
            If LAZY, then uses composed jvp and vjp calls to lazily compute the various
            Jacobian-vector products. This is more computationally and memory-efficient.
            If DEBUG, then directly computes the Jacobian (per-example gradients) and
            uses jnp.matmul to compute the Jacobian-vector products. Defaults to LAZY.

    Returns:
        Callable: function which computes the gradient preconditioned with the inverse
        of the Fisher information matrix. Has the signature
            (energy_grad, params, positions) -> preconditioned_grad
    """
    # TODO(Jeffmin): explore preconditioners for speeding up convergence and to provide
    # more stability
    # TODO(Jeffmin): investigate damping scheduling and possibly adaptive damping
    if mode == SRMode.DEBUG:

        def raveled_log_psi_grad(params: P, positions: Array) -> Array:
            log_grads = jax.grad(log_psi_apply)(params, positions)
            return jax.flatten_util.ravel_pytree(log_grads)[0]

        batch_raveled_log_psi_grad = jax.vmap(raveled_log_psi_grad, in_axes=(None, 0))

        def precondition_grad_with_fisher(
            energy_grad: P, params: P, positions: Array
        ) -> P:
            raveled_energy_grad, unravel_fn = jax.flatten_util.ravel_pytree(energy_grad)

            log_psi_grads = batch_raveled_log_psi_grad(params, positions)
            mean_log_psi_grads = mean_grad_fn(log_psi_grads)
            centered_log_psi_grads = (
                log_psi_grads - mean_log_psi_grads
            )  # shape (nchains, nparams)

            def fisher_apply(x: Array) -> Array:
                # x is shape (nparams,)
                nchains_local = centered_log_psi_grads.shape[0]
                centered_jacobian_vector_prod = jnp.matmul(centered_log_psi_grads, x)
                local_fisher_times_x = (
                    jnp.matmul(
                        jnp.transpose(centered_log_psi_grads),
                        centered_jacobian_vector_prod,
                    )
                    / nchains_local
                )
                fisher_times_x = pmean_if_pmap(local_fisher_times_x)
                return fisher_times_x + damping * x

            sr_grad, _ = jscp.sparse.linalg.cg(
                fisher_apply,
                raveled_energy_grad,
                x0=raveled_energy_grad,
                maxiter=maxiter,
            )

            return unravel_fn(sr_grad)

    elif mode == SRMode.LAZY:

        def precondition_grad_with_fisher(
            energy_grad: P, params: P, positions: Array
        ) -> P:
            def partial_log_psi_apply(params: P) -> Array:
                return log_psi_apply(params, positions)

            _, vjp_fn = jax.vjp(partial_log_psi_apply, params)

            def fisher_apply(x: Array) -> Array:
                # x is a pytree with same structure as params
                nchains_local = positions.shape[0]
                _, jacobian_vector_prod = jax.jvp(
                    partial_log_psi_apply, (params,), (x,)
                )
                mean_jacobian_vector_prod = mean_grad_fn(jacobian_vector_prod)
                centered_jacobian_vector_prod = (
                    jacobian_vector_prod - mean_jacobian_vector_prod
                )
                local_device_fisher_times_x = multiply_tree_by_scalar(
                    vjp_fn(centered_jacobian_vector_prod)[0], 1.0 / nchains_local
                )
                fisher_times_x = pmean_if_pmap(local_device_fisher_times_x)
                return tree_sum(fisher_times_x, multiply_tree_by_scalar(x, damping))

            sr_grad, _ = jscp.sparse.linalg.cg(
                fisher_apply,
                energy_grad,
                x0=energy_grad,
                maxiter=maxiter,
            )

            return sr_grad

    else:
        raise ValueError(
            "Requested Fisher apply mode not supported; only {} are supported, "
            "but {} was requested.".format(", ".join(SRMode.__members__.keys()), mode)
        )

    return precondition_grad_with_fisher