Skip to content

sign_symmetry

Routines for symmetrizing a function to be sign covariant.

ProductsSignCovariance (Module) dataclass

Sign covariance from a weighted sum of products of per-particle values.

Only supports two spins at the moment. Given per-spin antiequivariant vectors a_1, a_2, ..., and b_1, b_2, ..., computes an antisymmetry of sum_{i,j} (w_{i,j} sum_{k} a_ik b_jk), or multiple such antisymmetries if features>1. If use_weights=False, then no weights are used, so that effectively w_{i,j} = 1 for all i,j.

Attributes:

Name Type Description
features int

the number of antisymmetric output features to generate. If use_weights is False, must be equal to 1.

kernel_init WeightInitializer

initializer for the weights of the dense layer.

register_kfac bool

whether to register the dense layer with KFAC. Defaults to True.

use_weights bool

whether to use a weighted sum of products. Defaults to False.

__call__(self, x) special

Calculate weighted sum of products of up- and down-spin antiequivariances.

Parameters:

Name Type Description Default
x ArrayList

input antiequivariant arrays of shape [(..., nelec_up, d), (..., nelec_down, d)]

required

Returns:

Type Description
Array

array of length features of antisymmetric values calculated by taking a weighted sum of the pairwise dot-products of the up- and down-spin antiequivariant inputs.

Source code in vmcnet/models/sign_symmetry.py
@flax.linen.compact
def __call__(self, x: ArrayList) -> Array:  # type: ignore[override]
    """Calculate weighted sum of products of up- and down-spin antiequivariances.

    Arguments:
        x (ArrayList): input antiequivariant arrays of shape
            [(..., nelec_up, d), (..., nelec_down, d)]

    Returns:
        Array: array of length features of antisymmetric values calculated
            by taking a weighted sum of the pairwise dot-products of the up- and
            down-spin antiequivariant inputs.
    """
    # TODO (ggoldsh): update this to support nspins != 2 as well
    if len(x) != 2:
        raise ValueError(
            "Products covariance only supported for nspins=2, got {}".format(len(x))
        )

    naxes = len(x[0].shape)
    batch_dims = range(naxes - 2)
    contraction_dim = (naxes - 1,)
    # Since the second last axis is not specified as either a batch or contraction
    # dim, jax.lax.dot_general will automatically compute over all pairs of up and
    # down spins. pairwise_dots thus has shape (..., nelec_up, nelec_down).
    pairwise_dots = jax.lax.dot_general(
        x[0], x[1], ((contraction_dim, contraction_dim), (batch_dims, batch_dims))
    )

    if not self.use_weights:
        if self.features != 1:
            raise ValueError(
                "Can only return one output feature when use_weights is False. "
                "Received {} for features.".format(self.features)
            )
        return jnp.expand_dims(jnp.sum(pairwise_dots, axis=(-1, -2)), -1)

    shape = pairwise_dots.shape
    # flattened_dots has shape (..., nelec_up * nelec_down)
    flattened_dots = jnp.reshape(
        pairwise_dots, (*shape[:-2], shape[-1] * shape[-2])
    )

    return Dense(
        self.features,
        kernel_init=self.kernel_init,
        use_bias=False,
        register_kfac=self.register_kfac,
    )(flattened_dots)

apply_sign_symmetry_to_fn(fn, get_signs_and_syms, apply_output_signs, add_up_results)

Make a function of a list of inputs covariant in the sign of each input.

That is, output a function g(s_1, s_2, ..., s_n) with is odd with respect to each input, such that g(s_1, ...) = -g(-s_1, ...), and likewise for every other input.

This is done by taking the orbit of the inputs with respect to the sign group applied separately to each one, and adding up the results with appropriate covariant signs. For example, for two spins this calculates g(U,D) = f(U,D) - f(-U,D) - f(U, -D) + f(-U, -D).

