Skip to content

distribute

Helper functions for distributing computation to multiple devices.

wrap_if_pmap(p_func)

Make a function run if in a pmapped context.

Source code in vmcnet/utils/distribute.py
def wrap_if_pmap(p_func: Callable) -> Callable:
    """Make a function run if in a pmapped context."""

    def p_func_if_pmap(obj, axis_name):
        try:
            core.axis_frame(axis_name)
            return p_func(obj, axis_name)
        except NameError:
            return obj

    return p_func_if_pmap

replicate_all_local_devices(obj)

Replicate a pytree on all local devices.

Source code in vmcnet/utils/distribute.py
def replicate_all_local_devices(obj: T) -> T:
    """Replicate a pytree on all local devices."""
    if obj is None:
        return None
    n = jax.local_device_count()
    obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj)
    return broadcast_all_local_devices(obj_stacked)

make_different_rng_key_on_all_devices(rng)

Split a PRNG key to all local devices.

Source code in vmcnet/utils/distribute.py
def make_different_rng_key_on_all_devices(rng: PRNGKey) -> PRNGKey:
    """Split a PRNG key to all local devices."""
    rng = jax.random.fold_in(rng, jax.process_index())
    rng = jax.random.split(rng, jax.local_device_count())
    return broadcast_all_local_devices(rng)

get_first(obj)

Get the first object in each leaf of a pytree.

Can be used to grab the first instance of a replicated object on the first local device.

Source code in vmcnet/utils/distribute.py
def get_first(obj: T) -> T:
    """Get the first object in each leaf of a pytree.

    Can be used to grab the first instance of a replicated object on the first local
    device.
    """
    return jax.tree_map(lambda x: x[0], obj)

mean_all_local_devices(x)

Compute mean over all local devices if distributed, otherwise the usual mean.

Source code in vmcnet/utils/distribute.py
def mean_all_local_devices(x: Array) -> jnp.float32:
    """Compute mean over all local devices if distributed, otherwise the usual mean."""
    return pmean_if_pmap(jnp.mean(x))

nanmean_all_local_devices(x)

Compute a nan-safe mean over all local devices.

Source code in vmcnet/utils/distribute.py
def nanmean_all_local_devices(x: Array) -> jnp.float32:
    """Compute a nan-safe mean over all local devices."""
    return pmean_if_pmap(jnp.nanmean(x))

get_mean_over_first_axis_fn(nan_safe=True)

Get a function which averages over the first axis over all local devices.

Parameters:

Name Type Description Default
nan_safe bool

whether to use jnp.nanmean or jnp.mean in the local average computation. Defaults to True.

True

Returns:

Type Description
Callable

function which averages an array over its first axis over all local devices.

Source code in vmcnet/utils/distribute.py
def get_mean_over_first_axis_fn(
    nan_safe: bool = True,
) -> Callable[[Array], Array]:
    """Get a function which averages over the first axis over all local devices.

    Args:
        nan_safe (bool, optional): whether to use jnp.nanmean or jnp.mean in the local
            average computation. Defaults to True.

    Returns:
        Callable: function which averages an array over its first axis over all local
        devices.
    """
    if nan_safe:
        local_mean_fn = functools.partial(jnp.nanmean, axis=0)
    else:
        local_mean_fn = functools.partial(jnp.mean, axis=0)

    def mean_fn(x: Array) -> Array:
        return pmean_if_pmap(local_mean_fn(x))

    return mean_fn

split_or_psplit_key(key, multi_device=True)

Split PRNG key, potentially on multiple devices.

Source code in vmcnet/utils/distribute.py
def split_or_psplit_key(key: PRNGKey, multi_device: bool = True) -> PRNGKey:
    """Split PRNG key, potentially on multiple devices."""
    return p_split(key) if multi_device else jax.random.split(key)

reshape_data_leaves_for_distribution(data_leaf)

For a leaf of a pytree, reshape it for distributing to all local devices.

Source code in vmcnet/utils/distribute.py
def reshape_data_leaves_for_distribution(data_leaf: Array) -> Array:
    """For a leaf of a pytree, reshape it for distributing to all local devices."""
    num_devices = jax.local_device_count()
    nchains = data_leaf.shape[0]
    if nchains % num_devices != 0:
        raise ValueError(
            "Number of chains must be divisible by number of devices, "
            "got nchains {} for {} devices.".format(nchains, num_devices)
        )
    distributed_data_shape = (num_devices, nchains // num_devices)
    data = jnp.reshape(data_leaf, distributed_data_shape + data_leaf.shape[1:])
    return data

default_distribute_data(data)

Split all data to all devices. The first axis must be divisible by ndevices.

Source code in vmcnet/utils/distribute.py
def default_distribute_data(data: D) -> D:
    """Split all data to all devices. The first axis must be divisible by ndevices."""
    data = jax.tree_map(reshape_data_leaves_for_distribution, data)
    data = broadcast_all_local_devices(data)
    return data

distribute_vmc_state(data, params, optimizer_state, key, distribute_data_fn=<function default_distribute_data at 0x7ff172b38430>)

Split data, replicate params and opt state, and split PRNG key to all devices.

Parameters:

Name Type Description Default
data ~D

the MCMC data to distribute

required
params ~P

model parameters

required
optimizer_state ~S

optimizer state

required
key PRNGKeyArray

RNG key

required
distribute_data_fn Callable[[~D], ~D]

custom function for distributing the MCMC data, for the case where some of the data needs to be replicated instead of distributed across the devices. Default works if there is no data that requires replication.

<function default_distribute_data at 0x7ff172b38430>

Returns:

Type Description
(D, P, S, PRNGKey)

tuple of data, params, optimizer_state, and key, each of which has been either distributed or replicated across all devices, as appopriate.

Source code in vmcnet/utils/distribute.py
def distribute_vmc_state(
    data: D,
    params: P,
    optimizer_state: S,
    key: PRNGKey,
    distribute_data_fn: Callable[[D], D] = default_distribute_data,
) -> Tuple[D, P, S, PRNGKey]:
    """Split data, replicate params and opt state, and split PRNG key to all devices.

    Args:
        data: the MCMC data to distribute
        params: model parameters
        optimizer_state: optimizer state
        key: RNG key
        distribute_data_fn: custom function for distributing the MCMC data, for the case
            where some of the data needs to be replicated instead of distributed across
            the devices. Default works if there is no data that requires replication.

    Returns:
        (D, P, S, PRNGKey): tuple of data, params, optimizer_state, and key,
        each of which has been either distributed or replicated across all devices,
        as appopriate.
    """
    data = distribute_data_fn(data)
    params = replicate_all_local_devices(params)
    optimizer_state = replicate_all_local_devices(optimizer_state)
    sharded_key = make_different_rng_key_on_all_devices(key)

    return data, params, optimizer_state, sharded_key

distribute_vmc_state_from_checkpoint(data, params, optimizer_state, key)

Distribute vmc state that was reloaded from a saved checkpoint.

Data and key are saved independently for each device, so on reload we simply broadcast them back to the devices. Params and optimizer state are saved as a single copy, so on reload we replicate them to all devices.

Source code in vmcnet/utils/distribute.py
def distribute_vmc_state_from_checkpoint(
    data: D,
    params: P,
    optimizer_state: S,
    key: PRNGKey,
) -> Tuple[D, P, S, PRNGKey]:
    """Distribute vmc state that was reloaded from a saved checkpoint.

    Data and key are saved independently for each device, so on reload
    we simply broadcast them back to the devices. Params and optimizer state are saved
    as a single copy, so on reload we replicate them to all devices.
    """
    data = broadcast_all_local_devices(data)
    params = replicate_all_local_devices(params)
    optimizer_state = replicate_all_local_devices(optimizer_state)
    key = broadcast_all_local_devices(key)

    return data, params, optimizer_state, key

is_distributed(data)

Tests whether given data has been distributed using pmap.

Source code in vmcnet/utils/distribute.py
def is_distributed(data: PyTree) -> bool:
    """Tests whether given data has been distributed using pmap."""
    return isinstance(jax.tree_leaves(data)[0], pxla.ShardedDeviceArray)

get_first_if_distributed(data)

Gets single copy of input data, which may or may not be replicated.

Source code in vmcnet/utils/distribute.py
def get_first_if_distributed(data: PyTree) -> PyTree:
    """Gets single copy of input data, which may or may not be replicated."""
    if is_distributed(data):
        return get_first(data)
    return data