Skip to content

core

Core model building parts.

AddedModel (Module) dataclass

A model made from added parts.

Attributes:

Name Type Description
submodels Sequence[Union[Callable, Module]]

a sequence of functions or Modules which are called on the same args and can be added

__call__(self, *args) special

Add the outputs of the submodels.

Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, *args):
    """Add the outputs of the submodels."""
    return sum(submodel(*args) for submodel in self.submodels)

ComposedModel (Module) dataclass

A model made from composable parts.

Attributes:

Name Type Description
submodels Sequence[Union[Callable, Module]]

a sequence of functions or Modules which can be composed sequentially

__call__(self, x) special

Call submodels on the output of the previous one one at a time.

Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, x):
    """Call submodels on the output of the previous one one at a time."""
    outputs = x
    for model in self.submodels:
        outputs = model(outputs)
    return outputs

Dense (Module) dataclass

A linear transformation applied over the last dimension of the input.

This is a copy of the flax Dense layer, but with registration of the weights for use with KFAC.

Attributes:

Name Type Description
features int

the number of output features.

kernel_init WeightInitializer

initializer function for the weight matrix. Defaults to orthogonal initialization.

bias_init WeightInitializer

initializer function for the bias. Defaults to random normal initialization.

use_bias bool

whether to add a bias to the output. Defaults to True.

register_kfac bool

whether to register the computation with KFAC. Defaults to True.

__call__(self, inputs) special

Applies a linear transformation with optional bias along the last dimension.

Parameters:

Name Type Description Default
inputs Array

The nd-array to be transformed.

required

Returns:

Type Description
Array

The transformed input.

Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, inputs: Array) -> Array:  # type: ignore[override]
    """Applies a linear transformation with optional bias along the last dimension.

    Args:
        inputs (Array): The nd-array to be transformed.

    Returns:
        Array: The transformed input.
    """
    kernel = self.param(
        "kernel", self.kernel_init, (inputs.shape[-1], self.features)
    )
    y = jnp.dot(inputs, kernel)
    bias = None
    if self.use_bias:
        bias = self.param("bias", self.bias_init, (self.features,))
        y = y + bias

    if self.register_kfac:
        return register_batch_dense(y, inputs, kernel, bias)
    else:
        return y

LogDomainDense (Module) dataclass

A linear transformation applied on the last axis of the input, in the log domain.

If the inputs are (sign(x), log(abs(x))), the outputs are (sign(Wx + b), log(abs(Wx + b))).

The bias is implemented by extending the inputs with a vector of ones.

Attributes:

Name Type Description
features int

the number of output features.

kernel_init WeightInitializer

initializer function for the weight matrix. Defaults to orthogonal initialization.

use_bias bool

whether to add a bias to the output. Defaults to True.

register_kfac bool

whether to register the computation with KFAC. Defaults to True.

__call__(self, x) special

Applies a linear transformation with optional bias along the last dimension.

Parameters:

Name Type Description Default
x SLArray

The nd-array in slog form to be transformed.

required

Returns:

Type Description
SLArray

The transformed input, in slog form.

Source code in vmcnet/models/core.py
@flax.linen.compact
def __call__(self, x: SLArray) -> SLArray:  # type: ignore[override]
    """Applies a linear transformation with optional bias along the last dimension.

    Args:
        x (SLArray): The nd-array in slog form to be transformed.

    Returns:
        SLArray: The transformed input, in slog form.
    """
    sign_x, log_abs_x = x
    input_dim = log_abs_x.shape[-1]

    if self.use_bias:
        input_dim += 1
        sign_x = jnp.concatenate([sign_x, jnp.ones_like(sign_x[..., 0:1])], axis=-1)
        log_abs_x = jnp.concatenate(
            [log_abs_x, jnp.zeros_like(log_abs_x[..., 0:1])], axis=-1
        )

    kernel = self.param("kernel", self.kernel_init, (input_dim, self.features))

    return log_linear_exp(
        sign_x,
        log_abs_x,
        kernel,
        axis=-1,
        register_kfac=self.register_kfac,
    )

