Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FR: Dimension dependent axis-item labels for treescope.render_array #3

Open
amifalk opened this issue Jul 26, 2024 · 2 comments
Open
Labels
feature-request New feature or request

Comments

@amifalk
Copy link

amifalk commented Jul 26, 2024

Currently, the axis-item labels for an array cannot depend on multiple dimensions, but sometimes the meaning of an axis depends on another dimension that is sliced into.

This interpretation should be coherent so long as the axis-item labels named array is a subset of the array to be rendered.

In general, it would be great if axis_item_labels could accept Penzai named arrays/Jax arrays (and internally cast each to a list or equivalent so the object is json-serializable).

import treescope
from penzai import pz

# dimension-dependent axis-item labels:
# should render different labels for each slice of `offset`
axis_item_labels = pz.nx.ones({"x": 10}) * pz.nx.arange("y", 10)  + pz.nx.arange("offset", 3)
treescope.render_array(pz.nx.ones({"offset": 3, "x": 10, "y": 10}), axis_item_labels=axis_item_labels)

# item labels for x and y:
treescope.render_array(pz.nx.ones({"x": 10, "y": 10}), axis_item_labels=[pz.nx.arange("x", 10), pz.nx.arange("y", 10)]) 
@danieldjohnson danieldjohnson added the feature-request New feature or request label Jul 31, 2024
@danieldjohnson
Copy link
Collaborator

Thanks for the suggestions, I recently ran into a similar situation myself!

I think allowing axis item labels to depend on multiple dimensions would definitely be possible. Unfortunately, I don't think JAX lets you store strings in arrays, so using arrays to specify axis labels seems too limiting.

Another option would be to specify the keys of axis_item_labels as tuples, and allow the values to be nested lists? Perhaps something like

treescope.render_array(..., axis_item_labels={
    "x": ["a", "b", "c"],
    ("y", "z"): [["A", "B", "C"], ["D", "E", "F"]],
})

which would associate each y-z pair with a label. (This would probably require changes to how the axis item labels are actually shown, though.)

@amifalk
Copy link
Author

amifalk commented Aug 2, 2024

That looks great! Since numpy arrays do support strings, having that as an acceptable input could be a nice-to-have too, but I could always write some internal wrapper if not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature-request New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants