Helper function for log sum exp trick with weights.

log_linear_exp(signs, vals, weights=None, axis=0, register_kfac=True)

Stably compute sign and log(abs(.)) of sum_i(sign_i * w_ij * exp(vals_i)) + b_j.

In order to avoid overflow when computing

log(abs(sum_i(sign_i * w_ij * exp(vals_i)))),

the largest exp(val_i) is divided out from all the values and added back in after the outer log, i.e.

log(abs(sum_i(sign_i * w_ij * exp(vals_i - max)))) + max.

This trick also avoids the underflow issue of when all vals are small enough that exp(val_i) is approximately 0 for all i.


Name Type Description Default
signs Array

array of signs of the input x with shape (..., d, ...), where d is the size of the given axis

vals Array

array of log|abs(x)| with shape (..., d, ...), where d is the size of the given axis

weights Array

weights of a linear transformation to apply to the given axis, with shape (d, d'). If not provided, a simple sum is taken instead, equivalent to (d, 1) weights equal to 1. Defaults to None.

axis int

axis along which to take the sum and max. Defaults to 0.

register_kfac bool

if weights are not None, whether to register the linear part of the computation with KFAC. Defaults to True.



Type Description

sign of linear combination, log of linear combination. Both outputs have shape (..., d', ...), where d' = 1 if weights is None, and d' = weights.shape[1] otherwise.

Source code in vmcnet/utils/
def log_linear_exp(
    signs: Array,
    vals: Array,
    weights: Optional[Array] = None,
    axis: int = 0,
    register_kfac: bool = True,
) -> SLArray:
    """Stably compute sign and log(abs(.)) of sum_i(sign_i * w_ij * exp(vals_i)) + b_j.

    In order to avoid overflow when computing

        log(abs(sum_i(sign_i * w_ij * exp(vals_i)))),

    the largest exp(val_i) is divided out from all the values and added back in after
    the outer log, i.e.

        log(abs(sum_i(sign_i * w_ij * exp(vals_i - max)))) + max.

    This trick also avoids the underflow issue of when all vals are small enough that
    exp(val_i) is approximately 0 for all i.

        signs (Array): array of signs of the input x with shape (..., d, ...),
            where d is the size of the given axis
        vals (Array): array of log|abs(x)| with shape (..., d, ...), where d is
            the size of the given axis
        weights (Array, optional): weights of a linear transformation to apply to
            the given axis, with shape (d, d'). If not provided, a simple sum is taken
            instead, equivalent to (d, 1) weights equal to 1. Defaults to None.
        axis (int, optional): axis along which to take the sum and max. Defaults to 0.
        register_kfac (bool, optional): if weights are not None, whether to register the
            linear part of the computation with KFAC. Defaults to True.

        (SLArray): sign of linear combination, log of linear
        combination. Both outputs have shape (..., d', ...), where d' = 1 if weights is
        None, and d' = weights.shape[1] otherwise.
    max_val = jnp.max(vals, axis=axis, keepdims=True)
    terms_divided_by_max = signs * jnp.exp(vals - max_val)
    if weights is not None:
        # swap axis and -1 to conform to and register_batch_dense api
        terms_divided_by_max = jnp.swapaxes(terms_divided_by_max, axis, -1)
        transformed_divided_by_max =, weights)
        if register_kfac:
            transformed_divided_by_max = register_batch_dense(
                transformed_divided_by_max, terms_divided_by_max, weights, None

        # swap axis and -1 back after the contraction and registration
        transformed_divided_by_max = jnp.swapaxes(transformed_divided_by_max, axis, -1)
        transformed_divided_by_max = jnp.sum(
            terms_divided_by_max, axis=axis, keepdims=True

    signs = jnp.sign(transformed_divided_by_max)
    vals = jnp.log(jnp.abs(transformed_divided_by_max)) + max_val
    return signs, vals