Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Binding std::generator<T> with pybind11?

I am trying to bind a std::generator<T> to a Python generator through pybind11, I am using the following currently:

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <generator>
#include <ranges>

namespace py = pybind11;

std::generator<int> f1(int n = 0) {
    while (true) { co_yield n++; }
}

std::generator<int> f2(int n = 0) {
    co_yield n;
    co_yield std::ranges::elements_of(f1(n + 1));
}

PYBIND11_MODULE(test_generator, m) {
    py::class_<std::generator<int>>(m, "_generator_int", pybind11::module_local())
        .def("__iter__",
             [](std::generator<int>& gen) -> std::generator<int>& {
                 return gen;
             })
        .def("__next__", [](std::generator<int>& gen) {
            auto it = gen.begin();
            if (it != gen.end()) { return *it; }
            else                 { throw py::stop_iteration(); }
        });

    m.def("f1", &f1, py::arg("n") = 0);
    m.def("f2", &f2, py::arg("n") = 0);
}

The above code works with f1 but not f2 - If f2(0) is called in Python, it will only generate value 0 and 1, so I am assuming my implementation does not work with std::ranges::elements_of.

How can I make it work with std::ranges::elements_of without modifying f2?

like image 461
Holt Avatar asked Oct 15 '25 03:10

Holt


1 Answers

As pointed out in a comment, calling g.begin() more than once is undefined. I was able to fix the code by storing the iterator alongside the generator in an intermediate struct:

template <class T>
struct state {
    std::generator<T> g;
    decltype(g.begin()) it;

    state(std::generator<T> g) : g(std::move(g)), it(this->g.begin()) {}
};

PYBIND11_MODULE(test_generator, m)
{
    py::class_<state<int>>(m, "_generator_int", pybind11::module_local())
        .def("__iter__",
             [](state<int>& gen) -> state<int>& {
                 return gen;
             })
        .def("__next__", [](state<int>& s) {
            if (s.it != s.g.end()) {
                const auto v = *s.it;
                s.it++;
                return v;
            }
            else {
                throw py::stop_iteration();
            }
        });

    m.def("f1", [](int n) -> state<int> { return f1(n); }, py::arg("n") = 0);
    m.def("f2", [](int n) -> state<int> { return f2(n); }, py::arg("n") = 0);
}
like image 193
Holt Avatar answered Oct 18 '25 08:10

Holt



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!