Skip to content

Commit

Permalink
Reduces latency of rendering Jax arrays.
Browse files Browse the repository at this point in the history
Before:
  Cold rendering array of any size would take ~15-25 seconds (in colab).

Now:
  Cold rendering of small-ish (less than 10M bytes) arrays is instantaneous (<< 1s)
  Larger arrays (e.g. > 2K x 2K) take 3-4 seconds, which is much more manageable

To do this we do two things:
 a) We convert to numpy for arrays smaller than 10M for computing stats and doing slicing. (We still maintain jax visualization of sharding and types)
 b) We use a single jitted function for arrays larger than 10m to compute summaries rather than individual jax invocations for each stat, which end up jitted separately.

PiperOrigin-RevId: 681055675
  • Loading branch information
marksandler2 authored and Treescope Developers committed Oct 1, 2024
1 parent c05b5fe commit 438b9a5
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 91 deletions.
18 changes: 18 additions & 0 deletions tests/renderer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from treescope import layout_algorithms
from treescope import lowering
from treescope import rendering_parts
from treescope.external import jax_support
from tests.fixtures import treescope_examples_fixture as fixture_lib


Expand Down Expand Up @@ -635,6 +636,23 @@ def inner_fn(y):
lowering.render_to_text_as_root(rendering),
)

def test_render_jax_array_within_jitted_function(self):
old = jax_support.SUMMARIZE_USING_NUMPY_THRESHOLD
jax_support.SUMMARIZE_USING_NUMPY_THRESHOLD = 0
renderer = treescope.active_renderer.get()
x = jnp.arange(10)

# Verify that we don't fail when called inside a jitted function.
@jax.jit
def go(s):
nonlocal renderer, x
adapter = jax_support.JAXArrayAdapter()
self.assertNotEmpty(adapter.get_array_summary(x, False))
return jax.numpy.sum(s)

go(jnp.arange(3))
jax_support.SUMMARIZE_USING_NUMPY_THRESHOLD = old

def test_fallback_repr_pytree_node(self):
target = [fixture_lib.UnknownPytreeNode(1234, 5678)]
renderer = treescope.active_renderer.get()
Expand Down
221 changes: 130 additions & 91 deletions treescope/external/jax_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import functools
import typing
from typing import Mapping

Expand All @@ -39,17 +40,6 @@
# pylint: enable=g-import-not-at-top


def _finite_mean_std_any(array):
"""Helper to compute mean and standard deviation only over finite elements."""
assert jax is not None
jnp = jax.numpy
isfinite = jnp.isfinite(array)
inf_to_nan = jnp.where(isfinite, array, jnp.array(jnp.nan, dtype=array.dtype))
mean = jnp.nanmean(inf_to_nan)
std = jnp.nanstd(inf_to_nan)
return mean, std, jnp.any(isfinite)


