Skip to content

slog_helpers

Helper functions for dealing with (sign, logabs) data.

array_to_slog(x)

Converts a regular array into (sign, logabs) form.

Parameters:

Name Type Description Default
x Array

input data.

required

Returns:

Type Description
(SLArray)

data in form (sign(x), log(abs(x)))

Source code in vmcnet/utils/slog_helpers.py
def array_to_slog(x: Array) -> SLArray:
    """Converts a regular array into (sign, logabs) form.

    Args:
        x (Array): input data.

    Returns:
        (SLArray): data in form (sign(x), log(abs(x)))
    """
    return (jnp.sign(x), jnp.log(jnp.abs(x)))

array_from_slog(x)

Converts an slog data tuple into a regular array.

Parameters:

Name Type Description Default
x SLArray

input data in slog form. This data looks like (sign(z), log(abs(z))) for some z which represents the underlying data.

required

Returns:

Type Description
(Array)

the data as a single, regular array. In other words, the z such that x = (sign(z), log(abs(z)))

Source code in vmcnet/utils/slog_helpers.py
def array_from_slog(x: SLArray) -> Array:
    """Converts an slog data tuple into a regular array.

    Args:
        x (SLArray): input data in slog form. This data looks like
            (sign(z), log(abs(z))) for some z which represents the underlying data.

    Returns:
        (Array): the data as a single, regular array. In other words, the z
        such that x = (sign(z), log(abs(z)))
    """
    return x[0] * jnp.exp(x[1])

array_list_to_slog(x)

Map an ArrayList to SLArrayList form.

Parameters:

Name Type Description Default
x ArrayList

input data as a regular spin-split array.

required

Returns:

Type Description
(SLArrayList)

same data with each array transformed to slog form.

Source code in vmcnet/utils/slog_helpers.py
def array_list_to_slog(x: ArrayList) -> SLArrayList:
    """Map an ArrayList to SLArrayList form.

    Args:
        x (ArrayList): input data as a regular spin-split array.

    Returns:
        (SLArrayList): same data with each array transformed to slog form.
    """
    return [array_to_slog(arr) for arr in x]

array_list_from_slog(x)

Map a SLArrayList to ArrayList form.

Parameters:

Name Type Description Default
x SLArrayList

input data as a list of slog arrays.

required

Returns:

Type Description
(ArrayList)

same data with slog tuples transformed to single arrays.

Source code in vmcnet/utils/slog_helpers.py
def array_list_from_slog(x: SLArrayList) -> ArrayList:
    """Map a SLArrayList to ArrayList form.

    Args:
        x (SLArrayList): input data as a list of slog arrays.

    Returns:
        (ArrayList): same data with slog tuples transformed to single arrays.
    """
    return [array_from_slog(slog) for slog in x]

slog_multiply(x, y)

Computes the product of two slog array tuples, as another slog array tuple.

Signs are multiplied and logs are added.

Source code in vmcnet/utils/slog_helpers.py
def slog_multiply(x: SLArray, y: SLArray) -> SLArray:
    """Computes the product of two slog array tuples, as another slog array tuple.

    Signs are multiplied and logs are added.
    """
    (sx, lx) = x
    (sy, ly) = y
    return (sx * sy, lx + ly)

slog_sum_over_axis(x, axis=0)

Take the sum of a single slog array over a specified axis.

Source code in vmcnet/utils/slog_helpers.py
def slog_sum_over_axis(x: SLArray, axis: int = 0) -> SLArray:
    """Take the sum of a single slog array over a specified axis."""
    signs, logs = log_linear_exp(x[0], x[1], axis=axis)
    return (jnp.squeeze(signs, axis=axis), jnp.squeeze(logs, axis=axis))

slog_array_list_sum(x)

Take the sum of a list of SLArrays which are all of the same shape.

Source code in vmcnet/utils/slog_helpers.py
def slog_array_list_sum(x: SLArrayList) -> SLArray:
    """Take the sum of a list of SLArrays which are all of the same shape."""
    stacked_vals = (jnp.stack([a[0] for a in x]), jnp.stack([a[1] for a in x]))
    return slog_sum_over_axis(stacked_vals)

slog_sum(x, y)

Take the sum of two SLArrays which are of the same shape.

Source code in vmcnet/utils/slog_helpers.py
def slog_sum(x: SLArray, y: SLArray) -> SLArray:
    """Take the sum of two SLArrays which are of the same shape."""
    return slog_array_list_sum([x, y])