sign_symmetry
Routines for symmetrizing a function to be sign covariant.
ProductsSignCovariance (Module)
dataclass
Sign covariance from a weighted sum of products of per-particle values.
Only supports two spins at the moment. Given per-spin antiequivariant vectors a_1, a_2, ..., and b_1, b_2, ..., computes an antisymmetry of sum_{i,j} (w_{i,j} sum_{k} a_ik b_jk), or multiple such antisymmetries if features>1. If use_weights=False, then no weights are used, so that effectively w_{i,j} = 1 for all i,j.
Attributes:
Name | Type | Description |
---|---|---|
features |
int |
the number of antisymmetric output features to generate. If use_weights is False, must be equal to 1. |
kernel_init |
WeightInitializer |
initializer for the weights of the dense layer. |
register_kfac |
bool |
whether to register the dense layer with KFAC. Defaults to True. |
use_weights |
bool |
whether to use a weighted sum of products. Defaults to False. |
__call__(self, x)
special
Calculate weighted sum of products of up- and down-spin antiequivariances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ArrayList |
input antiequivariant arrays of shape [(..., nelec_up, d), (..., nelec_down, d)] |
required |
Returns:
Type | Description |
---|---|
Array |
array of length features of antisymmetric values calculated by taking a weighted sum of the pairwise dot-products of the up- and down-spin antiequivariant inputs. |
Source code in vmcnet/models/sign_symmetry.py
@flax.linen.compact
def __call__(self, x: ArrayList) -> Array: # type: ignore[override]
"""Calculate weighted sum of products of up- and down-spin antiequivariances.
Arguments:
x (ArrayList): input antiequivariant arrays of shape
[(..., nelec_up, d), (..., nelec_down, d)]
Returns:
Array: array of length features of antisymmetric values calculated
by taking a weighted sum of the pairwise dot-products of the up- and
down-spin antiequivariant inputs.
"""
# TODO (ggoldsh): update this to support nspins != 2 as well
if len(x) != 2:
raise ValueError(
"Products covariance only supported for nspins=2, got {}".format(len(x))
)
naxes = len(x[0].shape)
batch_dims = range(naxes - 2)
contraction_dim = (naxes - 1,)
# Since the second last axis is not specified as either a batch or contraction
# dim, jax.lax.dot_general will automatically compute over all pairs of up and
# down spins. pairwise_dots thus has shape (..., nelec_up, nelec_down).
pairwise_dots = jax.lax.dot_general(
x[0], x[1], ((contraction_dim, contraction_dim), (batch_dims, batch_dims))
)
if not self.use_weights:
if self.features != 1:
raise ValueError(
"Can only return one output feature when use_weights is False. "
"Received {} for features.".format(self.features)
)
return jnp.expand_dims(jnp.sum(pairwise_dots, axis=(-1, -2)), -1)
shape = pairwise_dots.shape
# flattened_dots has shape (..., nelec_up * nelec_down)
flattened_dots = jnp.reshape(
pairwise_dots, (*shape[:-2], shape[-1] * shape[-2])
)
return Dense(
self.features,
kernel_init=self.kernel_init,
use_bias=False,
register_kfac=self.register_kfac,
)(flattened_dots)
apply_sign_symmetry_to_fn(fn, get_signs_and_syms, apply_output_signs, add_up_results)
Make a function of a list of inputs covariant in the sign of each input.
That is, output a function g(s_1, s_2, ..., s_n) with is odd with respect to each input, such that g(s_1, ...) = -g(-s_1, ...), and likewise for every other input.
This is done by taking the orbit of the inputs with respect to the sign group applied separately to each one, and adding up the results with appropriate covariant signs. For example, for two spins this calculates g(U,D) = f(U,D) - f(-U,D) - f(U, -D) + f(-U, -D).
Inputs are assumed to be either Arrays or SLArrays, so that in either case the
required symmetries can be stacked along a new axis of the underlying array values.
The function fn
is assumed to support the injection of a batch dimension, done in
get_signs_and_syms
, and pass it through to the output (e.g., a function which
flattens the input would not be supported). The extra dimension is removed at the
end via add_up_results
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn |
Callable |
the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape)) |
required |
get_signs_and_syms |
Callable |
a function which gets the signs and symmetries for the input array. Returns a tuple of the symmetries plus the associated signs as a 1D array. |
required |
apply_output_signs |
Callable |
function for applying signs to the outputs of the symmetrized function. For example, if the outputs are Arrays, this would simply multiply the arrays by the signs along the appropriate axis. |
required |
add_up_results |
Callable |
function for combining the signed outputs into a single, sign-covariant output. For example, simple addition for Arrays or the slog_sum function for SLArrays. |
required |
Returns:
Type | Description |
---|---|
Callable |
a function with the same signature as the input function, but which has been symmetrized so that its output will be covariant with respect to the sign of each input, or in other words, will be odd. |
Source code in vmcnet/models/sign_symmetry.py
def apply_sign_symmetry_to_fn(
fn: Callable[[List[A]], A],
get_signs_and_syms: Callable[[List[A]], Tuple[List[A], Array]],
apply_output_signs: Callable[[A, Array], A],
add_up_results: Callable[[A], A],
) -> Callable[[List[A]], A]:
"""Make a function of a list of inputs covariant in the sign of each input.
That is, output a function g(s_1, s_2, ..., s_n) with is odd with respect to each
input, such that g(s_1, ...) = -g(-s_1, ...), and likewise for every other input.
This is done by taking the orbit of the inputs with respect to the sign group
applied separately to each one, and adding up the results with appropriate
covariant signs. For example, for two spins this calculates
g(U,D) = f(U,D) - f(-U,D) - f(U, -D) + f(-U, -D).
Inputs are assumed to be either Arrays or SLArrays, so that in either case the
required symmetries can be stacked along a new axis of the underlying array values.
The function `fn` is assumed to support the injection of a batch dimension, done in
`get_signs_and_syms`, and pass it through to the output (e.g., a function which
flattens the input would not be supported). The extra dimension is removed at the
end via `add_up_results`.
Args:
fn (Callable): the function to symmetrize. The given axis is injected into the
inputs and the sign orbit is computed, so this function should be able to
treat the given sign orbit axis as a batch dimension, and the overall tensor
rank should not change (len(input.shape) == len(output.shape))
get_signs_and_syms (Callable): a function which gets the signs and symmetries
for the input array. Returns a tuple of the symmetries plus the associated
signs as a 1D array.
apply_output_signs (Callable): function for applying signs to the outputs of
the symmetrized function. For example, if the outputs are Arrays, this
would simply multiply the arrays by the signs along the appropriate axis.
add_up_results (Callable): function for combining the signed outputs into a
single, sign-covariant output. For example, simple addition for Arrays
or the slog_sum function for SLArrays.
Returns:
Callable: a function with the same signature as the input function, but
which has been symmetrized so that its output will be covariant with respect
to the sign of each input, or in other words, will be odd.
"""
def sign_covariant_fn(x: List[A]) -> A:
symmetries, signs = get_signs_and_syms(x)
outputs = fn(symmetries)
signed_results = apply_output_signs(outputs, signs)
return add_up_results(signed_results)
return sign_covariant_fn
make_sl_array_list_fn_sign_covariant(fn, axis=-2)
Make a function of an SLArrayList sign-covariant in the sign of each SLArray.
Shallow wrapper around the generic apply_sign_symmetry_to_fn.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn |
Callable |
the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape)) |
required |
Returns:
Type | Description |
---|---|
Callable |
a function with the same signature as the input function, but which has been symmetrized so that its output will be covariant with respect to the sign of each input, or in other words, will be odd. |
Source code in vmcnet/models/sign_symmetry.py
def make_sl_array_list_fn_sign_covariant(
fn: Callable[[SLArrayList], SLArray], axis: int = -2
) -> Callable[[SLArrayList], SLArray]:
"""Make a function of an SLArrayList sign-covariant in the sign of each SLArray.
Shallow wrapper around the generic apply_sign_symmetry_to_fn.
Args:
fn (Callable): the function to symmetrize. The given axis is injected into the
inputs and the sign orbit is computed, so this function should be able to
treat the given sign orbit axis as a batch dimension, and the overall tensor
rank should not change (len(input.shape) == len(output.shape))
Returns:
Callable: a function with the same signature as the input function, but
which has been symmetrized so that its output will be covariant with respect
to the sign of each input, or in other words, will be odd.
"""
return apply_sign_symmetry_to_fn(
fn,
functools.partial(_get_sign_orbit_sl_array_list, axis=axis),
lambda x, s: (_multiply_sign_along_axis(x[0], s, axis), x[1]),
functools.partial(slog_sum_over_axis, axis=axis),
)
make_array_list_fn_sign_covariant(fn, axis=-2)
Make a function of an ArrayList sign-covariant in the sign of each array.
Shallow wrapper around the generic apply_sign_symmetry_to_fn.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn |
Callable |
the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape)) |
required |
Returns:
Type | Description |
---|---|
Callable |
a function with the same signature as the input function, but which has been symmetrized so that its output will be covariant with respect to the sign of each input, or in other words, will be odd. |
Source code in vmcnet/models/sign_symmetry.py
def make_array_list_fn_sign_covariant(
fn: Callable[[ArrayList], Array], axis: int = -2
) -> Callable[[ArrayList], Array]:
"""Make a function of an ArrayList sign-covariant in the sign of each array.
Shallow wrapper around the generic apply_sign_symmetry_to_fn.
Args:
fn (Callable): the function to symmetrize. The given axis is injected into the
inputs and the sign orbit is computed, so this function should be able to
treat the given sign orbit axis as a batch dimension, and the overall tensor
rank should not change (len(input.shape) == len(output.shape))
Returns:
Callable: a function with the same signature as the input function, but
which has been symmetrized so that its output will be covariant with respect
to the sign of each input, or in other words, will be odd.
"""
return apply_sign_symmetry_to_fn(
fn,
functools.partial(_get_sign_orbit_array_list, axis=axis),
functools.partial(_multiply_sign_along_axis, axis=axis),
functools.partial(jnp.sum, axis=axis),
)
make_array_list_fn_sign_invariant(fn, axis=-2)
Make a function of an ArrayList sign-invariant (even) in the sign of each array.
Shallow wrapper around the generic apply_sign_symmetry_to_fn.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fn |
Callable |
the function to symmetrize. The given axis is injected into the inputs and the sign orbit is computed, so this function should be able to treat the given sign orbit axis as a batch dimension, and the overall tensor rank should not change (len(input.shape) == len(output.shape)) |
required |
Returns:
Type | Description |
---|---|
Callable |
a function with the same signature as the input function, but which has been symmetrized so that its output will be invariant with respect to the sign of each input, or in other words, will be even. |
Source code in vmcnet/models/sign_symmetry.py
def make_array_list_fn_sign_invariant(
fn: Callable[[ArrayList], Array], axis: int = -2
) -> Callable[[ArrayList], Array]:
"""Make a function of an ArrayList sign-invariant (even) in the sign of each array.
Shallow wrapper around the generic apply_sign_symmetry_to_fn.
Args:
fn (Callable): the function to symmetrize. The given axis is injected into the
inputs and the sign orbit is computed, so this function should be able to
treat the given sign orbit axis as a batch dimension, and the overall tensor
rank should not change (len(input.shape) == len(output.shape))
Returns:
Callable: a function with the same signature as the input function, but
which has been symmetrized so that its output will be invariant with respect
to the sign of each input, or in other words, will be even.
"""
return apply_sign_symmetry_to_fn(
fn,
functools.partial(_get_sign_orbit_array_list, axis=axis),
lambda x, _: x, # Ignore the signs to get an invariance
functools.partial(jnp.sum, axis=axis),
)