pytree_helpers
Helper functions for pytrees.
tree_sum(tree1, tree2)
Leaf-wise sum of two pytrees with the same structure.
Source code in vmcnet/utils/pytree_helpers.py
def tree_sum(tree1: T, tree2: T) -> T:
"""Leaf-wise sum of two pytrees with the same structure."""
return jax.tree_map(lambda a, b: a + b, tree1, tree2)
tree_prod(tree1, tree2)
Leaf-wise product of two pytrees with the same structure.
Source code in vmcnet/utils/pytree_helpers.py
def tree_prod(tree1: T, tree2: T) -> T:
"""Leaf-wise product of two pytrees with the same structure."""
return jax.tree_map(lambda a, b: a * b, tree1, tree2)
multiply_tree_by_scalar(tree, scalar)
Multiply all leaves of a pytree by a scalar.
Source code in vmcnet/utils/pytree_helpers.py
def multiply_tree_by_scalar(tree: T, scalar: jnp.float32) -> T:
"""Multiply all leaves of a pytree by a scalar."""
return jax.tree_map(lambda x: scalar * x, tree)
tree_inner_product(tree1, tree2)
Inner product of two pytrees with the same structure.
Source code in vmcnet/utils/pytree_helpers.py
def tree_inner_product(tree1: T, tree2: T) -> Array:
"""Inner product of two pytrees with the same structure."""
leaf_inner_prods = jax.tree_map(lambda a, b: jnp.sum(a * b), tree1, tree2)
return jnp.sum(jax.flatten_util.ravel_pytree(leaf_inner_prods)[0])
tree_reduce_l1(xs)
L1 norm of a pytree as a flattened vector.
Source code in vmcnet/utils/pytree_helpers.py
def tree_reduce_l1(xs: PyTree) -> jnp.float32:
"""L1 norm of a pytree as a flattened vector."""
concat_xs, _ = jax.flatten_util.ravel_pytree(xs)
return jnp.sum(jnp.abs(concat_xs))