def _is_subdtype(dtype, base) -> bool:
"""Safely checks for dtype subtyping."""
assert jax is not None
Expand Down Expand Up @@ -89,17 +79,29 @@ def _is_subdtype(dtype, base) -> bool:
"""


SUMMARIZE_USING_NUMPY_THRESHOLD = 10_000_000
"""Threshold for using NumPy to summarize and render JAX Arrays.
Moving arrays to main memory and using numpy is significantly faster because
it avoid jitting the summary and rendering functions.
"""


def _is_locally_available(array: jax.Array) -> bool:
"""Checks if the array is available locally."""
return getattr(array, "is_fully_addressable", False) or getattr(
array, "is_fully_replicated", False
)


def safe_to_summarize(array: jax.Array) -> bool:
"""Checks if the array is safe to summarize (not a tracer and not replicated)."""
assert jax is not None, "JAX is not available."
if isinstance(array, jax.core.Tracer):
return False
if array.is_deleted():
return False
if not (
getattr(array, "is_fully_addressable", False)
or getattr(array, "is_fully_replicated", False)
):
if not _is_locally_available(array):
return False
thresh_dict = summarization_threshold.get()
[platform] = set(device.platform for device in array.devices())
Expand All @@ -114,24 +116,29 @@ def _truncate_part_with_slices(
mask: jax.Array,
prefix_slices: tuple[slice, ...],
remaining_edge_items_per_axis: tuple[int | None, ...],
xnp=None,
) -> tuple[jax.Array, jax.Array]:
"""Helper to truncate names of an array.
Args:
array: An array to truncate.
mask: Mask array, which must have the same number of dimensions as `array`,
and whose axis sizes must be either 1 or the same as that axis of `array`
(e.g. they are broadcast compatible).
mask: Mask array, which must be broadcastable to `array`.
prefix_slices: Slices to apply to each axis of `array` and `mask`, starting
at axis 0, which we have already computed.
remaining_edge_items_per_axis: Number of edge items to keep for each axis,
ignoring any axes whose slices are already computed in `prefix_slices`.
xnp: backend to use (numpy or jax.numpy).
Returns:
Truncated array and mask, which will both be the same shape.
"""
assert jax is not None, "JAX is not available."
jnp = jax.numpy
if xnp is None:
assert jax is not None, "JAX is not available."
xnp = jax.numpy

array = xnp.array(array)
mask = xnp.array(mask)
mask = xnp.broadcast_to(mask, array.shape)
if not remaining_edge_items_per_axis:
# Perform the base case slice.
assert len(prefix_slices) == len(array.shape)
Expand All @@ -141,8 +148,8 @@ def _truncate_part_with_slices(
slice(None) if mask.shape[i] == 1 else array_slice
for i, array_slice in enumerate(prefix_slices)
)
truncated_mask = jnp.broadcast_to(
jnp.array(mask[valid_mask_slices]), truncated_array.shape
truncated_mask = xnp.broadcast_to(
xnp.array(mask[valid_mask_slices]), truncated_array.shape
)
return truncated_array, truncated_mask

Expand All @@ -157,6 +164,7 @@ def _truncate_part_with_slices(
mask,
prefix_slices=prefix_slices + (slice(None),),
remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:],
xnp=xnp,
)
else:
assert array.shape[axis] > 2 * edge_items
Expand All @@ -165,21 +173,23 @@ def _truncate_part_with_slices(
mask,
prefix_slices=prefix_slices + (slice(None, edge_items),),
remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:],
xnp=xnp,
)
result_b, valid_b = _truncate_part_with_slices(
array,
mask,
prefix_slices=prefix_slices + (slice(-edge_items, None),),
remaining_edge_items_per_axis=remaining_edge_items_per_axis[1:],
xnp=xnp,
)
padding_shape = list(result_a.shape)
padding_shape[axis] = 1
result = jnp.concatenate(
[result_a, jnp.zeros(padding_shape, result_a.dtype), result_b],
result = xnp.concatenate(
[result_a, xnp.zeros(padding_shape, result_a.dtype), result_b],
axis=axis,
)
valid = jnp.concatenate(
[valid_a, jnp.zeros(padding_shape, valid_a.dtype), valid_b], axis=axis
valid = xnp.concatenate(
[valid_a, xnp.zeros(padding_shape, valid_a.dtype), valid_b], axis=axis
)
return result, valid

Expand Down Expand Up @@ -233,9 +243,12 @@ def truncate_array_and_mask(
array.sharding._device_assignment # pylint: disable=protected-access
)
)
fn = jax.jit(
_truncate_part_with_slices, static_argnums=(2, 3), **sharding_kwargs
)
if array.size < SUMMARIZE_USING_NUMPY_THRESHOLD and safe_to_summarize(array):
fn = functools.partial(_truncate_part_with_slices, xnp=np)
else:
fn = jax.jit(
_truncate_part_with_slices, static_argnums=(2, 3), **sharding_kwargs
)
return fn(array, mask, (), edge_items_per_axis)


Expand Down Expand Up @@ -272,7 +285,7 @@ def faster_array_repr(array: jax.Array) -> str:
edge_items_per_axis.append(None)
array_edges, _ = truncate_array_and_mask(
array,
jnp.ones((1,) * array.ndim, dtype=jnp.bool_),
np.ones((1,) * array.ndim, dtype=jnp.bool_),
edge_items_per_axis=tuple(edge_items_per_axis),
)
prefix = "Array("
Expand Down Expand Up @@ -344,6 +357,76 @@ def render_precision(
)


def _compute_summary(
x: jax.Array, is_floating: bool, is_integer: bool, is_bool: bool, xnp=None
) -> dict[str, jax.Array]:
"""Computes a summary of the given array."""
if xnp is None:
assert jax is not None, "JAX is not available."
xnp = jax.numpy
x = xnp.array(x)
result = {}
if is_floating:
isfinite = xnp.isfinite(x)
inf_to_nan = xnp.where(isfinite, x, xnp.array(xnp.nan, dtype=x.dtype))
result.update(mean=xnp.nanmean(inf_to_nan), std=xnp.nanstd(inf_to_nan))
result.update(nanmin=xnp.nanmin(x), nanmax=xnp.nanmax(x))
result.update(
nan=xnp.count_nonzero(xnp.isnan(x)),
inf=xnp.count_nonzero(xnp.isposinf(x)),
)
result["any_finite"] = xnp.any(isfinite)
result["-inf"] = xnp.count_nonzero(xnp.isneginf(x))
if is_integer:
result.update(min=xnp.min(x), max=xnp.max(x))
if is_floating or is_integer:
result.update(zero=xnp.count_nonzero(x == 0), nonzero=xnp.count_nonzero(x))
if is_bool:
result.update(
true=xnp.count_nonzero(x), false=xnp.count_nonzero(xnp.logical_not(x))
)
return result


def _summarize_array_data_unconditionally(array: jax.Array) -> list[str]:
"""Summarized the data of a JAX array."""
assert jax is not None, "JAX is not available."
jnp = jax.numpy
output_parts = []
# This is required if treescope is invoked inside jitted function.
with jax.core.ensure_compile_time_eval():
is_floating = _is_subdtype(array.dtype, jnp.floating)
is_integer = _is_subdtype(array.dtype, jnp.integer)
is_bool = _is_subdtype(array.dtype, jnp.bool_)
if array.size < SUMMARIZE_USING_NUMPY_THRESHOLD:
compute_summary = functools.partial(_compute_summary, xnp=np)
else:
compute_summary = jax.jit(_compute_summary, static_argnums=(1, 2, 3))
stat = compute_summary(array, is_floating, is_integer, is_bool)
# Get values in parallel.
stat = jax.device_get(stat)
if is_floating and stat["any_finite"]:
output_parts.append(f" ≈{stat['mean']:.2} ±{stat['std']:.2}")
output_parts.append(f" [≥{stat['nanmin']:.2}, ≤{stat['nanmax']:.2}]")

if is_integer:
output_parts.append(f" [≥{stat['min']:_d}, ≤{stat['max']:_d}]")

def append_if_present(output_parts, *names):
for name in names:
if stat[name]:
output_parts.append(f" {name}:{stat[name]:_d}")

if is_floating or is_integer:
append_if_present(output_parts, "zero", "nonzero")
if is_floating:
append_if_present(output_parts, "nan", "inf", "-inf")

if is_bool:
append_if_present(output_parts, "true", "false")
return output_parts


def summarize_array_data(array: jax.Array) -> str:
"""Summarized the data of a JAX array.
Expand All @@ -353,60 +436,20 @@ def summarize_array_data(array: jax.Array) -> str:
Returns:
A string summarizing the data of the array.
"""
assert jax is not None, "JAX is not available."
jnp = jax.numpy

output_parts = []
if array.is_deleted():

if isinstance(array, jax.core.Tracer):
output_parts.append(" - tracer.")
elif array.is_deleted():
output_parts.append(" - deleted!")
elif not _is_locally_available(array):
output_parts.append(" - multi-host array!")
elif safe_to_summarize(array):
with jax.core.ensure_compile_time_eval():
is_floating = _is_subdtype(array.dtype, jnp.floating)
is_integer = _is_subdtype(array.dtype, jnp.integer)
is_bool = _is_subdtype(array.dtype, jnp.bool_)

if is_floating:
mean, std, any_finite = jax.jit(_finite_mean_std_any)(array)

if any_finite:
output_parts.append(f" ≈{float(mean):.2} ±{float(std):.2}")
output_parts.append(
f" [≥{float(jnp.nanmin(array)):.2},"
f" ≤{float(jnp.nanmax(array)):.2}]"
)

if is_integer:
output_parts.append(f" [≥{jnp.min(array):_d}, ≤{jnp.max(array):_d}]")

if is_floating or is_integer:
ct_zero = jnp.count_nonzero(array == 0)
if ct_zero:
output_parts.append(f" zero:{ct_zero:_d}")

ct_nonzero = jnp.count_nonzero(array)
if ct_nonzero:
output_parts.append(f" nonzero:{ct_nonzero:_d}")

if is_floating:
ct_nan = jnp.count_nonzero(jnp.isnan(array))
if ct_nan:
output_parts.append(f" nan:{ct_nan:_d}")

ct_inf = jnp.count_nonzero(jnp.isposinf(array))
if ct_inf:
output_parts.append(f" inf:{ct_inf:_d}")

ct_neginf = jnp.count_nonzero(jnp.isneginf(array))
if ct_neginf:
output_parts.append(f" -inf:{ct_neginf:_d}")

if is_bool:
ct_true = jnp.count_nonzero(array)
if ct_true:
output_parts.append(f" true:{ct_true:_d}")

ct_false = jnp.count_nonzero(jnp.logical_not(array))
if ct_false:
output_parts.append(f" false:{ct_false:_d}")
output_parts.extend(_summarize_array_data_unconditionally(array))
else:
output_parts.append("- too large to summarize.")

return "".join(output_parts)


Expand All @@ -427,23 +470,19 @@ def get_array_data_with_truncation(
array: jax.Array,
mask: jax.Array | None,
edge_items_per_axis: tuple[int | None, ...],
) -> tuple[jax.Array, jax.Array]:
) -> tuple[np.ndaray, np.ndarray]:
assert jax is not None, "JAX is not available."
jnp = jax.numpy
assert not isinstance(array, jax.core.Tracer)
assert not array.is_deleted()
if mask is not None:
# Make sure we can broadcast the shape correctly.
_ = jax.eval_shape(lambda: jnp.broadcast_to(mask, array.shape))
mask = mask[(None,) * (array.ndim - mask.ndim) + (...,)]
else:
mask = jnp.ones((1,) * array.ndim, dtype=jnp.bool_)
if mask is None:
mask = np.array(True)

if edge_items_per_axis == (None,) * array.ndim:
# No truncation.
return array, jnp.broadcast_to(mask, array.shape)
return np.array(array), np.broadcast_to(mask, array.shape)

return truncate_array_and_mask(array, mask, edge_items_per_axis)
array, mask = truncate_array_and_mask(array, mask, edge_items_per_axis)
return jax.device_get((array, mask))

def get_array_summary(self, array: jax.Array, fast: bool) -> str:
output_parts = ["jax.Array "]
Expand Down

0 comments on commit 438b9a5

Please sign in to comment.