Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
Binary file not shown.
Binary file not shown.
1114
venv/lib/python3.12/site-packages/mypy/plugins/attrs.py
Normal file
1114
venv/lib/python3.12/site-packages/mypy/plugins/attrs.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
343
venv/lib/python3.12/site-packages/mypy/plugins/common.py
Normal file
343
venv/lib/python3.12/site-packages/mypy/plugins/common.py
Normal file
@@ -0,0 +1,343 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from mypy.argmap import map_actuals_to_formals
|
||||
from mypy.fixup import TypeFixer
|
||||
from mypy.nodes import (
|
||||
ARG_POS,
|
||||
MDEF,
|
||||
SYMBOL_FUNCBASE_TYPES,
|
||||
Argument,
|
||||
Block,
|
||||
CallExpr,
|
||||
ClassDef,
|
||||
Decorator,
|
||||
Expression,
|
||||
FuncDef,
|
||||
JsonDict,
|
||||
NameExpr,
|
||||
Node,
|
||||
PassStmt,
|
||||
RefExpr,
|
||||
SymbolTableNode,
|
||||
Var,
|
||||
)
|
||||
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
|
||||
from mypy.semanal_shared import (
|
||||
ALLOW_INCOMPATIBLE_OVERRIDE,
|
||||
parse_bool,
|
||||
require_bool_literal_argument,
|
||||
set_callable_name,
|
||||
)
|
||||
from mypy.typeops import try_getting_str_literals as try_getting_str_literals
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
Instance,
|
||||
LiteralType,
|
||||
NoneType,
|
||||
Overloaded,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
TypeType,
|
||||
TypeVarType,
|
||||
deserialize_type,
|
||||
get_proper_type,
|
||||
)
|
||||
from mypy.types_utils import is_overlapping_none
|
||||
from mypy.typevars import fill_typevars
|
||||
from mypy.util import get_unique_redefinition_name
|
||||
|
||||
|
||||
def _get_decorator_bool_argument(ctx: ClassDefContext, name: str, default: bool) -> bool:
|
||||
"""Return the bool argument for the decorator.
|
||||
|
||||
This handles both @decorator(...) and @decorator.
|
||||
"""
|
||||
if isinstance(ctx.reason, CallExpr):
|
||||
return _get_bool_argument(ctx, ctx.reason, name, default)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: bool) -> bool:
|
||||
"""Return the boolean value for an argument to a call or the
|
||||
default if it's not found.
|
||||
"""
|
||||
attr_value = _get_argument(expr, name)
|
||||
if attr_value:
|
||||
return require_bool_literal_argument(ctx.api, attr_value, name, default)
|
||||
return default
|
||||
|
||||
|
||||
def _get_argument(call: CallExpr, name: str) -> Expression | None:
|
||||
"""Return the expression for the specific argument."""
|
||||
# To do this we use the CallableType of the callee to find the FormalArgument,
|
||||
# then walk the actual CallExpr looking for the appropriate argument.
|
||||
#
|
||||
# Note: I'm not hard-coding the index so that in the future we can support other
|
||||
# attrib and class makers.
|
||||
callee_type = _get_callee_type(call)
|
||||
if not callee_type:
|
||||
return None
|
||||
|
||||
argument = callee_type.argument_by_name(name)
|
||||
if not argument:
|
||||
return None
|
||||
assert argument.name
|
||||
|
||||
for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)):
|
||||
if argument.pos is not None and not attr_name and i == argument.pos:
|
||||
return attr_value
|
||||
if attr_name == argument.name:
|
||||
return attr_value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType:
|
||||
"""Perform limited lookup of a matching overload item.
|
||||
|
||||
Full overload resolution is only supported during type checking, but plugins
|
||||
sometimes need to resolve overloads. This can be used in some such use cases.
|
||||
|
||||
Resolve overloads based on these things only:
|
||||
|
||||
* Match using argument kinds and names
|
||||
* If formal argument has type None, only accept the "None" expression in the callee
|
||||
* If formal argument has type Literal[True] or Literal[False], only accept the
|
||||
relevant bool literal
|
||||
|
||||
Return the first matching overload item, or the last one if nothing matches.
|
||||
"""
|
||||
for item in overload.items[:-1]:
|
||||
ok = True
|
||||
mapped = map_actuals_to_formals(
|
||||
call.arg_kinds,
|
||||
call.arg_names,
|
||||
item.arg_kinds,
|
||||
item.arg_names,
|
||||
lambda i: AnyType(TypeOfAny.special_form),
|
||||
)
|
||||
|
||||
# Look for extra actuals
|
||||
matched_actuals = set()
|
||||
for actuals in mapped:
|
||||
matched_actuals.update(actuals)
|
||||
if any(i not in matched_actuals for i in range(len(call.args))):
|
||||
ok = False
|
||||
|
||||
for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped):
|
||||
if kind.is_required() and not actuals:
|
||||
# Missing required argument
|
||||
ok = False
|
||||
break
|
||||
elif actuals:
|
||||
args = [call.args[i] for i in actuals]
|
||||
arg_type = get_proper_type(arg_type)
|
||||
arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args)
|
||||
if isinstance(arg_type, NoneType):
|
||||
if not arg_none:
|
||||
ok = False
|
||||
break
|
||||
elif (
|
||||
arg_none
|
||||
and not is_overlapping_none(arg_type)
|
||||
and not (
|
||||
isinstance(arg_type, Instance)
|
||||
and arg_type.type.fullname == "builtins.object"
|
||||
)
|
||||
and not isinstance(arg_type, AnyType)
|
||||
):
|
||||
ok = False
|
||||
break
|
||||
elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool:
|
||||
if not any(parse_bool(arg) == arg_type.value for arg in args):
|
||||
ok = False
|
||||
break
|
||||
if ok:
|
||||
return item
|
||||
return overload.items[-1]
|
||||
|
||||
|
||||
def _get_callee_type(call: CallExpr) -> CallableType | None:
|
||||
"""Return the type of the callee, regardless of its syntatic form."""
|
||||
|
||||
callee_node: Node | None = call.callee
|
||||
|
||||
if isinstance(callee_node, RefExpr):
|
||||
callee_node = callee_node.node
|
||||
|
||||
# Some decorators may be using typing.dataclass_transform, which is itself a decorator, so we
|
||||
# need to unwrap them to get at the true callee
|
||||
if isinstance(callee_node, Decorator):
|
||||
callee_node = callee_node.func
|
||||
|
||||
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
|
||||
callee_node_type = get_proper_type(callee_node.type)
|
||||
if isinstance(callee_node_type, Overloaded):
|
||||
return find_shallow_matching_overload_item(callee_node_type, call)
|
||||
elif isinstance(callee_node_type, CallableType):
|
||||
return callee_node_type
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def add_method(
|
||||
ctx: ClassDefContext,
|
||||
name: str,
|
||||
args: list[Argument],
|
||||
return_type: Type,
|
||||
self_type: Type | None = None,
|
||||
tvar_def: TypeVarType | None = None,
|
||||
is_classmethod: bool = False,
|
||||
is_staticmethod: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new method to a class.
|
||||
Deprecated, use add_method_to_class() instead.
|
||||
"""
|
||||
add_method_to_class(
|
||||
ctx.api,
|
||||
ctx.cls,
|
||||
name=name,
|
||||
args=args,
|
||||
return_type=return_type,
|
||||
self_type=self_type,
|
||||
tvar_def=tvar_def,
|
||||
is_classmethod=is_classmethod,
|
||||
is_staticmethod=is_staticmethod,
|
||||
)
|
||||
|
||||
|
||||
def add_method_to_class(
|
||||
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
|
||||
cls: ClassDef,
|
||||
name: str,
|
||||
args: list[Argument],
|
||||
return_type: Type,
|
||||
self_type: Type | None = None,
|
||||
tvar_def: TypeVarType | None = None,
|
||||
is_classmethod: bool = False,
|
||||
is_staticmethod: bool = False,
|
||||
) -> None:
|
||||
"""Adds a new method to a class definition."""
|
||||
|
||||
assert not (
|
||||
is_classmethod is True and is_staticmethod is True
|
||||
), "Can't add a new method that's both staticmethod and classmethod."
|
||||
|
||||
info = cls.info
|
||||
|
||||
# First remove any previously generated methods with the same name
|
||||
# to avoid clashes and problems in the semantic analyzer.
|
||||
if name in info.names:
|
||||
sym = info.names[name]
|
||||
if sym.plugin_generated and isinstance(sym.node, FuncDef):
|
||||
cls.defs.body.remove(sym.node)
|
||||
|
||||
if isinstance(api, SemanticAnalyzerPluginInterface):
|
||||
function_type = api.named_type("builtins.function")
|
||||
else:
|
||||
function_type = api.named_generic_type("builtins.function", [])
|
||||
|
||||
if is_classmethod:
|
||||
self_type = self_type or TypeType(fill_typevars(info))
|
||||
first = [Argument(Var("_cls"), self_type, None, ARG_POS, True)]
|
||||
elif is_staticmethod:
|
||||
first = []
|
||||
else:
|
||||
self_type = self_type or fill_typevars(info)
|
||||
first = [Argument(Var("self"), self_type, None, ARG_POS)]
|
||||
args = first + args
|
||||
|
||||
arg_types, arg_names, arg_kinds = [], [], []
|
||||
for arg in args:
|
||||
assert arg.type_annotation, "All arguments must be fully typed."
|
||||
arg_types.append(arg.type_annotation)
|
||||
arg_names.append(arg.variable.name)
|
||||
arg_kinds.append(arg.kind)
|
||||
|
||||
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||||
if tvar_def:
|
||||
signature.variables = [tvar_def]
|
||||
|
||||
func = FuncDef(name, args, Block([PassStmt()]))
|
||||
func.info = info
|
||||
func.type = set_callable_name(signature, func)
|
||||
func.is_class = is_classmethod
|
||||
func.is_static = is_staticmethod
|
||||
func._fullname = info.fullname + "." + name
|
||||
func.line = info.line
|
||||
|
||||
# NOTE: we would like the plugin generated node to dominate, but we still
|
||||
# need to keep any existing definitions so they get semantically analyzed.
|
||||
if name in info.names:
|
||||
# Get a nice unique name instead.
|
||||
r_name = get_unique_redefinition_name(name, info.names)
|
||||
info.names[r_name] = info.names[name]
|
||||
|
||||
# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
|
||||
if is_staticmethod:
|
||||
func.is_decorated = True
|
||||
v = Var(name, func.type)
|
||||
v.info = info
|
||||
v._fullname = func._fullname
|
||||
v.is_staticmethod = True
|
||||
dec = Decorator(func, [], v)
|
||||
dec.line = info.line
|
||||
sym = SymbolTableNode(MDEF, dec)
|
||||
else:
|
||||
sym = SymbolTableNode(MDEF, func)
|
||||
sym.plugin_generated = True
|
||||
info.names[name] = sym
|
||||
|
||||
info.defn.defs.body.append(func)
|
||||
|
||||
|
||||
def add_attribute_to_class(
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
cls: ClassDef,
|
||||
name: str,
|
||||
typ: Type,
|
||||
final: bool = False,
|
||||
no_serialize: bool = False,
|
||||
override_allow_incompatible: bool = False,
|
||||
fullname: str | None = None,
|
||||
is_classvar: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new attribute to a class definition.
|
||||
This currently only generates the symbol table entry and no corresponding AssignmentStatement
|
||||
"""
|
||||
info = cls.info
|
||||
|
||||
# NOTE: we would like the plugin generated node to dominate, but we still
|
||||
# need to keep any existing definitions so they get semantically analyzed.
|
||||
if name in info.names:
|
||||
# Get a nice unique name instead.
|
||||
r_name = get_unique_redefinition_name(name, info.names)
|
||||
info.names[r_name] = info.names[name]
|
||||
|
||||
node = Var(name, typ)
|
||||
node.info = info
|
||||
node.is_final = final
|
||||
node.is_classvar = is_classvar
|
||||
if name in ALLOW_INCOMPATIBLE_OVERRIDE:
|
||||
node.allow_incompatible_override = True
|
||||
else:
|
||||
node.allow_incompatible_override = override_allow_incompatible
|
||||
|
||||
if fullname:
|
||||
node._fullname = fullname
|
||||
else:
|
||||
node._fullname = info.fullname + "." + name
|
||||
|
||||
info.names[name] = SymbolTableNode(
|
||||
MDEF, node, plugin_generated=True, no_serialize=no_serialize
|
||||
)
|
||||
|
||||
|
||||
def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:
|
||||
typ = deserialize_type(data)
|
||||
typ.accept(TypeFixer(api.modules, allow_missing=False))
|
||||
return typ
|
||||
Binary file not shown.
245
venv/lib/python3.12/site-packages/mypy/plugins/ctypes.py
Normal file
245
venv/lib/python3.12/site-packages/mypy/plugins/ctypes.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Plugin to provide accurate types for some parts of the ctypes module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# Fully qualified instead of "from mypy.plugin import ..." to avoid circular import problems.
|
||||
import mypy.plugin
|
||||
from mypy import nodes
|
||||
from mypy.maptype import map_instance_to_supertype
|
||||
from mypy.messages import format_type
|
||||
from mypy.subtypes import is_subtype
|
||||
from mypy.typeops import make_simplified_union
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
Instance,
|
||||
NoneType,
|
||||
ProperType,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
UnionType,
|
||||
flatten_nested_unions,
|
||||
get_proper_type,
|
||||
)
|
||||
|
||||
|
||||
def _find_simplecdata_base_arg(
|
||||
tp: Instance, api: mypy.plugin.CheckerPluginInterface
|
||||
) -> ProperType | None:
|
||||
"""Try to find a parametrized _SimpleCData in tp's bases and return its single type argument.
|
||||
|
||||
None is returned if _SimpleCData appears nowhere in tp's (direct or indirect) bases.
|
||||
"""
|
||||
if tp.type.has_base("_ctypes._SimpleCData"):
|
||||
simplecdata_base = map_instance_to_supertype(
|
||||
tp,
|
||||
api.named_generic_type("_ctypes._SimpleCData", [AnyType(TypeOfAny.special_form)]).type,
|
||||
)
|
||||
assert len(simplecdata_base.args) == 1, "_SimpleCData takes exactly one type argument"
|
||||
return get_proper_type(simplecdata_base.args[0])
|
||||
return None
|
||||
|
||||
|
||||
def _autoconvertible_to_cdata(tp: Type, api: mypy.plugin.CheckerPluginInterface) -> Type:
|
||||
"""Get a type that is compatible with all types that can be implicitly converted to the given
|
||||
CData type.
|
||||
|
||||
Examples:
|
||||
* c_int -> Union[c_int, int]
|
||||
* c_char_p -> Union[c_char_p, bytes, int, NoneType]
|
||||
* MyStructure -> MyStructure
|
||||
"""
|
||||
allowed_types = []
|
||||
# If tp is a union, we allow all types that are convertible to at least one of the union
|
||||
# items. This is not quite correct - strictly speaking, only types convertible to *all* of the
|
||||
# union items should be allowed. This may be worth changing in the future, but the more
|
||||
# correct algorithm could be too strict to be useful.
|
||||
for t in flatten_nested_unions([tp]):
|
||||
t = get_proper_type(t)
|
||||
# Every type can be converted from itself (obviously).
|
||||
allowed_types.append(t)
|
||||
if isinstance(t, Instance):
|
||||
unboxed = _find_simplecdata_base_arg(t, api)
|
||||
if unboxed is not None:
|
||||
# If _SimpleCData appears in tp's (direct or indirect) bases, its type argument
|
||||
# specifies the type's "unboxed" version, which can always be converted back to
|
||||
# the original "boxed" type.
|
||||
allowed_types.append(unboxed)
|
||||
|
||||
if t.type.has_base("ctypes._PointerLike"):
|
||||
# Pointer-like _SimpleCData subclasses can also be converted from
|
||||
# an int or None.
|
||||
allowed_types.append(api.named_generic_type("builtins.int", []))
|
||||
allowed_types.append(NoneType())
|
||||
|
||||
return make_simplified_union(allowed_types)
|
||||
|
||||
|
||||
def _autounboxed_cdata(tp: Type) -> ProperType:
|
||||
"""Get the auto-unboxed version of a CData type, if applicable.
|
||||
|
||||
For *direct* _SimpleCData subclasses, the only type argument of _SimpleCData in the bases list
|
||||
is returned.
|
||||
For all other CData types, including indirect _SimpleCData subclasses, tp is returned as-is.
|
||||
"""
|
||||
tp = get_proper_type(tp)
|
||||
|
||||
if isinstance(tp, UnionType):
|
||||
return make_simplified_union([_autounboxed_cdata(t) for t in tp.items])
|
||||
elif isinstance(tp, Instance):
|
||||
for base in tp.type.bases:
|
||||
if base.type.fullname == "_ctypes._SimpleCData":
|
||||
# If tp has _SimpleCData as a direct base class,
|
||||
# the auto-unboxed type is the single type argument of the _SimpleCData type.
|
||||
assert len(base.args) == 1
|
||||
return get_proper_type(base.args[0])
|
||||
# If tp is not a concrete type, or if there is no _SimpleCData in the bases,
|
||||
# the type is not auto-unboxed.
|
||||
return tp
|
||||
|
||||
|
||||
def _get_array_element_type(tp: Type) -> ProperType | None:
|
||||
"""Get the element type of the Array type tp, or None if not specified."""
|
||||
tp = get_proper_type(tp)
|
||||
if isinstance(tp, Instance):
|
||||
assert tp.type.fullname == "_ctypes.Array"
|
||||
if len(tp.args) == 1:
|
||||
return get_proper_type(tp.args[0])
|
||||
return None
|
||||
|
||||
|
||||
def array_constructor_callback(ctx: mypy.plugin.FunctionContext) -> Type:
|
||||
"""Callback to provide an accurate signature for the ctypes.Array constructor."""
|
||||
# Extract the element type from the constructor's return type, i. e. the type of the array
|
||||
# being constructed.
|
||||
et = _get_array_element_type(ctx.default_return_type)
|
||||
if et is not None:
|
||||
allowed = _autoconvertible_to_cdata(et, ctx.api)
|
||||
assert (
|
||||
len(ctx.arg_types) == 1
|
||||
), "The stub of the ctypes.Array constructor should have a single vararg parameter"
|
||||
for arg_num, (arg_kind, arg_type) in enumerate(zip(ctx.arg_kinds[0], ctx.arg_types[0]), 1):
|
||||
if arg_kind == nodes.ARG_POS and not is_subtype(arg_type, allowed):
|
||||
ctx.api.msg.fail(
|
||||
"Array constructor argument {} of type {}"
|
||||
" is not convertible to the array element type {}".format(
|
||||
arg_num,
|
||||
format_type(arg_type, ctx.api.options),
|
||||
format_type(et, ctx.api.options),
|
||||
),
|
||||
ctx.context,
|
||||
)
|
||||
elif arg_kind == nodes.ARG_STAR:
|
||||
ty = ctx.api.named_generic_type("typing.Iterable", [allowed])
|
||||
if not is_subtype(arg_type, ty):
|
||||
it = ctx.api.named_generic_type("typing.Iterable", [et])
|
||||
ctx.api.msg.fail(
|
||||
"Array constructor argument {} of type {}"
|
||||
" is not convertible to the array element type {}".format(
|
||||
arg_num,
|
||||
format_type(arg_type, ctx.api.options),
|
||||
format_type(it, ctx.api.options),
|
||||
),
|
||||
ctx.context,
|
||||
)
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def array_getitem_callback(ctx: mypy.plugin.MethodContext) -> Type:
|
||||
"""Callback to provide an accurate return type for ctypes.Array.__getitem__."""
|
||||
et = _get_array_element_type(ctx.type)
|
||||
if et is not None:
|
||||
unboxed = _autounboxed_cdata(et)
|
||||
assert (
|
||||
len(ctx.arg_types) == 1
|
||||
), "The stub of ctypes.Array.__getitem__ should have exactly one parameter"
|
||||
assert (
|
||||
len(ctx.arg_types[0]) == 1
|
||||
), "ctypes.Array.__getitem__'s parameter should not be variadic"
|
||||
index_type = get_proper_type(ctx.arg_types[0][0])
|
||||
if isinstance(index_type, Instance):
|
||||
if index_type.type.has_base("builtins.int"):
|
||||
return unboxed
|
||||
elif index_type.type.has_base("builtins.slice"):
|
||||
return ctx.api.named_generic_type("builtins.list", [unboxed])
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def array_setitem_callback(ctx: mypy.plugin.MethodSigContext) -> CallableType:
|
||||
"""Callback to provide an accurate signature for ctypes.Array.__setitem__."""
|
||||
et = _get_array_element_type(ctx.type)
|
||||
if et is not None:
|
||||
allowed = _autoconvertible_to_cdata(et, ctx.api)
|
||||
assert len(ctx.default_signature.arg_types) == 2
|
||||
index_type = get_proper_type(ctx.default_signature.arg_types[0])
|
||||
if isinstance(index_type, Instance):
|
||||
arg_type = None
|
||||
if index_type.type.has_base("builtins.int"):
|
||||
arg_type = allowed
|
||||
elif index_type.type.has_base("builtins.slice"):
|
||||
arg_type = ctx.api.named_generic_type("builtins.list", [allowed])
|
||||
if arg_type is not None:
|
||||
# Note: arg_type can only be None if index_type is invalid, in which case we use
|
||||
# the default signature and let mypy report an error about it.
|
||||
return ctx.default_signature.copy_modified(
|
||||
arg_types=ctx.default_signature.arg_types[:1] + [arg_type]
|
||||
)
|
||||
return ctx.default_signature
|
||||
|
||||
|
||||
def array_iter_callback(ctx: mypy.plugin.MethodContext) -> Type:
|
||||
"""Callback to provide an accurate return type for ctypes.Array.__iter__."""
|
||||
et = _get_array_element_type(ctx.type)
|
||||
if et is not None:
|
||||
unboxed = _autounboxed_cdata(et)
|
||||
return ctx.api.named_generic_type("typing.Iterator", [unboxed])
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def array_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
|
||||
"""Callback to provide an accurate type for ctypes.Array.value."""
|
||||
et = _get_array_element_type(ctx.type)
|
||||
if et is not None:
|
||||
types: list[Type] = []
|
||||
for tp in flatten_nested_unions([et]):
|
||||
tp = get_proper_type(tp)
|
||||
if isinstance(tp, AnyType):
|
||||
types.append(AnyType(TypeOfAny.from_another_any, source_any=tp))
|
||||
elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_char":
|
||||
types.append(ctx.api.named_generic_type("builtins.bytes", []))
|
||||
elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_wchar":
|
||||
types.append(ctx.api.named_generic_type("builtins.str", []))
|
||||
else:
|
||||
ctx.api.msg.fail(
|
||||
'Array attribute "value" is only available'
|
||||
' with element type "c_char" or "c_wchar", not {}'.format(
|
||||
format_type(et, ctx.api.options)
|
||||
),
|
||||
ctx.context,
|
||||
)
|
||||
return make_simplified_union(types)
|
||||
return ctx.default_attr_type
|
||||
|
||||
|
||||
def array_raw_callback(ctx: mypy.plugin.AttributeContext) -> Type:
|
||||
"""Callback to provide an accurate type for ctypes.Array.raw."""
|
||||
et = _get_array_element_type(ctx.type)
|
||||
if et is not None:
|
||||
types: list[Type] = []
|
||||
for tp in flatten_nested_unions([et]):
|
||||
tp = get_proper_type(tp)
|
||||
if (
|
||||
isinstance(tp, AnyType)
|
||||
or isinstance(tp, Instance)
|
||||
and tp.type.fullname == "ctypes.c_char"
|
||||
):
|
||||
types.append(ctx.api.named_generic_type("builtins.bytes", []))
|
||||
else:
|
||||
ctx.api.msg.fail(
|
||||
'Array attribute "raw" is only available'
|
||||
' with element type "c_char", not {}'.format(format_type(et, ctx.api.options)),
|
||||
ctx.context,
|
||||
)
|
||||
return make_simplified_union(types)
|
||||
return ctx.default_attr_type
|
||||
Binary file not shown.
1115
venv/lib/python3.12/site-packages/mypy/plugins/dataclasses.py
Normal file
1115
venv/lib/python3.12/site-packages/mypy/plugins/dataclasses.py
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
504
venv/lib/python3.12/site-packages/mypy/plugins/default.py
Normal file
504
venv/lib/python3.12/site-packages/mypy/plugins/default.py
Normal file
@@ -0,0 +1,504 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import mypy.errorcodes as codes
|
||||
from mypy import message_registry
|
||||
from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr
|
||||
from mypy.plugin import (
|
||||
AttributeContext,
|
||||
ClassDefContext,
|
||||
FunctionContext,
|
||||
FunctionSigContext,
|
||||
MethodContext,
|
||||
MethodSigContext,
|
||||
Plugin,
|
||||
)
|
||||
from mypy.plugins.common import try_getting_str_literals
|
||||
from mypy.subtypes import is_subtype
|
||||
from mypy.typeops import is_literal_type_like, make_simplified_union
|
||||
from mypy.types import (
|
||||
TPDICT_FB_NAMES,
|
||||
AnyType,
|
||||
CallableType,
|
||||
FunctionLike,
|
||||
Instance,
|
||||
LiteralType,
|
||||
NoneType,
|
||||
TupleType,
|
||||
Type,
|
||||
TypedDictType,
|
||||
TypeOfAny,
|
||||
TypeVarType,
|
||||
UnionType,
|
||||
get_proper_type,
|
||||
get_proper_types,
|
||||
)
|
||||
|
||||
|
||||
class DefaultPlugin(Plugin):
|
||||
"""Type checker plugin that is enabled by default."""
|
||||
|
||||
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
|
||||
from mypy.plugins import ctypes, singledispatch
|
||||
|
||||
if fullname == "_ctypes.Array":
|
||||
return ctypes.array_constructor_callback
|
||||
elif fullname == "functools.singledispatch":
|
||||
return singledispatch.create_singledispatch_function_callback
|
||||
|
||||
return None
|
||||
|
||||
def get_function_signature_hook(
|
||||
self, fullname: str
|
||||
) -> Callable[[FunctionSigContext], FunctionLike] | None:
|
||||
from mypy.plugins import attrs, dataclasses
|
||||
|
||||
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
|
||||
return attrs.evolve_function_sig_callback
|
||||
elif fullname in ("attr.fields", "attrs.fields"):
|
||||
return attrs.fields_function_sig_callback
|
||||
elif fullname == "dataclasses.replace":
|
||||
return dataclasses.replace_function_sig_callback
|
||||
return None
|
||||
|
||||
def get_method_signature_hook(
|
||||
self, fullname: str
|
||||
) -> Callable[[MethodSigContext], FunctionLike] | None:
|
||||
from mypy.plugins import ctypes, singledispatch
|
||||
|
||||
if fullname == "typing.Mapping.get":
|
||||
return typed_dict_get_signature_callback
|
||||
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
|
||||
return typed_dict_setdefault_signature_callback
|
||||
elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
|
||||
return typed_dict_pop_signature_callback
|
||||
elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}:
|
||||
return typed_dict_update_signature_callback
|
||||
elif fullname == "_ctypes.Array.__setitem__":
|
||||
return ctypes.array_setitem_callback
|
||||
elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD:
|
||||
return singledispatch.call_singledispatch_function_callback
|
||||
return None
|
||||
|
||||
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
|
||||
from mypy.plugins import ctypes, singledispatch
|
||||
|
||||
if fullname == "typing.Mapping.get":
|
||||
return typed_dict_get_callback
|
||||
elif fullname == "builtins.int.__pow__":
|
||||
return int_pow_callback
|
||||
elif fullname == "builtins.int.__neg__":
|
||||
return int_neg_callback
|
||||
elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
|
||||
return tuple_mul_callback
|
||||
elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}:
|
||||
return typed_dict_setdefault_callback
|
||||
elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}:
|
||||
return typed_dict_pop_callback
|
||||
elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}:
|
||||
return typed_dict_delitem_callback
|
||||
elif fullname == "_ctypes.Array.__getitem__":
|
||||
return ctypes.array_getitem_callback
|
||||
elif fullname == "_ctypes.Array.__iter__":
|
||||
return ctypes.array_iter_callback
|
||||
elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD:
|
||||
return singledispatch.singledispatch_register_callback
|
||||
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
|
||||
return singledispatch.call_singledispatch_function_after_register_argument
|
||||
return None
|
||||
|
||||
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
|
||||
from mypy.plugins import ctypes, enums
|
||||
|
||||
if fullname == "_ctypes.Array.value":
|
||||
return ctypes.array_value_callback
|
||||
elif fullname == "_ctypes.Array.raw":
|
||||
return ctypes.array_raw_callback
|
||||
elif fullname in enums.ENUM_NAME_ACCESS:
|
||||
return enums.enum_name_callback
|
||||
elif fullname in enums.ENUM_VALUE_ACCESS:
|
||||
return enums.enum_value_callback
|
||||
return None
|
||||
|
||||
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
|
||||
from mypy.plugins import attrs, dataclasses
|
||||
|
||||
# These dataclass and attrs hooks run in the main semantic analysis pass
|
||||
# and only tag known dataclasses/attrs classes, so that the second
|
||||
# hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
|
||||
# in the MRO.
|
||||
if fullname in dataclasses.dataclass_makers:
|
||||
return dataclasses.dataclass_tag_callback
|
||||
if (
|
||||
fullname in attrs.attr_class_makers
|
||||
or fullname in attrs.attr_dataclass_makers
|
||||
or fullname in attrs.attr_frozen_makers
|
||||
or fullname in attrs.attr_define_makers
|
||||
):
|
||||
return attrs.attr_tag_callback
|
||||
|
||||
return None
|
||||
|
||||
def get_class_decorator_hook_2(
|
||||
self, fullname: str
|
||||
) -> Callable[[ClassDefContext], bool] | None:
|
||||
from mypy.plugins import attrs, dataclasses, functools
|
||||
|
||||
if fullname in dataclasses.dataclass_makers:
|
||||
return dataclasses.dataclass_class_maker_callback
|
||||
elif fullname in functools.functools_total_ordering_makers:
|
||||
return functools.functools_total_ordering_maker_callback
|
||||
elif fullname in attrs.attr_class_makers:
|
||||
return attrs.attr_class_maker_callback
|
||||
elif fullname in attrs.attr_dataclass_makers:
|
||||
return partial(attrs.attr_class_maker_callback, auto_attribs_default=True)
|
||||
elif fullname in attrs.attr_frozen_makers:
|
||||
return partial(
|
||||
attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
|
||||
)
|
||||
elif fullname in attrs.attr_define_makers:
|
||||
return partial(
|
||||
attrs.attr_class_maker_callback, auto_attribs_default=None, slots_default=True
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
|
||||
"""Try to infer a better signature type for TypedDict.get.
|
||||
|
||||
This is used to get better type context for the second argument that
|
||||
depends on a TypedDict value type.
|
||||
"""
|
||||
signature = ctx.default_signature
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.args) == 2
|
||||
and len(ctx.args[0]) == 1
|
||||
and isinstance(ctx.args[0][0], StrExpr)
|
||||
and len(signature.arg_types) == 2
|
||||
and len(signature.variables) == 1
|
||||
and len(ctx.args[1]) == 1
|
||||
):
|
||||
key = ctx.args[0][0].value
|
||||
value_type = get_proper_type(ctx.type.items.get(key))
|
||||
ret_type = signature.ret_type
|
||||
if value_type:
|
||||
default_arg = ctx.args[1][0]
|
||||
if (
|
||||
isinstance(value_type, TypedDictType)
|
||||
and isinstance(default_arg, DictExpr)
|
||||
and len(default_arg.items) == 0
|
||||
):
|
||||
# Caller has empty dict {} as default for typed dict.
|
||||
value_type = value_type.copy_modified(required_keys=set())
|
||||
# Tweak the signature to include the value type as context. It's
|
||||
# only needed for type inference since there's a union with a type
|
||||
# variable that accepts everything.
|
||||
tv = signature.variables[0]
|
||||
assert isinstance(tv, TypeVarType)
|
||||
return signature.copy_modified(
|
||||
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
|
||||
ret_type=ret_type,
|
||||
)
|
||||
return signature
|
||||
|
||||
|
||||
def typed_dict_get_callback(ctx: MethodContext) -> Type:
|
||||
"""Infer a precise return type for TypedDict.get with literal first argument."""
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.arg_types) >= 1
|
||||
and len(ctx.arg_types[0]) == 1
|
||||
):
|
||||
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
|
||||
if keys is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
output_types: list[Type] = []
|
||||
for key in keys:
|
||||
value_type = get_proper_type(ctx.type.items.get(key))
|
||||
if value_type is None:
|
||||
return ctx.default_return_type
|
||||
|
||||
if len(ctx.arg_types) == 1:
|
||||
output_types.append(value_type)
|
||||
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
|
||||
default_arg = ctx.args[1][0]
|
||||
if (
|
||||
isinstance(default_arg, DictExpr)
|
||||
and len(default_arg.items) == 0
|
||||
and isinstance(value_type, TypedDictType)
|
||||
):
|
||||
# Special case '{}' as the default for a typed dict type.
|
||||
output_types.append(value_type.copy_modified(required_keys=set()))
|
||||
else:
|
||||
output_types.append(value_type)
|
||||
output_types.append(ctx.arg_types[1][0])
|
||||
|
||||
if len(ctx.arg_types) == 1:
|
||||
output_types.append(NoneType())
|
||||
|
||||
return make_simplified_union(output_types)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
|
||||
"""Try to infer a better signature type for TypedDict.pop.
|
||||
|
||||
This is used to get better type context for the second argument that
|
||||
depends on a TypedDict value type.
|
||||
"""
|
||||
signature = ctx.default_signature
|
||||
str_type = ctx.api.named_generic_type("builtins.str", [])
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.args) == 2
|
||||
and len(ctx.args[0]) == 1
|
||||
and isinstance(ctx.args[0][0], StrExpr)
|
||||
and len(signature.arg_types) == 2
|
||||
and len(signature.variables) == 1
|
||||
and len(ctx.args[1]) == 1
|
||||
):
|
||||
key = ctx.args[0][0].value
|
||||
value_type = ctx.type.items.get(key)
|
||||
if value_type:
|
||||
# Tweak the signature to include the value type as context. It's
|
||||
# only needed for type inference since there's a union with a type
|
||||
# variable that accepts everything.
|
||||
tv = signature.variables[0]
|
||||
assert isinstance(tv, TypeVarType)
|
||||
typ = make_simplified_union([value_type, tv])
|
||||
return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
|
||||
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
|
||||
|
||||
|
||||
def typed_dict_pop_callback(ctx: MethodContext) -> Type:
|
||||
"""Type check and infer a precise return type for TypedDict.pop."""
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.arg_types) >= 1
|
||||
and len(ctx.arg_types[0]) == 1
|
||||
):
|
||||
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
|
||||
if keys is None:
|
||||
ctx.api.fail(
|
||||
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
|
||||
ctx.context,
|
||||
code=codes.LITERAL_REQ,
|
||||
)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
value_types = []
|
||||
for key in keys:
|
||||
if key in ctx.type.required_keys:
|
||||
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
|
||||
|
||||
value_type = ctx.type.items.get(key)
|
||||
if value_type:
|
||||
value_types.append(value_type)
|
||||
else:
|
||||
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
if len(ctx.args[1]) == 0:
|
||||
return make_simplified_union(value_types)
|
||||
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
|
||||
return make_simplified_union([*value_types, ctx.arg_types[1][0]])
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
|
||||
"""Try to infer a better signature type for TypedDict.setdefault.
|
||||
|
||||
This is used to get better type context for the second argument that
|
||||
depends on a TypedDict value type.
|
||||
"""
|
||||
signature = ctx.default_signature
|
||||
str_type = ctx.api.named_generic_type("builtins.str", [])
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.args) == 2
|
||||
and len(ctx.args[0]) == 1
|
||||
and isinstance(ctx.args[0][0], StrExpr)
|
||||
and len(signature.arg_types) == 2
|
||||
and len(ctx.args[1]) == 1
|
||||
):
|
||||
key = ctx.args[0][0].value
|
||||
value_type = ctx.type.items.get(key)
|
||||
if value_type:
|
||||
return signature.copy_modified(arg_types=[str_type, value_type])
|
||||
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
|
||||
|
||||
|
||||
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
|
||||
"""Type check TypedDict.setdefault and infer a precise return type."""
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.arg_types) == 2
|
||||
and len(ctx.arg_types[0]) == 1
|
||||
and len(ctx.arg_types[1]) == 1
|
||||
):
|
||||
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
|
||||
if keys is None:
|
||||
ctx.api.fail(
|
||||
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
|
||||
ctx.context,
|
||||
code=codes.LITERAL_REQ,
|
||||
)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
default_type = ctx.arg_types[1][0]
|
||||
|
||||
value_types = []
|
||||
for key in keys:
|
||||
value_type = ctx.type.items.get(key)
|
||||
|
||||
if value_type is None:
|
||||
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
# The signature_callback above can't always infer the right signature
|
||||
# (e.g. when the expression is a variable that happens to be a Literal str)
|
||||
# so we need to handle the check ourselves here and make sure the provided
|
||||
# default can be assigned to all key-value pairs we're updating.
|
||||
if not is_subtype(default_type, value_type):
|
||||
ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
|
||||
default_type, value_type, ctx.context
|
||||
)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
value_types.append(value_type)
|
||||
|
||||
return make_simplified_union(value_types)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
|
||||
"""Type check TypedDict.__delitem__."""
|
||||
if (
|
||||
isinstance(ctx.type, TypedDictType)
|
||||
and len(ctx.arg_types) == 1
|
||||
and len(ctx.arg_types[0]) == 1
|
||||
):
|
||||
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
|
||||
if keys is None:
|
||||
ctx.api.fail(
|
||||
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
|
||||
ctx.context,
|
||||
code=codes.LITERAL_REQ,
|
||||
)
|
||||
return AnyType(TypeOfAny.from_error)
|
||||
|
||||
for key in keys:
|
||||
if key in ctx.type.required_keys:
|
||||
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
|
||||
elif key not in ctx.type.items:
|
||||
ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
|
||||
"""Try to infer a better signature type for TypedDict.update."""
|
||||
signature = ctx.default_signature
|
||||
if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
|
||||
arg_type = get_proper_type(signature.arg_types[0])
|
||||
assert isinstance(arg_type, TypedDictType)
|
||||
arg_type = arg_type.as_anonymous()
|
||||
arg_type = arg_type.copy_modified(required_keys=set())
|
||||
if ctx.args and ctx.args[0]:
|
||||
with ctx.api.msg.filter_errors():
|
||||
inferred = get_proper_type(
|
||||
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
|
||||
)
|
||||
possible_tds = []
|
||||
if isinstance(inferred, TypedDictType):
|
||||
possible_tds = [inferred]
|
||||
elif isinstance(inferred, UnionType):
|
||||
possible_tds = [
|
||||
t
|
||||
for t in get_proper_types(inferred.relevant_items())
|
||||
if isinstance(t, TypedDictType)
|
||||
]
|
||||
items = []
|
||||
for td in possible_tds:
|
||||
item = arg_type.copy_modified(
|
||||
required_keys=(arg_type.required_keys | td.required_keys)
|
||||
& arg_type.items.keys()
|
||||
)
|
||||
if not ctx.api.options.extra_checks:
|
||||
item = item.copy_modified(item_names=list(td.items))
|
||||
items.append(item)
|
||||
if items:
|
||||
arg_type = make_simplified_union(items)
|
||||
return signature.copy_modified(arg_types=[arg_type])
|
||||
return signature
|
||||
|
||||
|
||||
def int_pow_callback(ctx: MethodContext) -> Type:
|
||||
"""Infer a more precise return type for int.__pow__."""
|
||||
# int.__pow__ has an optional modulo argument,
|
||||
# so we expect 2 argument positions
|
||||
if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
|
||||
arg = ctx.args[0][0]
|
||||
if isinstance(arg, IntExpr):
|
||||
exponent = arg.value
|
||||
elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
|
||||
exponent = -arg.expr.value
|
||||
else:
|
||||
# Right operand not an int literal or a negated literal -- give up.
|
||||
return ctx.default_return_type
|
||||
if exponent >= 0:
|
||||
return ctx.api.named_generic_type("builtins.int", [])
|
||||
else:
|
||||
return ctx.api.named_generic_type("builtins.float", [])
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def int_neg_callback(ctx: MethodContext) -> Type:
|
||||
"""Infer a more precise return type for int.__neg__.
|
||||
|
||||
This is mainly used to infer the return type as LiteralType
|
||||
if the original underlying object is a LiteralType object
|
||||
"""
|
||||
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
|
||||
value = ctx.type.last_known_value.value
|
||||
fallback = ctx.type.last_known_value.fallback
|
||||
if isinstance(value, int):
|
||||
if is_literal_type_like(ctx.api.type_context[-1]):
|
||||
return LiteralType(value=-value, fallback=fallback)
|
||||
else:
|
||||
return ctx.type.copy_modified(
|
||||
last_known_value=LiteralType(
|
||||
value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column
|
||||
)
|
||||
)
|
||||
elif isinstance(ctx.type, LiteralType):
|
||||
value = ctx.type.value
|
||||
fallback = ctx.type.fallback
|
||||
if isinstance(value, int):
|
||||
return LiteralType(value=-value, fallback=fallback)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def tuple_mul_callback(ctx: MethodContext) -> Type:
|
||||
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
|
||||
|
||||
This is used to return a specific sized tuple if multiplied by Literal int
|
||||
"""
|
||||
if not isinstance(ctx.type, TupleType):
|
||||
return ctx.default_return_type
|
||||
|
||||
arg_type = get_proper_type(ctx.arg_types[0][0])
|
||||
if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
|
||||
value = arg_type.last_known_value.value
|
||||
if isinstance(value, int):
|
||||
return ctx.type.copy_modified(items=ctx.type.items * value)
|
||||
elif isinstance(ctx.type, LiteralType):
|
||||
value = arg_type.value
|
||||
if isinstance(value, int):
|
||||
return ctx.type.copy_modified(items=ctx.type.items * value)
|
||||
|
||||
return ctx.default_return_type
|
||||
Binary file not shown.
258
venv/lib/python3.12/site-packages/mypy/plugins/enums.py
Normal file
258
venv/lib/python3.12/site-packages/mypy/plugins/enums.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
This file contains a variety of plugins for refining how mypy infers types of
|
||||
expressions involving Enums.
|
||||
|
||||
Currently, this file focuses on providing better inference for expressions like
|
||||
'SomeEnum.FOO.name' and 'SomeEnum.FOO.value'. Note that the type of both expressions
|
||||
will vary depending on exactly which instance of SomeEnum we're looking at.
|
||||
|
||||
Note that this file does *not* contain all special-cased logic related to enums:
|
||||
we actually bake some of it directly in to the semantic analysis layer (see
|
||||
semanal_enum.py).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final, Iterable, Sequence, TypeVar, cast
|
||||
|
||||
import mypy.plugin # To avoid circular imports.
|
||||
from mypy.nodes import TypeInfo
|
||||
from mypy.semanal_enum import ENUM_BASES
|
||||
from mypy.subtypes import is_equivalent
|
||||
from mypy.typeops import fixup_partial_type, make_simplified_union
|
||||
from mypy.types import CallableType, Instance, LiteralType, ProperType, Type, get_proper_type
|
||||
|
||||
ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
|
||||
f"{prefix}._name_" for prefix in ENUM_BASES
|
||||
}
|
||||
ENUM_VALUE_ACCESS: Final = {f"{prefix}.value" for prefix in ENUM_BASES} | {
|
||||
f"{prefix}._value_" for prefix in ENUM_BASES
|
||||
}
|
||||
|
||||
|
||||
def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type:
|
||||
"""This plugin refines the 'name' attribute in enums to act as if
|
||||
they were declared to be final.
|
||||
|
||||
For example, the expression 'MyEnum.FOO.name' normally is inferred
|
||||
to be of type 'str'.
|
||||
|
||||
This plugin will instead make the inferred type be a 'str' where the
|
||||
last known value is 'Literal["FOO"]'. This means it would be legal to
|
||||
use 'MyEnum.FOO.name' in contexts that expect a Literal type, just like
|
||||
any other Final variable or attribute.
|
||||
|
||||
This plugin assumes that the provided context is an attribute access
|
||||
matching one of the strings found in 'ENUM_NAME_ACCESS'.
|
||||
"""
|
||||
enum_field_name = _extract_underlying_field_name(ctx.type)
|
||||
if enum_field_name is None:
|
||||
return ctx.default_attr_type
|
||||
else:
|
||||
str_type = ctx.api.named_generic_type("builtins.str", [])
|
||||
literal_type = LiteralType(enum_field_name, fallback=str_type)
|
||||
return str_type.copy_modified(last_known_value=literal_type)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _first(it: Iterable[_T]) -> _T | None:
|
||||
"""Return the first value from any iterable.
|
||||
|
||||
Returns ``None`` if the iterable is empty.
|
||||
"""
|
||||
for val in it:
|
||||
return val
|
||||
return None
|
||||
|
||||
|
||||
def _infer_value_type_with_auto_fallback(
|
||||
ctx: mypy.plugin.AttributeContext, proper_type: ProperType | None
|
||||
) -> Type | None:
|
||||
"""Figure out the type of an enum value accounting for `auto()`.
|
||||
|
||||
This method is a no-op for a `None` proper_type and also in the case where
|
||||
the type is not "enum.auto"
|
||||
"""
|
||||
if proper_type is None:
|
||||
return None
|
||||
proper_type = get_proper_type(fixup_partial_type(proper_type))
|
||||
if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"):
|
||||
return proper_type
|
||||
assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed."
|
||||
info = ctx.type.type
|
||||
# Find the first _generate_next_value_ on the mro. We need to know
|
||||
# if it is `Enum` because `Enum` types say that the return-value of
|
||||
# `_generate_next_value_` is `Any`. In reality the default `auto()`
|
||||
# returns an `int` (presumably the `Any` in typeshed is to make it
|
||||
# easier to subclass and change the returned type).
|
||||
type_with_gnv = _first(ti for ti in info.mro if ti.names.get("_generate_next_value_"))
|
||||
if type_with_gnv is None:
|
||||
return ctx.default_attr_type
|
||||
|
||||
stnode = type_with_gnv.names["_generate_next_value_"]
|
||||
|
||||
# This should be a `CallableType`
|
||||
node_type = get_proper_type(stnode.type)
|
||||
if isinstance(node_type, CallableType):
|
||||
if type_with_gnv.fullname == "enum.Enum":
|
||||
int_type = ctx.api.named_generic_type("builtins.int", [])
|
||||
return int_type
|
||||
return get_proper_type(node_type.ret_type)
|
||||
return ctx.default_attr_type
|
||||
|
||||
|
||||
def _implements_new(info: TypeInfo) -> bool:
|
||||
"""Check whether __new__ comes from enum.Enum or was implemented in a
|
||||
subclass. In the latter case, we must infer Any as long as mypy can't infer
|
||||
the type of _value_ from assignments in __new__.
|
||||
"""
|
||||
type_with_new = _first(
|
||||
ti
|
||||
for ti in info.mro
|
||||
if ti.names.get("__new__") and not ti.fullname.startswith("builtins.")
|
||||
)
|
||||
if type_with_new is None:
|
||||
return False
|
||||
return type_with_new.fullname not in ("enum.Enum", "enum.IntEnum", "enum.StrEnum")
|
||||
|
||||
|
||||
def enum_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
|
||||
"""This plugin refines the 'value' attribute in enums to refer to
|
||||
the original underlying value. For example, suppose we have the
|
||||
following:
|
||||
|
||||
class SomeEnum:
|
||||
FOO = A()
|
||||
BAR = B()
|
||||
|
||||
By default, mypy will infer that 'SomeEnum.FOO.value' and
|
||||
'SomeEnum.BAR.value' both are of type 'Any'. This plugin refines
|
||||
this inference so that mypy understands the expressions are
|
||||
actually of types 'A' and 'B' respectively. This better reflects
|
||||
the actual runtime behavior.
|
||||
|
||||
This plugin works simply by looking up the original value assigned
|
||||
to the enum. For example, when this plugin sees 'SomeEnum.BAR.value',
|
||||
it will look up whatever type 'BAR' had in the SomeEnum TypeInfo and
|
||||
use that as the inferred type of the overall expression.
|
||||
|
||||
This plugin assumes that the provided context is an attribute access
|
||||
matching one of the strings found in 'ENUM_VALUE_ACCESS'.
|
||||
"""
|
||||
enum_field_name = _extract_underlying_field_name(ctx.type)
|
||||
if enum_field_name is None:
|
||||
# We do not know the enum field name (perhaps it was passed to a
|
||||
# function and we only know that it _is_ a member). All is not lost
|
||||
# however, if we can prove that the all of the enum members have the
|
||||
# same value-type, then it doesn't matter which member was passed in.
|
||||
# The value-type is still known.
|
||||
if isinstance(ctx.type, Instance):
|
||||
info = ctx.type.type
|
||||
|
||||
# As long as mypy doesn't understand attribute creation in __new__,
|
||||
# there is no way to predict the value type if the enum class has a
|
||||
# custom implementation
|
||||
if _implements_new(info):
|
||||
return ctx.default_attr_type
|
||||
|
||||
stnodes = (info.get(name) for name in info.names)
|
||||
|
||||
# Enums _can_ have methods and instance attributes.
|
||||
# Omit methods and attributes created by assigning to self.*
|
||||
# for our value inference.
|
||||
node_types = (
|
||||
get_proper_type(n.type) if n else None
|
||||
for n in stnodes
|
||||
if n is None or not n.implicit
|
||||
)
|
||||
proper_types = list(
|
||||
_infer_value_type_with_auto_fallback(ctx, t)
|
||||
for t in node_types
|
||||
if t is None or not isinstance(t, CallableType)
|
||||
)
|
||||
underlying_type = _first(proper_types)
|
||||
if underlying_type is None:
|
||||
return ctx.default_attr_type
|
||||
|
||||
# At first we try to predict future `value` type if all other items
|
||||
# have the same type. For example, `int`.
|
||||
# If this is the case, we simply return this type.
|
||||
# See https://github.com/python/mypy/pull/9443
|
||||
all_same_value_type = all(
|
||||
proper_type is not None and proper_type == underlying_type
|
||||
for proper_type in proper_types
|
||||
)
|
||||
if all_same_value_type:
|
||||
if underlying_type is not None:
|
||||
return underlying_type
|
||||
|
||||
# But, after we started treating all `Enum` values as `Final`,
|
||||
# we start to infer types in
|
||||
# `item = 1` as `Literal[1]`, not just `int`.
|
||||
# So, for example types in this `Enum` will all be different:
|
||||
#
|
||||
# class Ordering(IntEnum):
|
||||
# one = 1
|
||||
# two = 2
|
||||
# three = 3
|
||||
#
|
||||
# We will infer three `Literal` types here.
|
||||
# They are not the same, but they are equivalent.
|
||||
# So, we unify them to make sure `.value` prediction still works.
|
||||
# Result will be `Literal[1] | Literal[2] | Literal[3]` for this case.
|
||||
all_equivalent_types = all(
|
||||
proper_type is not None and is_equivalent(proper_type, underlying_type)
|
||||
for proper_type in proper_types
|
||||
)
|
||||
if all_equivalent_types:
|
||||
return make_simplified_union(cast(Sequence[Type], proper_types))
|
||||
return ctx.default_attr_type
|
||||
|
||||
assert isinstance(ctx.type, Instance)
|
||||
info = ctx.type.type
|
||||
|
||||
# As long as mypy doesn't understand attribute creation in __new__,
|
||||
# there is no way to predict the value type if the enum class has a
|
||||
# custom implementation
|
||||
if _implements_new(info):
|
||||
return ctx.default_attr_type
|
||||
|
||||
stnode = info.get(enum_field_name)
|
||||
if stnode is None:
|
||||
return ctx.default_attr_type
|
||||
|
||||
underlying_type = _infer_value_type_with_auto_fallback(ctx, get_proper_type(stnode.type))
|
||||
if underlying_type is None:
|
||||
return ctx.default_attr_type
|
||||
|
||||
return underlying_type
|
||||
|
||||
|
||||
def _extract_underlying_field_name(typ: Type) -> str | None:
|
||||
"""If the given type corresponds to some Enum instance, returns the
|
||||
original name of that enum. For example, if we receive in the type
|
||||
corresponding to 'SomeEnum.FOO', we return the string "SomeEnum.Foo".
|
||||
|
||||
This helper takes advantage of the fact that Enum instances are valid
|
||||
to use inside Literal[...] types. An expression like 'SomeEnum.FOO' is
|
||||
actually represented by an Instance type with a Literal enum fallback.
|
||||
|
||||
We can examine this Literal fallback to retrieve the string.
|
||||
"""
|
||||
typ = get_proper_type(typ)
|
||||
if not isinstance(typ, Instance):
|
||||
return None
|
||||
|
||||
if not typ.type.is_enum:
|
||||
return None
|
||||
|
||||
underlying_literal = typ.last_known_value
|
||||
if underlying_literal is None:
|
||||
return None
|
||||
|
||||
# The checks above have verified this LiteralType is representing an enum value,
|
||||
# which means the 'value' field is guaranteed to be the name of the enum field
|
||||
# as a string.
|
||||
assert isinstance(underlying_literal.value, str)
|
||||
return underlying_literal.value
|
||||
Binary file not shown.
103
venv/lib/python3.12/site-packages/mypy/plugins/functools.py
Normal file
103
venv/lib/python3.12/site-packages/mypy/plugins/functools.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Plugin for supporting the functools standard library module."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final, NamedTuple
|
||||
|
||||
import mypy.plugin
|
||||
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
|
||||
from mypy.plugins.common import add_method_to_class
|
||||
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
|
||||
|
||||
functools_total_ordering_makers: Final = {"functools.total_ordering"}
|
||||
|
||||
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
|
||||
|
||||
|
||||
class _MethodInfo(NamedTuple):
|
||||
is_static: bool
|
||||
type: CallableType
|
||||
|
||||
|
||||
def functools_total_ordering_maker_callback(
|
||||
ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False
|
||||
) -> bool:
|
||||
"""Add dunder methods to classes decorated with functools.total_ordering."""
|
||||
comparison_methods = _analyze_class(ctx)
|
||||
if not comparison_methods:
|
||||
ctx.api.fail(
|
||||
'No ordering operation defined when using "functools.total_ordering": < > <= >=',
|
||||
ctx.reason,
|
||||
)
|
||||
return True
|
||||
|
||||
# prefer __lt__ to __le__ to __gt__ to __ge__
|
||||
root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
|
||||
root_method = comparison_methods[root]
|
||||
if not root_method:
|
||||
# None of the defined comparison methods can be analysed
|
||||
return True
|
||||
|
||||
other_type = _find_other_type(root_method)
|
||||
bool_type = ctx.api.named_type("builtins.bool")
|
||||
ret_type: Type = bool_type
|
||||
if root_method.type.ret_type != ctx.api.named_type("builtins.bool"):
|
||||
proper_ret_type = get_proper_type(root_method.type.ret_type)
|
||||
if not (
|
||||
isinstance(proper_ret_type, UnboundType)
|
||||
and proper_ret_type.name.split(".")[-1] == "bool"
|
||||
):
|
||||
ret_type = AnyType(TypeOfAny.implementation_artifact)
|
||||
for additional_op in _ORDERING_METHODS:
|
||||
# Either the method is not implemented
|
||||
# or has an unknown signature that we can now extrapolate.
|
||||
if not comparison_methods.get(additional_op):
|
||||
args = [Argument(Var("other", other_type), other_type, None, ARG_POS)]
|
||||
add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _find_other_type(method: _MethodInfo) -> Type:
|
||||
"""Find the type of the ``other`` argument in a comparison method."""
|
||||
first_arg_pos = 0 if method.is_static else 1
|
||||
cur_pos_arg = 0
|
||||
other_arg = None
|
||||
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
|
||||
if arg_kind.is_positional():
|
||||
if cur_pos_arg == first_arg_pos:
|
||||
other_arg = arg_type
|
||||
break
|
||||
|
||||
cur_pos_arg += 1
|
||||
elif arg_kind != ARG_STAR2:
|
||||
other_arg = arg_type
|
||||
break
|
||||
|
||||
if other_arg is None:
|
||||
return AnyType(TypeOfAny.implementation_artifact)
|
||||
|
||||
return other_arg
|
||||
|
||||
|
||||
def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | None]:
|
||||
"""Analyze the class body, its parents, and return the comparison methods found."""
|
||||
# Traverse the MRO and collect ordering methods.
|
||||
comparison_methods: dict[str, _MethodInfo | None] = {}
|
||||
# Skip object because total_ordering does not use methods from object
|
||||
for cls in ctx.cls.info.mro[:-1]:
|
||||
for name in _ORDERING_METHODS:
|
||||
if name in cls.names and name not in comparison_methods:
|
||||
node = cls.names[name].node
|
||||
if isinstance(node, FuncItem) and isinstance(node.type, CallableType):
|
||||
comparison_methods[name] = _MethodInfo(node.is_static, node.type)
|
||||
continue
|
||||
|
||||
if isinstance(node, Var):
|
||||
proper_type = get_proper_type(node.type)
|
||||
if isinstance(proper_type, CallableType):
|
||||
comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
|
||||
continue
|
||||
|
||||
comparison_methods[name] = None
|
||||
|
||||
return comparison_methods
|
||||
Binary file not shown.
224
venv/lib/python3.12/site-packages/mypy/plugins/singledispatch.py
Normal file
224
venv/lib/python3.12/site-packages/mypy/plugins/singledispatch.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Final, NamedTuple, Sequence, TypeVar, Union
|
||||
from typing_extensions import TypeAlias as _TypeAlias
|
||||
|
||||
from mypy.messages import format_type
|
||||
from mypy.nodes import ARG_POS, Argument, Block, ClassDef, Context, SymbolTable, TypeInfo, Var
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext
|
||||
from mypy.plugins.common import add_method_to_class
|
||||
from mypy.subtypes import is_subtype
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
FunctionLike,
|
||||
Instance,
|
||||
NoneType,
|
||||
Overloaded,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
get_proper_type,
|
||||
)
|
||||
|
||||
|
||||
class SingledispatchTypeVars(NamedTuple):
|
||||
return_type: Type
|
||||
fallback: CallableType
|
||||
|
||||
|
||||
class RegisterCallableInfo(NamedTuple):
|
||||
register_type: Type
|
||||
singledispatch_obj: Instance
|
||||
|
||||
|
||||
SINGLEDISPATCH_TYPE: Final = "functools._SingleDispatchCallable"
|
||||
|
||||
SINGLEDISPATCH_REGISTER_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.register"
|
||||
|
||||
SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.__call__"
|
||||
|
||||
|
||||
def get_singledispatch_info(typ: Instance) -> SingledispatchTypeVars | None:
|
||||
if len(typ.args) == 2:
|
||||
return SingledispatchTypeVars(*typ.args) # type: ignore[arg-type]
|
||||
return None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_first_arg(args: list[list[T]]) -> T | None:
|
||||
"""Get the element that corresponds to the first argument passed to the function"""
|
||||
if args and args[0]:
|
||||
return args[0][0]
|
||||
return None
|
||||
|
||||
|
||||
REGISTER_RETURN_CLASS: Final = "_SingleDispatchRegisterCallable"
|
||||
|
||||
REGISTER_CALLABLE_CALL_METHOD: Final = f"functools.{REGISTER_RETURN_CLASS}.__call__"
|
||||
|
||||
|
||||
def make_fake_register_class_instance(
|
||||
api: CheckerPluginInterface, type_args: Sequence[Type]
|
||||
) -> Instance:
|
||||
defn = ClassDef(REGISTER_RETURN_CLASS, Block([]))
|
||||
defn.fullname = f"functools.{REGISTER_RETURN_CLASS}"
|
||||
info = TypeInfo(SymbolTable(), defn, "functools")
|
||||
obj_type = api.named_generic_type("builtins.object", []).type
|
||||
info.bases = [Instance(obj_type, [])]
|
||||
info.mro = [info, obj_type]
|
||||
defn.info = info
|
||||
|
||||
func_arg = Argument(Var("name"), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS)
|
||||
add_method_to_class(api, defn, "__call__", [func_arg], NoneType())
|
||||
|
||||
return Instance(info, type_args)
|
||||
|
||||
|
||||
PluginContext: _TypeAlias = Union[FunctionContext, MethodContext]
|
||||
|
||||
|
||||
def fail(ctx: PluginContext, msg: str, context: Context | None) -> None:
|
||||
"""Emit an error message.
|
||||
|
||||
This tries to emit an error message at the location specified by `context`, falling back to the
|
||||
location specified by `ctx.context`. This is helpful when the only context information about
|
||||
where you want to put the error message may be None (like it is for `CallableType.definition`)
|
||||
and falling back to the location of the calling function is fine."""
|
||||
# TODO: figure out if there is some more reliable way of getting context information, so this
|
||||
# function isn't necessary
|
||||
if context is not None:
|
||||
err_context = context
|
||||
else:
|
||||
err_context = ctx.context
|
||||
ctx.api.fail(msg, err_context)
|
||||
|
||||
|
||||
def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
|
||||
"""Called for functools.singledispatch"""
|
||||
func_type = get_proper_type(get_first_arg(ctx.arg_types))
|
||||
if isinstance(func_type, CallableType):
|
||||
if len(func_type.arg_kinds) < 1:
|
||||
fail(
|
||||
ctx, "Singledispatch function requires at least one argument", func_type.definition
|
||||
)
|
||||
return ctx.default_return_type
|
||||
|
||||
elif not func_type.arg_kinds[0].is_positional(star=True):
|
||||
fail(
|
||||
ctx,
|
||||
"First argument to singledispatch function must be a positional argument",
|
||||
func_type.definition,
|
||||
)
|
||||
return ctx.default_return_type
|
||||
|
||||
# singledispatch returns an instance of functools._SingleDispatchCallable according to
|
||||
# typeshed
|
||||
singledispatch_obj = get_proper_type(ctx.default_return_type)
|
||||
assert isinstance(singledispatch_obj, Instance)
|
||||
singledispatch_obj.args += (func_type,)
|
||||
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def singledispatch_register_callback(ctx: MethodContext) -> Type:
|
||||
"""Called for functools._SingleDispatchCallable.register"""
|
||||
assert isinstance(ctx.type, Instance)
|
||||
# TODO: check that there's only one argument
|
||||
first_arg_type = get_proper_type(get_first_arg(ctx.arg_types))
|
||||
if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj():
|
||||
# HACK: We received a class as an argument to register. We need to be able
|
||||
# to access the function that register is being applied to, and the typeshed definition
|
||||
# of register has it return a generic Callable, so we create a new
|
||||
# SingleDispatchRegisterCallable class, define a __call__ method, and then add a
|
||||
# plugin hook for that.
|
||||
|
||||
# is_subtype doesn't work when the right type is Overloaded, so we need the
|
||||
# actual type
|
||||
register_type = first_arg_type.items[0].ret_type
|
||||
type_args = RegisterCallableInfo(register_type, ctx.type)
|
||||
register_callable = make_fake_register_class_instance(ctx.api, type_args)
|
||||
return register_callable
|
||||
elif isinstance(first_arg_type, CallableType):
|
||||
# TODO: do more checking for registered functions
|
||||
register_function(ctx, ctx.type, first_arg_type, ctx.api.options)
|
||||
# The typeshed stubs for register say that the function returned is Callable[..., T], even
|
||||
# though the function returned is the same as the one passed in. We return the type of the
|
||||
# function so that mypy can properly type check cases where the registered function is used
|
||||
# directly (instead of through singledispatch)
|
||||
return first_arg_type
|
||||
|
||||
# fallback in case we don't recognize the arguments
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def register_function(
|
||||
ctx: PluginContext,
|
||||
singledispatch_obj: Instance,
|
||||
func: Type,
|
||||
options: Options,
|
||||
register_arg: Type | None = None,
|
||||
) -> None:
|
||||
"""Register a function"""
|
||||
|
||||
func = get_proper_type(func)
|
||||
if not isinstance(func, CallableType):
|
||||
return
|
||||
metadata = get_singledispatch_info(singledispatch_obj)
|
||||
if metadata is None:
|
||||
# if we never added the fallback to the type variables, we already reported an error, so
|
||||
# just don't do anything here
|
||||
return
|
||||
dispatch_type = get_dispatch_type(func, register_arg)
|
||||
if dispatch_type is None:
|
||||
# TODO: report an error here that singledispatch requires at least one argument
|
||||
# (might want to do the error reporting in get_dispatch_type)
|
||||
return
|
||||
fallback = metadata.fallback
|
||||
|
||||
fallback_dispatch_type = fallback.arg_types[0]
|
||||
if not is_subtype(dispatch_type, fallback_dispatch_type):
|
||||
fail(
|
||||
ctx,
|
||||
"Dispatch type {} must be subtype of fallback function first argument {}".format(
|
||||
format_type(dispatch_type, options), format_type(fallback_dispatch_type, options)
|
||||
),
|
||||
func.definition,
|
||||
)
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
def get_dispatch_type(func: CallableType, register_arg: Type | None) -> Type | None:
|
||||
if register_arg is not None:
|
||||
return register_arg
|
||||
if func.arg_types:
|
||||
return func.arg_types[0]
|
||||
return None
|
||||
|
||||
|
||||
def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type:
|
||||
"""Called on the function after passing a type to register"""
|
||||
register_callable = ctx.type
|
||||
if isinstance(register_callable, Instance):
|
||||
type_args = RegisterCallableInfo(*register_callable.args) # type: ignore[arg-type]
|
||||
func = get_first_arg(ctx.arg_types)
|
||||
if func is not None:
|
||||
register_function(
|
||||
ctx, type_args.singledispatch_obj, func, ctx.api.options, type_args.register_type
|
||||
)
|
||||
# see call to register_function in the callback for register
|
||||
return func
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike:
|
||||
"""Called for functools._SingleDispatchCallable.__call__"""
|
||||
if not isinstance(ctx.type, Instance):
|
||||
return ctx.default_signature
|
||||
metadata = get_singledispatch_info(ctx.type)
|
||||
if metadata is None:
|
||||
return ctx.default_signature
|
||||
return metadata.fallback
|
||||
Reference in New Issue
Block a user