Skip to content
Advertisement

Why does GPU memory increase when recreating and reassigning a JAX numpy array to the same variable name?

When I recreate and reassign a JAX np array to the same variable name, for some reason the GPU memory nearly doubles the first recreation and then stays stable for subsequent recreations/reassignments.

Why does this happen and is this generally expected behavior for JAX arrays?

Fully runnable minimal example: https://colab.research.google.com/drive/1piUvyVylRBKm1xb1WsocsSVXJzvn5bdI?usp=sharing.

For posterity in case colab goes down:

%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax
from jax import numpy as jnp
from jax import random

# First creation of jnp array
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage from the first call is 618 MB

# Second creation of jnp array, reassigning it to the same variable name
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage is now 1130 MB - almost double!

# Third creation of jnp array, reassigning it to the same variable name
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage is stable at 1130 MB.

Thank you!

Advertisement

Answer

The reason for this behavior comes from the interaction of several things:

  1. Without pre-allocation, the GPU memory usage will grow as needed, but will not shrink when buffers are deleted.

  2. When you reassign a python variable, the old value still exists in memory until the Python garbage collector notices it is no longer referenced, and deletes it. This will take a small amount of time to occur in the background (you can call import gc; gc.collect() to force this to happen at any point).

  3. JAX sends instructions to the GPU asynchronously, meaning that once Python garbage-collects a GPU-backed value, the Python script may continue running for a short time before the corresponding buffer is actually removed from the device.

All of this means there’s some delay between unassigning the previous x value, and that memory being freed on the device, and if you’re immediately allocating a new value, the device will likely expand its memory allocation to fit the new array before the old one is deleted.

So why does the memory use stay constant on the third call? Well, by this time the first allocation has been removed, and so there is already space for the third allocation without growing the memory footprint.

With these things in mind, you can keep the allocation constant by putting a delay between deleting the old value and creating the new value; i.e. replace this:

x = jnp.ones(shape=(int(1e8),), dtype=float)

with this:

del x
time.sleep(1)
x = jnp.ones(shape=(int(1e8),), dtype=float)

When I run it this way, I see constant memory usage at 618MiB.

User contributions licensed under: CC BY-SA
1 People found this is helpful
Advertisement