construct
Combine pieces to form full models.
AntiequivarianceNet (Module)
dataclass
Antisymmetry from anti-equivariance, backflow -> antieq -> odd invariance.
Attributes:
Name | Type | Description |
---|---|---|
spin_split |
ParticleSplit |
number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
compute_input_streams |
ComputeInputStreams |
function to compute input streams from electron positions. Has the signature (elec_pos of shape (..., n, d)) -> ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), optional r_ei of shape (..., n, nion, d), optional r_ee of shape (..., n, n, d), ) |
backflow |
Callable |
function which computes position features from the electron positions. Has the signature ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), ) -> stream_1e of shape (..., n, d') |
antiequivariant_layer |
Callable |
function which computes antiequivariances-per- spin. Has the signature (stream_1e of shape (..., n, d_backflow), r_ei of shape (..., n, nion, d)) -> (antieqs of shapes [spin: (..., n[spin], d_antieq)]) |
array_list_sign_covariance |
Callable |
function which is sign- covariant with respect to each spin. Has the signature [(..., nelec[spin], d_antieq)] -> (..., d_antisym). Since this function is sign covariant, its outputs are antisymmetric, so Psi can be calculated by summing over the final axis of the result. |
multiply_by_eq_features |
bool |
If True, the antiequivariance from the antiequivariant_layer is multiplied by the equivariant features from the backflow before being fed into the sign covariant function. If False, the antiequivariance is processed directly by the sign covariant function. Defaults to False. |
setup(self)
Setup backflow.
Source code in vmcnet/models/construct.py
def setup(self):
"""Setup backflow."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._compute_input_streams = self.compute_input_streams
self._backflow = self.backflow
self._antiequivariant_layer = self.antiequivariant_layer
self._array_list_sign_covariance = self.array_list_sign_covariance
__call__(self, elec_pos)
special
Compose backflow -> antiequivariance -> sign covariant equivariance -> sum.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
array of particle positions (..., nelec, d) |
required |
Returns:
Type | Description |
---|---|
Array |
log(abs(psi)), where psi is a general odd invariance of an anti-equivariant backflow. If the inputs have shape (batch_dims, nelec, d), then the output has shape (batch_dims,). |
Source code in vmcnet/models/construct.py
@flax.linen.compact
def __call__(self, elec_pos: Array) -> SLArray: # type: ignore[override]
"""Compose backflow -> antiequivariance -> sign covariant equivariance -> sum.
Args:
elec_pos (Array): array of particle positions (..., nelec, d)
Returns:
Array: log(abs(psi)), where psi is a general odd invariance of an
anti-equivariant backflow. If the inputs have shape (batch_dims, nelec, d),
then the output has shape (batch_dims,).
"""
stream_1e, stream_2e, r_ei, _ = self._compute_input_streams(elec_pos)
backflow_out = self._backflow(stream_1e, stream_2e)
antiequivariant_out = self._antiequivariant_layer(backflow_out, r_ei)
if self.multiply_by_eq_features:
antiequivariant_out = antiequivariance.multiply_antieq_by_eq_features(
antiequivariant_out, backflow_out, self.spin_split
)
antisym_vector = self._array_list_sign_covariance(antiequivariant_out)
return array_to_slog(jnp.sum(antisym_vector, axis=-1))
DeterminantFnMode (Enum)
Enum specifying how to use determinant resnet in FermiNet model.
EmbeddedParticleFermiNet (FermiNet)
dataclass
Model that expands its inputs with extra hidden particles, then applies FermiNet.
Note: the backflow argument supplied for the construction of this model should use the spin_split for the TOTAL number of particles, visible and hidden, for each spin, not the standard spin_split for the visible particles only.
Attributes:
Name | Type | Description |
---|---|---|
nhidden_fermions_per_spin |
Sequence[int] |
number of hidden fermions to generate for each spin. Must have length nspins. |
invariance_compute_input_streams |
ComputeInputStreams |
function to compute input streams from electron positions, for the invariance that is used to generate the hidden particle positions. Has the signature (elec_pos of shape (..., n, d)) -> ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), optional r_ei of shape (..., n, nion, d), optional r_ee of shape (..., n, n, d), ) |
invariance_backflow |
Callable |
backflow function to be used for the invariance which generates the hidden fermion positions. Has the signature ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), ) -> stream_1e of shape (..., n, d') |
invariance_kernel_initializer |
WeightInitializer |
kernel initializer for the invariance dense layer. Has signature (key, shape, dtype) -> Array |
invariance_bias_initializer |
WeightInitializer |
bias initializer for the invariance dense layer. Has signature (key, shape, dtype) -> Array |
invariance_use_bias |
(bool, optional): whether to add a bias term in the dense layer of the invariance. Defaults to True. |
|
invariance_register_kfac |
bool |
whether to register the dense layer of the invariance with KFAC. Defaults to True. |
setup(self)
Setup EmbeddedParticleFermiNet.
Source code in vmcnet/models/construct.py
def setup(self):
"""Setup EmbeddedParticleFermiNet."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
super().setup()
self._invariance_compute_input_streams = self.invariance_compute_input_streams
self._invariance_backflow = self.invariance_backflow
ExtendedOrbitalMatrixFermiNet (FermiNet)
dataclass
FermiNet-based model with larger orbital matrices via padding with invariance.
Attributes:
Name | Type | Description |
---|---|---|
nhidden_fermions_per_spin |
Sequence[int] |
sequence of integers specifying how many extra hidden particle dimensions and corresponding virtual orbitals to add to the orbital matrices. If not None, must have length nspins. Defaults to None (no extra dims added, equivalent to FermiNet). |
invariance_kernel_initializer |
WeightInitializer |
kernel initializer for the invariance dense layer. Has signature (key, shape, dtype) -> Array. Defaults to an orthogonal initializer. |
invariance_bias_initializer |
WeightInitializer |
bias initializer for the invariance dense layer. Has signature (key, shape, dtype) -> Array. Defaults to a scaled random normal initializer. |
invariance_use_bias |
(bool, optional): whether to add a bias term in the dense layer of the invariance. Defaults to True. |
|
invariance_register_kfac |
bool |
whether to register the dense layer of the invariance with KFAC. Defaults to True. |
invariance_backflow |
Callable |
backflow function to be used for the invariance which generates the hidden fermion positions. If None, the outputs of the regular FermiNet backflow are used instead to form an invariance. Defaults to None. |
FactorizedAntisymmetry (Module)
dataclass
A sum of products of explicitly antisymmetrized ResNets, composed with backflow.
This connects the computational graph between a backflow, a factorized antisymmetrized ResNet, and a jastrow.
See https://arxiv.org/abs/2112.03491 for a description of the factorized antisymmetric layer.
Attributes:
Name | Type | Description |
---|---|---|
spin_split |
ParticleSplit |
number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
compute_input_streams |
ComputeInputStreams |
function to compute input streams from electron positions. Has the signature (elec_pos of shape (..., n, d)) -> ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), optional r_ei of shape (..., n, nion, d), optional r_ee of shape (..., n, n, d), ) |
backflow |
Callable |
function which computes position features from the electron positions. Has the signature ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), ) -> stream_1e of shape (..., n, d') |
jastrow |
Callable |
function which computes a Jastrow factor from displacements. Has the signature ( r_ei of shape (batch_dims, n, nion, d), r_ee of shape (batch_dims, n, n, d), ) -> log jastrow of shape (batch_dims,) |
rank |
int |
The rank of the explicit antisymmetry. In practical terms, the number of resnets to antisymmetrize for each spin. This is analogous to ndeterminants for regular FermiNet. |
ndense_resnet |
int |
number of dense nodes in each layer of each antisymmetrized ResNet |
nlayers_resnet |
int |
number of layers in each antisymmetrized ResNet |
kernel_initializer_resnet |
WeightInitializer |
kernel initializer for the dense layers in the antisymmetrized ResNets. Has signature (key, shape, dtype) -> Array |
bias_initializer_resnet |
WeightInitializer |
bias initializer for the dense layers in the antisymmetrized ResNets. Has signature (key, shape, dtype) -> Array |
activation_fn_resnet |
Activation |
activation function in the antisymmetrized ResNets. Has the signature Array -> Array (shape is preserved) |
resnet_use_bias |
bool |
whether to add a bias term in the dense layers of the antisymmetrized ResNets. Defaults to True. |
setup(self)
Setup backflow.
Source code in vmcnet/models/construct.py
def setup(self):
"""Setup backflow."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._compute_input_streams = self.compute_input_streams
self._backflow = self.backflow
self._jastrow = self.jastrow
__call__(self, elec_pos)
special
Compose FermiNet backflow -> antisymmetrized ResNets -> logabs product.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
array of particle positions (..., nelec, d) |
required |
Returns:
Type | Description |
---|---|
Array |
spinful antisymmetrized output; logarithm of the absolute value of a anti-symmetric function of elec_pos, where the anti-symmetry is with respect to the second-to-last axis of elec_pos. The anti-symmetry holds for particles within the same split, but not for permutations which swap particles across different spin splits. If the inputs have shape (batch_dims, nelec, d), then the output has shape (batch_dims,). |
Source code in vmcnet/models/construct.py
@flax.linen.compact
def __call__(self, elec_pos: Array) -> SLArray: # type: ignore[override]
"""Compose FermiNet backflow -> antisymmetrized ResNets -> logabs product.
Args:
elec_pos (Array): array of particle positions (..., nelec, d)
Returns:
Array: spinful antisymmetrized output; logarithm of the absolute value
of a anti-symmetric function of elec_pos, where the anti-symmetry is with
respect to the second-to-last axis of elec_pos. The anti-symmetry holds for
particles within the same split, but not for permutations which swap
particles across different spin splits. If the inputs have shape
(batch_dims, nelec, d), then the output has shape (batch_dims,).
"""
input_stream_1e, input_stream_2e, r_ei, r_ee = self._compute_input_streams(
elec_pos
)
stream_1e = self._backflow(input_stream_1e, input_stream_2e)
split_spins = jnp.split(stream_1e, self.spin_split, axis=-2)
def fn_to_antisymmetrize(x_one_spin):
resnet_outputs = [
SimpleResNet(
self.ndense_resnet,
1,
self.nlayers_resnet,
self.activation_fn_resnet,
self.kernel_initializer_resnet,
self.bias_initializer_resnet,
use_bias=self.resnet_use_bias,
)(x_one_spin)
for _ in range(self.rank)
]
return jnp.concatenate(resnet_outputs, axis=-1)
# TODO (ggoldsh/jeffminlin): better typing for the Array vs SLArray version of
# this model to avoid having to cast the return type.
slog_antisyms = cast(
SLArray,
FactorizedAntisymmetrize([fn_to_antisymmetrize for _ in split_spins])(
split_spins
),
)
sign_psi, log_antisyms = slog_sum_over_axis(slog_antisyms, axis=-1)
jastrow_part = self._jastrow(
input_stream_1e, input_stream_2e, stream_1e, r_ei, r_ee
)
return sign_psi, log_antisyms + jastrow_part
FermiNet (Module)
dataclass
FermiNet/generalized Slater determinant model.
This model was first introduced in the following papers: https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.2.033429 https://arxiv.org/abs/2011.07125 Their repository can be found at https://github.com/deepmind/ferminet, which includes a JAX branch.
Attributes:
Name | Type | Description |
---|---|---|
spin_split |
ParticleSplit |
number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
compute_input_streams |
ComputeInputStreams |
function to compute input streams from electron positions. Has the signature (elec_pos of shape (..., n, d)) -> ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), optional r_ei of shape (..., n, nion, d), optional r_ee of shape (..., n, n, d), ) |
backflow |
Callable |
function which computes position features from the electron positions. Has the signature ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), ) -> stream_1e of shape (..., n, d') |
ndeterminants |
int |
number of determinants in the FermiNet model, i.e. the number of distinct orbital layers applied |
kernel_initializer_orbital_linear |
WeightInitializer |
kernel initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array |
kernel_initializer_envelope_dim |
WeightInitializer |
kernel initializer for the
decay rate in the exponential envelopes. If |
kernel_initializer_envelope_ion |
WeightInitializer |
kernel initializer for the linear combination over the ions of exponential envelopes. Has signature (key, shape, dtype) -> Array |
bias_initializer_orbital_linear |
WeightInitializer |
bias initializer for the linear part of the orbitals. Has signature (key, shape, dtype) -> Array |
orbitals_use_bias |
bool |
whether to add a bias term in the linear part of the orbitals. |
isotropic_decay |
bool |
whether the decay for each ion should be anisotropic (w.r.t. the dimensions of the input), giving envelopes of the form exp(-||A(r - R)||) for a dxd matrix A or isotropic, giving exp(-||a(r - R||)) for a number a. |
determinant_fn |
DeterminantFn or None |
A optional function with signature dout, [nspins: (..., ndeterminants)] -> (..., dout). If not None, the function will be used to calculate Psi based on the outputs of the orbital matrix determinants. Depending on the determinant_fn_mode selected, this function can be used in one of several ways. If the mode is SIGN_COVARIANCE, the function will use d=1 and will be explicitly symmetrized over the sign group, on a per-spin basis, to be sign-covariant (odd). If PARALLEL_EVEN or PAIRWISE_EVEN are selected, the function will be symmetrized to be spin-wise sign invariant (even). For PARALLEL_EVEN, the function will use d=ndeterminants, and each output will be multiplied by the product of corresponding determinants. That is, for 2 spins, with up determinants u_i and down determinants d_i, the ansatz will be sum_{i}(u_i * d_i * f_i(u,d)), where f_i(u,d) is the symmetrized determinant function. For PAIRWISE_EVEN, the function will use d=ndeterminants**nspins, and each output will again be multiplied by a product of determinants, but this time the determinants will range over all pairs. That is, for 2 spins, the ansatz will be sum_{i, j}(u_i * d_j * f_{i,j}(u,d)). Currently, PAIRWISE_EVEN mode only supports nspins = 2. If None, the equivalent of PARALLEL_EVEN mode (overriding any set determinant_fn_mode) is used without a symmetrized resnet (so the output, before any log-transformations, is a sum of products of determinants). |
determinant_fn_mode |
DeterminantFnMode |
One of SIGN_COVARIANCE, PARALLEL_EVEN, or PAIRWISE_EVEN. Used to decide how exactly to use the provided determinant_fn to calculate an ansatz for Psi; irrelevant if determinant_fn is set to None. |
full_det |
bool |
If True, the model will use a single, "full" determinant with orbitals from particles of all spins. For example, for a spin_split of (2,2), the original FermiNet with ndeterminants=1 would calculate two separate 2x2 orbital matrices and multiply their determinants together. A full determinant model would instead calculate a single 4x4 matrix, with the first two particle indices corresponding to the up-spin particles and the last two particle indices corresponding to the down-spin particles. The output of the model would then be the determinant of that single matrix, if ndeterminants=1, or the sum of multiple such determinants if ndeterminants>1. |
setup(self)
Setup backflow and symmetrized determinant function.
Source code in vmcnet/models/construct.py
def setup(self):
"""Setup backflow and symmetrized determinant function."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._compute_input_streams = self.compute_input_streams
self._backflow = self.backflow
self._symmetrized_det_fn = None
if self.determinant_fn is not None:
if self.determinant_fn_mode == DeterminantFnMode.SIGN_COVARIANCE:
self._symmetrized_det_fn = make_array_list_fn_sign_covariant(
functools.partial(self.determinant_fn, 1)
)
elif self.determinant_fn_mode == DeterminantFnMode.PARALLEL_EVEN:
self._symmetrized_det_fn = make_array_list_fn_sign_invariant(
functools.partial(self.determinant_fn, self.ndeterminants)
)
elif self.determinant_fn_mode == DeterminantFnMode.PAIRWISE_EVEN:
# TODO (ggoldsh): build support for PAIRWISE_EVEN for nspins != 2
self._symmetrized_det_fn = make_array_list_fn_sign_invariant(
functools.partial(self.determinant_fn, self.ndeterminants ** 2)
)
else:
raise self._get_bad_determinant_fn_mode_error()
__call__(self, elec_pos)
special
Compose FermiNet backflow -> orbitals -> logabs determinant product.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
array of particle positions (..., nelec, d) |
required |
Returns:
Type | Description |
---|---|
Array |
FermiNet output; logarithm of the absolute value of a anti-symmetric function of elec_pos, where the anti-symmetry is with respect to the second-to-last axis of elec_pos. The anti-symmetry holds for particles within the same split, but not for permutations which swap particles across different spin splits. If the inputs have shape (batch_dims, nelec, d), then the output has shape (batch_dims,). |
Source code in vmcnet/models/construct.py
@flax.linen.compact
def __call__(self, elec_pos: Array) -> SLArray: # type: ignore[override]
"""Compose FermiNet backflow -> orbitals -> logabs determinant product.
Args:
elec_pos (Array): array of particle positions (..., nelec, d)
Returns:
Array: FermiNet output; logarithm of the absolute value of a
anti-symmetric function of elec_pos, where the anti-symmetry is with respect
to the second-to-last axis of elec_pos. The anti-symmetry holds for
particles within the same split, but not for permutations which swap
particles across different spin splits. If the inputs have shape
(batch_dims, nelec, d), then the output has shape (batch_dims,).
"""
elec_pos, orbitals_split = self._get_elec_pos_and_orbitals_split(elec_pos)
input_stream_1e, input_stream_2e, r_ei, _ = self._compute_input_streams(
elec_pos
)
stream_1e = self._backflow(input_stream_1e, input_stream_2e)
norbitals_per_split = self._get_norbitals_per_split(elec_pos, orbitals_split)
# orbitals is [norb_splits: (ndeterminants, ..., nelec[i], norbitals[i])]
orbitals = self._eval_orbitals(
orbitals_split,
norbitals_per_split,
input_stream_1e,
input_stream_2e,
stream_1e,
r_ei,
)
if self.full_det:
orbitals = [jnp.concatenate(orbitals, axis=-2)]
if self._symmetrized_det_fn is not None:
# dets is ArrayList of shape [norb_splits: (ndeterminants, ...)]
dets = jax.tree_map(jnp.linalg.det, orbitals)
# Move axis to get shape [norb_splits: (..., ndeterminants)]
fn_inputs = jax.tree_map(lambda x: jnp.moveaxis(x, 0, -1), dets)
if self.determinant_fn_mode == DeterminantFnMode.SIGN_COVARIANCE:
psi = jnp.squeeze(self._symmetrized_det_fn(fn_inputs), -1)
elif self.determinant_fn_mode == DeterminantFnMode.PARALLEL_EVEN:
psi = self._calculate_psi_parallel_even(fn_inputs)
elif self.determinant_fn_mode == DeterminantFnMode.PAIRWISE_EVEN:
psi = self._calculate_psi_pairwise_even(fn_inputs)
else:
raise self._get_bad_determinant_fn_mode_error()
return array_to_slog(psi)
# slog_det_prods is SLArray of shape (ndeterminants, ...)
slog_det_prods = slogdet_product(orbitals)
return slog_sum_over_axis(slog_det_prods)
GenericAntisymmetry (Module)
dataclass
A single ResNet antisymmetrized over all input leaves, composed with backflow.
The ResNet is antisymmetrized with respect to each spin split separately (i.e. the antisymmetrization operators for each spin are composed and applied).
This connects the computational graph between a backflow, a generic antisymmetrized ResNet, and a jastrow.
See https://arxiv.org/abs/2112.03491 for a description of the generic antisymmetric layer.
Attributes:
Name | Type | Description |
---|---|---|
spin_split |
ParticleSplit |
number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
compute_input_streams |
ComputeInputStreams |
function to compute input streams from electron positions. Has the signature (elec_pos of shape (..., n, d)) -> ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), optional r_ei of shape (..., n, nion, d), optional r_ee of shape (..., n, n, d), ) |
backflow |
Callable |
function which computes position features from the electron positions. Has the signature ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), ) -> stream_1e of shape (..., n, d') |
jastrow |
Callable |
function which computes a Jastrow factor from displacements. Has the signature ( r_ei of shape (batch_dims, n, nion, d), r_ee of shape (batch_dims, n, n, d), ) -> log jastrow of shape (batch_dims,) |
ndense_resnet |
int |
number of dense nodes in each layer of the ResNet |
nlayers_resnet |
int |
number of layers in each antisymmetrized ResNet |
kernel_initializer_resnet |
WeightInitializer |
kernel initializer for the dense layers in the antisymmetrized ResNet. Has signature (key, shape, dtype) -> Array |
bias_initializer_resnet |
WeightInitializer |
bias initializer for the dense layers in the antisymmetrized ResNet. Has signature (key, shape, dtype) -> Array |
activation_fn_resnet |
Activation |
activation function in the antisymmetrized ResNet. Has the signature Array -> Array (shape is preserved) |
resnet_use_bias |
bool |
whether to add a bias term in the dense layers of the antisymmetrized ResNet. Defaults to True. |
setup(self)
Setup backflow.
Source code in vmcnet/models/construct.py
def setup(self):
"""Setup backflow."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._compute_input_streams = self.compute_input_streams
self._backflow = self.backflow
self._jastrow = self.jastrow
self._activation_fn_resnet = self.activation_fn_resnet
__call__(self, elec_pos)
special
Compose FermiNet backflow -> antisymmetrized ResNet -> logabs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
elec_pos |
Array |
array of particle positions (..., nelec, d) |
required |
Returns:
Type | Description |
---|---|
Array |
spinful antisymmetrized output; logarithm of the absolute value of a anti-symmetric function of elec_pos, where the anti-symmetry is with respect to the second-to-last axis of elec_pos. The anti-symmetry holds for particles within the same split, but not for permutations which swap particles across different spin splits. If the inputs have shape (batch_dims, nelec, d), then the output has shape (batch_dims,). |
Source code in vmcnet/models/construct.py
@flax.linen.compact
def __call__(self, elec_pos: Array) -> SLArray: # type: ignore[override]
"""Compose FermiNet backflow -> antisymmetrized ResNet -> logabs.
Args:
elec_pos (Array): array of particle positions (..., nelec, d)
Returns:
Array: spinful antisymmetrized output; logarithm of the absolute value
of a anti-symmetric function of elec_pos, where the anti-symmetry is with
respect to the second-to-last axis of elec_pos. The anti-symmetry holds for
particles within the same split, but not for permutations which swap
particles across different spin splits. If the inputs have shape
(batch_dims, nelec, d), then the output has shape (batch_dims,).
"""
input_stream_1e, input_stream_2e, r_ei, r_ee = self._compute_input_streams(
elec_pos
)
stream_1e = self._backflow(input_stream_1e, input_stream_2e)
split_spins = jnp.split(stream_1e, self.spin_split, axis=-2)
sign_psi, log_antisym = GenericAntisymmetrize(
SimpleResNet(
self.ndense_resnet,
1,
self.nlayers_resnet,
self._activation_fn_resnet,
self.kernel_initializer_resnet,
self.bias_initializer_resnet,
use_bias=self.resnet_use_bias,
)
)(split_spins)
jastrow_part = self._jastrow(
input_stream_1e, input_stream_2e, stream_1e, r_ei, r_ee
)
return sign_psi, log_antisym + jastrow_part
slog_psi_to_log_psi_apply(slog_psi_apply)
Get a log|psi| model apply callable from a sign(psi), log|psi| apply callable.
Source code in vmcnet/models/construct.py
def slog_psi_to_log_psi_apply(
slog_psi_apply: Callable[..., SLArray]
) -> Callable[..., Array]:
"""Get a log|psi| model apply callable from a sign(psi), log|psi| apply callable."""
def log_psi_apply(*args) -> Array:
return slog_psi_apply(*args)[1]
return log_psi_apply
get_model_from_config(model_config, nelec, ion_pos, ion_charges, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)
Get a model from a hyperparameter config.
Source code in vmcnet/models/construct.py
def get_model_from_config(
model_config: ConfigDict,
nelec: Array,
ion_pos: Array,
ion_charges: Array,
dtype=jnp.float32,
) -> Module:
"""Get a model from a hyperparameter config."""
spin_split = get_spin_split(nelec)
compute_input_streams = get_compute_input_streams_from_config(
model_config.input_streams, ion_pos
)
backflow = get_backflow_from_config(
model_config.backflow,
spin_split,
dtype=dtype,
)
kernel_init_constructor, bias_init_constructor = _get_dtype_init_constructors(dtype)
ferminet_model_types = [
"ferminet",
"embedded_particle_ferminet",
"extended_orbital_matrix_ferminet",
]
if model_config.type in ferminet_model_types:
determinant_fn = None
resnet_config = model_config.det_resnet
if model_config.use_det_resnet:
determinant_fn = get_resnet_determinant_fn_for_ferminet(
resnet_config.ndense,
resnet_config.nlayers,
_get_named_activation_fn(resnet_config.activation),
kernel_init_constructor(resnet_config.kernel_init),
bias_init_constructor(resnet_config.bias_init),
resnet_config.use_bias,
resnet_config.register_kfac,
)
# TODO(Jeffmin): make interface more flexible w.r.t. different types of Jastrows
if model_config.type == "ferminet":
return FermiNet(
spin_split,
compute_input_streams,
backflow,
model_config.ndeterminants,
kernel_initializer_orbital_linear=kernel_init_constructor(
model_config.kernel_init_orbital_linear
),
kernel_initializer_envelope_dim=kernel_init_constructor(
model_config.kernel_init_envelope_dim
),
kernel_initializer_envelope_ion=kernel_init_constructor(
model_config.kernel_init_envelope_ion
),
bias_initializer_orbital_linear=bias_init_constructor(
model_config.bias_init_orbital_linear
),
orbitals_use_bias=model_config.orbitals_use_bias,
isotropic_decay=model_config.isotropic_decay,
determinant_fn=determinant_fn,
determinant_fn_mode=DeterminantFnMode[resnet_config.mode.upper()],
full_det=model_config.full_det,
)
elif model_config.type == "embedded_particle_ferminet":
total_nelec = jnp.array(model_config.nhidden_fermions_per_spin) + nelec
total_spin_split = get_spin_split(total_nelec)
backflow = get_backflow_from_config(
model_config.backflow,
total_spin_split,
dtype=dtype,
)
invariance_config = model_config.invariance
invariance_compute_input_streams = get_compute_input_streams_from_config(
invariance_config.input_streams, ion_pos
)
invariance_backflow: Optional[Module] = get_backflow_from_config(
invariance_config.backflow,
spin_split,
dtype=dtype,
)
return EmbeddedParticleFermiNet(
spin_split,
compute_input_streams,
backflow,
model_config.ndeterminants,
kernel_initializer_orbital_linear=kernel_init_constructor(
model_config.kernel_init_orbital_linear
),
kernel_initializer_envelope_dim=kernel_init_constructor(
model_config.kernel_init_envelope_dim
),
kernel_initializer_envelope_ion=kernel_init_constructor(
model_config.kernel_init_envelope_ion
),
bias_initializer_orbital_linear=bias_init_constructor(
model_config.bias_init_orbital_linear
),
orbitals_use_bias=model_config.orbitals_use_bias,
isotropic_decay=model_config.isotropic_decay,
determinant_fn=determinant_fn,
determinant_fn_mode=DeterminantFnMode[resnet_config.mode.upper()],
full_det=model_config.full_det,
nhidden_fermions_per_spin=model_config.nhidden_fermions_per_spin,
invariance_compute_input_streams=invariance_compute_input_streams,
invariance_backflow=invariance_backflow,
invariance_kernel_initializer=kernel_init_constructor(
invariance_config.kernel_initializer
),
invariance_bias_initializer=bias_init_constructor(
invariance_config.bias_initializer
),
invariance_use_bias=invariance_config.use_bias,
invariance_register_kfac=invariance_config.register_kfac,
)
elif model_config.type == "extended_orbital_matrix_ferminet":
invariance_config = model_config.invariance
if model_config.use_separate_invariance_backflow:
invariance_backflow = get_backflow_from_config(
invariance_config.backflow,
spin_split,
dtype=dtype,
)
else:
invariance_backflow = None
return ExtendedOrbitalMatrixFermiNet(
spin_split,
compute_input_streams,
backflow,
model_config.ndeterminants,
kernel_initializer_orbital_linear=kernel_init_constructor(
model_config.kernel_init_orbital_linear
),
kernel_initializer_envelope_dim=kernel_init_constructor(
model_config.kernel_init_envelope_dim
),
kernel_initializer_envelope_ion=kernel_init_constructor(
model_config.kernel_init_envelope_ion
),
bias_initializer_orbital_linear=bias_init_constructor(
model_config.bias_init_orbital_linear
),
orbitals_use_bias=model_config.orbitals_use_bias,
isotropic_decay=model_config.isotropic_decay,
determinant_fn=determinant_fn,
determinant_fn_mode=DeterminantFnMode[resnet_config.mode.upper()],
full_det=model_config.full_det,
nhidden_fermions_per_spin=model_config.nhidden_fermions_per_spin,
invariance_backflow=invariance_backflow,
invariance_kernel_initializer=kernel_init_constructor(
invariance_config.kernel_initializer
),
invariance_bias_initializer=bias_init_constructor(
invariance_config.bias_initializer
),
invariance_use_bias=invariance_config.use_bias,
invariance_register_kfac=invariance_config.register_kfac,
)
else:
raise ValueError(
"FermiNet model type {} requested, but the only supported "
"types are: {}".format(model_config.type, ferminet_model_types)
)
elif model_config.type in ["orbital_cofactor_net", "per_particle_dets_net"]:
if model_config.type == "orbital_cofactor_net":
antieq_layer: Callable[
[Array, Array], ArrayList
] = antiequivariance.OrbitalCofactorAntiequivarianceLayer(
spin_split,
kernel_initializer_orbital_linear=kernel_init_constructor(
model_config.kernel_init_orbital_linear
),
kernel_initializer_envelope_dim=kernel_init_constructor(
model_config.kernel_init_envelope_dim
),
kernel_initializer_envelope_ion=kernel_init_constructor(
model_config.kernel_init_envelope_ion
),
bias_initializer_orbital_linear=bias_init_constructor(
model_config.bias_init_orbital_linear
),
orbitals_use_bias=model_config.orbitals_use_bias,
isotropic_decay=model_config.isotropic_decay,
)
elif model_config.type == "per_particle_dets_net":
antieq_layer = antiequivariance.PerParticleDeterminantAntiequivarianceLayer(
spin_split,
kernel_initializer_orbital_linear=kernel_init_constructor(
model_config.kernel_init_orbital_linear
),
kernel_initializer_envelope_dim=kernel_init_constructor(
model_config.kernel_init_envelope_dim
),
kernel_initializer_envelope_ion=kernel_init_constructor(
model_config.kernel_init_envelope_ion
),
bias_initializer_orbital_linear=bias_init_constructor(
model_config.bias_init_orbital_linear
),
orbitals_use_bias=model_config.orbitals_use_bias,
isotropic_decay=model_config.isotropic_decay,
)
array_list_sign_covariance = get_sign_covariance_from_config(
model_config, spin_split, kernel_init_constructor, dtype
)
return AntiequivarianceNet(
spin_split,
compute_input_streams,
backflow,
antieq_layer,
array_list_sign_covariance,
multiply_by_eq_features=model_config.multiply_by_eq_features,
)
elif model_config.type == "explicit_antisym":
jastrow_config = model_config.jastrow
def _get_two_body_decay_jastrow():
return get_two_body_decay_scaled_for_chargeless_molecules(
ion_pos,
ion_charges,
init_ee_strength=jastrow_config.two_body_decay.init_ee_strength,
trainable=jastrow_config.two_body_decay.trainable,
)
def _get_backflow_based_jastrow():
if jastrow_config.backflow_based.use_separate_jastrow_backflow:
jastrow_backflow = get_backflow_from_config(
jastrow_config.backflow_based.backflow,
spin_split,
dtype=dtype,
)
else:
jastrow_backflow = None
return BackflowJastrow(backflow=jastrow_backflow)
if jastrow_config.type == "one_body_decay":
jastrow: Jastrow = OneBodyExpDecay(
kernel_initializer=kernel_init_constructor(
jastrow_config.one_body_decay.kernel_init
)
)
elif jastrow_config.type == "two_body_decay":
jastrow = _get_two_body_decay_jastrow()
elif jastrow_config.type == "backflow_based":
jastrow = _get_backflow_based_jastrow()
elif jastrow_config.type == "two_body_decay_and_backflow_based":
two_body_decay_jastrow = _get_two_body_decay_jastrow()
backflow_jastrow = _get_backflow_based_jastrow()
jastrow = AddedModel([two_body_decay_jastrow, backflow_jastrow])
else:
raise ValueError(
"Unsupported jastrow type; {} was requested, but the only supported "
"types are: {}".format(
jastrow_config.type, ", ".join(VALID_JASTROW_TYPES)
)
)
if model_config.antisym_type == "factorized":
return FactorizedAntisymmetry(
spin_split,
compute_input_streams,
backflow,
jastrow,
rank=model_config.rank,
ndense_resnet=model_config.ndense_resnet,
nlayers_resnet=model_config.nlayers_resnet,
kernel_initializer_resnet=kernel_init_constructor(
model_config.kernel_init_resnet
),
bias_initializer_resnet=bias_init_constructor(
model_config.bias_init_resnet
),
activation_fn_resnet=_get_named_activation_fn(
model_config.activation_fn_resnet
),
resnet_use_bias=model_config.resnet_use_bias,
)
elif model_config.antisym_type == "generic":
return GenericAntisymmetry(
spin_split,
compute_input_streams,
backflow,
jastrow,
ndense_resnet=model_config.ndense_resnet,
nlayers_resnet=model_config.nlayers_resnet,
kernel_initializer_resnet=kernel_init_constructor(
model_config.kernel_init_resnet
),
bias_initializer_resnet=bias_init_constructor(
model_config.bias_init_resnet
),
activation_fn_resnet=_get_named_activation_fn(
model_config.activation_fn_resnet
),
resnet_use_bias=model_config.resnet_use_bias,
)
else:
raise ValueError(
"Unsupported explicit antisymmetry type; {} was requested".format(
model_config.antisym_type
)
)
else:
raise ValueError(
"Unsupported model type; {} was requested".format(model_config.type)
)
get_compute_input_streams_from_config(input_streams_config, ion_pos=None)
Get a function for computing input streams from a model configuration.
Source code in vmcnet/models/construct.py
def get_compute_input_streams_from_config(
input_streams_config: ConfigDict, ion_pos: Optional[Array] = None
) -> ComputeInputStreams:
"""Get a function for computing input streams from a model configuration."""
return functools.partial(
compute_input_streams,
ion_pos=ion_pos,
include_2e_stream=input_streams_config.include_2e_stream,
include_ei_norm=input_streams_config.include_ei_norm,
include_ee_norm=input_streams_config.include_ee_norm,
)
get_backflow_from_config(backflow_config, spin_split, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)
Get a FermiNet backflow from a model configuration.
Source code in vmcnet/models/construct.py
def get_backflow_from_config(
backflow_config,
spin_split,
dtype=jnp.float32,
) -> Module:
"""Get a FermiNet backflow from a model configuration."""
kernel_init_constructor, bias_init_constructor = _get_dtype_init_constructors(dtype)
residual_blocks = get_residual_blocks_for_ferminet_backflow(
spin_split,
backflow_config.ndense_list,
kernel_initializer_unmixed=kernel_init_constructor(
backflow_config.kernel_init_unmixed
),
kernel_initializer_mixed=kernel_init_constructor(
backflow_config.kernel_init_mixed
),
kernel_initializer_2e_1e_stream=kernel_init_constructor(
backflow_config.kernel_init_2e_1e_stream
),
kernel_initializer_2e_2e_stream=kernel_init_constructor(
backflow_config.kernel_init_2e_2e_stream
),
bias_initializer_1e_stream=bias_init_constructor(
backflow_config.bias_init_1e_stream
),
bias_initializer_2e_stream=bias_init_constructor(
backflow_config.bias_init_2e_stream
),
activation_fn=_get_named_activation_fn(backflow_config.activation_fn),
use_bias=backflow_config.use_bias,
one_electron_skip=backflow_config.one_electron_skip,
one_electron_skip_scale=backflow_config.one_electron_skip_scale,
two_electron_skip=backflow_config.two_electron_skip,
two_electron_skip_scale=backflow_config.two_electron_skip_scale,
cyclic_spins=backflow_config.cyclic_spins,
)
return FermiNetBackflow(residual_blocks)
get_sign_covariance_from_config(model_config, spin_split, kernel_init_constructor, dtype)
Get a sign covariance from a model config, for use in AntiequivarianceNet.
Source code in vmcnet/models/construct.py
def get_sign_covariance_from_config(
model_config: ConfigDict,
spin_split: ParticleSplit,
kernel_init_constructor: Callable[[ConfigDict], WeightInitializer],
dtype: jnp.dtype,
) -> Callable[[ArrayList], Array]:
"""Get a sign covariance from a model config, for use in AntiequivarianceNet."""
if model_config.use_products_covariance:
return ProductsSignCovariance(
1,
kernel_init_constructor(model_config.products_covariance.kernel_init),
model_config.products_covariance.register_kfac,
use_weights=model_config.products_covariance.use_weights,
)
else:
def backflow_based_equivariance(x: ArrayList) -> Array:
concat_x = jnp.concatenate(x, axis=-2)
return get_backflow_from_config(
model_config.invariance,
spin_split=spin_split,
dtype=dtype,
)(concat_x)
odd_equivariance = make_array_list_fn_sign_covariant(
backflow_based_equivariance, axis=-3
)
return lambda x: jnp.sum(odd_equivariance(x), axis=-2)
get_residual_blocks_for_ferminet_backflow(spin_split, ndense_list, kernel_initializer_unmixed, kernel_initializer_mixed, kernel_initializer_2e_1e_stream, kernel_initializer_2e_2e_stream, bias_initializer_1e_stream, bias_initializer_2e_stream, activation_fn, use_bias=True, one_electron_skip=True, one_electron_skip_scale=1.0, two_electron_skip=True, two_electron_skip_scale=1.0, cyclic_spins=True)
Construct a list of FermiNet residual blocks composed by FermiNetBackflow.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
spin_split |
int or Sequence[int] |
number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and |
required |
ndense_list |
List[Tuple[int, ...]] |
(list of (int, ...)): number of dense nodes in each of the residual blocks, (ndense_1e, optional ndense_2e). The length of this list determines the number of residual blocks which are composed on top of the input streams. If ndense_2e is specified, then then a dense layer is applied to the two-electron stream with an optional skip connection, otherwise the two-electron stream is mixed into the one-electron stream but no transformation is done. |
required |
kernel_initializer_unmixed |
WeightInitializer |
kernel initializer for the unmixed part of the one-electron stream. This initializes the part of the dense kernel which multiplies the previous one-electron stream output. Has signature (key, shape, dtype) -> Array |
required |
kernel_initializer_mixed |
WeightInitializer |
kernel initializer for the mixed part of the one-electron stream. This initializes the part of the dense kernel which multiplies the average of the previous one-electron stream output. Has signature (key, shape, dtype) -> Array |
required |
kernel_initializer_2e_1e_stream |
WeightInitializer |
kernel initializer for the two-electron part of the one-electron stream. This initializes the part of the dense kernel which multiplies the average of the previous two-electron stream which is mixed into the one-electron stream. Has signature (key, shape, dtype) -> Array |
required |
kernel_initializer_2e_2e_stream |
WeightInitializer |
kernel initializer for the two-electron stream. Has signature (key, shape, dtype) -> Array |
required |
bias_initializer_1e_stream |
WeightInitializer |
bias initializer for the one-electron stream. Has signature (key, shape, dtype) -> Array |
required |
bias_initializer_2e_stream |
WeightInitializer |
bias initializer for the two-electron stream. Has signature (key, shape, dtype) -> Array |
required |
activation_fn |
Activation |
activation function in the electron streams. Has the signature Array -> Array (shape is preserved) |
required |
use_bias |
bool |
whether to add a bias term in the electron streams. Defaults to True. |
True |
one_electron_skip |
bool |
whether to add a residual skip connection to the one-electron layer whenever the shapes of the input and output match. Defaults to True. |
True |
one_electron_skip_scale |
float |
quantity to scale the one-electron output by if a skip connection is added. Defaults to 1.0. |
1.0 |
two_electron_skip |
bool |
whether to add a residual skip connection to the two-electron layer whenever the shapes of the input and output match. Defaults to True. |
True |
two_electron_skip_scale |
float |
quantity to scale the two-electron output by if a skip connection is added. Defaults to 1.0. |
1.0 |
cyclic_spins |
bool |
whether the the concatenation in the one-electron stream should satisfy a cyclic equivariance structure, i.e. if there are three spins (1, 2, 3), then in the mixed part of the stream, after averaging but before the linear transformation, cyclic equivariance means the inputs are [(1, 2, 3), (2, 3, 1), (3, 1, 2)]. If False, then the inputs are [(1, 2, 3), (1, 2, 3), (1, 2, 3)] (as in the original FermiNet). When there are only two spins (spin-1/2 case), then this is equivalent to true spin equivariance. Defaults to False (original FermiNet). |
True |
Source code in vmcnet/models/construct.py
def get_residual_blocks_for_ferminet_backflow(
spin_split: ParticleSplit,
ndense_list: List[Tuple[int, ...]],
kernel_initializer_unmixed: WeightInitializer,
kernel_initializer_mixed: WeightInitializer,
kernel_initializer_2e_1e_stream: WeightInitializer,
kernel_initializer_2e_2e_stream: WeightInitializer,
bias_initializer_1e_stream: WeightInitializer,
bias_initializer_2e_stream: WeightInitializer,
activation_fn: Activation,
use_bias: bool = True,
one_electron_skip: bool = True,
one_electron_skip_scale: float = 1.0,
two_electron_skip: bool = True,
two_electron_skip_scale: float = 1.0,
cyclic_spins: bool = True,
) -> List[FermiNetResidualBlock]:
"""Construct a list of FermiNet residual blocks composed by FermiNetBackflow.
Arguments:
spin_split (int or Sequence[int]): number of spins to split the input equally,
or specified sequence of locations to split along the 2nd-to-last axis.
E.g., if nelec = 10, and `spin_split` = 2, then the input is split (5, 5).
If nelec = 10, and `spin_split` = (2, 4), then the input is split into
(2, 4, 4) -- note when `spin_split` is a sequence, there will be one more
spin than the length of the sequence. In the original use-case of spin-1/2
particles, `spin_split` should be either the number 2 (for closed-shell
systems) or should be a Sequence with length 1 whose element is less than
the total number of electrons.
ndense_list: (list of (int, ...)): number of dense nodes in each of the
residual blocks, (ndense_1e, optional ndense_2e). The length of this list
determines the number of residual blocks which are composed on top of the
input streams. If ndense_2e is specified, then then a dense layer is applied
to the two-electron stream with an optional skip connection, otherwise
the two-electron stream is mixed into the one-electron stream but no
transformation is done.
kernel_initializer_unmixed (WeightInitializer): kernel initializer for the
unmixed part of the one-electron stream. This initializes the part of the
dense kernel which multiplies the previous one-electron stream output. Has
signature (key, shape, dtype) -> Array
kernel_initializer_mixed (WeightInitializer): kernel initializer for the
mixed part of the one-electron stream. This initializes the part of the
dense kernel which multiplies the average of the previous one-electron
stream output. Has signature (key, shape, dtype) -> Array
kernel_initializer_2e_1e_stream (WeightInitializer): kernel initializer for the
two-electron part of the one-electron stream. This initializes the part of
the dense kernel which multiplies the average of the previous two-electron
stream which is mixed into the one-electron stream. Has signature
(key, shape, dtype) -> Array
kernel_initializer_2e_2e_stream (WeightInitializer): kernel initializer for the
two-electron stream. Has signature (key, shape, dtype) -> Array
bias_initializer_1e_stream (WeightInitializer): bias initializer for the
one-electron stream. Has signature (key, shape, dtype) -> Array
bias_initializer_2e_stream (WeightInitializer): bias initializer for the
two-electron stream. Has signature (key, shape, dtype) -> Array
activation_fn (Activation): activation function in the electron streams. Has
the signature Array -> Array (shape is preserved)
use_bias (bool, optional): whether to add a bias term in the electron streams.
Defaults to True.
one_electron_skip (bool, optional): whether to add a residual skip connection to
the one-electron layer whenever the shapes of the input and output match.
Defaults to True.
one_electron_skip_scale (float, optional): quantity to scale the one-electron
output by if a skip connection is added. Defaults to 1.0.
two_electron_skip (bool, optional): whether to add a residual skip connection to
the two-electron layer whenever the shapes of the input and output match.
Defaults to True.
two_electron_skip_scale (float, optional): quantity to scale the two-electron
output by if a skip connection is added. Defaults to 1.0.
cyclic_spins (bool, optional): whether the the concatenation in the one-electron
stream should satisfy a cyclic equivariance structure, i.e. if there are
three spins (1, 2, 3), then in the mixed part of the stream, after averaging
but before the linear transformation, cyclic equivariance means the inputs
are [(1, 2, 3), (2, 3, 1), (3, 1, 2)]. If False, then the inputs are
[(1, 2, 3), (1, 2, 3), (1, 2, 3)] (as in the original FermiNet).
When there are only two spins (spin-1/2 case), then this is equivalent to
true spin equivariance. Defaults to False (original FermiNet).
"""
residual_blocks = []
for ndense in ndense_list:
one_electron_layer = FermiNetOneElectronLayer(
spin_split,
ndense[0],
kernel_initializer_unmixed,
kernel_initializer_mixed,
kernel_initializer_2e_1e_stream,
bias_initializer_1e_stream,
activation_fn,
use_bias,
skip_connection=one_electron_skip,
skip_connection_scale=one_electron_skip_scale,
cyclic_spins=cyclic_spins,
)
two_electron_layer = None
if len(ndense) > 1:
two_electron_layer = FermiNetTwoElectronLayer(
ndense[1],
kernel_initializer_2e_2e_stream,
bias_initializer_2e_stream,
activation_fn,
use_bias,
skip_connection=two_electron_skip,
skip_connection_scale=two_electron_skip_scale,
)
residual_blocks.append(
FermiNetResidualBlock(one_electron_layer, two_electron_layer)
)
return residual_blocks
get_resnet_determinant_fn_for_ferminet(ndense, nlayers, activation, kernel_initializer, bias_initializer, use_bias=True, register_kfac=False)
Get a resnet-based determinant function for FermiNet construction.
The returned function is used as a more general way to combine the determinant outputs into the final wavefunction value, relative to the original method of a sum of products. The function takes as its first argument the number of requested output features because several variants of this method are supported, and each requires the function to generate a different output size.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ndense |
int |
the number of neurons in the dense layers of the ResNet. |
required |
nlayers |
int |
the number of layers in the ResNet. |
required |
activation |
Activation |
the activation function to use for the resnet. |
required |
kernel_initializer |
WeightInitializer |
kernel initializer for the resnet. |
required |
bias_initializer |
WeightInitializer |
bias initializer for the resnet. |
required |
use_bias |
bool |
Whether to use a bias in the ResNet. Defaults to True. |
True |
register_kfac |
bool |
Whether to register the ResNet Dense layers with KFAC. Currently, params for this ResNet explode to huge values and cause nans if register_kfac is True, so this flag defaults to false and should only be overridden with care. The reason for the instability is not known. |
False |
Returns:
Type | Description |
---|---|
(DeterminantFn) |
A resnet-based function. Has the signature dout, [nspins: (..., ndeterminants)] -> (..., dout). |
Source code in vmcnet/models/construct.py
def get_resnet_determinant_fn_for_ferminet(
ndense: int,
nlayers: int,
activation: Activation,
kernel_initializer: WeightInitializer,
bias_initializer: WeightInitializer,
use_bias: bool = True,
register_kfac: bool = False,
) -> DeterminantFn:
"""Get a resnet-based determinant function for FermiNet construction.
The returned function is used as a more general way to combine the determinant
outputs into the final wavefunction value, relative to the original method of a
sum of products. The function takes as its first argument the number of requested
output features because several variants of this method are supported, and each
requires the function to generate a different output size.
Args:
ndense (int): the number of neurons in the dense layers
of the ResNet.
nlayers (int): the number of layers in the ResNet.
activation (Activation): the activation function to use for the resnet.
kernel_initializer (WeightInitializer): kernel initializer for the resnet.
bias_initializer (WeightInitializer): bias initializer for the resnet.
use_bias (bool): Whether to use a bias in the ResNet. Defaults to True.
register_kfac (bool): Whether to register the ResNet Dense layers with KFAC.
Currently, params for this ResNet explode to huge values and cause nans
if register_kfac is True, so this flag defaults to false and should only be
overridden with care. The reason for the instability is not known.
Returns:
(DeterminantFn): A resnet-based function. Has the signature
dout, [nspins: (..., ndeterminants)] -> (..., dout).
"""
def fn(dout: int, det_values: ArrayList) -> Array:
concat_values = jnp.concatenate(det_values, axis=-1)
return SimpleResNet(
ndense,
dout,
nlayers,
activation,
kernel_initializer,
bias_initializer,
use_bias,
register_kfac=register_kfac,
)(concat_values)
return fn