LogDomainResNet (Module) dataclass

Simplest fully-connected ResNet, implemented in the log domain.

Attributes:

Name Type Description
ndense_inner int

number of dense nodes in layers before the final layer.

ndense_final int

number of output features, i.e. the number of dense nodes in the final Dense call.

nlayers int

number of dense layers applied to the input, including the final layer. If this is 0, the final dense layer will still be applied.

activation_fn SLActivation

activation function between intermediate layers (is not applied after the final dense layer). Has the signature SLArray -> SLArray (shape is preserved).

kernel_init WeightInitializer

initializer function for the weight matrices of each layer. Defaults to orthogonal initialization.

use_bias bool

whether the dense layers should all have bias terms or not. Defaults to True.

setup(self)

Setup dense layers.

Source code in vmcnet/models/core.py
def setup(self):
    """Setup dense layers."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._activation_fn = self.activation_fn

    self.inner_dense = [
        LogDomainDense(
            self.ndense_inner,
            kernel_init=self.kernel_init,
            use_bias=self.use_bias,
        )
        for _ in range(self.nlayers - 1)
    ]
    self.final_dense = LogDomainDense(
        self.ndense_final,
        kernel_init=self.kernel_init,
        use_bias=False,
    )

__call__(self, x) special

Repeated application of (dense layer -> activation -> optional skip) block.

Parameters:

Name Type Description Default
x SLArray

an slog input array of shape (..., d)

required

Returns:

Type Description
SLArray

slog array of shape (..., self.ndense_final)

Source code in vmcnet/models/core.py
def __call__(self, x: SLArray) -> SLArray:  # type: ignore[override]
    """Repeated application of (dense layer -> activation -> optional skip) block.

    Args:
        x (SLArray): an slog input array of shape (..., d)

    Returns:
        SLArray: slog array of shape (..., self.ndense_final)
    """
    for dense_layer in self.inner_dense:
        prev_x = x
        x = dense_layer(prev_x)
        x = self._activation_fn(x)
        if _sl_valid_skip(prev_x, x):
            x = slog_sum(x, prev_x)

    return self.final_dense(x)

SimpleResNet (Module) dataclass

Simplest fully-connected ResNet.

Attributes:

Name Type Description
ndense_inner int

number of dense nodes in layers before the final layer.

ndense_final int

number of output features, i.e. the number of dense nodes in the final Dense call.

nlayers int

number of dense layers applied to the input, including the final layer. If this is 0, the final dense layer will still be applied.

kernel_init WeightInitializer

initializer function for the weight matrices of each layer. Defaults to orthogonal initialization.

bias_init WeightInitializer

initializer function for the bias. Defaults to random normal initialization.

activation_fn Activation

activation function between intermediate layers (is not applied after the final dense layer). Has the signature Array -> Array (shape is preserved)

use_bias bool

whether the dense layers should all have bias terms or not. Defaults to True.

register_kfac bool

whether to register the dense layers with KFAC. Defaults to True.

setup(self)

Setup dense layers.

Source code in vmcnet/models/core.py
def setup(self):
    """Setup dense layers."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._activation_fn = self.activation_fn

    self.inner_dense = [
        Dense(
            self.ndense_inner,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            use_bias=self.use_bias,
            register_kfac=self.register_kfac,
        )
        for _ in range(self.nlayers - 1)
    ]
    self.final_dense = Dense(
        self.ndense_final,
        kernel_init=self.kernel_init,
        bias_init=self.bias_init,
        use_bias=False,
        register_kfac=self.register_kfac,
    )

__call__(self, x) special

Repeated application of (dense layer -> activation -> optional skip) block.

Parameters:

Name Type Description Default
x Array

an input array of shape (..., d)

required

Returns:

Type Description
Array

array of shape (..., self.ndense_final)

Source code in vmcnet/models/core.py
def __call__(self, x: Array) -> Array:  # type: ignore[override]
    """Repeated application of (dense layer -> activation -> optional skip) block.

    Args:
        x (Array): an input array of shape (..., d)

    Returns:
        Array: array of shape (..., self.ndense_final)
    """
    for dense_layer in self.inner_dense:
        prev_x = x
        x = dense_layer(prev_x)
        x = self._activation_fn(x)
        if _valid_skip(prev_x, x):
            x = cast(Array, x + prev_x)

    return self.final_dense(x)

