core
Core model building parts.
AddedModel (Module)
dataclass
A model made from added parts.
Attributes:
Name | Type | Description |
---|---|---|
submodels |
Sequence[Union[Callable, Module]] |
a sequence of functions or Modules which are called on the same args and can be added |
__call__(self, *args)
special
Add the outputs of the submodels.
Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, *args):
"""Add the outputs of the submodels."""
return sum(submodel(*args) for submodel in self.submodels)
ComposedModel (Module)
dataclass
A model made from composable parts.
Attributes:
Name | Type | Description |
---|---|---|
submodels |
Sequence[Union[Callable, Module]] |
a sequence of functions or Modules which can be composed sequentially |
__call__(self, x)
special
Call submodels on the output of the previous one one at a time.
Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, x):
"""Call submodels on the output of the previous one one at a time."""
outputs = x
for model in self.submodels:
outputs = model(outputs)
return outputs
Dense (Module)
dataclass
A linear transformation applied over the last dimension of the input.
This is a copy of the flax Dense layer, but with registration of the weights for use with KFAC.
Attributes:
Name | Type | Description |
---|---|---|
features |
int |
the number of output features. |
kernel_init |
WeightInitializer |
initializer function for the weight matrix. Defaults to orthogonal initialization. |
bias_init |
WeightInitializer |
initializer function for the bias. Defaults to random normal initialization. |
use_bias |
bool |
whether to add a bias to the output. Defaults to True. |
register_kfac |
bool |
whether to register the computation with KFAC. Defaults to True. |
__call__(self, inputs)
special
Applies a linear transformation with optional bias along the last dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Array |
The nd-array to be transformed. |
required |
Returns:
Type | Description |
---|---|
Array |
The transformed input. |
Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, inputs: Array) -> Array: # type: ignore[override]
"""Applies a linear transformation with optional bias along the last dimension.
Args:
inputs (Array): The nd-array to be transformed.
Returns:
Array: The transformed input.
"""
kernel = self.param(
"kernel", self.kernel_init, (inputs.shape[-1], self.features)
)
y = jnp.dot(inputs, kernel)
bias = None
if self.use_bias:
bias = self.param("bias", self.bias_init, (self.features,))
y = y + bias
if self.register_kfac:
return register_batch_dense(y, inputs, kernel, bias)
else:
return y
LogDomainDense (Module)
dataclass
A linear transformation applied on the last axis of the input, in the log domain.
If the inputs are (sign(x), log(abs(x))), the outputs are (sign(Wx + b), log(abs(Wx + b))).
The bias is implemented by extending the inputs with a vector of ones.
Attributes:
Name | Type | Description |
---|---|---|
features |
int |
the number of output features. |
kernel_init |
WeightInitializer |
initializer function for the weight matrix. Defaults to orthogonal initialization. |
use_bias |
bool |
whether to add a bias to the output. Defaults to True. |
register_kfac |
bool |
whether to register the computation with KFAC. Defaults to True. |
__call__(self, x)
special
Applies a linear transformation with optional bias along the last dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
SLArray |
The nd-array in slog form to be transformed. |
required |
Returns:
Type | Description |
---|---|
SLArray |
The transformed input, in slog form. |
Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, x: SLArray) -> SLArray: # type: ignore[override]
"""Applies a linear transformation with optional bias along the last dimension.
Args:
x (SLArray): The nd-array in slog form to be transformed.
Returns:
SLArray: The transformed input, in slog form.
"""
sign_x, log_abs_x = x
input_dim = log_abs_x.shape[-1]
if self.use_bias:
input_dim += 1
sign_x = jnp.concatenate([sign_x, jnp.ones_like(sign_x[..., 0:1])], axis=-1)
log_abs_x = jnp.concatenate(
[log_abs_x, jnp.zeros_like(log_abs_x[..., 0:1])], axis=-1
)
kernel = self.param("kernel", self.kernel_init, (input_dim, self.features))
return log_linear_exp(
sign_x,
log_abs_x,
kernel,
axis=-1,
register_kfac=self.register_kfac,
)
LogDomainResNet (Module)
dataclass
Simplest fully-connected ResNet, implemented in the log domain.
Attributes:
Name | Type | Description |
---|---|---|
ndense_inner |
int |
number of dense nodes in layers before the final layer. |
ndense_final |
int |
number of output features, i.e. the number of dense nodes in the final Dense call. |
nlayers |
int |
number of dense layers applied to the input, including the final layer. If this is 0, the final dense layer will still be applied. |
activation_fn |
SLActivation |
activation function between intermediate layers (is not applied after the final dense layer). Has the signature SLArray -> SLArray (shape is preserved). |
kernel_init |
WeightInitializer |
initializer function for the weight matrices of each layer. Defaults to orthogonal initialization. |
use_bias |
bool |
whether the dense layers should all have bias terms or not. Defaults to True. |
setup(self)
Setup dense layers.
Source code in vmcnet/models/core.py
def setup(self):
"""Setup dense layers."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._activation_fn = self.activation_fn
self.inner_dense = [
LogDomainDense(
self.ndense_inner,
kernel_init=self.kernel_init,
use_bias=self.use_bias,
)
for _ in range(self.nlayers - 1)
]
self.final_dense = LogDomainDense(
self.ndense_final,
kernel_init=self.kernel_init,
use_bias=False,
)
__call__(self, x)
special
Repeated application of (dense layer -> activation -> optional skip) block.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
SLArray |
an slog input array of shape (..., d) |
required |
Returns:
Type | Description |
---|---|
SLArray |
slog array of shape (..., self.ndense_final) |
Source code in vmcnet/models/core.py
def __call__(self, x: SLArray) -> SLArray: # type: ignore[override]
"""Repeated application of (dense layer -> activation -> optional skip) block.
Args:
x (SLArray): an slog input array of shape (..., d)
Returns:
SLArray: slog array of shape (..., self.ndense_final)
"""
for dense_layer in self.inner_dense:
prev_x = x
x = dense_layer(prev_x)
x = self._activation_fn(x)
if _sl_valid_skip(prev_x, x):
x = slog_sum(x, prev_x)
return self.final_dense(x)
SimpleResNet (Module)
dataclass
Simplest fully-connected ResNet.
Attributes:
Name | Type | Description |
---|---|---|
ndense_inner |
int |
number of dense nodes in layers before the final layer. |
ndense_final |
int |
number of output features, i.e. the number of dense nodes in the final Dense call. |
nlayers |
int |
number of dense layers applied to the input, including the final layer. If this is 0, the final dense layer will still be applied. |
kernel_init |
WeightInitializer |
initializer function for the weight matrices of each layer. Defaults to orthogonal initialization. |
bias_init |
WeightInitializer |
initializer function for the bias. Defaults to random normal initialization. |
activation_fn |
Activation |
activation function between intermediate layers (is not applied after the final dense layer). Has the signature Array -> Array (shape is preserved) |
use_bias |
bool |
whether the dense layers should all have bias terms or not. Defaults to True. |
register_kfac |
bool |
whether to register the dense layers with KFAC. Defaults to True. |
setup(self)
Setup dense layers.
Source code in vmcnet/models/core.py
def setup(self):
"""Setup dense layers."""
# workaround MyPy's typing error for callable attribute, see
# https://github.com/python/mypy/issues/708
self._activation_fn = self.activation_fn
self.inner_dense = [
Dense(
self.ndense_inner,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=self.use_bias,
register_kfac=self.register_kfac,
)
for _ in range(self.nlayers - 1)
]
self.final_dense = Dense(
self.ndense_final,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
use_bias=False,
register_kfac=self.register_kfac,
)
__call__(self, x)
special
Repeated application of (dense layer -> activation -> optional skip) block.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Array |
an input array of shape (..., d) |
required |
Returns:
Type | Description |
---|---|
Array |
array of shape (..., self.ndense_final) |
Source code in vmcnet/models/core.py
def __call__(self, x: Array) -> Array: # type: ignore[override]
"""Repeated application of (dense layer -> activation -> optional skip) block.
Args:
x (Array): an input array of shape (..., d)
Returns:
Array: array of shape (..., self.ndense_final)
"""
for dense_layer in self.inner_dense:
prev_x = x
x = dense_layer(prev_x)
x = self._activation_fn(x)
if _valid_skip(prev_x, x):
x = cast(Array, x + prev_x)
return self.final_dense(x)
compute_ee_norm_with_safe_diag(r_ee)
Get electron-electron distances with a safe derivative along the diagonal.
Avoids computing norm(x - x) along the diagonal, since autograd will be unhappy about differentiating through the norm function evaluated at 0. Instead compute 0 * norm(x - x + 1) along the diagonal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
r_ee |
Array |
electron-electron displacements wth shape (..., n, n, d) |
required |
Returns:
Type | Description |
---|---|
Array |
electron-electrondists with shape (..., n, n, 1) |
Source code in vmcnet/models/core.py
def compute_ee_norm_with_safe_diag(r_ee):
"""Get electron-electron distances with a safe derivative along the diagonal.
Avoids computing norm(x - x) along the diagonal, since autograd will be unhappy
about differentiating through the norm function evaluated at 0. Instead compute
0 * norm(x - x + 1) along the diagonal.
Args:
r_ee (Array): electron-electron displacements wth shape (..., n, n, d)
Returns:
Array: electron-electrondists with shape (..., n, n, 1)
"""
n = r_ee.shape[-2]
eye_n = jnp.expand_dims(jnp.eye(n), axis=-1)
r_ee_diag_ones = r_ee + eye_n
return jnp.linalg.norm(r_ee_diag_ones, axis=-1, keepdims=True) * (1.0 - eye_n)
is_tuple_of_arrays(x)
Returns True if x is a tuple of Array objects.
Source code in vmcnet/models/core.py
def is_tuple_of_arrays(x: PyTree) -> bool:
"""Returns True if x is a tuple of Array objects."""
return isinstance(x, tuple) and all(isinstance(x_i, jnp.ndarray) for x_i in x)
get_alternating_signs(n)
Return alternating series of 1 and -1, of length n.
Source code in vmcnet/models/core.py
def get_alternating_signs(n: int) -> Array:
"""Return alternating series of 1 and -1, of length n."""
return jax.ops.index_update(jnp.ones(n), jax.ops.index[1::2], -1.0)
get_nsplits(split)
Get the number of splits from a particle split specification.
Source code in vmcnet/models/core.py
def get_nsplits(split: ParticleSplit) -> int:
"""Get the number of splits from a particle split specification."""
if isinstance(split, int):
return split
return len(split) + 1
get_nelec_per_split(split, nelec_total)
From a particle split and nelec_total, get the number of particles per split.
If the number of particles per split is nelec_per_spin = (n1, n2, ..., nk), then split should be jnp.cumsum(nelec_per_spin)[:-1], or an integer of these are all equal. This function is the inverse of this operation.
Source code in vmcnet/models/core.py
def get_nelec_per_split(split: ParticleSplit, nelec_total: int) -> Tuple[int, ...]:
"""From a particle split and nelec_total, get the number of particles per split.
If the number of particles per split is nelec_per_spin = (n1, n2, ..., nk), then
split should be jnp.cumsum(nelec_per_spin)[:-1], or an integer of these are all
equal. This function is the inverse of this operation.
"""
if isinstance(split, int):
return (nelec_total // split,) * split
else:
spin_diffs = jnp.diff(jnp.array(split))
return (
split[0],
*tuple([int(i) for i in spin_diffs]),
nelec_total - split[-1],
)
get_spin_split(n_per_split)
Calculate spin split from n_per_split, making sure to output a Tuple of ints.
Source code in vmcnet/models/core.py
def get_spin_split(n_per_split: Union[Sequence[int], Array]) -> Tuple[int, ...]:
"""Calculate spin split from n_per_split, making sure to output a Tuple of ints."""
cumsum = np.cumsum(n_per_split[:-1])
# Convert to tuple of python ints.
return tuple([int(i) for i in cumsum])