Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dataclasses: Matching Generic TypeVar names to attributes in the origin class

Say I have a Generic dataclass like the following:

from dataclasses import dataclass
from typing import TypeVar, Generic

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class Class(Generic[T, U]):
    foo: U
    bar: T


IntStrClass = Class[int, str]

When we read the code you can see that for IntStrClass:

  • the T lines up with int, which makes the type of bar an int.
  • the U lines up with str, which makes the type of foo a str.

But how can I figure this out progamatically?

I've been playing around with the typing module, but can't see from the outputs how I would match them up. What I have is:

from typing import get_type_hints, get_origin, get_args

print("Class field types:", get_type_hints(get_origin(IntStrClass)))
print("Class generic args:", get_args(IntStrClass))
Class field types: {'foo': ~U, 'bar': ~T}
Class generic args: (<class 'int'>, <class 'str'>)

What I'm missing here is from the definition of Class, to determine that T -> int and U -> str. If I had this information, then I could infer the proper types of foo and bar.

Thanks in advance!

like image 835
flakes Avatar asked Sep 12 '25 17:09

flakes


1 Answers

How about this?

[Has been significantly edited following a conversation in the comments.]

from dataclasses import dataclass
from typing import TypeVar, Generic, get_type_hints, get_args, get_origin

T = TypeVar('T')
U = TypeVar('U')


@dataclass
class Class(Generic[T, U]):
    foo: U
    spam: str
    bar: T
    baz: int


IntStrClass = Class[int, str]

def get_annotations(generic_subclass):
    generic_origin = get_origin(generic_subclass)
    annotations_map = get_type_hints(generic_origin)
    generic_args = get_args(generic_subclass)

    try:
        generic_params = generic_origin.__parameters__
    except AttributeError as err:
        raise AttributeError(
            f"{origin} has no attribute '__parameters__'. "
            "The likely cause of this is that the typing module's "
            "API for the Generic class has changed "
            "since this function was written."
            ) from err

    type_var_map = dict(zip(generic_params, generic_args))
    
    for field, annotation in annotations_map.items():
        if isinstance(annotation, TypeVar):
            annotations_map[field] = type_var_map[annotation]
            
    return annotations_map

print("Resolved attributes:", get_annotations(IntStrClass))
Resolved attributes: {'foo': <class 'str'>, 'spam': <class 'str'>, 'bar': <class 'int'>, 'baz': <class 'int'>}
like image 161
Alex Waygood Avatar answered Sep 14 '25 07:09

Alex Waygood