Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Update parameter description and update based on JEP 9263
  • Loading branch information
ayaka14732 committed Sep 25, 2023
1 parent 8f8fc94 commit 45b396a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 27 deletions.
37 changes: 16 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,35 +172,30 @@ On TPU pods, the command is:
## Model Configurations

- _B_: batch_size
- _K_: d_k
- _V_: d_v
- _F_: d_ff
- _M_: d_model
- _R_: n_rep_kv
- _H_: n_heads_kv
- _L_: seq_len
- _S_: src_seq_len
- _D_: dst_seq_len
- _C_: vocab_size
- _N_: n_layers
- _K_: d_k
- _V_: d_v
- _H_: n_heads_kv
- _R_: n_rep_kv
- _M_: d_model
- _F_: d_ff

| Name | Parameters | _C_ | _N_ | _H_ | _R_ | _M_ | _F_ |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| LLaMA 1 7B | 6738415616 | 32000 | 32 | 32 | 1 | 4096 | 11008 |
| LLaMA 1 13B | | 32000 | 40 | 40 | 1 | 5120 | |
| LLaMA 1 33B | | 32000 | 60 | 52 | 1 | 6656 | |
| LLaMA 1 65B | | 32000 | 80 | 64 | 1 | 8192 | |
| Llama 2 7B | 6738415616 | 32000 | 32 | 32 | 1 | 4096 | 11008 |
| Llama 2 13B | | 32000 | | | | | |
| Llama 2 70B | | 32000 | 80 | 8 | 8 | 8192 | 28672 |
| Name | Parameters | _C_ | _N_ | _K_/_V_ | _H_ | _R_ | _M_ | _F_ |
| -: | -: | -: | -: | -: | -: | -: | -: | -: |
| LLaMA 1 7B | 6738415616 | 32000 | 32 | 128 | 32 | 1 | 4096 | 11008 |
| Llama 2 7B | 6738415616 | 32000 | 32 | 128 | 32 | 1 | 4096 | 11008 |
| LLaMA 1 13B | | 32000 | 40 | | 40 | 1 | 5120 | |
| Llama 2 13B | 13015864320 | 32000 | 40 | 128 | 40 | 1 | 5120 | 13824 |
| LLaMA 1 33B | | 32000 | 60 | | 52 | 1 | 6656 | |
| LLaMA 1 65B | | 32000 | 80 | | 64 | 1 | 8192 | |
| Llama 2 70B | 68976648192 | 32000 | 80 | 128 | 8 | 8 | 8192 | 28672 |

```
n_parameters
= 2 * vocab_size * d_model
+ (2 * n_layers + 1) * d_model
+ 2 * n_layers * d_model * n_rep_kv * n_heads_kv * d_k
+ 2 * n_layers * d_model * n_heads_kv * d_k
+ 3 * n_layers * d_model * d_ff
n_params = 2CM + (2N + 1)M + 2NMRHK + 2NMHK + 3NMF
```

## Model Architecture
Expand Down
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main() -> None:
if is_process_0:
jax_smi.initialise_tracking(interval=0.5)

key = rand.PRNGKey(BEST_INTEGER)
key = rand.key(BEST_INTEGER)
max_len = 640
batch_size = 2
seed = HASHED_BUDDHA
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main() -> None:
print(jax.devices)
initialise_tracking()

key = rand.PRNGKey(BEST_INTEGER)
key = rand.key(BEST_INTEGER)
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
params = load_params('llama2-7B.pickle')
Expand Down
28 changes: 28 additions & 0 deletions scripts/determine_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from pathlib import Path; import sys; sys.path.append(str(Path(__file__).resolve().parent.parent))
from lib.proc_init_utils import initialise_cpu; initialise_cpu()

from transformers import LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')

q_size = model.model.layers[0].self_attn.q_proj.weight.shape[0] * model.model.layers[0].self_attn.q_proj.weight.shape[1]
k_size = model.model.layers[0].self_attn.k_proj.weight.shape[0] * model.model.layers[0].self_attn.k_proj.weight.shape[1]

c = model.model.embed_tokens.weight.shape[0]
f = model.model.layers[0].mlp.gate_proj.weight.shape[0]
m = model.model.layers[0].mlp.gate_proj.weight.shape[1]
n = len(model.model.layers)
r = q_size // k_size
h = model.config.num_attention_heads // r
k = q_size // m // h

n_params = sum(x.numel() for x in model.parameters())

print('C', c)
print('N', n)
print('K', k)
print('H', h)
print('R', r)
print('M', m)
print('F', f)
print('n_params', n_params)
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
def load_params_from_disk(path: str) -> Llama:
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
# params = init_llama(key=rand.PRNGKey(42), model_config=model_config_dummy)
# params = init_llama(key=rand.key(42), model_config=model_config_dummy)
params = load_params(path)
params = shard_model_params(params)
return params
Expand All @@ -52,8 +52,8 @@ def save_params_to_disk() -> None:

def save_params_signal_handler(signum, frame):
save_params_to_disk()
print(f'Signal {signum} received. Params have been saved to disk.')
exit(0)
print(f'Signal {signum} received. Model params have been successfully saved to disk.')
exit(-1)

@jax.value_and_grad
def train_forward(params: Llama, data_batch: TrainData, *, key: Array):
Expand Down Expand Up @@ -89,7 +89,7 @@ def main() -> None:
if is_process_0:
wandb.init(project='llama-finetuning-gsm', config=dict(learning_rate=lr, batch_size=batch_size * n_accumulation_steps, n_epochs=n_epochs, optimiser='adamw'))

key = rand.PRNGKey(seed)
key = rand.key(seed)
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
dataset = GSMDataset(split='train')
collate_fn = partial(gsm_collate_fn_train, tokenizer, max_len)
Expand Down

0 comments on commit 45b396a

Please sign in to comment.