Skip to content

pytree_helpers

Helper functions for pytrees.

tree_sum(tree1, tree2)

Leaf-wise sum of two pytrees with the same structure.

Source code in vmcnet/utils/pytree_helpers.py
def tree_sum(tree1: T, tree2: T) -> T:
    """Leaf-wise sum of two pytrees with the same structure."""
    return jax.tree_map(lambda a, b: a + b, tree1, tree2)

tree_prod(tree1, tree2)

Leaf-wise product of two pytrees with the same structure.

Source code in vmcnet/utils/pytree_helpers.py
def tree_prod(tree1: T, tree2: T) -> T:
    """Leaf-wise product of two pytrees with the same structure."""
    return jax.tree_map(lambda a, b: a * b, tree1, tree2)

multiply_tree_by_scalar(tree, scalar)

Multiply all leaves of a pytree by a scalar.

Source code in vmcnet/utils/pytree_helpers.py
def multiply_tree_by_scalar(tree: T, scalar: jnp.float32) -> T:
    """Multiply all leaves of a pytree by a scalar."""
    return jax.tree_map(lambda x: scalar * x, tree)

tree_inner_product(tree1, tree2)

Inner product of two pytrees with the same structure.

Source code in vmcnet/utils/pytree_helpers.py
def tree_inner_product(tree1: T, tree2: T) -> Array:
    """Inner product of two pytrees with the same structure."""
    leaf_inner_prods = jax.tree_map(lambda a, b: jnp.sum(a * b), tree1, tree2)
    return jnp.sum(jax.flatten_util.ravel_pytree(leaf_inner_prods)[0])

tree_reduce_l1(xs)

L1 norm of a pytree as a flattened vector.

Source code in vmcnet/utils/pytree_helpers.py
def tree_reduce_l1(xs: PyTree) -> jnp.float32:
    """L1 norm of a pytree as a flattened vector."""
    concat_xs, _ = jax.flatten_util.ravel_pytree(xs)
    return jnp.sum(jnp.abs(concat_xs))