Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ntxent fix #946

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,45 @@
from jax import lax
import jax.numpy as jnp
from optax.losses import _regression
import numpy as np


def ntxent(
embeddings: chex.Array, labels: chex.Array, temperature: chex.Numeric = 0.07
) -> chex.Numeric:
"""Normalized temperature scaled cross entropy loss (NT-Xent).

Examples:
>>> import jax
>>> import optax
>>> import jax.numpy as jnp
>>>
>>> key = jax.random.key(42)
>>> key1, key2, key3 = jax.random.split(key, 3)
>>> x = jax.random.normal(key1, shape=(4,2))
>>> labels = jnp.array([0, 0, 1, 1])
>>>
>>> print("input:", x)
input: [[-0.9155995 1.5534698 ]
[ 0.2623586 -1.5908985 ]
[-0.15977189 0.480501 ]
[ 0.58389133 0.10497775]]
>>> print("labels:", labels)
labels: [0 0 1 1]
>>>
>>> w = jax.random.normal(key2, shape=(2,1)) # params
>>> b = jax.random.normal(key3, shape=(1,)) # params
>>> out = x @ w + b # model
>>>
>>> print("Embeddings:", out)
Embeddings: [[-1.0076267]
[-1.2960069]
[-1.1829865]
[-1.3485558]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
loss: 1.0986123

References:
T. Chen et al `A Simple Framework for Contrastive Learning of Visual
Representations <http://arxiv.org/abs/2002.05709>`_, 2020
Expand All @@ -34,8 +66,7 @@ def ntxent(
embeddings: batch of embeddings, with shape [batch, feature_length]
labels: labels for groups that are positive pairs. e.g. if you have a batch
of 4 embeddings and the first two and last two were positive pairs your
`labels` should look like [0, 0, 1, 1]. labels SHOULD NOT be all the same
(e.g. [0, 0, 0, 0]) you will get a NaN result. Shape [batch]
`labels` should look like [0, 0, 1, 1]. Shape [batch]
temperature: temperature scaling parameter.

Returns:
Expand All @@ -55,7 +86,8 @@ def ntxent(
# cosine similarity matrix
xcs = (
_regression.cosine_similarity(
embeddings[None, :, :], embeddings[:, None, :]
embeddings[None, :, :], embeddings[:, None, :],
epsilon=np.finfo(embeddings.dtype).eps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to import numpy, you can do the same with jnp instead of np

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Will change shortly.

)
/ temperature
)
Expand Down
13 changes: 13 additions & 0 deletions optax/losses/_self_supervised_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@ def setUp(self):
[1.8745, -0.0195],
[-0.6719, -1.9210],
])
self.ys_2 = jnp.array([
[0.0, 0.0],
[ 0.2380, -0.5703],
[ 1.8745, -0.0195],
[-0.6719, -1.9210],
])
self.ts_1 = jnp.array([0, 0, 1, 1])
self.ts_2 = jnp.array([0, 0, 0, 1])
# Calculated expected output
self.exp_1 = jnp.array(14.01032)
self.exp_2 = jnp.array(8.968544)
self.exp_3 = jnp.array(9.2889)

@chex.all_variants
def test_batched(self):
Expand All @@ -52,6 +59,12 @@ def test_batched(self):
atol=1e-4,
)

np.testing.assert_allclose(
self.variant(_self_supervised.ntxent)(self.ys_2, self.ts_1),
self.exp_3,
atol=1e-4,
)


if __name__ == '__main__':
absltest.main()
Loading