distribute
Helper functions for distributing computation to multiple devices.
wrap_if_pmap(p_func)
Make a function run if in a pmapped context.
Source code in vmcnet/utils/distribute.py
def wrap_if_pmap(p_func: Callable) -> Callable:
"""Make a function run if in a pmapped context."""
def p_func_if_pmap(obj, axis_name):
try:
core.axis_frame(axis_name)
return p_func(obj, axis_name)
except NameError:
return obj
return p_func_if_pmap
replicate_all_local_devices(obj)
Replicate a pytree on all local devices.
Source code in vmcnet/utils/distribute.py
def replicate_all_local_devices(obj: T) -> T:
"""Replicate a pytree on all local devices."""
if obj is None:
return None
n = jax.local_device_count()
obj_stacked = jax.tree_map(lambda x: jnp.stack([x] * n, axis=0), obj)
return broadcast_all_local_devices(obj_stacked)
make_different_rng_key_on_all_devices(rng)
Split a PRNG key to all local devices.
Source code in vmcnet/utils/distribute.py
def make_different_rng_key_on_all_devices(rng: PRNGKey) -> PRNGKey:
"""Split a PRNG key to all local devices."""
rng = jax.random.fold_in(rng, jax.process_index())
rng = jax.random.split(rng, jax.local_device_count())
return broadcast_all_local_devices(rng)
get_first(obj)
Get the first object in each leaf of a pytree.
Can be used to grab the first instance of a replicated object on the first local device.
Source code in vmcnet/utils/distribute.py
def get_first(obj: T) -> T:
"""Get the first object in each leaf of a pytree.
Can be used to grab the first instance of a replicated object on the first local
device.
"""
return jax.tree_map(lambda x: x[0], obj)
mean_all_local_devices(x)
Compute mean over all local devices if distributed, otherwise the usual mean.
Source code in vmcnet/utils/distribute.py
def mean_all_local_devices(x: Array) -> jnp.float32:
"""Compute mean over all local devices if distributed, otherwise the usual mean."""
return pmean_if_pmap(jnp.mean(x))
nanmean_all_local_devices(x)
Compute a nan-safe mean over all local devices.
Source code in vmcnet/utils/distribute.py
def nanmean_all_local_devices(x: Array) -> jnp.float32:
"""Compute a nan-safe mean over all local devices."""
return pmean_if_pmap(jnp.nanmean(x))
get_mean_over_first_axis_fn(nan_safe=True)
Get a function which averages over the first axis over all local devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nan_safe |
bool |
whether to use jnp.nanmean or jnp.mean in the local average computation. Defaults to True. |
True |
Returns:
Type | Description |
---|---|
Callable |
function which averages an array over its first axis over all local devices. |
Source code in vmcnet/utils/distribute.py
def get_mean_over_first_axis_fn(
nan_safe: bool = True,
) -> Callable[[Array], Array]:
"""Get a function which averages over the first axis over all local devices.
Args:
nan_safe (bool, optional): whether to use jnp.nanmean or jnp.mean in the local
average computation. Defaults to True.
Returns:
Callable: function which averages an array over its first axis over all local
devices.
"""
if nan_safe:
local_mean_fn = functools.partial(jnp.nanmean, axis=0)
else:
local_mean_fn = functools.partial(jnp.mean, axis=0)
def mean_fn(x: Array) -> Array:
return pmean_if_pmap(local_mean_fn(x))
return mean_fn
split_or_psplit_key(key, multi_device=True)
Split PRNG key, potentially on multiple devices.
Source code in vmcnet/utils/distribute.py
def split_or_psplit_key(key: PRNGKey, multi_device: bool = True) -> PRNGKey:
"""Split PRNG key, potentially on multiple devices."""
return p_split(key) if multi_device else jax.random.split(key)
reshape_data_leaves_for_distribution(data_leaf)
For a leaf of a pytree, reshape it for distributing to all local devices.
Source code in vmcnet/utils/distribute.py
def reshape_data_leaves_for_distribution(data_leaf: Array) -> Array:
"""For a leaf of a pytree, reshape it for distributing to all local devices."""
num_devices = jax.local_device_count()
nchains = data_leaf.shape[0]
if nchains % num_devices != 0:
raise ValueError(
"Number of chains must be divisible by number of devices, "
"got nchains {} for {} devices.".format(nchains, num_devices)
)
distributed_data_shape = (num_devices, nchains // num_devices)
data = jnp.reshape(data_leaf, distributed_data_shape + data_leaf.shape[1:])
return data
default_distribute_data(data)
Split all data to all devices. The first axis must be divisible by ndevices.
Source code in vmcnet/utils/distribute.py
def default_distribute_data(data: D) -> D:
"""Split all data to all devices. The first axis must be divisible by ndevices."""
data = jax.tree_map(reshape_data_leaves_for_distribution, data)
data = broadcast_all_local_devices(data)
return data
distribute_vmc_state(data, params, optimizer_state, key, distribute_data_fn=<function default_distribute_data at 0x7ff172b38430>)
Split data, replicate params and opt state, and split PRNG key to all devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
~D |
the MCMC data to distribute |
required |
params |
~P |
model parameters |
required |
optimizer_state |
~S |
optimizer state |
required |
key |
PRNGKeyArray |
RNG key |
required |
distribute_data_fn |
Callable[[~D], ~D] |
custom function for distributing the MCMC data, for the case where some of the data needs to be replicated instead of distributed across the devices. Default works if there is no data that requires replication. |
<function default_distribute_data at 0x7ff172b38430> |
Returns:
Type | Description |
---|---|
(D, P, S, PRNGKey) |
tuple of data, params, optimizer_state, and key, each of which has been either distributed or replicated across all devices, as appopriate. |
Source code in vmcnet/utils/distribute.py
def distribute_vmc_state(
data: D,
params: P,
optimizer_state: S,
key: PRNGKey,
distribute_data_fn: Callable[[D], D] = default_distribute_data,
) -> Tuple[D, P, S, PRNGKey]:
"""Split data, replicate params and opt state, and split PRNG key to all devices.
Args:
data: the MCMC data to distribute
params: model parameters
optimizer_state: optimizer state
key: RNG key
distribute_data_fn: custom function for distributing the MCMC data, for the case
where some of the data needs to be replicated instead of distributed across
the devices. Default works if there is no data that requires replication.
Returns:
(D, P, S, PRNGKey): tuple of data, params, optimizer_state, and key,
each of which has been either distributed or replicated across all devices,
as appopriate.
"""
data = distribute_data_fn(data)
params = replicate_all_local_devices(params)
optimizer_state = replicate_all_local_devices(optimizer_state)
sharded_key = make_different_rng_key_on_all_devices(key)
return data, params, optimizer_state, sharded_key
distribute_vmc_state_from_checkpoint(data, params, optimizer_state, key)
Distribute vmc state that was reloaded from a saved checkpoint.
Data and key are saved independently for each device, so on reload we simply broadcast them back to the devices. Params and optimizer state are saved as a single copy, so on reload we replicate them to all devices.
Source code in vmcnet/utils/distribute.py
def distribute_vmc_state_from_checkpoint(
data: D,
params: P,
optimizer_state: S,
key: PRNGKey,
) -> Tuple[D, P, S, PRNGKey]:
"""Distribute vmc state that was reloaded from a saved checkpoint.
Data and key are saved independently for each device, so on reload
we simply broadcast them back to the devices. Params and optimizer state are saved
as a single copy, so on reload we replicate them to all devices.
"""
data = broadcast_all_local_devices(data)
params = replicate_all_local_devices(params)
optimizer_state = replicate_all_local_devices(optimizer_state)
key = broadcast_all_local_devices(key)
return data, params, optimizer_state, key
is_distributed(data)
Tests whether given data has been distributed using pmap.
Source code in vmcnet/utils/distribute.py
def is_distributed(data: PyTree) -> bool:
"""Tests whether given data has been distributed using pmap."""
return isinstance(jax.tree_leaves(data)[0], pxla.ShardedDeviceArray)
get_first_if_distributed(data)
Gets single copy of input data, which may or may not be replicated.
Source code in vmcnet/utils/distribute.py
def get_first_if_distributed(data: PyTree) -> PyTree:
"""Gets single copy of input data, which may or may not be replicated."""
if is_distributed(data):
return get_first(data)
return data