Major fixes and new features
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-09-25 15:51:48 +09:00
parent dd7349bb4c
commit ddce9f5125
5586 changed files with 1470941 additions and 0 deletions

View File

@@ -0,0 +1 @@
# This page intentionally left blank

View File

@@ -0,0 +1,37 @@
"""Mypy type checker command line tool."""
from __future__ import annotations
import os
import sys
import traceback
from mypy.main import main, process_options
from mypy.util import FancyFormatter
def console_entry() -> None:
try:
main()
sys.stdout.flush()
sys.stderr.flush()
except BrokenPipeError:
# Python flushes standard streams on exit; redirect remaining output
# to devnull to avoid another BrokenPipeError at shutdown
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())
sys.exit(2)
except KeyboardInterrupt:
_, options = process_options(args=sys.argv[1:])
if options.show_traceback:
sys.stdout.write(traceback.format_exc())
formatter = FancyFormatter(sys.stdout, sys.stderr, False)
msg = "Interrupted\n"
sys.stdout.write(formatter.style(msg, color="red", bold=True))
sys.stdout.flush()
sys.stderr.flush()
sys.exit(2)
if __name__ == "__main__":
console_entry()

View File

@@ -0,0 +1,94 @@
"""This module makes it possible to use mypy as part of a Python application.
Since mypy still changes, the API was kept utterly simple and non-intrusive.
It just mimics command line activation without starting a new interpreter.
So the normal docs about the mypy command line apply.
Changes in the command line version of mypy will be immediately usable.
Just import this module and then call the 'run' function with a parameter of
type List[str], containing what normally would have been the command line
arguments to mypy.
Function 'run' returns a Tuple[str, str, int], namely
(<normal_report>, <error_report>, <exit_status>),
in which <normal_report> is what mypy normally writes to sys.stdout,
<error_report> is what mypy normally writes to sys.stderr and exit_status is
the exit status mypy normally returns to the operating system.
Any pretty formatting is left to the caller.
The 'run_dmypy' function is similar, but instead mimics invocation of
dmypy. Note that run_dmypy is not thread-safe and modifies sys.stdout
and sys.stderr during its invocation.
Note that these APIs don't support incremental generation of error
messages.
Trivial example of code using this module:
import sys
from mypy import api
result = api.run(sys.argv[1:])
if result[0]:
print('\nType checking report:\n')
print(result[0]) # stdout
if result[1]:
print('\nError report:\n')
print(result[1]) # stderr
print('\nExit status:', result[2])
"""
from __future__ import annotations
import sys
from io import StringIO
from typing import Callable, TextIO
def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]:
stdout = StringIO()
stderr = StringIO()
try:
main_wrapper(stdout, stderr)
exit_status = 0
except SystemExit as system_exit:
assert isinstance(system_exit.code, int)
exit_status = system_exit.code
return stdout.getvalue(), stderr.getvalue(), exit_status
def run(args: list[str]) -> tuple[str, str, int]:
# Lazy import to avoid needing to import all of mypy to call run_dmypy
from mypy.main import main
return _run(
lambda stdout, stderr: main(args=args, stdout=stdout, stderr=stderr, clean_exit=True)
)
def run_dmypy(args: list[str]) -> tuple[str, str, int]:
from mypy.dmypy.client import main
# A bunch of effort has been put into threading stdout and stderr
# through the main API to avoid the threadsafety problems of
# modifying sys.stdout/sys.stderr, but that hasn't been done for
# the dmypy client, so we just do the non-threadsafe thing.
def f(stdout: TextIO, stderr: TextIO) -> None:
old_stdout = sys.stdout
old_stderr = sys.stderr
try:
sys.stdout = stdout
sys.stderr = stderr
main(args)
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
return _run(f)

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
from typing import Callable, Sequence
import mypy.subtypes
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.types import (
AnyType,
CallableType,
ParamSpecType,
PartialType,
Type,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnpackType,
get_proper_type,
)
def get_target_type(
tvar: TypeVarLikeType,
type: Type,
callable: CallableType,
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool,
) -> Type | None:
p_type = get_proper_type(type)
if isinstance(p_type, UninhabitedType) and tvar.has_default():
return tvar.default
if isinstance(tvar, ParamSpecType):
return type
if isinstance(tvar, TypeVarTupleType):
return type
assert isinstance(tvar, TypeVarType)
values = tvar.values
if values:
if isinstance(p_type, AnyType):
return type
if isinstance(p_type, TypeVarType) and p_type.values:
# Allow substituting T1 for T if every allowed value of T1
# is also a legal value of T.
if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values):
return type
matching = []
for value in values:
if mypy.subtypes.is_subtype(type, value):
matching.append(value)
if matching:
best = matching[0]
# If there are more than one matching value, we select the narrowest
for match in matching[1:]:
if mypy.subtypes.is_subtype(match, best):
best = match
return best
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
else:
upper_bound = tvar.upper_bound
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
return type
def apply_generic_arguments(
callable: CallableType,
orig_types: Sequence[Type | None],
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool = False,
) -> CallableType:
"""Apply generic type arguments to a callable type.
For example, applying [int] to 'def [T] (T) -> T' results in
'def (int) -> int'.
Note that each type can be None; in this case, it will not be applied.
If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable
bound or constraints, instead of giving an error.
"""
tvars = callable.variables
assert len(tvars) == len(orig_types)
# Check that inferred type variable values are compatible with allowed
# values and bounds. Also, promote subtype values to allowed values.
# Create a map from type variable id to target type.
id_to_type: dict[TypeVarId, Type] = {}
for tvar, type in zip(tvars, orig_types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
if type is None:
continue
target_type = get_target_type(
tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied
)
if target_type is not None:
id_to_type[tvar.id] = target_type
# TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements,
# not just type variable bounds above.
param_spec = callable.param_spec()
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
# ParamSpec expansion is special-cased, so we need to always expand callable
# as a whole, not expanding arguments individually.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(
variables=[tv for tv in tvars if tv.id not in id_to_type]
)
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
else:
callable = callable.copy_modified(
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
)
# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None
# The callable may retain some type vars if only some were applied.
# TODO: move apply_poly() logic from checkexpr.py here when new inference
# becomes universally used (i.e. in all passes + in unification).
# With this new logic we can actually *add* some new free variables.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
)

View File

@@ -0,0 +1,247 @@
"""Utilities for mapping between actual and formal arguments (and their types)."""
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Sequence
from mypy import nodes
from mypy.maptype import map_instance_to_supertype
from mypy.types import (
AnyType,
Instance,
ParamSpecType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
get_proper_type,
)
if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
def map_actuals_to_formals(
actual_kinds: list[nodes.ArgKind],
actual_names: Sequence[str | None] | None,
formal_kinds: list[nodes.ArgKind],
formal_names: Sequence[str | None],
actual_arg_type: Callable[[int], Type],
) -> list[list[int]]:
"""Calculate mapping between actual (caller) args and formals.
The result contains a list of caller argument indexes mapping to each
callee argument index, indexed by callee index.
The caller_arg_type argument should evaluate to the type of the actual
argument type with the given index.
"""
nformals = len(formal_kinds)
formal_to_actual: list[list[int]] = [[] for i in range(nformals)]
ambiguous_actual_kwargs: list[int] = []
fi = 0
for ai, actual_kind in enumerate(actual_kinds):
if actual_kind == nodes.ARG_POS:
if fi < nformals:
if not formal_kinds[fi].is_star():
formal_to_actual[fi].append(ai)
fi += 1
elif formal_kinds[fi] == nodes.ARG_STAR:
formal_to_actual[fi].append(ai)
elif actual_kind == nodes.ARG_STAR:
# We need to know the actual type to map varargs.
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TupleType):
# A tuple actual maps to a fixed number of formals.
for _ in range(len(actualt.items)):
if fi < nformals:
if formal_kinds[fi] != nodes.ARG_STAR2:
formal_to_actual[fi].append(ai)
else:
break
if formal_kinds[fi] != nodes.ARG_STAR:
fi += 1
else:
# Assume that it is an iterable (if it isn't, there will be
# an error later).
while fi < nformals:
if formal_kinds[fi].is_named(star=True):
break
else:
formal_to_actual[fi].append(ai)
if formal_kinds[fi] == nodes.ARG_STAR:
break
fi += 1
elif actual_kind.is_named():
assert actual_names is not None, "Internal error: named kinds without names given"
name = actual_names[ai]
if name in formal_names:
formal_to_actual[formal_names.index(name)].append(ai)
elif nodes.ARG_STAR2 in formal_kinds:
formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai)
else:
assert actual_kind == nodes.ARG_STAR2
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TypedDictType):
for name in actualt.items:
if name in formal_names:
formal_to_actual[formal_names.index(name)].append(ai)
elif nodes.ARG_STAR2 in formal_kinds:
formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai)
else:
# We don't exactly know which **kwargs are provided by the
# caller, so we'll defer until all the other unambiguous
# actuals have been processed
ambiguous_actual_kwargs.append(ai)
if ambiguous_actual_kwargs:
# Assume the ambiguous kwargs will fill the remaining arguments.
#
# TODO: If there are also tuple varargs, we might be missing some potential
# matches if the tuple was short enough to not match everything.
unmatched_formals = [
fi
for fi in range(nformals)
if (
formal_names[fi]
and (
not formal_to_actual[fi]
or actual_kinds[formal_to_actual[fi][0]] == nodes.ARG_STAR
)
and formal_kinds[fi] != nodes.ARG_STAR
)
or formal_kinds[fi] == nodes.ARG_STAR2
]
for ai in ambiguous_actual_kwargs:
for fi in unmatched_formals:
formal_to_actual[fi].append(ai)
return formal_to_actual
def map_formals_to_actuals(
actual_kinds: list[nodes.ArgKind],
actual_names: Sequence[str | None] | None,
formal_kinds: list[nodes.ArgKind],
formal_names: list[str | None],
actual_arg_type: Callable[[int], Type],
) -> list[list[int]]:
"""Calculate the reverse mapping of map_actuals_to_formals."""
formal_to_actual = map_actuals_to_formals(
actual_kinds, actual_names, formal_kinds, formal_names, actual_arg_type
)
# Now reverse the mapping.
actual_to_formal: list[list[int]] = [[] for _ in actual_kinds]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_to_formal[actual].append(formal)
return actual_to_formal
class ArgTypeExpander:
"""Utility class for mapping actual argument types to formal arguments.
One of the main responsibilities is to expand caller tuple *args and TypedDict
**kwargs, and to keep track of which tuple/TypedDict items have already been
consumed.
Example:
def f(x: int, *args: str) -> None: ...
f(*(1, 'x', 1.1))
We'd call expand_actual_type three times:
1. The first call would provide 'int' as the actual type of 'x' (from '1').
2. The second call would provide 'str' as one of the actual types for '*args'.
2. The third call would provide 'float' as one of the actual types for '*args'.
A single instance can process all the arguments for a single call. Each call
needs a separate instance since instances have per-call state.
"""
def __init__(self, context: ArgumentInferContext) -> None:
# Next tuple *args index to use.
self.tuple_index = 0
# Keyword arguments in TypedDict **kwargs used.
self.kwargs_used: set[str] = set()
# Type context for `*` and `**` arg kinds.
self.context = context
def expand_actual_type(
self,
actual_type: Type,
actual_kind: nodes.ArgKind,
formal_name: str | None,
formal_kind: nodes.ArgKind,
) -> Type:
"""Return the actual (caller) type(s) of a formal argument with the given kinds.
If the actual argument is a tuple *args, return the next individual tuple item that
maps to the formal arg.
If the actual argument is a TypedDict **kwargs, return the next matching typed dict
value type based on formal argument name and kind.
This is supposed to be called for each formal, in order. Call multiple times per
formal if multiple actuals map to a formal.
"""
original_actual = actual_type
actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, Instance) and actual_type.args:
from mypy.subtypes import is_subtype
if is_subtype(actual_type, self.context.iterable_type):
return map_instance_to_supertype(
actual_type, self.context.iterable_type.type
).args[0]
else:
# We cannot properly unpack anything other
# than `Iterable` type with `*`.
# Just return `Any`, other parts of code would raise
# a different error for improper use.
return AnyType(TypeOfAny.from_error)
elif isinstance(actual_type, TupleType):
# Get the next tuple item of a tuple *arg.
if self.tuple_index >= len(actual_type.items):
# Exhausted a tuple -- continue to the next *args.
self.tuple_index = 1
else:
self.tuple_index += 1
return actual_type.items[self.tuple_index - 1]
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in *args but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
elif actual_kind == nodes.ARG_STAR2:
from mypy.subtypes import is_subtype
if isinstance(actual_type, TypedDictType):
if formal_kind != nodes.ARG_STAR2 and formal_name in actual_type.items:
# Lookup type based on keyword argument name.
assert formal_name is not None
else:
# Pick an arbitrary item if no specified keyword is expected.
formal_name = (set(actual_type.items.keys()) - self.kwargs_used).pop()
self.kwargs_used.add(formal_name)
return actual_type.items[formal_name]
elif (
isinstance(actual_type, Instance)
and len(actual_type.args) > 1
and is_subtype(actual_type, self.context.mapping_type)
):
# Only `Mapping` type can be unpacked with `**`.
# Other types will produce an error somewhere else.
return map_instance_to_supertype(actual_type, self.context.mapping_type.type).args[
1
]
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in **kwargs but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
else:
# No translation for other kinds -- 1:1 mapping.
return original_actual

View File

@@ -0,0 +1,455 @@
from __future__ import annotations
from collections import defaultdict
from contextlib import contextmanager
from typing import DefaultDict, Iterator, List, Optional, Tuple, Union, cast
from typing_extensions import TypeAlias as _TypeAlias
from mypy.erasetype import remove_instance_last_known_values
from mypy.join import join_simple
from mypy.literals import Key, literal, literal_hash, subkeys
from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var
from mypy.subtypes import is_same_type, is_subtype
from mypy.types import (
AnyType,
NoneType,
PartialType,
Type,
TypeOfAny,
TypeType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars_with_any
BindableExpression: _TypeAlias = Union[IndexExpr, MemberExpr, NameExpr]
class Frame:
"""A Frame represents a specific point in the execution of a program.
It carries information about the current types of expressions at
that point, arising either from assignments to those expressions
or the result of isinstance checks. It also records whether it is
possible to reach that point at all.
This information is not copied into a new Frame when it is pushed
onto the stack, so a given Frame only has information about types
that were assigned in that frame.
"""
def __init__(self, id: int, conditional_frame: bool = False) -> None:
self.id = id
self.types: dict[Key, Type] = {}
self.unreachable = False
self.conditional_frame = conditional_frame
self.suppress_unreachable_warnings = False
def __repr__(self) -> str:
return f"Frame({self.id}, {self.types}, {self.unreachable}, {self.conditional_frame})"
Assigns = DefaultDict[Expression, List[Tuple[Type, Optional[Type]]]]
class ConditionalTypeBinder:
"""Keep track of conditional types of variables.
NB: Variables are tracked by literal expression, so it is possible
to confuse the binder; for example,
```
class A:
a: Union[int, str] = None
x = A()
lst = [x]
reveal_type(x.a) # Union[int, str]
x.a = 1
reveal_type(x.a) # int
reveal_type(lst[0].a) # Union[int, str]
lst[0].a = 'a'
reveal_type(x.a) # int
reveal_type(lst[0].a) # str
```
"""
# Stored assignments for situations with tuple/list lvalue and rvalue of union type.
# This maps an expression to a list of bound types for every item in the union type.
type_assignments: Assigns | None = None
def __init__(self) -> None:
self.next_id = 1
# The stack of frames currently used. These map
# literal_hash(expr) -- literals like 'foo.bar' --
# to types. The last element of this list is the
# top-most, current frame. Each earlier element
# records the state as of when that frame was last
# on top of the stack.
self.frames = [Frame(self._get_id())]
# For frames higher in the stack, we record the set of
# Frames that can escape there, either by falling off
# the end of the frame or by a loop control construct
# or raised exception. The last element of self.frames
# has no corresponding element in this list.
self.options_on_return: list[list[Frame]] = []
# Maps literal_hash(expr) to get_declaration(expr)
# for every expr stored in the binder
self.declarations: dict[Key, Type | None] = {}
# Set of other keys to invalidate if a key is changed, e.g. x -> {x.a, x[0]}
# Whenever a new key (e.g. x.a.b) is added, we update this
self.dependencies: dict[Key, set[Key]] = {}
# Whether the last pop changed the newly top frame on exit
self.last_pop_changed = False
self.try_frames: set[int] = set()
self.break_frames: list[int] = []
self.continue_frames: list[int] = []
def _get_id(self) -> int:
self.next_id += 1
return self.next_id
def _add_dependencies(self, key: Key, value: Key | None = None) -> None:
if value is None:
value = key
else:
self.dependencies.setdefault(key, set()).add(value)
for elt in subkeys(key):
self._add_dependencies(elt, value)
def push_frame(self, conditional_frame: bool = False) -> Frame:
"""Push a new frame into the binder."""
f = Frame(self._get_id(), conditional_frame)
self.frames.append(f)
self.options_on_return.append([])
return f
def _put(self, key: Key, type: Type, index: int = -1) -> None:
self.frames[index].types[key] = type
def _get(self, key: Key, index: int = -1) -> Type | None:
if index < 0:
index += len(self.frames)
for i in range(index, -1, -1):
if key in self.frames[i].types:
return self.frames[i].types[key]
return None
def put(self, expr: Expression, typ: Type) -> None:
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
return
if not literal(expr):
return
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried to put non-literal"
if key not in self.declarations:
self.declarations[key] = get_declaration(expr)
self._add_dependencies(key)
self._put(key, typ)
def unreachable(self) -> None:
self.frames[-1].unreachable = True
def suppress_unreachable_warnings(self) -> None:
self.frames[-1].suppress_unreachable_warnings = True
def get(self, expr: Expression) -> Type | None:
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried to get non-literal"
return self._get(key)
def is_unreachable(self) -> bool:
# TODO: Copy the value of unreachable into new frames to avoid
# this traversal on every statement?
return any(f.unreachable for f in self.frames)
def is_unreachable_warning_suppressed(self) -> bool:
return any(f.suppress_unreachable_warnings for f in self.frames)
def cleanse(self, expr: Expression) -> None:
"""Remove all references to a Node from the binder."""
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried cleanse non-literal"
self._cleanse_key(key)
def _cleanse_key(self, key: Key) -> None:
"""Remove all references to a key from the binder."""
for frame in self.frames:
if key in frame.types:
del frame.types[key]
def update_from_options(self, frames: list[Frame]) -> bool:
"""Update the frame to reflect that each key will be updated
as in one of the frames. Return whether any item changes.
If a key is declared as AnyType, only update it if all the
options are the same.
"""
frames = [f for f in frames if not f.unreachable]
changed = False
keys = {key for f in frames for key in f.types}
for key in keys:
current_value = self._get(key)
resulting_values = [f.types.get(key, current_value) for f in frames]
if any(x is None for x in resulting_values):
# We didn't know anything about key before
# (current_value must be None), and we still don't
# know anything about key in at least one possible frame.
continue
type = resulting_values[0]
assert type is not None
declaration_type = get_proper_type(self.declarations.get(key))
if isinstance(declaration_type, AnyType):
# At this point resulting values can't contain None, see continue above
if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]):
type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type)
else:
for other in resulting_values[1:]:
assert other is not None
type = join_simple(self.declarations[key], type, other)
if current_value is None or not is_same_type(type, current_value):
self._put(key, type)
changed = True
self.frames[-1].unreachable = not frames
return changed
def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
"""Pop a frame and return it.
See frame_context() for documentation of fall_through.
"""
if fall_through > 0:
self.allow_jump(-fall_through)
result = self.frames.pop()
options = self.options_on_return.pop()
if can_skip:
options.insert(0, self.frames[-1])
self.last_pop_changed = self.update_from_options(options)
return result
@contextmanager
def accumulate_type_assignments(self) -> Iterator[Assigns]:
"""Push a new map to collect assigned types in multiassign from union.
If this map is not None, actual binding is deferred until all items in
the union are processed (a union of collected items is later bound
manually by the caller).
"""
old_assignments = None
if self.type_assignments is not None:
old_assignments = self.type_assignments
self.type_assignments = defaultdict(list)
yield self.type_assignments
self.type_assignments = old_assignments
def assign_type(
self, expr: Expression, type: Type, declared_type: Type | None, restrict_any: bool = False
) -> None:
# We should erase last known value in binder, because if we are using it,
# it means that the target is not final, and therefore can't hold a literal.
type = remove_instance_last_known_values(type)
if self.type_assignments is not None:
# We are in a multiassign from union, defer the actual binding,
# just collect the types.
self.type_assignments[expr].append((type, declared_type))
return
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
return None
if not literal(expr):
return
self.invalidate_dependencies(expr)
if declared_type is None:
# Not sure why this happens. It seems to mainly happen in
# member initialization.
return
if not is_subtype(type, declared_type):
# Pretty sure this is only happens when there's a type error.
# Ideally this function wouldn't be called if the
# expression has a type error, though -- do other kinds of
# errors cause this function to get called at invalid
# times?
return
p_declared = get_proper_type(declared_type)
p_type = get_proper_type(type)
enclosing_type = get_proper_type(self.most_recent_enclosing_type(expr, type))
if isinstance(enclosing_type, AnyType) and not restrict_any:
# If x is Any and y is int, after x = y we do not infer that x is int.
# This could be changed.
# Instead, since we narrowed type from Any in a recent frame (probably an
# isinstance check), but now it is reassigned, we broaden back
# to Any (which is the most recent enclosing type)
self.put(expr, enclosing_type)
# As a special case, when assigning Any to a variable with a
# declared Optional type that has been narrowed to None,
# replace all the Nones in the declared Union type with Any.
# This overrides the normal behavior of ignoring Any assignments to variables
# in order to prevent false positives.
# (See discussion in #3526)
elif (
isinstance(p_type, AnyType)
and isinstance(p_declared, UnionType)
and any(isinstance(get_proper_type(item), NoneType) for item in p_declared.items)
and isinstance(
get_proper_type(self.most_recent_enclosing_type(expr, NoneType())), NoneType
)
):
# Replace any Nones in the union type with Any
new_items = [
type if isinstance(get_proper_type(item), NoneType) else item
for item in p_declared.items
]
self.put(expr, UnionType(new_items))
elif isinstance(p_type, AnyType) and not (
isinstance(p_declared, UnionType)
and any(isinstance(get_proper_type(item), AnyType) for item in p_declared.items)
):
# Assigning an Any value doesn't affect the type to avoid false negatives, unless
# there is an Any item in a declared union type.
self.put(expr, declared_type)
else:
self.put(expr, type)
for i in self.try_frames:
# XXX This should probably not copy the entire frame, but
# just copy this variable into a single stored frame.
self.allow_jump(i)
def invalidate_dependencies(self, expr: BindableExpression) -> None:
"""Invalidate knowledge of types that include expr, but not expr itself.
For example, when expr is foo.bar, invalidate foo.bar.baz.
It is overly conservative: it invalidates globally, including
in code paths unreachable from here.
"""
key = literal_hash(expr)
assert key is not None
for dep in self.dependencies.get(key, set()):
self._cleanse_key(dep)
def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Type | None:
type = get_proper_type(type)
if isinstance(type, AnyType):
return get_declaration(expr)
key = literal_hash(expr)
assert key is not None
enclosers = [get_declaration(expr)] + [
f.types[key] for f in self.frames if key in f.types and is_subtype(type, f.types[key])
]
return enclosers[-1]
def allow_jump(self, index: int) -> None:
# self.frames and self.options_on_return have different lengths
# so make sure the index is positive
if index < 0:
index += len(self.options_on_return)
frame = Frame(self._get_id())
for f in self.frames[index + 1 :]:
frame.types.update(f.types)
if f.unreachable:
frame.unreachable = True
self.options_on_return[index].append(frame)
def handle_break(self) -> None:
self.allow_jump(self.break_frames[-1])
self.unreachable()
def handle_continue(self) -> None:
self.allow_jump(self.continue_frames[-1])
self.unreachable()
@contextmanager
def frame_context(
self,
*,
can_skip: bool,
fall_through: int = 1,
break_frame: int = 0,
continue_frame: int = 0,
conditional_frame: bool = False,
try_frame: bool = False,
) -> Iterator[Frame]:
"""Return a context manager that pushes/pops frames on enter/exit.
If can_skip is True, control flow is allowed to bypass the
newly-created frame.
If fall_through > 0, then it will allow control flow that
falls off the end of the frame to escape to its ancestor
`fall_through` levels higher. Otherwise control flow ends
at the end of the frame.
If break_frame > 0, then 'break' statements within this frame
will jump out to the frame break_frame levels higher than the
frame created by this call to frame_context. Similarly for
continue_frame and 'continue' statements.
If try_frame is true, then execution is allowed to jump at any
point within the newly created frame (or its descendants) to
its parent (i.e., to the frame that was on top before this
call to frame_context).
After the context manager exits, self.last_pop_changed indicates
whether any types changed in the newly-topmost frame as a result
of popping this frame.
"""
assert len(self.frames) > 1
if break_frame:
self.break_frames.append(len(self.frames) - break_frame)
if continue_frame:
self.continue_frames.append(len(self.frames) - continue_frame)
if try_frame:
self.try_frames.add(len(self.frames) - 1)
new_frame = self.push_frame(conditional_frame)
if try_frame:
# An exception may occur immediately
self.allow_jump(-1)
yield new_frame
self.pop_frame(can_skip, fall_through)
if break_frame:
self.break_frames.pop()
if continue_frame:
self.continue_frames.pop()
if try_frame:
self.try_frames.remove(len(self.frames) - 1)
@contextmanager
def top_frame_context(self) -> Iterator[Frame]:
"""A variant of frame_context for use at the top level of
a namespace (module, function, or class).
"""
assert len(self.frames) == 1
yield self.push_frame()
self.pop_frame(True, 0)
assert len(self.frames) == 1
def get_declaration(expr: BindableExpression) -> Type | None:
if isinstance(expr, RefExpr):
if isinstance(expr.node, Var):
type = expr.node.type
if not isinstance(get_proper_type(type), PartialType):
return type
elif isinstance(expr.node, TypeInfo):
return TypeType(fill_typevars_with_any(expr.node))
return None

View File

