Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optional tensors in PyTorch c++ extension

Tags:

c++

pytorch

torch

I'm writing a C++ extension for pytorch, and using the c++ api to do so. To my forward function, I need to pass an optional tensor. Inside the function, I want to do different things based on whether this optional parameter was passed or not. In general, we use NULL for optional pointer arguments in C++ and check inside the function if the pointer is NULL or not. I don't know how to do this for the at::Tensor type of Torch's c++ api.

void xyz_forward(
    const at::Tensor xyz1, 
    const at::Tensor xyz2, 
    const at::Tensor optional_constraints = something)
{
     if(optional_constraints){
        //do something
     }else{
        //do something else
     }
}

Note that, I can't do const at::Tensor optional_constraints = at::ones or something, because that parameter can take any real value and can be of varying size/shape. I can't assign it a numerical value as an optional argument. Is there a NULL equivalent for this?

like image 241
sanjeev mk Avatar asked Dec 07 '25 05:12

sanjeev mk


1 Answers

One possibility could be to use std::optional as std::optional<at::Tensor> optional_constraints = std::nullopt. It is contextually convertible to bool, so you can check it with if (optional_constraints). Use the .value() method to get the tensor if you pass one, otherwise the default value will be std::nullopt.

like image 154
cantordust Avatar answered Dec 08 '25 19:12

cantordust



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!