Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch Python Distributed Multiprocessing: Gather/Concatenate tensor arrays of different lengths/sizes

If you have tensor arrays of different lengths across several gpu ranks, the default all_gather method does not work as it requires the lengths to be same.

For example, if you have:

if gpu == 0:
    q = torch.tensor([1.5, 2.3], device=torch.device(gpu))
else:
    q = torch.tensor([5.3], device=torch.device(gpu))

If I need to gather these two tensor arrays as follows:

all_q = [torch.tensor([1.5, 2.3], torch.tensor[5.3])

the default torch.all_gather does not work as the lengths, 2, 1 are different.

like image 266
omsrisagar Avatar asked Sep 05 '25 03:09

omsrisagar


2 Answers

As it is not directly possible to gather using built in methods, we need to write custom function with the following steps:

  1. Use dist.all_gather to get sizes of all arrays.
  2. Find the max size.
  3. Pad local array to max size using zeros/constants.
  4. Use dist.all_gather to get all padded arrays.
  5. Unpad the added zeros/constants using sizes found in step 1.

The below function does this:

def all_gather(q, ws, device):
    """
    Gathers tensor arrays of different lengths across multiple gpus
    
    Parameters
    ----------
        q : tensor array
        ws : world size
        device : current gpu device
        
    Returns
    -------
        all_q : list of gathered tensor arrays from all the gpus

    """
    local_size = torch.tensor(q.size(), device=device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(ws)]
    dist.all_gather(all_sizes, local_size)
    max_size = max(all_sizes)

    size_diff = max_size.item() - local_size.item()
    if size_diff:
        padding = torch.zeros(size_diff, device=device, dtype=q.dtype)
        q = torch.cat((q, padding))

    all_qs_padded = [torch.zeros_like(q) for _ in range(ws)]
    dist.all_gather(all_qs_padded, q)
    all_qs = []
    for q, size in zip(all_qs_padded, all_sizes):
        all_qs.append(q[:size])
    return all_qs

Once, we are able to do the above, we can then easily use torch.cat to further concatenate into a single array if needed:

torch.cat(all_q)
[torch.tensor([1.5, 2.3, 5.3])

Adapted from: github

like image 161
omsrisagar Avatar answered Sep 07 '25 16:09

omsrisagar


Here is an extension of @omsrisagar's solution that supports tensors of any number of dimensions (not only 1-dimensional tensors).

def all_gather_nd(tensor):
    """
    Gathers tensor arrays of different lengths in a list.
    The length dimension is 0. This supports any number of extra dimensions in the tensors.
    All the other dimensions should be equal between the tensors.

    Args:
        tensor (Tensor): Tensor to be broadcast from current process.

    Returns:
        (Tensor): output list of tensors that can be of different sizes
    """
    world_size = dist.get_world_size()
    local_size = torch.tensor(tensor.size(), device=tensor.device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
    dist.all_gather(all_sizes, local_size)

    max_length = max(size[0] for size in all_sizes)

    length_diff = max_length.item() - local_size[0].item()
    if length_diff:
        pad_size = (length_diff, *tensor.size()[1:])
        padding = torch.zeros(pad_size, device=tensor.device, dtype=tensor.dtype)
        tensor = torch.cat((tensor, padding))

    all_tensors_padded = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(all_tensors_padded, tensor)
    all_tensors = []
    for tensor_, size in zip(all_tensors_padded, all_sizes):
        all_tensors.append(tensor_[:size[0]])
    return all_tensors

Note that this requires that all the tensors have the same number of dimensions and have all their dimensions equal, except for the first dimension.

like image 29
jb0u Avatar answered Sep 07 '25 16:09

jb0u