@@ -0,0 +1,27 @@
"""A Bogus[T] type alias for marking when we subvert the type system
We need this for compiling with mypyc, which inserts runtime
typechecks that cause problems when we subvert the type system. So
when compiling with mypyc, we turn those places into Any, while
keeping the types around for normal typechecks.
Since this causes the runtime types to be Any, this is best used
in places where efficient access to properties is not important.
For those cases some other technique should be used.
"""
from __future__ import annotations
from typing import Any, TypeVar
from mypy_extensions import FlexibleAlias
T = TypeVar("T")
# This won't ever be true at runtime, but we consider it true during
# mypyc compilations.
MYPYC = False
if MYPYC:
Bogus = FlexibleAlias[T, Any]
else:
Bogus = FlexibleAlias[T, T]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,721 @@
"""Pattern checker. This file is conceptually part of TypeChecker."""
from __future__ import annotations
from collections import defaultdict
from typing import Final, NamedTuple
import mypy.checker
from mypy import message_registry
from mypy.checkmember import analyze_member_access
from mypy.expandtype import expand_type_by_instance
from mypy.join import join_types
from mypy.literals import literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import narrow_declared_type
from mypy.messages import MessageBuilder
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
from mypy.options import Options
from mypy.patterns import (
AsPattern,
ClassPattern,
MappingPattern,
OrPattern,
Pattern,
SequencePattern,
SingletonPattern,
StarredPattern,
ValuePattern,
)
from mypy.plugin import Plugin
from mypy.subtypes import is_subtype
from mypy.typeops import (
coerce_to_literal,
make_simplified_union,
try_getting_str_literals_from_type,
tuple_fallback,
)
from mypy.types import (
AnyType,
Instance,
LiteralType,
NoneType,
ProperType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
UninhabitedType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars
from mypy.visitor import PatternVisitor
self_match_type_names: Final = [
"builtins.bool",
"builtins.bytearray",
"builtins.bytes",
"builtins.dict",
"builtins.float",
"builtins.frozenset",
"builtins.int",
"builtins.list",
"builtins.set",
"builtins.str",
"builtins.tuple",
]
non_sequence_match_type_names: Final = ["builtins.str", "builtins.bytes", "builtins.bytearray"]
# For every Pattern a PatternType can be calculated. This requires recursively calculating
# the PatternTypes of the sub-patterns first.
# Using the data in the PatternType the match subject and captured names can be narrowed/inferred.
class PatternType(NamedTuple):
type: Type # The type the match subject can be narrowed to
rest_type: Type # The remaining type if the pattern didn't match
captures: dict[Expression, Type] # The variables captured by the pattern
class PatternChecker(PatternVisitor[PatternType]):
"""Pattern checker.
This class checks if a pattern can match a type, what the type can be narrowed to, and what
type capture patterns should be inferred as.
"""
# Some services are provided by a TypeChecker instance.
chk: mypy.checker.TypeChecker
# This is shared with TypeChecker, but stored also here for convenience.
msg: MessageBuilder
# Currently unused
plugin: Plugin
# The expression being matched against the pattern
subject: Expression
subject_type: Type
# Type of the subject to check the (sub)pattern against
type_context: list[Type]
# Types that match against self instead of their __match_args__ if used as a class pattern
# Filled in from self_match_type_names
self_match_types: list[Type]
# Types that are sequences, but don't match sequence patterns. Filled in from
# non_sequence_match_type_names
non_sequence_match_types: list[Type]
options: Options
def __init__(
self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: Plugin, options: Options
) -> None:
self.chk = chk
self.msg = msg
self.plugin = plugin
self.type_context = []
self.self_match_types = self.generate_types_from_names(self_match_type_names)
self.non_sequence_match_types = self.generate_types_from_names(
non_sequence_match_type_names
)
self.options = options
def accept(self, o: Pattern, type_context: Type) -> PatternType:
self.type_context.append(type_context)
result = o.accept(self)
self.type_context.pop()
return result
def visit_as_pattern(self, o: AsPattern) -> PatternType:
current_type = self.type_context[-1]
if o.pattern is not None:
pattern_type = self.accept(o.pattern, current_type)
typ, rest_type, type_map = pattern_type
else:
typ, rest_type, type_map = current_type, UninhabitedType(), {}
if not is_uninhabited(typ) and o.name is not None:
typ, _ = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=current_type
)
if not is_uninhabited(typ):
type_map[o.name] = typ
return PatternType(typ, rest_type, type_map)
def visit_or_pattern(self, o: OrPattern) -> PatternType:
current_type = self.type_context[-1]
#
# Check all the subpatterns
#
pattern_types = []
for pattern in o.patterns:
pattern_type = self.accept(pattern, current_type)
pattern_types.append(pattern_type)
current_type = pattern_type.rest_type
#
# Collect the final type
#
types = []
for pattern_type in pattern_types:
if not is_uninhabited(pattern_type.type):
types.append(pattern_type.type)
#
# Check the capture types
#
capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
# Collect captures from the first subpattern
for expr, typ in pattern_types[0].captures.items():
node = get_var(expr)
capture_types[node].append((expr, typ))
# Check if other subpatterns capture the same names
for i, pattern_type in enumerate(pattern_types[1:]):
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
if capture_types.keys() != vars:
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
for expr, typ in pattern_type.captures.items():
node = get_var(expr)
capture_types[node].append((expr, typ))
captures: dict[Expression, Type] = {}
for var, capture_list in capture_types.items():
typ = UninhabitedType()
for _, other in capture_list:
typ = join_types(typ, other)
captures[capture_list[0][0]] = typ
union_type = make_simplified_union(types)
return PatternType(union_type, current_type, captures)
def visit_value_pattern(self, o: ValuePattern) -> PatternType:
current_type = self.type_context[-1]
typ = self.chk.expr_checker.accept(o.expr)
typ = coerce_to_literal(typ)
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=current_type
)
if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)):
return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {})
return PatternType(narrowed_type, rest_type, {})
def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
current_type = self.type_context[-1]
value: bool | None = o.value
if isinstance(value, bool):
typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool")
elif value is None:
typ = NoneType()
else:
assert False
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=current_type
)
return PatternType(narrowed_type, rest_type, {})
def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
#
# check for existence of a starred pattern
#
current_type = get_proper_type(self.type_context[-1])
if not self.can_match_sequence(current_type):
return self.early_non_match()
star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
star_position: int | None = None
if len(star_positions) == 1:
star_position = star_positions[0]
elif len(star_positions) >= 2:
assert False, "Parser should prevent multiple starred patterns"
required_patterns = len(o.patterns)
if star_position is not None:
required_patterns -= 1
#
# get inner types of original type
#
if isinstance(current_type, TupleType):
inner_types = current_type.items
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
else:
inner_type = self.get_sequence_type(current_type, o)
if inner_type is None:
inner_type = self.chk.named_type("builtins.object")
inner_types = [inner_type] * len(o.patterns)
#
# match inner patterns
#
contracted_new_inner_types: list[Type] = []
contracted_rest_inner_types: list[Type] = []
captures: dict[Expression, Type] = {}
contracted_inner_types = self.contract_starred_pattern_types(
inner_types, star_position, required_patterns
)
for p, t in zip(o.patterns, contracted_inner_types):
pattern_type = self.accept(p, t)
typ, rest, type_map = pattern_type
contracted_new_inner_types.append(typ)
contracted_rest_inner_types.append(rest)
self.update_type_map(captures, type_map)
new_inner_types = self.expand_starred_pattern_types(
contracted_new_inner_types, star_position, len(inner_types)
)
rest_inner_types = self.expand_starred_pattern_types(
contracted_rest_inner_types, star_position, len(inner_types)
)
#
# Calculate new type
#
new_type: Type
rest_type: Type = current_type
if isinstance(current_type, TupleType):
narrowed_inner_types = []
inner_rest_types = []
for inner_type, new_inner_type in zip(inner_types, new_inner_types):
(
narrowed_inner_type,
inner_rest_type,
) = self.chk.conditional_types_with_intersection(
new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type
)
narrowed_inner_types.append(narrowed_inner_type)
inner_rest_types.append(inner_rest_type)
if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
else:
new_type = UninhabitedType()
if all(is_uninhabited(typ) for typ in inner_rest_types):
# All subpatterns always match, so we can apply negative narrowing
rest_type = TupleType(rest_inner_types, current_type.partial_fallback)
else:
new_inner_type = UninhabitedType()
for typ in new_inner_types:
new_inner_type = join_types(new_inner_type, typ)
new_type = self.construct_sequence_child(current_type, new_inner_type)
if is_subtype(new_type, current_type):
new_type, _ = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(new_type)], o, default=current_type
)
else:
new_type = current_type
return PatternType(new_type, rest_type, captures)
def get_sequence_type(self, t: Type, context: Context) -> Type | None:
t = get_proper_type(t)
if isinstance(t, AnyType):
return AnyType(TypeOfAny.from_another_any, t)
if isinstance(t, UnionType):
items = [self.get_sequence_type(item, context) for item in t.items]
not_none_items = [item for item in items if item is not None]
if not_none_items:
return make_simplified_union(not_none_items)
else:
return None
if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)):
if isinstance(t, TupleType):
t = tuple_fallback(t)
return self.chk.iterable_item_type(t, context)
else:
return None
def contract_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_patterns: int
) -> list[Type]:
"""
Contracts a list of types in a sequence pattern depending on the position of a starred
capture pattern.
For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str,
bytes] the contracted types are [bool, Union[int, str], bytes].
If star_pos in None the types are returned unchanged.
"""
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = len(types) - num_patterns
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
new_types += types[star_pos + star_length :]
return new_types
def expand_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_types: int
) -> list[Type]:
"""Undoes the contraction done by contract_starred_pattern_types.
For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
to length 4 the result is [bool, int, int, str].
"""
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = num_types - len(types) + 1
new_types += [types[star_pos]] * star_length
new_types += types[star_pos + 1 :]
return new_types
def visit_starred_pattern(self, o: StarredPattern) -> PatternType:
captures: dict[Expression, Type] = {}
if o.capture is not None:
list_type = self.chk.named_generic_type("builtins.list", [self.type_context[-1]])
captures[o.capture] = list_type
return PatternType(self.type_context[-1], UninhabitedType(), captures)
def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
current_type = get_proper_type(self.type_context[-1])
can_match = True
captures: dict[Expression, Type] = {}
for key, value in zip(o.keys, o.values):
inner_type = self.get_mapping_item_type(o, current_type, key)
if inner_type is None:
can_match = False
inner_type = self.chk.named_type("builtins.object")
pattern_type = self.accept(value, inner_type)
if is_uninhabited(pattern_type.type):
can_match = False
else:
self.update_type_map(captures, pattern_type.captures)
if o.rest is not None:
mapping = self.chk.named_type("typing.Mapping")
if is_subtype(current_type, mapping) and isinstance(current_type, Instance):
mapping_inst = map_instance_to_supertype(current_type, mapping.type)
dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict")
rest_type = Instance(dict_typeinfo, mapping_inst.args)
else:
object_type = self.chk.named_type("builtins.object")
rest_type = self.chk.named_generic_type(
"builtins.dict", [object_type, object_type]
)
captures[o.rest] = rest_type
if can_match:
# We can't narrow the type here, as Mapping key is invariant.
new_type = self.type_context[-1]
else:
new_type = UninhabitedType()
return PatternType(new_type, current_type, captures)
def get_mapping_item_type(
self, pattern: MappingPattern, mapping_type: Type, key: Expression
) -> Type | None:
mapping_type = get_proper_type(mapping_type)
if isinstance(mapping_type, TypedDictType):
with self.msg.filter_errors() as local_errors:
result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
mapping_type, key
)
has_local_errors = local_errors.has_new_errors()
# If we can't determine the type statically fall back to treating it as a normal
# mapping
if has_local_errors:
with self.msg.filter_errors() as local_errors:
result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
if local_errors.has_new_errors():
result = None
else:
with self.msg.filter_errors():
result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
return result
def get_simple_mapping_item_type(
self, pattern: MappingPattern, mapping_type: Type, key: Expression
) -> Type:
result, _ = self.chk.expr_checker.check_method_call_by_name(
"__getitem__", mapping_type, [key], [ARG_POS], pattern
)
return result
def visit_class_pattern(self, o: ClassPattern) -> PatternType:
current_type = get_proper_type(self.type_context[-1])
#
# Check class type
#
type_info = o.class_ref.node
if type_info is None:
return PatternType(AnyType(TypeOfAny.from_error), AnyType(TypeOfAny.from_error), {})
if isinstance(type_info, TypeAlias) and not type_info.no_args:
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
return self.early_non_match()
if isinstance(type_info, TypeInfo):
any_type = AnyType(TypeOfAny.implementation_artifact)
typ: Type = Instance(type_info, [any_type] * len(type_info.defn.type_vars))
elif isinstance(type_info, TypeAlias):
typ = type_info.target
else:
if isinstance(type_info, Var) and type_info.type is not None:
name = type_info.type.str_with_options(self.options)
else:
name = type_info.name
self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o)
return self.early_non_match()
new_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=current_type
)
if is_uninhabited(new_type):
return self.early_non_match()
# TODO: Do I need this?
narrowed_type = narrow_declared_type(current_type, new_type)
#
# Convert positional to keyword patterns
#
keyword_pairs: list[tuple[str | None, Pattern]] = []
match_arg_set: set[str] = set()
captures: dict[Expression, Type] = {}
if len(o.positionals) != 0:
if self.should_self_match(typ):
if len(o.positionals) > 1:
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
pattern_type = self.accept(o.positionals[0], narrowed_type)
if not is_uninhabited(pattern_type.type):
return PatternType(
pattern_type.type,
join_types(rest_type, pattern_type.rest_type),
pattern_type.captures,
)
captures = pattern_type.captures
else:
with self.msg.filter_errors() as local_errors:
match_args_type = analyze_member_access(
"__match_args__",
typ,
o,
False,
False,
False,
self.msg,
original_type=typ,
chk=self.chk,
)
has_local_errors = local_errors.has_new_errors()
if has_local_errors:
self.msg.fail(
message_registry.MISSING_MATCH_ARGS.format(
typ.str_with_options(self.options)
),
o,
)
return self.early_non_match()
proper_match_args_type = get_proper_type(match_args_type)
if isinstance(proper_match_args_type, TupleType):
match_arg_names = get_match_arg_names(proper_match_args_type)
if len(o.positionals) > len(match_arg_names):
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
return self.early_non_match()
else:
match_arg_names = [None] * len(o.positionals)
for arg_name, pos in zip(match_arg_names, o.positionals):
keyword_pairs.append((arg_name, pos))
if arg_name is not None:
match_arg_set.add(arg_name)
#
# Check for duplicate patterns
#
keyword_arg_set = set()
has_duplicates = False
for key, value in zip(o.keyword_keys, o.keyword_values):
keyword_pairs.append((key, value))
if key in match_arg_set:
self.msg.fail(
message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), value
)
has_duplicates = True
elif key in keyword_arg_set:
self.msg.fail(
message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), value
)
has_duplicates = True
keyword_arg_set.add(key)
if has_duplicates:
return self.early_non_match()
#
# Check keyword patterns
#
can_match = True
for keyword, pattern in keyword_pairs:
key_type: Type | None = None
with self.msg.filter_errors() as local_errors:
if keyword is not None:
key_type = analyze_member_access(
keyword,
narrowed_type,
pattern,
False,
False,
False,
self.msg,
original_type=new_type,
chk=self.chk,
)
else:
key_type = AnyType(TypeOfAny.from_error)
has_local_errors = local_errors.has_new_errors()
if has_local_errors or key_type is None:
key_type = AnyType(TypeOfAny.from_error)
self.msg.fail(
message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(
typ.str_with_options(self.options), keyword
),
pattern,
)
inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
if is_uninhabited(inner_type):
can_match = False
else:
self.update_type_map(captures, inner_captures)
if not is_uninhabited(inner_rest_type):
rest_type = current_type
if not can_match:
new_type = UninhabitedType()
return PatternType(new_type, rest_type, captures)
def should_self_match(self, typ: Type) -> bool:
typ = get_proper_type(typ)
if isinstance(typ, Instance) and typ.type.is_named_tuple:
return False
for other in self.self_match_types:
if is_subtype(typ, other):
return True
return False
def can_match_sequence(self, typ: ProperType) -> bool:
if isinstance(typ, UnionType):
return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
for other in self.non_sequence_match_types:
# We have to ignore promotions, as memoryview should match, but bytes,
# which it can be promoted to, shouldn't
if is_subtype(typ, other, ignore_promotions=True):
return False
sequence = self.chk.named_type("typing.Sequence")
# If the static type is more general than sequence the actual type could still match
return is_subtype(typ, sequence) or is_subtype(sequence, typ)
def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
types: list[Type] = []
for name in type_names:
try:
types.append(self.chk.named_type(name))
except KeyError as e:
# Some built in types are not defined in all test cases
if not name.startswith("builtins."):
raise e
return types
def update_type_map(
self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
) -> None:
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
# expressions, as suggested in the TODO above it's definition
already_captured = {literal_hash(expr) for expr in original_type_map}
for expr, typ in extra_type_map.items():
if literal_hash(expr) in already_captured:
node = get_var(expr)
self.msg.fail(
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
)
else:
original_type_map[expr] = typ
def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
"""
If outer_type is a child class of typing.Sequence returns a new instance of
outer_type, that is a Sequence of inner_type. If outer_type is not a child class of
typing.Sequence just returns a Sequence of inner_type
For example:
construct_sequence_child(List[int], str) = List[str]
TODO: this doesn't make sense. For example if one has class S(Sequence[int], Generic[T])
or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str].
"""
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, UnionType):
types = [
self.construct_sequence_child(item, inner_type)
for item in proper_type.items
if self.can_match_sequence(get_proper_type(item))
]
return make_simplified_union(types)
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, TupleType):
proper_type = tuple_fallback(proper_type)
assert isinstance(proper_type, Instance)
empty_type = fill_typevars(proper_type.type)
partial_type = expand_type_by_instance(empty_type, sequence)
return expand_type_by_instance(partial_type, proper_type)
else:
return sequence
def early_non_match(self) -> PatternType:
return PatternType(UninhabitedType(), self.type_context[-1], {})
def get_match_arg_names(typ: TupleType) -> list[str | None]:
args: list[str | None] = []
for item in typ.items:
values = try_getting_str_literals_from_type(item)
if values is None or len(values) != 1:
args.append(None)
else:
args.append(values[0])
return args
def get_var(expr: Expression) -> Var:
"""
Warning: this in only true for expressions captured by a match statement.
Don't call it from anywhere else
"""
assert isinstance(expr, NameExpr)
node = expr.node
assert isinstance(node, Var)
return node
def get_type_range(typ: Type) -> mypy.checker.TypeRange:
typ = get_proper_type(typ)
if (
isinstance(typ, Instance)
and typ.last_known_value
and isinstance(typ.last_known_value.value, bool)
):
typ = typ.last_known_value
return mypy.checker.TypeRange(typ, is_upper_bound=False)
def is_uninhabited(typ: Type) -> bool:
return isinstance(get_proper_type(typ), UninhabitedType)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,638 @@
from __future__ import annotations
import argparse
import configparser
import glob as fileglob
import os
import re
import sys
from io import StringIO
from mypy.errorcodes import error_codes
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
from typing import (
Any,
Callable,
Dict,
Final,
Iterable,
List,
Mapping,
MutableMapping,
Sequence,
TextIO,
Tuple,
Union,
)
from typing_extensions import TypeAlias as _TypeAlias
from mypy import defaults
from mypy.options import PER_MODULE_OPTIONS, Options
_CONFIG_VALUE_TYPES: _TypeAlias = Union[
str, bool, int, float, Dict[str, str], List[str], Tuple[int, int]
]
_INI_PARSER_CALLABLE: _TypeAlias = Callable[[Any], _CONFIG_VALUE_TYPES]
def parse_version(v: str | float) -> tuple[int, int]:
m = re.match(r"\A(\d)\.(\d+)\Z", str(v))
if not m:
raise argparse.ArgumentTypeError(f"Invalid python version '{v}' (expected format: 'x.y')")
major, minor = int(m.group(1)), int(m.group(2))
if major == 2 and minor == 7:
pass # Error raised elsewhere
elif major == 3:
if minor < defaults.PYTHON3_VERSION_MIN[1]:
msg = "Python 3.{} is not supported (must be {}.{} or higher)".format(
minor, *defaults.PYTHON3_VERSION_MIN
)
if isinstance(v, float):
msg += ". You may need to put quotes around your Python version"
raise argparse.ArgumentTypeError(msg)
else:
raise argparse.ArgumentTypeError(
f"Python major version '{major}' out of range (must be 3)"
)
return major, minor
def try_split(v: str | Sequence[str], split_regex: str = "[,]") -> list[str]:
"""Split and trim a str or list of str into a list of str"""
if isinstance(v, str):
return [p.strip() for p in re.split(split_regex, v)]
return [p.strip() for p in v]
def validate_codes(codes: list[str]) -> list[str]:
invalid_codes = set(codes) - set(error_codes.keys())
if invalid_codes:
raise argparse.ArgumentTypeError(
f"Invalid error code(s): {', '.join(sorted(invalid_codes))}"
)
return codes
def validate_package_allow_list(allow_list: list[str]) -> list[str]:
for p in allow_list:
msg = f"Invalid allow list entry: {p}"
if "*" in p:
raise argparse.ArgumentTypeError(
f"{msg} (entries are already prefixes so must not contain *)"
)
if "\\" in p or "/" in p:
raise argparse.ArgumentTypeError(
f"{msg} (entries must be packages like foo.bar not directories or files)"
)
return allow_list
def expand_path(path: str) -> str:
"""Expand the user home directory and any environment variables contained within
the provided path.
"""
return os.path.expandvars(os.path.expanduser(path))
def str_or_array_as_list(v: str | Sequence[str]) -> list[str]:
if isinstance(v, str):
return [v.strip()] if v.strip() else []
return [p.strip() for p in v if p.strip()]
def split_and_match_files_list(paths: Sequence[str]) -> list[str]:
"""Take a list of files/directories (with support for globbing through the glob library).
Where a path/glob matches no file, we still include the raw path in the resulting list.
Returns a list of file paths
"""
expanded_paths = []
for path in paths:
path = expand_path(path.strip())
globbed_files = fileglob.glob(path, recursive=True)
if globbed_files:
expanded_paths.extend(globbed_files)
else:
expanded_paths.append(path)
return expanded_paths
def split_and_match_files(paths: str) -> list[str]:
"""Take a string representing a list of files/directories (with support for globbing
through the glob library).
Where a path/glob matches no file, we still include the raw path in the resulting list.
Returns a list of file paths
"""
return split_and_match_files_list(paths.split(","))
def check_follow_imports(choice: str) -> str:
choices = ["normal", "silent", "skip", "error"]
if choice not in choices:
raise argparse.ArgumentTypeError(
"invalid choice '{}' (choose from {})".format(
choice, ", ".join(f"'{x}'" for x in choices)
)
)
return choice
def split_commas(value: str) -> list[str]:
# Uses a bit smarter technique to allow last trailing comma
# and to remove last `""` item from the split.
items = value.split(",")
if items and items[-1] == "":
items.pop(-1)
return items
# For most options, the type of the default value set in options.py is
# sufficient, and we don't have to do anything here. This table
# exists to specify types for values initialized to None or container
# types.
ini_config_types: Final[dict[str, _INI_PARSER_CALLABLE]] = {
"python_version": parse_version,
"custom_typing_module": str,
"custom_typeshed_dir": expand_path,
"mypy_path": lambda s: [expand_path(p.strip()) for p in re.split("[,:]", s)],
"files": split_and_match_files,
"quickstart_file": expand_path,
"junit_xml": expand_path,
"follow_imports": check_follow_imports,
"no_site_packages": bool,
"plugins": lambda s: [p.strip() for p in split_commas(s)],
"always_true": lambda s: [p.strip() for p in split_commas(s)],
"always_false": lambda s: [p.strip() for p in split_commas(s)],
"untyped_calls_exclude": lambda s: validate_package_allow_list(
[p.strip() for p in split_commas(s)]
),
"enable_incomplete_feature": lambda s: [p.strip() for p in split_commas(s)],
"disable_error_code": lambda s: validate_codes([p.strip() for p in split_commas(s)]),
"enable_error_code": lambda s: validate_codes([p.strip() for p in split_commas(s)]),
"package_root": lambda s: [p.strip() for p in split_commas(s)],
"cache_dir": expand_path,
"python_executable": expand_path,
"strict": bool,
"exclude": lambda s: [s.strip()],
"packages": try_split,
"modules": try_split,
}
# Reuse the ini_config_types and overwrite the diff
toml_config_types: Final[dict[str, _INI_PARSER_CALLABLE]] = ini_config_types.copy()
toml_config_types.update(
{
"python_version": parse_version,
"mypy_path": lambda s: [expand_path(p) for p in try_split(s, "[,:]")],
"files": lambda s: split_and_match_files_list(try_split(s)),
"follow_imports": lambda s: check_follow_imports(str(s)),
"plugins": try_split,
"always_true": try_split,
"always_false": try_split,
"untyped_calls_exclude": lambda s: validate_package_allow_list(try_split(s)),
"enable_incomplete_feature": try_split,
"disable_error_code": lambda s: validate_codes(try_split(s)),
"enable_error_code": lambda s: validate_codes(try_split(s)),
"package_root": try_split,
"exclude": str_or_array_as_list,
"packages": try_split,
"modules": try_split,
}
)
def parse_config_file(
options: Options,
set_strict_flags: Callable[[], None],
filename: str | None,
stdout: TextIO | None = None,
stderr: TextIO | None = None,
) -> None:
"""Parse a config file into an Options object.
Errors are written to stderr but are not fatal.
If filename is None, fall back to default config files.
"""
stdout = stdout or sys.stdout
stderr = stderr or sys.stderr
if filename is not None:
config_files: tuple[str, ...] = (filename,)
else:
config_files_iter: Iterable[str] = map(os.path.expanduser, defaults.CONFIG_FILES)
config_files = tuple(config_files_iter)
config_parser = configparser.RawConfigParser()
for config_file in config_files:
if not os.path.exists(config_file):
continue
try:
if is_toml(config_file):
with open(config_file, "rb") as f:
toml_data = tomllib.load(f)
# Filter down to just mypy relevant toml keys
toml_data = toml_data.get("tool", {})
if "mypy" not in toml_data:
continue
toml_data = {"mypy": toml_data["mypy"]}
parser: MutableMapping[str, Any] = destructure_overrides(toml_data)
config_types = toml_config_types
else:
config_parser.read(config_file)
parser = config_parser
config_types = ini_config_types
except (tomllib.TOMLDecodeError, configparser.Error, ConfigTOMLValueError) as err:
print(f"{config_file}: {err}", file=stderr)
else:
if config_file in defaults.SHARED_CONFIG_FILES and "mypy" not in parser:
continue
file_read = config_file
options.config_file = file_read
break
else:
return
os.environ["MYPY_CONFIG_FILE_DIR"] = os.path.dirname(os.path.abspath(config_file))
if "mypy" not in parser:
if filename or file_read not in defaults.SHARED_CONFIG_FILES:
print(f"{file_read}: No [mypy] section in config file", file=stderr)
else:
section = parser["mypy"]
prefix = f"{file_read}: [mypy]: "
updates, report_dirs = parse_section(
prefix, options, set_strict_flags, section, config_types, stderr
)
for k, v in updates.items():
setattr(options, k, v)
options.report_dirs.update(report_dirs)
for name, section in parser.items():
if name.startswith("mypy-"):
prefix = get_prefix(file_read, name)
updates, report_dirs = parse_section(
prefix, options, set_strict_flags, section, config_types, stderr
)
if report_dirs:
print(
"%sPer-module sections should not specify reports (%s)"
% (prefix, ", ".join(s + "_report" for s in sorted(report_dirs))),
file=stderr,
)
if set(updates) - PER_MODULE_OPTIONS:
print(
"%sPer-module sections should only specify per-module flags (%s)"
% (prefix, ", ".join(sorted(set(updates) - PER_MODULE_OPTIONS))),
file=stderr,
)
updates = {k: v for k, v in updates.items() if k in PER_MODULE_OPTIONS}
globs = name[5:]
for glob in globs.split(","):
# For backwards compatibility, replace (back)slashes with dots.
glob = glob.replace(os.sep, ".")
if os.altsep:
glob = glob.replace(os.altsep, ".")
if any(c in glob for c in "?[]!") or any(
"*" in x and x != "*" for x in glob.split(".")
):
print(
"%sPatterns must be fully-qualified module names, optionally "
"with '*' in some components (e.g spam.*.eggs.*)" % prefix,
file=stderr,
)
else:
options.per_module_options[glob] = updates
def get_prefix(file_read: str, name: str) -> str:
if is_toml(file_read):
module_name_str = 'module = "%s"' % "-".join(name.split("-")[1:])
else:
module_name_str = name
return f"{file_read}: [{module_name_str}]: "
def is_toml(filename: str) -> bool:
return filename.lower().endswith(".toml")
def destructure_overrides(toml_data: dict[str, Any]) -> dict[str, Any]:
"""Take the new [[tool.mypy.overrides]] section array in the pyproject.toml file,
and convert it back to a flatter structure that the existing config_parser can handle.
E.g. the following pyproject.toml file:
[[tool.mypy.overrides]]
module = [
"a.b",
"b.*"
]
disallow_untyped_defs = true
[[tool.mypy.overrides]]
module = 'c'
disallow_untyped_defs = false
Would map to the following config dict that it would have gotten from parsing an equivalent
ini file:
{
"mypy-a.b": {
disallow_untyped_defs = true,
},
"mypy-b.*": {
disallow_untyped_defs = true,
},
"mypy-c": {
disallow_untyped_defs: false,
},
}
"""
if "overrides" not in toml_data["mypy"]:
return toml_data
if not isinstance(toml_data["mypy"]["overrides"], list):
raise ConfigTOMLValueError(
"tool.mypy.overrides sections must be an array. Please make "
"sure you are using double brackets like so: [[tool.mypy.overrides]]"
)
result = toml_data.copy()
for override in result["mypy"]["overrides"]:
if "module" not in override:
raise ConfigTOMLValueError(
"toml config file contains a [[tool.mypy.overrides]] "
"section, but no module to override was specified."
)
if isinstance(override["module"], str):
modules = [override["module"]]
elif isinstance(override["module"], list):
modules = override["module"]
else:
raise ConfigTOMLValueError(
"toml config file contains a [[tool.mypy.overrides]] "
"section with a module value that is not a string or a list of "
"strings"
)
for module in modules:
module_overrides = override.copy()
del module_overrides["module"]
old_config_name = f"mypy-{module}"
if old_config_name not in result:
result[old_config_name] = module_overrides
else:
for new_key, new_value in module_overrides.items():
if (
new_key in result[old_config_name]
and result[old_config_name][new_key] != new_value
):
raise ConfigTOMLValueError(
"toml config file contains "
"[[tool.mypy.overrides]] sections with conflicting "
"values. Module '%s' has two different values for '%s'"
% (module, new_key)
)
result[old_config_name][new_key] = new_value
del result["mypy"]["overrides"]
return result
def parse_section(
prefix: str,
template: Options,
set_strict_flags: Callable[[], None],
section: Mapping[str, Any],
config_types: dict[str, Any],
stderr: TextIO = sys.stderr,
) -> tuple[dict[str, object], dict[str, str]]:
"""Parse one section of a config file.
Returns a dict of option values encountered, and a dict of report directories.
"""
results: dict[str, object] = {}
report_dirs: dict[str, str] = {}
for key in section:
invert = False
options_key = key
if key in config_types:
ct = config_types[key]
else:
dv = None
# We have to keep new_semantic_analyzer in Options
# for plugin compatibility but it is not a valid option anymore.
assert hasattr(template, "new_semantic_analyzer")
if key != "new_semantic_analyzer":
dv = getattr(template, key, None)
if dv is None:
if key.endswith("_report"):
report_type = key[:-7].replace("_", "-")
if report_type in defaults.REPORTER_NAMES:
report_dirs[report_type] = str(section[key])
else:
print(f"{prefix}Unrecognized report type: {key}", file=stderr)
continue
if key.startswith("x_"):
pass # Don't complain about `x_blah` flags
elif key.startswith("no_") and hasattr(template, key[3:]):
options_key = key[3:]
invert = True
elif key.startswith("allow") and hasattr(template, "dis" + key):
options_key = "dis" + key
invert = True
elif key.startswith("disallow") and hasattr(template, key[3:]):
options_key = key[3:]
invert = True
elif key.startswith("show_") and hasattr(template, "hide_" + key[5:]):
options_key = "hide_" + key[5:]
invert = True
elif key == "strict":
pass # Special handling below
else:
print(f"{prefix}Unrecognized option: {key} = {section[key]}", file=stderr)
if invert:
dv = getattr(template, options_key, None)
else:
continue
ct = type(dv)
v: Any = None
try:
if ct is bool:
if isinstance(section, dict):
v = convert_to_boolean(section.get(key))
else:
v = section.getboolean(key) # type: ignore[attr-defined] # Until better stub
if invert:
v = not v
elif callable(ct):
if invert:
print(f"{prefix}Can not invert non-boolean key {options_key}", file=stderr)
continue
try:
v = ct(section.get(key))
except argparse.ArgumentTypeError as err:
print(f"{prefix}{key}: {err}", file=stderr)
continue
else:
print(f"{prefix}Don't know what type {key} should have", file=stderr)
continue
except ValueError as err:
print(f"{prefix}{key}: {err}", file=stderr)
continue
if key == "strict":
if v:
set_strict_flags()
continue
results[options_key] = v
# These two flags act as per-module overrides, so store the empty defaults.
if "disable_error_code" not in results:
results["disable_error_code"] = []
if "enable_error_code" not in results:
results["enable_error_code"] = []
return results, report_dirs
def convert_to_boolean(value: Any | None) -> bool:
"""Return a boolean value translating from other types if necessary."""
if isinstance(value, bool):
return value
if not isinstance(value, str):
value = str(value)
if value.lower() not in configparser.RawConfigParser.BOOLEAN_STATES:
raise ValueError(f"Not a boolean: {value}")
return configparser.RawConfigParser.BOOLEAN_STATES[value.lower()]
def split_directive(s: str) -> tuple[list[str], list[str]]:
"""Split s on commas, except during quoted sections.
Returns the parts and a list of error messages."""
parts = []
cur: list[str] = []
errors = []
i = 0
while i < len(s):
if s[i] == ",":
parts.append("".join(cur).strip())
cur = []
elif s[i] == '"':
i += 1
while i < len(s) and s[i] != '"':
cur.append(s[i])
i += 1
if i == len(s):
errors.append("Unterminated quote in configuration comment")
cur.clear()
else:
cur.append(s[i])
i += 1
if cur:
parts.append("".join(cur).strip())
return parts, errors
def mypy_comments_to_config_map(line: str, template: Options) -> tuple[dict[str, str], list[str]]:
"""Rewrite the mypy comment syntax into ini file syntax."""
options = {}
entries, errors = split_directive(line)
for entry in entries:
if "=" not in entry:
name = entry
value = None
else:
name, value = (x.strip() for x in entry.split("=", 1))
name = name.replace("-", "_")
if value is None:
value = "True"
options[name] = value
return options, errors
def parse_mypy_comments(
args: list[tuple[int, str]], template: Options
) -> tuple[dict[str, object], list[tuple[int, str]]]:
"""Parse a collection of inline mypy: configuration comments.
Returns a dictionary of options to be applied and a list of error messages
generated.
"""
errors: list[tuple[int, str]] = []
sections = {}
for lineno, line in args:
# In order to easily match the behavior for bools, we abuse configparser.
# Oddly, the only way to get the SectionProxy object with the getboolean
# method is to create a config parser.
parser = configparser.RawConfigParser()
options, parse_errors = mypy_comments_to_config_map(line, template)
parser["dummy"] = options
errors.extend((lineno, x) for x in parse_errors)
stderr = StringIO()
strict_found = False
def set_strict_flags() -> None:
nonlocal strict_found
strict_found = True
new_sections, reports = parse_section(
"", template, set_strict_flags, parser["dummy"], ini_config_types, stderr=stderr
)
errors.extend((lineno, x) for x in stderr.getvalue().strip().split("\n") if x)
if reports:
errors.append((lineno, "Reports not supported in inline configuration"))
if strict_found:
errors.append(
(
lineno,
'Setting "strict" not supported in inline configuration: specify it in '
"a configuration file instead, or set individual inline flags "
'(see "mypy -h" for the list of flags enabled in strict mode)',
)
)
sections.update(new_sections)
return sections, errors
def get_config_module_names(filename: str | None, modules: list[str]) -> str:
if not filename or not modules:
return ""
if not is_toml(filename):
return ", ".join(f"[mypy-{module}]" for module in modules)
return "module = ['%s']" % ("', '".join(sorted(modules)))
class ConfigTOMLValueError(ValueError):
pass

