Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch: How to do inference in batches (inference in parallel)

Tags:

pytorch

How to do inference in batches in PyTorch? How to do inference in parallel to speed up that part of the code.

I've started with the standard way of doing inference:

with torch.no_grad():
    for inputs, labels in dataloader['predict']:
        inputs = inputs.to(device)
        output = model(inputs)
        output = output.to(device)

And I've researched and the only mention of doing inference in parallel (in the same machine) seems to be with the library Dask: https://examples.dask.org/machine-learning/torch-prediction.html

Currently attempting to understand that library and create a working example. In the meanwhile do you know of a better way?

like image 469
WurmD Avatar asked Oct 17 '25 00:10

WurmD


1 Answers

In pytorch, the input tensors always have the batch dimension in the first dimension. Thus doing inference by batch is the default behavior, you just need to increase the batch dimension to larger than 1.

For example, if your single input is [1, 1], its input tensor is [[1, 1], ] with shape (1, 2). If you have two inputs [1, 1] and [2, 2], generate the input tensor as [[1, 1], [2, 2], ] with shape (2, 2). This is usually done in the batch generator function such as your dataloader.

like image 116
THN Avatar answered Oct 21 '25 15:10

THN