Skip to content

weights

Functions to get weight initializers from names.

validate_kernel_initializer(name)

Check that a kernel initializer name is in the list of supported kernel inits.

Source code in vmcnet/models/weights.py
def validate_kernel_initializer(name: str) -> None:
    """Check that a kernel initializer name is in the list of supported kernel inits."""
    if name not in VALID_KERNEL_INITIALIZERS:
        raise ValueError(
            "Invalid kernel initializer requested, {} was requested, but available "
            "initializers are: ".format(name) + ", ".join(VALID_KERNEL_INITIALIZERS)
        )

get_kernel_initializer(name, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, **kwargs)

Get a kernel initializer.

Source code in vmcnet/models/weights.py
def get_kernel_initializer(
    name: str, dtype=jnp.float32, **kwargs: Any
) -> WeightInitializer:
    """Get a kernel initializer."""
    validate_kernel_initializer(name)
    constructor = INITIALIZER_CONSTRUCTORS[name]
    if name == "orthogonal" or name == "delta_orthogonal":
        return constructor(scale=kwargs.get("scale", 1.0), dtype=dtype)
    else:
        return constructor(dtype=dtype)

get_kernel_init_from_config(config, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

Get a kernel initializer from a ConfigDict.

The ConfigDict should have the key "type", as well as any other kwargs to pass to the initializer constructor.

Source code in vmcnet/models/weights.py
def get_kernel_init_from_config(config: ConfigDict, dtype=jnp.float32):
    """Get a kernel initializer from a ConfigDict.

    The ConfigDict should have the key "type", as well as any other kwargs to pass
    to the initializer constructor.
    """
    return get_kernel_initializer(config.type, dtype=dtype, **config)

validate_bias_initializer(name)

Check that a bias initializer name is in the list of supported bias inits.

Source code in vmcnet/models/weights.py
def validate_bias_initializer(name: str) -> None:
    """Check that a bias initializer name is in the list of supported bias inits."""
    if name not in VALID_BIAS_INITIALIZERS:
        raise ValueError(
            "Invalid bias initializer requested, {} was requested, but available "
            "initializers are: ".format(name) + ", ".join(VALID_BIAS_INITIALIZERS)
        )

get_bias_initializer(name, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

Get a bias initializer.

Source code in vmcnet/models/weights.py
def get_bias_initializer(name: str, dtype=jnp.float32) -> WeightInitializer:
    """Get a bias initializer."""
    validate_bias_initializer(name)
    return INITIALIZER_CONSTRUCTORS[name](dtype=dtype)

get_bias_init_from_config(config, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)

Get a bias initializer from a ConfigDict.

The ConfigDict should have the key "type", as well as any other kwargs to pass to the initializer constructor.

Source code in vmcnet/models/weights.py
def get_bias_init_from_config(config, dtype=jnp.float32):
    """Get a bias initializer from a ConfigDict.

    The ConfigDict should have the key "type", as well as any other kwargs to pass
    to the initializer constructor.
    """
    return get_bias_initializer(config.type, dtype=dtype)

get_constant_init(constant)

Get a weight initializer for a constant array with specified dtype, ignoring key.

Parameters:

Name Type Description Default
constant float

the number to initialize to

required
Source code in vmcnet/models/weights.py
def get_constant_init(constant: float):
    """Get a weight initializer for a constant array with specified dtype, ignoring key.

    Args:
        constant (float): the number to initialize to
    """

    def init_fn(key, shape, dtype=jnp.float32):
        del key
        return jnp.ones(shape, dtype=dtype) * jnp.array(constant, dtype=dtype)

    return init_fn