Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Jul 30, 2023
1 parent e1605a4 commit 7e51bbb
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 19 deletions.
22 changes: 21 additions & 1 deletion lib/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,21 @@
from .cross_entropy_loss import cross_entropy_loss
from jax import Array
import jax.numpy as jnp
import optax

def cross_entropy_loss(logits: Array, labels: Array, *, mask: Array) -> Array:
'''
Computes the mean softmax cross entropy loss for the provided logits and labels, only considering elements where the mask is True.
Args:
logits (Array): The model's predictions. Typically an array of shape (batch_size, seq_len, vocab_size).
labels (Array): The true labels. Typically an array of shape (batch_size, seq_len).
mask (Array): Typically a boolean array of shape (batch_size, seq_len). Specifies which elements of the loss array should be included in the mean loss calculation.
Returns:
Array: The mean masked softmax cross entropy loss.
Notes:
The mask array typically specifies which elements of the sequence are actual tokens as opposed to padding tokens, so the loss is not calculated over padding tokens.
'''
loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
return jnp.mean(loss, where=mask)
7 changes: 0 additions & 7 deletions lib/loss/cross_entropy_loss.py

This file was deleted.

9 changes: 2 additions & 7 deletions lib/multihost_utils/shard_model_params_to_multihost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
embedding=...,
decoder=Decoder(
input_norm=...,
attention=Attention(
q_proj=3,
k_proj=2,
v_proj=2,
out_proj=2,
),
attention=Attention(q_proj=3, k_proj=2, v_proj=2, out_proj=2),
post_attn_norm=...,
gate_proj=2,
up_proj=2,
Expand All @@ -23,5 +18,5 @@
lm_head=...,
)

def shard_model_params_to_multihost(params: Llama):
def shard_model_params_to_multihost(params: Llama) -> Llama:
return tree_apply(shard_array_to_multihost, params, sharding_mp)
17 changes: 16 additions & 1 deletion lib/seeding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# best integer for seeding, proposed in https://arxiv.org/abs/2109.08203
MEMPTY = 0
'''
The identity element of the monoid that is the set of integers.
A type `a` is a monoid if it provides an associative function that lets you combine any two values of type `a` into one, and a neutral element (`mempty`) such that
```haskell
a <> mempty == mempty <> a == a.
```
A monoid is a semigroup with the added requirement of a neutral element. Therefore, any monoid is a semigroup, but not the other way around.
'''

BEST_INTEGER = 3407
'''The best integer for seeding, as proposed in https://arxiv.org/abs/2109.08203.'''

BUDDHA = r'''
_oo0oo_
Expand All @@ -23,5 +36,7 @@
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
佛祖保佑 永無 BUG
'''
'''The "May Buddha bless us: no bugs forever" ASCII art. Placing this ASCII art in the codebase is a common practice to prevent bugs and avoid having to debug the code.'''

HASHED_BUDDHA = 3516281645 # hash(BUDDHA) % 2**32
'''The hashed value of the `BUDDHA` string.'''
7 changes: 4 additions & 3 deletions lib/tree_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.numpy as jnp
from typing import Callable

# https://docs.liesel-project.org/en/v0.1.4/_modules/liesel/goose/pytree.html#stack_leaves
def stack_leaves(pytrees, axis: int=0):
Expand All @@ -24,17 +25,17 @@ def unstack_leaves(pytrees):
pytrees: A PyTree.
Returns:
A list of PyTrees, where each PyTree has the same structure as the input PyTree, but with only one of the original leaves.
A list of PyTrees, where each PyTree has the same structure as the input PyTree, but each leaf contains only one part of the original leaf.
'''
leaves, treedef = jax.tree_util.tree_flatten(pytrees)
return [treedef.unflatten(leaf) for leaf in zip(*leaves, strict=True)]

def tree_apply(func, *pytrees):
def tree_apply(func: Callable, *pytrees):
'''
Apply a function to the leaves of one or more PyTrees.
Args:
func (callable): Function to apply to each leaf. It must take the same number of arguments as there are PyTrees.
func (Callable): Function to apply to each leaf. It must take the same number of arguments as there are PyTrees.
pytrees: One or more PyTrees.
Returns:
Expand Down

0 comments on commit 7e51bbb

Please sign in to comment.