Inputs are assumed to be either Arrays or SLArrays, so that in either case the required symmetries can be stacked along a new axis of the underlying array values. The function fn is assumed to support the injection of a batch dimension, done in get_signs_and_syms, and pass it through to the output (e.g., a function which flattens the input would not be supported). The extra dimension is removed at the end via add_up_results.

Parameters:

Name Type Description Default
fn Callable

the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape))

required
get_signs_and_syms Callable

a function which gets the signs and symmetries for the input array. Returns a tuple of the symmetries plus the associated signs as a 1D array.

required
apply_output_signs Callable

function for applying signs to the outputs of the symmetrized function. For example, if the outputs are Arrays, this would simply multiply the arrays by the signs along the appropriate axis.

required
add_up_results Callable

function for combining the signed outputs into a single, sign-covariant output. For example, simple addition for Arrays or the slog_sum function for SLArrays.

required

Returns:

Type Description
Callable

a function with the same signature as the input function, but which has been symmetrized so that its output will be covariant with respect to the sign of each input, or in other words, will be odd.

Source code in vmcnet/models/sign_symmetry.py
def apply_sign_symmetry_to_fn(
    fn: Callable[[List[A]], A],
    get_signs_and_syms: Callable[[List[A]], Tuple[List[A], Array]],
    apply_output_signs: Callable[[A, Array], A],
    add_up_results: Callable[[A], A],
) -> Callable[[List[A]], A]:
    """Make a function of a list of inputs covariant in the sign of each input.

    That is, output a function g(s_1, s_2, ..., s_n) with is odd with respect to each
    input, such that g(s_1, ...) = -g(-s_1, ...), and likewise for every other input.

    This is done by taking the orbit of the inputs with respect to the sign group
    applied separately to each one, and adding up the results with appropriate
    covariant signs. For example, for two spins this calculates
    g(U,D) = f(U,D) - f(-U,D) - f(U, -D) + f(-U, -D).

    Inputs are assumed to be either Arrays or SLArrays, so that in either case the
    required symmetries can be stacked along a new axis of the underlying array values.
    The function `fn` is assumed to support the injection of a batch dimension, done in
    `get_signs_and_syms`, and pass it through to the output (e.g., a function which
    flattens the input would not be supported). The extra dimension is removed at the
    end via `add_up_results`.

    Args:
        fn (Callable): the function to symmetrize. The given axis is injected into the
            inputs and the sign orbit is computed, so this function should be able to
            treat the given sign orbit axis as a batch dimension, and the overall tensor
            rank should not change (len(input.shape) == len(output.shape))
        get_signs_and_syms (Callable): a function which gets the signs and symmetries
            for the input array. Returns a tuple of the symmetries plus the associated
            signs as a 1D array.
        apply_output_signs (Callable): function for applying signs to the outputs of
            the symmetrized function. For example, if the outputs are Arrays, this
            would simply multiply the arrays by the signs along the appropriate axis.
        add_up_results (Callable): function for combining the signed outputs into a
            single, sign-covariant output. For example, simple addition for Arrays
            or the slog_sum function for SLArrays.

    Returns:
        Callable: a function with the same signature as the input function, but
        which has been symmetrized so that its output will be covariant with respect
        to the sign of each input, or in other words, will be odd.
    """

    def sign_covariant_fn(x: List[A]) -> A:
        symmetries, signs = get_signs_and_syms(x)
        outputs = fn(symmetries)
        signed_results = apply_output_signs(outputs, signs)
        return add_up_results(signed_results)

    return sign_covariant_fn

make_sl_array_list_fn_sign_covariant(fn, axis=-2)

Make a function of an SLArrayList sign-covariant in the sign of each SLArray.

Shallow wrapper around the generic apply_sign_symmetry_to_fn.

Parameters:

Name Type Description Default
fn Callable

the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape))

required

Returns:

Type Description
Callable

a function with the same signature as the input function, but which has been symmetrized so that its output will be covariant with respect to the sign of each input, or in other words, will be odd.

Source code in vmcnet/models/sign_symmetry.py
def make_sl_array_list_fn_sign_covariant(
    fn: Callable[[SLArrayList], SLArray], axis: int = -2
) -> Callable[[SLArrayList], SLArray]:
    """Make a function of an SLArrayList sign-covariant in the sign of each SLArray.

    Shallow wrapper around the generic apply_sign_symmetry_to_fn.

    Args:
        fn (Callable): the function to symmetrize. The given axis is injected into the
            inputs and the sign orbit is computed, so this function should be able to
            treat the given sign orbit axis as a batch dimension, and the overall tensor
            rank should not change (len(input.shape) == len(output.shape))

    Returns:
        Callable: a function with the same signature as the input function, but
        which has been symmetrized so that its output will be covariant with respect
        to the sign of each input, or in other words, will be odd.
    """
    return apply_sign_symmetry_to_fn(
        fn,
        functools.partial(_get_sign_orbit_sl_array_list, axis=axis),
        lambda x, s: (_multiply_sign_along_axis(x[0], s, axis), x[1]),
        functools.partial(slog_sum_over_axis, axis=axis),
    )

make_array_list_fn_sign_covariant(fn, axis=-2)

Make a function of an ArrayList sign-covariant in the sign of each array.

Shallow wrapper around the generic apply_sign_symmetry_to_fn.

Parameters:

Name Type Description Default
fn Callable

the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape))

required

Returns:

Type Description
Callable

a function with the same signature as the input function, but which has been symmetrized so that its output will be covariant with respect to the sign of each input, or in other words, will be odd.

Source code in vmcnet/models/sign_symmetry.py
def make_array_list_fn_sign_covariant(
    fn: Callable[[ArrayList], Array], axis: int = -2
) -> Callable[[ArrayList], Array]:
    """Make a function of an ArrayList sign-covariant in the sign of each array.

    Shallow wrapper around the generic apply_sign_symmetry_to_fn.

    Args:
        fn (Callable): the function to symmetrize. The given axis is injected into the
            inputs and the sign orbit is computed, so this function should be able to
            treat the given sign orbit axis as a batch dimension, and the overall tensor
            rank should not change (len(input.shape) == len(output.shape))

    Returns:
        Callable: a function with the same signature as the input function, but
        which has been symmetrized so that its output will be covariant with respect
        to the sign of each input, or in other words, will be odd.
    """
    return apply_sign_symmetry_to_fn(
        fn,
        functools.partial(_get_sign_orbit_array_list, axis=axis),
        functools.partial(_multiply_sign_along_axis, axis=axis),
        functools.partial(jnp.sum, axis=axis),
    )

make_array_list_fn_sign_invariant(fn, axis=-2)

Make a function of an ArrayList sign-invariant (even) in the sign of each array.

Shallow wrapper around the generic apply_sign_symmetry_to_fn.

Parameters:

Name Type Description Default
fn Callable

the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape))

required

Returns:

Type Description
Callable

a function with the same signature as the input function, but which has been symmetrized so that its output will be invariant with respect to the sign of each input, or in other words, will be even.

Source code in vmcnet/models/sign_symmetry.py
def make_array_list_fn_sign_invariant(
    fn: Callable[[ArrayList], Array], axis: int = -2
) -> Callable[[ArrayList], Array]:
    """Make a function of an ArrayList sign-invariant (even) in the sign of each array.

    Shallow wrapper around the generic apply_sign_symmetry_to_fn.

    Args:
        fn (Callable): the function to symmetrize. The given axis is injected into the
            inputs and the sign orbit is computed, so this function should be able to
            treat the given sign orbit axis as a batch dimension, and the overall tensor
            rank should not change (len(input.shape) == len(output.shape))

    Returns:
        Callable: a function with the same signature as the input function, but
        which has been symmetrized so that its output will be invariant with respect
        to the sign of each input, or in other words, will be even.
    """
    return apply_sign_symmetry_to_fn(
        fn,
        functools.partial(_get_sign_orbit_array_list, axis=axis),
        lambda x, _: x,  # Ignore the signs to get an invariance
        functools.partial(jnp.sum, axis=axis),
    )