Skip to content

Commit

Permalink
Fix rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 committed Oct 5, 2023
1 parent 45b396a commit da2b6f7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions lib/llama/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def forward_rotary_embedding(m: Array) -> Array:

with jax.ensure_compile_time_eval():
sin_val, cos_val = _make_weights(seq_len, d_k)
sin_val = sin_val.astype(m.dtype)
cos_val = cos_val.astype(m.dtype)
assert sin_val.dtype == jnp.float32
assert cos_val.dtype == jnp.float32

n = _rotate_half(m)
a = op.einsum(m, cos_val, '... seq_len d_k, seq_len d_k -> ... seq_len d_k')
b = op.einsum(n, sin_val, '... seq_len d_k, seq_len d_k -> ... seq_len d_k')
a = op.einsum(m, cos_val, '... seq_len d_k, seq_len d_k -> ... seq_len d_k').astype(m.dtype)
b = op.einsum(n, sin_val, '... seq_len d_k, seq_len d_k -> ... seq_len d_k').astype(m.dtype)
return a + b

0 comments on commit da2b6f7

Please sign in to comment.