Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train.py OOM on TPUv3-8 #10

Open
ethanhe42 opened this issue Sep 10, 2023 · 9 comments
Open

train.py OOM on TPUv3-8 #10

ethanhe42 opened this issue Sep 10, 2023 · 9 comments

Comments

@ethanhe42
Copy link

 1 Successfully loaded and sharded model parameters!
 2   0%|                                                                                   | 0/155 [00:00<?, ?it/s]Traceback (most recent call last):
 3   File "/home/yihuihe_yh/llama-2-jax/train.py", line 113, in <module>
 4     main()
 5   File "/home/yihuihe_yh/llama-2-jax/train.py", line 101, in main
 6     params, opt_state, total_loss, loss, key = train_step(params, opt_state, total_loss, data_batch, key)
 7                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 8   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
 9     return fun(*args, **kwargs)
10            ^^^^^^^^^^^^^^^^^^^^
11   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 253, in cache_miss
12     outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
13                                                  ^^^^^^^^^^^^^^^^^^^^
14   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 166, in _python_pjit_helper
15     out_flat = pjit_p.bind(*args_flat, **params)
16                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
18     return self.bind_with_trace(top_trace, args, params)
19            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
21     out = trace.process_primitive(self, map(trace.full_raise, args), params)
22           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
24     return primitive.impl(*tracers, **params)
25            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1209, in _pjit_call_impl
27     return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
28            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1192, in call_impl_cache_miss
30     out_flat, compiled = _pjit_call_impl_python(
31                          ^^^^^^^^^^^^^^^^^^^^^^^
32   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1148, in _pjit_call_impl_python
33     return compiled.unsafe_call(*args), compiled
34            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
35   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
36     return func(*args, **kwargs)
37            ^^^^^^^^^^^^^^^^^^^^^
38   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1229, in __call__
39     results = self.xla_executable.execute_sharded(input_bufs)
40               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41 jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 344.00M. That was not possible. There are 91.66M free.; (0x1x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
42 The stack trace below excludes JAX-internal frames.
43 The preceding is the original exception that occurred, unmodified.
44 --------------------
45 The above exception was the direct cause of the following exception:
46 Traceback (most recent call last):
47   File "/home/yihuihe_yh/llama-2-jax/train.py", line 113, in <module>
48     main()
49   File "/home/yihuihe_yh/llama-2-jax/train.py", line 101, in main
50     params, opt_state, total_loss, loss, key = train_step(params, opt_state, total_loss, data_batch, key)
51                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
52 jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 344.00M. That was not possible. There are 91.66M free.; (0x1x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
@ayaka14732
Copy link
Owner

I got the same error

@ethanhe42
Copy link
Author

I guess TPUv3-8 doesn't have enough memory. Is there easy way to use PEFT for fine-tuning?

@ayaka14732
Copy link
Owner

I've tested on TPU v3-32 and there is no OOM error

@ayaka14732
Copy link
Owner

Is there easy way to use PEFT for fine-tuning?

I haven't started to implement PEFT in this project yet

@zhangzx-uiuc
Copy link

 1 Successfully loaded and sharded model parameters!
 2   0%|                                                                                   | 0/155 [00:00<?, ?it/s]Traceback (most recent call last):
 3   File "/home/yihuihe_yh/llama-2-jax/train.py", line 113, in <module>
 4     main()
 5   File "/home/yihuihe_yh/llama-2-jax/train.py", line 101, in main
 6     params, opt_state, total_loss, loss, key = train_step(params, opt_state, total_loss, data_batch, key)
 7                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 8   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
 9     return fun(*args, **kwargs)
10            ^^^^^^^^^^^^^^^^^^^^
11   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 253, in cache_miss
12     outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
13                                                  ^^^^^^^^^^^^^^^^^^^^
14   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 166, in _python_pjit_helper
15     out_flat = pjit_p.bind(*args_flat, **params)
16                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/core.py", line 2596, in bind
18     return self.bind_with_trace(top_trace, args, params)
19            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/core.py", line 389, in bind_with_trace
21     out = trace.process_primitive(self, map(trace.full_raise, args), params)
22           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/core.py", line 821, in process_primitive
24     return primitive.impl(*tracers, **params)
25            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1209, in _pjit_call_impl
27     return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
28            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1192, in call_impl_cache_miss
30     out_flat, compiled = _pjit_call_impl_python(
31                          ^^^^^^^^^^^^^^^^^^^^^^^
32   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1148, in _pjit_call_impl_python
33     return compiled.unsafe_call(*args), compiled
34            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
35   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
36     return func(*args, **kwargs)
37            ^^^^^^^^^^^^^^^^^^^^^
38   File "/home/yihuihe_yh/llama-2-jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1229, in __call__
39     results = self.xla_executable.execute_sharded(input_bufs)
40               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41 jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 344.00M. That was not possible. There are 91.66M free.; (0x1x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
42 The stack trace below excludes JAX-internal frames.
43 The preceding is the original exception that occurred, unmodified.
44 --------------------
45 The above exception was the direct cause of the following exception:
46 Traceback (most recent call last):
47   File "/home/yihuihe_yh/llama-2-jax/train.py", line 113, in <module>
48     main()
49   File "/home/yihuihe_yh/llama-2-jax/train.py", line 101, in main
50     params, opt_state, total_loss, loss, key = train_step(params, opt_state, total_loss, data_batch, key)
51                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
52 jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 344.00M. That was not possible. There are 91.66M free.; (0x1x0_HBM0): while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

Will it be working if you use float16 or bfloat16 training?

@ayaka14732
Copy link
Owner

ayaka14732 commented Oct 4, 2023

@zhangzx-uiuc Using float32 won't fit on TPU v3-8. For bfloat16, I remember that I tried it and I still got the OOM error. I have identified the issue in google-deepmind/optax#472 (comment). It seems that this issue has been fixed, so it might work on the main branch of the optax library.

Besides, I contacted the OP of google-deepmind/optax#377, and I learnt that "it is bad practice to keep the actual params in bf16 during training". I think the performance would be better if we stick to float32 during training and save the parameters in bfloat16.

Another thing that is worth noticing is the precision of the rotary embedding: https://www.qbitai.com/2023/08/78565.html. I haven't fixed this yet.

@lodestone-rock
Copy link

@ayaka14732 if you're using multisteps you need to add context manager to init it inside CPU memory instead of TPU device:0

with jax.default_device(jax.devices("cpu")[0]):
    # your optax multisteps code here 

@Beomi
Copy link

Beomi commented Nov 27, 2023

@lodestone-rock
Hi, I'm using this code to use grad accumulation:

optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)

Should I use MutliSteps optimizer like this to avoid TPU OOM?

with jax.default_device(jax.devices("cpu")[0]):
    # your optax multisteps code here 
    optimizer = optax.MultiSteps(optimizer, config.accumulate_gradient_steps)

@lodestone-rock
Copy link

@Beomi
yes, MultiSteps init the accumulator weight directly on the accelerator 0 for some reason. with that context manager it should tell jax to init it on the host

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants