Skip to content

Commit

Permalink
Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.

PiperOrigin-RevId: 559558609
  • Loading branch information
hawkinsp authored and OptaxDev committed Aug 23, 2023
1 parent 1b23e56 commit 1fa1fe0
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions optax/_src/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,18 @@ def _map(cv, params, samples, state):

def control_variates_jacobians(
function: Callable[[chex.Array], float],
control_variate_from_function: Callable[[Callable[[chex.Array], float]],
ControlVariate],
grad_estimator: Callable[..., jnp.array],
control_variate_from_function: Callable[
[Callable[[chex.Array], float]], ControlVariate
],
grad_estimator: Callable[..., jnp.ndarray],
params: base.Params,
dist_builder: Callable[..., Any],
rng: chex.PRNGKey,
num_samples: int,
control_variate_state: CvState = None,
estimate_cv_coeffs: bool = False,
estimate_cv_coeffs_num_samples: int = 20) -> Tuple[
Sequence[chex.Array], CvState]:
estimate_cv_coeffs_num_samples: int = 20,
) -> Tuple[Sequence[chex.Array], CvState]:
r"""Obtain jacobians using control variates.
We will compute each term individually. The first term will use stochastic
Expand Down Expand Up @@ -338,15 +339,17 @@ def param_fn(x):

def estimate_control_variate_coefficients(
function: Callable[[chex.Array], float],
control_variate_from_function: Callable[[Callable[[chex.Array], float]],
ControlVariate],
grad_estimator: Callable[..., jnp.array],
control_variate_from_function: Callable[
[Callable[[chex.Array], float]], ControlVariate
],
grad_estimator: Callable[..., jnp.ndarray],
params: base.Params,
dist_builder: Callable[..., Any],
rng: chex.PRNGKey,
num_samples: int,
control_variate_state: CvState = None,
eps: float = 1e-3) -> Sequence[float]:
eps: float = 1e-3,
) -> Sequence[float]:
r"""Estimates the control variate coefficients for the given parameters.
For each variable `var_k`, the coefficient is given by:
Expand Down

0 comments on commit 1fa1fe0

Please sign in to comment.