diff options
author | Jelle Zijlstra <jelle.zijlstra@gmail.com> | 2022-04-16 10:11:18 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-16 10:11:18 -0700 |
commit | 35dff91370a382e312dd53d002d948e3efedb317 (patch) | |
tree | d9f2470d05a3464ed8104f7aef29e5a877fe301e | |
parent | 2acaa5acd01aeabb295e961913c111a7df52656d (diff) | |
download | typing-35dff91370a382e312dd53d002d948e3efedb317.tar.gz |
Add get_overloads() (#1140)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
-rw-r--r-- | typing_extensions/CHANGELOG | 5 | ||||
-rw-r--r-- | typing_extensions/README.rst | 6 | ||||
-rw-r--r-- | typing_extensions/src/test_typing_extensions.py | 74 | ||||
-rw-r--r-- | typing_extensions/src/typing_extensions.py | 70 |
4 files changed, 152 insertions, 3 deletions
diff --git a/typing_extensions/CHANGELOG b/typing_extensions/CHANGELOG index a9a5980..970bbd4 100644 --- a/typing_extensions/CHANGELOG +++ b/typing_extensions/CHANGELOG @@ -1,6 +1,9 @@ # Unreleased -- Add `typing.assert_type`. Backport from bpo-46480. +- Add `typing_extensions.get_overloads` and + `typing_extensions.clear_overloads`, and add registry support to + `typing_extensions.overload`. Backport from python/cpython#89263. +- Add `typing_extensions.assert_type`. Backport from bpo-46480. - Drop support for Python 3.6. Original patch by Adam Turner (@AA-Turner). # Release 4.1.1 (February 13, 2022) diff --git a/typing_extensions/README.rst b/typing_extensions/README.rst index 9abed04..3a23b75 100644 --- a/typing_extensions/README.rst +++ b/typing_extensions/README.rst @@ -47,6 +47,8 @@ This module currently contains the following: - ``assert_never`` - ``assert_type`` + - ``clear_overloads`` + - ``get_overloads`` - ``LiteralString`` (see PEP 675) - ``Never`` - ``NotRequired`` (see PEP 655) @@ -122,6 +124,10 @@ Certain objects were changed after they were added to ``typing``, and Python 3.8 and lack support for ``ParamSpecArgs`` and ``ParamSpecKwargs`` in 3.9. - ``@final`` was changed in Python 3.11 to set the ``.__final__`` attribute. +- ``@overload`` was changed in Python 3.11 to make function overloads + introspectable at runtime. In order to access overloads with + ``typing_extensions.get_overloads()``, you must use + ``@typing_extensions.overload``. There are a few types whose interface was modified between different versions of typing. For example, ``typing.Sequence`` was modified to diff --git a/typing_extensions/src/test_typing_extensions.py b/typing_extensions/src/test_typing_extensions.py index 1439e51..ab03244 100644 --- a/typing_extensions/src/test_typing_extensions.py +++ b/typing_extensions/src/test_typing_extensions.py @@ -3,6 +3,7 @@ import os import abc import contextlib import collections +from collections import defaultdict import collections.abc from functools import lru_cache import inspect @@ -10,6 +11,7 @@ import pickle import subprocess import types from unittest import TestCase, main, skipUnless, skipIf +from unittest.mock import patch from test import ann_module, ann_module2, ann_module3 import typing from typing import TypeVar, Optional, Union, Any, AnyStr @@ -21,9 +23,10 @@ import typing_extensions from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired -from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict +from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString from typing_extensions import assert_type, get_type_hints, get_origin, get_args +from typing_extensions import clear_overloads, get_overloads, overload # Flags used to mark tests that only apply after a specific # version of the typing module. @@ -403,6 +406,20 @@ class LiteralTests(BaseTestCase): Literal[1][1] +class MethodHolder: + @classmethod + def clsmethod(cls): ... + @staticmethod + def stmethod(): ... + def method(self): ... + + +if TYPING_3_11_0: + registry_holder = typing +else: + registry_holder = typing_extensions + + class OverloadTests(BaseTestCase): def test_overload_fails(self): @@ -424,6 +441,61 @@ class OverloadTests(BaseTestCase): blah() + def set_up_overloads(self): + def blah(): + pass + + overload1 = blah + overload(blah) + + def blah(): + pass + + overload2 = blah + overload(blah) + + def blah(): + pass + + return blah, [overload1, overload2] + + # Make sure we don't clear the global overload registry + @patch( + f"{registry_holder.__name__}._overload_registry", + defaultdict(lambda: defaultdict(dict)) + ) + def test_overload_registry(self): + registry = registry_holder._overload_registry + # The registry starts out empty + self.assertEqual(registry, {}) + + impl, overloads = self.set_up_overloads() + self.assertNotEqual(registry, {}) + self.assertEqual(list(get_overloads(impl)), overloads) + + def some_other_func(): pass + overload(some_other_func) + other_overload = some_other_func + def some_other_func(): pass + self.assertEqual(list(get_overloads(some_other_func)), [other_overload]) + + # Make sure that after we clear all overloads, the registry is + # completely empty. + clear_overloads() + self.assertEqual(registry, {}) + self.assertEqual(get_overloads(impl), []) + + # Querying a function with no overloads shouldn't change the registry. + def the_only_one(): pass + self.assertEqual(get_overloads(the_only_one), []) + self.assertEqual(registry, {}) + + def test_overload_registry_repeated(self): + for _ in range(2): + impl, overloads = self.set_up_overloads() + + self.assertEqual(list(get_overloads(impl)), overloads) + class AssertTypeTests(BaseTestCase): diff --git a/typing_extensions/src/typing_extensions.py b/typing_extensions/src/typing_extensions.py index d5e4049..4911099 100644 --- a/typing_extensions/src/typing_extensions.py +++ b/typing_extensions/src/typing_extensions.py @@ -1,6 +1,7 @@ import abc import collections import collections.abc +import functools import operator import sys import types as _types @@ -46,7 +47,9 @@ __all__ = [ 'Annotated', 'assert_never', 'assert_type', + 'clear_overloads', 'dataclass_transform', + 'get_overloads', 'final', 'get_args', 'get_origin', @@ -249,7 +252,72 @@ else: _overload_dummy = typing._overload_dummy # noqa -overload = typing.overload + + +if hasattr(typing, "get_overloads"): # 3.11+ + overload = typing.overload + get_overloads = typing.get_overloads + clear_overloads = typing.clear_overloads +else: + # {module: {qualname: {firstlineno: func}}} + _overload_registry = collections.defaultdict( + functools.partial(collections.defaultdict, dict) + ) + + def overload(func): + """Decorator for overloaded functions/methods. + + In a stub file, place two or more stub definitions for the same + function in a row, each decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + + In a non-stub file (i.e. a regular .py file), do the same but + follow it with an implementation. The implementation should *not* + be decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. + """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][ + f.__code__.co_firstlineno + ] = func + except AttributeError: + # Not a normal function; ignore. + pass + return _overload_dummy + + def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() # This is not a real generic class. Don't use outside annotations. |