-
Notifications
You must be signed in to change notification settings - Fork 66
Labels
Description
Here's the reproducer to see the effect that nvFuser uses more memory than the equivalent PyTorch function (on a224936) which prints:
nvFuser Peak memory diff: 2.75 MiB
Eager Peak memory diff: 0.75 MiB
import torch
import gc
from nvfuser import FusionDefinition, DataType
N_PARALLEL_PATHS = 10
with FusionDefinition() as fd:
T0s = [fd.define_tensor(shape=[256, 256], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0]) for _ in range(N_PARALLEL_PATHS)]
a = fd.define_tensor(shape=[256, 256], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
for T0 in T0s:
T1 = fd.ops.relu(T0)
T2 = fd.ops.matmul(T1, T1)
T3 = fd.ops.relu(T2)
a = fd.ops.matmul(T3, a)
fd.add_output(a)
t0s = [torch.randn(256, 256, device="cuda") for _ in range(N_PARALLEL_PATHS)] # 0.25 MiB * N_PARALLEL_PATHS
a = torch.randn(256, 256, device="cuda") # 0.25 MiB
# Record peak memory usage
fd.execute([*t0s, a])
gc.collect(0)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024)
fd.execute([*t0s, a])
max_allocated_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024) - before_in_MiB
print(f"nvFuser Peak memory diff: {max_allocated_in_MiB:.2f} MiB")
def eager_func(t0s, a):
for t0 in t0s:
t1 = torch.nn.functional.relu(t0); del t0
t2 = torch.matmul(t1, t1); del t1
t3 = torch.nn.functional.relu(t2); del t2
a = torch.matmul(t3, a); del t3
return a
# # Record peak memory usage
eager_func(t0s, a)
gc.collect(0)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
before_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024)
eager_func(t0s, a)
max_allocated_in_MiB = torch.cuda.max_memory_allocated() / (1024 * 1024) - before_in_MiB
print(f" Eager Peak memory diff: {max_allocated_in_MiB:.2f} MiB")
This was discovered while working on a similar problem in Thunder and trying to send the whole example computational graph to nvFuser (Lightning-AI/lightning-thunder#1337).