diff --git a/alphafold/model/utils.py b/alphafold/model/utils.py index 634f0388a..3e5ac625c 100644 --- a/alphafold/model/utils.py +++ b/alphafold/model/utils.py @@ -163,7 +163,7 @@ def inner(key, shape, **kwargs): keys = grid_keys(key, shape) signature = ( '()->()' - if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key) + if isinstance(keys, jax.random.PRNGKeyArray) else '(2)->()' ) return jnp.vectorize(