weights
Functions to get weight initializers from names.
validate_kernel_initializer(name)
Check that a kernel initializer name is in the list of supported kernel inits.
Source code in vmcnet/models/weights.py
def validate_kernel_initializer(name: str) -> None:
"""Check that a kernel initializer name is in the list of supported kernel inits."""
if name not in VALID_KERNEL_INITIALIZERS:
raise ValueError(
"Invalid kernel initializer requested, {} was requested, but available "
"initializers are: ".format(name) + ", ".join(VALID_KERNEL_INITIALIZERS)
)
get_kernel_initializer(name, dtype=<class 'jax._src.numpy.lax_numpy.float32'>, **kwargs)
Get a kernel initializer.
Source code in vmcnet/models/weights.py
def get_kernel_initializer(
name: str, dtype=jnp.float32, **kwargs: Any
) -> WeightInitializer:
"""Get a kernel initializer."""
validate_kernel_initializer(name)
constructor = INITIALIZER_CONSTRUCTORS[name]
if name == "orthogonal" or name == "delta_orthogonal":
return constructor(scale=kwargs.get("scale", 1.0), dtype=dtype)
else:
return constructor(dtype=dtype)
get_kernel_init_from_config(config, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)
Get a kernel initializer from a ConfigDict.
The ConfigDict should have the key "type", as well as any other kwargs to pass to the initializer constructor.
Source code in vmcnet/models/weights.py
def get_kernel_init_from_config(config: ConfigDict, dtype=jnp.float32):
"""Get a kernel initializer from a ConfigDict.
The ConfigDict should have the key "type", as well as any other kwargs to pass
to the initializer constructor.
"""
return get_kernel_initializer(config.type, dtype=dtype, **config)
validate_bias_initializer(name)
Check that a bias initializer name is in the list of supported bias inits.
Source code in vmcnet/models/weights.py
def validate_bias_initializer(name: str) -> None:
"""Check that a bias initializer name is in the list of supported bias inits."""
if name not in VALID_BIAS_INITIALIZERS:
raise ValueError(
"Invalid bias initializer requested, {} was requested, but available "
"initializers are: ".format(name) + ", ".join(VALID_BIAS_INITIALIZERS)
)
get_bias_initializer(name, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)
Get a bias initializer.
Source code in vmcnet/models/weights.py
def get_bias_initializer(name: str, dtype=jnp.float32) -> WeightInitializer:
"""Get a bias initializer."""
validate_bias_initializer(name)
return INITIALIZER_CONSTRUCTORS[name](dtype=dtype)
get_bias_init_from_config(config, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)
Get a bias initializer from a ConfigDict.
The ConfigDict should have the key "type", as well as any other kwargs to pass to the initializer constructor.
Source code in vmcnet/models/weights.py
def get_bias_init_from_config(config, dtype=jnp.float32):
"""Get a bias initializer from a ConfigDict.
The ConfigDict should have the key "type", as well as any other kwargs to pass
to the initializer constructor.
"""
return get_bias_initializer(config.type, dtype=dtype)
get_constant_init(constant)
Get a weight initializer for a constant array with specified dtype, ignoring key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
constant |
float |
the number to initialize to |
required |
Source code in vmcnet/models/weights.py
def get_constant_init(constant: float):
"""Get a weight initializer for a constant array with specified dtype, ignoring key.
Args:
constant (float): the number to initialize to
"""
def init_fn(key, shape, dtype=jnp.float32):
del key
return jnp.ones(shape, dtype=dtype) * jnp.array(constant, dtype=dtype)
return init_fn