Skip to content

equivariance

Permutation equivariant functions.

DoublyEquivariantOrbitalLayer (Module) dataclass

Equivariantly generate an orbital matrix corresponding to each input stream.

The calculation being done here is a bit subtle, so it's worth explaining here in some detail. Let the equivariant input vectors to this layer be y_i. Then, this layer will generate an orbital matrix M_p for each particle P, such that the (i,j)th element of M_p satisfies M_(p,i,j) = phi_j(y_p, y_i). This is essentially the usual orbital matrix formula M_(i,j) = phi_j(y_i), except with an added dependence on the particle index p which allows us to generate a distinct matrix for each input particle. This construction allows us to generate a unique antisymmetric determinant D_p = det(M_p) for each input particle, which can then be the basis for an expressive antiequivariant layer.

If r_ei is provided in addition to the main inputs y_i, then an exponentially decaying envelope is also applied equally to every orbital matrix M_p in order to ensure that the orbital values decay to zero far from the ions.

Attributes:

Name Type Description
orbitals_split ParticleSplit

number of pieces to split the input equally, or specified sequence of locations to split along the 2nd-to-last axis. E.g., if nelec = 10, and orbitals_split = 2, then the input is split (5, 5). If nelec = 10, and orbitals_split = (2, 4), then the input is split into (2, 4, 4) -- note when orbitals_split is a sequence, there will be one more split than the length of the sequence. In the original use-case of spin-1/2 particles, split should be either the number 2 (for closed-shell systems) or should be a Sequence with length 1 whose element is less than the total number of electrons.

norbitals_per_split Sequence[int]

sequence of integers specifying the number of orbitals to create for each split. This determines the output shapes for each split, i.e. the outputs are shaped (..., split_size[i], norbitals[i])

kernel_initializer_linear WeightInitializer

kernel initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array

kernel_initializer_envelope_dim WeightInitializer

kernel initializer for the decay rate in the exponential envelopes. If isotropic_decay is True, then this initializes a single decay rate number per ion and orbital. If isotropic_decay is False, then this initializes a 3x3 matrix per ion and orbital. Has signature (key, shape, dtype) -> Array

kernel_initializer_envelope_ion WeightInitializer

kernel initializer for the linear combination over the ions of exponential envelopes. Has signature (key, shape, dtype) -> Array

bias_initializer_linear WeightInitializer

bias initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array

use_bias bool

whether to add a bias term to the linear part of the orbitals. Defaults to True.

isotropic_decay bool

whether the decay for each ion should be anisotropic (w.r.t. the dimensions of the input), giving envelopes of the form exp(-||A(r - R)||) for a dxd matrix A or isotropic, giving exp(-||a(r - R||)) for a number a.

setup(self)

Setup envelope kernel initializers.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Setup envelope kernel initializers."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._kernel_initializer_envelope_dim = self.kernel_initializer_envelope_dim
    self._kernel_initializer_envelope_ion = self.kernel_initializer_envelope_ion

__call__(self, x, r_ei=None) special

Calculate an equivariant orbital matrix for each input particle.

Parameters:

Name Type Description Default
x Array

array of shape (..., nelec, d)

required
r_ei Array

array of shape (..., nelec, nion, d)

None

Returns:

Type Description
(ArrayList)

list of length nsplits of arrays of shape (..., nelec[i], nelec[i], self.norbitals_per_split[i]). Here nelec[i] is the number of particles in the ith split. The output arrays have both their -2 and -3 axes equivariant with respect to the input particles. The exponential envelopes are computed only when r_ei is not None (so, when connected to FermiNetBackflow, when ion locations are specified). To output square matrices, say in order to be able to take antiequivariant per-particle determinants, nelec[i] should be equal to self.norbitals_per_split[i].

