aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2022-04-16 10:11:18 -0700
committerGitHub <noreply@github.com>2022-04-16 10:11:18 -0700
commit35dff91370a382e312dd53d002d948e3efedb317 (patch)
treed9f2470d05a3464ed8104f7aef29e5a877fe301e
parent2acaa5acd01aeabb295e961913c111a7df52656d (diff)
downloadtyping-35dff91370a382e312dd53d002d948e3efedb317.tar.gz
Add get_overloads() (#1140)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
-rw-r--r--typing_extensions/CHANGELOG5
-rw-r--r--typing_extensions/README.rst6
-rw-r--r--typing_extensions/src/test_typing_extensions.py74
-rw-r--r--typing_extensions/src/typing_extensions.py70
4 files changed, 152 insertions, 3 deletions
diff --git a/typing_extensions/CHANGELOG b/typing_extensions/CHANGELOG
index a9a5980..970bbd4 100644
--- a/typing_extensions/CHANGELOG
+++ b/typing_extensions/CHANGELOG
@@ -1,6 +1,9 @@
# Unreleased
-- Add `typing.assert_type`. Backport from bpo-46480.
+- Add `typing_extensions.get_overloads` and
+ `typing_extensions.clear_overloads`, and add registry support to
+ `typing_extensions.overload`. Backport from python/cpython#89263.
+- Add `typing_extensions.assert_type`. Backport from bpo-46480.
- Drop support for Python 3.6. Original patch by Adam Turner (@AA-Turner).
# Release 4.1.1 (February 13, 2022)
diff --git a/typing_extensions/README.rst b/typing_extensions/README.rst
index 9abed04..3a23b75 100644
--- a/typing_extensions/README.rst
+++ b/typing_extensions/README.rst
@@ -47,6 +47,8 @@ This module currently contains the following:
- ``assert_never``
- ``assert_type``
+ - ``clear_overloads``
+ - ``get_overloads``
- ``LiteralString`` (see PEP 675)
- ``Never``
- ``NotRequired`` (see PEP 655)
@@ -122,6 +124,10 @@ Certain objects were changed after they were added to ``typing``, and
Python 3.8 and lack support for ``ParamSpecArgs`` and ``ParamSpecKwargs``
in 3.9.
- ``@final`` was changed in Python 3.11 to set the ``.__final__`` attribute.
+- ``@overload`` was changed in Python 3.11 to make function overloads
+ introspectable at runtime. In order to access overloads with
+ ``typing_extensions.get_overloads()``, you must use
+ ``@typing_extensions.overload``.
There are a few types whose interface was modified between different
versions of typing. For example, ``typing.Sequence`` was modified to
diff --git a/typing_extensions/src/test_typing_extensions.py b/typing_extensions/src/test_typing_extensions.py
index 1439e51..ab03244 100644
--- a/typing_extensions/src/test_typing_extensions.py
+++ b/typing_extensions/src/test_typing_extensions.py
@@ -3,6 +3,7 @@ import os
import abc
import contextlib
import collections
+from collections import defaultdict
import collections.abc
from functools import lru_cache
import inspect
@@ -10,6 +11,7 @@ import pickle
import subprocess
import types
from unittest import TestCase, main, skipUnless, skipIf
+from unittest.mock import patch
from test import ann_module, ann_module2, ann_module3
import typing
from typing import TypeVar, Optional, Union, Any, AnyStr
@@ -21,9 +23,10 @@ import typing_extensions
from typing_extensions import NoReturn, ClassVar, Final, IntVar, Literal, Type, NewType, TypedDict, Self
from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs, ParamSpecKwargs, TypeGuard
from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired
-from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict
+from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, final, is_typeddict
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
from typing_extensions import assert_type, get_type_hints, get_origin, get_args
+from typing_extensions import clear_overloads, get_overloads, overload
# Flags used to mark tests that only apply after a specific
# version of the typing module.
@@ -403,6 +406,20 @@ class LiteralTests(BaseTestCase):
Literal[1][1]
+class MethodHolder:
+ @classmethod
+ def clsmethod(cls): ...
+ @staticmethod
+ def stmethod(): ...
+ def method(self): ...
+
+
+if TYPING_3_11_0:
+ registry_holder = typing
+else:
+ registry_holder = typing_extensions
+
+
class OverloadTests(BaseTestCase):
def test_overload_fails(self):
@@ -424,6 +441,61 @@ class OverloadTests(BaseTestCase):
blah()
+ def set_up_overloads(self):
+ def blah():
+ pass
+
+ overload1 = blah
+ overload(blah)
+
+ def blah():
+ pass
+
+ overload2 = blah
+ overload(blah)
+
+ def blah():
+ pass
+
+ return blah, [overload1, overload2]
+
+ # Make sure we don't clear the global overload registry
+ @patch(
+ f"{registry_holder.__name__}._overload_registry",
+ defaultdict(lambda: defaultdict(dict))
+ )
+ def test_overload_registry(self):
+ registry = registry_holder._overload_registry
+ # The registry starts out empty
+ self.assertEqual(registry, {})
+
+ impl, overloads = self.set_up_overloads()
+ self.assertNotEqual(registry, {})
+ self.assertEqual(list(get_overloads(impl)), overloads)
+
+ def some_other_func(): pass
+ overload(some_other_func)
+ other_overload = some_other_func
+ def some_other_func(): pass
+ self.assertEqual(list(get_overloads(some_other_func)), [other_overload])
+
+ # Make sure that after we clear all overloads, the registry is
+ # completely empty.
+ clear_overloads()
+ self.assertEqual(registry, {})
+ self.assertEqual(get_overloads(impl), [])
+
+ # Querying a function with no overloads shouldn't change the registry.
+ def the_only_one(): pass
+ self.assertEqual(get_overloads(the_only_one), [])
+ self.assertEqual(registry, {})
+
+ def test_overload_registry_repeated(self):
+ for _ in range(2):
+ impl, overloads = self.set_up_overloads()
+
+ self.assertEqual(list(get_overloads(impl)), overloads)
+
class AssertTypeTests(BaseTestCase):
diff --git a/typing_extensions/src/typing_extensions.py b/typing_extensions/src/typing_extensions.py
index d5e4049..4911099 100644
--- a/typing_extensions/src/typing_extensions.py
+++ b/typing_extensions/src/typing_extensions.py
@@ -1,6 +1,7 @@
import abc
import collections
import collections.abc
+import functools
import operator
import sys
import types as _types
@@ -46,7 +47,9 @@ __all__ = [
'Annotated',
'assert_never',
'assert_type',
+ 'clear_overloads',
'dataclass_transform',
+ 'get_overloads',
'final',
'get_args',
'get_origin',
@@ -249,7 +252,72 @@ else:
_overload_dummy = typing._overload_dummy # noqa
-overload = typing.overload
+
+
+if hasattr(typing, "get_overloads"): # 3.11+
+ overload = typing.overload
+ get_overloads = typing.get_overloads
+ clear_overloads = typing.clear_overloads
+else:
+ # {module: {qualname: {firstlineno: func}}}
+ _overload_registry = collections.defaultdict(
+ functools.partial(collections.defaultdict, dict)
+ )
+
+ def overload(func):
+ """Decorator for overloaded functions/methods.
+
+ In a stub file, place two or more stub definitions for the same
+ function in a row, each decorated with @overload. For example:
+
+ @overload
+ def utf8(value: None) -> None: ...
+ @overload
+ def utf8(value: bytes) -> bytes: ...
+ @overload
+ def utf8(value: str) -> bytes: ...
+
+ In a non-stub file (i.e. a regular .py file), do the same but
+ follow it with an implementation. The implementation should *not*
+ be decorated with @overload. For example:
+
+ @overload
+ def utf8(value: None) -> None: ...
+ @overload
+ def utf8(value: bytes) -> bytes: ...
+ @overload
+ def utf8(value: str) -> bytes: ...
+ def utf8(value):
+ # implementation goes here
+
+ The overloads for a function can be retrieved at runtime using the
+ get_overloads() function.
+ """
+ # classmethod and staticmethod
+ f = getattr(func, "__func__", func)
+ try:
+ _overload_registry[f.__module__][f.__qualname__][
+ f.__code__.co_firstlineno
+ ] = func
+ except AttributeError:
+ # Not a normal function; ignore.
+ pass
+ return _overload_dummy
+
+ def get_overloads(func):
+ """Return all defined overloads for *func* as a sequence."""
+ # classmethod and staticmethod
+ f = getattr(func, "__func__", func)
+ if f.__module__ not in _overload_registry:
+ return []
+ mod_dict = _overload_registry[f.__module__]
+ if f.__qualname__ not in mod_dict:
+ return []
+ return list(mod_dict[f.__qualname__].values())
+
+ def clear_overloads():
+ """Clear all overloads in the registry."""
+ _overload_registry.clear()
# This is not a real generic class. Don't use outside annotations.