equivariance
Permutation equivariant functions.
DoublyEquivariantOrbitalLayer (Module)
dataclass
Equivariantly generate an orbital matrix corresponding to each input stream.
The calculation being done here is a bit subtle, so it's worth explaining here in some detail. Let the equivariant input vectors to this layer be y_i. Then, this layer will generate an orbital matrix M_p for each particle P, such that the (i,j)th element of M_p satisfies M_(p,i,j) = phi_j(y_p, y_i). This is essentially the usual orbital matrix formula M_(i,j) = phi_j(y_i), except with an added dependence on the particle index p which allows us to generate a distinct matrix for each input particle. This construction allows us to generate a unique antisymmetric determinant D_p = det(M_p) for each input particle, which can then be the basis for an expressive antiequivariant layer.
If r_ei is provided in addition to the main inputs y_i, then an exponentially decaying envelope is also applied equally to every orbital matrix M_p in order to ensure that the orbital values decay to zero far from the ions.
Attributes:
Name | Type | Description |
---|---|---|
orbitals_split |
ParticleSplit |
number of pieces to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
norbitals_per_split |
Sequence[int] |
sequence of integers specifying the number of orbitals to create for each split. This determines the output shapes for each split, i.e. the outputs are shaped (..., split_size[i], norbitals[i]) |
kernel_initializer_linear |
WeightInitializer |
kernel initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array |
kernel_initializer_envelope_dim |
WeightInitializer |
kernel initializer for the
decay rate in the exponential envelopes. If |
kernel_initializer_envelope_ion |
WeightInitializer |
kernel initializer for the linear combination over the ions of exponential envelopes. Has signature (key, shape, dtype) -> Array |
bias_initializer_linear |
WeightInitializer |
bias initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array |
use_bias |
bool |
whether to add a bias term to the linear part of the orbitals. Defaults to True. |
isotropic_decay |
bool |
whether the decay for each ion should be anisotropic (w.r.t. the dimensions of the input), giving envelopes of the form exp(-||A(r - R)||) for a dxd matrix A or isotropic, giving exp(-||a(r - R||)) for a number a. |
setup(self)
Setup envelope kernel initializers.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Setup envelope kernel initializers."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._kernel_initializer_envelope_dim = self.kernel_initializer_envelope_dim
self._kernel_initializer_envelope_ion = self.kernel_initializer_envelope_ion
__call__(self, x, r_ei=None)
special
Calculate an equivariant orbital matrix for each input particle.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Array |
array of shape (..., nelec, d) |
required |
r_ei |
Array |
array of shape (..., nelec, nion, d) |
None |
Returns:
Type | Description |
---|---|
(ArrayList) |
list of length nsplits of arrays of shape (..., nelec[i], nelec[i], self.norbitals_per_split[i]). Here nelec[i] is the number of particles in the ith split. The output arrays have both their -2 and -3 axes equivariant with respect to the input particles. The exponential envelopes are computed only when r_ei is not None (so, when connected to FermiNetBackflow, when ion locations are specified). To output square matrices, say in order to be able to take antiequivariant per-particle determinants, nelec[i] should be equal to self.norbitals_per_split[i]. |
Source code in vmcnet/models/equivariance.py
@flax.linen.compact
def __call__( # type: ignore[override]
self, x: Array, r_ei: Array = None
) -> ArrayList:
"""Calculate an equivariant orbital matrix for each input particle.
Args:
x (Array): array of shape (..., nelec, d)
r_ei (Array): array of shape (..., nelec, nion, d)
Returns:
(ArrayList): list of length nsplits of arrays of shape
(..., nelec[i], nelec[i], self.norbitals_per_split[i]). Here nelec[i] is the
number of particles in the ith split. The output arrays have both their -2
and -3 axes equivariant with respect to the input particles. The exponential
envelopes are computed only when r_ei is not None (so, when connected to
FermiNetBackflow, when ion locations are specified). To output square
matrices, say in order to be able to take antiequivariant per-particle
determinants, nelec[i] should be equal to self.norbitals_per_split[i].
"""
# split_x is a list of nsplits arrays of shape (..., nelec[i], d)]
split_x = jnp.split(x, self.orbitals_split, -2)
# orbs is a list of nsplits arrays of shape
# (..., nelec[i], nelec[i], norbitals[i])
orbs = [
self._get_orbital_matrices_one_split(x, self.norbitals_per_split[i])
for (i, x) in enumerate(split_x)
]
if r_ei is not None:
exp_envelopes = _compute_exponential_envelopes_all_splits(
r_ei,
self.orbitals_split,
self.norbitals_per_split,
self._kernel_initializer_envelope_dim,
self._kernel_initializer_envelope_ion,
self.isotropic_decay,
)
# Envelope must be expanded to apply equally to each per-particle matrix.
exp_envelopes = jax.tree_map(
lambda x: jnp.expand_dims(x, axis=-3), exp_envelopes
)
orbs = tree_prod(orbs, exp_envelopes)
return orbs
FermiNetBackflow (Module)
dataclass
The FermiNet equivariant part up until, but not including, the orbitals.
Repeated composition of the residual blocks in the parallel one-electron and two-electron streams.
Attributes:
Name | Type | Description |
---|---|---|
residual_blocks |
Sequence |
sequence of callable residual blocks which apply the one- and two- electron layers. Each residual block has the signature (in_1e, optional in_2e) -> (out_1e, optional out_2e), where in_1e has shape (..., n, d_1e) out_1e has shape (..., n, d_1e') in_2e has shape (..., n, n, d_2e) out_2d has shape (..., n, n, d_2e') |
setup(self)
Setup called residual blocks.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Setup called residual blocks."""
self._residual_block_list = [block for block in self.residual_blocks]
__call__(self, stream_1e, stream_2e=None)
special
Iteratively apply residual blocks to Ferminet input streams.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
stream_1e |
Array |
one-electron input stream of shape (..., nelec, d1). |
required |
stream_2e |
Array |
two-electron input of shape (..., nelec, nelec, d2). |
None |
Returns:
Type | Description |
---|---|
(Array) |
the output of the one-electron stream after applying self.residual_blocks to the initial input streams. |
Source code in vmcnet/models/equivariance.py
def __call__( # type: ignore[override]
self,
stream_1e: Array,
stream_2e: Optional[Array] = None,
) -> Array:
"""Iteratively apply residual blocks to Ferminet input streams.
Args:
stream_1e (Array): one-electron input stream of shape
(..., nelec, d1).
stream_2e (Array, optional): two-electron input of shape
(..., nelec, nelec, d2).
Returns:
(Array): the output of the one-electron stream after applying
self.residual_blocks to the initial input streams.
"""
for block in self._residual_block_list:
stream_1e, stream_2e = block(stream_1e, stream_2e)
return stream_1e
FermiNetOneElectronLayer (Module)
dataclass
A single layer in the one-electron stream of the FermiNet equivariant part.
Attributes:
Name | Type | Description |
---|---|---|
spin_split |
ParticleSplit |
number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
ndense |
int |
number of dense nodes |
kernel_initializer_unmixed |
WeightInitializer |
kernel initializer for the unmixed part of the one-electron stream. This initializes the part of the dense kernel which multiplies the previous one-electron stream output. Has signature (key, shape, dtype) -> Array |
kernel_initializer_mixed |
WeightInitializer |
kernel initializer for the mixed part of the one-electron stream. This initializes the part of the dense kernel which multiplies the average of the previous one-electron stream output. Has signature (key, shape, dtype) -> Array |
kernel_initializer_2e |
WeightInitializer |
kernel initializer for the two-electron part of the one-electron stream. This initializes the part of the dense kernel which multiplies the average of the previous two-electron stream which is mixed into the one-electron stream. Has signature (key, shape, dtype) -> Array |
bias_initializer |
WeightInitializer |
bias initializer. Has signature (key, shape, dtype) -> Array |
activation_fn |
Activation |
activation function. Has the signature Array -> Array (shape is preserved) |
use_bias |
bool |
whether to add a bias term. Defaults to True. |
skip_connection |
bool |
whether to add residual skip connections whenever the shapes of the input and output match. Defaults to True. |
skip_connection_scale |
float |
quantity to scale the final output by if a skip connection is added. Defaults to 1.0. |
cyclic_spins |
bool |
whether the the concatenation in the one-electron stream should satisfy a cyclic equivariance structure, i.e. if there are three spins (1, 2, 3), then in the mixed part of the stream, after averaging but before the linear transformation, cyclic equivariance means the inputs are [(1, 2, 3), (2, 3, 1), (3, 1, 2)]. If False, then the inputs are [(1, 2, 3), (1, 2, 3), (1, 2, 3)] (as in the original FermiNet). When there are only two spins (spin-1/2 case), then this is equivalent to true spin equivariance. Defaults to False (original FermiNet). |
setup(self)
Setup Dense layers.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Setup Dense layers."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._activation_fn = self.activation_fn
self._unmixed_dense = Dense(
self.ndense,
kernel_init=self.kernel_initializer_unmixed,
bias_init=self.bias_initializer,
use_bias=self.use_bias,
)
self._mixed_dense = Dense(
self.ndense, kernel_init=self.kernel_initializer_mixed, use_bias=False
)
self._dense_2e = Dense(
self.ndense, kernel_init=self.kernel_initializer_2e, use_bias=False
)
__call__(self, in_1e, in_2e=None)
special
Add dense outputs on unmixed, mixed, and 2e terms to get the 1e output.
This implementation breaks the one-electron stream into three parts: 1) the unmixed one-particle part, which is a linear transformation applied in parallel for each particle to the inputs 2) the mixed one-particle part, which is a linear transformation applied to the averages of the inputs (concatenated over spin) 3) the two-particle part, which is a linear transformation applied in parallel for each particle to the average of the input interactions between that particle and all the other particles.
For 1), we take in_1e
of shape (..., n_total, d_1e), batch apply a linear
transformation to get (..., n_total, d'), and split over the spins i to get
[i: (..., n[i], d')].
For 2), we split in_1e
over the spins along the particle axis to get
[i: (..., n[i], d_1e)], average over each spin to get [i: (..., 1, d_1e)],
concatenate all averages for each spin to get [i: (..., 1, d_1e * nspins)], and
apply a linear transformation to get [i: (..., 1, d')].
For 3) we split in_2e of shape (..., n_total, n_total, d_2e) over the spins along a particle axis to get [i: (..., n[i], n_total, d_2e)], average over the other particle axis to get [i: [j: (..., n[i], d_2e)]], concatenate the averages for each spin to get [i: (..., n[i], d_2e * nspins)], and apply a linear transformation to get [i: (..., n[i], d')].
Finally, for each spin, we add the three parts, each equivariant or symmetric, to get a final equivariant linear transformation of the inputs, to which a non-linearity is then applied and a skip connection optionally added.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_1e |
Array |
array of shape (..., n_total, d_1e) |
required |
in_2e |
Array |
array of shape (..., n_total, n_total, d_2e). Defaults to None. |
None |
Returns:
Type | Description |
---|---|
ndarray |
Array of shape (..., n_total, self.ndense), the output one-electron stream |
Source code in vmcnet/models/equivariance.py
def __call__( # type: ignore[override]
self, in_1e: Array, in_2e: Array = None
) -> Array:
"""Add dense outputs on unmixed, mixed, and 2e terms to get the 1e output.
This implementation breaks the one-electron stream into three parts:
1) the unmixed one-particle part, which is a linear transformation applied
in parallel for each particle to the inputs
2) the mixed one-particle part, which is a linear transformation applied to
the averages of the inputs (concatenated over spin)
3) the two-particle part, which is a linear transformation applied in
parallel for each particle to the average of the input interactions
between that particle and all the other particles.
For 1), we take `in_1e` of shape (..., n_total, d_1e), batch apply a linear
transformation to get (..., n_total, d'), and split over the spins i to get
[i: (..., n[i], d')].
For 2), we split `in_1e` over the spins along the particle axis to get
[i: (..., n[i], d_1e)], average over each spin to get [i: (..., 1, d_1e)],
concatenate all averages for each spin to get [i: (..., 1, d_1e * nspins)], and
apply a linear transformation to get [i: (..., 1, d')].
For 3) we split in_2e of shape (..., n_total, n_total, d_2e) over the spins
along a particle axis to get [i: (..., n[i], n_total, d_2e)], average over the
other particle axis to get [i: [j: (..., n[i], d_2e)]], concatenate the averages
for each spin to get [i: (..., n[i], d_2e * nspins)], and apply a linear
transformation to get [i: (..., n[i], d')].
Finally, for each spin, we add the three parts, each equivariant or symmetric,
to get a final equivariant linear transformation of the inputs, to which a
non-linearity is then applied and a skip connection optionally added.
Args:
in_1e (Array): array of shape (..., n_total, d_1e)
in_2e (Array, optional): array of shape (..., n_total, n_total, d_2e).
Defaults to None.
Returns:
Array of shape (..., n_total, self.ndense), the output one-electron
stream
"""
dense_unmixed = self._unmixed_dense(in_1e)
dense_unmixed_split = jnp.split(dense_unmixed, self.spin_split, axis=-2)
split_1e_means = _split_mean(in_1e, self.spin_split, axis=-2, keepdims=True)
dense_mixed_split = self._compute_transformed_1e_means(split_1e_means)
# adds the unmixed [i: (..., n[i], d')] to the mixed [i: (..., 1, d')] to get
# an equivariant function. Without the two-electron mixing, this is a spinful
# version of DeepSet's Lemma 3: https://arxiv.org/pdf/1703.06114.pdf
dense_out = tree_sum(dense_unmixed_split, dense_mixed_split)
if in_2e is not None:
dense_2e_split = self._compute_transformed_2e_means(in_2e)
dense_out = tree_sum(dense_out, dense_2e_split)
dense_out_concat = jnp.concatenate(dense_out, axis=-2)
nonlinear_out = self._activation_fn(dense_out_concat)
if self.skip_connection and _valid_skip(in_1e, nonlinear_out):
nonlinear_out = self.skip_connection_scale * (nonlinear_out + in_1e)
return nonlinear_out
FermiNetOrbitalLayer (Module)
dataclass
Make the FermiNet orbitals (parallel linear layers with exp decay envelopes).
Attributes:
Name | Type | Description |
---|---|---|
orbitals_split |
ParticleSplit |
number of pieces to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
norbitals_per_split |
Sequence[int] |
sequence of integers specifying the number of orbitals to create for each split. This determines the output shapes for each split, i.e. the outputs are shaped (..., split_size[i], norbitals[i]) |
kernel_initializer_linear |
WeightInitializer |
kernel initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array |
kernel_initializer_envelope_dim |
WeightInitializer |
kernel initializer for the
decay rate in the exponential envelopes. If |
kernel_initializer_envelope_ion |
WeightInitializer |
kernel initializer for the linear combination over the ions of exponential envelopes. Has signature (key, shape, dtype) -> Array |
bias_initializer_linear |
WeightInitializer |
bias initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array |
use_bias |
bool |
whether to add a bias term to the linear part of the orbitals. Defaults to True. |
isotropic_decay |
bool |
whether the decay for each ion should be anisotropic (w.r.t. the dimensions of the input), giving envelopes of the form exp(-||A(r - R)||) for a dxd matrix A or isotropic, giving exp(-||a(r - R||)) for a number a. |
setup(self)
Setup envelope kernel initializers.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Setup envelope kernel initializers."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._kernel_initializer_envelope_dim = self.kernel_initializer_envelope_dim
self._kernel_initializer_envelope_ion = self.kernel_initializer_envelope_ion
__call__(self, x, r_ei=None)
special
Apply a dense layer R -> R^n for each split and multiply by exp envelopes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Array |
array of shape (..., nelec, d) |
required |
r_ei |
Array |
array of shape (..., nelec, nion, d) |
None |
Returns:
Type | Description |
---|---|
[(..., nelec[i], self.norbitals_per_split[i])] |
list of FermiNet orbital matrices computed from an output stream x and the electron-ion displacements r_ei. Here n[i] is the number of particles in the ith split. The exponential envelopes are computed only when r_ei is not None (so, when connected to FermiNetBackflow, when ion locations are specified). To output square matrices, say for composing with the determinant anti-symmetry, nelec[i] should be equal to self.norbitals_per_split[i]. |
Source code in vmcnet/models/equivariance.py
@flax.linen.compact
def __call__( # type: ignore[override]
self, x: Array, r_ei: Array = None
) -> ArrayList:
"""Apply a dense layer R -> R^n for each split and multiply by exp envelopes.
Args:
x (Array): array of shape (..., nelec, d)
r_ei (Array): array of shape (..., nelec, nion, d)
Returns:
[(..., nelec[i], self.norbitals_per_split[i])]: list of FermiNet orbital
matrices computed from an output stream x and the electron-ion displacements
r_ei. Here n[i] is the number of particles in the ith split. The exponential
envelopes are computed only when r_ei is not None (so, when connected to
FermiNetBackflow, when ion locations are specified). To output square
matrices, say for composing with the determinant anti-symmetry,
nelec[i] should be equal to self.norbitals_per_split[i].
"""
orbs = SplitDense(
self.orbitals_split,
self.norbitals_per_split,
self.kernel_initializer_linear,
self.bias_initializer_linear,
use_bias=self.use_bias,
)(x)
if r_ei is not None:
exp_envelopes = _compute_exponential_envelopes_all_splits(
r_ei,
self.orbitals_split,
self.norbitals_per_split,
self._kernel_initializer_envelope_dim,
self._kernel_initializer_envelope_ion,
self.isotropic_decay,
)
orbs = tree_prod(orbs, exp_envelopes)
return orbs
FermiNetResidualBlock (Module)
dataclass
A single residual block in the FermiNet equivariant part.
Combines the one-electron and two-electron streams.
Attributes:
Name | Type | Description |
---|---|---|
one_electron_layer |
Callable |
function which takes in a previous one-electron stream output and two-electron stream output and mixes/transforms them to create a new one-electron stream output. Has the signature: (array of shape (..., n, d_1e), optional array of shape (..., n, n, d_2e)) -> array of shape (..., n, d_1e') |
two_electron_layer |
Callable |
function which takes in a previous two-electron stream output and batch applies a Dense layer along the last axis. Has the signature: array of shape (..., n, n, d_2e) -> array of shape (..., n, n, d_2e') |
setup(self)
Setup called one- and two- electron layers.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Setup called one- and two- electron layers."""
self._one_electron_layer = self.one_electron_layer
self._two_electron_layer = self.two_electron_layer
__call__(self, in_1e, in_2e=None)
special
Apply the one-electron layer and optionally the two-electron layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_1e |
Array |
array of shape (..., n_total, d_1e) |
required |
in_2e |
Array |
array of shape (..., n_total, n_total, d_2e). Defaults to None. |
None |
Returns:
Type | Description |
---|---|
(Array, optional Array) |
tuple of (out_1e, out_2e) where out_1e is the output from the one-electron layer and out_2e is the output of the two-electron stream |
Source code in vmcnet/models/equivariance.py
def __call__( # type: ignore[override]
self, in_1e: Array, in_2e: Array = None
) -> Tuple[Array, Optional[Array]]:
"""Apply the one-electron layer and optionally the two-electron layer.
Args:
in_1e (Array): array of shape (..., n_total, d_1e)
in_2e (Array, optional): array of shape (..., n_total, n_total, d_2e).
Defaults to None.
Returns:
(Array, optional Array): tuple of (out_1e, out_2e) where out_1e
is the output from the one-electron layer and out_2e is the output of the
two-electron stream
"""
out_1e = self._one_electron_layer(in_1e, in_2e)
out_2e = in_2e
if self.two_electron_layer is not None and in_2e is not None:
out_2e = self._two_electron_layer(in_2e)
return out_1e, out_2e
FermiNetTwoElectronLayer (Module)
dataclass
A single layer in the two-electron stream of the FermiNet equivariance.
Attributes:
Name | Type | Description |
---|---|---|
ndense |
int |
number of dense nodes |
kernel_initializer |
WeightInitializer |
kernel initializer. Has signature (key, shape, dtype) -> Array |
bias_initializer |
WeightInitializer |
bias initializer. Has signature (key, shape, dtype) -> Array |
activation_fn |
Activation |
activation function. Has the signature Array -> Array (shape is preserved) |
use_bias |
bool |
whether to add a bias term. Defaults to True. |
skip_connection |
bool |
whether to add residual skip connections whenever the shapes of the input and output match. Defaults to True. |
skip_connection_scale |
float |
quantity to scale the final output by if a skip connection is added. Defaults to 1.0. |
setup(self)
Setup Dense layer.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Setup Dense layer."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._activation_fn = self.activation_fn
self._dense = Dense(
self.ndense,
kernel_init=self.kernel_initializer,
bias_init=self.bias_initializer,
use_bias=self.use_bias,
)
__call__(self, x)
special
Apply a Dense layer in parallel to all electron pairs.
The expected use-case of this is to batch apply a dense layer to an input x of shape (..., n_total, n_total, d), getting an output of shape (..., n_total, n_total, d'), and optionally adding a skip connection. The function itself is just a standard residual network layer.
Source code in vmcnet/models/equivariance.py
def __call__(self, x: Array) -> Array: # type: ignore[override]
"""Apply a Dense layer in parallel to all electron pairs.
The expected use-case of this is to batch apply a dense layer to an input x of
shape (..., n_total, n_total, d), getting an output of shape
(..., n_total, n_total, d'), and optionally adding a skip connection. The
function itself is just a standard residual network layer.
"""
dense_out = self._dense(x)
nonlinear_out = self._activation_fn(dense_out)
if self.skip_connection and _valid_skip(x, nonlinear_out):
nonlinear_out = self.skip_connection_scale * (nonlinear_out + x)
return nonlinear_out
SplitDense (Module)
dataclass
Split input on the 2nd-to-last axis and apply unique Dense layers to each split.
Attributes:
Name | Type | Description |
---|---|---|
split |
ParticleSplit |
number of pieces to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
ndense_per_split |
Sequence[int] |
sequence of integers specifying the number of dense nodes in the unique dense layer applied to each split of the input. This determines the output shapes for each split, i.e. the outputs are shaped (..., split_size[i], ndense[i]) |
kernel_initializer |
WeightInitializer |
kernel initializer. Has signature (key, shape, dtype) -> Array |
bias_initializer |
WeightInitializer |
bias initializer. Has signature (key, shape, dtype) -> Array. Defaults to random normal initialization. |
use_bias |
bool |
whether to add a bias term. Defaults to True. |
register_kfac |
bool |
whether to register the dense computations with KFAC. Defaults to True. |
setup(self)
Set up the dense layers for each split.
Source code in vmcnet/models/equivariance.py
def setup(self):
"""Set up the dense layers for each split."""
nsplits = get_nsplits(self.split)
if len(self.ndense_per_split) != nsplits:
raise ValueError(
"Incorrect number of dense output shapes specified for number of "
"splits, should be one shape per split: shapes {} specified for the "
"given split {}".format(self.ndense_per_split, self.split)
)
self._dense_layers = [
Dense(
self.ndense_per_split[i],
kernel_init=self.kernel_initializer,
bias_init=self.bias_initializer,
use_bias=self.use_bias,
register_kfac=self.register_kfac,
)
for i in range(nsplits)
]
__call__(self, x)
special
Split the input and apply a dense layer to each split.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Array |
array of shape (..., n, d) |
required |
Returns:
Type | Description |
---|---|
[(..., n[i], self.ndense_per_split[i])] |
list of length nsplits, where nsplits is the number of splits created by jnp.split(x, self.split, axis=-2), and the ith entry of the output is the ith split transformed by a dense layer with self.ndense_per_split[i] nodes. |
Source code in vmcnet/models/equivariance.py
def __call__(self, x: Array) -> ArrayList: # type: ignore[override]
"""Split the input and apply a dense layer to each split.
Args:
x (Array): array of shape (..., n, d)
Returns:
[(..., n[i], self.ndense_per_split[i])]: list of length nsplits, where
nsplits is the number of splits created by
jnp.split(x, self.split, axis=-2), and the ith entry of the output is the
ith split transformed by a dense layer with self.ndense_per_split[i] nodes.
"""
x_split = jnp.split(x, self.split, axis=-2)
return [self._dense_layers[i](split) for i, split in enumerate(x_split)]
compute_input_streams(elec_pos, ion_pos=None, include_2e_stream=True, include_ei_norm=True, include_ee_norm=True)
Create input streams with electron and optionally ion data.
If ion_pos
is given, computes the electron-ion displacements (i.e. nuclear
coordinates) and concatenates/flattens them along the ion dimension. If
include_ei_norm
is True, then the distances are also concatenated, so the map is
elec_pos = (..., nelec, d) -> input_1e = (..., nelec, nion * (d + 1)).
If include_2e_stream
is True, then a two-electron stream of shape
(..., nelec, nelec, d) is also computed and returned (otherwise None is returned).
If include_ee_norm
is True, then this becomes (..., nelec, nelec, d + 1) by
concatenating pairwise distances onto the stream.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
electron positions of shape (..., nelec, d) |
required |
ion_pos |
Array |
locations of (stationary) ions to compute relative electron positions, 2-d array of shape (nion, d). Defaults to None. |
None |
include_2e_stream |
bool |
whether to compute pairwise electron displacements/distances. Defaults to True. |
True |
include_ei_norm |
bool |
whether to include electron-ion distances in the one-electron input. Defaults to True. |
True |
include_ee_norm |
bool |
whether to include electron-electron distances in the two-electron input. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
(
Array,
Optional[Array],
Optional[Array],
Optional[Array],
) |
first output: one-electron input of shape (..., nelec, d'), where
d' = d if second output: two-electron input of shape (..., nelec, nelec, d'), where
d' = d if third output: electron-ion displacements of shape (..., nelec, nion, d) fourth output: electron-electron displacements of shape (..., nelec, nelec, d) If |
Source code in vmcnet/models/equivariance.py
def compute_input_streams(
elec_pos: Array,
ion_pos: Array = None,
include_2e_stream: bool = True,
include_ei_norm: bool = True,
include_ee_norm: bool = True,
) -> InputStreams:
"""Create input streams with electron and optionally ion data.
If `ion_pos` is given, computes the electron-ion displacements (i.e. nuclear
coordinates) and concatenates/flattens them along the ion dimension. If
`include_ei_norm` is True, then the distances are also concatenated, so the map is
elec_pos = (..., nelec, d) -> input_1e = (..., nelec, nion * (d + 1)).
If `include_2e_stream` is True, then a two-electron stream of shape
(..., nelec, nelec, d) is also computed and returned (otherwise None is returned).
If `include_ee_norm` is True, then this becomes (..., nelec, nelec, d + 1) by
concatenating pairwise distances onto the stream.
Args:
elec_pos (Array): electron positions of shape (..., nelec, d)
ion_pos (Array, optional): locations of (stationary) ions to compute
relative electron positions, 2-d array of shape (nion, d). Defaults to None.
include_2e_stream (bool, optional): whether to compute pairwise electron
displacements/distances. Defaults to True.
include_ei_norm (bool, optional): whether to include electron-ion distances in
the one-electron input. Defaults to True.
include_ee_norm (bool, optional): whether to include electron-electron distances
in the two-electron input. Defaults to True.
Returns:
(
Array,
Optional[Array],
Optional[Array],
Optional[Array],
):
first output: one-electron input of shape (..., nelec, d'), where
d' = d if `ion_pos` is None,
d' = nion * d if `ion_pos` is given and `include_ei_norm` is False, and
d' = nion * (d + 1) if `ion_pos` is given and `include_ei_norm` is True.
second output: two-electron input of shape (..., nelec, nelec, d'), where
d' = d if `include_ee_norm` is False, and
d' = d + 1 if `include_ee_norm` is True
third output: electron-ion displacements of shape (..., nelec, nion, d)
fourth output: electron-electron displacements of shape (..., nelec, nelec, d)
If `include_2e_stream` is False, then the second and fourth outputs are None. If
`ion_pos` is None, then the third output is None.
"""
input_1e, r_ei = compute_electron_ion(elec_pos, ion_pos, include_ei_norm)
input_2e = None
r_ee = None
if include_2e_stream:
input_2e, r_ee = compute_electron_electron(elec_pos, include_ee_norm)
return input_1e, input_2e, r_ei, r_ee
compute_electron_ion(elec_pos, ion_pos=None, include_ei_norm=True)
Compute electron-ion displacements and optionally add on the distances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
electron positions of shape (..., nelec, d) |
required |
ion_pos |
Array |
locations of (stationary) ions to compute relative electron positions, 2-d array of shape (nion, d). Defaults to None. |
None |
include_ei_norm |
bool |
whether to include electron-ion distances in the one-electron input. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
(Array, Optional[Array]) |
first output: one-electron input of shape (..., nelec, d'), where
d' = d if second output: electron-ion displacements of shape (..., nelec, nion, d) If |
Source code in vmcnet/models/equivariance.py
def compute_electron_ion(
elec_pos: Array, ion_pos: Array = None, include_ei_norm: bool = True
) -> Tuple[Array, Optional[Array]]:
"""Compute electron-ion displacements and optionally add on the distances.
Args:
elec_pos (Array): electron positions of shape (..., nelec, d)
ion_pos (Array, optional): locations of (stationary) ions to compute
relative electron positions, 2-d array of shape (nion, d). Defaults to None.
include_ei_norm (bool, optional): whether to include electron-ion distances in
the one-electron input. Defaults to True.
Returns:
(Array, Optional[Array]):
first output: one-electron input of shape (..., nelec, d'), where
d' = d if `ion_pos` is None,
d' = nion * d if `ion_pos` is given and `include_ei_norm` is False, and
d' = nion * (d + 1) if `ion_pos` is given and `include_ei_norm` is True.
second output: electron-ion displacements of shape (..., nelec, nion, d)
If `ion_pos` is None, then the second output is None.
"""
r_ei = None
input_1e = elec_pos
if ion_pos is not None:
r_ei = _compute_displacements(input_1e, ion_pos)
input_1e = r_ei
if include_ei_norm:
input_norm = jnp.linalg.norm(input_1e, axis=-1, keepdims=True)
input_with_norm = jnp.concatenate([input_1e, input_norm], axis=-1)
input_1e = jnp.reshape(input_with_norm, input_with_norm.shape[:-2] + (-1,))
return input_1e, r_ei
compute_electron_electron(elec_pos, include_ee_norm=True)
Compute electron-electron displacements and optionally add on the distances.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
electron positions of shape (..., nelec, d) |
required |
include_ee_norm |
bool |
whether to include electron-electron distances in the two-electron input. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
(Array, Array) |
first output: two-electron input of shape (..., nelec, nelec, d'), where
d' = d if second output: two-electron displacements of shape (..., nelec, nelec, d) |
Source code in vmcnet/models/equivariance.py
def compute_electron_electron(
elec_pos: Array, include_ee_norm: bool = True
) -> Tuple[Array, Array]:
"""Compute electron-electron displacements and optionally add on the distances.
Args:
elec_pos (Array): electron positions of shape (..., nelec, d)
include_ee_norm (bool, optional): whether to include electron-electron distances
in the two-electron input. Defaults to True.
Returns:
(Array, Array):
first output: two-electron input of shape (..., nelec, nelec, d'), where
d' = d if `include_ee_norm` is False, and
d' = d + 1 if `include_ee_norm` is True
second output: two-electron displacements of shape (..., nelec, nelec, d)
"""
r_ee = _compute_displacements(elec_pos, elec_pos)
input_2e = r_ee
if include_ee_norm:
r_ee_norm = compute_ee_norm_with_safe_diag(r_ee)
input_2e = jnp.concatenate([input_2e, r_ee_norm], axis=-1)
return input_2e, r_ee