Source code in vmcnet/models/equivariance.py
@flax.linen.compact
def __call__(  # type: ignore[override]
    self, x: Array, r_ei: Array = None
) -> ArrayList:
    """Calculate an equivariant orbital matrix for each input particle.

    Args:
        x (Array): array of shape (..., nelec, d)
        r_ei (Array): array of shape (..., nelec, nion, d)

    Returns:
        (ArrayList): list of length nsplits of arrays of shape
        (..., nelec[i], nelec[i], self.norbitals_per_split[i]). Here nelec[i] is the
        number of particles in the ith split. The output arrays have both their -2
        and -3 axes equivariant with respect to the input particles. The exponential
        envelopes are computed only when r_ei is not None (so, when connected to
        FermiNetBackflow, when ion locations are specified). To output square
        matrices, say in order to be able to take antiequivariant per-particle
        determinants, nelec[i] should be equal to self.norbitals_per_split[i].
    """
    # split_x is a list of nsplits arrays of shape (..., nelec[i], d)]
    split_x = jnp.split(x, self.orbitals_split, -2)
    # orbs is a list of nsplits arrays of shape
    # (..., nelec[i], nelec[i], norbitals[i])
    orbs = [
        self._get_orbital_matrices_one_split(x, self.norbitals_per_split[i])
        for (i, x) in enumerate(split_x)
    ]

    if r_ei is not None:
        exp_envelopes = _compute_exponential_envelopes_all_splits(
            r_ei,
            self.orbitals_split,
            self.norbitals_per_split,
            self._kernel_initializer_envelope_dim,
            self._kernel_initializer_envelope_ion,
            self.isotropic_decay,
        )
        # Envelope must be expanded to apply equally to each per-particle matrix.
        exp_envelopes = jax.tree_map(
            lambda x: jnp.expand_dims(x, axis=-3), exp_envelopes
        )
        orbs = tree_prod(orbs, exp_envelopes)
    return orbs

FermiNetBackflow (Module) dataclass

The FermiNet equivariant part up until, but not including, the orbitals.

Repeated composition of the residual blocks in the parallel one-electron and two-electron streams.

Attributes:

Name Type Description
residual_blocks Sequence

sequence of callable residual blocks which apply the one- and two- electron layers. Each residual block has the signature (in_1e, optional in_2e) -> (out_1e, optional out_2e), where in_1e has shape (..., n, d_1e) out_1e has shape (..., n, d_1e') in_2e has shape (..., n, n, d_2e) out_2d has shape (..., n, n, d_2e')

setup(self)

