diff --git a/README.md b/README.md index d2042ff..9870fed 100644 --- a/README.md +++ b/README.md @@ -172,35 +172,30 @@ On TPU pods, the command is: ## Model Configurations - _B_: batch_size -- _K_: d_k -- _V_: d_v -- _F_: d_ff -- _M_: d_model -- _R_: n_rep_kv -- _H_: n_heads_kv - _L_: seq_len - _S_: src_seq_len - _D_: dst_seq_len - _C_: vocab_size - _N_: n_layers +- _K_: d_k +- _V_: d_v +- _H_: n_heads_kv +- _R_: n_rep_kv +- _M_: d_model +- _F_: d_ff -| Name | Parameters | _C_ | _N_ | _H_ | _R_ | _M_ | _F_ | -| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | -| LLaMA 1 7B | 6738415616 | 32000 | 32 | 32 | 1 | 4096 | 11008 | -| LLaMA 1 13B | | 32000 | 40 | 40 | 1 | 5120 | | -| LLaMA 1 33B | | 32000 | 60 | 52 | 1 | 6656 | | -| LLaMA 1 65B | | 32000 | 80 | 64 | 1 | 8192 | | -| Llama 2 7B | 6738415616 | 32000 | 32 | 32 | 1 | 4096 | 11008 | -| Llama 2 13B | | 32000 | | | | | | -| Llama 2 70B | | 32000 | 80 | 8 | 8 | 8192 | 28672 | +| Name | Parameters | _C_ | _N_ | _K_/_V_ | _H_ | _R_ | _M_ | _F_ | +| -: | -: | -: | -: | -: | -: | -: | -: | -: | +| LLaMA 1 7B | 6738415616 | 32000 | 32 | 128 | 32 | 1 | 4096 | 11008 | +| Llama 2 7B | 6738415616 | 32000 | 32 | 128 | 32 | 1 | 4096 | 11008 | +| LLaMA 1 13B | | 32000 | 40 | | 40 | 1 | 5120 | | +| Llama 2 13B | 13015864320 | 32000 | 40 | 128 | 40 | 1 | 5120 | 13824 | +| LLaMA 1 33B | | 32000 | 60 | | 52 | 1 | 6656 | | +| LLaMA 1 65B | | 32000 | 80 | | 64 | 1 | 8192 | | +| Llama 2 70B | 68976648192 | 32000 | 80 | 128 | 8 | 8 | 8192 | 28672 | ``` - n_parameters -= 2 * vocab_size * d_model -+ (2 * n_layers + 1) * d_model -+ 2 * n_layers * d_model * n_rep_kv * n_heads_kv * d_k -+ 2 * n_layers * d_model * n_heads_kv * d_k -+ 3 * n_layers * d_model * d_ff +n_params = 2CM + (2N + 1)M + 2NMRHK + 2NMHK + 3NMF ``` ## Model Architecture diff --git a/evaluate.py b/evaluate.py index 7915da5..28d42d4 100644 --- a/evaluate.py +++ b/evaluate.py @@ -21,7 +21,7 @@ def main() -> None: if is_process_0: jax_smi.initialise_tracking(interval=0.5) - key = rand.PRNGKey(BEST_INTEGER) + key = rand.key(BEST_INTEGER) max_len = 640 batch_size = 2 seed = HASHED_BUDDHA diff --git a/generate.py b/generate.py index 48e8717..a32fbce 100644 --- a/generate.py +++ b/generate.py @@ -30,7 +30,7 @@ def main() -> None: print(jax.devices) initialise_tracking() - key = rand.PRNGKey(BEST_INTEGER) + key = rand.key(BEST_INTEGER) cpu_device = jax.devices('cpu')[0] with jax.default_device(cpu_device): params = load_params('llama2-7B.pickle') diff --git a/scripts/determine_params.py b/scripts/determine_params.py new file mode 100644 index 0000000..d97746e --- /dev/null +++ b/scripts/determine_params.py @@ -0,0 +1,28 @@ +from pathlib import Path; import sys; sys.path.append(str(Path(__file__).resolve().parent.parent)) +from lib.proc_init_utils import initialise_cpu; initialise_cpu() + +from transformers import LlamaForCausalLM + +model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf') + +q_size = model.model.layers[0].self_attn.q_proj.weight.shape[0] * model.model.layers[0].self_attn.q_proj.weight.shape[1] +k_size = model.model.layers[0].self_attn.k_proj.weight.shape[0] * model.model.layers[0].self_attn.k_proj.weight.shape[1] + +c = model.model.embed_tokens.weight.shape[0] +f = model.model.layers[0].mlp.gate_proj.weight.shape[0] +m = model.model.layers[0].mlp.gate_proj.weight.shape[1] +n = len(model.model.layers) +r = q_size // k_size +h = model.config.num_attention_heads // r +k = q_size // m // h + +n_params = sum(x.numel() for x in model.parameters()) + +print('C', c) +print('N', n) +print('K', k) +print('H', h) +print('R', r) +print('M', m) +print('F', f) +print('n_params', n_params) diff --git a/train.py b/train.py index ba22387..ecad983 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,7 @@ def load_params_from_disk(path: str) -> Llama: cpu_device = jax.devices('cpu')[0] with jax.default_device(cpu_device): - # params = init_llama(key=rand.PRNGKey(42), model_config=model_config_dummy) + # params = init_llama(key=rand.key(42), model_config=model_config_dummy) params = load_params(path) params = shard_model_params(params) return params @@ -52,8 +52,8 @@ def save_params_to_disk() -> None: def save_params_signal_handler(signum, frame): save_params_to_disk() - print(f'Signal {signum} received. Params have been saved to disk.') - exit(0) + print(f'Signal {signum} received. Model params have been successfully saved to disk.') + exit(-1) @jax.value_and_grad def train_forward(params: Llama, data_batch: TrainData, *, key: Array): @@ -89,7 +89,7 @@ def main() -> None: if is_process_0: wandb.init(project='llama-finetuning-gsm', config=dict(learning_rate=lr, batch_size=batch_size * n_accumulation_steps, n_epochs=n_epochs, optimiser='adamw')) - key = rand.PRNGKey(seed) + key = rand.key(seed) tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf') dataset = GSMDataset(split='train') collate_fn = partial(gsm_collate_fn_train, tokenizer, max_len)