Skip to content

jastrow

Jastrow factors.

BackflowJastrow (Module) dataclass

Backflow-based general permutation invariant Jastrow.

Attributes:

Name Type Description
backflow Callable or None

function which computes position features from the electron positions. Has the signature ( stream_1e of shape (..., n, d'), optional stream_2e of shape (..., nelec, nelec, d2), ) -> stream_1e of shape (..., n, d'). Can pass None here to use a stream_1e from an already computed backflow.

logabs bool

whether to return the log jastrow (True) or the jastrow (False). Defaults to True.

setup(self)

Set up the dense layers for each split.

Source code in vmcnet/models/jastrow.py
def setup(self):
    """Set up the dense layers for each split."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._backflow = self.backflow

__call__(self, input_stream_1e, input_stream_2e, stream_1e, r_ei, r_ee) special

Compute backflow-based general permutation invariant Jastrow.

Parameters:

Name Type Description Default
input_stream_1e Array

input one-electron stream

required
input_stream_2e Array

input two-electron stream

required
stream_1e Array

one-electron stream, post-backflow

required
r_ei Array

electron-ion displacements with shape (..., nelec, nion, d); unused

required
r_ee Array

electron-electron displacements with shape (..., nelec, nelec, d); unused

required

Returns:

Type Description
Array

-mean_i ||Backflow_i||, or exp(-mean_i ||Backflow_i||) if logabs is False

Source code in vmcnet/models/jastrow.py
@flax.linen.compact
def __call__(  # type: ignore[override]
    self,
    input_stream_1e: Array,
    input_stream_2e: Array,
    stream_1e: Array,
    r_ei: Array,
    r_ee: Array,
) -> Array:
    """Compute backflow-based general permutation invariant Jastrow.

    Args:
        input_stream_1e (Array): input one-electron stream
        input_stream_2e (Array): input two-electron stream
        stream_1e (Array): one-electron stream, post-backflow
        r_ei (Array): electron-ion displacements with shape
            (..., nelec, nion, d); unused
        r_ee (Array): electron-electron displacements with shape
            (..., nelec, nelec, d); unused

    Returns:
        Array: -mean_i ||Backflow_i||, or exp(-mean_i ||Backflow_i||) if
        logabs is False
    """
    del r_ei, r_ee

    if self._backflow is not None:
        stream_1e = self._backflow(input_stream_1e, input_stream_2e)

    log_jastrow = -jnp.mean(jnp.linalg.norm(stream_1e, axis=-1), axis=-1)

    if self.logabs:
        return log_jastrow

    return jnp.exp(log_jastrow)

OneBodyExpDecay (Module) dataclass

Creates an isotropic exponential decay one-body Jastrow model.

The decay is centered at the coordinates of the nuclei, and the electron-nuclei displacements are multiplied by trainable params before a sum and exp(-x). The decay is isotropic and equal for all electrons, so it computes

-sum_ij ||a_j * (elec_i - ion_j)||

or the exponential if logabs is False. The tensor a_j * (elec_i - ion_j) is computed with a split dense operation.

Attributes:

Name Type Description
kernel_initializer WeightInitializer

kernel initializer for the decay rates a_j. This initializes a single decay rate number per ion. Has signature (key, shape, dtype) -> Array

logabs bool

whether to compute -sum_ij ||a_j * (elec_i - ion_j)||, when logabs is True, or exp of that expression when logabs is False. Defaults to True.

setup(self)

Setup the kernel initializer.

Source code in vmcnet/models/jastrow.py
def setup(self):
    """Setup the kernel initializer."""
    # workaround MyPy's typing error for callable attribute, see
    # https://github.com/python/mypy/issues/708
    self._kernel_initializer = self.kernel_initializer

__call__(self, input_stream_1e, input_stream_2e, stream_1e, r_ei, r_ee) special

Transform electron-ion displacements into an exp decay one-body Jastrow.

Parameters:

Name Type Description Default
input_stream_1e Array

input one-electron stream; unused

required
input_stream_2e Array

input two-electron stream; unused

required
stream_1e Array

one-electron stream, post-backflow; unused

required
r_ei Array

electron-ion displacements of shape (..., nelec, nion, d)

required
r_ee Array

electron-electron displacements of shape (..., nelec, nelec, d); unused

required

Returns:

Type Description
Array

-sum_ij ||a_j * (elec_i - ion_j)||, when self.logabs is True, or exp of that expression when self.logabs is False. If the input has shape (batch_dims, nelec, nion, d), then the output has shape (batch_dims,)

Source code in vmcnet/models/jastrow.py
@flax.linen.compact
def __call__(  # type: ignore[override]
    self,
    input_stream_1e: Array,
    input_stream_2e: Array,
    stream_1e: Array,
    r_ei: Array,
    r_ee: Array,
) -> Array:
    """Transform electron-ion displacements into an exp decay one-body Jastrow.

    Args:
        input_stream_1e (Array): input one-electron stream; unused
        input_stream_2e (Array): input two-electron stream; unused
        stream_1e (Array): one-electron stream, post-backflow; unused
        r_ei (Array): electron-ion displacements of shape
            (..., nelec, nion, d)
        r_ee (Array): electron-electron displacements of shape
            (..., nelec, nelec, d); unused

    Returns:
        Array: -sum_ij ||a_j * (elec_i - ion_j)||, when self.logabs is True,
        or exp of that expression when self.logabs is False. If the input has shape
        (batch_dims, nelec, nion, d), then the output has shape (batch_dims,)
    """
    del input_stream_1e, input_stream_2e, stream_1e, r_ee
    # scale_out has shape (..., nelec, 1, nion, d)
    scale_out = _isotropy_on_leaf(
        r_ei, 1, self._kernel_initializer, register_kfac=True
    )
    scaled_distances = jnp.linalg.norm(scale_out, axis=-1)

    abs_lin_comb_distances = jnp.sum(scaled_distances, axis=(-1, -2, -3))

    if self.logabs:
        return -abs_lin_comb_distances

    return jnp.exp(-abs_lin_comb_distances)

TwoBodyExpDecay (Module) dataclass

Isotropic exponential decay two-body Jastrow model.

The decay is isotropic in the sense that each electron-nuclei and electron-electron term is isotropic, i.e. radially symmetric. The computed interactions are:

sum_i(-sum_j Z_j ||elec_i - ion_j|| + sum_k Q ||elec_i - elec_k||)

or the exponential if logabs is False. Z_j and Q are initialized to init_ei_strength and init_ee_strength, respectively, and are trainable if trainable is True.

Attributes:

Name Type Description
init_ei_strength Array or Sequence[float]

1-d array or sequence of length nion which gives the initial strength of the electron-nucleus interaction per ion

init_ee_strength float

initial strength of the electron-electron interaction. Defaults to 1.0.

log_scale_factor float

Amount to add to the log jastrow (amounts to a multiplicative factor after exponentiation). Defaults to 0.0.

register_kfac bool

whether to register the computation with KFAC. Defaults to True.

logabs bool

whether to return the log jastrow (True) or the jastrow (False). Defaults to True.

trainable bool

whether to allow the jastrow to be trainable. Defaults to True.

__call__(self, input_stream_1e, input_stream_2e, stream_1e, r_ei, r_ee) special

Compute jastrow with both electron-ion and electron-electron effects.

Parameters:

Name Type Description Default
input_stream_1e Array

input one-electron stream; unused

required
input_stream_2e Array

input two-electron stream; unused

required
stream_1e Array

one-electron stream, post-backflow; unused

required
r_ei Array

electron-ion displacements with shape (..., nelec, nion, d)

required
r_ee Array

electron-electron displacements with shape (..., nelec, nelec, d)

required

Returns:

Type Description
Array

sum_i(-sum_j Z_j ||elec_i - ion_j|| + sum_k Q ||elec_i - elec_k||),

where Z_j and Q are trainable if trainable is true, and an exponential is taken if logabs is False

Source code in vmcnet/models/jastrow.py
@flax.linen.compact
def __call__(  # type: ignore[override]
    self,
    input_stream_1e: Array,
    input_stream_2e: Array,
    stream_1e: Array,
    r_ei: Array,
    r_ee: Array,
) -> Array:
    """Compute jastrow with both electron-ion and electron-electron effects.

    Args:
        input_stream_1e (Array): input one-electron stream; unused
        input_stream_2e (Array): input two-electron stream; unused
        stream_1e (Array): one-electron stream, post-backflow; unused
        r_ei (Array): electron-ion displacements with shape
            (..., nelec, nion, d)
        r_ee (Array): electron-electron displacements with shape
            (..., nelec, nelec, d)

    Returns:
        Array:

            sum_i(-sum_j Z_j ||elec_i - ion_j|| + sum_k Q ||elec_i - elec_k||),

        where Z_j and Q are trainable if trainable is true, and an exponential is
        taken if logabs is False
    """
    del input_stream_1e, input_stream_2e, stream_1e
    ei_distances = jnp.linalg.norm(r_ei, axis=-1)
    ee_distances = jnp.squeeze(compute_ee_norm_with_safe_diag(r_ee), axis=-1)
    sum_ee_effect = jnp.sum(jnp.triu(ee_distances), axis=-1, keepdims=True)

    if self.trainable:
        split_over_ions = jnp.split(ei_distances, ei_distances.shape[-1], axis=-1)
        # TODO: potentially add support for this to SplitDense or otherwise?
        split_scaled_ei_distances = [
            Dense(
                1,
                kernel_init=get_constant_init(self.init_ei_strength[i]),
                use_bias=False,
                register_kfac=self.register_kfac,
            )(single_ion_displacement)
            for i, single_ion_displacement in enumerate(split_over_ions)
        ]
        scaled_ei_distances = jnp.concatenate(split_scaled_ei_distances, axis=-1)

        sum_ee_effect = Dense(
            1,
            kernel_init=get_constant_init(self.init_ee_strength),
            use_bias=False,
            register_kfac=self.register_kfac,
        )(sum_ee_effect)
    else:
        scaled_ei_distances = self.init_ei_strength * ei_distances
        sum_ee_effect = self.init_ee_strength * sum_ee_effect

    sum_ee_effect = jnp.squeeze(sum_ee_effect, axis=-1)
    sum_ei_effect = jnp.sum(scaled_ei_distances, axis=-1)
    unscaled_interaction = jnp.sum(sum_ee_effect - sum_ei_effect, axis=-1)
    interaction = unscaled_interaction + self.log_scale_factor

    if self.logabs:
        return interaction

    return jnp.exp(interaction)

get_two_body_decay_scaled_for_chargeless_molecules(ion_pos, ion_charges, init_ee_strength=1.0, register_kfac=True, logabs=True, trainable=True)

Make molecular decay jastrow, scaled for chargeless molecules.

The scale factor is chosen so that the log jastrow is initialized to 0 when electrons are at ion positions.

Parameters:

Name Type Description Default
ion_pos Array

an (nion, d) array of ion positions.

required
ion_charges Array

an (nion,) array of ion charges, in units of one elementary charge (the charge of one electron)

required
init_ee_strength float

the initial strength of the electron-electron interaction. Defaults to 1.0.

1.0
register_kfac bool

whether to register the computation with KFAC. Defaults to True.

True
logabs bool

whether to return the log jastrow (True) or the jastrow (False). Defaults to True.

True
trainable bool

whether to allow the jastrow to be trainable. Defaults to True.

True

Returns:

Type Description
Callable

a flax Module with signature (r_ei, r_ee) -> jastrow or log jastrow

Source code in vmcnet/models/jastrow.py
def get_two_body_decay_scaled_for_chargeless_molecules(
    ion_pos: Array,
    ion_charges: Array,
    init_ee_strength: float = 1.0,
    register_kfac: bool = True,
    logabs: bool = True,
    trainable: bool = True,
) -> Jastrow:
    """Make molecular decay jastrow, scaled for chargeless molecules.

    The scale factor is chosen so that the log jastrow is initialized to 0 when
    electrons are at ion positions.

    Args:
        ion_pos (Array): an (nion, d) array of ion positions.
        ion_charges (Array): an (nion,) array of ion charges, in units of one
            elementary charge (the charge of one electron)
        init_ee_strength (float, optional): the initial strength of the
            electron-electron interaction. Defaults to 1.0.
        register_kfac (bool, optional): whether to register the computation with KFAC.
            Defaults to True.
        logabs (bool, optional): whether to return the log jastrow (True) or the jastrow
            (False). Defaults to True.
        trainable (bool, optional): whether to allow the jastrow to be trainable.
            Defaults to True.

    Returns:
        Callable: a flax Module with signature (r_ei, r_ee) -> jastrow or log jastrow
    """
    r_ii, charge_charge_prods = physics.potential._get_ion_ion_info(
        ion_pos, ion_charges
    )
    jastrow_scale_factor = 0.5 * jnp.sum(
        jnp.linalg.norm(r_ii, axis=-1) * charge_charge_prods
    )
    jastrow = TwoBodyExpDecay(
        ion_charges,
        init_ee_strength,
        log_scale_factor=jastrow_scale_factor,
        register_kfac=register_kfac,
        logabs=logabs,
        trainable=trainable,
    )
    return jastrow