Setup called residual blocks.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Setup called residual blocks."""
    self._residual_block_list = [block for block in self.residual_blocks]

__call__(self, stream_1e, stream_2e=None) special

Iteratively apply residual blocks to Ferminet input streams.

Parameters:

Name Type Description Default
stream_1e Array

one-electron input stream of shape (..., nelec, d1).

required
stream_2e Array

two-electron input of shape (..., nelec, nelec, d2).

None

Returns:

Type Description
(Array)

the output of the one-electron stream after applying self.residual_blocks to the initial input streams.

Source code in vmcnet/models/equivariance.py
def __call__(  # type: ignore[override]
    self,
    stream_1e: Array,
    stream_2e: Optional[Array] = None,
) -> Array:
    """Iteratively apply residual blocks to Ferminet input streams.

    Args:
        stream_1e (Array): one-electron input stream of shape
            (..., nelec, d1).
        stream_2e (Array, optional): two-electron input of shape
            (..., nelec, nelec, d2).

    Returns:
        (Array): the output of the one-electron stream after applying
        self.residual_blocks to the initial input streams.
    """
    for block in self._residual_block_list:
        stream_1e, stream_2e = block(stream_1e, stream_2e)

    return stream_1e

FermiNetOneElectronLayer (Module) dataclass

A single layer in the one-electron stream of the FermiNet equivariant part.

Attributes:

Name Type Description
spin_split ParticleSplit

number of spins to split the input equally, or specified sequence of locations to split along the 2nd-to-last axis. E.g., if nelec = 10, and spin_split = 2, then the input is split (5, 5). If nelec = 10, and spin_split = (2, 4), then the input is split into (2, 4, 4) -- note when spin_split is a sequence, there will be one more spin than the length of the sequence. In the original use-case of spin-1/2 particles, spin_split should be either the number 2 (for closed-shell systems) or should be a Sequence with length 1 whose element is less than the total number of electrons.

ndense int

number of dense nodes

kernel_initializer_unmixed WeightInitializer

kernel initializer for the unmixed part of the one-electron stream. This initializes the part of the dense kernel which multiplies the previous one-electron stream output. Has signature (key, shape, dtype) -> Array

kernel_initializer_mixed WeightInitializer

kernel initializer for the mixed part of the one-electron stream. This initializes the part of the dense kernel which multiplies the average of the previous one-electron stream output. Has signature (key, shape, dtype) -> Array

kernel_initializer_2e WeightInitializer

kernel initializer for the two-electron part of the one-electron stream. This initializes the part of the dense kernel which multiplies the average of the previous two-electron stream which is mixed into the one-electron stream. Has signature (key, shape, dtype) -> Array

bias_initializer WeightInitializer

bias initializer. Has signature (key, shape, dtype) -> Array

activation_fn Activation

activation function. Has the signature Array -> Array (shape is preserved)

use_bias bool

whether to add a bias term. Defaults to True.

skip_connection bool

whether to add residual skip connections whenever the shapes of the input and output match. Defaults to True.

skip_connection_scale float

quantity to scale the final output by if a skip connection is added. Defaults to 1.0.

cyclic_spins bool

whether the the concatenation in the one-electron stream should satisfy a cyclic equivariance structure, i.e. if there are three spins (1, 2, 3), then in the mixed part of the stream, after averaging but before the linear transformation, cyclic equivariance means the inputs are [(1, 2, 3), (2, 3, 1), (3, 1, 2)]. If False, then the inputs are [(1, 2, 3), (1, 2, 3), (1, 2, 3)] (as in the original FermiNet). When there are only two spins (spin-1/2 case), then this is equivalent to true spin equivariance. Defaults to False (original FermiNet).

setup(self)

Setup Dense layers.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Setup Dense layers."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._activation_fn = self.activation_fn

    self._unmixed_dense = Dense(
        self.ndense,
        kernel_init=self.kernel_initializer_unmixed,
        bias_init=self.bias_initializer,
        use_bias=self.use_bias,
    )
    self._mixed_dense = Dense(
        self.ndense, kernel_init=self.kernel_initializer_mixed, use_bias=False
    )
    self._dense_2e = Dense(
        self.ndense, kernel_init=self.kernel_initializer_2e, use_bias=False
    )

__call__(self, in_1e, in_2e=None) special

Add dense outputs on unmixed, mixed, and 2e terms to get the 1e output.

This implementation breaks the one-electron stream into three parts: 1) the unmixed one-particle part, which is a linear transformation applied in parallel for each particle to the inputs 2) the mixed one-particle part, which is a linear transformation applied to the averages of the inputs (concatenated over spin) 3) the two-particle part, which is a linear transformation applied in parallel for each particle to the average of the input interactions between that particle and all the other particles.

For 1), we take in_1e of shape (..., n_total, d_1e), batch apply a linear transformation to get (..., n_total, d'), and split over the spins i to get [i: (..., n[i], d')].

For 2), we split in_1e over the spins along the particle axis to get [i: (..., n[i], d_1e)], average over each spin to get [i: (..., 1, d_1e)], concatenate all averages for each spin to get [i: (..., 1, d_1e * nspins)], and apply a linear transformation to get [i: (..., 1, d')].

For 3) we split in_2e of shape (..., n_total, n_total, d_2e) over the spins along a particle axis to get [i: (..., n[i], n_total, d_2e)], average over the other particle axis to get [i: [j: (..., n[i], d_2e)]], concatenate the averages for each spin to get [i: (..., n[i], d_2e * nspins)], and apply a linear transformation to get [i: (..., n[i], d')].

Finally, for each spin, we add the three parts, each equivariant or symmetric, to get a final equivariant linear transformation of the inputs, to which a non-linearity is then applied and a skip connection optionally added.

Parameters:

Name Type Description Default
in_1e Array

array of shape (..., n_total, d_1e)

required
in_2e Array

array of shape (..., n_total, n_total, d_2e). Defaults to None.

None

Returns:

Type Description
ndarray

Array of shape (..., n_total, self.ndense), the output one-electron stream

Source code in vmcnet/models/equivariance.py
def __call__(  # type: ignore[override]
    self, in_1e: Array, in_2e: Array = None
) -> Array:
    """Add dense outputs on unmixed, mixed, and 2e terms to get the 1e output.

    This implementation breaks the one-electron stream into three parts:
        1) the unmixed one-particle part, which is a linear transformation applied
            in parallel for each particle to the inputs
        2) the mixed one-particle part, which is a linear transformation applied to
            the averages of the inputs (concatenated over spin)
        3) the two-particle part, which is a linear transformation applied in
            parallel for each particle to the average of the input interactions
            between that particle and all the other particles.

    For 1), we take `in_1e` of shape (..., n_total, d_1e), batch apply a linear
    transformation to get (..., n_total, d'), and split over the spins i to get
    [i: (..., n[i], d')].

    For 2), we split `in_1e` over the spins along the particle axis to get
    [i: (..., n[i], d_1e)], average over each spin to get [i: (..., 1, d_1e)],
    concatenate all averages for each spin to get [i: (..., 1, d_1e * nspins)], and
    apply a linear transformation to get [i: (..., 1, d')].

    For 3) we split in_2e of shape (..., n_total, n_total, d_2e) over the spins
    along a particle axis to get [i: (..., n[i], n_total, d_2e)], average over the
    other particle axis to get [i: [j: (..., n[i], d_2e)]], concatenate the averages
    for each spin to get [i: (..., n[i], d_2e * nspins)], and apply a linear
    transformation to get [i: (..., n[i], d')].

    Finally, for each spin, we add the three parts, each equivariant or symmetric,
    to get a final equivariant linear transformation of the inputs, to which a
    non-linearity is then applied and a skip connection optionally added.

    Args:
        in_1e (Array): array of shape (..., n_total, d_1e)
        in_2e (Array, optional): array of shape (..., n_total, n_total, d_2e).
            Defaults to None.

    Returns:
        Array of shape (..., n_total, self.ndense), the output one-electron
        stream
    """
    dense_unmixed = self._unmixed_dense(in_1e)
    dense_unmixed_split = jnp.split(dense_unmixed, self.spin_split, axis=-2)

    split_1e_means = _split_mean(in_1e, self.spin_split, axis=-2, keepdims=True)
    dense_mixed_split = self._compute_transformed_1e_means(split_1e_means)

    # adds the unmixed [i: (..., n[i], d')] to the mixed [i: (..., 1, d')] to get
    # an equivariant function. Without the two-electron mixing, this is a spinful
    # version of DeepSet's Lemma 3: https://arxiv.org/pdf/1703.06114.pdf
    dense_out = tree_sum(dense_unmixed_split, dense_mixed_split)

    if in_2e is not None:
        dense_2e_split = self._compute_transformed_2e_means(in_2e)
        dense_out = tree_sum(dense_out, dense_2e_split)

    dense_out_concat = jnp.concatenate(dense_out, axis=-2)
    nonlinear_out = self._activation_fn(dense_out_concat)

    if self.skip_connection and _valid_skip(in_1e, nonlinear_out):
        nonlinear_out = self.skip_connection_scale * (nonlinear_out + in_1e)

    return nonlinear_out

FermiNetOrbitalLayer (Module) dataclass

Make the FermiNet orbitals (parallel linear layers with exp decay envelopes).

Attributes:

Name Type Description
orbitals_split ParticleSplit

number of pieces to split the input equally, or specified sequence of locations to split along the 2nd-to-last axis. E.g., if nelec = 10, and orbitals_split = 2, then the input is split (5, 5). If nelec = 10, and orbitals_split = (2, 4), then the input is split into (2, 4, 4) -- note when orbitals_split is a sequence, there will be one more split than the length of the sequence. In the original use-case of spin-1/2 particles, split should be either the number 2 (for closed-shell systems) or should be a Sequence with length 1 whose element is less than the total number of electrons.

norbitals_per_split Sequence[int]

sequence of integers specifying the number of orbitals to create for each split. This determines the output shapes for each split, i.e. the outputs are shaped (..., split_size[i], norbitals[i])

kernel_initializer_linear WeightInitializer

kernel initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array

kernel_initializer_envelope_dim WeightInitializer

kernel initializer for the decay rate in the exponential envelopes. If isotropic_decay is True, then this initializes a single decay rate number per ion and orbital. If isotropic_decay is False, then this initializes a 3x3 matrix per ion and orbital. Has signature (key, shape, dtype) -> Array

kernel_initializer_envelope_ion WeightInitializer

kernel initializer for the linear combination over the ions of exponential envelopes. Has signature (key, shape, dtype) -> Array

bias_initializer_linear WeightInitializer

bias initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array

use_bias bool

whether to add a bias term to the linear part of the orbitals. Defaults to True.

isotropic_decay bool

whether the decay for each ion should be anisotropic (w.r.t. the dimensions of the input), giving envelopes of the form exp(-||A(r - R)||) for a dxd matrix A or isotropic, giving exp(-||a(r - R||)) for a number a.

setup(self)

Setup envelope kernel initializers.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Setup envelope kernel initializers."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._kernel_initializer_envelope_dim = self.kernel_initializer_envelope_dim
    self._kernel_initializer_envelope_ion = self.kernel_initializer_envelope_ion

__call__(self, x, r_ei=None) special

Apply a dense layer R -> R^n for each split and multiply by exp envelopes.

Parameters:

Name Type Description Default
x Array

array of shape (..., nelec, d)

required
r_ei Array

array of shape (..., nelec, nion, d)

None

Returns:

Type Description
[(..., nelec[i], self.norbitals_per_split[i])]

list of FermiNet orbital matrices computed from an output stream x and the electron-ion displacements r_ei. Here n[i] is the number of particles in the ith split. The exponential envelopes are computed only when r_ei is not None (so, when connected to FermiNetBackflow, when ion locations are specified). To output square matrices, say for composing with the determinant anti-symmetry, nelec[i] should be equal to self.norbitals_per_split[i].

Source code in vmcnet/models/equivariance.py
@flax.linen.compact
def __call__(  # type: ignore[override]
    self, x: Array, r_ei: Array = None
) -> ArrayList:
    """Apply a dense layer R -> R^n for each split and multiply by exp envelopes.

    Args:
        x (Array): array of shape (..., nelec, d)
        r_ei (Array): array of shape (..., nelec, nion, d)

    Returns:
        [(..., nelec[i], self.norbitals_per_split[i])]: list of FermiNet orbital
        matrices computed from an output stream x and the electron-ion displacements
        r_ei. Here n[i] is the number of particles in the ith split. The exponential
        envelopes are computed only when r_ei is not None (so, when connected to
        FermiNetBackflow, when ion locations are specified). To output square
        matrices, say for composing with the determinant anti-symmetry,
        nelec[i] should be equal to self.norbitals_per_split[i].
    """
    orbs = SplitDense(
        self.orbitals_split,
        self.norbitals_per_split,
        self.kernel_initializer_linear,
        self.bias_initializer_linear,
        use_bias=self.use_bias,
    )(x)
    if r_ei is not None:
        exp_envelopes = _compute_exponential_envelopes_all_splits(
            r_ei,
            self.orbitals_split,
            self.norbitals_per_split,
            self._kernel_initializer_envelope_dim,
            self._kernel_initializer_envelope_ion,
            self.isotropic_decay,
        )
        orbs = tree_prod(orbs, exp_envelopes)
    return orbs

FermiNetResidualBlock (Module) dataclass

A single residual block in the FermiNet equivariant part.

Combines the one-electron and two-electron streams.

Attributes:

Name Type Description
one_electron_layer Callable

function which takes in a previous one-electron stream output and two-electron stream output and mixes/transforms them to create a new one-electron stream output. Has the signature: (array of shape (..., n, d_1e), optional array of shape (..., n, n, d_2e)) -> array of shape (..., n, d_1e')

two_electron_layer Callable

function which takes in a previous two-electron stream output and batch applies a Dense layer along the last axis. Has the signature: array of shape (..., n, n, d_2e) -> array of shape (..., n, n, d_2e')

setup(self)

Setup called one- and two- electron layers.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Setup called one- and two- electron layers."""
    self._one_electron_layer = self.one_electron_layer
    self._two_electron_layer = self.two_electron_layer

__call__(self, in_1e, in_2e=None) special

Apply the one-electron layer and optionally the two-electron layer.

Parameters:

Name Type Description Default
in_1e Array

array of shape (..., n_total, d_1e)

required
in_2e Array

array of shape (..., n_total, n_total, d_2e). Defaults to None.

None

Returns:

Type Description
(Array, optional Array)

tuple of (out_1e, out_2e) where out_1e is the output from the one-electron layer and out_2e is the output of the two-electron stream

Source code in vmcnet/models/equivariance.py
def __call__(  # type: ignore[override]
    self, in_1e: Array, in_2e: Array = None
) -> Tuple[Array, Optional[Array]]:
    """Apply the one-electron layer and optionally the two-electron layer.

    Args:
        in_1e (Array): array of shape (..., n_total, d_1e)
        in_2e (Array, optional): array of shape (..., n_total, n_total, d_2e).
            Defaults to None.

    Returns:
        (Array, optional Array): tuple of (out_1e, out_2e) where out_1e
        is the output from the one-electron layer and out_2e is the output of the
        two-electron stream
    """
    out_1e = self._one_electron_layer(in_1e, in_2e)

    out_2e = in_2e
    if self.two_electron_layer is not None and in_2e is not None:
        out_2e = self._two_electron_layer(in_2e)

    return out_1e, out_2e

FermiNetTwoElectronLayer (Module) dataclass

A single layer in the two-electron stream of the FermiNet equivariance.

Attributes:

Name Type Description
ndense int

number of dense nodes

kernel_initializer WeightInitializer

kernel initializer. Has signature (key, shape, dtype) -> Array

bias_initializer WeightInitializer

bias initializer. Has signature (key, shape, dtype) -> Array

activation_fn Activation

activation function. Has the signature Array -> Array (shape is preserved)

use_bias bool

whether to add a bias term. Defaults to True.

skip_connection bool

whether to add residual skip connections whenever the shapes of the input and output match. Defaults to True.

skip_connection_scale float

quantity to scale the final output by if a skip connection is added. Defaults to 1.0.

setup(self)

Setup Dense layer.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Setup Dense layer."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._activation_fn = self.activation_fn
    self._dense = Dense(
        self.ndense,
        kernel_init=self.kernel_initializer,
        bias_init=self.bias_initializer,
        use_bias=self.use_bias,
    )

__call__(self, x) special

Apply a Dense layer in parallel to all electron pairs.

The expected use-case of this is to batch apply a dense layer to an input x of shape (..., n_total, n_total, d), getting an output of shape (..., n_total, n_total, d'), and optionally adding a skip connection. The function itself is just a standard residual network layer.

Source code in vmcnet/models/equivariance.py
def __call__(self, x: Array) -> Array:  # type: ignore[override]
    """Apply a Dense layer in parallel to all electron pairs.

    The expected use-case of this is to batch apply a dense layer to an input x of
    shape (..., n_total, n_total, d), getting an output of shape
    (..., n_total, n_total, d'), and optionally adding a skip connection. The
    function itself is just a standard residual network layer.
    """
    dense_out = self._dense(x)
    nonlinear_out = self._activation_fn(dense_out)

    if self.skip_connection and _valid_skip(x, nonlinear_out):
        nonlinear_out = self.skip_connection_scale * (nonlinear_out + x)

    return nonlinear_out

SplitDense (Module) dataclass

Split input on the 2nd-to-last axis and apply unique Dense layers to each split.

Attributes:

Name Type Description
split ParticleSplit

number of pieces to split the input equally, or specified sequence of locations to split along the 2nd-to-last axis. E.g., if nelec = 10, and split = 2, then the input is split (5, 5). If nelec = 10, and split = (2, 4), then the input is split into (2, 4, 4) -- note when split is a sequence, there will be one more split than the length of the sequence. In the original use-case of spin-1/2 particles, split should be either the number 2 (for closed-shell systems) or should be a Sequence with length 1 whose element is less than the total number of electrons.

ndense_per_split Sequence[int]

sequence of integers specifying the number of dense nodes in the unique dense layer applied to each split of the input. This determines the output shapes for each split, i.e. the outputs are shaped (..., split_size[i], ndense[i])

kernel_initializer WeightInitializer

kernel initializer. Has signature (key, shape, dtype) -> Array

bias_initializer WeightInitializer

bias initializer. Has signature (key, shape, dtype) -> Array. Defaults to random normal initialization.

use_bias bool

whether to add a bias term. Defaults to True.

register_kfac bool

whether to register the dense computations with KFAC. Defaults to True.

setup(self)

Set up the dense layers for each split.

Source code in vmcnet/models/equivariance.py
def setup(self):
    """Set up the dense layers for each split."""
    nsplits = get_nsplits(self.split)

    if len(self.ndense_per_split) != nsplits:
        raise ValueError(
            "Incorrect number of dense output shapes specified for number of "
            "splits, should be one shape per split: shapes {} specified for the "
            "given split {}".format(self.ndense_per_split, self.split)
        )

    self._dense_layers = [
        Dense(
            self.ndense_per_split[i],
            kernel_init=self.kernel_initializer,
            bias_init=self.bias_initializer,
            use_bias=self.use_bias,
            register_kfac=self.register_kfac,
        )
        for i in range(nsplits)
    ]

__call__(self, x) special

Split the input and apply a dense layer to each split.

Parameters:

Name Type Description Default
x Array

array of shape (..., n, d)

required

Returns:

Type Description
[(..., n[i], self.ndense_per_split[i])]

list of length nsplits, where nsplits is the number of splits created by jnp.split(x, self.split, axis=-2), and the ith entry of the output is the ith split transformed by a dense layer with self.ndense_per_split[i] nodes.

Source code in vmcnet/models/equivariance.py
def __call__(self, x: Array) -> ArrayList:  # type: ignore[override]
    """Split the input and apply a dense layer to each split.

    Args:
        x (Array): array of shape (..., n, d)

    Returns:
        [(..., n[i], self.ndense_per_split[i])]: list of length nsplits, where
        nsplits is the number of splits created by
        jnp.split(x, self.split, axis=-2), and the ith entry of the output is the
        ith split transformed by a dense layer with self.ndense_per_split[i] nodes.
    """
    x_split = jnp.split(x, self.split, axis=-2)
    return [self._dense_layers[i](split) for i, split in enumerate(x_split)]

compute_input_streams(elec_pos, ion_pos=None, include_2e_stream=True, include_ei_norm=True, include_ee_norm=True)

Create input streams with electron and optionally ion data.

If ion_pos is given, computes the electron-ion displacements (i.e. nuclear coordinates) and concatenates/flattens them along the ion dimension. If include_ei_norm is True, then the distances are also concatenated, so the map is elec_pos = (..., nelec, d) -> input_1e = (..., nelec, nion * (d + 1)).

If include_2e_stream is True, then a two-electron stream of shape (..., nelec, nelec, d) is also computed and returned (otherwise None is returned). If include_ee_norm is True, then this becomes (..., nelec, nelec, d + 1) by concatenating pairwise distances onto the stream.

Parameters:

Name Type Description Default
elec_pos Array

electron positions of shape (..., nelec, d)

required
ion_pos Array

locations of (stationary) ions to compute relative electron positions, 2-d array of shape (nion, d). Defaults to None.

None
include_2e_stream bool

whether to compute pairwise electron displacements/distances. Defaults to True.

True
include_ei_norm bool

whether to include electron-ion distances in the one-electron input. Defaults to True.

True
include_ee_norm bool

whether to include electron-electron distances in the two-electron input. Defaults to True.

True

Returns:

Type Description
( Array, Optional[Array], Optional[Array], Optional[Array], )

first output: one-electron input of shape (..., nelec, d'), where d' = d if ion_pos is None, d' = nion * d if ion_pos is given and include_ei_norm is False, and d' = nion * (d + 1) if ion_pos is given and include_ei_norm is True.

second output: two-electron input of shape (..., nelec, nelec, d'), where d' = d if include_ee_norm is False, and d' = d + 1 if include_ee_norm is True

third output: electron-ion displacements of shape (..., nelec, nion, d)

fourth output: electron-electron displacements of shape (..., nelec, nelec, d)

If include_2e_stream is False, then the second and fourth outputs are None. If ion_pos is None, then the third output is None.

Source code in vmcnet/models/equivariance.py
def compute_input_streams(
    elec_pos: Array,
    ion_pos: Array = None,
    include_2e_stream: bool = True,
    include_ei_norm: bool = True,
    include_ee_norm: bool = True,
) -> InputStreams:
    """Create input streams with electron and optionally ion data.

    If `ion_pos` is given, computes the electron-ion displacements (i.e. nuclear
    coordinates) and concatenates/flattens them along the ion dimension. If
    `include_ei_norm` is True, then the distances are also concatenated, so the map is
    elec_pos = (..., nelec, d) -> input_1e = (..., nelec, nion * (d + 1)).

    If `include_2e_stream` is True, then a two-electron stream of shape
    (..., nelec, nelec, d) is also computed and returned (otherwise None is returned).
    If `include_ee_norm` is True, then this becomes (..., nelec, nelec, d + 1) by
    concatenating pairwise distances onto the stream.

    Args:
        elec_pos (Array): electron positions of shape (..., nelec, d)
        ion_pos (Array, optional): locations of (stationary) ions to compute
            relative electron positions, 2-d array of shape (nion, d). Defaults to None.
        include_2e_stream (bool, optional): whether to compute pairwise electron
            displacements/distances. Defaults to True.
        include_ei_norm (bool, optional): whether to include electron-ion distances in
            the one-electron input. Defaults to True.
        include_ee_norm (bool, optional): whether to include electron-electron distances
            in the two-electron input. Defaults to True.

    Returns:
        (
            Array,
            Optional[Array],
            Optional[Array],
            Optional[Array],
        ):

        first output: one-electron input of shape (..., nelec, d'), where
            d' = d if `ion_pos` is None,
            d' = nion * d if `ion_pos` is given and `include_ei_norm` is False, and
            d' = nion * (d + 1) if `ion_pos` is given and `include_ei_norm` is True.

        second output: two-electron input of shape (..., nelec, nelec, d'), where
            d' = d if `include_ee_norm` is False, and
            d' = d + 1 if `include_ee_norm` is True

        third output: electron-ion displacements of shape (..., nelec, nion, d)

        fourth output: electron-electron displacements of shape (..., nelec, nelec, d)

        If `include_2e_stream` is False, then the second and fourth outputs are None. If
        `ion_pos` is None, then the third output is None.
    """
    input_1e, r_ei = compute_electron_ion(elec_pos, ion_pos, include_ei_norm)
    input_2e = None
    r_ee = None
    if include_2e_stream:
        input_2e, r_ee = compute_electron_electron(elec_pos, include_ee_norm)
    return input_1e, input_2e, r_ei, r_ee

compute_electron_ion(elec_pos, ion_pos=None, include_ei_norm=True)

Compute electron-ion displacements and optionally add on the distances.

Parameters:

Name Type Description Default
elec_pos Array

electron positions of shape (..., nelec, d)

required
ion_pos Array

locations of (stationary) ions to compute relative electron positions, 2-d array of shape (nion, d). Defaults to None.

None
include_ei_norm bool

whether to include electron-ion distances in the one-electron input. Defaults to True.

True

Returns:

Type Description
(Array, Optional[Array])

first output: one-electron input of shape (..., nelec, d'), where d' = d if ion_pos is None, d' = nion * d if ion_pos is given and include_ei_norm is False, and d' = nion * (d + 1) if ion_pos is given and include_ei_norm is True.

second output: electron-ion displacements of shape (..., nelec, nion, d)

If ion_pos is None, then the second output is None.

Source code in vmcnet/models/equivariance.py
def compute_electron_ion(
    elec_pos: Array, ion_pos: Array = None, include_ei_norm: bool = True
) -> Tuple[Array, Optional[Array]]:
    """Compute electron-ion displacements and optionally add on the distances.

    Args:
        elec_pos (Array): electron positions of shape (..., nelec, d)
        ion_pos (Array, optional): locations of (stationary) ions to compute
            relative electron positions, 2-d array of shape (nion, d). Defaults to None.
        include_ei_norm (bool, optional): whether to include electron-ion distances in
            the one-electron input. Defaults to True.

    Returns:
        (Array, Optional[Array]):

        first output: one-electron input of shape (..., nelec, d'), where
            d' = d if `ion_pos` is None,
            d' = nion * d if `ion_pos` is given and `include_ei_norm` is False, and
            d' = nion * (d + 1) if `ion_pos` is given and `include_ei_norm` is True.

        second output: electron-ion displacements of shape (..., nelec, nion, d)

        If `ion_pos` is None, then the second output is None.
    """
    r_ei = None
    input_1e = elec_pos
    if ion_pos is not None:
        r_ei = _compute_displacements(input_1e, ion_pos)
        input_1e = r_ei
        if include_ei_norm:
            input_norm = jnp.linalg.norm(input_1e, axis=-1, keepdims=True)
            input_with_norm = jnp.concatenate([input_1e, input_norm], axis=-1)
            input_1e = jnp.reshape(input_with_norm, input_with_norm.shape[:-2] + (-1,))
    return input_1e, r_ei

compute_electron_electron(elec_pos, include_ee_norm=True)

Compute electron-electron displacements and optionally add on the distances.

Parameters:

Name Type Description Default
elec_pos Array

electron positions of shape (..., nelec, d)

required
include_ee_norm bool

whether to include electron-electron distances in the two-electron input. Defaults to True.

True

Returns:

Type Description
(Array, Array)

first output: two-electron input of shape (..., nelec, nelec, d'), where d' = d if include_ee_norm is False, and d' = d + 1 if include_ee_norm is True

second output: two-electron displacements of shape (..., nelec, nelec, d)

Source code in vmcnet/models/equivariance.py
def compute_electron_electron(
    elec_pos: Array, include_ee_norm: bool = True
) -> Tuple[Array, Array]:
    """Compute electron-electron displacements and optionally add on the distances.

    Args:
        elec_pos (Array): electron positions of shape (..., nelec, d)
        include_ee_norm (bool, optional): whether to include electron-electron distances
            in the two-electron input. Defaults to True.

    Returns:
        (Array, Array):

        first output: two-electron input of shape (..., nelec, nelec, d'), where
            d' = d if `include_ee_norm` is False, and
            d' = d + 1 if `include_ee_norm` is True

        second output: two-electron displacements of shape (..., nelec, nelec, d)
    """
    r_ee = _compute_displacements(elec_pos, elec_pos)
    input_2e = r_ee
    if include_ee_norm:
        r_ee_norm = compute_ee_norm_with_safe_diag(r_ee)
        input_2e = jnp.concatenate([input_2e, r_ee_norm], axis=-1)
    return input_2e, r_ee