compute_ee_norm_with_safe_diag(r_ee)

Get electron-electron distances with a safe derivative along the diagonal.

Avoids computing norm(x - x) along the diagonal, since autograd will be unhappy about differentiating through the norm function evaluated at 0. Instead compute 0 * norm(x - x + 1) along the diagonal.

Parameters:

Name Type Description Default
r_ee Array

electron-electron displacements wth shape (..., n, n, d)

required

Returns:

Type Description
Array

electron-electrondists with shape (..., n, n, 1)

Source code in vmcnet/models/core.py
def compute_ee_norm_with_safe_diag(r_ee):
    """Get electron-electron distances with a safe derivative along the diagonal.

    Avoids computing norm(x - x) along the diagonal, since autograd will be unhappy
    about differentiating through the norm function evaluated at 0. Instead compute
    0 * norm(x - x + 1) along the diagonal.

    Args:
        r_ee (Array): electron-electron displacements wth shape (..., n, n, d)

    Returns:
        Array: electron-electrondists with shape (..., n, n, 1)
    """
    n = r_ee.shape[-2]
    eye_n = jnp.expand_dims(jnp.eye(n), axis=-1)
    r_ee_diag_ones = r_ee + eye_n
    return jnp.linalg.norm(r_ee_diag_ones, axis=-1, keepdims=True) * (1.0 - eye_n)

is_tuple_of_arrays(x)

Returns True if x is a tuple of Array objects.

Source code in vmcnet/models/core.py
def is_tuple_of_arrays(x: PyTree) -> bool:
    """Returns True if x is a tuple of Array objects."""
    return isinstance(x, tuple) and all(isinstance(x_i, jnp.ndarray) for x_i in x)

get_alternating_signs(n)

Return alternating series of 1 and -1, of length n.

Source code in vmcnet/models/core.py
def get_alternating_signs(n: int) -> Array:
    """Return alternating series of 1 and -1, of length n."""
    return jax.ops.index_update(jnp.ones(n), jax.ops.index[1::2], -1.0)

get_nsplits(split)

Get the number of splits from a particle split specification.

Source code in vmcnet/models/core.py
def get_nsplits(split: ParticleSplit) -> int:
    """Get the number of splits from a particle split specification."""
    if isinstance(split, int):
        return split

    return len(split) + 1

get_nelec_per_split(split, nelec_total)

From a particle split and nelec_total, get the number of particles per split.

If the number of particles per split is nelec_per_spin = (n1, n2, ..., nk), then split should be jnp.cumsum(nelec_per_spin)[:-1], or an integer of these are all equal. This function is the inverse of this operation.

Source code in vmcnet/models/core.py
def get_nelec_per_split(split: ParticleSplit, nelec_total: int) -> Tuple[int, ...]:
    """From a particle split and nelec_total, get the number of particles per split.

    If the number of particles per split is nelec_per_spin = (n1, n2, ..., nk), then
    split should be jnp.cumsum(nelec_per_spin)[:-1], or an integer of these are all
    equal. This function is the inverse of this operation.
    """
    if isinstance(split, int):
        return (nelec_total // split,) * split
    else:
        spin_diffs = jnp.diff(jnp.array(split))
        return (
            split[0],
            *tuple([int(i) for i in spin_diffs]),
            nelec_total - split[-1],
        )

get_spin_split(n_per_split)

Calculate spin split from n_per_split, making sure to output a Tuple of ints.

Source code in vmcnet/models/core.py
def get_spin_split(n_per_split: Union[Sequence[int], Array]) -> Tuple[int, ...]:
    """Calculate spin split from n_per_split, making sure to output a Tuple of ints."""
    cumsum = np.cumsum(n_per_split[:-1])
    # Convert to tuple of python ints.
    return tuple([int(i) for i in cumsum])