View File

@@ -0,0 +1,187 @@
"""Constant folding of expressions.
For example, 3 + 5 can be constant folded into 8.
"""
from __future__ import annotations
from typing import Final, Union
from mypy.nodes import (
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
Var,
)
# All possible result types of constant folding
ConstantValue = Union[int, bool, float, complex, str]
CONST_TYPES: Final = (int, bool, float, complex, str)
def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
"""Return the constant value of an expression for supported operations.
Among other things, support int arithmetic and string
concatenation. For example, the expression 3 + 5 has the constant
value 8.
Also bind simple references to final constants defined in the
current module (cur_mod_id). Binding to references is best effort
-- we don't bind references to other modules. Mypyc trusts these
to be correct in compiled modules, so that it can replace a
constant expression (or a reference to one) with the statically
computed value. We don't want to infer constant values based on
stubs, in particular, as these might not match the implementation
(due to version skew, for example).
Return None if unsuccessful.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
if expr.name == "True":
return True
elif expr.name == "False":
return False
node = expr.node
if (
isinstance(node, Var)
and node.is_final
and node.fullname.rsplit(".", 1)[0] == cur_mod_id
):
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(expr.left, cur_mod_id)
right = constant_fold_expr(expr.right, cur_mod_id)
if left is not None and right is not None:
return constant_fold_binary_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(expr.expr, cur_mod_id)
if value is not None:
return constant_fold_unary_op(expr.op, value)
return None
def constant_fold_binary_op(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(op, left, right)
# Float and mixed int/float arithmetic.
if isinstance(left, float) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, float) and isinstance(right, int):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, int) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)
# String concatenation and multiplication.
if op == "+" and isinstance(left, str) and isinstance(right, str):
return left + right
elif op == "*" and isinstance(left, str) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, str):
return left * right
# Complex construction.
if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex):
return left + right
elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)):
return left + right
elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex):
return left - right
elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)):
return left - right
return None
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None:
if op == "+":
return left + right
if op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "&":
return left & right
elif op == "|":
return left | right
elif op == "^":
return left ^ right
elif op == "<<":
if right >= 0:
return left << right
elif op == ">>":
if right >= 0:
return left >> right
elif op == "**":
if right >= 0:
ret = left**right
assert isinstance(ret, int)
return ret
return None
def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None:
assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right)
if op == "+":
return left + right
elif op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "**":
if (left < 0 and isinstance(right, int)) or left > 0:
try:
ret = left**right
except OverflowError:
return None
else:
assert isinstance(ret, float), ret
return ret
return None
def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None:
if op == "-" and isinstance(value, (int, float)):
return -value
elif op == "~" and isinstance(value, int):
return ~value
elif op == "+" and isinstance(value, (int, float)):
return value
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,133 @@
from __future__ import annotations
from typing import Any, cast
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
TypeAliasType,
TypedDictType,
TypeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
)
# type_visitor needs to be imported after types
from mypy.type_visitor import TypeVisitor # ruff: isort: skip
def copy_type(t: ProperType) -> ProperType:
"""Create a shallow copy of a type.
This can be used to mutate the copy with truthiness information.
Classes compiled with mypyc don't support copy.copy(), so we need
a custom implementation.
"""
return t.accept(TypeShallowCopier())
class TypeShallowCopier(TypeVisitor[ProperType]):
def visit_unbound_type(self, t: UnboundType) -> ProperType:
return t
def visit_any(self, t: AnyType) -> ProperType:
return self.copy_common(t, AnyType(t.type_of_any, t.source_any, t.missing_import_name))
def visit_none_type(self, t: NoneType) -> ProperType:
return self.copy_common(t, NoneType())
def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
dup = UninhabitedType(t.is_noreturn)
dup.ambiguous = t.ambiguous
return self.copy_common(t, dup)
def visit_erased_type(self, t: ErasedType) -> ProperType:
return self.copy_common(t, ErasedType())
def visit_deleted_type(self, t: DeletedType) -> ProperType:
return self.copy_common(t, DeletedType(t.source))
def visit_instance(self, t: Instance) -> ProperType:
dup = Instance(t.type, t.args, last_known_value=t.last_known_value)
dup.invalid = t.invalid
return self.copy_common(t, dup)
def visit_type_var(self, t: TypeVarType) -> ProperType:
return self.copy_common(t, t.copy_modified())
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
dup = ParamSpecType(
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
)
return self.copy_common(t, dup)
def visit_parameters(self, t: Parameters) -> ProperType:
dup = Parameters(
t.arg_types,
t.arg_kinds,
t.arg_names,
variables=t.variables,
is_ellipsis_args=t.is_ellipsis_args,
)
return self.copy_common(t, dup)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
dup = TypeVarTupleType(
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
)
return self.copy_common(t, dup)
def visit_unpack_type(self, t: UnpackType) -> ProperType:
dup = UnpackType(t.type)
return self.copy_common(t, dup)
def visit_partial_type(self, t: PartialType) -> ProperType:
return self.copy_common(t, PartialType(t.type, t.var, t.value_type))
def visit_callable_type(self, t: CallableType) -> ProperType:
return self.copy_common(t, t.copy_modified())
def visit_tuple_type(self, t: TupleType) -> ProperType:
return self.copy_common(t, TupleType(t.items, t.partial_fallback, implicit=t.implicit))
def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
return self.copy_common(t, TypedDictType(t.items, t.required_keys, t.fallback))
def visit_literal_type(self, t: LiteralType) -> ProperType:
return self.copy_common(t, LiteralType(value=t.value, fallback=t.fallback))
def visit_union_type(self, t: UnionType) -> ProperType:
return self.copy_common(t, UnionType(t.items))
def visit_overloaded(self, t: Overloaded) -> ProperType:
return self.copy_common(t, Overloaded(items=t.items))
def visit_type_type(self, t: TypeType) -> ProperType:
# Use cast since the type annotations in TypeType are imprecise.
return self.copy_common(t, TypeType(cast(Any, t.item)))
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, "only ProperTypes supported"
def copy_common(self, t: ProperType, t2: ProperType) -> ProperType:
t2.line = t.line
t2.column = t.column
t2.can_be_false = t.can_be_false
t2.can_be_true = t.can_be_true
return t2

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import os
from typing import Final
# Earliest fully supported Python 3.x version. Used as the default Python
# version in tests. Mypy wheels should be built starting with this version,
# and CI tests should be run on this version (and later versions).
PYTHON3_VERSION: Final = (3, 8)
# Earliest Python 3.x version supported via --python-version 3.x. To run
# mypy, at least version PYTHON3_VERSION is needed.
PYTHON3_VERSION_MIN: Final = (3, 7) # Keep in sync with typeshed's python support
CACHE_DIR: Final = ".mypy_cache"
CONFIG_FILE: Final = ["mypy.ini", ".mypy.ini"]
PYPROJECT_CONFIG_FILES: Final = ["pyproject.toml"]
SHARED_CONFIG_FILES: Final = ["setup.cfg"]
USER_CONFIG_FILES: Final = ["~/.config/mypy/config", "~/.mypy.ini"]
if os.environ.get("XDG_CONFIG_HOME"):
USER_CONFIG_FILES.insert(0, os.path.join(os.environ["XDG_CONFIG_HOME"], "mypy/config"))
CONFIG_FILES: Final = (
CONFIG_FILE + PYPROJECT_CONFIG_FILES + SHARED_CONFIG_FILES + USER_CONFIG_FILES
)
# This must include all reporters defined in mypy.report. This is defined here
# to make reporter names available without importing mypy.report -- this speeds
# up startup.
REPORTER_NAMES: Final = [
"linecount",
"any-exprs",
"linecoverage",
"memory-xml",
"cobertura-xml",
"xml",
"xslt-html",
"xslt-txt",
"html",
"txt",
"lineprecision",
]
# Threshold after which we sometimes filter out most errors to avoid very
# verbose output. The default is to show all errors.
MANY_ERRORS_THRESHOLD: Final = -1

View File

@@ -0,0 +1,6 @@
from __future__ import annotations
from mypy.dmypy.client import console_entry
if __name__ == "__main__":
console_entry()

View File

@@ -0,0 +1,748 @@
"""Client for mypy daemon mode.
This manages a daemon process which keeps useful state in memory
rather than having to read it back from disk on each run.
"""
from __future__ import annotations
import argparse
import base64
import json
import os
import pickle
import sys
import time
import traceback
from typing import Any, Callable, Mapping, NoReturn
from mypy.dmypy_os import alive, kill
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive
from mypy.ipc import IPCClient, IPCException
from mypy.util import check_python_version, get_terminal_width, should_force_color
from mypy.version import __version__
# Argument parser. Subparsers are tied to action functions by the
# @action(subparse) decorator.
class AugmentedHelpFormatter(argparse.RawDescriptionHelpFormatter):
def __init__(self, prog: str) -> None:
super().__init__(prog=prog, max_help_position=30)
parser = argparse.ArgumentParser(
prog="dmypy", description="Client for mypy daemon mode", fromfile_prefix_chars="@"
)
parser.set_defaults(action=None)
parser.add_argument(
"--status-file", default=DEFAULT_STATUS_FILE, help="status file to retrieve daemon details"
)
parser.add_argument(
"-V",
"--version",
action="version",
version="%(prog)s " + __version__,
help="Show program's version number and exit",
)
subparsers = parser.add_subparsers()
start_parser = p = subparsers.add_parser("start", help="Start daemon")
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument(
"flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)"
)
restart_parser = p = subparsers.add_parser(
"restart", help="Restart daemon (stop or kill followed by start)"
)
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument(
"flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)"
)
status_parser = p = subparsers.add_parser("status", help="Show daemon status")
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("--fswatcher-dump-file", help="Collect information about the current file state")
stop_parser = p = subparsers.add_parser("stop", help="Stop daemon (asks it politely to go away)")
kill_parser = p = subparsers.add_parser("kill", help="Kill daemon (kills the process)")
check_parser = p = subparsers.add_parser(
"check", formatter_class=AugmentedHelpFormatter, help="Check some files (requires daemon)"
)
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("-q", "--quiet", action="store_true", help=argparse.SUPPRESS) # Deprecated
p.add_argument("--junit-xml", help="Write junit.xml to the given file")
p.add_argument("--perf-stats-file", help="write performance information to the given file")
p.add_argument("files", metavar="FILE", nargs="+", help="File (or directory) to check")
p.add_argument(
"--export-types",
action="store_true",
help="Store types of all expressions in a shared location (useful for inspections)",
)
run_parser = p = subparsers.add_parser(
"run",
formatter_class=AugmentedHelpFormatter,
help="Check some files, [re]starting daemon if necessary",
)
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("--junit-xml", help="Write junit.xml to the given file")
p.add_argument("--perf-stats-file", help="write performance information to the given file")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"--export-types",
action="store_true",
help="Store types of all expressions in a shared location (useful for inspections)",
)
p.add_argument(
"flags",
metavar="ARG",
nargs="*",
type=str,
help="Regular mypy flags and files (precede with --)",
)
recheck_parser = p = subparsers.add_parser(
"recheck",
formatter_class=AugmentedHelpFormatter,
help="Re-check the previous list of files, with optional modifications (requires daemon)",
)
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("-q", "--quiet", action="store_true", help=argparse.SUPPRESS) # Deprecated
p.add_argument("--junit-xml", help="Write junit.xml to the given file")
p.add_argument("--perf-stats-file", help="write performance information to the given file")
p.add_argument(
"--export-types",
action="store_true",
help="Store types of all expressions in a shared location (useful for inspections)",
)
p.add_argument(
"--update",
metavar="FILE",
nargs="*",
help="Files in the run to add or check again (default: all from previous run)",
)
p.add_argument("--remove", metavar="FILE", nargs="*", help="Files to remove from the run")
suggest_parser = p = subparsers.add_parser(
"suggest", help="Suggest a signature or show call sites for a specific function"
)
p.add_argument(
"function",
metavar="FUNCTION",
type=str,
help="Function specified as '[package.]module.[class.]function'",
)
p.add_argument(
"--json",
action="store_true",
help="Produce json that pyannotate can use to apply a suggestion",
)
p.add_argument(
"--no-errors", action="store_true", help="Only produce suggestions that cause no errors"
)
p.add_argument(
"--no-any", action="store_true", help="Only produce suggestions that don't contain Any"
)
p.add_argument(
"--flex-any",
type=float,
help="Allow anys in types if they go above a certain score (scores are from 0-1)",
)
p.add_argument(
"--callsites", action="store_true", help="Find callsites instead of suggesting a type"
)
p.add_argument(
"--use-fixme",
metavar="NAME",
type=str,
help="A dummy name to use instead of Any for types that can't be inferred",
)
p.add_argument(
"--max-guesses",
type=int,
help="Set the maximum number of types to try for a function (default 64)",
)
inspect_parser = p = subparsers.add_parser(
"inspect", help="Locate and statically inspect expression(s)"
)
p.add_argument(
"location",
metavar="LOCATION",
type=str,
help="Location specified as path/to/file.py:line:column[:end_line:end_column]."
" If position is given (i.e. only line and column), this will return all"
" enclosing expressions",
)
p.add_argument(
"--show",
metavar="INSPECTION",
type=str,
default="type",
choices=["type", "attrs", "definition"],
help="What kind of inspection to run",
)
p.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="Increase verbosity of the type string representation (can be repeated)",
)
p.add_argument(
"--limit",
metavar="NUM",
type=int,
default=0,
help="Return at most NUM innermost expressions (if position is given); 0 means no limit",
)
p.add_argument(
"--include-span",
action="store_true",
help="Prepend each inspection result with the span of corresponding expression"
' (e.g. 1:2:3:4:"int")',
)
p.add_argument(
"--include-kind",
action="store_true",
help="Prepend each inspection result with the kind of corresponding expression"
' (e.g. NameExpr:"int")',
)
p.add_argument(
"--include-object-attrs",
action="store_true",
help='Include attributes of "object" in "attrs" inspection',
)
p.add_argument(
"--union-attrs",
action="store_true",
help="Include attributes valid for some of possible expression types"
" (by default an intersection is returned)",
)
p.add_argument(
"--force-reload",
action="store_true",
help="Re-parse and re-type-check file before inspection (may be slow)",
)
hang_parser = p = subparsers.add_parser("hang", help="Hang for 100 seconds")
daemon_parser = p = subparsers.add_parser("daemon", help="Run daemon in foreground")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)"
)
p.add_argument("--options-data", help=argparse.SUPPRESS)
help_parser = p = subparsers.add_parser("help")
del p
class BadStatus(Exception):
"""Exception raised when there is something wrong with the status file.
For example:
- No status file found
- Status file malformed
- Process whose pid is in the status file does not exist
"""
def main(argv: list[str]) -> None:
"""The code is top-down."""
check_python_version("dmypy")
args = parser.parse_args(argv)
if not args.action:
parser.print_usage()
else:
try:
args.action(args)
except BadStatus as err:
fail(err.args[0])
except Exception:
# We do this explicitly to avoid exceptions percolating up
# through mypy.api invocations
traceback.print_exc()
sys.exit(2)
def fail(msg: str) -> NoReturn:
print(msg, file=sys.stderr)
sys.exit(2)
ActionFunction = Callable[[argparse.Namespace], None]
def action(subparser: argparse.ArgumentParser) -> Callable[[ActionFunction], ActionFunction]:
"""Decorator to tie an action function to a subparser."""
def register(func: ActionFunction) -> ActionFunction:
subparser.set_defaults(action=func)
return func
return register
# Action functions (run in client from command line).
@action(start_parser)
def do_start(args: argparse.Namespace) -> None:
"""Start daemon (it must not already be running).
This is where mypy flags are set from the command line.
Setting flags is a bit awkward; you have to use e.g.:
dmypy start -- --strict
since we don't want to duplicate mypy's huge list of flags.
"""
try:
get_status(args.status_file)
except BadStatus:
# Bad or missing status file or dead process; good to start.
pass
else:
fail("Daemon is still alive")
start_server(args)
@action(restart_parser)
def do_restart(args: argparse.Namespace) -> None:
"""Restart daemon (it may or may not be running; but not hanging).
We first try to stop it politely if it's running. This also sets
mypy flags from the command line (see do_start()).
"""
restart_server(args)
def restart_server(args: argparse.Namespace, allow_sources: bool = False) -> None:
"""Restart daemon (it may or may not be running; but not hanging)."""
try:
do_stop(args)
except BadStatus:
# Bad or missing status file or dead process; good to start.
pass
start_server(args, allow_sources)
def start_server(args: argparse.Namespace, allow_sources: bool = False) -> None:
"""Start the server from command arguments and wait for it."""
# Lazy import so this import doesn't slow down other commands.
from mypy.dmypy_server import daemonize, process_start_options
start_options = process_start_options(args.flags, allow_sources)
if daemonize(start_options, args.status_file, timeout=args.timeout, log_file=args.log_file):
sys.exit(2)
wait_for_server(args.status_file)
def wait_for_server(status_file: str, timeout: float = 5.0) -> None:
"""Wait until the server is up.
Exit if it doesn't happen within the timeout.
"""
endtime = time.time() + timeout
while time.time() < endtime:
try:
data = read_status(status_file)
except BadStatus:
# If the file isn't there yet, retry later.
time.sleep(0.1)
continue
# If the file's content is bogus or the process is dead, fail.
check_status(data)
print("Daemon started")
return
fail("Timed out waiting for daemon to start")
@action(run_parser)
def do_run(args: argparse.Namespace) -> None:
"""Do a check, starting (or restarting) the daemon as necessary
Restarts the daemon if the running daemon reports that it is
required (due to a configuration change, for example).
Setting flags is a bit awkward; you have to use e.g.:
dmypy run -- --strict a.py b.py ...
since we don't want to duplicate mypy's huge list of flags.
(The -- is only necessary if flags are specified.)
"""
if not is_running(args.status_file):
# Bad or missing status file or dead process; good to start.
start_server(args, allow_sources=True)
t0 = time.time()
response = request(
args.status_file,
"run",
version=__version__,
args=args.flags,
export_types=args.export_types,
)
# If the daemon signals that a restart is necessary, do it
if "restart" in response:
print(f"Restarting: {response['restart']}")
restart_server(args, allow_sources=True)
response = request(
args.status_file,
"run",
version=__version__,
args=args.flags,
export_types=args.export_types,
)
t1 = time.time()
response["roundtrip_time"] = t1 - t0
check_output(response, args.verbose, args.junit_xml, args.perf_stats_file)
@action(status_parser)
def do_status(args: argparse.Namespace) -> None:
"""Print daemon status.
This verifies that it is responsive to requests.
"""
status = read_status(args.status_file)
if args.verbose:
show_stats(status)
# Both check_status() and request() may raise BadStatus,
# which will be handled by main().
check_status(status)
response = request(
args.status_file, "status", fswatcher_dump_file=args.fswatcher_dump_file, timeout=5
)
if args.verbose or "error" in response:
show_stats(response)
if "error" in response:
fail(f"Daemon is stuck; consider {sys.argv[0]} kill")
print("Daemon is up and running")
@action(stop_parser)
def do_stop(args: argparse.Namespace) -> None:
"""Stop daemon via a 'stop' request."""
# May raise BadStatus, which will be handled by main().
response = request(args.status_file, "stop", timeout=5)
if "error" in response:
show_stats(response)
fail(f"Daemon is stuck; consider {sys.argv[0]} kill")
else:
print("Daemon stopped")
@action(kill_parser)
def do_kill(args: argparse.Namespace) -> None:
"""Kill daemon process with SIGKILL."""
pid, _ = get_status(args.status_file)
try:
kill(pid)
except OSError as err:
fail(str(err))
else:
print("Daemon killed")
@action(check_parser)
def do_check(args: argparse.Namespace) -> None:
"""Ask the daemon to check a list of files."""
t0 = time.time()
response = request(args.status_file, "check", files=args.files, export_types=args.export_types)
t1 = time.time()
response["roundtrip_time"] = t1 - t0
check_output(response, args.verbose, args.junit_xml, args.perf_stats_file)
@action(recheck_parser)
def do_recheck(args: argparse.Namespace) -> None:
"""Ask the daemon to recheck the previous list of files, with optional modifications.
If at least one of --remove or --update is given, the server will
update the list of files to check accordingly and assume that any other files
are unchanged. If none of these flags are given, the server will call stat()
on each file last checked to determine its status.
Files given in --update ought to exist. Files given in --remove need not exist;
if they don't they will be ignored.
The lists may be empty but oughtn't contain duplicates or overlap.
NOTE: The list of files is lost when the daemon is restarted.
"""
t0 = time.time()
if args.remove is not None or args.update is not None:
response = request(
args.status_file,
"recheck",
export_types=args.export_types,
remove=args.remove,
update=args.update,
)
else:
response = request(args.status_file, "recheck", export_types=args.export_types)
t1 = time.time()
response["roundtrip_time"] = t1 - t0
check_output(response, args.verbose, args.junit_xml, args.perf_stats_file)
@action(suggest_parser)
def do_suggest(args: argparse.Namespace) -> None:
"""Ask the daemon for a suggested signature.
This just prints whatever the daemon reports as output.
For now it may be closer to a list of call sites.
"""
response = request(
args.status_file,
"suggest",
function=args.function,
json=args.json,
callsites=args.callsites,
no_errors=args.no_errors,
no_any=args.no_any,
flex_any=args.flex_any,
use_fixme=args.use_fixme,
max_guesses=args.max_guesses,
)
check_output(response, verbose=False, junit_xml=None, perf_stats_file=None)
@action(inspect_parser)
def do_inspect(args: argparse.Namespace) -> None:
"""Ask daemon to print the type of an expression."""
response = request(
args.status_file,
"inspect",
show=args.show,
location=args.location,
verbosity=args.verbose,
limit=args.limit,
include_span=args.include_span,
include_kind=args.include_kind,
include_object_attrs=args.include_object_attrs,
union_attrs=args.union_attrs,
force_reload=args.force_reload,
)
check_output(response, verbose=False, junit_xml=None, perf_stats_file=None)
def check_output(
response: dict[str, Any], verbose: bool, junit_xml: str | None, perf_stats_file: str | None
) -> None:
"""Print the output from a check or recheck command.
Call sys.exit() unless the status code is zero.
"""
if "error" in response:
fail(response["error"])
try:
out, err, status_code = response["out"], response["err"], response["status"]
except KeyError:
fail(f"Response: {str(response)}")
sys.stdout.write(out)
sys.stdout.flush()
sys.stderr.write(err)
sys.stderr.flush()
if verbose:
show_stats(response)
if junit_xml:
# Lazy import so this import doesn't slow things down when not writing junit
from mypy.util import write_junit_xml
messages = (out + err).splitlines()
write_junit_xml(
response["roundtrip_time"],
bool(err),
messages,
junit_xml,
response["python_version"],
response["platform"],
)
if perf_stats_file:
telemetry = response.get("stats", {})
with open(perf_stats_file, "w") as f:
json.dump(telemetry, f)
if status_code:
sys.exit(status_code)
def show_stats(response: Mapping[str, object]) -> None:
for key, value in sorted(response.items()):
if key in ("out", "err", "stdout", "stderr"):
# Special case text output to display just 40 characters of text
value = repr(value)[1:-1]
if len(value) > 50:
value = f"{value[:40]} ... {len(value)-40} more characters"
print("%-24s: %s" % (key, value))
continue
print("%-24s: %10s" % (key, "%.3f" % value if isinstance(value, float) else value))
@action(hang_parser)
def do_hang(args: argparse.Namespace) -> None:
"""Hang for 100 seconds, as a debug hack."""
print(request(args.status_file, "hang", timeout=1))
@action(daemon_parser)
def do_daemon(args: argparse.Namespace) -> None:
"""Serve requests in the foreground."""
# Lazy import so this import doesn't slow down other commands.
from mypy.dmypy_server import Server, process_start_options
if args.log_file:
sys.stdout = sys.stderr = open(args.log_file, "a", buffering=1)
fd = sys.stdout.fileno()
os.dup2(fd, 2)
os.dup2(fd, 1)
if args.options_data:
from mypy.options import Options
options_dict = pickle.loads(base64.b64decode(args.options_data))
options_obj = Options()
options = options_obj.apply_changes(options_dict)
else:
options = process_start_options(args.flags, allow_sources=False)
Server(options, args.status_file, timeout=args.timeout).serve()
@action(help_parser)
def do_help(args: argparse.Namespace) -> None:
"""Print full help (same as dmypy --help)."""
parser.print_help()
# Client-side infrastructure.
def request(
status_file: str, command: str, *, timeout: int | None = None, **kwds: object
) -> dict[str, Any]:
"""Send a request to the daemon.
Return the JSON dict with the response.
Raise BadStatus if there is something wrong with the status file
or if the process whose pid is in the status file has died.
Return {'error': <message>} if an IPC operation or receive()
raised OSError. This covers cases such as connection refused or
closed prematurely as well as invalid JSON received.
"""
response: dict[str, str] = {}
args = dict(kwds)
args["command"] = command
# Tell the server whether this request was initiated from a human-facing terminal,
# so that it can format the type checking output accordingly.
args["is_tty"] = sys.stdout.isatty() or should_force_color()
args["terminal_width"] = get_terminal_width()
bdata = json.dumps(args).encode("utf8")
_, name = get_status(status_file)
try:
with IPCClient(name, timeout) as client:
client.write(bdata)
response = receive(client)
except (OSError, IPCException) as err:
return {"error": str(err)}
# TODO: Other errors, e.g. ValueError, UnicodeError
else:
# Display debugging output written to stdout/stderr in the server process for convenience.
# This should not be confused with "out" and "err" fields in the response.
# Those fields hold the output of the "check" command, and are handled in check_output().
stdout = response.get("stdout")
if stdout:
sys.stdout.write(stdout)
stderr = response.get("stderr")
if stderr:
print("-" * 79)
print("stderr:")
sys.stdout.write(stderr)
return response
def get_status(status_file: str) -> tuple[int, str]:
"""Read status file and check if the process is alive.
Return (pid, connection_name) on success.
Raise BadStatus if something's wrong.
"""
data = read_status(status_file)
return check_status(data)
def check_status(data: dict[str, Any]) -> tuple[int, str]:
"""Check if the process is alive.
Return (pid, connection_name) on success.
Raise BadStatus if something's wrong.
"""
if "pid" not in data:
raise BadStatus("Invalid status file (no pid field)")
pid = data["pid"]
if not isinstance(pid, int):
raise BadStatus("pid field is not an int")
if not alive(pid):
raise BadStatus("Daemon has died")
if "connection_name" not in data:
raise BadStatus("Invalid status file (no connection_name field)")
connection_name = data["connection_name"]
if not isinstance(connection_name, str):
raise BadStatus("connection_name field is not a string")
return pid, connection_name
def read_status(status_file: str) -> dict[str, object]:
"""Read status file.
Raise BadStatus if the status file doesn't exist or contains
invalid JSON or the JSON is not a dict.
"""
if not os.path.isfile(status_file):
raise BadStatus("No status file found")
with open(status_file) as f:
try:
data = json.load(f)
except Exception as e:
raise BadStatus("Malformed status file (not JSON)") from e
if not isinstance(data, dict):
raise BadStatus("Invalid status file (not a dict)")
return data
def is_running(status_file: str) -> bool:
"""Check if the server is running cleanly"""
try:
get_status(status_file)
except BadStatus:
return False
return True
# Run main().
def console_entry() -> None:
main(sys.argv[1:])

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import sys
from typing import Any, Callable
if sys.platform == "win32":
import ctypes
import subprocess
from ctypes.wintypes import DWORD, HANDLE
PROCESS_QUERY_LIMITED_INFORMATION = ctypes.c_ulong(0x1000)
kernel32 = ctypes.windll.kernel32
OpenProcess: Callable[[DWORD, int, int], HANDLE] = kernel32.OpenProcess
GetExitCodeProcess: Callable[[HANDLE, Any], int] = kernel32.GetExitCodeProcess
else:
import os
import signal
def alive(pid: int) -> bool:
"""Is the process alive?"""
if sys.platform == "win32":
# why can't anything be easy...
status = DWORD()
handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid)
GetExitCodeProcess(handle, ctypes.byref(status))
return status.value == 259 # STILL_ACTIVE
else:
try:
os.kill(pid, 0)
except OSError:
return False
return True
def kill(pid: int) -> None:
"""Kill the process."""
if sys.platform == "win32":
subprocess.check_output(f"taskkill /pid {pid} /f /t")
else:
os.kill(pid, signal.SIGKILL)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,31 @@
"""Shared code between dmypy.py and dmypy_server.py.
This should be pretty lightweight and not depend on other mypy code (other than ipc).
"""
from __future__ import annotations
import json
from typing import Any, Final
from mypy.ipc import IPCBase
DEFAULT_STATUS_FILE: Final = ".dmypy.json"
def receive(connection: IPCBase) -> Any:
"""Receive JSON data from a connection until EOF.
Raise OSError if the data received is not valid JSON or if it is
not a dict.
"""
bdata = connection.read()
if not bdata:
raise OSError("No data received")
try:
data = json.loads(bdata.decode("utf8"))
except Exception as e:
raise OSError("Data received is not valid JSON") from e
if not isinstance(data, dict):
raise OSError(f"Data received is not a dict ({type(data)})")
return data

View File

@@ -0,0 +1,233 @@
from __future__ import annotations
from typing import Callable, Container, cast
from mypy.nodes import ARG_STAR, ARG_STAR2
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeTranslator,
TypeType,
TypeVarId,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
get_proper_type,
get_proper_types,
)
def erase_type(typ: Type) -> ProperType:
"""Erase any type variables from a type.
Also replace tuple types with the corresponding concrete types.
Examples:
A -> A
B[X] -> B[Any]
Tuple[A, B] -> tuple
Callable[[A1, A2, ...], R] -> Callable[..., Any]
Type[X] -> Type[Any]
"""
typ = get_proper_type(typ)
return typ.accept(EraseTypeVisitor())
class EraseTypeVisitor(TypeVisitor[ProperType]):
def visit_unbound_type(self, t: UnboundType) -> ProperType:
# TODO: replace with an assert after UnboundType can't leak from semantic analysis.
return AnyType(TypeOfAny.from_error)
def visit_any(self, t: AnyType) -> ProperType:
return t
def visit_none_type(self, t: NoneType) -> ProperType:
return t
def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
return t
def visit_erased_type(self, t: ErasedType) -> ProperType:
return t
def visit_partial_type(self, t: PartialType) -> ProperType:
# Should not get here.
raise RuntimeError("Cannot erase partial types")
def visit_deleted_type(self, t: DeletedType) -> ProperType:
return t
def visit_instance(self, t: Instance) -> ProperType:
return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line)
def visit_type_var(self, t: TypeVarType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_parameters(self, t: Parameters) -> ProperType:
raise RuntimeError("Parameters should have been bound to a class")
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_unpack_type(self, t: UnpackType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_callable_type(self, t: CallableType) -> ProperType:
# We must preserve the fallback type for overload resolution to work.
any_type = AnyType(TypeOfAny.special_form)
return CallableType(
arg_types=[any_type, any_type],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None],
ret_type=any_type,
fallback=t.fallback,
is_ellipsis_args=True,
implicit=True,
)
def visit_overloaded(self, t: Overloaded) -> ProperType:
return t.fallback.accept(self)
def visit_tuple_type(self, t: TupleType) -> ProperType:
return t.partial_fallback.accept(self)
def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
return t.fallback.accept(self)
def visit_literal_type(self, t: LiteralType) -> ProperType:
# The fallback for literal types should always be either
# something like int or str, or an enum class -- types that
# don't contain any TypeVars. So there's no need to visit it.
return t
def visit_union_type(self, t: UnionType) -> ProperType:
erased_items = [erase_type(item) for item in t.items]
from mypy.typeops import make_simplified_union
return make_simplified_union(erased_items)
def visit_type_type(self, t: TypeType) -> ProperType:
return TypeType.make_normalized(t.item.accept(self), line=t.line)
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
raise RuntimeError("Type aliases should be expanded before accepting this visitor")
def erase_typevars(t: Type, ids_to_erase: Container[TypeVarId] | None = None) -> Type:
"""Replace all type variables in a type with any,
or just the ones in the provided collection.
"""
def erase_id(id: TypeVarId) -> bool:
if ids_to_erase is None:
return True
return id in ids_to_erase
return t.accept(TypeVarEraser(erase_id, AnyType(TypeOfAny.special_form)))
def replace_meta_vars(t: Type, target_type: Type) -> Type:
"""Replace unification variables in a type with the target type."""
return t.accept(TypeVarEraser(lambda id: id.is_meta_var(), target_type))
class TypeVarEraser(TypeTranslator):
"""Implementation of type erasure"""
def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
self.erase_id = erase_id
self.replacement = replacement
def visit_type_var(self, t: TypeVarType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t
def visit_param_spec(self, t: ParamSpecType) -> Type:
if self.erase_id(t.id):
return self.replacement
return t
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Type alias target can't contain bound type variables (not bound by the type
# alias itself), so it is safe to just erase the arguments.
return t.copy_modified(args=[a.accept(self) for a in t.args])
def remove_instance_last_known_values(t: Type) -> Type:
return t.accept(LastKnownValueEraser())
class LastKnownValueEraser(TypeTranslator):
"""Removes the Literal[...] type that may be associated with any
Instance types."""
def visit_instance(self, t: Instance) -> Type:
if not t.last_known_value and not t.args:
return t
return t.copy_modified(args=[a.accept(self) for a in t.args], last_known_value=None)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Type aliases can't contain literal values, because they are
# always constructed as explicit types.
return t
def visit_union_type(self, t: UnionType) -> Type:
new = cast(UnionType, super().visit_union_type(t))
# Erasure can result in many duplicate items; merge them.
# Call make_simplified_union only on lists of instance types
# that all have the same fullname, to avoid simplifying too
# much.
instances = [item for item in new.items if isinstance(get_proper_type(item), Instance)]
# Avoid merge in simple cases such as optional types.
if len(instances) > 1:
instances_by_name: dict[str, list[Instance]] = {}
p_new_items = get_proper_types(new.items)
for p_item in p_new_items:
if isinstance(p_item, Instance) and not p_item.args:
instances_by_name.setdefault(p_item.type.fullname, []).append(p_item)
merged: list[Type] = []
for item in new.items:
orig_item = item
item = get_proper_type(item)
if isinstance(item, Instance) and not item.args:
types = instances_by_name.get(item.type.fullname)
if types is not None:
if len(types) == 1:
merged.append(item)
else:
from mypy.typeops import make_simplified_union
merged.append(make_simplified_union(types))
del instances_by_name[item.type.fullname]
else:
merged.append(orig_item)
return UnionType.make_union(merged)
return new

View File

@@ -0,0 +1,263 @@
"""Classification of possible errors mypy can detect.
These can be used for filtering specific errors.
"""
from __future__ import annotations
from collections import defaultdict
from typing import Final
from mypy_extensions import mypyc_attr
error_codes: dict[str, ErrorCode] = {}
sub_code_map: dict[str, set[str]] = defaultdict(set)
@mypyc_attr(allow_interpreted_subclasses=True)
class ErrorCode:
def __init__(
self,
code: str,
description: str,
category: str,
default_enabled: bool = True,
sub_code_of: ErrorCode | None = None,
) -> None:
self.code = code
self.description = description
self.category = category
self.default_enabled = default_enabled
self.sub_code_of = sub_code_of
if sub_code_of is not None:
assert sub_code_of.sub_code_of is None, "Nested subcategories are not supported"
sub_code_map[sub_code_of.code].add(code)
error_codes[code] = self
def __str__(self) -> str:
return f"<ErrorCode {self.code}>"
def __eq__(self, other: object) -> bool:
if not isinstance(other, ErrorCode):
return False
return self.code == other.code
def __hash__(self) -> int:
return hash((self.code,))
ATTR_DEFINED: Final = ErrorCode("attr-defined", "Check that attribute exists", "General")
NAME_DEFINED: Final = ErrorCode("name-defined", "Check that name is defined", "General")
CALL_ARG: Final[ErrorCode] = ErrorCode(
"call-arg", "Check number, names and kinds of arguments in calls", "General"
)
ARG_TYPE: Final = ErrorCode("arg-type", "Check argument types in calls", "General")
CALL_OVERLOAD: Final = ErrorCode(
"call-overload", "Check that an overload variant matches arguments", "General"
)
VALID_TYPE: Final[ErrorCode] = ErrorCode(
"valid-type", "Check that type (annotation) is valid", "General"
)
VAR_ANNOTATED: Final = ErrorCode(
"var-annotated", "Require variable annotation if type can't be inferred", "General"
)
OVERRIDE: Final = ErrorCode(
"override", "Check that method override is compatible with base class", "General"
)
RETURN: Final[ErrorCode] = ErrorCode(
"return", "Check that function always returns a value", "General"
)
RETURN_VALUE: Final[ErrorCode] = ErrorCode(
"return-value", "Check that return value is compatible with signature", "General"
)
ASSIGNMENT: Final[ErrorCode] = ErrorCode(
"assignment", "Check that assigned value is compatible with target", "General"
)
METHOD_ASSIGN: Final[ErrorCode] = ErrorCode(
"method-assign",
"Check that assignment target is not a method",
"General",
sub_code_of=ASSIGNMENT,
)
TYPE_ARG: Final = ErrorCode("type-arg", "Check that generic type arguments are present", "General")
TYPE_VAR: Final = ErrorCode("type-var", "Check that type variable values are valid", "General")
UNION_ATTR: Final = ErrorCode(
"union-attr", "Check that attribute exists in each item of a union", "General"
)
INDEX: Final = ErrorCode("index", "Check indexing operations", "General")
OPERATOR: Final = ErrorCode("operator", "Check that operator is valid for operands", "General")
LIST_ITEM: Final = ErrorCode(
"list-item", "Check list items in a list expression [item, ...]", "General"
)
DICT_ITEM: Final = ErrorCode(
"dict-item", "Check dict items in a dict expression {key: value, ...}", "General"
)
TYPEDDICT_ITEM: Final = ErrorCode(
"typeddict-item", "Check items when constructing TypedDict", "General"
)
TYPEDDICT_UNKNOWN_KEY: Final = ErrorCode(
"typeddict-unknown-key",
"Check unknown keys when constructing TypedDict",
"General",
sub_code_of=TYPEDDICT_ITEM,
)
HAS_TYPE: Final = ErrorCode(
"has-type", "Check that type of reference can be determined", "General"
)
IMPORT: Final = ErrorCode(
"import", "Require that imported module can be found or has stubs", "General"
)
IMPORT_NOT_FOUND: Final = ErrorCode(
"import-not-found", "Require that imported module can be found", "General", sub_code_of=IMPORT
)
IMPORT_UNTYPED: Final = ErrorCode(
"import-untyped", "Require that imported module has stubs", "General", sub_code_of=IMPORT
)
NO_REDEF: Final = ErrorCode("no-redef", "Check that each name is defined once", "General")
FUNC_RETURNS_VALUE: Final = ErrorCode(
"func-returns-value", "Check that called function returns a value in value context", "General"
)
ABSTRACT: Final = ErrorCode(
"abstract", "Prevent instantiation of classes with abstract attributes", "General"
)
TYPE_ABSTRACT: Final = ErrorCode(
"type-abstract", "Require only concrete classes where Type[...] is expected", "General"
)
VALID_NEWTYPE: Final = ErrorCode(
"valid-newtype", "Check that argument 2 to NewType is valid", "General"
)
STRING_FORMATTING: Final = ErrorCode(
"str-format", "Check that string formatting/interpolation is type-safe", "General"
)
STR_BYTES_PY3: Final = ErrorCode(
"str-bytes-safe", "Warn about implicit coercions related to bytes and string types", "General"
)
EXIT_RETURN: Final = ErrorCode(
"exit-return", "Warn about too general return type for '__exit__'", "General"
)
LITERAL_REQ: Final = ErrorCode("literal-required", "Check that value is a literal", "General")
UNUSED_COROUTINE: Final = ErrorCode(
"unused-coroutine", "Ensure that all coroutines are used", "General"
)
# TODO: why do we need the explicit type here? Without it mypyc CI builds fail with
# mypy/message_registry.py:37: error: Cannot determine type of "EMPTY_BODY" [has-type]
EMPTY_BODY: Final[ErrorCode] = ErrorCode(
"empty-body",
"A dedicated error code to opt out return errors for empty/trivial bodies",
"General",
)
SAFE_SUPER: Final = ErrorCode(
"safe-super", "Warn about calls to abstract methods with empty/trivial bodies", "General"
)
TOP_LEVEL_AWAIT: Final = ErrorCode(
"top-level-await", "Warn about top level await expressions", "General"
)
AWAIT_NOT_ASYNC: Final = ErrorCode(
"await-not-async", 'Warn about "await" outside coroutine ("async def")', "General"
)
# These error codes aren't enabled by default.
NO_UNTYPED_DEF: Final[ErrorCode] = ErrorCode(
"no-untyped-def", "Check that every function has an annotation", "General"
)
NO_UNTYPED_CALL: Final = ErrorCode(
"no-untyped-call",
"Disallow calling functions without type annotations from annotated functions",
"General",
)
REDUNDANT_CAST: Final = ErrorCode(
"redundant-cast", "Check that cast changes type of expression", "General"
)
ASSERT_TYPE: Final = ErrorCode("assert-type", "Check that assert_type() call succeeds", "General")
COMPARISON_OVERLAP: Final = ErrorCode(
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
)
NO_ANY_UNIMPORTED: Final = ErrorCode(
"no-any-unimported", 'Reject "Any" types from unfollowed imports', "General"
)
NO_ANY_RETURN: Final = ErrorCode(
"no-any-return",
'Reject returning value with "Any" type if return type is not "Any"',
"General",
)
UNREACHABLE: Final = ErrorCode(
"unreachable", "Warn about unreachable statements or expressions", "General"
)
ANNOTATION_UNCHECKED = ErrorCode(
"annotation-unchecked", "Notify about type annotations in unchecked functions", "General"
)
POSSIBLY_UNDEFINED: Final[ErrorCode] = ErrorCode(
"possibly-undefined",
"Warn about variables that are defined only in some execution paths",
"General",
default_enabled=False,
)
REDUNDANT_EXPR: Final = ErrorCode(
"redundant-expr", "Warn about redundant expressions", "General", default_enabled=False
)
TRUTHY_BOOL: Final[ErrorCode] = ErrorCode(
"truthy-bool",
"Warn about expressions that could always evaluate to true in boolean contexts",
"General",
default_enabled=False,
)
TRUTHY_FUNCTION: Final[ErrorCode] = ErrorCode(
"truthy-function",
"Warn about function that always evaluate to true in boolean contexts",
"General",
)
TRUTHY_ITERABLE: Final[ErrorCode] = ErrorCode(
"truthy-iterable",
"Warn about Iterable expressions that could always evaluate to true in boolean contexts",
"General",
default_enabled=False,
)
NAME_MATCH: Final = ErrorCode(
"name-match", "Check that type definition has consistent naming", "General"
)
NO_OVERLOAD_IMPL: Final = ErrorCode(
"no-overload-impl",
"Check that overloaded functions outside stub files have an implementation",
"General",
)
IGNORE_WITHOUT_CODE: Final = ErrorCode(
"ignore-without-code",
"Warn about '# type: ignore' comments which do not have error codes",
"General",
default_enabled=False,
)
UNUSED_AWAITABLE: Final = ErrorCode(
"unused-awaitable",
"Ensure that all awaitable values are used",
"General",
default_enabled=False,
)
REDUNDANT_SELF_TYPE = ErrorCode(
"redundant-self",
"Warn about redundant Self type annotations on method first argument",
"General",
default_enabled=False,
)
USED_BEFORE_DEF: Final[ErrorCode] = ErrorCode(
"used-before-def", "Warn about variables that are used before they are defined", "General"
)
UNUSED_IGNORE: Final = ErrorCode(
"unused-ignore", "Ensure that all type ignores are used", "General", default_enabled=False
)
EXPLICIT_OVERRIDE_REQUIRED: Final = ErrorCode(
"explicit-override",
"Require @override decorator if method is overriding a base class method",
"General",
default_enabled=False,
)
# Syntax errors are often blocking.
SYNTAX: Final[ErrorCode] = ErrorCode("syntax", "Report syntax errors", "General")
# This is an internal marker code for a whole-file ignore. It is not intended to
# be user-visible.
FILE: Final = ErrorCode("file", "Internal marker for a whole file being ignored", "General")
del error_codes[FILE.code]
# This is a catch-all for remaining uncategorized errors.
MISC: Final = ErrorCode("misc", "Miscellaneous other checks", "General")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,204 @@
"""
Evaluate an expression.
Used by stubtest; in a separate file because things break if we don't
put it in a mypyc-compiled file.
"""
import ast
from typing import Final
import mypy.nodes
from mypy.visitor import ExpressionVisitor
UNKNOWN = object()
class _NodeEvaluator(ExpressionVisitor[object]):
def visit_int_expr(self, o: mypy.nodes.IntExpr) -> int:
return o.value
def visit_str_expr(self, o: mypy.nodes.StrExpr) -> str:
return o.value
def visit_bytes_expr(self, o: mypy.nodes.BytesExpr) -> object:
# The value of a BytesExpr is a string created from the repr()
# of the bytes object. Get the original bytes back.
try:
return ast.literal_eval(f"b'{o.value}'")
except SyntaxError:
return ast.literal_eval(f'b"{o.value}"')
def visit_float_expr(self, o: mypy.nodes.FloatExpr) -> float:
return o.value
def visit_complex_expr(self, o: mypy.nodes.ComplexExpr) -> object:
return o.value
def visit_ellipsis(self, o: mypy.nodes.EllipsisExpr) -> object:
return Ellipsis
def visit_star_expr(self, o: mypy.nodes.StarExpr) -> object:
return UNKNOWN
def visit_name_expr(self, o: mypy.nodes.NameExpr) -> object:
if o.name == "True":
return True
elif o.name == "False":
return False
elif o.name == "None":
return None
# TODO: Handle more names by figuring out a way to hook into the
# symbol table.
return UNKNOWN
def visit_member_expr(self, o: mypy.nodes.MemberExpr) -> object:
return UNKNOWN
def visit_yield_from_expr(self, o: mypy.nodes.YieldFromExpr) -> object:
return UNKNOWN
def visit_yield_expr(self, o: mypy.nodes.YieldExpr) -> object:
return UNKNOWN
def visit_call_expr(self, o: mypy.nodes.CallExpr) -> object:
return UNKNOWN
def visit_op_expr(self, o: mypy.nodes.OpExpr) -> object:
return UNKNOWN
def visit_comparison_expr(self, o: mypy.nodes.ComparisonExpr) -> object:
return UNKNOWN
def visit_cast_expr(self, o: mypy.nodes.CastExpr) -> object:
return o.expr.accept(self)
def visit_assert_type_expr(self, o: mypy.nodes.AssertTypeExpr) -> object:
return o.expr.accept(self)
def visit_reveal_expr(self, o: mypy.nodes.RevealExpr) -> object:
return UNKNOWN
def visit_super_expr(self, o: mypy.nodes.SuperExpr) -> object:
return UNKNOWN
def visit_unary_expr(self, o: mypy.nodes.UnaryExpr) -> object:
operand = o.expr.accept(self)
if operand is UNKNOWN:
return UNKNOWN
if o.op == "-":
if isinstance(operand, (int, float, complex)):
return -operand
elif o.op == "+":
if isinstance(operand, (int, float, complex)):
return +operand
elif o.op == "~":
if isinstance(operand, int):
return ~operand
elif o.op == "not":
if isinstance(operand, (bool, int, float, str, bytes)):
return not operand
return UNKNOWN
def visit_assignment_expr(self, o: mypy.nodes.AssignmentExpr) -> object:
return o.value.accept(self)
def visit_list_expr(self, o: mypy.nodes.ListExpr) -> object:
items = [item.accept(self) for item in o.items]
if all(item is not UNKNOWN for item in items):
return items
return UNKNOWN
def visit_dict_expr(self, o: mypy.nodes.DictExpr) -> object:
items = [
(UNKNOWN if key is None else key.accept(self), value.accept(self))
for key, value in o.items
]
if all(key is not UNKNOWN and value is not None for key, value in items):
return dict(items)
return UNKNOWN
def visit_tuple_expr(self, o: mypy.nodes.TupleExpr) -> object:
items = [item.accept(self) for item in o.items]
if all(item is not UNKNOWN for item in items):
return tuple(items)
return UNKNOWN
def visit_set_expr(self, o: mypy.nodes.SetExpr) -> object:
items = [item.accept(self) for item in o.items]
if all(item is not UNKNOWN for item in items):
return set(items)
return UNKNOWN
def visit_index_expr(self, o: mypy.nodes.IndexExpr) -> object:
return UNKNOWN
def visit_type_application(self, o: mypy.nodes.TypeApplication) -> object:
return UNKNOWN
def visit_lambda_expr(self, o: mypy.nodes.LambdaExpr) -> object:
return UNKNOWN
def visit_list_comprehension(self, o: mypy.nodes.ListComprehension) -> object:
return UNKNOWN
def visit_set_comprehension(self, o: mypy.nodes.SetComprehension) -> object:
return UNKNOWN
def visit_dictionary_comprehension(self, o: mypy.nodes.DictionaryComprehension) -> object:
return UNKNOWN
def visit_generator_expr(self, o: mypy.nodes.GeneratorExpr) -> object:
return UNKNOWN
def visit_slice_expr(self, o: mypy.nodes.SliceExpr) -> object:
return UNKNOWN
def visit_conditional_expr(self, o: mypy.nodes.ConditionalExpr) -> object:
return UNKNOWN
def visit_type_var_expr(self, o: mypy.nodes.TypeVarExpr) -> object:
return UNKNOWN
def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> object:
return UNKNOWN
def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> object:
return UNKNOWN
def visit_type_alias_expr(self, o: mypy.nodes.TypeAliasExpr) -> object:
return UNKNOWN
def visit_namedtuple_expr(self, o: mypy.nodes.NamedTupleExpr) -> object:
return UNKNOWN
def visit_enum_call_expr(self, o: mypy.nodes.EnumCallExpr) -> object:
return UNKNOWN
def visit_typeddict_expr(self, o: mypy.nodes.TypedDictExpr) -> object:
return UNKNOWN
def visit_newtype_expr(self, o: mypy.nodes.NewTypeExpr) -> object:
return UNKNOWN
def visit__promote_expr(self, o: mypy.nodes.PromoteExpr) -> object:
return UNKNOWN
def visit_await_expr(self, o: mypy.nodes.AwaitExpr) -> object:
return UNKNOWN
def visit_temp_node(self, o: mypy.nodes.TempNode) -> object:
return UNKNOWN
_evaluator: Final = _NodeEvaluator()
def evaluate_expression(expr: mypy.nodes.Expression) -> object:
"""Evaluate an expression at runtime.
Return the result of the expression, or UNKNOWN if the expression cannot be
evaluated.
"""
return expr.accept(_evaluator)

View File

@@ -0,0 +1,516 @@
from __future__ import annotations
from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload
from mypy.nodes import ARG_STAR, Var
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
AnyType,
BoolTypeQuery,
CallableType,
DeletedType,
ErasedType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
ProperType,
TrivialSyntheticTypeTranslator,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_unions,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import split_with_instance
# Solving the import cycle:
import mypy.type_visitor # ruff: isort: skip
# WARNING: these functions should never (directly or indirectly) depend on
# is_subtype(), meet_types(), join_types() etc.
# TODO: add a static dependency test for this.
@overload
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType:
...
@overload
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType:
...
@overload
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
...
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
"""Substitute any type variable references in a type given by a type
environment.
"""
return typ.accept(ExpandTypeVisitor(env))
@overload
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType:
...
@overload
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType:
...
@overload
def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
...
def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
"""Substitute type variables in type using values from an Instance.
Type variables are considered to be bound by the class declaration."""
if not instance.args:
return typ
else:
variables: dict[TypeVarId, Type] = {}
if instance.type.has_type_var_tuple_type:
assert instance.type.type_var_tuple_prefix is not None
assert instance.type.type_var_tuple_suffix is not None
args_prefix, args_middle, args_suffix = split_with_instance(instance)
tvars_prefix, tvars_middle, tvars_suffix = split_with_prefix_and_suffix(
tuple(instance.type.defn.type_vars),
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
)
tvar = tvars_middle[0]
assert isinstance(tvar, TypeVarTupleType)
variables = {tvar.id: TupleType(list(args_middle), tvar.tuple_fallback)}
instance_args = args_prefix + args_suffix
tvars = tvars_prefix + tvars_suffix
else:
tvars = tuple(instance.type.defn.type_vars)
instance_args = instance.args
for binder, arg in zip(tvars, instance_args):
assert isinstance(binder, TypeVarLikeType)
variables[binder.id] = arg
return expand_type(typ, variables)
F = TypeVar("F", bound=FunctionLike)
def freshen_function_type_vars(callee: F) -> F:
"""Substitute fresh type variables for generic function type variables."""
if isinstance(callee, CallableType):
if not callee.is_generic():
return cast(F, callee)
tvs = []
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
tv = v.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
return cast(F, fresh)
else:
assert isinstance(callee, Overloaded)
fresh_overload = Overloaded([freshen_function_type_vars(item) for item in callee.items])
return cast(F, fresh_overload)
class HasGenericCallable(BoolTypeQuery):
def __init__(self) -> None:
super().__init__(ANY_STRATEGY)
def visit_callable_type(self, t: CallableType) -> bool:
return t.is_generic() or super().visit_callable_type(t)
# Share a singleton since this is performance sensitive
has_generic_callable: Final = HasGenericCallable()
T = TypeVar("T", bound=Type)
def freshen_all_functions_type_vars(t: T) -> T:
result: Type
has_generic_callable.reset()
if not t.accept(has_generic_callable):
return t # Fast path to avoid expensive freshening
else:
result = t.accept(FreshenCallableVisitor())
assert isinstance(result, type(t))
return result
class FreshenCallableVisitor(mypy.type_visitor.TypeTranslator):
def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
return freshen_function_type_vars(result)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Same as for ExpandTypeVisitor
return t.copy_modified(args=[arg.accept(self) for arg in t.args])
class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
"""Visitor that substitutes type variables with values."""
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
self.variables = variables
def visit_unbound_type(self, t: UnboundType) -> Type:
return t
def visit_any(self, t: AnyType) -> Type:
return t
def visit_none_type(self, t: NoneType) -> Type:
return t
def visit_uninhabited_type(self, t: UninhabitedType) -> Type:
return t
def visit_deleted_type(self, t: DeletedType) -> Type:
return t
def visit_erased_type(self, t: ErasedType) -> Type:
# This may happen during type inference if some function argument
# type is a generic callable, and its erased form will appear in inferred
# constraints, then solver may check subtyping between them, which will trigger
# unify_generic_callables(), this is why we can get here. Another example is
# when inferring type of lambda in generic context, the lambda body contains
# a generic method in generic class.
return t
def visit_instance(self, t: Instance) -> Type:
args = self.expand_types_with_unpack(list(t.args))
if isinstance(args, list):
return t.copy_modified(args=args)
else:
return args
def visit_type_var(self, t: TypeVarType) -> Type:
# Normally upper bounds can't contain other type variables, the only exception is
# special type variable Self`0 <: C[T, S], where C is the class where Self is used.
if t.id.raw_id == 0:
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
if isinstance(repl, ProperType) and isinstance(repl, Instance):
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
return repl
def visit_param_spec(self, t: ParamSpecType) -> Type:
# Set prefix to something empty, so we don't duplicate it below.
repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=self.expand_types(t.prefix.arg_types) + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
elif isinstance(repl, Parameters):
assert t.flavor == ParamSpecFlavor.BARE
return Parameters(
self.expand_types(t.prefix.arg_types) + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
)
else:
# TODO: replace this with "assert False"
return repl
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# Sometimes solver may need to expand a type variable with (a copy of) itself
# (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).
repl = self.variables.get(t.id, t)
if isinstance(repl, TypeVarTupleType):
return repl
raise NotImplementedError
def visit_unpack_type(self, t: UnpackType) -> Type:
# It is impossible to reasonably implement visit_unpack_type, because
# unpacking inherently expands to something more like a list of types.
#
# Relevant sections that can call unpack should call expand_unpack()
# instead.
# However, if the item is a variadic tuple, we can simply carry it over.
# In particular, if we expand A[*tuple[T, ...]] with substitutions {T: str},
# it is hard to assert this without getting proper type. Another important
# example is non-normalized types when called from semanal.py.
return UnpackType(t.type.accept(self))
def expand_unpack(self, t: UnpackType) -> list[Type]:
assert isinstance(t.type, TypeVarTupleType)
repl = get_proper_type(self.variables.get(t.type.id, t.type))
if isinstance(repl, TupleType):
return repl.items
elif (
isinstance(repl, Instance)
and repl.type.fullname == "builtins.tuple"
or isinstance(repl, TypeVarTupleType)
):
return [UnpackType(typ=repl)]
elif isinstance(repl, (AnyType, UninhabitedType)):
# Replace *Ts = Any with *Ts = *tuple[Any, ...] and some for <nothing>.
# These types may appear here as a result of user error or failed inference.
return [UnpackType(t.type.tuple_fallback.copy_modified(args=[repl]))]
else:
raise RuntimeError(f"Invalid type replacement to expand: {repl}")
def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> list[Type]:
star_index = t.arg_kinds.index(ARG_STAR)
prefix = self.expand_types(t.arg_types[:star_index])
suffix = self.expand_types(t.arg_types[star_index + 1 :])
var_arg_type = get_proper_type(var_arg.type)
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
if isinstance(var_arg_type, TupleType):
expanded_tuple = var_arg_type.accept(self)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
expanded_items = expanded_tuple.items
fallback = var_arg_type.partial_fallback
else:
# We have plain Unpack[Ts]
assert isinstance(var_arg_type, TypeVarTupleType)
fallback = var_arg_type.tuple_fallback
expanded_items = self.expand_unpack(var_arg)
new_unpack = UnpackType(TupleType(expanded_items, fallback))
return prefix + [new_unpack] + suffix
def visit_callable_type(self, t: CallableType) -> CallableType:
param_spec = t.param_spec()
if param_spec is not None:
repl = self.variables.get(param_spec.id)
# If a ParamSpec in a callable type is substituted with a
# callable type, we can't use normal substitution logic,
# since ParamSpec is actually split into two components
# *P.args and **P.kwargs in the original type. Instead, we
# must expand both of them with all the argument types,
# kinds and names in the replacement. The return type in
# the replacement is ignored.
if isinstance(repl, Parameters):
# We need to expand both the types in the prefix and the ParamSpec itself
return t.copy_modified(
arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types,
arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds,
arg_names=t.arg_names[:-2] + repl.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
variables=[*repl.variables, *t.variables],
)
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
# changes, e.g. substitute Concatenate[int, P] in place of Q.
prefix = repl.prefix
clean_repl = repl.copy_modified(prefix=Parameters([], [], []))
return t.copy_modified(
arg_types=self.expand_types(t.arg_types[:-2])
+ prefix.arg_types
+ [
clean_repl.with_flavor(ParamSpecFlavor.ARGS),
clean_repl.with_flavor(ParamSpecFlavor.KWARGS),
],
arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:],
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
ret_type=t.ret_type.accept(self),
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds),
)
var_arg = t.var_arg()
needs_normalization = False
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_types = self.expand_types(t.arg_types)
expanded = t.copy_modified(
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
if needs_normalization:
return expanded.with_normalized_var_args()
return expanded
def visit_overloaded(self, t: Overloaded) -> Type:
items: list[CallableType] = []
for item in t.items:
new_item = item.accept(self)
assert isinstance(new_item, ProperType)
assert isinstance(new_item, CallableType)
items.append(new_item)
return Overloaded(items)
def expand_types_with_unpack(
self, typs: Sequence[Type]
) -> list[Type] | AnyType | UninhabitedType:
"""Expands a list of types that has an unpack.
In corner cases, this can return a type rather than a list, in which case this
indicates use of Any or some error occurred earlier. In this case callers should
simply propagate the resulting type.
"""
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
items.extend(self.expand_unpack(item))
else:
items.append(item.accept(self))
return items
def visit_tuple_type(self, t: TupleType) -> Type:
items = self.expand_types_with_unpack(t.items)
if isinstance(items, list):
if len(items) == 1:
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
item = items[0]
if isinstance(item, UnpackType):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
if t.partial_fallback.type.fullname != "builtins.tuple":
# If it is a subtype (like named tuple) we need to preserve it,
# this essentially mimics the logic in tuple_fallback().
return t.partial_fallback.accept(self)
return unpacked
fallback = t.partial_fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(items=items, fallback=fallback)
else:
return items
def visit_typeddict_type(self, t: TypedDictType) -> Type:
fallback = t.fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
def visit_literal_type(self, t: LiteralType) -> Type:
# TODO: Verify this implementation is correct
return t
def visit_union_type(self, t: UnionType) -> Type:
expanded = self.expand_types(t.items)
# After substituting for type variables in t.items, some resulting types
# might be subtypes of others, however calling make_simplified_union()
# can cause recursion, so we just remove strict duplicates.
simplified = UnionType.make_union(
remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
)
# This call to get_proper_type() is unfortunate but is required to preserve
# the invariant that ProperType will stay ProperType after applying expand_type(),
# otherwise a single item union of a type alias will break it. Note this should not
# cause infinite recursion since pathological aliases like A = Union[A, B] are
# banned at the semantic analysis level.
return get_proper_type(simplified)
def visit_partial_type(self, t: PartialType) -> Type:
return t
def visit_type_type(self, t: TypeType) -> Type:
# TODO: Verify that the new item type is valid (instance or
# union of instances or Any). Sadly we can't report errors
# here yet.
item = t.item.accept(self)
return TypeType.make_normalized(item)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Target of the type alias cannot contain type variables (not bound by the type
# alias itself), so we just expand the arguments.
args = self.expand_types_with_unpack(t.args)
if isinstance(args, list):
# TODO: normalize if target is Tuple, and args are [*tuple[X, ...]]?
return t.copy_modified(args=args)
else:
return args
def expand_types(self, types: Iterable[Type]) -> list[Type]:
a: list[Type] = []
for t in types:
a.append(t.accept(self))
return a
@overload
def expand_self_type(var: Var, typ: ProperType, replacement: ProperType) -> ProperType:
...
@overload
def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
...
def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
"""Expand appearances of Self type in a variable type."""
if var.info.self_type is not None and not var.is_property:
return expand_type(typ, {var.info.self_type.id: replacement})
return typ
def remove_trivial(types: Iterable[Type]) -> list[Type]:
"""Make trivial simplifications on a list of types without calling is_subtype().
This makes following simplifications:
* Remove bottom types (taking into account strict optional setting)
* Remove everything else if there is an `object`
* Remove strict duplicate types
"""
removed_none = False
new_types = []
all_types = set()
for t in types:
p_t = get_proper_type(t)
if isinstance(p_t, UninhabitedType):
continue
if isinstance(p_t, NoneType) and not state.strict_optional:
removed_none = True
continue
if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":
return [p_t]
if p_t not in all_types:
new_types.append(t)
all_types.add(p_t)
if new_types:
return new_types
if removed_none:
return [NoneType()]
return [UninhabitedType()]

View File

@@ -0,0 +1,201 @@
"""Translate an Expression to a Type value."""
from __future__ import annotations
from mypy.fastparse import parse_type_string
from mypy.nodes import (
BytesExpr,
CallExpr,
ComplexExpr,
EllipsisExpr,
Expression,
FloatExpr,
IndexExpr,
IntExpr,
ListExpr,
MemberExpr,
NameExpr,
OpExpr,
RefExpr,
StarExpr,
StrExpr,
TupleExpr,
UnaryExpr,
get_member_expr_fullname,
)
from mypy.options import Options
from mypy.types import (
ANNOTATED_TYPE_NAMES,
AnyType,
CallableArgument,
EllipsisType,
ProperType,
RawExpressionType,
Type,
TypeList,
TypeOfAny,
UnboundType,
UnionType,
UnpackType,
)
class TypeTranslationError(Exception):
"""Exception raised when an expression is not valid as a type."""
def _extract_argument_name(expr: Expression) -> str | None:
if isinstance(expr, NameExpr) and expr.name == "None":
return None
elif isinstance(expr, StrExpr):
return expr.value
else:
raise TypeTranslationError()
def expr_to_unanalyzed_type(
expr: Expression,
options: Options | None = None,
allow_new_syntax: bool = False,
_parent: Expression | None = None,
allow_unpack: bool = False,
) -> ProperType:
"""Translate an expression to the corresponding type.
The result is not semantically analyzed. It can be UnboundType or TypeList.
Raise TypeTranslationError if the expression cannot represent a type.
If allow_new_syntax is True, allow all type syntax independent of the target
Python version (used in stubs).
"""
# The `parent` parameter is used in recursive calls to provide context for
# understanding whether an CallableArgument is ok.
name: str | None = None
if isinstance(expr, NameExpr):
name = expr.name
if name == "True":
return RawExpressionType(True, "builtins.bool", line=expr.line, column=expr.column)
elif name == "False":
return RawExpressionType(False, "builtins.bool", line=expr.line, column=expr.column)
else:
return UnboundType(name, line=expr.line, column=expr.column)
elif isinstance(expr, MemberExpr):
fullname = get_member_expr_fullname(expr)
if fullname:
return UnboundType(fullname, line=expr.line, column=expr.column)
else:
raise TypeTranslationError()
elif isinstance(expr, IndexExpr):
base = expr_to_unanalyzed_type(expr.base, options, allow_new_syntax, expr)
if isinstance(base, UnboundType):
if base.args:
raise TypeTranslationError()
if isinstance(expr.index, TupleExpr):
args = expr.index.items
else:
args = [expr.index]
if isinstance(expr.base, RefExpr) and expr.base.fullname in ANNOTATED_TYPE_NAMES:
# TODO: this is not the optimal solution as we are basically getting rid
# of the Annotation definition and only returning the type information,
# losing all the annotations.
return expr_to_unanalyzed_type(args[0], options, allow_new_syntax, expr)
else:
base.args = tuple(
expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr) for arg in args
)
if not base.args:
base.empty_tuple_index = True
return base
else:
raise TypeTranslationError()
elif (
isinstance(expr, OpExpr)
and expr.op == "|"
and ((options and options.python_version >= (3, 10)) or allow_new_syntax)
):
return UnionType(
[
expr_to_unanalyzed_type(expr.left, options, allow_new_syntax),
expr_to_unanalyzed_type(expr.right, options, allow_new_syntax),
]
)
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
c = expr.callee
names = []
# Go through the dotted member expr chain to get the full arg
# constructor name to look up
while True:
if isinstance(c, NameExpr):
names.append(c.name)
break
elif isinstance(c, MemberExpr):
names.append(c.name)
c = c.expr
else:
raise TypeTranslationError()
arg_const = ".".join(reversed(names))
# Go through the constructor args to get its name and type.
name = None
default_type = AnyType(TypeOfAny.unannotated)
typ: Type = default_type
for i, arg in enumerate(expr.args):
if expr.arg_names[i] is not None:
if expr.arg_names[i] == "name":
if name is not None:
# Two names
raise TypeTranslationError()
name = _extract_argument_name(arg)
continue
elif expr.arg_names[i] == "type":
if typ is not default_type:
# Two types
raise TypeTranslationError()
typ = expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr)
continue
else:
raise TypeTranslationError()
elif i == 0:
typ = expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr)
elif i == 1:
name = _extract_argument_name(arg)
else:
raise TypeTranslationError()
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
elif isinstance(expr, ListExpr):
return TypeList(
[
expr_to_unanalyzed_type(t, options, allow_new_syntax, expr, allow_unpack=True)
for t in expr.items
],
line=expr.line,
column=expr.column,
)
elif isinstance(expr, StrExpr):
return parse_type_string(expr.value, "builtins.str", expr.line, expr.column)
elif isinstance(expr, BytesExpr):
return parse_type_string(expr.value, "builtins.bytes", expr.line, expr.column)
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax)
if isinstance(typ, RawExpressionType):
if isinstance(typ.literal_value, int) and expr.op == "-":
typ.literal_value *= -1
return typ
raise TypeTranslationError()
elif isinstance(expr, IntExpr):
return RawExpressionType(expr.value, "builtins.int", line=expr.line, column=expr.column)
elif isinstance(expr, FloatExpr):
# Floats are not valid parameters for RawExpressionType , so we just
# pass in 'None' for now. We'll report the appropriate error at a later stage.
return RawExpressionType(None, "builtins.float", line=expr.line, column=expr.column)
elif isinstance(expr, ComplexExpr):
# Same thing as above with complex numbers.
return RawExpressionType(None, "builtins.complex", line=expr.line, column=expr.column)
elif isinstance(expr, EllipsisExpr):
return EllipsisType(expr.line)
elif allow_unpack and isinstance(expr, StarExpr):
return UnpackType(expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax))
else:
raise TypeTranslationError()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,243 @@
"""Routines for finding the sources that mypy will check"""
from __future__ import annotations
import functools
import os
from typing import Final, Sequence
from mypy.fscache import FileSystemCache
from mypy.modulefinder import PYTHON_EXTENSIONS, BuildSource, matches_exclude, mypy_path
from mypy.options import Options
PY_EXTENSIONS: Final = tuple(PYTHON_EXTENSIONS)
class InvalidSourceList(Exception):
"""Exception indicating a problem in the list of sources given to mypy."""
def create_source_list(
paths: Sequence[str],
options: Options,
fscache: FileSystemCache | None = None,
allow_empty_dir: bool = False,
) -> list[BuildSource]:
"""From a list of source files/directories, makes a list of BuildSources.
Raises InvalidSourceList on errors.
"""
fscache = fscache or FileSystemCache()
finder = SourceFinder(fscache, options)
sources = []
for path in paths:
path = os.path.normpath(path)
if path.endswith(PY_EXTENSIONS):
# Can raise InvalidSourceList if a directory doesn't have a valid module name.
name, base_dir = finder.crawl_up(path)
sources.append(BuildSource(path, name, None, base_dir))
elif fscache.isdir(path):
sub_sources = finder.find_sources_in_dir(path)
if not sub_sources and not allow_empty_dir:
raise InvalidSourceList(f"There are no .py[i] files in directory '{path}'")
sources.extend(sub_sources)
else:
mod = os.path.basename(path) if options.scripts_are_modules else None
sources.append(BuildSource(path, mod, None))
return sources
def keyfunc(name: str) -> tuple[bool, int, str]:
"""Determines sort order for directory listing.
The desirable properties are:
1) foo < foo.pyi < foo.py
2) __init__.py[i] < foo
"""
base, suffix = os.path.splitext(name)
for i, ext in enumerate(PY_EXTENSIONS):
if suffix == ext:
return (base != "__init__", i, base)
return (base != "__init__", -1, name)
def normalise_package_base(root: str) -> str:
if not root:
root = os.curdir
root = os.path.abspath(root)
if root.endswith(os.sep):
root = root[:-1]
return root
def get_explicit_package_bases(options: Options) -> list[str] | None:
"""Returns explicit package bases to use if the option is enabled, or None if disabled.
We currently use MYPYPATH and the current directory as the package bases. In the future,
when --namespace-packages is the default could also use the values passed with the
--package-root flag, see #9632.
Values returned are normalised so we can use simple string comparisons in
SourceFinder.is_explicit_package_base
"""
if not options.explicit_package_bases:
return None
roots = mypy_path() + options.mypy_path + [os.getcwd()]
return [normalise_package_base(root) for root in roots]
class SourceFinder:
def __init__(self, fscache: FileSystemCache, options: Options) -> None:
self.fscache = fscache
self.explicit_package_bases = get_explicit_package_bases(options)
self.namespace_packages = options.namespace_packages
self.exclude = options.exclude
self.verbosity = options.verbosity
def is_explicit_package_base(self, path: str) -> bool:
assert self.explicit_package_bases
return normalise_package_base(path) in self.explicit_package_bases
def find_sources_in_dir(self, path: str) -> list[BuildSource]:
sources = []
seen: set[str] = set()
names = sorted(self.fscache.listdir(path), key=keyfunc)
for name in names:
# Skip certain names altogether
if name in ("__pycache__", "site-packages", "node_modules") or name.startswith("."):
continue
subpath = os.path.join(path, name)
if matches_exclude(subpath, self.exclude, self.fscache, self.verbosity >= 2):
continue
if self.fscache.isdir(subpath):
sub_sources = self.find_sources_in_dir(subpath)
if sub_sources:
seen.add(name)
sources.extend(sub_sources)
else:
stem, suffix = os.path.splitext(name)
if stem not in seen and suffix in PY_EXTENSIONS:
seen.add(stem)
module, base_dir = self.crawl_up(subpath)
sources.append(BuildSource(subpath, module, None, base_dir))
return sources
def crawl_up(self, path: str) -> tuple[str, str]:
"""Given a .py[i] filename, return module and base directory.
For example, given "xxx/yyy/foo/bar.py", we might return something like:
("foo.bar", "xxx/yyy")
If namespace packages is off, we crawl upwards until we find a directory without
an __init__.py
If namespace packages is on, we crawl upwards until the nearest explicit base directory.
Failing that, we return one past the highest directory containing an __init__.py
We won't crawl past directories with invalid package names.
The base directory returned is an absolute path.
"""
path = os.path.abspath(path)
parent, filename = os.path.split(path)
module_name = strip_py(filename) or filename
parent_module, base_dir = self.crawl_up_dir(parent)
if module_name == "__init__":
return parent_module, base_dir
# Note that module_name might not actually be a valid identifier, but that's okay
# Ignoring this possibility sidesteps some search path confusion
module = module_join(parent_module, module_name)
return module, base_dir
def crawl_up_dir(self, dir: str) -> tuple[str, str]:
return self._crawl_up_helper(dir) or ("", dir)
@functools.lru_cache # noqa: B019
def _crawl_up_helper(self, dir: str) -> tuple[str, str] | None:
"""Given a directory, maybe returns module and base directory.
We return a non-None value if we were able to find something clearly intended as a base
directory (as adjudicated by being an explicit base directory or by containing a package
with __init__.py).
This distinction is necessary for namespace packages, so that we know when to treat
ourselves as a subpackage.
"""
# stop crawling if we're an explicit base directory
if self.explicit_package_bases is not None and self.is_explicit_package_base(dir):
return "", dir
parent, name = os.path.split(dir)
if name.endswith("-stubs"):
name = name[:-6] # PEP-561 stub-only directory
# recurse if there's an __init__.py
init_file = self.get_init_file(dir)
if init_file is not None:
if not name.isidentifier():
# in most cases the directory name is invalid, we'll just stop crawling upwards
# but if there's an __init__.py in the directory, something is messed up
raise InvalidSourceList(f"{name} is not a valid Python package name")
# we're definitely a package, so we always return a non-None value
mod_prefix, base_dir = self.crawl_up_dir(parent)
return module_join(mod_prefix, name), base_dir
# stop crawling if we're out of path components or our name is an invalid identifier
if not name or not parent or not name.isidentifier():
return None
# stop crawling if namespace packages is off (since we don't have an __init__.py)
if not self.namespace_packages:
return None
# at this point: namespace packages is on, we don't have an __init__.py and we're not an
# explicit base directory
result = self._crawl_up_helper(parent)
if result is None:
# we're not an explicit base directory and we don't have an __init__.py
# and none of our parents are either, so return
return None
# one of our parents was an explicit base directory or had an __init__.py, so we're
# definitely a subpackage! chain our name to the module.
mod_prefix, base_dir = result
return module_join(mod_prefix, name), base_dir
def get_init_file(self, dir: str) -> str | None:
"""Check whether a directory contains a file named __init__.py[i].
If so, return the file's name (with dir prefixed). If not, return None.
This prefers .pyi over .py (because of the ordering of PY_EXTENSIONS).
"""
for ext in PY_EXTENSIONS:
f = os.path.join(dir, "__init__" + ext)
if self.fscache.isfile(f):
return f
if ext == ".py" and self.fscache.init_under_package_root(f):
return f
return None
def module_join(parent: str, child: str) -> str:
"""Join module ids, accounting for a possibly empty parent."""
if parent:
return parent + "." + child
return child
def strip_py(arg: str) -> str | None:
"""Strip a trailing .py or .pyi suffix.
Return None if no such suffix is found.
"""
for ext in PY_EXTENSIONS:
if arg.endswith(ext):
return arg[: -len(ext)]
return None

View File

@@ -0,0 +1,410 @@
"""Fix up various things after deserialization."""
from __future__ import annotations
from typing import Any, Final
from mypy.lookup import lookup_fully_qualified
from mypy.nodes import (
Block,
ClassDef,
Decorator,
FuncDef,
MypyFile,
OverloadedFuncDef,
ParamSpecExpr,
SymbolTable,
TypeAlias,
TypeInfo,
TypeVarExpr,
TypeVarTupleExpr,
Var,
)
from mypy.types import (
NOT_READY,
AnyType,
CallableType,
Instance,
LiteralType,
Overloaded,
Parameters,
ParamSpecType,
TupleType,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
UnboundType,
UnionType,
UnpackType,
)
from mypy.visitor import NodeVisitor
# N.B: we do a allow_missing fixup when fixing up a fine-grained
# incremental cache load (since there may be cross-refs into deleted
# modules)
def fixup_module(tree: MypyFile, modules: dict[str, MypyFile], allow_missing: bool) -> None:
node_fixer = NodeFixer(modules, allow_missing)
node_fixer.visit_symbol_table(tree.names, tree.fullname)
# TODO: Fix up .info when deserializing, i.e. much earlier.
class NodeFixer(NodeVisitor[None]):
current_info: TypeInfo | None = None
def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None:
self.modules = modules
self.allow_missing = allow_missing
self.type_fixer = TypeFixer(self.modules, allow_missing)
# NOTE: This method isn't (yet) part of the NodeVisitor API.
def visit_type_info(self, info: TypeInfo) -> None:
save_info = self.current_info
try:
self.current_info = info
if info.defn:
info.defn.accept(self)
if info.names:
self.visit_symbol_table(info.names, info.fullname)
if info.bases:
for base in info.bases:
base.accept(self.type_fixer)
if info._promote:
for p in info._promote:
p.accept(self.type_fixer)
if info.tuple_type:
info.tuple_type.accept(self.type_fixer)
info.update_tuple_type(info.tuple_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
info.update_typeddict_type(info.typeddict_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
if info.declared_metaclass:
info.declared_metaclass.accept(self.type_fixer)
if info.metaclass_type:
info.metaclass_type.accept(self.type_fixer)
if info.alt_promote:
info.alt_promote.accept(self.type_fixer)
instance = Instance(info, [])
# Hack: We may also need to add a backwards promotion (from int to native int),
# since it might not be serialized.
if instance not in info.alt_promote.type._promote:
info.alt_promote.type._promote.append(instance)
if info._mro_refs:
info.mro = [
lookup_fully_qualified_typeinfo(
self.modules, name, allow_missing=self.allow_missing
)
for name in info._mro_refs
]
info._mro_refs = None
finally:
self.current_info = save_info
# NOTE: This method *definitely* isn't part of the NodeVisitor API.
def visit_symbol_table(self, symtab: SymbolTable, table_fullname: str) -> None:
# Copy the items because we may mutate symtab.
for key, value in list(symtab.items()):
cross_ref = value.cross_ref
if cross_ref is not None: # Fix up cross-reference.
value.cross_ref = None
if cross_ref in self.modules:
value.node = self.modules[cross_ref]
else:
stnode = lookup_fully_qualified(
cross_ref, self.modules, raise_on_missing=not self.allow_missing
)
if stnode is not None:
assert stnode.node is not None, (table_fullname + "." + key, cross_ref)
value.node = stnode.node
elif not self.allow_missing:
assert False, f"Could not find cross-ref {cross_ref}"
else:
# We have a missing crossref in allow missing mode, need to put something
value.node = missing_info(self.modules)
else:
if isinstance(value.node, TypeInfo):
# TypeInfo has no accept(). TODO: Add it?
self.visit_type_info(value.node)
elif value.node is not None:
value.node.accept(self)
else:
assert False, f"Unexpected empty node {key!r}: {value}"
def visit_func_def(self, func: FuncDef) -> None:
if self.current_info is not None:
func.info = self.current_info
if func.type is not None:
func.type.accept(self.type_fixer)
def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
if self.current_info is not None:
o.info = self.current_info
if o.type:
o.type.accept(self.type_fixer)
for item in o.items:
item.accept(self)
if o.impl:
o.impl.accept(self)
def visit_decorator(self, d: Decorator) -> None:
if self.current_info is not None:
d.var.info = self.current_info
if d.func:
d.func.accept(self)
if d.var:
d.var.accept(self)
for node in d.decorators:
node.accept(self)
def visit_class_def(self, c: ClassDef) -> None:
for v in c.type_vars:
if isinstance(v, TypeVarType):
for value in v.values:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
value.accept(self.type_fixer)
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
p.upper_bound.accept(self.type_fixer)
p.default.accept(self.type_fixer)
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)
def visit_var(self, v: Var) -> None:
if self.current_info is not None:
v.info = self.current_info
if v.type is not None:
v.type.accept(self.type_fixer)
def visit_type_alias(self, a: TypeAlias) -> None:
a.target.accept(self.type_fixer)
for v in a.alias_tvars:
v.accept(self.type_fixer)
class TypeFixer(TypeVisitor[None]):
def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None:
self.modules = modules
self.allow_missing = allow_missing
def visit_instance(self, inst: Instance) -> None:
# TODO: Combine Instances that are exactly the same?
type_ref = inst.type_ref
if type_ref is None:
return # We've already been here.
inst.type_ref = None
inst.type = lookup_fully_qualified_typeinfo(
self.modules, type_ref, allow_missing=self.allow_missing
)
# TODO: Is this needed or redundant?
# Also fix up the bases, just in case.
for base in inst.type.bases:
if base.type is NOT_READY:
base.accept(self)
for a in inst.args:
a.accept(self)
if inst.last_known_value is not None:
inst.last_known_value.accept(self)
def visit_type_alias_type(self, t: TypeAliasType) -> None:
type_ref = t.type_ref
if type_ref is None:
return # We've already been here.
t.type_ref = None
t.alias = lookup_fully_qualified_alias(
self.modules, type_ref, allow_missing=self.allow_missing
)
for a in t.args:
a.accept(self)
def visit_any(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_callable_type(self, ct: CallableType) -> None:
if ct.fallback:
ct.fallback.accept(self)
for argt in ct.arg_types:
# argt may be None, e.g. for __self in NamedTuple constructors.
if argt is not None:
argt.accept(self)
if ct.ret_type is not None:
ct.ret_type.accept(self)
for v in ct.variables:
v.accept(self)
for arg in ct.bound_args:
if arg:
arg.accept(self)
if ct.type_guard is not None:
ct.type_guard.accept(self)
def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items:
ct.accept(self)
def visit_erased_type(self, o: Any) -> None:
# This type should exist only temporarily during type inference
raise RuntimeError("Shouldn't get here", o)
def visit_deleted_type(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_none_type(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_uninhabited_type(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_partial_type(self, o: Any) -> None:
raise RuntimeError("Shouldn't get here", o)
def visit_tuple_type(self, tt: TupleType) -> None:
if tt.items:
for it in tt.items:
it.accept(self)
if tt.partial_fallback is not None:
tt.partial_fallback.accept(self)
def visit_typeddict_type(self, tdt: TypedDictType) -> None:
if tdt.items:
for it in tdt.items.values():
it.accept(self)
if tdt.fallback is not None:
if tdt.fallback.type_ref is not None:
if (
lookup_fully_qualified(
tdt.fallback.type_ref,
self.modules,
raise_on_missing=not self.allow_missing,
)
is None
):
# We reject fake TypeInfos for TypedDict fallbacks because
# the latter are used in type checking and must be valid.
tdt.fallback.type_ref = "typing._TypedDict"
tdt.fallback.accept(self)
def visit_literal_type(self, lt: LiteralType) -> None:
lt.fallback.accept(self)
def visit_type_var(self, tvt: TypeVarType) -> None:
if tvt.values:
for vt in tvt.values:
vt.accept(self)
tvt.upper_bound.accept(self)
tvt.default.accept(self)
def visit_param_spec(self, p: ParamSpecType) -> None:
p.upper_bound.accept(self)
p.default.accept(self)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.upper_bound.accept(self)
t.default.accept(self)
def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)
def visit_parameters(self, p: Parameters) -> None:
for argt in p.arg_types:
if argt is not None:
argt.accept(self)
for var in p.variables:
var.accept(self)
def visit_unbound_type(self, o: UnboundType) -> None:
for a in o.args:
a.accept(self)
def visit_union_type(self, ut: UnionType) -> None:
if ut.items:
for it in ut.items:
it.accept(self)
def visit_void(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_type_type(self, t: TypeType) -> None:
t.item.accept(self)
def lookup_fully_qualified_typeinfo(
modules: dict[str, MypyFile], name: str, *, allow_missing: bool
) -> TypeInfo:
stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing)
node = stnode.node if stnode else None
if isinstance(node, TypeInfo):
return node
else:
# Looks like a missing TypeInfo during an initial daemon load, put something there
assert (
allow_missing
), "Should never get here in normal mode, got {}:{} instead of TypeInfo".format(
type(node).__name__, node.fullname if node else ""
)
return missing_info(modules)
def lookup_fully_qualified_alias(
modules: dict[str, MypyFile], name: str, *, allow_missing: bool
) -> TypeAlias:
stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing)
node = stnode.node if stnode else None
if isinstance(node, TypeAlias):
return node
elif isinstance(node, TypeInfo):
if node.special_alias:
# Already fixed up.
return node.special_alias
if node.tuple_type:
alias = TypeAlias.from_tuple_type(node)
elif node.typeddict_type:
alias = TypeAlias.from_typeddict_type(node)
else:
assert allow_missing
return missing_alias()
node.special_alias = alias
return alias
else:
# Looks like a missing TypeAlias during an initial daemon load, put something there
assert (
allow_missing
), "Should never get here in normal mode, got {}:{} instead of TypeAlias".format(
type(node).__name__, node.fullname if node else ""
)
return missing_alias()
_SUGGESTION: Final = "<missing {}: *should* have gone away during fine-grained update>"
def missing_info(modules: dict[str, MypyFile]) -> TypeInfo:
suggestion = _SUGGESTION.format("info")
dummy_def = ClassDef(suggestion, Block([]))
dummy_def.fullname = suggestion
info = TypeInfo(SymbolTable(), dummy_def, "<missing>")
obj_type = lookup_fully_qualified_typeinfo(modules, "builtins.object", allow_missing=False)
info.bases = [Instance(obj_type, [])]
info.mro = [info, obj_type]
return info
def missing_alias() -> TypeAlias:
suggestion = _SUGGESTION.format("alias")
return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, line=-1, column=-1)

View File

@@ -0,0 +1,23 @@
"""Generic node traverser visitor"""
from __future__ import annotations
from mypy.nodes import Block, MypyFile
from mypy.traverser import TraverserVisitor
class TreeFreer(TraverserVisitor):
def visit_block(self, block: Block) -> None:
super().visit_block(block)
block.body.clear()
def free_tree(tree: MypyFile) -> None:
"""Free all the ASTs associated with a module.
This needs to be done recursively, since symbol tables contain
references to definitions, so those won't be freed but we want their
contents to be.
"""
tree.accept(TreeFreer())
tree.defs.clear()

View File

@@ -0,0 +1,309 @@
"""Interface for accessing the file system with automatic caching.
The idea is to cache the results of any file system state reads during
a single transaction. This has two main benefits:
* This avoids redundant syscalls, as we won't perform the same OS
operations multiple times.
* This makes it easier to reason about concurrent FS updates, as different
operations targeting the same paths can't report different state during
a transaction.
Note that this only deals with reading state, not writing.
Properties maintained by the API:
* The contents of the file are always from the same or later time compared
to the reported mtime of the file, even if mtime is queried after reading
a file.
* Repeating an operation produces the same result as the first one during
a transaction.
* Call flush() to start a new transaction (flush the caches).
The API is a bit limited. It's easy to add new cached operations, however.
You should perform all file system reads through the API to actually take
advantage of the benefits.
"""
from __future__ import annotations
import os
import stat
from mypy_extensions import mypyc_attr
from mypy.util import hash_digest
@mypyc_attr(allow_interpreted_subclasses=True) # for tests
class FileSystemCache:
def __init__(self) -> None:
# The package root is not flushed with the caches.
# It is set by set_package_root() below.
self.package_root: list[str] = []
self.flush()
def set_package_root(self, package_root: list[str]) -> None:
self.package_root = package_root
def flush(self) -> None:
"""Start another transaction and empty all caches."""
self.stat_cache: dict[str, os.stat_result] = {}
self.stat_error_cache: dict[str, OSError] = {}
self.listdir_cache: dict[str, list[str]] = {}
self.listdir_error_cache: dict[str, OSError] = {}
self.isfile_case_cache: dict[str, bool] = {}
self.exists_case_cache: dict[str, bool] = {}
self.read_cache: dict[str, bytes] = {}
self.read_error_cache: dict[str, Exception] = {}
self.hash_cache: dict[str, str] = {}
self.fake_package_cache: set[str] = set()
def stat(self, path: str) -> os.stat_result:
if path in self.stat_cache:
return self.stat_cache[path]
if path in self.stat_error_cache:
raise copy_os_error(self.stat_error_cache[path])
try:
st = os.stat(path)
except OSError as err:
if self.init_under_package_root(path):
try:
return self._fake_init(path)
except OSError:
pass
# Take a copy to get rid of associated traceback and frame objects.
# Just assigning to __traceback__ doesn't free them.
self.stat_error_cache[path] = copy_os_error(err)
raise err
self.stat_cache[path] = st
return st
def init_under_package_root(self, path: str) -> bool:
"""Is this path an __init__.py under a package root?
This is used to detect packages that don't contain __init__.py
files, which is needed to support Bazel. The function should
only be called for non-existing files.
It will return True if it refers to a __init__.py file that
Bazel would create, so that at runtime Python would think the
directory containing it is a package. For this to work you
must pass one or more package roots using the --package-root
flag.
As an exceptional case, any directory that is a package root
itself will not be considered to contain a __init__.py file.
This is different from the rules Bazel itself applies, but is
necessary for mypy to properly distinguish packages from other
directories.
See https://docs.bazel.build/versions/master/be/python.html,
where this behavior is described under legacy_create_init.
"""
if not self.package_root:
return False
dirname, basename = os.path.split(path)
if basename != "__init__.py":
return False
if not os.path.basename(dirname).isidentifier():
# Can't put an __init__.py in a place that's not an identifier
return False
try:
st = self.stat(dirname)
except OSError:
return False
else:
if not stat.S_ISDIR(st.st_mode):
return False
ok = False
drive, path = os.path.splitdrive(path) # Ignore Windows drive name
if os.path.isabs(path):
path = os.path.relpath(path)
path = os.path.normpath(path)
for root in self.package_root:
if path.startswith(root):
if path == root + basename:
# A package root itself is never a package.
ok = False
break
else:
ok = True
return ok
def _fake_init(self, path: str) -> os.stat_result:
"""Prime the cache with a fake __init__.py file.
This makes code that looks for path believe an empty file by
that name exists. Should only be called after
init_under_package_root() returns True.
"""
dirname, basename = os.path.split(path)
assert basename == "__init__.py", path
assert not os.path.exists(path), path # Not cached!
dirname = os.path.normpath(dirname)
st = self.stat(dirname) # May raise OSError
# Get stat result as a list so we can modify it.
seq: list[float] = list(st)
seq[stat.ST_MODE] = stat.S_IFREG | 0o444
seq[stat.ST_INO] = 1
seq[stat.ST_NLINK] = 1
seq[stat.ST_SIZE] = 0
st = os.stat_result(seq)
self.stat_cache[path] = st
# Make listdir() and read() also pretend this file exists.
self.fake_package_cache.add(dirname)
return st
def listdir(self, path: str) -> list[str]:
path = os.path.normpath(path)
if path in self.listdir_cache:
res = self.listdir_cache[path]
# Check the fake cache.
if path in self.fake_package_cache and "__init__.py" not in res:
res.append("__init__.py") # Updates the result as well as the cache
return res
if path in self.listdir_error_cache:
raise copy_os_error(self.listdir_error_cache[path])
try:
results = os.listdir(path)
except OSError as err:
# Like above, take a copy to reduce memory use.
self.listdir_error_cache[path] = copy_os_error(err)
raise err
self.listdir_cache[path] = results
# Check the fake cache.
if path in self.fake_package_cache and "__init__.py" not in results:
results.append("__init__.py")
return results
def isfile(self, path: str) -> bool:
try:
st = self.stat(path)
except OSError:
return False
return stat.S_ISREG(st.st_mode)
def isfile_case(self, path: str, prefix: str) -> bool:
"""Return whether path exists and is a file.
On case-insensitive filesystems (like Mac or Windows) this returns
False if the case of path's last component does not exactly match
the case found in the filesystem.
We check also the case of other path components up to prefix.
For example, if path is 'user-stubs/pack/mod.pyi' and prefix is 'user-stubs',
we check that the case of 'pack' and 'mod.py' matches exactly, 'user-stubs' will be
case insensitive on case insensitive filesystems.
The caller must ensure that prefix is a valid file system prefix of path.
"""
if not self.isfile(path):
# Fast path
return False
if path in self.isfile_case_cache:
return self.isfile_case_cache[path]
head, tail = os.path.split(path)
if not tail:
self.isfile_case_cache[path] = False
return False
try:
names = self.listdir(head)
# This allows one to check file name case sensitively in
# case-insensitive filesystems.
res = tail in names
except OSError:
res = False
if res:
# Also recursively check the other path components in case sensitive way.
res = self.exists_case(head, prefix)
self.isfile_case_cache[path] = res
return res
def exists_case(self, path: str, prefix: str) -> bool:
"""Return whether path exists - checking path components in case sensitive
fashion, up to prefix.
"""
if path in self.exists_case_cache:
return self.exists_case_cache[path]
head, tail = os.path.split(path)
if not head.startswith(prefix) or not tail:
# Only perform the check for paths under prefix.
self.exists_case_cache[path] = True
return True
try:
names = self.listdir(head)
# This allows one to check file name case sensitively in
# case-insensitive filesystems.
res = tail in names
except OSError:
res = False
if res:
# Also recursively check other path components.
res = self.exists_case(head, prefix)
self.exists_case_cache[path] = res
return res
def isdir(self, path: str) -> bool:
try:
st = self.stat(path)
except OSError:
return False
return stat.S_ISDIR(st.st_mode)
def exists(self, path: str) -> bool:
try:
self.stat(path)
except FileNotFoundError:
return False
return True
def read(self, path: str) -> bytes:
if path in self.read_cache:
return self.read_cache[path]
if path in self.read_error_cache:
raise self.read_error_cache[path]
# Need to stat first so that the contents of file are from no
# earlier instant than the mtime reported by self.stat().
self.stat(path)
dirname, basename = os.path.split(path)
dirname = os.path.normpath(dirname)
# Check the fake cache.
if basename == "__init__.py" and dirname in self.fake_package_cache:
data = b""
else:
try:
with open(path, "rb") as f:
data = f.read()
except OSError as err:
self.read_error_cache[path] = err
raise
self.read_cache[path] = data
self.hash_cache[path] = hash_digest(data)
return data
def hash_digest(self, path: str) -> str:
if path not in self.hash_cache:
self.read(path)
return self.hash_cache[path]
def samefile(self, f1: str, f2: str) -> bool:
s1 = self.stat(f1)
s2 = self.stat(f2)
return os.path.samestat(s1, s2)
def copy_os_error(e: OSError) -> OSError:
new = OSError(*e.args)
new.errno = e.errno
new.strerror = e.strerror
new.filename = e.filename
if e.filename2:
new.filename2 = e.filename2
return new

View File

@@ -0,0 +1,106 @@
"""Watch parts of the file system for changes."""
from __future__ import annotations
from typing import AbstractSet, Iterable, NamedTuple
from mypy.fscache import FileSystemCache
class FileData(NamedTuple):
st_mtime: float
st_size: int
hash: str
class FileSystemWatcher:
"""Watcher for file system changes among specific paths.
All file system access is performed using FileSystemCache. We
detect changed files by stat()ing them all and comparing hashes
of potentially changed files. If a file has both size and mtime
unmodified, the file is assumed to be unchanged.
An important goal of this class is to make it easier to eventually
use file system events to detect file changes.
Note: This class doesn't flush the file system cache. If you don't
manually flush it, changes won't be seen.
"""
# TODO: Watching directories?
# TODO: Handle non-files
def __init__(self, fs: FileSystemCache) -> None:
self.fs = fs
self._paths: set[str] = set()
self._file_data: dict[str, FileData | None] = {}
def dump_file_data(self) -> dict[str, tuple[float, int, str]]:
return {k: v for k, v in self._file_data.items() if v is not None}
def set_file_data(self, path: str, data: FileData) -> None:
self._file_data[path] = data
def add_watched_paths(self, paths: Iterable[str]) -> None:
for path in paths:
if path not in self._paths:
# By storing None this path will get reported as changed by
# find_changed if it exists.
self._file_data[path] = None
self._paths |= set(paths)
def remove_watched_paths(self, paths: Iterable[str]) -> None:
for path in paths:
if path in self._file_data:
del self._file_data[path]
self._paths -= set(paths)
def _update(self, path: str) -> None:
st = self.fs.stat(path)
hash_digest = self.fs.hash_digest(path)
self._file_data[path] = FileData(st.st_mtime, st.st_size, hash_digest)
def _find_changed(self, paths: Iterable[str]) -> AbstractSet[str]:
changed = set()
for path in paths:
old = self._file_data[path]
try:
st = self.fs.stat(path)
except FileNotFoundError:
if old is not None:
# File was deleted.
changed.add(path)
self._file_data[path] = None
else:
if old is None:
# File is new.
changed.add(path)
self._update(path)
# Round mtimes down, to match the mtimes we write to meta files
elif st.st_size != old.st_size or int(st.st_mtime) != int(old.st_mtime):
# Only look for changes if size or mtime has changed as an
# optimization, since calculating hash is expensive.
new_hash = self.fs.hash_digest(path)
self._update(path)
if st.st_size != old.st_size or new_hash != old.hash:
# Changed file.
changed.add(path)
return changed
def find_changed(self) -> AbstractSet[str]:
"""Return paths that have changes since the last call, in the watched set."""
return self._find_changed(self._paths)
def update_changed(self, remove: list[str], update: list[str]) -> AbstractSet[str]:
"""Alternative to find_changed() given explicit changes.
This only calls self.fs.stat() on added or updated files, not
on all files. It believes all other files are unchanged!
Implies add_watched_paths() for add and update, and
remove_watched_paths() for remove.
"""
self.remove_watched_paths(remove)
self.add_watched_paths(update)
return self._find_changed(update)

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
import gc
import time
from typing import Mapping
class GcLogger:
"""Context manager to log GC stats and overall time."""
def __enter__(self) -> GcLogger:
self.gc_start_time: float | None = None
self.gc_time = 0.0
self.gc_calls = 0
self.gc_collected = 0
self.gc_uncollectable = 0
gc.callbacks.append(self.gc_callback)
self.start_time = time.time()
return self
def gc_callback(self, phase: str, info: Mapping[str, int]) -> None:
if phase == "start":
assert self.gc_start_time is None, "Start phase out of sequence"
self.gc_start_time = time.time()
elif phase == "stop":
assert self.gc_start_time is not None, "Stop phase out of sequence"
self.gc_calls += 1
self.gc_time += time.time() - self.gc_start_time
self.gc_start_time = None
self.gc_collected += info["collected"]
self.gc_uncollectable += info["uncollectable"]
else:
assert False, f"Unrecognized gc phase ({phase!r})"
def __exit__(self, *args: object) -> None:
while self.gc_callback in gc.callbacks:
gc.callbacks.remove(self.gc_callback)
def get_stats(self) -> Mapping[str, float]:
end_time = time.time()
result = {}
result["gc_time"] = self.gc_time
result["gc_calls"] = self.gc_calls
result["gc_collected"] = self.gc_collected
result["gc_uncollectable"] = self.gc_uncollectable
result["build_time"] = end_time - self.start_time
return result

View File

@@ -0,0 +1,34 @@
"""Git utilities."""
# Used also from setup.py, so don't pull in anything additional here (like mypy or typing):
from __future__ import annotations
import os
import subprocess
def is_git_repo(dir: str) -> bool:
"""Is the given directory version-controlled with git?"""
return os.path.exists(os.path.join(dir, ".git"))
def have_git() -> bool:
"""Can we run the git executable?"""
try:
subprocess.check_output(["git", "--help"])
return True
except subprocess.CalledProcessError:
return False
except OSError:
return False
def git_revision(dir: str) -> bytes:
"""Get the SHA-1 of the HEAD of a git repository."""
return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dir).strip()
def is_dirty(dir: str) -> bool:
"""Check whether a git repository has uncommitted changes."""
output = subprocess.check_output(["git", "status", "-uno", "--porcelain"], cwd=dir)
return output.strip() != b""

View File

@@ -0,0 +1,112 @@
"""Helpers for manipulations with graphs."""
from __future__ import annotations
from typing import AbstractSet, Iterable, Iterator, TypeVar
T = TypeVar("T")
def strongly_connected_components(
vertices: AbstractSet[T], edges: dict[T, list[T]]
) -> Iterator[set[T]]:
"""Compute Strongly Connected Components of a directed graph.
Args:
vertices: the labels for the vertices
edges: for each vertex, gives the target vertices of its outgoing edges
Returns:
An iterator yielding strongly connected components, each
represented as a set of vertices. Each input vertex will occur
exactly once; vertices not part of a SCC are returned as
singleton sets.
From https://code.activestate.com/recipes/578507/.
"""
identified: set[T] = set()
stack: list[T] = []
index: dict[T, int] = {}
boundaries: list[int] = []
def dfs(v: T) -> Iterator[set[T]]:
index[v] = len(stack)
stack.append(v)
boundaries.append(index[v])
for w in edges[v]:
if w not in index:
yield from dfs(w)
elif w not in identified:
while index[w] < boundaries[-1]:
boundaries.pop()
if boundaries[-1] == index[v]:
boundaries.pop()
scc = set(stack[index[v] :])
del stack[index[v] :]
identified.update(scc)
yield scc
for v in vertices:
if v not in index:
yield from dfs(v)
def prepare_sccs(
sccs: list[set[T]], edges: dict[T, list[T]]
) -> dict[AbstractSet[T], set[AbstractSet[T]]]:
"""Use original edges to organize SCCs in a graph by dependencies between them."""
sccsmap = {v: frozenset(scc) for scc in sccs for v in scc}
data: dict[AbstractSet[T], set[AbstractSet[T]]] = {}
for scc in sccs:
deps: set[AbstractSet[T]] = set()
for v in scc:
deps.update(sccsmap[x] for x in edges[v])
data[frozenset(scc)] = deps
return data
def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]:
"""Topological sort.
Args:
data: A map from vertices to all vertices that it has an edge
connecting it to. NOTE: This data structure
is modified in place -- for normalization purposes,
self-dependencies are removed and entries representing
orphans are added.
Returns:
An iterator yielding sets of vertices that have an equivalent
ordering.
Example:
Suppose the input has the following structure:
{A: {B, C}, B: {D}, C: {D}}
This is normalized to:
{A: {B, C}, B: {D}, C: {D}, D: {}}
The algorithm will yield the following values:
{D}
{B, C}
{A}
From https://code.activestate.com/recipes/577413/.
"""
# TODO: Use a faster algorithm?
for k, v in data.items():
v.discard(k) # Ignore self dependencies.
for item in set.union(*data.values()) - set(data.keys()):
data[item] = set()
while True:
ready = {item for item, dep in data.items() if not dep}
if not ready:
break
yield ready
data = {item: (dep - ready) for item, dep in data.items() if item not in ready}
assert not data, f"A cyclic dependency exists amongst {data!r}"

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
from typing import Iterable, Set
import mypy.types as types
from mypy.types import TypeVisitor
from mypy.util import split_module_names
def extract_module_names(type_name: str | None) -> list[str]:
"""Returns the module names of a fully qualified type name."""
if type_name is not None:
# Discard the first one, which is just the qualified name of the type
possible_module_names = split_module_names(type_name)
return possible_module_names[1:]
else:
return []
class TypeIndirectionVisitor(TypeVisitor[Set[str]]):
"""Returns all module references within a particular type."""
def __init__(self) -> None:
self.cache: dict[types.Type, set[str]] = {}
self.seen_aliases: set[types.TypeAliasType] = set()
def find_modules(self, typs: Iterable[types.Type]) -> set[str]:
self.seen_aliases.clear()
return self._visit(typs)
def _visit(self, typ_or_typs: types.Type | Iterable[types.Type]) -> set[str]:
typs = [typ_or_typs] if isinstance(typ_or_typs, types.Type) else typ_or_typs
output: set[str] = set()
for typ in typs:
if isinstance(typ, types.TypeAliasType):
# Avoid infinite recursion for recursive type aliases.
if typ in self.seen_aliases:
continue
self.seen_aliases.add(typ)
if typ in self.cache:
modules = self.cache[typ]
else:
modules = typ.accept(self)
self.cache[typ] = set(modules)
output.update(modules)
return output
def visit_unbound_type(self, t: types.UnboundType) -> set[str]:
return self._visit(t.args)
def visit_any(self, t: types.AnyType) -> set[str]:
return set()
def visit_none_type(self, t: types.NoneType) -> set[str]:
return set()
def visit_uninhabited_type(self, t: types.UninhabitedType) -> set[str]:
return set()
def visit_erased_type(self, t: types.ErasedType) -> set[str]:
return set()
def visit_deleted_type(self, t: types.DeletedType) -> set[str]:
return set()
def visit_type_var(self, t: types.TypeVarType) -> set[str]:
return self._visit(t.values) | self._visit(t.upper_bound) | self._visit(t.default)
def visit_param_spec(self, t: types.ParamSpecType) -> set[str]:
return self._visit(t.upper_bound) | self._visit(t.default)
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> set[str]:
return self._visit(t.upper_bound) | self._visit(t.default)
def visit_unpack_type(self, t: types.UnpackType) -> set[str]:
return t.type.accept(self)
def visit_parameters(self, t: types.Parameters) -> set[str]:
return self._visit(t.arg_types)
def visit_instance(self, t: types.Instance) -> set[str]:
out = self._visit(t.args)
if t.type:
# Uses of a class depend on everything in the MRO,
# as changes to classes in the MRO can add types to methods,
# change property types, change the MRO itself, etc.
for s in t.type.mro:
out.update(split_module_names(s.module_name))
if t.type.metaclass_type is not None:
out.update(split_module_names(t.type.metaclass_type.type.module_name))
return out
def visit_callable_type(self, t: types.CallableType) -> set[str]:
out = self._visit(t.arg_types) | self._visit(t.ret_type)
if t.definition is not None:
out.update(extract_module_names(t.definition.fullname))
return out
def visit_overloaded(self, t: types.Overloaded) -> set[str]:
return self._visit(t.items) | self._visit(t.fallback)
def visit_tuple_type(self, t: types.TupleType) -> set[str]:
return self._visit(t.items) | self._visit(t.partial_fallback)
def visit_typeddict_type(self, t: types.TypedDictType) -> set[str]:
return self._visit(t.items.values()) | self._visit(t.fallback)
def visit_literal_type(self, t: types.LiteralType) -> set[str]:
return self._visit(t.fallback)
def visit_union_type(self, t: types.UnionType) -> set[str]:
return self._visit(t.items)
def visit_partial_type(self, t: types.PartialType) -> set[str]:
return set()
def visit_type_type(self, t: types.TypeType) -> set[str]:
return self._visit(t.item)
def visit_type_alias_type(self, t: types.TypeAliasType) -> set[str]:
return self._visit(types.get_proper_type(t))

View File

@@ -0,0 +1,71 @@
"""Utilities for type argument inference."""
from __future__ import annotations
from typing import NamedTuple, Sequence
from mypy.constraints import (
SUBTYPE_OF,
SUPERTYPE_OF,
infer_constraints,
infer_constraints_for_callable,
)
from mypy.nodes import ArgKind
from mypy.solve import solve_constraints
from mypy.types import CallableType, Instance, Type, TypeVarLikeType
class ArgumentInferContext(NamedTuple):
"""Type argument inference context.
We need this because we pass around ``Mapping`` and ``Iterable`` types.
These types are only known by ``TypeChecker`` itself.
It is required for ``*`` and ``**`` argument inference.
https://github.com/python/mypy/issues/11144
"""
mapping_type: Instance
iterable_type: Instance
def infer_function_type_arguments(
callee_type: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
strict: bool = True,
allow_polymorphic: bool = False,
) -> tuple[list[Type | None], list[TypeVarLikeType]]:
"""Infer the type arguments of a generic function.
Return an array of lower bound types for the type variables -1 (at
index 0), -2 (at index 1), etc. A lower bound is None if a value
could not be inferred.
Arguments:
callee_type: the target generic function
arg_types: argument types at the call site (each optional; if None,
we are not considering this argument in the current pass)
arg_kinds: nodes.ARG_* values for arg_types
formal_to_actual: mapping from formal to actual variable indices
"""
# Infer constraints.
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context
)
# Solve constraints.
type_vars = callee_type.variables
return solve_constraints(type_vars, constraints, strict, allow_polymorphic)
def infer_type_arguments(
type_vars: Sequence[TypeVarLikeType], template: Type, actual: Type, is_supertype: bool = False
) -> list[Type | None]:
# Like infer_function_type_arguments, but only match a single type
# against a generic type.
constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF)
return solve_constraints(type_vars, constraints)[0]

View File

@@ -0,0 +1,625 @@
from __future__ import annotations
import os
from collections import defaultdict
from functools import cmp_to_key
from typing import Callable
from mypy.build import State
from mypy.find_sources import InvalidSourceList, SourceFinder
from mypy.messages import format_type
from mypy.modulefinder import PYTHON_EXTENSIONS
from mypy.nodes import (
LDEF,
Decorator,
Expression,
FuncBase,
MemberExpr,
MypyFile,
Node,
OverloadedFuncDef,
RefExpr,
SymbolNode,
TypeInfo,
Var,
)
from mypy.server.update import FineGrainedBuildManager
from mypy.traverser import ExtendedTraverserVisitor
from mypy.typeops import tuple_fallback
from mypy.types import (
FunctionLike,
Instance,
LiteralType,
ProperType,
TupleType,
TypedDictType,
TypeVarType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars_with_any
def node_starts_after(o: Node, line: int, column: int) -> bool:
return o.line > line or o.line == line and o.column > column
def node_ends_before(o: Node, line: int, column: int) -> bool:
# Unfortunately, end positions for some statements are a mess,
# e.g. overloaded functions, so we return False when we don't know.
if o.end_line is not None and o.end_column is not None:
if o.end_line < line or o.end_line == line and o.end_column < column:
return True
return False
def expr_span(expr: Expression) -> str:
"""Format expression span as in mypy error messages."""
return f"{expr.line}:{expr.column + 1}:{expr.end_line}:{expr.end_column}"
def get_instance_fallback(typ: ProperType) -> list[Instance]:
"""Returns the Instance fallback for this type if one exists or None."""
if isinstance(typ, Instance):
return [typ]
elif isinstance(typ, TupleType):
return [tuple_fallback(typ)]
elif isinstance(typ, TypedDictType):
return [typ.fallback]
elif isinstance(typ, FunctionLike):
return [typ.fallback]
elif isinstance(typ, LiteralType):
return [typ.fallback]
elif isinstance(typ, TypeVarType):
if typ.values:
res = []
for t in typ.values:
res.extend(get_instance_fallback(get_proper_type(t)))
return res
return get_instance_fallback(get_proper_type(typ.upper_bound))
elif isinstance(typ, UnionType):
res = []
for t in typ.items:
res.extend(get_instance_fallback(get_proper_type(t)))
return res
return []
def find_node(name: str, info: TypeInfo) -> Var | FuncBase | None:
"""Find the node defining member 'name' in given TypeInfo."""
# TODO: this code shares some logic with checkmember.py
method = info.get_method(name)
if method:
if isinstance(method, Decorator):
return method.var
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
dec = method.items[0]
assert isinstance(dec, Decorator)
return dec.var
return method
else:
# don't have such method, maybe variable?
node = info.get(name)
v = node.node if node else None
if isinstance(v, Var):
return v
return None
def find_module_by_fullname(fullname: str, modules: dict[str, State]) -> State | None:
"""Find module by a node fullname.
This logic mimics the one we use in fixup, so should be good enough.
"""
head = fullname
# Special case: a module symbol is considered to be defined in itself, not in enclosing
# package, since this is what users want when clicking go to definition on a module.
if head in modules:
return modules[head]
while True:
if "." not in head:
return None
head, tail = head.rsplit(".", maxsplit=1)
mod = modules.get(head)
if mod is not None:
return mod
class SearchVisitor(ExtendedTraverserVisitor):
"""Visitor looking for an expression whose span matches given one exactly."""
def __init__(self, line: int, column: int, end_line: int, end_column: int) -> None:
self.line = line
self.column = column
self.end_line = end_line
self.end_column = end_column
self.result: Expression | None = None
def visit(self, o: Node) -> bool:
if node_starts_after(o, self.line, self.column):
return False
if node_ends_before(o, self.end_line, self.end_column):
return False
if (
o.line == self.line
and o.end_line == self.end_line
and o.column == self.column
and o.end_column == self.end_column
):
if isinstance(o, Expression):
self.result = o
return self.result is None
def find_by_location(
tree: MypyFile, line: int, column: int, end_line: int, end_column: int
) -> Expression | None:
"""Find an expression matching given span, or None if not found."""
if end_line < line:
raise ValueError('"end_line" must not be before "line"')
if end_line == line and end_column <= column:
raise ValueError('"end_column" must be after "column"')
visitor = SearchVisitor(line, column, end_line, end_column)
tree.accept(visitor)
return visitor.result
class SearchAllVisitor(ExtendedTraverserVisitor):
"""Visitor looking for all expressions whose spans enclose given position."""
def __init__(self, line: int, column: int) -> None:
self.line = line
self.column = column
self.result: list[Expression] = []
def visit(self, o: Node) -> bool:
if node_starts_after(o, self.line, self.column):
return False
if node_ends_before(o, self.line, self.column):
return False
if isinstance(o, Expression):
self.result.append(o)
return True
def find_all_by_location(tree: MypyFile, line: int, column: int) -> list[Expression]:
"""Find all expressions enclosing given position starting from innermost."""
visitor = SearchAllVisitor(line, column)
tree.accept(visitor)
return list(reversed(visitor.result))
class InspectionEngine:
"""Engine for locating and statically inspecting expressions."""
def __init__(
self,
fg_manager: FineGrainedBuildManager,
*,
verbosity: int = 0,
limit: int = 0,
include_span: bool = False,
include_kind: bool = False,
include_object_attrs: bool = False,
union_attrs: bool = False,
force_reload: bool = False,
) -> None:
self.fg_manager = fg_manager
self.finder = SourceFinder(
self.fg_manager.manager.fscache, self.fg_manager.manager.options
)
self.verbosity = verbosity
self.limit = limit
self.include_span = include_span
self.include_kind = include_kind
self.include_object_attrs = include_object_attrs
self.union_attrs = union_attrs
self.force_reload = force_reload
# Module for which inspection was requested.
self.module: State | None = None
def parse_location(self, location: str) -> tuple[str, list[int]]:
if location.count(":") not in [2, 4]:
raise ValueError("Format should be file:line:column[:end_line:end_column]")
parts = location.split(":")
module, *rest = parts
return module, [int(p) for p in rest]
def reload_module(self, state: State) -> None:
"""Reload given module while temporary exporting types."""
old = self.fg_manager.manager.options.export_types
self.fg_manager.manager.options.export_types = True
try:
self.fg_manager.flush_cache()
assert state.path is not None
self.fg_manager.update([(state.id, state.path)], [])
finally:
self.fg_manager.manager.options.export_types = old
def expr_type(self, expression: Expression) -> tuple[str, bool]:
"""Format type for an expression using current options.
If type is known, second item returned is True. If type is not known, an error
message is returned instead, and second item returned is False.
"""
expr_type = self.fg_manager.manager.all_types.get(expression)
if expr_type is None:
return self.missing_type(expression), False
type_str = format_type(
expr_type, self.fg_manager.manager.options, verbosity=self.verbosity
)
return self.add_prefixes(type_str, expression), True
def object_type(self) -> Instance:
builtins = self.fg_manager.graph["builtins"].tree
assert builtins is not None
object_node = builtins.names["object"].node
assert isinstance(object_node, TypeInfo)
return Instance(object_node, [])
def collect_attrs(self, instances: list[Instance]) -> dict[TypeInfo, list[str]]:
"""Collect attributes from all union/typevar variants."""
def item_attrs(attr_dict: dict[TypeInfo, list[str]]) -> set[str]:
attrs = set()
for base in attr_dict:
attrs |= set(attr_dict[base])
return attrs
def cmp_types(x: TypeInfo, y: TypeInfo) -> int:
if x in y.mro:
return 1
if y in x.mro:
return -1
return 0
# First gather all attributes for every union variant.
assert instances
all_attrs = []
for instance in instances:
attrs = {}
mro = instance.type.mro
if not self.include_object_attrs:
mro = mro[:-1]
for base in mro:
attrs[base] = sorted(base.names)
all_attrs.append(attrs)
# Find attributes valid for all variants in a union or type variable.
intersection = item_attrs(all_attrs[0])
for item in all_attrs[1:]:
intersection &= item_attrs(item)
# Combine attributes from all variants into a single dict while
# also removing invalid attributes (unless using --union-attrs).
combined_attrs = defaultdict(list)
for item in all_attrs:
for base in item:
if base in combined_attrs:
continue
for name in item[base]:
if self.union_attrs or name in intersection:
combined_attrs[base].append(name)
# Sort bases by MRO, unrelated will appear in the order they appeared as union variants.
sorted_bases = sorted(combined_attrs.keys(), key=cmp_to_key(cmp_types))
result = {}
for base in sorted_bases:
if not combined_attrs[base]:
# Skip bases where everytihng was filtered out.
continue
result[base] = combined_attrs[base]
return result
def _fill_from_dict(
self, attrs_strs: list[str], attrs_dict: dict[TypeInfo, list[str]]
) -> None:
for base in attrs_dict:
cls_name = base.name if self.verbosity < 1 else base.fullname
attrs = [f'"{attr}"' for attr in attrs_dict[base]]
attrs_strs.append(f'"{cls_name}": [{", ".join(attrs)}]')
def expr_attrs(self, expression: Expression) -> tuple[str, bool]:
"""Format attributes that are valid for a given expression.
If expression type is not an Instance, try using fallback. Attributes are
returned as a JSON (ordered by MRO) that maps base class name to list of
attributes. Attributes may appear in multiple bases if overridden (we simply
follow usual mypy logic for creating new Vars etc).
"""
expr_type = self.fg_manager.manager.all_types.get(expression)
if expr_type is None:
return self.missing_type(expression), False
expr_type = get_proper_type(expr_type)
instances = get_instance_fallback(expr_type)
if not instances:
# Everything is an object in Python.
instances = [self.object_type()]
attrs_dict = self.collect_attrs(instances)
# Special case: modules have names apart from those from ModuleType.
if isinstance(expression, RefExpr) and isinstance(expression.node, MypyFile):
node = expression.node
names = sorted(node.names)
if "__builtins__" in names:
# This is just to make tests stable. No one will really need ths name.
names.remove("__builtins__")
mod_dict = {f'"<{node.fullname}>"': [f'"{name}"' for name in names]}
else:
mod_dict = {}
# Special case: for class callables, prepend with the class attributes.
# TODO: also handle cases when such callable appears in a union.
if isinstance(expr_type, FunctionLike) and expr_type.is_type_obj():
template = fill_typevars_with_any(expr_type.type_object())
class_dict = self.collect_attrs(get_instance_fallback(template))
else:
class_dict = {}
# We don't use JSON dump to be sure keys order is always preserved.
base_attrs = []
if mod_dict:
for mod in mod_dict:
base_attrs.append(f'{mod}: [{", ".join(mod_dict[mod])}]')
self._fill_from_dict(base_attrs, class_dict)
self._fill_from_dict(base_attrs, attrs_dict)
return self.add_prefixes(f'{{{", ".join(base_attrs)}}}', expression), True
def format_node(self, module: State, node: FuncBase | SymbolNode) -> str:
return f"{module.path}:{node.line}:{node.column + 1}:{node.name}"
def collect_nodes(self, expression: RefExpr) -> list[FuncBase | SymbolNode]:
"""Collect nodes that can be referred to by an expression.
Note: it can be more than one for example in case of a union attribute.
"""
node: FuncBase | SymbolNode | None = expression.node
nodes: list[FuncBase | SymbolNode]
if node is None:
# Tricky case: instance attribute
if isinstance(expression, MemberExpr) and expression.kind is None:
base_type = self.fg_manager.manager.all_types.get(expression.expr)
if base_type is None:
return []
# Now we use the base type to figure out where the attribute is defined.
base_type = get_proper_type(base_type)
instances = get_instance_fallback(base_type)
nodes = []
for instance in instances:
node = find_node(expression.name, instance.type)
if node:
nodes.append(node)
if not nodes:
# Try checking class namespace if attribute is on a class object.
if isinstance(base_type, FunctionLike) and base_type.is_type_obj():
instances = get_instance_fallback(
fill_typevars_with_any(base_type.type_object())
)
for instance in instances:
node = find_node(expression.name, instance.type)
if node:
nodes.append(node)
else:
# Still no luck, give up.
return []
else:
return []
else:
# Easy case: a module-level definition
nodes = [node]
return nodes
def modules_for_nodes(
self, nodes: list[FuncBase | SymbolNode], expression: RefExpr
) -> tuple[dict[FuncBase | SymbolNode, State], bool]:
"""Gather modules where given nodes where defined.
Also check if they need to be refreshed (cached nodes may have
lines/columns missing).
"""
modules = {}
reload_needed = False
for node in nodes:
module = find_module_by_fullname(node.fullname, self.fg_manager.graph)
if not module:
if expression.kind == LDEF and self.module:
module = self.module
else:
continue
modules[node] = module
if not module.tree or module.tree.is_cache_skeleton or self.force_reload:
reload_needed |= not module.tree or module.tree.is_cache_skeleton
self.reload_module(module)
return modules, reload_needed
def expression_def(self, expression: Expression) -> tuple[str, bool]:
"""Find and format definition location for an expression.
If it is not a RefExpr, it is effectively skipped by returning an
empty result.
"""
if not isinstance(expression, RefExpr):
# If there are no suitable matches at all, we return error later.
return "", True
nodes = self.collect_nodes(expression)
if not nodes:
return self.missing_node(expression), False
modules, reload_needed = self.modules_for_nodes(nodes, expression)
if reload_needed:
# TODO: line/column are not stored in cache for vast majority of symbol nodes.
# Adding them will make thing faster, but will have visible memory impact.
nodes = self.collect_nodes(expression)
modules, reload_needed = self.modules_for_nodes(nodes, expression)
assert not reload_needed
result = []
for node in modules:
result.append(self.format_node(modules[node], node))
if not result:
return self.missing_node(expression), False
return self.add_prefixes(", ".join(result), expression), True
def missing_type(self, expression: Expression) -> str:
alt_suggestion = ""
if not self.force_reload:
alt_suggestion = " or try --force-reload"
return (
f'No known type available for "{type(expression).__name__}"'
f" (maybe unreachable{alt_suggestion})"
)
def missing_node(self, expression: Expression) -> str:
return (
f'Cannot find definition for "{type(expression).__name__}"'
f" at {expr_span(expression)}"
)
def add_prefixes(self, result: str, expression: Expression) -> str:
prefixes = []
if self.include_kind:
prefixes.append(f"{type(expression).__name__}")
if self.include_span:
prefixes.append(expr_span(expression))
if prefixes:
prefix = ":".join(prefixes) + " -> "
else:
prefix = ""
return prefix + result
def run_inspection_by_exact_location(
self,
tree: MypyFile,
line: int,
column: int,
end_line: int,
end_column: int,
method: Callable[[Expression], tuple[str, bool]],
) -> dict[str, object]:
"""Get type of an expression matching a span.
Type or error is returned as a standard daemon response dict.
"""
try:
expression = find_by_location(tree, line, column - 1, end_line, end_column)
except ValueError as err:
return {"error": str(err)}
if expression is None:
span = f"{line}:{column}:{end_line}:{end_column}"
return {"out": f"Can't find expression at span {span}", "err": "", "status": 1}
inspection_str, success = method(expression)
return {"out": inspection_str, "err": "", "status": 0 if success else 1}
def run_inspection_by_position(
self,
tree: MypyFile,
line: int,
column: int,
method: Callable[[Expression], tuple[str, bool]],
) -> dict[str, object]:
"""Get types of all expressions enclosing a position.
Types and/or errors are returned as a standard daemon response dict.
"""
expressions = find_all_by_location(tree, line, column - 1)
if not expressions:
position = f"{line}:{column}"
return {
"out": f"Can't find any expressions at position {position}",
"err": "",
"status": 1,
}
inspection_strs = []
status = 0
for expression in expressions:
inspection_str, success = method(expression)
if not success:
status = 1
if inspection_str:
inspection_strs.append(inspection_str)
if self.limit:
inspection_strs = inspection_strs[: self.limit]
return {"out": "\n".join(inspection_strs), "err": "", "status": status}
def find_module(self, file: str) -> tuple[State | None, dict[str, object]]:
"""Find module by path, or return a suitable error message.
Note we don't use exceptions to simplify handling 1 vs 2 statuses.
"""
if not any(file.endswith(ext) for ext in PYTHON_EXTENSIONS):
return None, {"error": "Source file is not a Python file"}
try:
module, _ = self.finder.crawl_up(os.path.normpath(file))
except InvalidSourceList:
return None, {"error": "Invalid source file name: " + file}
state = self.fg_manager.graph.get(module)
self.module = state
return (
state,
{"out": f"Unknown module: {module}", "err": "", "status": 1} if state is None else {},
)
def run_inspection(
self, location: str, method: Callable[[Expression], tuple[str, bool]]
) -> dict[str, object]:
"""Top-level logic to inspect expression(s) at a location.
This can be re-used by various simple inspections.
"""
try:
file, pos = self.parse_location(location)
except ValueError as err:
return {"error": str(err)}
state, err_dict = self.find_module(file)
if state is None:
assert err_dict
return err_dict
# Force reloading to load from cache, account for any edits, etc.
if not state.tree or state.tree.is_cache_skeleton or self.force_reload:
self.reload_module(state)
assert state.tree is not None
if len(pos) == 4:
# Full span, return an exact match only.
line, column, end_line, end_column = pos
return self.run_inspection_by_exact_location(
state.tree, line, column, end_line, end_column, method
)
assert len(pos) == 2
# Inexact location, return all expressions.
line, column = pos
return self.run_inspection_by_position(state.tree, line, column, method)
def get_type(self, location: str) -> dict[str, object]:
"""Get types of expression(s) at a location."""
return self.run_inspection(location, self.expr_type)
def get_attrs(self, location: str) -> dict[str, object]:
"""Get attributes of expression(s) at a location."""
return self.run_inspection(location, self.expr_attrs)
def get_definition(self, location: str) -> dict[str, object]:
"""Get symbol definitions of expression(s) at a location."""
result = self.run_inspection(location, self.expression_def)
if "out" in result and not result["out"]:
# None of the expressions found turns out to be a RefExpr.
_, location = location.split(":", maxsplit=1)
result["out"] = f"No name or member expressions at {location}"
result["status"] = 1
return result

View File

@@ -0,0 +1,268 @@
"""Cross platform abstractions for inter-process communication
On Unix, this uses AF_UNIX sockets.
On Windows, this uses NamedPipes.
"""
from __future__ import annotations
import base64
import os
import shutil
import sys
import tempfile
from types import TracebackType
from typing import Callable, Final
if sys.platform == "win32":
# This may be private, but it is needed for IPC on Windows, and is basically stable
import ctypes
import _winapi
_IPCHandle = int
kernel32 = ctypes.windll.kernel32
DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe
FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers
else:
import socket
_IPCHandle = socket.socket
class IPCException(Exception):
"""Exception for IPC issues."""
class IPCBase:
"""Base class for communication between the dmypy client and server.
This contains logic shared between the client and server, such as reading
and writing.
"""
connection: _IPCHandle
def __init__(self, name: str, timeout: float | None) -> None:
self.name = name
self.timeout = timeout
def read(self, size: int = 100000) -> bytes:
"""Read bytes from an IPC connection until its empty."""
bdata = bytearray()
if sys.platform == "win32":
while True:
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
if res != _winapi.WAIT_OBJECT_0:
raise IPCException(f"Bad result from I/O wait: {res}")
except BaseException:
ov.cancel()
raise
_, err = ov.GetOverlappedResult(True)
more = ov.getbuffer()
if more:
bdata.extend(more)
if err == 0:
# we are done!
break
elif err == _winapi.ERROR_MORE_DATA:
# read again
continue
elif err == _winapi.ERROR_OPERATION_ABORTED:
raise IPCException("ReadFile operation aborted.")
else:
while True:
more = self.connection.recv(size)
if not more:
break
bdata.extend(more)
return bytes(bdata)
def write(self, data: bytes) -> None:
"""Write bytes to an IPC connection."""
if sys.platform == "win32":
try:
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
if res != _winapi.WAIT_OBJECT_0:
raise IPCException(f"Bad result from I/O wait: {res}")
elif err != 0:
raise IPCException(f"Failed writing to pipe with error: {err}")
except BaseException:
ov.cancel()
raise
bytes_written, err = ov.GetOverlappedResult(True)
assert err == 0, err
assert bytes_written == len(data)
except OSError as e:
raise IPCException(f"Failed to write with error: {e.winerror}") from e
else:
self.connection.sendall(data)
self.connection.shutdown(socket.SHUT_WR)
def close(self) -> None:
if sys.platform == "win32":
if self.connection != _winapi.NULL:
_winapi.CloseHandle(self.connection)
else:
self.connection.close()
class IPCClient(IPCBase):
"""The client side of an IPC connection."""
def __init__(self, name: str, timeout: float | None) -> None:
super().__init__(name, timeout)
if sys.platform == "win32":
timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
try:
_winapi.WaitNamedPipe(self.name, timeout)
except FileNotFoundError as e:
raise IPCException(f"The NamedPipe at {self.name} was not found.") from e
except OSError as e:
if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
raise IPCException("Timed out waiting for connection.") from e
else:
raise
try:
self.connection = _winapi.CreateFile(
self.name,
_winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
0,
_winapi.NULL,
_winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED,
_winapi.NULL,
)
except OSError as e:
if e.winerror == _winapi.ERROR_PIPE_BUSY:
raise IPCException("The connection is busy.") from e
else:
raise
_winapi.SetNamedPipeHandleState(
self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None
)
else:
self.connection = socket.socket(socket.AF_UNIX)
self.connection.settimeout(timeout)
self.connection.connect(name)
def __enter__(self) -> IPCClient:
return self
def __exit__(
self,
exc_ty: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: TracebackType | None = None,
) -> None:
self.close()
class IPCServer(IPCBase):
BUFFER_SIZE: Final = 2**16
def __init__(self, name: str, timeout: float | None = None) -> None:
if sys.platform == "win32":
name = r"\\.\pipe\{}-{}.pipe".format(
name, base64.urlsafe_b64encode(os.urandom(6)).decode()
)
else:
name = f"{name}.sock"
super().__init__(name, timeout)
if sys.platform == "win32":
self.connection = _winapi.CreateNamedPipe(
self.name,
_winapi.PIPE_ACCESS_DUPLEX
| _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
| _winapi.FILE_FLAG_OVERLAPPED,
_winapi.PIPE_READMODE_MESSAGE
| _winapi.PIPE_TYPE_MESSAGE
| _winapi.PIPE_WAIT
| 0x8, # PIPE_REJECT_REMOTE_CLIENTS
1, # one instance
self.BUFFER_SIZE,
self.BUFFER_SIZE,
_winapi.NMPWAIT_WAIT_FOREVER,
0, # Use default security descriptor
)
if self.connection == -1: # INVALID_HANDLE_VALUE
err = _winapi.GetLastError()
raise IPCException(f"Invalid handle to pipe: {err}")
else:
self.sock_directory = tempfile.mkdtemp()
sockfile = os.path.join(self.sock_directory, self.name)
self.sock = socket.socket(socket.AF_UNIX)
self.sock.bind(sockfile)
self.sock.listen(1)
if timeout is not None:
self.sock.settimeout(timeout)
def __enter__(self) -> IPCServer:
if sys.platform == "win32":
# NOTE: It is theoretically possible that this will hang forever if the
# client never connects, though this can be "solved" by killing the server
try:
ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
except OSError as e:
# Don't raise if the client already exists, or the client already connected
if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
raise
else:
try:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
assert res == _winapi.WAIT_OBJECT_0
except BaseException:
ov.cancel()
_winapi.CloseHandle(self.connection)
raise
_, err = ov.GetOverlappedResult(True)
assert err == 0
else:
try:
self.connection, _ = self.sock.accept()
except socket.timeout as e:
raise IPCException("The socket timed out") from e
return self
def __exit__(
self,
exc_ty: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: TracebackType | None = None,
) -> None:
if sys.platform == "win32":
try:
# Wait for the client to finish reading the last write before disconnecting
if not FlushFileBuffers(self.connection):
raise IPCException(
"Failed to flush NamedPipe buffer, maybe the client hung up?"
)
finally:
DisconnectNamedPipe(self.connection)
else:
self.close()
def cleanup(self) -> None:
if sys.platform == "win32":
self.close()
else:
shutil.rmtree(self.sock_directory)
@property
def connection_name(self) -> str:
if sys.platform == "win32":
return self.name
else:
name = self.sock.getsockname()
assert isinstance(name, str)
return name

View File

@@ -0,0 +1,688 @@
"""Calculation of the least upper bound types (joins)."""
from __future__ import annotations
from typing import overload
import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
from mypy.state import state
from mypy.subtypes import (
SubtypeContext,
find_member,
is_equivalent,
is_proper_subtype,
is_protocol_implementation,
is_subtype,
)
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
get_proper_type,
get_proper_types,
)
class InstanceJoiner:
def __init__(self) -> None:
self.seen_instances: list[tuple[Instance, Instance]] = []
def join_instances(self, t: Instance, s: Instance) -> ProperType:
if (t, s) in self.seen_instances or (s, t) in self.seen_instances:
return object_from_instance(t)
self.seen_instances.append((t, s))
# Calculate the join of two instance types
if t.type == s.type:
# Simplest case: join two types with the same base type (but
# potentially different arguments).
# Combine type arguments.
args: list[Type] = []
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for ta, sa, type_var in zip(t.args, s.args, t.type.defn.type_vars):
ta_proper = get_proper_type(ta)
sa_proper = get_proper_type(sa)
new_type: Type | None = None
if isinstance(ta_proper, AnyType):
new_type = AnyType(TypeOfAny.from_another_any, ta_proper)
elif isinstance(sa_proper, AnyType):
new_type = AnyType(TypeOfAny.from_another_any, sa_proper)
elif isinstance(type_var, TypeVarType):
if type_var.variance == COVARIANT:
new_type = join_types(ta, sa, self)
if len(type_var.values) != 0 and new_type not in type_var.values:
self.seen_instances.pop()
return object_from_instance(t)
if not is_subtype(new_type, type_var.upper_bound):
self.seen_instances.pop()
return object_from_instance(t)
# TODO: contravariant case should use meet but pass seen instances as
# an argument to keep track of recursive checks.
elif type_var.variance in (INVARIANT, CONTRAVARIANT):
if not is_equivalent(ta, sa):
self.seen_instances.pop()
return object_from_instance(t)
# If the types are different but equivalent, then an Any is involved
# so using a join in the contravariant case is also OK.
new_type = join_types(ta, sa, self)
else:
# ParamSpec type variables behave the same, independent of variance
if not is_equivalent(ta, sa):
return get_proper_type(type_var.upper_bound)
new_type = join_types(ta, sa, self)
assert new_type is not None
args.append(new_type)
result: ProperType = Instance(t.type, args)
elif t.type.bases and is_proper_subtype(
t, s, subtype_context=SubtypeContext(ignore_type_params=True)
):
result = self.join_instances_via_supertype(t, s)
else:
# Now t is not a subtype of s, and t != s. Now s could be a subtype
# of t; alternatively, we need to find a common supertype. This works
# in of the both cases.
result = self.join_instances_via_supertype(s, t)
self.seen_instances.pop()
return result
def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
# Give preference to joins via duck typing relationship, so that
# join(int, float) == float, for example.
for p in t.type._promote:
if is_subtype(p, s):
return join_types(p, s, self)
for p in s.type._promote:
if is_subtype(p, t):
return join_types(t, p, self)
# Compute the "best" supertype of t when joined with s.
# The definition of "best" may evolve; for now it is the one with
# the longest MRO. Ties are broken by using the earlier base.
best: ProperType | None = None
for base in t.type.bases:
mapped = map_instance_to_supertype(t, base.type)
res = self.join_instances(mapped, s)
if best is None or is_better(res, best):
best = res
assert best is not None
for promote in t.type._promote:
if isinstance(promote, Instance):
res = self.join_instances(promote, s)
if is_better(res, best):
best = res
return best
def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
"""Return a simple least upper bound given the declared type.
This function should be only used by binder, and should not recurse.
For all other uses, use `join_types()`.
"""
declaration = get_proper_type(declaration)
s = get_proper_type(s)
t = get_proper_type(t)
if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false):
# if types are restricted in different ways, use the more general versions
s = mypy.typeops.true_or_false(s)
t = mypy.typeops.true_or_false(t)
if isinstance(s, AnyType):
return s
if isinstance(s, ErasedType):
return t
if is_proper_subtype(s, t, ignore_promotions=True):
return t
if is_proper_subtype(t, s, ignore_promotions=True):
return s
if isinstance(declaration, UnionType):
return mypy.typeops.make_simplified_union([s, t])
if isinstance(s, NoneType) and not isinstance(t, NoneType):
s, t = t, s
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s
# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s
value = t.accept(TypeJoinVisitor(s))
if declaration is None or is_subtype(value, declaration):
return value
return declaration
def trivial_join(s: Type, t: Type) -> Type:
"""Return one of types (expanded) if it is a supertype of other, otherwise top type."""
if is_subtype(s, t):
return t
elif is_subtype(t, s):
return s
else:
return object_or_any_from_type(get_proper_type(t))
@overload
def join_types(
s: ProperType, t: ProperType, instance_joiner: InstanceJoiner | None = None
) -> ProperType:
...
@overload
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
...
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
"""Return the least upper bound of s and t.
For example, the join of 'int' and 'object' is 'object'.
"""
if mypy.typeops.is_recursive_pair(s, t):
# This case can trigger an infinite recursion, general support for this will be
# tricky so we use a trivial join (like for protocols).
return trivial_join(s, t)
s = get_proper_type(s)
t = get_proper_type(t)
if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false):
# if types are restricted in different ways, use the more general versions
s = mypy.typeops.true_or_false(s)
t = mypy.typeops.true_or_false(t)
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s
if isinstance(s, AnyType):
return s
if isinstance(s, ErasedType):
return t
if isinstance(s, NoneType) and not isinstance(t, NoneType):
s, t = t, s
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s
# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)
# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))
class TypeJoinVisitor(TypeVisitor[ProperType]):
"""Implementation of the least upper bound algorithm.
Attributes:
s: The other (left) type operand.
"""
def __init__(self, s: ProperType, instance_joiner: InstanceJoiner | None = None) -> None:
self.s = s
self.instance_joiner = instance_joiner
def visit_unbound_type(self, t: UnboundType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_union_type(self, t: UnionType) -> ProperType:
if is_proper_subtype(self.s, t):
return t
else:
return mypy.typeops.make_simplified_union([self.s, t])
def visit_any(self, t: AnyType) -> ProperType:
return t
def visit_none_type(self, t: NoneType) -> ProperType:
if state.strict_optional:
if isinstance(self.s, (NoneType, UninhabitedType)):
return t
elif isinstance(self.s, UnboundType):
return AnyType(TypeOfAny.special_form)
else:
return mypy.typeops.make_simplified_union([self.s, t])
else:
return self.s
def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
return self.s
def visit_deleted_type(self, t: DeletedType) -> ProperType:
return self.s
def visit_erased_type(self, t: ErasedType) -> ProperType:
return self.s
def visit_type_var(self, t: TypeVarType) -> ProperType:
if isinstance(self.s, TypeVarType) and self.s.id == t.id:
return self.s
else:
return self.default(self.s)
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
if self.s == t:
return t
return self.default(self.s)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
if self.s == t:
return t
return self.default(self.s)
def visit_unpack_type(self, t: UnpackType) -> UnpackType:
raise NotImplementedError
def visit_parameters(self, t: Parameters) -> ProperType:
if isinstance(self.s, Parameters):
if len(t.arg_types) != len(self.s.arg_types):
return self.default(self.s)
return t.copy_modified(
# Note that since during constraint inference we already treat whole ParamSpec as
# contravariant, we should join individual items, not meet them like for Callables
arg_types=[join_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)]
)
else:
return self.default(self.s)
def visit_instance(self, t: Instance) -> ProperType:
if isinstance(self.s, Instance):
if self.instance_joiner is None:
self.instance_joiner = InstanceJoiner()
nominal = self.instance_joiner.join_instances(t, self.s)
structural: Instance | None = None
if t.type.is_protocol and is_protocol_implementation(self.s, t):
structural = t
elif self.s.type.is_protocol and is_protocol_implementation(t, self.s):
structural = self.s
# Structural join is preferred in the case where we have found both
# structural and nominal and they have same MRO length (see two comments
# in join_instances_via_supertype). Otherwise, just return the nominal join.
if not structural or is_better(nominal, structural):
return nominal
return structural
elif isinstance(self.s, FunctionLike):
if t.type.is_protocol:
call = unpack_callback_protocol(t)
if call:
return join_types(call, self.s)
return join_types(t, self.s.fallback)
elif isinstance(self.s, TypeType):
return join_types(t, self.s)
elif isinstance(self.s, TypedDictType):
return join_types(t, self.s)
elif isinstance(self.s, TupleType):
return join_types(t, self.s)
elif isinstance(self.s, LiteralType):
return join_types(t, self.s)
else:
return self.default(self.s)
def visit_callable_type(self, t: CallableType) -> ProperType:
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
if is_equivalent(t, self.s):
return combine_similar_callables(t, self.s)
result = join_similar_callables(t, self.s)
# We set the from_type_type flag to suppress error when a collection of
# concrete class objects gets inferred as their common abstract superclass.
if not (
(t.is_type_obj() and t.type_object().is_abstract)
or (self.s.is_type_obj() and self.s.type_object().is_abstract)
):
result.from_type_type = True
if any(
isinstance(tp, (NoneType, UninhabitedType))
for tp in get_proper_types(result.arg_types)
):
# We don't want to return unusable Callable, attempt fallback instead.
return join_types(t.fallback, self.s)
return result
elif isinstance(self.s, Overloaded):
# Switch the order of arguments to that we'll get to visit_overloaded.
return join_types(t, self.s)
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
call = unpack_callback_protocol(self.s)
if call:
return join_types(t, call)
return join_types(t.fallback, self.s)
def visit_overloaded(self, t: Overloaded) -> ProperType:
# This is more complex than most other cases. Here are some
# examples that illustrate how this works.
#
# First let's define a concise notation:
# - Cn are callable types (for n in 1, 2, ...)
# - Ov(C1, C2, ...) is an overloaded type with items C1, C2, ...
# - Callable[[T, ...], S] is written as [T, ...] -> S.
#
# We want some basic properties to hold (assume Cn are all
# unrelated via Any-similarity):
#
# join(Ov(C1, C2), C1) == C1
# join(Ov(C1, C2), Ov(C1, C2)) == Ov(C1, C2)
# join(Ov(C1, C2), Ov(C1, C3)) == C1
# join(Ov(C2, C2), C3) == join of fallback types
#
# The presence of Any types makes things more interesting. The join is the
# most general type we can get with respect to Any:
#
# join(Ov([int] -> int, [str] -> str), [Any] -> str) == Any -> str
#
# We could use a simplification step that removes redundancies, but that's not
# implemented right now. Consider this example, where we get a redundancy:
#
# join(Ov([int, Any] -> Any, [str, Any] -> Any), [Any, int] -> Any) ==
# Ov([Any, int] -> Any, [Any, int] -> Any)
#
# TODO: Consider more cases of callable subtyping.
result: list[CallableType] = []
s = self.s
if isinstance(s, FunctionLike):
# The interesting case where both types are function types.
for t_item in t.items:
for s_item in s.items:
if is_similar_callables(t_item, s_item):
if is_equivalent(t_item, s_item):
result.append(combine_similar_callables(t_item, s_item))
elif is_subtype(t_item, s_item):
result.append(s_item)
if result:
# TODO: Simplify redundancies from the result.
if len(result) == 1:
return result[0]
else:
return Overloaded(result)
return join_types(t.fallback, s.fallback)
elif isinstance(s, Instance) and s.type.is_protocol:
call = unpack_callback_protocol(s)
if call:
return join_types(t, call)
return join_types(t.fallback, s)
def visit_tuple_type(self, t: TupleType) -> ProperType:
# When given two fixed-length tuples:
# * If they have the same length, join their subtypes item-wise:
# Tuple[int, bool] + Tuple[bool, bool] becomes Tuple[int, bool]
# * If lengths do not match, return a variadic tuple:
# Tuple[bool, int] + Tuple[bool] becomes Tuple[int, ...]
#
# Otherwise, `t` is a fixed-length tuple but `self.s` is NOT:
# * Joining with a variadic tuple returns variadic tuple:
# Tuple[int, bool] + Tuple[bool, ...] becomes Tuple[int, ...]
# * Joining with any Sequence also returns a Sequence:
# Tuple[int, bool] + List[bool] becomes Sequence[int]
if isinstance(self.s, TupleType) and self.s.length() == t.length():
if self.instance_joiner is None:
self.instance_joiner = InstanceJoiner()
fallback = self.instance_joiner.join_instances(
mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t)
)
assert isinstance(fallback, Instance)
if self.s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(join_types(t.items[i], self.s.items[i]))
return TupleType(items, fallback)
else:
return fallback
else:
return join_types(self.s, mypy.typeops.tuple_fallback(t))
def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
if isinstance(self.s, TypedDictType):
items = {
item_name: s_item_type
for (item_name, s_item_type, t_item_type) in self.s.zip(t)
if (
is_equivalent(s_item_type, t_item_type)
and (item_name in t.required_keys) == (item_name in self.s.required_keys)
)
}
fallback = self.s.create_anonymous_fallback()
# We need to filter by items.keys() since some required keys present in both t and
# self.s might be missing from the join if the types are incompatible.
required_keys = set(items.keys()) & t.required_keys & self.s.required_keys
return TypedDictType(items, required_keys, fallback)
elif isinstance(self.s, Instance):
return join_types(self.s, t.fallback)
else:
return self.default(self.s)
def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType):
if t == self.s:
return t
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
return mypy.typeops.make_simplified_union([self.s, t])
return join_types(self.s.fallback, t.fallback)
else:
return join_types(self.s, t.fallback)
def visit_partial_type(self, t: PartialType) -> ProperType:
# We only have partial information so we can't decide the join result. We should
# never get here.
assert False, "Internal error"
def visit_type_type(self, t: TypeType) -> ProperType:
if isinstance(self.s, TypeType):
return TypeType.make_normalized(join_types(t.item, self.s.item), line=t.line)
elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type":
return self.s
else:
return self.default(self.s)
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, f"This should be never called, got {t}"
def default(self, typ: Type) -> ProperType:
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return object_from_instance(typ)
elif isinstance(typ, UnboundType):
return AnyType(TypeOfAny.special_form)
elif isinstance(typ, TupleType):
return self.default(mypy.typeops.tuple_fallback(typ))
elif isinstance(typ, TypedDictType):
return self.default(typ.fallback)
elif isinstance(typ, FunctionLike):
return self.default(typ.fallback)
elif isinstance(typ, TypeVarType):
return self.default(typ.upper_bound)
elif isinstance(typ, ParamSpecType):
return self.default(typ.upper_bound)
else:
return AnyType(TypeOfAny.special_form)
def is_better(t: Type, s: Type) -> bool:
# Given two possible results from join_instances_via_supertype(),
# indicate whether t is the better one.
t = get_proper_type(t)
s = get_proper_type(s)
if isinstance(t, Instance):
if not isinstance(s, Instance):
return True
# Use len(mro) as a proxy for the better choice.
if len(t.type.mro) > len(s.type.mro):
return True
return False
def normalize_callables(s: ProperType, t: ProperType) -> tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t
def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
"""
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and t.is_var_arg == s.is_var_arg
)
def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
from mypy.meet import meet_types
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(meet_types(t.arg_types[i], s.arg_types[i]))
# TODO in combine_similar_callables also applies here (names and kinds; user metaclasses)
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
# The result should always use 'function' as a fallback if either operands are using it.
if t.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
fallback = s.fallback
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
ret_type=join_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
)
def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(join_types(t.arg_types[i], s.arg_types[i]))
# TODO kinds and argument names
# TODO what should happen if one fallback is 'type' and the other is a user-provided metaclass?
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
# The result should always use 'function' as a fallback if either operands are using it.
if t.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
fallback = s.fallback
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
ret_type=join_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
)
def combine_arg_names(t: CallableType, s: CallableType) -> list[str | None]:
"""Produces a list of argument names compatible with both callables.
For example, suppose 't' and 's' have the following signatures:
- t: (a: int, b: str, X: str) -> None
- s: (a: int, b: str, Y: str) -> None
This function would return ["a", "b", None]. This information
is then used above to compute the join of t and s, which results
in a signature of (a: int, b: str, str) -> None.
Note that the third argument's name is omitted and 't' and 's'
are both valid subtypes of this inferred signature.
Precondition: is_similar_types(t, s) is true.
"""
num_args = len(t.arg_types)
new_names = []
for i in range(num_args):
t_name = t.arg_names[i]
s_name = s.arg_names[i]
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
new_names.append(t_name)
else:
new_names.append(None)
return new_names
def object_from_instance(instance: Instance) -> Instance:
"""Construct the type 'builtins.object' from an instance type."""
# Use the fact that 'object' is always the last class in the mro.
res = Instance(instance.type.mro[-1], [])
return res
def object_or_any_from_type(typ: ProperType) -> ProperType:
# Similar to object_from_instance() but tries hard for all types.
# TODO: find a better way to get object, or make this more reliable.
if isinstance(typ, Instance):
return object_from_instance(typ)
elif isinstance(typ, (CallableType, TypedDictType, LiteralType)):
return object_from_instance(typ.fallback)
elif isinstance(typ, TupleType):
return object_from_instance(typ.partial_fallback)
elif isinstance(typ, TypeType):
return object_or_any_from_type(typ.item)
elif isinstance(typ, TypeVarType) and isinstance(typ.upper_bound, ProperType):
return object_or_any_from_type(typ.upper_bound)
elif isinstance(typ, UnionType):
for item in typ.items:
if isinstance(item, ProperType):
candidate = object_or_any_from_type(item)
if isinstance(candidate, Instance):
return candidate
return AnyType(TypeOfAny.implementation_artifact)
def join_type_list(types: list[Type]) -> Type:
if not types:
# This is a little arbitrary but reasonable. Any empty tuple should be compatible
# with all variable length tuples, and this makes it possible.
return UninhabitedType()
joined = types[0]
for t in types[1:]:
joined = join_types(joined, t)
return joined
def unpack_callback_protocol(t: Instance) -> ProperType | None:
assert t.type.is_protocol
if t.type.protocol_members == ["__call__"]:
return get_proper_type(find_member("__call__", t, t, is_operator=True))
return None

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
from typing import Any, Final, Iterable, Optional, Tuple
from typing_extensions import TypeAlias as _TypeAlias
from mypy.nodes import (
LITERAL_NO,
LITERAL_TYPE,
LITERAL_YES,
AssertTypeExpr,
AssignmentExpr,
AwaitExpr,
BytesExpr,
CallExpr,
CastExpr,
ComparisonExpr,
ComplexExpr,
ConditionalExpr,
DictExpr,
DictionaryComprehension,
EllipsisExpr,
EnumCallExpr,
Expression,
FloatExpr,
GeneratorExpr,
IndexExpr,
IntExpr,
LambdaExpr,
ListComprehension,
ListExpr,
MemberExpr,
NamedTupleExpr,
NameExpr,
NewTypeExpr,
OpExpr,
ParamSpecExpr,
PromoteExpr,
RevealExpr,
SetComprehension,
SetExpr,
SliceExpr,
StarExpr,
StrExpr,
SuperExpr,
TempNode,
TupleExpr,
TypeAliasExpr,
TypeApplication,
TypedDictExpr,
TypeVarExpr,
TypeVarTupleExpr,
UnaryExpr,
Var,
YieldExpr,
YieldFromExpr,
)
from mypy.visitor import ExpressionVisitor
# [Note Literals and literal_hash]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Mypy uses the term "literal" to refer to any expression built out of
# the following:
#
# * Plain literal expressions, like `1` (integer, float, string, etc.)
#
# * Compound literal expressions, like `(lit1, lit2)` (list, dict,
# set, or tuple)
#
# * Operator expressions, like `lit1 + lit2`
#
# * Variable references, like `x`
#
# * Member references, like `lit.m`
#
# * Index expressions, like `lit[0]`
#
# A typical "literal" looks like `x[(i,j+1)].m`.
#
# An expression that is a literal has a `literal_hash`, with the
# following properties.
#
# * `literal_hash` is a Key: a tuple containing basic data types and
# possibly other Keys. So it can be used as a key in a dictionary
# that will be compared by value (as opposed to the Node itself,
# which is compared by identity).
#
# * Two expressions have equal `literal_hash`es if and only if they
# are syntactically equal expressions. (NB: Actually, we also
# identify as equal expressions like `3` and `3.0`; is this a good
# idea?)
#
# * The elements of `literal_hash` that are tuples are exactly the
# subexpressions of the original expression (e.g. the base and index
# of an index expression, or the operands of an operator expression).
def literal(e: Expression) -> int:
if isinstance(e, ComparisonExpr):
return min(literal(o) for o in e.operands)
elif isinstance(e, OpExpr):
return min(literal(e.left), literal(e.right))
elif isinstance(e, (MemberExpr, UnaryExpr, StarExpr)):
return literal(e.expr)
elif isinstance(e, AssignmentExpr):
return literal(e.target)
elif isinstance(e, IndexExpr):
if literal(e.index) == LITERAL_YES:
return literal(e.base)
else:
return LITERAL_NO
elif isinstance(e, NameExpr):
if isinstance(e.node, Var) and e.node.is_final and e.node.final_value is not None:
return LITERAL_YES
return LITERAL_TYPE
if isinstance(e, (IntExpr, FloatExpr, ComplexExpr, StrExpr, BytesExpr)):
return LITERAL_YES
if literal_hash(e):
return LITERAL_YES
return LITERAL_NO
Key: _TypeAlias = Tuple[Any, ...]
def subkeys(key: Key) -> Iterable[Key]:
return [elt for elt in key if isinstance(elt, tuple)]
def literal_hash(e: Expression) -> Key | None:
return e.accept(_hasher)
def extract_var_from_literal_hash(key: Key) -> Var | None:
"""If key refers to a Var node, return it.
Return None otherwise.
"""
if len(key) == 2 and key[0] == "Var" and isinstance(key[1], Var):
return key[1]
return None
class _Hasher(ExpressionVisitor[Optional[Key]]):
def visit_int_expr(self, e: IntExpr) -> Key:
return ("Literal", e.value)
def visit_str_expr(self, e: StrExpr) -> Key:
return ("Literal", e.value)
def visit_bytes_expr(self, e: BytesExpr) -> Key:
return ("Literal", e.value)
def visit_float_expr(self, e: FloatExpr) -> Key:
return ("Literal", e.value)
def visit_complex_expr(self, e: ComplexExpr) -> Key:
return ("Literal", e.value)
def visit_star_expr(self, e: StarExpr) -> Key:
return ("Star", literal_hash(e.expr))
def visit_name_expr(self, e: NameExpr) -> Key:
if isinstance(e.node, Var) and e.node.is_final and e.node.final_value is not None:
return ("Literal", e.node.final_value)
# N.B: We use the node itself as the key, and not the name,
# because using the name causes issues when there is shadowing
# (for example, in list comprehensions).
return ("Var", e.node)
def visit_member_expr(self, e: MemberExpr) -> Key:
return ("Member", literal_hash(e.expr), e.name)
def visit_op_expr(self, e: OpExpr) -> Key:
return ("Binary", e.op, literal_hash(e.left), literal_hash(e.right))
def visit_comparison_expr(self, e: ComparisonExpr) -> Key:
rest: tuple[str | Key | None, ...] = tuple(e.operators)
rest += tuple(literal_hash(o) for o in e.operands)
return ("Comparison",) + rest
def visit_unary_expr(self, e: UnaryExpr) -> Key:
return ("Unary", e.op, literal_hash(e.expr))
def seq_expr(self, e: ListExpr | TupleExpr | SetExpr, name: str) -> Key | None:
if all(literal(x) == LITERAL_YES for x in e.items):
rest: tuple[Key | None, ...] = tuple(literal_hash(x) for x in e.items)
return (name,) + rest
return None
def visit_list_expr(self, e: ListExpr) -> Key | None:
return self.seq_expr(e, "List")
def visit_dict_expr(self, e: DictExpr) -> Key | None:
if all(a and literal(a) == literal(b) == LITERAL_YES for a, b in e.items):
rest: tuple[Key | None, ...] = tuple(
(literal_hash(a) if a else None, literal_hash(b)) for a, b in e.items
)
return ("Dict",) + rest
return None
def visit_tuple_expr(self, e: TupleExpr) -> Key | None:
return self.seq_expr(e, "Tuple")
def visit_set_expr(self, e: SetExpr) -> Key | None:
return self.seq_expr(e, "Set")
def visit_index_expr(self, e: IndexExpr) -> Key | None:
if literal(e.index) == LITERAL_YES:
return ("Index", literal_hash(e.base), literal_hash(e.index))
return None
def visit_assignment_expr(self, e: AssignmentExpr) -> Key | None:
return literal_hash(e.target)
def visit_call_expr(self, e: CallExpr) -> None:
return None
def visit_slice_expr(self, e: SliceExpr) -> None:
return None
def visit_cast_expr(self, e: CastExpr) -> None:
return None
def visit_assert_type_expr(self, e: AssertTypeExpr) -> None:
return None
def visit_conditional_expr(self, e: ConditionalExpr) -> None:
return None
def visit_ellipsis(self, e: EllipsisExpr) -> None:
return None
def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
return None
def visit_yield_expr(self, e: YieldExpr) -> None:
return None
def visit_reveal_expr(self, e: RevealExpr) -> None:
return None
def visit_super_expr(self, e: SuperExpr) -> None:
return None
def visit_type_application(self, e: TypeApplication) -> None:
return None
def visit_lambda_expr(self, e: LambdaExpr) -> None:
return None
def visit_list_comprehension(self, e: ListComprehension) -> None:
return None
def visit_set_comprehension(self, e: SetComprehension) -> None:
return None
def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> None:
return None
def visit_generator_expr(self, e: GeneratorExpr) -> None:
return None
def visit_type_var_expr(self, e: TypeVarExpr) -> None:
return None
def visit_paramspec_expr(self, e: ParamSpecExpr) -> None:
return None
def visit_type_var_tuple_expr(self, e: TypeVarTupleExpr) -> None:
return None
def visit_type_alias_expr(self, e: TypeAliasExpr) -> None:
return None
def visit_namedtuple_expr(self, e: NamedTupleExpr) -> None:
return None
def visit_enum_call_expr(self, e: EnumCallExpr) -> None:
return None
def visit_typeddict_expr(self, e: TypedDictExpr) -> None:
return None
def visit_newtype_expr(self, e: NewTypeExpr) -> None:
return None
def visit__promote_expr(self, e: PromoteExpr) -> None:
return None
def visit_await_expr(self, e: AwaitExpr) -> None:
return None
def visit_temp_node(self, e: TempNode) -> None:
return None
_hasher: Final = _Hasher()

