Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,343 @@
from __future__ import annotations
from mypy.argmap import map_actuals_to_formals
from mypy.fixup import TypeFixer
from mypy.nodes import (
ARG_POS,
MDEF,
SYMBOL_FUNCBASE_TYPES,
Argument,
Block,
CallExpr,
ClassDef,
Decorator,
Expression,
FuncDef,
JsonDict,
NameExpr,
Node,
PassStmt,
RefExpr,
SymbolTableNode,
Var,
)
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.semanal_shared import (
ALLOW_INCOMPATIBLE_OVERRIDE,
parse_bool,
require_bool_literal_argument,
set_callable_name,
)
from mypy.typeops import try_getting_str_literals as try_getting_str_literals
from mypy.types import (
AnyType,
CallableType,
Instance,
LiteralType,
NoneType,
Overloaded,
Type,
TypeOfAny,
TypeType,
TypeVarType,
deserialize_type,
get_proper_type,
)
from mypy.types_utils import is_overlapping_none
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
def _get_decorator_bool_argument(ctx: ClassDefContext, name: str, default: bool) -> bool:
"""Return the bool argument for the decorator.
This handles both @decorator(...) and @decorator.
"""
if isinstance(ctx.reason, CallExpr):
return _get_bool_argument(ctx, ctx.reason, name, default)
else:
return default
def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: bool) -> bool:
"""Return the boolean value for an argument to a call or the
default if it's not found.
"""
attr_value = _get_argument(expr, name)
if attr_value:
return require_bool_literal_argument(ctx.api, attr_value, name, default)
return default
def _get_argument(call: CallExpr, name: str) -> Expression | None:
"""Return the expression for the specific argument."""
# To do this we use the CallableType of the callee to find the FormalArgument,
# then walk the actual CallExpr looking for the appropriate argument.
#
# Note: I'm not hard-coding the index so that in the future we can support other
# attrib and class makers.
callee_type = _get_callee_type(call)
if not callee_type:
return None
argument = callee_type.argument_by_name(name)
if not argument:
return None
assert argument.name
for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)):
if argument.pos is not None and not attr_name and i == argument.pos:
return attr_value
if attr_name == argument.name:
return attr_value
return None
def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType:
"""Perform limited lookup of a matching overload item.
Full overload resolution is only supported during type checking, but plugins
sometimes need to resolve overloads. This can be used in some such use cases.
Resolve overloads based on these things only:
* Match using argument kinds and names
* If formal argument has type None, only accept the "None" expression in the callee
* If formal argument has type Literal[True] or Literal[False], only accept the
relevant bool literal
Return the first matching overload item, or the last one if nothing matches.
"""
for item in overload.items[:-1]:
ok = True
mapped = map_actuals_to_formals(
call.arg_kinds,
call.arg_names,
item.arg_kinds,
item.arg_names,
lambda i: AnyType(TypeOfAny.special_form),
)
# Look for extra actuals
matched_actuals = set()
for actuals in mapped:
matched_actuals.update(actuals)
if any(i not in matched_actuals for i in range(len(call.args))):
ok = False
for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped):
if kind.is_required() and not actuals:
# Missing required argument
ok = False
break
elif actuals:
args = [call.args[i] for i in actuals]
arg_type = get_proper_type(arg_type)
arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args)
if isinstance(arg_type, NoneType):
if not arg_none:
ok = False
break
elif (
arg_none
and not is_overlapping_none(arg_type)
and not (
isinstance(arg_type, Instance)
and arg_type.type.fullname == "builtins.object"
)
and not isinstance(arg_type, AnyType)
):
ok = False
break
elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool:
if not any(parse_bool(arg) == arg_type.value for arg in args):
ok = False
break
if ok:
return item
return overload.items[-1]
def _get_callee_type(call: CallExpr) -> CallableType | None:
"""Return the type of the callee, regardless of its syntatic form."""
callee_node: Node | None = call.callee
if isinstance(callee_node, RefExpr):
callee_node = callee_node.node
# Some decorators may be using typing.dataclass_transform, which is itself a decorator, so we
# need to unwrap them to get at the true callee
if isinstance(callee_node, Decorator):
callee_node = callee_node.func
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
callee_node_type = get_proper_type(callee_node.type)
if isinstance(callee_node_type, Overloaded):
return find_shallow_matching_overload_item(callee_node_type, call)
elif isinstance(callee_node_type, CallableType):
return callee_node_type
return None
def add_method(
ctx: ClassDefContext,
name: str,
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
"""
Adds a new method to a class.
Deprecated, use add_method_to_class() instead.
"""
add_method_to_class(
ctx.api,
ctx.cls,
name=name,
args=args,
return_type=return_type,
self_type=self_type,
tvar_def=tvar_def,
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
def add_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
"""Adds a new method to a class definition."""
assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."
info = cls.info
# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
if name in info.names:
sym = info.names[name]
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)
if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type("builtins.function")
else:
function_type = api.named_generic_type("builtins.function", [])
if is_classmethod:
self_type = self_type or TypeType(fill_typevars(info))
first = [Argument(Var("_cls"), self_type, None, ARG_POS, True)]
elif is_staticmethod:
first = []
else:
self_type = self_type or fill_typevars(info)
first = [Argument(Var("self"), self_type, None, ARG_POS)]
args = first + args
arg_types, arg_names, arg_kinds = [], [], []
for arg in args:
assert arg.type_annotation, "All arguments must be fully typed."
arg_types.append(arg.type_annotation)
arg_names.append(arg.variable.name)
arg_kinds.append(arg.kind)
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
func.type = set_callable_name(signature, func)
func.is_class = is_classmethod
func.is_static = is_staticmethod
func._fullname = info.fullname + "." + name
func.line = info.line
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
if is_staticmethod:
func.is_decorated = True
v = Var(name, func.type)
v.info = info
v._fullname = func._fullname
v.is_staticmethod = True
dec = Decorator(func, [], v)
dec.line = info.line
sym = SymbolTableNode(MDEF, dec)
else:
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
info.names[name] = sym
info.defn.defs.body.append(func)
def add_attribute_to_class(
api: SemanticAnalyzerPluginInterface,
cls: ClassDef,
name: str,
typ: Type,
final: bool = False,
no_serialize: bool = False,
override_allow_incompatible: bool = False,
fullname: str | None = None,
is_classvar: bool = False,
) -> None:
"""
Adds a new attribute to a class definition.
This currently only generates the symbol table entry and no corresponding AssignmentStatement
"""
info = cls.info
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
node = Var(name, typ)
node.info = info
node.is_final = final
node.is_classvar = is_classvar
if name in ALLOW_INCOMPATIBLE_OVERRIDE:
node.allow_incompatible_override = True
else:
node.allow_incompatible_override = override_allow_incompatible
if fullname:
node._fullname = fullname
else:
node._fullname = info.fullname + "." + name
info.names[name] = SymbolTableNode(
MDEF, node, plugin_generated=True, no_serialize=no_serialize
)
def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ

View File

@@ -0,0 +1,245 @@
"""Plugin to provide accurate types for some parts of the ctypes module."""
from __future__ import annotations
# Fully qualified instead of "from mypy.plugin import ..." to avoid circular import problems.
import mypy.plugin
from mypy import nodes
from mypy.maptype import map_instance_to_supertype
from mypy.messages import format_type
from mypy.subtypes import is_subtype
from mypy.typeops import make_simplified_union
from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
ProperType,
Type,
TypeOfAny,
UnionType,
flatten_nested_unions,
get_proper_type,
)
def _find_simplecdata_base_arg(
tp: Instance, api: mypy.plugin.CheckerPluginInterface
) -> ProperType | None:
"""Try to find a parametrized _SimpleCData in tp's bases and return its single type argument.
None is returned if _SimpleCData appears nowhere in tp's (direct or indirect) bases.
"""
if tp.type.has_base("_ctypes._SimpleCData"):
simplecdata_base = map_instance_to_supertype(
tp,
api.named_generic_type("_ctypes._SimpleCData", [AnyType(TypeOfAny.special_form)]).type,
)
assert len(simplecdata_base.args) == 1, "_SimpleCData takes exactly one type argument"
return get_proper_type(simplecdata_base.args[0])
return None
def _autoconvertible_to_cdata(tp: Type, api: mypy.plugin.CheckerPluginInterface) -> Type:
"""Get a type that is compatible with all types that can be implicitly converted to the given
CData type.
Examples:
* c_int -> Union[c_int, int]
* c_char_p -> Union[c_char_p, bytes, int, NoneType]
* MyStructure -> MyStructure
"""
allowed_types = []
# If tp is a union, we allow all types that are convertible to at least one of the union
# items. This is not quite correct - strictly speaking, only types convertible to *all* of the
# union items should be allowed. This may be worth changing in the future, but the more
# correct algorithm could be too strict to be useful.
for t in flatten_nested_unions([tp]):
t = get_proper_type(t)
# Every type can be converted from itself (obviously).
allowed_types.append(t)
if isinstance(t, Instance):
unboxed = _find_simplecdata_base_arg(t, api)
if unboxed is not None:
# If _SimpleCData appears in tp's (direct or indirect) bases, its type argument
# specifies the type's "unboxed" version, which can always be converted back to
# the original "boxed" type.
allowed_types.append(unboxed)
if t.type.has_base("ctypes._PointerLike"):
# Pointer-like _SimpleCData subclasses can also be converted from
# an int or None.
allowed_types.append(api.named_generic_type("builtins.int", []))
allowed_types.append(NoneType())
return make_simplified_union(allowed_types)
def _autounboxed_cdata(tp: Type) -> ProperType:
"""Get the auto-unboxed version of a CData type, if applicable.
For *direct* _SimpleCData subclasses, the only type argument of _SimpleCData in the bases list
is returned.
For all other CData types, including indirect _SimpleCData subclasses, tp is returned as-is.
"""
tp = get_proper_type(tp)
if isinstance(tp, UnionType):
return make_simplified_union([_autounboxed_cdata(t) for t in tp.items])
elif isinstance(tp, Instance):
for base in tp.type.bases:
if base.type.fullname == "_ctypes._SimpleCData":
# If tp has _SimpleCData as a direct base class,
# the auto-unboxed type is the single type argument of the _SimpleCData type.
assert len(base.args) == 1
return get_proper_type(base.args[0])
# If tp is not a concrete type, or if there is no _SimpleCData in the bases,
# the type is not auto-unboxed.
return tp
def _get_array_element_type(tp: Type) -> ProperType | None:
"""Get the element type of the Array type tp, or None if not specified."""
tp = get_proper_type(tp)
if isinstance(tp, Instance):
assert tp.type.fullname == "_ctypes.Array"
if len(tp.args) == 1:
return get_proper_type(tp.args[0])
return None
def array_constructor_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Callback to provide an accurate signature for the ctypes.Array constructor."""
# Extract the element type from the constructor's return type, i. e. the type of the array
# being constructed.
et = _get_array_element_type(ctx.default_return_type)
if et is not None:
allowed = _autoconvertible_to_cdata(et, ctx.api)
assert (
len(ctx.arg_types) == 1
), "The stub of the ctypes.Array constructor should have a single vararg parameter"
for arg_num, (arg_kind, arg_type) in enumerate(zip(ctx.arg_kinds[0], ctx.arg_types[0]), 1):
if arg_kind == nodes.ARG_POS and not is_subtype(arg_type, allowed):
ctx.api.msg.fail(
"Array constructor argument {} of type {}"
" is not convertible to the array element type {}".format(
arg_num,
format_type(arg_type, ctx.api.options),
format_type(et, ctx.api.options),
),
ctx.context,
)
elif arg_kind == nodes.ARG_STAR:
ty = ctx.api.named_generic_type("typing.Iterable", [allowed])
if not is_subtype(arg_type, ty):
it = ctx.api.named_generic_type("typing.Iterable", [et])
ctx.api.msg.fail(
"Array constructor argument {} of type {}"
" is not convertible to the array element type {}".format(
arg_num,
format_type(arg_type, ctx.api.options),
format_type(it, ctx.api.options),
),
ctx.context,
)
return ctx.default_return_type
def array_getitem_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Callback to provide an accurate return type for ctypes.Array.__getitem__."""
et = _get_array_element_type(ctx.type)
if et is not None:
unboxed = _autounboxed_cdata(et)
assert (
len(ctx.arg_types) == 1
), "The stub of ctypes.Array.__getitem__ should have exactly one parameter"
assert (
len(ctx.arg_types[0]) == 1
), "ctypes.Array.__getitem__'s parameter should not be variadic"
index_type = get_proper_type(ctx.arg_types[0][0])
if isinstance(index_type, Instance):
if index_type.type.has_base("builtins.int"):
return unboxed
elif index_type.type.has_base("builtins.slice"):
return ctx.api.named_generic_type("builtins.list", [unboxed])
return ctx.default_return_type
def array_setitem_callback(ctx: mypy.plugin.MethodSigContext) -> CallableType:
"""Callback to provide an accurate signature for ctypes.Array.__setitem__."""
et = _get_array_element_type(ctx.type)
if et is not None:
allowed = _autoconvertible_to_cdata(et, ctx.api)
assert len(ctx.default_signature.arg_types) == 2
index_type = get_proper_type(ctx.default_signature.arg_types[0])
if isinstance(index_type, Instance):
arg_type = None
if index_type.type.has_base("builtins.int"):
arg_type = allowed
elif index_type.type.has_base("builtins.slice"):
arg_type = ctx.api.named_generic_type("builtins.list", [allowed])
if arg_type is not None:
# Note: arg_type can only be None if index_type is invalid, in which case we use
# the default signature and let mypy report an error about it.
return ctx.default_signature.copy_modified(
arg_types=ctx.default_signature.arg_types[:1] + [arg_type]
)
return ctx.default_signature
def array_iter_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Callback to provide an accurate return type for ctypes.Array.__iter__."""
et = _get_array_element_type(ctx.type)
if et is not None:
unboxed = _autounboxed_cdata(et)
return ctx.api.named_generic_type("typing.Iterator", [unboxed])
return ctx.default_return_type
def array_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""Callback to provide an accurate type for ctypes.Array.value."""
et = _get_array_element_type(ctx.type)
if et is not None:
types: list[Type] = []
for tp in flatten_nested_unions([et]):
tp = get_proper_type(tp)
if isinstance(tp, AnyType):
types.append(AnyType(TypeOfAny.from_another_any, source_any=tp))
elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_char":
types.append(ctx.api.named_generic_type("builtins.bytes", []))
elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_wchar":
types.append(ctx.api.named_generic_type("builtins.str", []))
else:
ctx.api.msg.fail(
'Array attribute "value" is only available'
' with element type "c_char" or "c_wchar", not {}'.format(
format_type(et, ctx.api.options)
),
ctx.context,
)
return make_simplified_union(types)
return ctx.default_attr_type
def array_raw_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""Callback to provide an accurate type for ctypes.Array.raw."""
et = _get_array_element_type(ctx.type)
if et is not None:
types: list[Type] = []
for tp in flatten_nested_unions([et]):
tp = get_proper_type(tp)
if (
isinstance(tp, AnyType)
or isinstance(tp, Instance)
and tp.type.fullname == "ctypes.c_char"
):
types.append(ctx.api.named_generic_type("builtins.bytes", []))
else:
ctx.api.msg.fail(
'Array attribute "raw" is only available'
' with element type "c_char", not {}'.format(format_type(et, ctx.api.options)),
ctx.context,
)
return make_simplified_union(types)
return ctx.default_attr_type

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,504 @@
from __future__ import annotations
from functools import partial
from typing import Callable
import mypy.errorcodes as codes
from mypy import message_registry
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
from mypy.plugin import (
AttributeContext,
ClassDefContext,
FunctionContext,
FunctionSigContext,
MethodContext,
MethodSigContext,
Plugin,
)
from mypy.plugins.common import try_getting_str_literals
from mypy.subtypes import is_subtype
from mypy.typeops import is_literal_type_like, make_simplified_union
from mypy.types import (
TPDICT_FB_NAMES,
AnyType,
CallableType,
FunctionLike,
Instance,
LiteralType,
NoneType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
TypeVarType,
UnionType,
get_proper_type,
get_proper_types,
)
class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
from mypy.plugins import ctypes, singledispatch
if fullname == "_ctypes.Array":
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
return None
def get_function_signature_hook(
self, fullname: str
) -> Callable[[FunctionSigContext], FunctionLike] | None:
from mypy.plugins import attrs, dataclasses
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
return attrs.evolve_function_sig_callback
elif fullname in ("attr.fields", "attrs.fields"):
return attrs.fields_function_sig_callback
elif fullname == "dataclasses.replace":
return dataclasses.replace_function_sig_callback
return None
def get_method_signature_hook(
self, fullname: str
) -> Callable[[MethodSigContext], FunctionLike] | None:
from mypy.plugins import ctypes, singledispatch
if fullname == "typing.Mapping.get":
return typed_dict_get_signature_callback
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
return typed_dict_setdefault_signature_callback
elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
return typed_dict_pop_signature_callback
elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
return typed_dict_update_signature_callback
elif fullname == "_ctypes.Array.__setitem__":
return ctypes.array_setitem_callback
elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_callback
return None
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
from mypy.plugins import ctypes, singledispatch
if fullname == "typing.Mapping.get":
return typed_dict_get_callback
elif fullname == "builtins.int.__pow__":
return int_pow_callback
elif fullname == "builtins.int.__neg__":
return int_neg_callback
elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
return tuple_mul_callback
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
return typed_dict_setdefault_callback
elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
return typed_dict_pop_callback
elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}:
return typed_dict_delitem_callback
elif fullname == "_ctypes.Array.__getitem__":
return ctypes.array_getitem_callback
elif fullname == "_ctypes.Array.__iter__":
return ctypes.array_iter_callback
elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_after_register_argument
return None
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
from mypy.plugins import ctypes, enums
if fullname == "_ctypes.Array.value":
return ctypes.array_value_callback
elif fullname == "_ctypes.Array.raw":
return ctypes.array_raw_callback
elif fullname in enums.ENUM_NAME_ACCESS:
return enums.enum_name_callback
elif fullname in enums.ENUM_VALUE_ACCESS:
return enums.enum_value_callback
return None
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
from mypy.plugins import attrs, dataclasses
# These dataclass and attrs hooks run in the main semantic analysis pass
# and only tag known dataclasses/attrs classes, so that the second
# hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
# in the MRO.
if fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_tag_callback
if (
fullname in attrs.attr_class_makers
or fullname in attrs.attr_dataclass_makers
or fullname in attrs.attr_frozen_makers
or fullname in attrs.attr_define_makers
):
return attrs.attr_tag_callback
return None
def get_class_decorator_hook_2(
self, fullname: str
) -> Callable[[ClassDefContext], bool] | None:
from mypy.plugins import attrs, dataclasses, functools
if fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback
elif fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
elif fullname in attrs.attr_dataclass_makers:
return partial(attrs.attr_class_maker_callback, auto_attribs_default=True)
elif fullname in attrs.attr_frozen_makers:
return partial(
attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
)
elif fullname in attrs.attr_define_makers:
return partial(
attrs.attr_class_maker_callback, auto_attribs_default=None, slots_default=True
)
return None
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.get.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = get_proper_type(ctx.type.items.get(key))
ret_type = signature.ret_type
if value_type:
default_arg = ctx.args[1][0]
if (
isinstance(value_type, TypedDictType)
and isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
):
# Caller has empty dict {} as default for typed dict.
value_type = value_type.copy_modified(required_keys=set())
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
return signature.copy_modified(
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
ret_type=ret_type,
)
return signature
def typed_dict_get_callback(ctx: MethodContext) -> Type:
"""Infer a precise return type for TypedDict.get with literal first argument."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type
output_types: list[Type] = []
for key in keys:
value_type = get_proper_type(ctx.type.items.get(key))
if value_type is None:
return ctx.default_return_type
if len(ctx.arg_types) == 1:
output_types.append(value_type)
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
default_arg = ctx.args[1][0]
if (
isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
and isinstance(value_type, TypedDictType)
):
# Special case '{}' as the default for a typed dict type.
output_types.append(value_type.copy_modified(required_keys=set()))
else:
output_types.append(value_type)
output_types.append(ctx.arg_types[1][0])
if len(ctx.arg_types) == 1:
output_types.append(NoneType())
return make_simplified_union(output_types)
return ctx.default_return_type
def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.pop.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type("builtins.str", [])
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
typ = make_simplified_union([value_type, tv])
return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
def typed_dict_pop_callback(ctx: MethodContext) -> Type:
"""Type check and infer a precise return type for TypedDict.pop."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
ctx.context,
code=codes.LITERAL_REQ,
)
return AnyType(TypeOfAny.from_error)
value_types = []
for key in keys:
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
value_type = ctx.type.items.get(key)
if value_type:
value_types.append(value_type)
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
if len(ctx.args[1]) == 0:
return make_simplified_union(value_types)
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
return make_simplified_union([*value_types, ctx.arg_types[1][0]])
return ctx.default_return_type
def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.setdefault.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type("builtins.str", [])
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
return signature.copy_modified(arg_types=[str_type, value_type])
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.setdefault and infer a precise return type."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 2
and len(ctx.arg_types[0]) == 1
and len(ctx.arg_types[1]) == 1
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
ctx.context,
code=codes.LITERAL_REQ,
)
return AnyType(TypeOfAny.from_error)
default_type = ctx.arg_types[1][0]
value_types = []
for key in keys:
value_type = ctx.type.items.get(key)
if value_type is None:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return AnyType(TypeOfAny.from_error)
# The signature_callback above can't always infer the right signature
# (e.g. when the expression is a variable that happens to be a Literal str)
# so we need to handle the check ourselves here and make sure the provided
# default can be assigned to all key-value pairs we're updating.
if not is_subtype(default_type, value_type):
ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
default_type, value_type, ctx.context
)
return AnyType(TypeOfAny.from_error)
value_types.append(value_type)
return make_simplified_union(value_types)
return ctx.default_return_type
def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.__delitem__."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
ctx.context,
code=codes.LITERAL_REQ,
)
return AnyType(TypeOfAny.from_error)
for key in keys:
if key in ctx.type.required_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
return ctx.default_return_type
def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.update."""
signature = ctx.default_signature
if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
arg_type = get_proper_type(signature.arg_types[0])
assert isinstance(arg_type, TypedDictType)
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
if ctx.args and ctx.args[0]:
with ctx.api.msg.filter_errors():
inferred = get_proper_type(
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
)
possible_tds = []
if isinstance(inferred, TypedDictType):
possible_tds = [inferred]
elif isinstance(inferred, UnionType):
possible_tds = [
t
for t in get_proper_types(inferred.relevant_items())
if isinstance(t, TypedDictType)
]
items = []
for td in possible_tds:
item = arg_type.copy_modified(
required_keys=(arg_type.required_keys | td.required_keys)
& arg_type.items.keys()
)
if not ctx.api.options.extra_checks:
item = item.copy_modified(item_names=list(td.items))
items.append(item)
if items:
arg_type = make_simplified_union(items)
return signature.copy_modified(arg_types=[arg_type])
return signature
def int_pow_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pow__."""
# int.__pow__ has an optional modulo argument,
# so we expect 2 argument positions
if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
arg = ctx.args[0][0]
if isinstance(arg, IntExpr):
exponent = arg.value
elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
exponent = -arg.expr.value
else:
# Right operand not an int literal or a negated literal -- give up.
return ctx.default_return_type
if exponent >= 0:
return ctx.api.named_generic_type("builtins.int", [])
else:
return ctx.api.named_generic_type("builtins.float", [])
return ctx.default_return_type
def int_neg_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__neg__.
This is mainly used to infer the return type as LiteralType
if the original underlying object is a LiteralType object
"""
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
value = ctx.type.last_known_value.value
fallback = ctx.type.last_known_value.fallback
if isinstance(value, int):
if is_literal_type_like(ctx.api.type_context[-1]):
return LiteralType(value=-value, fallback=fallback)
else:
return ctx.type.copy_modified(
last_known_value=LiteralType(
value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
)
)
elif isinstance(ctx.type, LiteralType):
value = ctx.type.value
fallback = ctx.type.fallback
if isinstance(value, int):
return LiteralType(value=-value, fallback=fallback)
return ctx.default_return_type
def tuple_mul_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
This is used to return a specific sized tuple if multiplied by Literal int
"""
if not isinstance(ctx.type, TupleType):
return ctx.default_return_type
arg_type = get_proper_type(ctx.arg_types[0][0])
if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
value = arg_type.last_known_value.value
if isinstance(value, int):
return ctx.type.copy_modified(items=ctx.type.items * value)
elif isinstance(ctx.type, LiteralType):
value = arg_type.value
if isinstance(value, int):
return ctx.type.copy_modified(items=ctx.type.items * value)
return ctx.default_return_type

View File

@@ -0,0 +1,258 @@
"""
This file contains a variety of plugins for refining how mypy infers types of
expressions involving Enums.
Currently, this file focuses on providing better inference for expressions like
'SomeEnum.FOO.name' and 'SomeEnum.FOO.value'. Note that the type of both expressions
will vary depending on exactly which instance of SomeEnum we're looking at.
Note that this file does *not* contain all special-cased logic related to enums:
we actually bake some of it directly in to the semantic analysis layer (see
semanal_enum.py).
"""
from __future__ import annotations
from typing import Final, Iterable, Sequence, TypeVar, cast
import mypy.plugin # To avoid circular imports.
from mypy.nodes import TypeInfo
from mypy.semanal_enum import ENUM_BASES
from mypy.subtypes import is_equivalent
from mypy.typeops import fixup_partial_type, make_simplified_union
from mypy.types import CallableType, Instance, LiteralType, ProperType, Type, get_proper_type
ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
f"{prefix}._name_" for prefix in ENUM_BASES
}
ENUM_VALUE_ACCESS: Final = {f"{prefix}.value" for prefix in ENUM_BASES} | {
f"{prefix}._value_" for prefix in ENUM_BASES
}
def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""This plugin refines the 'name' attribute in enums to act as if
they were declared to be final.
For example, the expression 'MyEnum.FOO.name' normally is inferred
to be of type 'str'.
This plugin will instead make the inferred type be a 'str' where the
last known value is 'Literal["FOO"]'. This means it would be legal to
use 'MyEnum.FOO.name' in contexts that expect a Literal type, just like
any other Final variable or attribute.
This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_NAME_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
return ctx.default_attr_type
else:
str_type = ctx.api.named_generic_type("builtins.str", [])
literal_type = LiteralType(enum_field_name, fallback=str_type)
return str_type.copy_modified(last_known_value=literal_type)
_T = TypeVar("_T")
def _first(it: Iterable[_T]) -> _T | None:
"""Return the first value from any iterable.
Returns ``None`` if the iterable is empty.
"""
for val in it:
return val
return None
def _infer_value_type_with_auto_fallback(
ctx: mypy.plugin.AttributeContext, proper_type: ProperType | None
) -> Type | None:
"""Figure out the type of an enum value accounting for `auto()`.
This method is a no-op for a `None` proper_type and also in the case where
the type is not "enum.auto"
"""
if proper_type is None:
return None
proper_type = get_proper_type(fixup_partial_type(proper_type))
if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"):
return proper_type
assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed."
info = ctx.type.type
# Find the first _generate_next_value_ on the mro. We need to know
# if it is `Enum` because `Enum` types say that the return-value of
# `_generate_next_value_` is `Any`. In reality the default `auto()`
# returns an `int` (presumably the `Any` in typeshed is to make it
# easier to subclass and change the returned type).
type_with_gnv = _first(ti for ti in info.mro if ti.names.get("_generate_next_value_"))
if type_with_gnv is None:
return ctx.default_attr_type
stnode = type_with_gnv.names["_generate_next_value_"]
# This should be a `CallableType`
node_type = get_proper_type(stnode.type)
if isinstance(node_type, CallableType):
if type_with_gnv.fullname == "enum.Enum":
int_type = ctx.api.named_generic_type("builtins.int", [])
return int_type
return get_proper_type(node_type.ret_type)
return ctx.default_attr_type
def _implements_new(info: TypeInfo) -> bool:
"""Check whether __new__ comes from enum.Enum or was implemented in a
subclass. In the latter case, we must infer Any as long as mypy can't infer
the type of _value_ from assignments in __new__.
"""
type_with_new = _first(
ti
for ti in info.mro
if ti.names.get("__new__") and not ti.fullname.startswith("builtins.")
)
if type_with_new is None:
return False
return type_with_new.fullname not in ("enum.Enum", "enum.IntEnum", "enum.StrEnum")
def enum_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
following:
class SomeEnum:
FOO = A()
BAR = B()
By default, mypy will infer that 'SomeEnum.FOO.value' and
'SomeEnum.BAR.value' both are of type 'Any'. This plugin refines
this inference so that mypy understands the expressions are
actually of types 'A' and 'B' respectively. This better reflects
the actual runtime behavior.
This plugin works simply by looking up the original value assigned
to the enum. For example, when this plugin sees 'SomeEnum.BAR.value',
it will look up whatever type 'BAR' had in the SomeEnum TypeInfo and
use that as the inferred type of the overall expression.
This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_VALUE_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
# We do not know the enum field name (perhaps it was passed to a
# function and we only know that it _is_ a member). All is not lost
# however, if we can prove that the all of the enum members have the
# same value-type, then it doesn't matter which member was passed in.
# The value-type is still known.
if isinstance(ctx.type, Instance):
info = ctx.type.type
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type
stnodes = (info.get(name) for name in info.names)
# Enums _can_ have methods and instance attributes.
# Omit methods and attributes created by assigning to self.*
# for our value inference.
node_types = (
get_proper_type(n.type) if n else None
for n in stnodes
if n is None or not n.implicit
)
proper_types = list(
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
if t is None or not isinstance(t, CallableType)
)
underlying_type = _first(proper_types)
if underlying_type is None:
return ctx.default_attr_type
# At first we try to predict future `value` type if all other items
# have the same type. For example, `int`.
# If this is the case, we simply return this type.
# See https://github.com/python/mypy/pull/9443
all_same_value_type = all(
proper_type is not None and proper_type == underlying_type
for proper_type in proper_types
)
if all_same_value_type:
if underlying_type is not None:
return underlying_type
# But, after we started treating all `Enum` values as `Final`,
# we start to infer types in
# `item = 1` as `Literal[1]`, not just `int`.
# So, for example types in this `Enum` will all be different:
#
# class Ordering(IntEnum):
# one = 1
# two = 2
# three = 3
#
# We will infer three `Literal` types here.
# They are not the same, but they are equivalent.
# So, we unify them to make sure `.value` prediction still works.
# Result will be `Literal[1] | Literal[2] | Literal[3]` for this case.
all_equivalent_types = all(
proper_type is not None and is_equivalent(proper_type, underlying_type)
for proper_type in proper_types
)
if all_equivalent_types:
return make_simplified_union(cast(Sequence[Type], proper_types))
return ctx.default_attr_type
assert isinstance(ctx.type, Instance)
info = ctx.type.type
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type
stnode = info.get(enum_field_name)
if stnode is None:
return ctx.default_attr_type
underlying_type = _infer_value_type_with_auto_fallback(ctx, get_proper_type(stnode.type))
if underlying_type is None:
return ctx.default_attr_type
return underlying_type
def _extract_underlying_field_name(typ: Type) -> str | None:
"""If the given type corresponds to some Enum instance, returns the
original name of that enum. For example, if we receive in the type
corresponding to 'SomeEnum.FOO', we return the string "SomeEnum.Foo".
This helper takes advantage of the fact that Enum instances are valid
to use inside Literal[...] types. An expression like 'SomeEnum.FOO' is
actually represented by an Instance type with a Literal enum fallback.
We can examine this Literal fallback to retrieve the string.
"""
typ = get_proper_type(typ)
if not isinstance(typ, Instance):
return None
if not typ.type.is_enum:
return None
underlying_literal = typ.last_known_value
if underlying_literal is None:
return None
# The checks above have verified this LiteralType is representing an enum value,
# which means the 'value' field is guaranteed to be the name of the enum field
# as a string.
assert isinstance(underlying_literal.value, str)
return underlying_literal.value

View File

@@ -0,0 +1,103 @@
"""Plugin for supporting the functools standard library module."""
from __future__ import annotations
from typing import Final, NamedTuple
import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
functools_total_ordering_makers: Final = {"functools.total_ordering"}
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
class _MethodInfo(NamedTuple):
is_static: bool
type: CallableType
def functools_total_ordering_maker_callback(
ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False
) -> bool:
"""Add dunder methods to classes decorated with functools.total_ordering."""
comparison_methods = _analyze_class(ctx)
if not comparison_methods:
ctx.api.fail(
'No ordering operation defined when using "functools.total_ordering": < > <= >=',
ctx.reason,
)
return True
# prefer __lt__ to __le__ to __gt__ to __ge__
root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
root_method = comparison_methods[root]
if not root_method:
# None of the defined comparison methods can be analysed
return True
other_type = _find_other_type(root_method)
bool_type = ctx.api.named_type("builtins.bool")
ret_type: Type = bool_type
if root_method.type.ret_type != ctx.api.named_type("builtins.bool"):
proper_ret_type = get_proper_type(root_method.type.ret_type)
if not (
isinstance(proper_ret_type, UnboundType)
and proper_ret_type.name.split(".")[-1] == "bool"
):
ret_type = AnyType(TypeOfAny.implementation_artifact)
for additional_op in _ORDERING_METHODS:
# Either the method is not implemented
# or has an unknown signature that we can now extrapolate.
if not comparison_methods.get(additional_op):
args = [Argument(Var("other", other_type), other_type, None, ARG_POS)]
add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)
return True
def _find_other_type(method: _MethodInfo) -> Type:
"""Find the type of the ``other`` argument in a comparison method."""
first_arg_pos = 0 if method.is_static else 1
cur_pos_arg = 0
other_arg = None
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
if arg_kind.is_positional():
if cur_pos_arg == first_arg_pos:
other_arg = arg_type
break
cur_pos_arg += 1
elif arg_kind != ARG_STAR2:
other_arg = arg_type
break
if other_arg is None:
return AnyType(TypeOfAny.implementation_artifact)
return other_arg
def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | None]:
"""Analyze the class body, its parents, and return the comparison methods found."""
# Traverse the MRO and collect ordering methods.
comparison_methods: dict[str, _MethodInfo | None] = {}
# Skip object because total_ordering does not use methods from object
for cls in ctx.cls.info.mro[:-1]:
for name in _ORDERING_METHODS:
if name in cls.names and name not in comparison_methods:
node = cls.names[name].node
if isinstance(node, FuncItem) and isinstance(node.type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_static, node.type)
continue
if isinstance(node, Var):
proper_type = get_proper_type(node.type)
if isinstance(proper_type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
continue
comparison_methods[name] = None
return comparison_methods

View File

@@ -0,0 +1,224 @@
from __future__ import annotations
from typing import Final, NamedTuple, Sequence, TypeVar, Union
from typing_extensions import TypeAlias as _TypeAlias
from mypy.messages import format_type
from mypy.nodes import ARG_POS, Argument, Block, ClassDef, Context, SymbolTable, TypeInfo, Var
from mypy.options import Options
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext
from mypy.plugins.common import add_method_to_class
from mypy.subtypes import is_subtype
from mypy.types import (
AnyType,
CallableType,
FunctionLike,
Instance,
NoneType,
Overloaded,
Type,
TypeOfAny,
get_proper_type,
)
class SingledispatchTypeVars(NamedTuple):
return_type: Type
fallback: CallableType
class RegisterCallableInfo(NamedTuple):
register_type: Type
singledispatch_obj: Instance
SINGLEDISPATCH_TYPE: Final = "functools._SingleDispatchCallable"
SINGLEDISPATCH_REGISTER_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.register"
SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.__call__"
def get_singledispatch_info(typ: Instance) -> SingledispatchTypeVars | None:
if len(typ.args) == 2:
return SingledispatchTypeVars(*typ.args) # type: ignore[arg-type]
return None
T = TypeVar("T")
def get_first_arg(args: list[list[T]]) -> T | None:
"""Get the element that corresponds to the first argument passed to the function"""
if args and args[0]:
return args[0][0]
return None
REGISTER_RETURN_CLASS: Final = "_SingleDispatchRegisterCallable"
REGISTER_CALLABLE_CALL_METHOD: Final = f"functools.{REGISTER_RETURN_CLASS}.__call__"
def make_fake_register_class_instance(
api: CheckerPluginInterface, type_args: Sequence[Type]
) -> Instance:
defn = ClassDef(REGISTER_RETURN_CLASS, Block([]))
defn.fullname = f"functools.{REGISTER_RETURN_CLASS}"
info = TypeInfo(SymbolTable(), defn, "functools")
obj_type = api.named_generic_type("builtins.object", []).type
info.bases = [Instance(obj_type, [])]
info.mro = [info, obj_type]
defn.info = info
func_arg = Argument(Var("name"), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS)
add_method_to_class(api, defn, "__call__", [func_arg], NoneType())
return Instance(info, type_args)
PluginContext: _TypeAlias = Union[FunctionContext, MethodContext]
def fail(ctx: PluginContext, msg: str, context: Context | None) -> None:
"""Emit an error message.
This tries to emit an error message at the location specified by `context`, falling back to the
location specified by `ctx.context`. This is helpful when the only context information about
where you want to put the error message may be None (like it is for `CallableType.definition`)
and falling back to the location of the calling function is fine."""
# TODO: figure out if there is some more reliable way of getting context information, so this
# function isn't necessary
if context is not None:
err_context = context
else:
err_context = ctx.context
ctx.api.fail(msg, err_context)
def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
"""Called for functools.singledispatch"""
func_type = get_proper_type(get_first_arg(ctx.arg_types))
if isinstance(func_type, CallableType):
if len(func_type.arg_kinds) < 1:
fail(
ctx, "Singledispatch function requires at least one argument", func_type.definition
)
return ctx.default_return_type
elif not func_type.arg_kinds[0].is_positional(star=True):
fail(
ctx,
"First argument to singledispatch function must be a positional argument",
func_type.definition,
)
return ctx.default_return_type
# singledispatch returns an instance of functools._SingleDispatchCallable according to
# typeshed
singledispatch_obj = get_proper_type(ctx.default_return_type)
assert isinstance(singledispatch_obj, Instance)
singledispatch_obj.args += (func_type,)
return ctx.default_return_type
def singledispatch_register_callback(ctx: MethodContext) -> Type:
"""Called for functools._SingleDispatchCallable.register"""
assert isinstance(ctx.type, Instance)
# TODO: check that there's only one argument
first_arg_type = get_proper_type(get_first_arg(ctx.arg_types))
if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj():
# HACK: We received a class as an argument to register. We need to be able
# to access the function that register is being applied to, and the typeshed definition
# of register has it return a generic Callable, so we create a new
# SingleDispatchRegisterCallable class, define a __call__ method, and then add a
# plugin hook for that.
# is_subtype doesn't work when the right type is Overloaded, so we need the
# actual type
register_type = first_arg_type.items[0].ret_type
type_args = RegisterCallableInfo(register_type, ctx.type)
register_callable = make_fake_register_class_instance(ctx.api, type_args)
return register_callable
elif isinstance(first_arg_type, CallableType):
# TODO: do more checking for registered functions
register_function(ctx, ctx.type, first_arg_type, ctx.api.options)
# The typeshed stubs for register say that the function returned is Callable[..., T], even
# though the function returned is the same as the one passed in. We return the type of the
# function so that mypy can properly type check cases where the registered function is used
# directly (instead of through singledispatch)
return first_arg_type
# fallback in case we don't recognize the arguments
return ctx.default_return_type
def register_function(
ctx: PluginContext,
singledispatch_obj: Instance,
func: Type,
options: Options,
register_arg: Type | None = None,
) -> None:
"""Register a function"""
func = get_proper_type(func)
if not isinstance(func, CallableType):
return
metadata = get_singledispatch_info(singledispatch_obj)
if metadata is None:
# if we never added the fallback to the type variables, we already reported an error, so
# just don't do anything here
return
dispatch_type = get_dispatch_type(func, register_arg)
if dispatch_type is None:
# TODO: report an error here that singledispatch requires at least one argument
# (might want to do the error reporting in get_dispatch_type)
return
fallback = metadata.fallback
fallback_dispatch_type = fallback.arg_types[0]
if not is_subtype(dispatch_type, fallback_dispatch_type):
fail(
ctx,
"Dispatch type {} must be subtype of fallback function first argument {}".format(
format_type(dispatch_type, options), format_type(fallback_dispatch_type, options)
),
func.definition,
)
return
return
def get_dispatch_type(func: CallableType, register_arg: Type | None) -> Type | None:
if register_arg is not None:
return register_arg
if func.arg_types:
return func.arg_types[0]
return None
def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type:
"""Called on the function after passing a type to register"""
register_callable = ctx.type
if isinstance(register_callable, Instance):
type_args = RegisterCallableInfo(*register_callable.args) # type: ignore[arg-type]
func = get_first_arg(ctx.arg_types)
if func is not None:
register_function(
ctx, type_args.singledispatch_obj, func, ctx.api.options, type_args.register_type
)
# see call to register_function in the callback for register
return func
return ctx.default_return_type
def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike:
"""Called for functools._SingleDispatchCallable.__call__"""
if not isinstance(ctx.type, Instance):
return ctx.default_signature
metadata = get_singledispatch_info(ctx.type)
if metadata is None:
return ctx.default_signature
return metadata.fallback