Skip to content

Commit

Permalink
Implement KV cache (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaka14732 authored Oct 7, 2023
1 parent da2b6f7 commit 18e9625
Show file tree
Hide file tree
Showing 40 changed files with 235 additions and 81 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
47 changes: 37 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,28 @@ This project is supported by Cloud TPUs from Google's [TPU Research Cloud](https
- [x] [Llama Model](lib/llama/llama_model.py)
- [x] [Cross entropy loss](lib/loss/cross_entropy_loss.py)
- [x] [Logits processing](lib/logits_processing/)
- [x] [Bias](lib/logits_processing/bias.py)
- [x] [Penalize presence](lib/logits_processing/penalize_presence.py)
- [x] [Penalize frequency](lib/logits_processing/penalize_frequency.py)
- [ ] Generation
- [ ] KV cache
- [x] [KV cache](lib/llama/kv_cache.py)
- [ ] Left padding
- [ ] Beam search
- [ ] Beam sampling
- [x] [Top-_k_ sampling](lib/generation/top_k.py)
- [x] [Top-_p_ sampling](lib/generation/top_p.py)
- [ ] Top-_k_ sampling
- [ ] Top-_p_ sampling
- [x] [Data loading](lib/dataloader/LlamaDataLoader.py)
- [x] Inference
- [x] Training
- [x] Parallelisation
- [ ] Data parallelism
- [x] [Model parallelism](lib/multihost_utils/shard_model_params_to_multihost.py)
- [ ] Other parallelisation schemes
- [ ] Documentation

The documentation of the library of this project is published on [GitHub Pages](https://ayaka14732.github.io/llama-2-jax/).

## Environment Setup

This project requires at least Python 3.11, JAX 0.4.16, PyTorch 2.1.0, Optax 0.1.8.dev0 and Transformers 4.32.0.dev0.
This project requires at least Python 3.11, JAX 0.4.18, PyTorch 2.1.0, Optax 0.1.8.dev0 and Transformers 4.32.0.dev0.

PyTorch and Transformers are needed for testing purposes. Additionally, the data loader depends on PyTorch `DataLoader`, while the profiling functionality requires TensorFlow.

Expand All @@ -69,6 +70,10 @@ pip install -U pip
pip install -U wheel
```

### Special configuration for TPU Pods

If you are running on TPU pods, you need to put the IP address of all other hosts in `~/podips.txt` (one IP address per line). Besides, you should make sure that the local host can SSH into itself and all other hosts listed in the file.

### Install the proper version of JAX

You need to follow the installation instructions on JAX's [official GitHub page](https://github.com/google/jax#installation).
Expand All @@ -79,6 +84,18 @@ Typically, you only need to install the CPU version of PyTorch since we perform

To install PyTorch, you can follow the [official installation guide](https://pytorch.org/get-started/locally/).

On TPU VMs, this is usually:

```sh
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
```

On TPU Pods:

```sh
./podrun -i -- ~/venv/bin/pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
```

### Install other dependencies

```sh
Expand All @@ -87,6 +104,14 @@ pip install git+https://github.com/deepmind/optax.git # https://github.com/goog
pip install -r requirements.txt
```

On TPU Pods:

```sh
./podrun -i -- ~/venv/bin/pip install git+https://github.com/huggingface/transformers.git
./podrun -i -- ~/venv/bin/pip install git+https://github.com/deepmind/optax.git
./podrun -iw -- ~/venv/bin/pip install -r requirements.txt
```

### Download LLaMA weights

LLaMA 1:
Expand Down Expand Up @@ -117,6 +142,12 @@ Alternatively, in case you are not using an interactive shell, you can login in
python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<YOUR_HUGGING_FACE_TOKEN>')"
```

On TPU Pods:

```sh
./podrun -i -- ~/venv/bin/python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('<YOUR_HUGGING_FACE_TOKEN>')"
```

### Convert parameters

```sh
Expand All @@ -125,10 +156,6 @@ python scripts/convert_params_runner.py llama2-7B
python scripts/convert_params_runner.py llama2-70B
```

### Special configuration for TPU Pods

If you are running on TPU pods, you need to put the IP address of all other hosts in `~/podips.txt` (one IP address per line). Besides, you should make sure that the local host can SSH into itself and all other hosts listed in the file.

### Generation

```sh
Expand Down
5 changes: 3 additions & 2 deletions hf_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from functools import partial
from itertools import chain, repeat
import json
from typing import NamedTuple

import jax
import jax.numpy as jnp
import json
import torch
from tqdm import tqdm
from transformers import LlamaConfig, LlamaTokenizer
from typing import NamedTuple

from lib.dataloader import LlamaDataLoader
from lib.gsm_data import GSMDataset
Expand Down
3 changes: 2 additions & 1 deletion lib/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from jax import Array
from typing import NamedTuple

from jax import Array

class TrainData(NamedTuple):
seq: Array
seq_mask: Array
Expand Down
3 changes: 2 additions & 1 deletion lib/gsm_data/GSMDataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
from torch.utils.data import Dataset
from typing import Literal, Union

from torch.utils.data import Dataset

def load_data(*, split=Union[Literal['train'], Literal['test']]):
path = os.path.join(f'../grade-school-math/grade_school_math/data/{split}.jsonl')
res = []
Expand Down
1 change: 1 addition & 0 deletions lib/gsm_data/gsm_collate_fn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import chain, repeat

import jax.numpy as jnp
from transformers import LlamaTokenizer

Expand Down
16 changes: 16 additions & 0 deletions lib/llama/ModelConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ class ModelConfig(NamedTuple):

model_config_llama2_7B = model_config_llama1_7B

model_config_llama2_13B = ModelConfig(
d_ff=13824,
d_k=128,
d_model=5120,
d_v=128,
dropout_rate=0.1,
n_heads_kv=40,
n_layers=40,
n_rep_kv=1,
rms_norm_eps=1e-6,
token_id_bos=1,
token_id_eos=2,
token_id_pad=0,
vocab_size=32000,
)

model_config_llama2_70B = ModelConfig(
d_ff=28672,
d_k=128,
Expand Down
3 changes: 2 additions & 1 deletion lib/llama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .ModelConfig import ModelConfig, model_config_dummy, model_config_llama1_7B, model_config_llama2_70B, model_config_llama2_7B
from .ModelConfig import ModelConfig, model_config_dummy, model_config_llama1_7B, model_config_llama2_13B, model_config_llama2_70B, model_config_llama2_7B
from .kv_cache import KVCache, init_kv_cache
from .llama import Llama, check_llama, forward_llama, init_llama
from .llama_model import LlamaModel, check_llama_model, forward_llama_model, init_llama_model
35 changes: 22 additions & 13 deletions lib/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import einops as op
from functools import partial
import math
from typing import Any, NamedTuple

import einops as op
import jax
from jax import Array
import jax.nn as nn
import jax.numpy as jnp
import jax.random as rand
import math
from typing import Any, NamedTuple

from .ModelConfig import ModelConfig
from .kv_cache import KVCache
from .rotary_embedding import forward_rotary_embedding

class Attention(NamedTuple):
Expand Down Expand Up @@ -38,21 +40,28 @@ def init_attention(*, key: Array, model_config: ModelConfig) -> Attention:
return Attention(q_proj, k_proj, v_proj, out_proj)

@partial(jax.jit, static_argnames=('model_config',))
def forward_attention(params: Attention, src_seq: Array, dst_seq: Array, attn_mask: Array, *, model_config: ModelConfig) -> Array:
q = op.einsum(src_seq, params.q_proj, 'batch_size src_seq_len d_model, d_model n_rep_kv n_heads_kv d_k -> batch_size n_rep_kv n_heads_kv src_seq_len d_k')
k = op.einsum(dst_seq, params.k_proj, 'batch_size dst_seq_len d_model, d_model n_heads_kv d_k -> batch_size n_heads_kv dst_seq_len d_k')
v = op.einsum(dst_seq, params.v_proj, 'batch_size dst_seq_len d_model, d_model n_heads_kv d_v -> batch_size n_heads_kv dst_seq_len d_v')
def forward_attention(params: Attention, src_seq: Array, dst_seq: Array, qk_mask: Array, *, cache_position: Array | None=None, kv_cache: KVCache | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
q = op.einsum(src_seq, params.q_proj, 'B S M, M R H K -> B R H S K')
k = op.einsum(dst_seq, params.k_proj, 'B D M, M H K -> B H D K')
v = op.einsum(dst_seq, params.v_proj, 'B D M, M H V -> B H D V')

if cache_position is not None and kv_cache is not None:
k_cache, v_cache = kv_cache
start_indices = jnp.array([0, 0, cache_position, 0], dtype=jnp.uint16)
k = jax.lax.dynamic_update_slice(k_cache, k, start_indices=start_indices)
v = jax.lax.dynamic_update_slice(v_cache, v, start_indices=start_indices)
kv_cache = KVCache(k, v)

q = forward_rotary_embedding(q)
k = forward_rotary_embedding(k)

qk = op.einsum(q, k, 'batch_size n_rep_kv n_heads_kv src_seq_len d_k, batch_size n_heads_kv dst_seq_len d_k -> batch_size n_rep_kv n_heads_kv src_seq_len dst_seq_len')
qk = op.einsum(q, k, 'B R H S K, B H D K -> B R H S D')
qk /= math.sqrt(model_config.d_k)
qk = jnp.where(attn_mask, qk, -jnp.inf)
qk = jnp.where(qk_mask, qk, -jnp.inf)
qk = nn.softmax(qk)
qk = jnp.where(attn_mask, qk, 0) # TODO: why this line?
qk = jnp.where(qk_mask, qk, 0) # TODO: why this line?

qkv = op.einsum(qk, v, 'batch_size n_rep_kv n_heads_kv src_seq_len dst_seq_len, batch_size n_heads_kv dst_seq_len d_v -> batch_size n_rep_kv n_heads_kv src_seq_len d_v')
qkv = op.einsum(qk, v, 'B R H S D, B H D V -> B R H S V')

out = op.einsum(qkv, params.out_proj, 'batch_size n_rep_kv n_heads_kv src_seq_len d_v, n_rep_kv n_heads_kv d_v d_model -> batch_size src_seq_len d_model')
return out
out = op.einsum(qkv, params.out_proj, 'B R H S V, R H V M -> B S M')
return out, kv_cache
17 changes: 12 additions & 5 deletions lib/llama/decoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from functools import partial

import jax
from jax import Array
import jax.numpy as jnp
import jax.random as rand

from ..rand_utils import split_key_nullable
from ..tree_utils import stack_leaves
from .ModelConfig import ModelConfig
from .decoder_block import DecoderBlock, DecoderBlock as Decoder, check_decoder_block, forward_decoder_block, init_decoder_block
from .kv_cache import KVCache

def check_decoder(params: Decoder, *, model_config: ModelConfig) -> None:
def inner(state, input_):
Expand All @@ -19,11 +22,15 @@ def init_decoder(*, key: Array, model_config: ModelConfig) -> Decoder:
return stack_leaves([init_decoder_block(key=subkey, model_config=model_config) for subkey in rand.split(key, num=model_config.n_layers)])

@partial(jax.jit, static_argnames=('model_config',))
def forward_decoder(params: Decoder, seq: Array, attn_mask: Array, *, key: Array | None, model_config: ModelConfig) -> Array:
def forward_decoder(params: Decoder, seq: Array, attn_mask: Array, *, cache_position: Array | None=None, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
qk_mask = jnp.tril(jnp.einsum('bi,bj->bij', attn_mask, attn_mask))[:, None, None]

def inner(state, input_):
key, seq = state
params, kv_cache = input_
key, subkey = split_key_nullable(key)
seq = forward_decoder_block(input_, seq, attn_mask, key=subkey, model_config=model_config)
return (key, seq), None
(key, seq), _ = jax.lax.scan(inner, (key, seq), params)
return seq
seq, kv_cache = forward_decoder_block(params, seq, qk_mask, cache_position=cache_position, kv_cache=kv_cache, key=subkey, model_config=model_config)
return (key, seq), kv_cache

(key, seq), kv_cache = jax.lax.scan(inner, (key, seq), (params, kv_cache))
return seq, kv_cache
14 changes: 8 additions & 6 deletions lib/llama/decoder_block.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from functools import partial
import math
from typing import Any, NamedTuple

import jax
from jax import Array
import jax.random as rand
import math
from typing import Any, NamedTuple

from ..rand_utils import split_key_nullable
from .attention import Attention, check_attention, forward_attention, init_attention
from .ModelConfig import ModelConfig
from .attention import Attention, check_attention, forward_attention, init_attention
from .dropout import forward_dropout
from .kv_cache import KVCache
from .rms_norm import check_rms_norm, forward_rms_norm, init_rms_norm

class DecoderBlock(NamedTuple):
Expand Down Expand Up @@ -46,12 +48,12 @@ def init_decoder_block(*, key: Array, model_config: ModelConfig) -> DecoderBlock
return DecoderBlock(input_norm, attention, post_attn_norm, gate_proj, up_proj, down_proj)

@partial(jax.jit, static_argnames=('model_config',))
def forward_decoder_block(params: DecoderBlock, seq: Array, attn_mask: Array, *, key: Array | None, model_config: ModelConfig) -> Array:
def forward_decoder_block(params: DecoderBlock, seq: Array, qk_mask: Array, *, cache_position: Array | None=None, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
key0, key1, key2 = split_key_nullable(key, num=3)

seq_ = seq
seq = forward_rms_norm(params.input_norm, seq, model_config=model_config)
seq = forward_attention(params.attention, seq, seq, attn_mask, model_config=model_config)
seq, kv_cache = forward_attention(params.attention, seq, seq, qk_mask, cache_position=cache_position, kv_cache=kv_cache, model_config=model_config)
seq = forward_dropout(seq, key=key0, model_config=model_config)
seq += seq_

Expand All @@ -63,4 +65,4 @@ def forward_decoder_block(params: DecoderBlock, seq: Array, attn_mask: Array, *,
seq = forward_dropout(seq, key=key2, model_config=model_config)
seq += seq_

return seq
return seq, kv_cache
3 changes: 2 additions & 1 deletion lib/llama/dropout.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from functools import partial

import jax
from jax import Array
import jax.random as rand

from .ModelConfig import ModelConfig

@partial(jax.jit, static_argnames=('model_config',))
def forward_dropout(x: Array, *, key: Array | None, model_config: ModelConfig) -> Array:
def forward_dropout(x: Array, *, key: Array | None=None, model_config: ModelConfig) -> Array:
if key is None or model_config.dropout_rate is None: # should disable dropout
return x

Expand Down
3 changes: 2 additions & 1 deletion lib/llama/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

from jax import Array
import jax.random as rand
import math

from .ModelConfig import ModelConfig

Expand Down
14 changes: 14 additions & 0 deletions lib/llama/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Any, NamedTuple

import jax.numpy as jnp

from .ModelConfig import ModelConfig

class KVCache(NamedTuple):
k_cache: Any # Array
v_cache: Any # Array

def init_kv_cache(batch_size: int, dst_len: int, *, model_config: ModelConfig) -> KVCache:
k_cache = jnp.zeros((model_config.n_layers, batch_size, model_config.n_heads_kv, dst_len, model_config.d_k))
v_cache = jnp.zeros((model_config.n_layers, batch_size, model_config.n_heads_kv, dst_len, model_config.d_v))
return KVCache(k_cache, v_cache)
14 changes: 8 additions & 6 deletions lib/llama/llama.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from functools import partial
import math
from typing import Any, NamedTuple

import jax
from jax import Array
import jax.random as rand
import math
from typing import Any, NamedTuple

from .ModelConfig import ModelConfig
from .kv_cache import KVCache
from .llama_model import LlamaModel, check_llama_model, forward_llama_model, init_llama_model

class Llama(NamedTuple):
Expand All @@ -27,7 +29,7 @@ def init_llama(*, key: Array, model_config: ModelConfig) -> Llama:
return Llama(model, lm_head)

@partial(jax.jit, static_argnames=('model_config'))
def forward_llama(params: Llama, seq: Array, attn_mask: Array, *, key: Array | None, model_config: ModelConfig) -> Array:
outputs = forward_llama_model(params.model, seq, attn_mask, key=key, model_config=model_config)
logits = outputs @ params.lm_head
return logits
def forward_llama(params: Llama, seq: Array, attn_mask: Array, *, cache_position: Array | None=None, kv_cache: KVCache | None=None, key: Array | None=None, model_config: ModelConfig) -> tuple[Array, KVCache | None]:
outputs, kv_cache = forward_llama_model(params.model, seq, attn_mask, cache_position=cache_position, kv_cache=kv_cache, key=key, model_config=model_config)
logits = outputs @ params.lm_head
return logits, kv_cache
Loading

0 comments on commit 18e9625

Please sign in to comment.