Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Are there any non-recursive dataclass.astuple around?

dataclasses.astuple is recursive (according to the documentation):

Each dataclass is converted to a tuple of its field values. dataclasses, dicts, lists, and tuples are recursed into.

Indeed, consider example:

In [1]: from dataclasses import dataclass, astuple
   ...: from typing import List
   ...:
   ...: @dataclass
   ...: class Point:
   ...:     x: float
   ...:     y: float
   ...:
   ...: @dataclass
   ...: class Side:
   ...:     left: Point
   ...:     right: Point
   ...:
   ...: side = Side(Point(2, 3), Point(1, 2))
   ...: astuple(side)
   ...:
   ...:
Out[1]: ((2, 3), (1, 2))

Are there any simple built-in way to obtain two-tuple of points i.e. (Point(x=2, y=3), Point(x=1, y=2)) instead?

like image 591
Ilya V. Schurov Avatar asked Nov 16 '25 12:11

Ilya V. Schurov


1 Answers

An alternative approach that uses the same core as MisterMiyagi's answer but makes the result a little more reusable and cooperative with existing python tools by leveraging the iterator protocol, consider this dataclass-specific mixin:

from dataclasses import fields, is_dataclass

class UnpackDCMixin:

    def __iter__(self):
        if not is_dataclass(self):
            raise TypeError(f"This mixin is dataclass-only, which {type(self)} isn't.")
        return (getattr(self, field.name) for field in fields(self))

If you mix it into your dataclasses, you can just use the standard tuple builtin to turn instances into tuples, including any other tools that expect an iterable (such as list(), set(), or using it as the argument in a for-loop). And since the definition doesn't recurse into its field values, the resulting containers stay flat:

>>> @dataclass
... class Point(UnpackDCMixin):
...     x: float
...     y: float
... 
>>> @dataclass
... class Side(UnpackDCMixin):
...     left: Point
...     right: Point
...
>>> side = Side(Point(2, 3), Point(1, 2))
>>> tuple(side)         # plug into container constructors...
(Point(x=2, y=3), Point(x=1, y=2))
>>> for point in side:  # ...or other iterable-contexts
...     print(point)
Point(x=2, y=3)
Point(x=1, y=2)
like image 182
Arne Avatar answered Nov 18 '25 05:11

Arne