View File

@@ -0,0 +1,61 @@
"""
This is a module for various lookup functions:
functions that will find a semantic node by its name.
"""
from __future__ import annotations
from mypy.nodes import MypyFile, SymbolTableNode, TypeInfo
# TODO: gradually move existing lookup functions to this module.
def lookup_fully_qualified(
name: str, modules: dict[str, MypyFile], *, raise_on_missing: bool = False
) -> SymbolTableNode | None:
"""Find a symbol using it fully qualified name.
The algorithm has two steps: first we try splitting the name on '.' to find
the module, then iteratively look for each next chunk after a '.' (e.g. for
nested classes).
This function should *not* be used to find a module. Those should be looked
in the modules dictionary.
"""
head = name
rest = []
# 1. Find a module tree in modules dictionary.
while True:
if "." not in head:
if raise_on_missing:
assert "." in head, f"Cannot find module for {name}"
return None
head, tail = head.rsplit(".", maxsplit=1)
rest.append(tail)
mod = modules.get(head)
if mod is not None:
break
names = mod.names
# 2. Find the symbol in the module tree.
if not rest:
# Looks like a module, don't use this to avoid confusions.
if raise_on_missing:
assert rest, f"Cannot find {name}, got a module symbol"
return None
while True:
key = rest.pop()
if key not in names:
if raise_on_missing:
assert key in names, f"Cannot find component {key!r} for {name!r}"
return None
stnode = names[key]
if not rest:
return stnode
node = stnode.node
# In fine-grained mode, could be a cross-reference to a deleted module
# or a Var made up for a missing module.
if not isinstance(node, TypeInfo):
if raise_on_missing:
assert node, f"Cannot find {name}"
return None
names = node.names

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,117 @@
from __future__ import annotations
from mypy.expandtype import expand_type
from mypy.nodes import TypeInfo
from mypy.types import AnyType, Instance, TupleType, Type, TypeOfAny, TypeVarId, has_type_vars
def map_instance_to_supertype(instance: Instance, superclass: TypeInfo) -> Instance:
"""Produce a supertype of `instance` that is an Instance
of `superclass`, mapping type arguments up the chain of bases.
If `superclass` is not a nominal superclass of `instance.type`,
then all type arguments are mapped to 'Any'.
"""
if instance.type == superclass:
# Fast path: `instance` already belongs to `superclass`.
return instance
if superclass.fullname == "builtins.tuple" and instance.type.tuple_type:
if has_type_vars(instance.type.tuple_type):
# We special case mapping generic tuple types to tuple base, because for
# such tuples fallback can't be calculated before applying type arguments.
alias = instance.type.special_alias
assert alias is not None
if not alias._is_recursive:
# Unfortunately we can't support this for generic recursive tuples.
# If we skip this special casing we will fall back to tuple[Any, ...].
env = instance_to_type_environment(instance)
tuple_type = expand_type(instance.type.tuple_type, env)
if isinstance(tuple_type, TupleType):
# Make the import here to avoid cyclic imports.
import mypy.typeops
return mypy.typeops.tuple_fallback(tuple_type)
if not superclass.type_vars:
# Fast path: `superclass` has no type variables to map to.
return Instance(superclass, [])
return map_instance_to_supertypes(instance, superclass)[0]
def map_instance_to_supertypes(instance: Instance, supertype: TypeInfo) -> list[Instance]:
# FIX: Currently we should only have one supertype per interface, so no
# need to return an array
result: list[Instance] = []
for path in class_derivation_paths(instance.type, supertype):
types = [instance]
for sup in path:
a: list[Instance] = []
for t in types:
a.extend(map_instance_to_direct_supertypes(t, sup))
types = a
result.extend(types)
if result:
return result
else:
# Nothing. Presumably due to an error. Construct a dummy using Any.
any_type = AnyType(TypeOfAny.from_error)
return [Instance(supertype, [any_type] * len(supertype.type_vars))]
def class_derivation_paths(typ: TypeInfo, supertype: TypeInfo) -> list[list[TypeInfo]]:
"""Return an array of non-empty paths of direct base classes from
type to supertype. Return [] if no such path could be found.
InterfaceImplementationPaths(A, B) == [[B]] if A inherits B
InterfaceImplementationPaths(A, C) == [[B, C]] if A inherits B and
B inherits C
"""
# FIX: Currently we might only ever have a single path, so this could be
# simplified
result: list[list[TypeInfo]] = []
for base in typ.bases:
btype = base.type
if btype == supertype:
result.append([btype])
else:
# Try constructing a longer path via the base class.
for path in class_derivation_paths(btype, supertype):
result.append([btype] + path)
return result
def map_instance_to_direct_supertypes(instance: Instance, supertype: TypeInfo) -> list[Instance]:
# FIX: There should only be one supertypes, always.
typ = instance.type
result: list[Instance] = []
for b in typ.bases:
if b.type == supertype:
env = instance_to_type_environment(instance)
t = expand_type(b, env)
assert isinstance(t, Instance)
result.append(t)
if result:
return result
else:
# Relationship with the supertype not specified explicitly. Use dynamic
# type arguments implicitly.
any_type = AnyType(TypeOfAny.unannotated)
return [Instance(supertype, [any_type] * len(supertype.type_vars))]
def instance_to_type_environment(instance: Instance) -> dict[TypeVarId, Type]:
"""Given an Instance, produce the resulting type environment for type
variables bound by the Instance's class definition.
An Instance is a type application of a class (a TypeInfo) to its
required number of type arguments. So this environment consists
of the class's type variables mapped to the Instance's actual
arguments. The type variables are mapped by their `id`.
"""
return {binder.id: arg for binder, arg in zip(instance.type.defn.type_vars, instance.args)}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,121 @@
"""Utility for dumping memory usage stats.
This is tailored to mypy and knows (a little) about which list objects are
owned by particular AST nodes, etc.
"""
from __future__ import annotations
import gc
import sys
from collections import defaultdict
from typing import Dict, Iterable, cast
from mypy.nodes import FakeInfo, Node
from mypy.types import Type
from mypy.util import get_class_descriptors
def collect_memory_stats() -> tuple[dict[str, int], dict[str, int]]:
"""Return stats about memory use.
Return a tuple with these items:
- Dict from object kind to number of instances of that kind
- Dict from object kind to total bytes used by all instances of that kind
"""
objs = gc.get_objects()
find_recursive_objects(objs)
inferred = {}
for obj in objs:
if type(obj) is FakeInfo:
# Processing these would cause a crash.
continue
n = type(obj).__name__
if hasattr(obj, "__dict__"):
# Keep track of which class a particular __dict__ is associated with.
inferred[id(obj.__dict__)] = f"{n} (__dict__)"
if isinstance(obj, (Node, Type)): # type: ignore[misc]
if hasattr(obj, "__dict__"):
for x in obj.__dict__.values():
if isinstance(x, list):
# Keep track of which node a list is associated with.
inferred[id(x)] = f"{n} (list)"
if isinstance(x, tuple):
# Keep track of which node a list is associated with.
inferred[id(x)] = f"{n} (tuple)"
for k in get_class_descriptors(type(obj)):
x = getattr(obj, k, None)
if isinstance(x, list):
inferred[id(x)] = f"{n} (list)"
if isinstance(x, tuple):
inferred[id(x)] = f"{n} (tuple)"
freqs: dict[str, int] = {}
memuse: dict[str, int] = {}
for obj in objs:
if id(obj) in inferred:
name = inferred[id(obj)]
else:
name = type(obj).__name__
freqs[name] = freqs.get(name, 0) + 1
memuse[name] = memuse.get(name, 0) + sys.getsizeof(obj)
return freqs, memuse
def print_memory_profile(run_gc: bool = True) -> None:
if not sys.platform.startswith("win"):
import resource
system_memuse = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
else:
system_memuse = -1 # TODO: Support this on Windows
if run_gc:
gc.collect()
freqs, memuse = collect_memory_stats()
print("%7s %7s %7s %s" % ("Freq", "Size(k)", "AvgSize", "Type"))
print("-------------------------------------------")
totalmem = 0
i = 0
for n, mem in sorted(memuse.items(), key=lambda x: -x[1]):
f = freqs[n]
if i < 50:
print("%7d %7d %7.0f %s" % (f, mem // 1024, mem / f, n))
i += 1
totalmem += mem
print()
print("Mem usage RSS ", system_memuse // 1024)
print("Total reachable ", totalmem // 1024)
def find_recursive_objects(objs: list[object]) -> None:
"""Find additional objects referenced by objs and append them to objs.
We use this since gc.get_objects() does not return objects without pointers
in them such as strings.
"""
seen = {id(o) for o in objs}
def visit(o: object) -> None:
if id(o) not in seen:
objs.append(o)
seen.add(id(o))
for obj in objs.copy():
if type(obj) is FakeInfo:
# Processing these would cause a crash.
continue
if type(obj) in (dict, defaultdict):
for key, val in cast(Dict[object, object], obj).items():
visit(key)
visit(val)
if type(obj) in (list, tuple, set):
for x in cast(Iterable[object], obj):
visit(x)
if hasattr(obj, "__slots__"):
for base in type.mro(type(obj)):
for slot in getattr(base, "__slots__", ()):
if hasattr(obj, slot):
visit(getattr(obj, slot))

View File

@@ -0,0 +1,319 @@
"""Message constants for generating error messages during type checking.
Literal messages should be defined as constants in this module so they won't get out of sync
if used in more than one place, and so that they can be easily introspected. These messages are
ultimately consumed by messages.MessageBuilder.fail(). For more non-trivial message generation,
add a method to MessageBuilder and call this instead.
"""
from __future__ import annotations
from typing import Final, NamedTuple
from mypy import errorcodes as codes
class ErrorMessage(NamedTuple):
value: str
code: codes.ErrorCode | None = None
def format(self, *args: object, **kwargs: object) -> ErrorMessage:
return ErrorMessage(self.value.format(*args, **kwargs), code=self.code)
def with_additional_msg(self, info: str) -> ErrorMessage:
return ErrorMessage(self.value + info, code=self.code)
# Invalid types
INVALID_TYPE_RAW_ENUM_VALUE: Final = ErrorMessage(
"Invalid type: try using Literal[{}.{}] instead?", codes.VALID_TYPE
)
# Type checker error message constants
NO_RETURN_VALUE_EXPECTED: Final = ErrorMessage("No return value expected", codes.RETURN_VALUE)
MISSING_RETURN_STATEMENT: Final = ErrorMessage("Missing return statement", codes.RETURN)
EMPTY_BODY_ABSTRACT: Final = ErrorMessage(
"If the method is meant to be abstract, use @abc.abstractmethod", codes.EMPTY_BODY
)
INVALID_IMPLICIT_RETURN: Final = ErrorMessage("Implicit return in function which does not return")
INCOMPATIBLE_RETURN_VALUE_TYPE: Final = ErrorMessage(
"Incompatible return value type", codes.RETURN_VALUE
)
RETURN_VALUE_EXPECTED: Final = ErrorMessage("Return value expected", codes.RETURN_VALUE)
NO_RETURN_EXPECTED: Final = ErrorMessage("Return statement in function which does not return")
INVALID_EXCEPTION: Final = ErrorMessage("Exception must be derived from BaseException")
INVALID_EXCEPTION_TYPE: Final = ErrorMessage(
"Exception type must be derived from BaseException (or be a tuple of exception classes)"
)
INVALID_EXCEPTION_GROUP: Final = ErrorMessage(
"Exception type in except* cannot derive from BaseExceptionGroup"
)
RETURN_IN_ASYNC_GENERATOR: Final = ErrorMessage(
'"return" with value in async generator is not allowed'
)
INVALID_RETURN_TYPE_FOR_GENERATOR: Final = ErrorMessage(
'The return type of a generator function should be "Generator"' " or one of its supertypes"
)
INVALID_RETURN_TYPE_FOR_ASYNC_GENERATOR: Final = ErrorMessage(
'The return type of an async generator function should be "AsyncGenerator" or one of its '
"supertypes"
)
YIELD_VALUE_EXPECTED: Final = ErrorMessage("Yield value expected")
INCOMPATIBLE_TYPES: Final = ErrorMessage("Incompatible types")
INCOMPATIBLE_TYPES_IN_ASSIGNMENT: Final = ErrorMessage(
"Incompatible types in assignment", code=codes.ASSIGNMENT
)
INCOMPATIBLE_TYPES_IN_AWAIT: Final = ErrorMessage('Incompatible types in "await"')
INCOMPATIBLE_REDEFINITION: Final = ErrorMessage("Incompatible redefinition")
INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER: Final = (
'Incompatible types in "async with" for "__aenter__"'
)
INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT: Final = (
'Incompatible types in "async with" for "__aexit__"'
)
INCOMPATIBLE_TYPES_IN_ASYNC_FOR: Final = 'Incompatible types in "async for"'
INVALID_TYPE_FOR_SLOTS: Final = 'Invalid type for "__slots__"'
ASYNC_FOR_OUTSIDE_COROUTINE: Final = '"async for" outside async function'
ASYNC_WITH_OUTSIDE_COROUTINE: Final = '"async with" outside async function'
INCOMPATIBLE_TYPES_IN_YIELD: Final = ErrorMessage('Incompatible types in "yield"')
INCOMPATIBLE_TYPES_IN_YIELD_FROM: Final = ErrorMessage('Incompatible types in "yield from"')
INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION: Final = "Incompatible types in string interpolation"
INCOMPATIBLE_TYPES_IN_CAPTURE: Final = ErrorMessage("Incompatible types in capture pattern")
MUST_HAVE_NONE_RETURN_TYPE: Final = ErrorMessage('The return type of "{}" must be None')
TUPLE_INDEX_OUT_OF_RANGE: Final = ErrorMessage("Tuple index out of range")
INVALID_SLICE_INDEX: Final = ErrorMessage("Slice index must be an integer, SupportsIndex or None")
CANNOT_INFER_LAMBDA_TYPE: Final = ErrorMessage("Cannot infer type of lambda")
CANNOT_ACCESS_INIT: Final = (
'Accessing "__init__" on an instance is unsound, since instance.__init__ could be from'
" an incompatible subclass"
)
NON_INSTANCE_NEW_TYPE: Final = ErrorMessage('"__new__" must return a class instance (got {})')
INVALID_NEW_TYPE: Final = ErrorMessage('Incompatible return type for "__new__"')
BAD_CONSTRUCTOR_TYPE: Final = ErrorMessage("Unsupported decorated constructor type")
CANNOT_ASSIGN_TO_METHOD: Final = "Cannot assign to a method"
CANNOT_ASSIGN_TO_TYPE: Final = "Cannot assign to a type"
INCONSISTENT_ABSTRACT_OVERLOAD: Final = ErrorMessage(
"Overloaded method has both abstract and non-abstract variants"
)
MULTIPLE_OVERLOADS_REQUIRED: Final = ErrorMessage("Single overload definition, multiple required")
READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE: Final = ErrorMessage(
"Read-only property cannot override read-write property"
)
FORMAT_REQUIRES_MAPPING: Final = "Format requires a mapping"
RETURN_TYPE_CANNOT_BE_CONTRAVARIANT: Final = ErrorMessage(
"Cannot use a contravariant type variable as return type"
)
FUNCTION_PARAMETER_CANNOT_BE_COVARIANT: Final = ErrorMessage(
"Cannot use a covariant type variable as a parameter"
)
INCOMPATIBLE_IMPORT_OF: Final = ErrorMessage('Incompatible import of "{}"', code=codes.ASSIGNMENT)
FUNCTION_TYPE_EXPECTED: Final = ErrorMessage(
"Function is missing a type annotation", codes.NO_UNTYPED_DEF
)
ONLY_CLASS_APPLICATION: Final = ErrorMessage(
"Type application is only supported for generic classes"
)
RETURN_TYPE_EXPECTED: Final = ErrorMessage(
"Function is missing a return type annotation", codes.NO_UNTYPED_DEF
)
ARGUMENT_TYPE_EXPECTED: Final = ErrorMessage(
"Function is missing a type annotation for one or more arguments", codes.NO_UNTYPED_DEF
)
KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE: Final = ErrorMessage(
'Keyword argument only valid with "str" key type in call to "dict"'
)
ALL_MUST_BE_SEQ_STR: Final = ErrorMessage("Type of __all__ must be {}, not {}")
INVALID_TYPEDDICT_ARGS: Final = ErrorMessage(
"Expected keyword arguments, {...}, or dict(...) in TypedDict constructor"
)
TYPEDDICT_KEY_MUST_BE_STRING_LITERAL: Final = ErrorMessage(
"Expected TypedDict key to be string literal"
)
MALFORMED_ASSERT: Final = ErrorMessage("Assertion is always true, perhaps remove parentheses?")
DUPLICATE_TYPE_SIGNATURES: Final = ErrorMessage("Function has duplicate type signatures")
DESCRIPTOR_SET_NOT_CALLABLE: Final = ErrorMessage("{}.__set__ is not callable")
DESCRIPTOR_GET_NOT_CALLABLE: Final = "{}.__get__ is not callable"
MODULE_LEVEL_GETATTRIBUTE: Final = ErrorMessage(
"__getattribute__ is not valid at the module level"
)
CLASS_VAR_CONFLICTS_SLOTS: Final = '"{}" in __slots__ conflicts with class variable access'
NAME_NOT_IN_SLOTS: Final = ErrorMessage(
'Trying to assign name "{}" that is not in "__slots__" of type "{}"'
)
TYPE_ALWAYS_TRUE: Final = ErrorMessage(
"{} which does not implement __bool__ or __len__ "
"so it could always be true in boolean context",
code=codes.TRUTHY_BOOL,
)
TYPE_ALWAYS_TRUE_UNIONTYPE: Final = ErrorMessage(
"{} of which no members implement __bool__ or __len__ "
"so it could always be true in boolean context",
code=codes.TRUTHY_BOOL,
)
FUNCTION_ALWAYS_TRUE: Final = ErrorMessage(
"Function {} could always be true in boolean context", code=codes.TRUTHY_FUNCTION
)
ITERABLE_ALWAYS_TRUE: Final = ErrorMessage(
"{} which can always be true in boolean context. Consider using {} instead.",
code=codes.TRUTHY_ITERABLE,
)
NOT_CALLABLE: Final = "{} not callable"
TYPE_MUST_BE_USED: Final = "Value of type {} must be used"
# Generic
GENERIC_INSTANCE_VAR_CLASS_ACCESS: Final = (
"Access to generic instance variables via class is ambiguous"
)
GENERIC_CLASS_VAR_ACCESS: Final = "Access to generic class variables is ambiguous"
BARE_GENERIC: Final = "Missing type parameters for generic type {}"
IMPLICIT_GENERIC_ANY_BUILTIN: Final = (
'Implicit generic "Any". Use "{}" and specify generic parameters'
)
INVALID_UNPACK: Final = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
INVALID_UNPACK_POSITION: Final = "Unpack is only valid in a variadic position"
# TypeVar
INCOMPATIBLE_TYPEVAR_VALUE: Final = 'Value of type variable "{}" of {} cannot be {}'
CANNOT_USE_TYPEVAR_AS_EXPRESSION: Final = 'Type variable "{}.{}" cannot be used as an expression'
INVALID_TYPEVAR_AS_TYPEARG: Final = 'Type variable "{}" not valid as type argument value for "{}"'
INVALID_TYPEVAR_ARG_BOUND: Final = 'Type argument {} of "{}" must be a subtype of {}'
INVALID_TYPEVAR_ARG_VALUE: Final = 'Invalid type argument value for "{}"'
TYPEVAR_VARIANCE_DEF: Final = 'TypeVar "{}" may only be a literal bool'
TYPEVAR_ARG_MUST_BE_TYPE: Final = '{} "{}" must be a type'
TYPEVAR_UNEXPECTED_ARGUMENT: Final = 'Unexpected argument to "TypeVar()"'
UNBOUND_TYPEVAR: Final = (
"A function returning TypeVar should receive at least "
"one argument containing the same TypeVar"
)
# Super
TOO_MANY_ARGS_FOR_SUPER: Final = ErrorMessage('Too many arguments for "super"')
SUPER_WITH_SINGLE_ARG_NOT_SUPPORTED: Final = ErrorMessage(
'"super" with a single argument not supported'
)
UNSUPPORTED_ARG_1_FOR_SUPER: Final = ErrorMessage('Unsupported argument 1 for "super"')
UNSUPPORTED_ARG_2_FOR_SUPER: Final = ErrorMessage('Unsupported argument 2 for "super"')
SUPER_VARARGS_NOT_SUPPORTED: Final = ErrorMessage('Varargs not supported with "super"')
SUPER_POSITIONAL_ARGS_REQUIRED: Final = ErrorMessage('"super" only accepts positional arguments')
SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1: Final = ErrorMessage(
'Argument 2 for "super" not an instance of argument 1'
)
TARGET_CLASS_HAS_NO_BASE_CLASS: Final = ErrorMessage("Target class has no base class")
SUPER_OUTSIDE_OF_METHOD_NOT_SUPPORTED: Final = ErrorMessage(
"super() outside of a method is not supported"
)
SUPER_ENCLOSING_POSITIONAL_ARGS_REQUIRED: Final = ErrorMessage(
"super() requires one or more positional arguments in enclosing function"
)
# Self-type
MISSING_OR_INVALID_SELF_TYPE: Final = ErrorMessage(
"Self argument missing for a non-static method (or an invalid type for self)"
)
ERASED_SELF_TYPE_NOT_SUPERTYPE: Final = ErrorMessage(
'The erased type of self "{}" is not a supertype of its class "{}"'
)
# Final
CANNOT_INHERIT_FROM_FINAL: Final = ErrorMessage('Cannot inherit from final class "{}"')
DEPENDENT_FINAL_IN_CLASS_BODY: Final = ErrorMessage(
"Final name declared in class body cannot depend on type variables"
)
CANNOT_ACCESS_FINAL_INSTANCE_ATTR: Final = (
'Cannot access final instance attribute "{}" on class object'
)
CANNOT_MAKE_DELETABLE_FINAL: Final = ErrorMessage("Deletable attribute cannot be final")
# Enum
ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDEN: Final = ErrorMessage(
'Assigned "__members__" will be overridden by "Enum" internally'
)
# ClassVar
CANNOT_OVERRIDE_INSTANCE_VAR: Final = ErrorMessage(
'Cannot override instance variable (previously declared on base class "{}") with class '
"variable"
)
CANNOT_OVERRIDE_CLASS_VAR: Final = ErrorMessage(
'Cannot override class variable (previously declared on base class "{}") with instance '
"variable"
)
CLASS_VAR_WITH_TYPEVARS: Final = "ClassVar cannot contain type variables"
CLASS_VAR_WITH_GENERIC_SELF: Final = "ClassVar cannot contain Self type in generic classes"
CLASS_VAR_OUTSIDE_OF_CLASS: Final = "ClassVar can only be used for assignments in class body"
# Protocol
RUNTIME_PROTOCOL_EXPECTED: Final = ErrorMessage(
"Only @runtime_checkable protocols can be used with instance and class checks"
)
CANNOT_INSTANTIATE_PROTOCOL: Final = ErrorMessage('Cannot instantiate protocol class "{}"')
TOO_MANY_UNION_COMBINATIONS: Final = ErrorMessage(
"Not all union combinations were tried because there are too many unions"
)
CONTIGUOUS_ITERABLE_EXPECTED: Final = ErrorMessage("Contiguous iterable with same type expected")
ITERABLE_TYPE_EXPECTED: Final = ErrorMessage("Invalid type '{}' for *expr (iterable expected)")
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type guard requires positional argument")
# Match Statement
MISSING_MATCH_ARGS: Final = 'Class "{}" doesn\'t define "__match_args__"'
OR_PATTERN_ALTERNATIVE_NAMES: Final = "Alternative patterns bind different names"
CLASS_PATTERN_GENERIC_TYPE_ALIAS: Final = (
"Class pattern class must not be a type alias with type parameters"
)
CLASS_PATTERN_TYPE_REQUIRED: Final = 'Expected type in class pattern; found "{}"'
CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS: Final = "Too many positional patterns for class pattern"
CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL: Final = (
'Keyword "{}" already matches a positional pattern'
)
CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN: Final = 'Duplicate keyword pattern "{}"'
CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"'
CLASS_PATTERN_CLASS_OR_STATIC_METHOD: Final = "Cannot have both classmethod and staticmethod"
MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern'
CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"'
DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
'"alias" argument to dataclass field must be a string literal'
)
DATACLASS_POST_INIT_MUST_BE_A_FUNCTION: Final = '"__post_init__" method must be an instance method'
# fastparse
FAILED_TO_MERGE_OVERLOADS: Final = ErrorMessage(
"Condition can't be inferred, unable to merge overloads"
)
TYPE_IGNORE_WITH_ERRCODE_ON_MODULE: Final = ErrorMessage(
"type ignore with error code is not supported for modules; "
'use `# mypy: disable-error-code="{}"`',
codes.SYNTAX,
)
INVALID_TYPE_IGNORE: Final = ErrorMessage('Invalid "type: ignore" comment', codes.SYNTAX)
TYPE_COMMENT_SYNTAX_ERROR_VALUE: Final = ErrorMessage(
'Syntax error in type comment "{}"', codes.SYNTAX
)
ELLIPSIS_WITH_OTHER_TYPEARGS: Final = ErrorMessage(
"Ellipses cannot accompany other argument types in function type signature", codes.SYNTAX
)
TYPE_SIGNATURE_TOO_MANY_ARGS: Final = ErrorMessage(
"Type signature has too many arguments", codes.SYNTAX
)
TYPE_SIGNATURE_TOO_FEW_ARGS: Final = ErrorMessage(
"Type signature has too few arguments", codes.SYNTAX
)
ARG_CONSTRUCTOR_NAME_EXPECTED: Final = ErrorMessage("Expected arg constructor name", codes.SYNTAX)
ARG_CONSTRUCTOR_TOO_MANY_ARGS: Final = ErrorMessage(
"Too many arguments for argument constructor", codes.SYNTAX
)
MULTIPLE_VALUES_FOR_NAME_KWARG: Final = ErrorMessage(
'"{}" gets multiple values for keyword argument "name"', codes.SYNTAX
)
MULTIPLE_VALUES_FOR_TYPE_KWARG: Final = ErrorMessage(
'"{}" gets multiple values for keyword argument "type"', codes.SYNTAX
)
ARG_CONSTRUCTOR_UNEXPECTED_ARG: Final = ErrorMessage(
'Unexpected argument "{}" for argument constructor', codes.SYNTAX
)
ARG_NAME_EXPECTED_STRING_LITERAL: Final = ErrorMessage(
"Expected string literal for argument name, got {}", codes.SYNTAX
)

Some files were not shown because too many files have changed in this diff Show More