init commit
Some checks failed
continuous-integration/drone Build is failing

This commit is contained in:
2025-09-04 01:51:59 +09:00
commit aca280b64d
1841 changed files with 753304 additions and 0 deletions

247
.venv/bin/Activate.ps1 Normal file
View File

@@ -0,0 +1,247 @@
<#
.Synopsis
Activate a Python virtual environment for the current PowerShell session.
.Description
Pushes the python executable for a virtual environment to the front of the
$Env:PATH environment variable and sets the prompt to signify that you are
in a Python virtual environment. Makes use of the command line switches as
well as the `pyvenv.cfg` file values present in the virtual environment.
.Parameter VenvDir
Path to the directory that contains the virtual environment to activate. The
default value for this is the parent of the directory that the Activate.ps1
script is located within.
.Parameter Prompt
The prompt prefix to display when this virtual environment is activated. By
default, this prompt is the name of the virtual environment folder (VenvDir)
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
.Example
Activate.ps1
Activates the Python virtual environment that contains the Activate.ps1 script.
.Example
Activate.ps1 -Verbose
Activates the Python virtual environment that contains the Activate.ps1 script,
and shows extra information about the activation as it executes.
.Example
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
Activates the Python virtual environment located in the specified location.
.Example
Activate.ps1 -Prompt "MyPython"
Activates the Python virtual environment that contains the Activate.ps1 script,
and prefixes the current prompt with the specified string (surrounded in
parentheses) while the virtual environment is active.
.Notes
On Windows, it may be required to enable this Activate.ps1 script by setting the
execution policy for the user. You can do this by issuing the following PowerShell
command:
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
For more information on Execution Policies:
https://go.microsoft.com/fwlink/?LinkID=135170
#>
Param(
[Parameter(Mandatory = $false)]
[String]
$VenvDir,
[Parameter(Mandatory = $false)]
[String]
$Prompt
)
<# Function declarations --------------------------------------------------- #>
<#
.Synopsis
Remove all shell session elements added by the Activate script, including the
addition of the virtual environment's Python executable from the beginning of
the PATH variable.
.Parameter NonDestructive
If present, do not remove this function from the global namespace for the
session.
#>
function global:deactivate ([switch]$NonDestructive) {
# Revert to original values
# The prior prompt:
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
}
# The prior PYTHONHOME:
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
}
# The prior PATH:
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
}
# Just remove the VIRTUAL_ENV altogether:
if (Test-Path -Path Env:VIRTUAL_ENV) {
Remove-Item -Path env:VIRTUAL_ENV
}
# Just remove VIRTUAL_ENV_PROMPT altogether.
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
}
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
}
# Leave deactivate function in the global namespace if requested:
if (-not $NonDestructive) {
Remove-Item -Path function:deactivate
}
}
<#
.Description
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
given folder, and returns them in a map.
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
two strings separated by `=` (with any amount of whitespace surrounding the =)
then it is considered a `key = value` line. The left hand string is the key,
the right hand is the value.
If the value starts with a `'` or a `"` then the first and last character is
stripped from the value before being captured.
.Parameter ConfigDir
Path to the directory that contains the `pyvenv.cfg` file.
#>
function Get-PyVenvConfig(
[String]
$ConfigDir
) {
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
# An empty map will be returned if no config file is found.
$pyvenvConfig = @{ }
if ($pyvenvConfigPath) {
Write-Verbose "File exists, parse `key = value` lines"
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
$pyvenvConfigContent | ForEach-Object {
$keyval = $PSItem -split "\s*=\s*", 2
if ($keyval[0] -and $keyval[1]) {
$val = $keyval[1]
# Remove extraneous quotations around a string value.
if ("'""".Contains($val.Substring(0, 1))) {
$val = $val.Substring(1, $val.Length - 2)
}
$pyvenvConfig[$keyval[0]] = $val
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
}
}
}
return $pyvenvConfig
}
<# Begin Activate script --------------------------------------------------- #>
# Determine the containing directory of this script
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
$VenvExecDir = Get-Item -Path $VenvExecPath
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
# Set values required in priority: CmdLine, ConfigFile, Default
# First, get the location of the virtual environment, it might not be
# VenvExecDir if specified on the command line.
if ($VenvDir) {
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
}
else {
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
Write-Verbose "VenvDir=$VenvDir"
}
# Next, read the `pyvenv.cfg` file to determine any required value such
# as `prompt`.
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
# Next, set the prompt from the command line, or the config file, or
# just use the name of the virtual environment folder.
if ($Prompt) {
Write-Verbose "Prompt specified as argument, using '$Prompt'"
}
else {
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
$Prompt = $pyvenvCfg['prompt'];
}
else {
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
$Prompt = Split-Path -Path $venvDir -Leaf
}
}
Write-Verbose "Prompt = '$Prompt'"
Write-Verbose "VenvDir='$VenvDir'"
# Deactivate any currently active virtual environment, but leave the
# deactivate function in place.
deactivate -nondestructive
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
# that there is an activated venv.
$env:VIRTUAL_ENV = $VenvDir
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
Write-Verbose "Setting prompt to '$Prompt'"
# Set the prompt to include the env name
# Make sure _OLD_VIRTUAL_PROMPT is global
function global:_OLD_VIRTUAL_PROMPT { "" }
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
function global:prompt {
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
_OLD_VIRTUAL_PROMPT
}
$env:VIRTUAL_ENV_PROMPT = $Prompt
}
# Clear PYTHONHOME
if (Test-Path -Path Env:PYTHONHOME) {
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
Remove-Item -Path Env:PYTHONHOME
}
# Add the venv to the PATH
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"

70
.venv/bin/activate Normal file
View File

@@ -0,0 +1,70 @@
# This file must be used with "source bin/activate" *from bash*
# You cannot run it directly
deactivate () {
# reset old environment variables
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
PATH="${_OLD_VIRTUAL_PATH:-}"
export PATH
unset _OLD_VIRTUAL_PATH
fi
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
export PYTHONHOME
unset _OLD_VIRTUAL_PYTHONHOME
fi
# Call hash to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
hash -r 2> /dev/null
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
PS1="${_OLD_VIRTUAL_PS1:-}"
export PS1
unset _OLD_VIRTUAL_PS1
fi
unset VIRTUAL_ENV
unset VIRTUAL_ENV_PROMPT
if [ ! "${1:-}" = "nondestructive" ] ; then
# Self destruct!
unset -f deactivate
fi
}
# unset irrelevant variables
deactivate nondestructive
# on Windows, a path can contain colons and backslashes and has to be converted:
if [ "${OSTYPE:-}" = "cygwin" ] || [ "${OSTYPE:-}" = "msys" ] ; then
# transform D:\path\to\venv to /d/path/to/venv on MSYS
# and to /cygdrive/d/path/to/venv on Cygwin
export VIRTUAL_ENV=$(cygpath /home/data/post_bot/.venv)
else
# use the path as-is
export VIRTUAL_ENV=/home/data/post_bot/.venv
fi
_OLD_VIRTUAL_PATH="$PATH"
PATH="$VIRTUAL_ENV/"bin":$PATH"
export PATH
# unset PYTHONHOME if set
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
# could use `if (set -u; : $PYTHONHOME) ;` in bash
if [ -n "${PYTHONHOME:-}" ] ; then
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
unset PYTHONHOME
fi
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
_OLD_VIRTUAL_PS1="${PS1:-}"
PS1='(.venv) '"${PS1:-}"
export PS1
VIRTUAL_ENV_PROMPT='(.venv) '
export VIRTUAL_ENV_PROMPT
fi
# Call hash to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
hash -r 2> /dev/null

27
.venv/bin/activate.csh Normal file
View File

@@ -0,0 +1,27 @@
# This file must be used with "source bin/activate.csh" *from csh*.
# You cannot run it directly.
# Created by Davide Di Blasi <davidedb@gmail.com>.
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
# Unset irrelevant variables.
deactivate nondestructive
setenv VIRTUAL_ENV /home/data/post_bot/.venv
set _OLD_VIRTUAL_PATH="$PATH"
setenv PATH "$VIRTUAL_ENV/"bin":$PATH"
set _OLD_VIRTUAL_PROMPT="$prompt"
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
set prompt = '(.venv) '"$prompt"
setenv VIRTUAL_ENV_PROMPT '(.venv) '
endif
alias pydoc python -m pydoc
rehash

69
.venv/bin/activate.fish Normal file
View File

@@ -0,0 +1,69 @@
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
# (https://fishshell.com/). You cannot run it directly.
function deactivate -d "Exit virtual environment and return to normal shell environment"
# reset old environment variables
if test -n "$_OLD_VIRTUAL_PATH"
set -gx PATH $_OLD_VIRTUAL_PATH
set -e _OLD_VIRTUAL_PATH
end
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
set -e _OLD_VIRTUAL_PYTHONHOME
end
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
set -e _OLD_FISH_PROMPT_OVERRIDE
# prevents error when using nested fish instances (Issue #93858)
if functions -q _old_fish_prompt
functions -e fish_prompt
functions -c _old_fish_prompt fish_prompt
functions -e _old_fish_prompt
end
end
set -e VIRTUAL_ENV
set -e VIRTUAL_ENV_PROMPT
if test "$argv[1]" != "nondestructive"
# Self-destruct!
functions -e deactivate
end
end
# Unset irrelevant variables.
deactivate nondestructive
set -gx VIRTUAL_ENV /home/data/post_bot/.venv
set -gx _OLD_VIRTUAL_PATH $PATH
set -gx PATH "$VIRTUAL_ENV/"bin $PATH
# Unset PYTHONHOME if set.
if set -q PYTHONHOME
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
set -e PYTHONHOME
end
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
# fish uses a function instead of an env var to generate the prompt.
# Save the current fish_prompt function as the function _old_fish_prompt.
functions -c fish_prompt _old_fish_prompt
# With the original prompt function renamed, we can override with our own.
function fish_prompt
# Save the return status of the last command.
set -l old_status $status
# Output the venv prompt; color taken from the blue of the Python logo.
printf "%s%s%s" (set_color 4B8BBE) '(.venv) ' (set_color normal)
# Restore the return status of the previous command.
echo "exit $old_status" | .
# Output the original/"old" prompt.
_old_fish_prompt
end
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
set -gx VIRTUAL_ENV_PROMPT '(.venv) '
end

8
.venv/bin/dotenv Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from dotenv.__main__ import cli
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(cli())

8
.venv/bin/httpx Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from httpx import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
.venv/bin/pip Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
.venv/bin/pip3 Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
.venv/bin/pip3.12 Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
.venv/bin/py.test Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(console_main())

8
.venv/bin/pygmentize Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pygments.cmdline import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

8
.venv/bin/pytest Executable file
View File

@@ -0,0 +1,8 @@
#!/home/data/post_bot/.venv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(console_main())

1
.venv/bin/python Symbolic link
View File

@@ -0,0 +1 @@
python3

1
.venv/bin/python3 Symbolic link
View File

@@ -0,0 +1 @@
/usr/bin/python3

1
.venv/bin/python3.12 Symbolic link
View File

@@ -0,0 +1 @@
python3

View File

@@ -0,0 +1,164 @@
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
/* Greenlet object interface */
#ifndef Py_GREENLETOBJECT_H
#define Py_GREENLETOBJECT_H
#include <Python.h>
#ifdef __cplusplus
extern "C" {
#endif
/* This is deprecated and undocumented. It does not change. */
#define GREENLET_VERSION "1.0.0"
#ifndef GREENLET_MODULE
#define implementation_ptr_t void*
#endif
typedef struct _greenlet {
PyObject_HEAD
PyObject* weakreflist;
PyObject* dict;
implementation_ptr_t pimpl;
} PyGreenlet;
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
/* C API functions */
/* Total number of symbols that are exported */
#define PyGreenlet_API_pointers 12
#define PyGreenlet_Type_NUM 0
#define PyExc_GreenletError_NUM 1
#define PyExc_GreenletExit_NUM 2
#define PyGreenlet_New_NUM 3
#define PyGreenlet_GetCurrent_NUM 4
#define PyGreenlet_Throw_NUM 5
#define PyGreenlet_Switch_NUM 6
#define PyGreenlet_SetParent_NUM 7
#define PyGreenlet_MAIN_NUM 8
#define PyGreenlet_STARTED_NUM 9
#define PyGreenlet_ACTIVE_NUM 10
#define PyGreenlet_GET_PARENT_NUM 11
#ifndef GREENLET_MODULE
/* This section is used by modules that uses the greenlet C API */
static void** _PyGreenlet_API = NULL;
# define PyGreenlet_Type \
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
# define PyExc_GreenletError \
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
# define PyExc_GreenletExit \
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
/*
* PyGreenlet_New(PyObject *args)
*
* greenlet.greenlet(run, parent=None)
*/
# define PyGreenlet_New \
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
_PyGreenlet_API[PyGreenlet_New_NUM])
/*
* PyGreenlet_GetCurrent(void)
*
* greenlet.getcurrent()
*/
# define PyGreenlet_GetCurrent \
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
/*
* PyGreenlet_Throw(
* PyGreenlet *greenlet,
* PyObject *typ,
* PyObject *val,
* PyObject *tb)
*
* g.throw(...)
*/
# define PyGreenlet_Throw \
(*(PyObject * (*)(PyGreenlet * self, \
PyObject * typ, \
PyObject * val, \
PyObject * tb)) \
_PyGreenlet_API[PyGreenlet_Throw_NUM])
/*
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
*
* g.switch(*args, **kwargs)
*/
# define PyGreenlet_Switch \
(*(PyObject * \
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
_PyGreenlet_API[PyGreenlet_Switch_NUM])
/*
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
*
* g.parent = new_parent
*/
# define PyGreenlet_SetParent \
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
/*
* PyGreenlet_GetParent(PyObject* greenlet)
*
* return greenlet.parent;
*
* This could return NULL even if there is no exception active.
* If it does not return NULL, you are responsible for decrementing the
* reference count.
*/
# define PyGreenlet_GetParent \
(*(PyGreenlet* (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
/*
* deprecated, undocumented alias.
*/
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
# define PyGreenlet_MAIN \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
# define PyGreenlet_STARTED \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
# define PyGreenlet_ACTIVE \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
/* Macro that imports greenlet and initializes C API */
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
keep the older definition to be sure older code that might have a copy of
the header still works. */
# define PyGreenlet_Import() \
{ \
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
}
#endif /* GREENLET_MODULE */
#ifdef __cplusplus
}
#endif
#endif /* !Py_GREENLETOBJECT_H */

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
__all__ = ["__version__", "version_tuple"]
try:
from ._version import version as __version__
from ._version import version_tuple
except ImportError: # pragma: no cover
# broken installation, we don't even try
# unknown only works because we do poor mans version compare
__version__ = "unknown"
version_tuple = (0, 0, "unknown")

View File

@@ -0,0 +1,117 @@
"""Allow bash-completion for argparse with argcomplete if installed.
Needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
to find the magic string, so _ARGCOMPLETE env. var is never set, and
this does not need special code).
Function try_argcomplete(parser) should be called directly before
the call to ArgumentParser.parse_args().
The filescompleter is what you normally would use on the positional
arguments specification, in order to get "dirname/" after "dirn<TAB>"
instead of the default "dirname ":
optparser.add_argument(Config._file_or_dir, nargs='*').completer=filescompleter
Other, application specific, completers should go in the file
doing the add_argument calls as they need to be specified as .completer
attributes as well. (If argcomplete is not installed, the function the
attribute points to will not be used).
SPEEDUP
=======
The generic argcomplete script for bash-completion
(/etc/bash_completion.d/python-argcomplete.sh)
uses a python program to determine startup script generated by pip.
You can speed up completion somewhat by changing this script to include
# PYTHON_ARGCOMPLETE_OK
so the python-argcomplete-check-easy-install-script does not
need to be called to find the entry point of the code and see if that is
marked with PYTHON_ARGCOMPLETE_OK.
INSTALL/DEBUGGING
=================
To include this support in another application that has setup.py generated
scripts:
- Add the line:
# PYTHON_ARGCOMPLETE_OK
near the top of the main python entry point.
- Include in the file calling parse_args():
from _argcomplete import try_argcomplete, filescompleter
Call try_argcomplete just before parse_args(), and optionally add
filescompleter to the positional arguments' add_argument().
If things do not work right away:
- Switch on argcomplete debugging with (also helpful when doing custom
completers):
export _ARC_DEBUG=1
- Run:
python-argcomplete-check-easy-install-script $(which appname)
echo $?
will echo 0 if the magic line has been found, 1 if not.
- Sometimes it helps to find early on errors using:
_ARGCOMPLETE=1 _ARC_DEBUG=1 appname
which should throw a KeyError: 'COMPLINE' (which is properly set by the
global argcomplete script).
"""
from __future__ import annotations
import argparse
from glob import glob
import os
import sys
from typing import Any
class FastFilesCompleter:
"""Fast file completer class."""
def __init__(self, directories: bool = True) -> None:
self.directories = directories
def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
# Only called on non option completions.
if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep)
else:
prefix_dir = 0
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
# We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
globbed.extend(glob(prefix))
for x in sorted(globbed):
if os.path.isdir(x):
x += "/"
# Append stripping the prefix (like bash, not like compgen).
completion.append(x[prefix_dir:])
return completion
if os.environ.get("_ARGCOMPLETE"):
try:
import argcomplete.completers
except ImportError:
sys.exit(-1)
filescompleter: FastFilesCompleter | None = FastFilesCompleter()
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
argcomplete.autocomplete(parser, always_complete_options=False)
else:
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
pass
filescompleter = None

View File

@@ -0,0 +1,26 @@
"""Python inspection/code generation API."""
from __future__ import annotations
from .code import Code
from .code import ExceptionInfo
from .code import filter_traceback
from .code import Frame
from .code import getfslineno
from .code import Traceback
from .code import TracebackEntry
from .source import getrawcode
from .source import Source
__all__ = [
"Code",
"ExceptionInfo",
"Frame",
"Source",
"Traceback",
"TracebackEntry",
"filter_traceback",
"getfslineno",
"getrawcode",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,225 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import ast
from bisect import bisect_right
from collections.abc import Iterable
from collections.abc import Iterator
import inspect
import textwrap
import tokenize
import types
from typing import overload
import warnings
class Source:
"""An immutable object holding a source code fragment.
When using Source(...), the source lines are deindented.
"""
def __init__(self, obj: object = None) -> None:
if not obj:
self.lines: list[str] = []
self.raw_lines: list[str] = []
elif isinstance(obj, Source):
self.lines = obj.lines
self.raw_lines = obj.raw_lines
elif isinstance(obj, (tuple, list)):
self.lines = deindent(x.rstrip("\n") for x in obj)
self.raw_lines = list(x.rstrip("\n") for x in obj)
elif isinstance(obj, str):
self.lines = deindent(obj.split("\n"))
self.raw_lines = obj.split("\n")
else:
try:
rawcode = getrawcode(obj)
src = inspect.getsource(rawcode)
except TypeError:
src = inspect.getsource(obj) # type: ignore[arg-type]
self.lines = deindent(src.split("\n"))
self.raw_lines = src.split("\n")
def __eq__(self, other: object) -> bool:
if not isinstance(other, Source):
return NotImplemented
return self.lines == other.lines
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@overload
def __getitem__(self, key: int) -> str: ...
@overload
def __getitem__(self, key: slice) -> Source: ...
def __getitem__(self, key: int | slice) -> str | Source:
if isinstance(key, int):
return self.lines[key]
else:
if key.step not in (None, 1):
raise IndexError("cannot slice a Source with a step")
newsource = Source()
newsource.lines = self.lines[key.start : key.stop]
newsource.raw_lines = self.raw_lines[key.start : key.stop]
return newsource
def __iter__(self) -> Iterator[str]:
return iter(self.lines)
def __len__(self) -> int:
return len(self.lines)
def strip(self) -> Source:
"""Return new Source object with trailing and leading blank lines removed."""
start, end = 0, len(self)
while start < end and not self.lines[start].strip():
start += 1
while end > start and not self.lines[end - 1].strip():
end -= 1
source = Source()
source.raw_lines = self.raw_lines
source.lines[:] = self.lines[start:end]
return source
def indent(self, indent: str = " " * 4) -> Source:
"""Return a copy of the source object with all lines indented by the
given indent-string."""
newsource = Source()
newsource.raw_lines = self.raw_lines
newsource.lines = [(indent + line) for line in self.lines]
return newsource
def getstatement(self, lineno: int) -> Source:
"""Return Source statement which contains the given linenumber
(counted from 0)."""
start, end = self.getstatementrange(lineno)
return self[start:end]
def getstatementrange(self, lineno: int) -> tuple[int, int]:
"""Return (start, end) tuple which spans the minimal statement region
which containing the given lineno."""
if not (0 <= lineno < len(self)):
raise IndexError("lineno out of range")
ast, start, end = getstatementrange_ast(lineno, self)
return start, end
def deindent(self) -> Source:
"""Return a new Source object deindented."""
newsource = Source()
newsource.lines[:] = deindent(self.lines)
newsource.raw_lines = self.raw_lines
return newsource
def __str__(self) -> str:
return "\n".join(self.lines)
#
# helper functions
#
def findsource(obj) -> tuple[Source | None, int]:
try:
sourcelines, lineno = inspect.findsource(obj)
except Exception:
return None, -1
source = Source()
source.lines = [line.rstrip() for line in sourcelines]
source.raw_lines = sourcelines
return source, lineno
def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
"""Return code object for given function."""
try:
return obj.__code__ # type: ignore[attr-defined,no-any-return]
except AttributeError:
pass
if trycall:
call = getattr(obj, "__call__", None)
if call and not isinstance(obj, type):
return getrawcode(call, trycall=False)
raise TypeError(f"could not get code object for {obj!r}")
def deindent(lines: Iterable[str]) -> list[str]:
return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
# Flatten all statements and except handlers into one lineno-list.
# AST's line numbers start indexing at 1.
values: list[int] = []
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# The lineno points to the class/def, so need to include the decorators.
if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
for d in x.decorator_list:
values.append(d.lineno - 1)
values.append(x.lineno - 1)
for name in ("finalbody", "orelse"):
val: list[ast.stmt] | None = getattr(x, name, None)
if val:
# Treat the finally/orelse part as its own statement.
values.append(val[0].lineno - 1 - 1)
values.sort()
insert_index = bisect_right(values, lineno)
start = values[insert_index - 1]
if insert_index >= len(values):
end = None
else:
end = values[insert_index]
return start, end
def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: ast.AST | None = None,
) -> tuple[ast.AST, int, int]:
if astnode is None:
content = str(source)
# See #4260:
# Don't produce duplicate warnings when compiling source to find AST.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = ast.parse(content, "source", "exec")
start, end = get_statement_startend2(lineno, astnode)
# We need to correct the end:
# - ast-parsing strips comments
# - there might be empty lines
# - we might have lesser indented code blocks at the end
if end is None:
end = len(source.lines)
if end > start + 1:
# Make sure we don't span differently indented code blocks
# by using the BlockFinder helper used which inspect.getsource() uses itself.
block_finder = inspect.BlockFinder()
# If we start with an indented line, put blockfinder to "started" mode.
block_finder.started = (
bool(source.lines[start]) and source.lines[start][0].isspace()
)
it = ((x + "\n") for x in source.lines[start:end])
try:
for tok in tokenize.generate_tokens(lambda: next(it)):
block_finder.tokeneater(*tok)
except (inspect.EndOfBlock, IndentationError):
end = block_finder.last + start
except Exception:
pass
# The end might still point to a comment or empty line, correct it.
while end:
line = source.lines[end - 1].lstrip()
if line.startswith("#") or not line:
end -= 1
else:
break
return astnode, start, end

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter
__all__ = [
"TerminalWriter",
"get_terminal_width",
]

View File

@@ -0,0 +1,673 @@
# mypy: allow-untyped-defs
# This module was imported from the cpython standard library
# (https://github.com/python/cpython/) at commit
# c5140945c723ae6c4b7ee81ff720ac8ea4b52cfd (python3.12).
#
#
# Original Author: Fred L. Drake, Jr.
# fdrake@acm.org
#
# This is a simple little module I wrote to make life easier. I didn't
# see anything quite like it in the library, though I may have overlooked
# something. I wrote this when I was trying to read some heavily nested
# tuples with fairly non-descriptive content. This is modeled very much
# after Lisp/Scheme - style pretty-printing of lists. If you find it
# useful, thank small children who sleep at night.
from __future__ import annotations
import collections as _collections
from collections.abc import Callable
from collections.abc import Iterator
import dataclasses as _dataclasses
from io import StringIO as _StringIO
import re
import types as _types
from typing import Any
from typing import IO
class _safe_key:
"""Helper function for key functions when sorting unorderable objects.
The wrapped-object will fallback to a Py2.x style comparison for
unorderable types (sorting first comparing the type name and then by
the obj ids). Does not work recursively, so dict.items() must have
_safe_key applied to both the key and the value.
"""
__slots__ = ["obj"]
def __init__(self, obj):
self.obj = obj
def __lt__(self, other):
try:
return self.obj < other.obj
except TypeError:
return (str(type(self.obj)), id(self.obj)) < (
str(type(other.obj)),
id(other.obj),
)
def _safe_tuple(t):
"""Helper function for comparing 2-tuples"""
return _safe_key(t[0]), _safe_key(t[1])
class PrettyPrinter:
def __init__(
self,
indent: int = 4,
width: int = 80,
depth: int | None = None,
) -> None:
"""Handle pretty printing operations onto a stream using a set of
configured parameters.
indent
Number of spaces to indent for each level of nesting.
width
Attempted maximum number of columns in the output.
depth
The maximum depth to print out nested structures.
"""
if indent < 0:
raise ValueError("indent must be >= 0")
if depth is not None and depth <= 0:
raise ValueError("depth must be > 0")
if not width:
raise ValueError("width must be != 0")
self._depth = depth
self._indent_per_level = indent
self._width = width
def pformat(self, object: Any) -> str:
sio = _StringIO()
self._format(object, sio, 0, 0, set(), 0)
return sio.getvalue()
def _format(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
objid = id(object)
if objid in context:
stream.write(_recursion(object))
return
p = self._dispatch.get(type(object).__repr__, None)
if p is not None:
context.add(objid)
p(self, object, stream, indent, allowance, context, level + 1)
context.remove(objid)
elif (
_dataclasses.is_dataclass(object)
and not isinstance(object, type)
and object.__dataclass_params__.repr # type:ignore[attr-defined]
and
# Check dataclass has generated repr method.
hasattr(object.__repr__, "__wrapped__")
and "__create_fn__" in object.__repr__.__wrapped__.__qualname__
):
context.add(objid)
self._pprint_dataclass(
object, stream, indent, allowance, context, level + 1
)
context.remove(objid)
else:
stream.write(self._repr(object, context, level))
def _pprint_dataclass(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
cls_name = object.__class__.__name__
items = [
(f.name, getattr(object, f.name))
for f in _dataclasses.fields(object)
if f.repr
]
stream.write(cls_name + "(")
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
_dispatch: dict[
Callable[..., str],
Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
] = {}
def _pprint_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
write("{")
items = sorted(object.items(), key=_safe_tuple)
self._format_dict_items(items, stream, indent, allowance, context, level)
write("}")
_dispatch[dict.__repr__] = _pprint_dict
def _pprint_ordered_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
cls = object.__class__
stream.write(cls.__name__ + "(")
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
def _pprint_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("[")
self._format_items(object, stream, indent, allowance, context, level)
stream.write("]")
_dispatch[list.__repr__] = _pprint_list
def _pprint_tuple(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("(")
self._format_items(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[tuple.__repr__] = _pprint_tuple
def _pprint_set(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object):
stream.write(repr(object))
return
typ = object.__class__
if typ is set:
stream.write("{")
endchar = "}"
else:
stream.write(typ.__name__ + "({")
endchar = "})"
object = sorted(object, key=_safe_key)
self._format_items(object, stream, indent, allowance, context, level)
stream.write(endchar)
_dispatch[set.__repr__] = _pprint_set
_dispatch[frozenset.__repr__] = _pprint_set
def _pprint_str(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
if not len(object):
write(repr(object))
return
chunks = []
lines = object.splitlines(True)
if level == 1:
indent += 1
allowance += 1
max_width1 = max_width = self._width - indent
for i, line in enumerate(lines):
rep = repr(line)
if i == len(lines) - 1:
max_width1 -= allowance
if len(rep) <= max_width1:
chunks.append(rep)
else:
# A list of alternating (non-space, space) strings
parts = re.findall(r"\S*\s*", line)
assert parts
assert not parts[-1]
parts.pop() # drop empty last part
max_width2 = max_width
current = ""
for j, part in enumerate(parts):
candidate = current + part
if j == len(parts) - 1 and i == len(lines) - 1:
max_width2 -= allowance
if len(repr(candidate)) > max_width2:
if current:
chunks.append(repr(current))
current = part
else:
current = candidate
if current:
chunks.append(repr(current))
if len(chunks) == 1:
write(rep)
return
if level == 1:
write("(")
for i, rep in enumerate(chunks):
if i > 0:
write("\n" + " " * indent)
write(rep)
if level == 1:
write(")")
_dispatch[str.__repr__] = _pprint_str
def _pprint_bytes(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
if len(object) <= 4:
write(repr(object))
return
parens = level == 1
if parens:
indent += 1
allowance += 1
write("(")
delim = ""
for rep in _wrap_bytes_repr(object, self._width - indent, allowance):
write(delim)
write(rep)
if not delim:
delim = "\n" + " " * indent
if parens:
write(")")
_dispatch[bytes.__repr__] = _pprint_bytes
def _pprint_bytearray(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
write = stream.write
write("bytearray(")
self._pprint_bytes(
bytes(object), stream, indent + 10, allowance + 1, context, level + 1
)
write(")")
_dispatch[bytearray.__repr__] = _pprint_bytearray
def _pprint_mappingproxy(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write("mappingproxy(")
self._format(object.copy(), stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
def _pprint_simplenamespace(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if type(object) is _types.SimpleNamespace:
# The SimpleNamespace repr is "namespace" instead of the class
# name, so we do the same here. For subclasses; use the class name.
cls_name = "namespace"
else:
cls_name = object.__class__.__name__
items = object.__dict__.items()
stream.write(cls_name + "(")
self._format_namespace_items(items, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
def _format_dict_items(
self,
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for key, ent in items:
write(delimnl)
write(self._repr(key, context, level))
write(": ")
self._format(ent, stream, item_indent, 1, context, level)
write(",")
write("\n" + " " * indent)
def _format_namespace_items(
self,
items: list[tuple[Any, Any]],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for key, ent in items:
write(delimnl)
write(key)
write("=")
if id(ent) in context:
# Special-case representation of recursion to match standard
# recursive dataclass repr.
write("...")
else:
self._format(
ent,
stream,
item_indent + len(key) + 1,
1,
context,
level,
)
write(",")
write("\n" + " " * indent)
def _format_items(
self,
items: list[Any],
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not items:
return
write = stream.write
item_indent = indent + self._indent_per_level
delimnl = "\n" + " " * item_indent
for item in items:
write(delimnl)
self._format(item, stream, item_indent, 1, context, level)
write(",")
write("\n" + " " * indent)
def _repr(self, object: Any, context: set[int], level: int) -> str:
return self._safe_repr(object, context.copy(), self._depth, level)
def _pprint_default_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
rdf = self._repr(object.default_factory, context, level)
stream.write(f"{object.__class__.__name__}({rdf}, ")
self._pprint_dict(object, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
def _pprint_counter(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object:
stream.write("{")
items = object.most_common()
self._format_dict_items(items, stream, indent, allowance, context, level)
stream.write("}")
stream.write(")")
_dispatch[_collections.Counter.__repr__] = _pprint_counter
def _pprint_chain_map(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
stream.write(repr(object))
return
stream.write(object.__class__.__name__ + "(")
self._format_items(object.maps, stream, indent, allowance, context, level)
stream.write(")")
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
def _pprint_deque(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
stream.write(object.__class__.__name__ + "(")
if object.maxlen is not None:
stream.write(f"maxlen={object.maxlen}, ")
stream.write("[")
self._format_items(object, stream, indent, allowance + 1, context, level)
stream.write("])")
_dispatch[_collections.deque.__repr__] = _pprint_deque
def _pprint_user_dict(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserDict.__repr__] = _pprint_user_dict
def _pprint_user_list(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserList.__repr__] = _pprint_user_list
def _pprint_user_string(
self,
object: Any,
stream: IO[str],
indent: int,
allowance: int,
context: set[int],
level: int,
) -> None:
self._format(object.data, stream, indent, allowance, context, level - 1)
_dispatch[_collections.UserString.__repr__] = _pprint_user_string
def _safe_repr(
self, object: Any, context: set[int], maxlevels: int | None, level: int
) -> str:
typ = type(object)
if typ in _builtin_scalars:
return repr(object)
r = getattr(typ, "__repr__", None)
if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}"
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}"
if objid in context:
return _recursion(object)
context.add(objid)
components: list[str] = []
append = components.append
level += 1
for k, v in sorted(object.items(), key=_safe_tuple):
krepr = self._safe_repr(k, context, maxlevels, level)
vrepr = self._safe_repr(v, context, maxlevels, level)
append(f"{krepr}: {vrepr}")
context.remove(objid)
return "{{{}}}".format(", ".join(components))
if (issubclass(typ, list) and r is list.__repr__) or (
issubclass(typ, tuple) and r is tuple.__repr__
):
if issubclass(typ, list):
if not object:
return "[]"
format = "[%s]"
elif len(object) == 1:
format = "(%s,)"
else:
if not object:
return "()"
format = "(%s)"
objid = id(object)
if maxlevels and level >= maxlevels:
return format % "..."
if objid in context:
return _recursion(object)
context.add(objid)
components = []
append = components.append
level += 1
for o in object:
orepr = self._safe_repr(o, context, maxlevels, level)
append(orepr)
context.remove(objid)
return format % ", ".join(components)
return repr(object)
_builtin_scalars = frozenset(
{str, bytes, bytearray, float, complex, bool, type(None), int}
)
def _recursion(object: Any) -> str:
return f"<Recursion on {type(object).__name__} with id={id(object)}>"
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
current = b""
last = len(object) // 4 * 4
for i in range(0, len(object), 4):
part = object[i : i + 4]
candidate = current + part
if i == last:
width -= allowance
if len(repr(candidate)) > width:
if current:
yield repr(current)
current = part
else:
current = candidate
if current:
yield repr(current)

View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import pprint
import reprlib
def _try_repr_or_str(obj: object) -> str:
try:
return repr(obj)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException:
return f'{type(obj).__name__}("{obj}")'
def _format_repr_exception(exc: BaseException, obj: object) -> str:
try:
exc_info = _try_repr_or_str(exc)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as inner_exc:
exc_info = f"unpresentable exception ({_try_repr_or_str(inner_exc)})"
return (
f"<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>"
)
def _ellipsize(s: str, maxsize: int) -> str:
if len(s) > maxsize:
i = max(0, (maxsize - 3) // 2)
j = max(0, maxsize - 3 - i)
return s[:i] + "..." + s[len(s) - j :]
return s
class SafeRepr(reprlib.Repr):
"""
repr.Repr that limits the resulting size of repr() and includes
information on exceptions raised during the call.
"""
def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
"""
:param maxsize:
If not None, will truncate the resulting repr to that specific size, using ellipsis
somewhere in the middle to hide the extra text.
If None, will not impose any size limits on the returning repr.
"""
super().__init__()
# ``maxstring`` is used by the superclass, and needs to be an int; using a
# very large number in case maxsize is None, meaning we want to disable
# truncation.
self.maxstring = maxsize if maxsize is not None else 1_000_000_000
self.maxsize = maxsize
self.use_ascii = use_ascii
def repr(self, x: object) -> str:
try:
if self.use_ascii:
s = ascii(x)
else:
s = super().repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def repr_instance(self, x: object, level: int) -> str:
try:
s = repr(x)
except (KeyboardInterrupt, SystemExit):
raise
except BaseException as exc:
s = _format_repr_exception(exc, x)
if self.maxsize is not None:
s = _ellipsize(s, self.maxsize)
return s
def safeformat(obj: object) -> str:
"""Return a pretty printed string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info.
"""
try:
return pprint.pformat(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)
# Maximum size of overall repr of objects to display during assertion errors.
DEFAULT_REPR_MAX_SIZE = 240
def saferepr(
obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
) -> str:
"""Return a size-limited safe repr-string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info and 'saferepr' generally takes
care to never raise exceptions itself.
This function is a wrapper around the Repr/reprlib functionality of the
stdlib.
"""
return SafeRepr(maxsize, use_ascii).repr(obj)
def saferepr_unlimited(obj: object, use_ascii: bool = True) -> str:
"""Return an unlimited-size safe repr-string for the given object.
As with saferepr, failing __repr__ functions of user instances
will be represented with a short exception info.
This function is a wrapper around simple repr.
Note: a cleaner solution would be to alter ``saferepr``this way
when maxsize=None, but that might affect some other code.
"""
try:
if use_ascii:
return ascii(obj)
return repr(obj)
except Exception as exc:
return _format_repr_exception(exc, obj)

View File

@@ -0,0 +1,254 @@
"""Helper functions for writing to terminals and files."""
from __future__ import annotations
from collections.abc import Sequence
import os
import shutil
import sys
from typing import final
from typing import Literal
from typing import TextIO
import pygments
from pygments.formatters.terminal import TerminalFormatter
from pygments.lexer import Lexer
from pygments.lexers.diff import DiffLexer
from pygments.lexers.python import PythonLexer
from ..compat import assert_never
from .wcwidth import wcswidth
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
def get_terminal_width() -> int:
width, _ = shutil.get_terminal_size(fallback=(80, 24))
# The Windows get_terminal_size may be bogus, let's sanify a bit.
if width < 40:
width = 80
return width
def should_do_markup(file: TextIO) -> bool:
if os.environ.get("PY_COLORS") == "1":
return True
if os.environ.get("PY_COLORS") == "0":
return False
if os.environ.get("NO_COLOR"):
return False
if os.environ.get("FORCE_COLOR"):
return True
return (
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
)
@final
class TerminalWriter:
_esctable = dict(
black=30,
red=31,
green=32,
yellow=33,
blue=34,
purple=35,
cyan=36,
white=37,
Black=40,
Red=41,
Green=42,
Yellow=43,
Blue=44,
Purple=45,
Cyan=46,
White=47,
bold=1,
light=2,
blink=5,
invert=7,
)
def __init__(self, file: TextIO | None = None) -> None:
if file is None:
file = sys.stdout
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
try:
import colorama
except ImportError:
pass
else:
file = colorama.AnsiToWin32(file).stream
assert file is not None
self._file = file
self.hasmarkup = should_do_markup(file)
self._current_line = ""
self._terminal_width: int | None = None
self.code_highlight = True
@property
def fullwidth(self) -> int:
if self._terminal_width is not None:
return self._terminal_width
return get_terminal_width()
@fullwidth.setter
def fullwidth(self, value: int) -> None:
self._terminal_width = value
@property
def width_of_current_line(self) -> int:
"""Return an estimate of the width so far in the current line."""
return wcswidth(self._current_line)
def markup(self, text: str, **markup: bool) -> str:
for name in markup:
if name not in self._esctable:
raise ValueError(f"unknown markup: {name!r}")
if self.hasmarkup:
esc = [self._esctable[name] for name, on in markup.items() if on]
if esc:
text = "".join(f"\x1b[{cod}m" for cod in esc) + text + "\x1b[0m"
return text
def sep(
self,
sepchar: str,
title: str | None = None,
fullwidth: int | None = None,
**markup: bool,
) -> None:
if fullwidth is None:
fullwidth = self.fullwidth
# The goal is to have the line be as long as possible
# under the condition that len(line) <= fullwidth.
if sys.platform == "win32":
# If we print in the last column on windows we are on a
# new line but there is no way to verify/neutralize this
# (we may not know the exact line width).
# So let's be defensive to avoid empty lines in the output.
fullwidth -= 1
if title is not None:
# we want 2 + 2*len(fill) + len(title) <= fullwidth
# i.e. 2 + 2*len(sepchar)*N + len(title) <= fullwidth
# 2*len(sepchar)*N <= fullwidth - len(title) - 2
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
fill = sepchar * N
line = f"{fill} {title} {fill}"
else:
# we want len(sepchar)*N <= fullwidth
# i.e. N <= fullwidth // len(sepchar)
line = sepchar * (fullwidth // len(sepchar))
# In some situations there is room for an extra sepchar at the right,
# in particular if we consider that with a sepchar like "_ " the
# trailing space is not important at the end of the line.
if len(line) + len(sepchar.rstrip()) <= fullwidth:
line += sepchar.rstrip()
self.line(line, **markup)
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
if msg:
current_line = msg.rsplit("\n", 1)[-1]
if "\n" in msg:
self._current_line = current_line
else:
self._current_line += current_line
msg = self.markup(msg, **markup)
try:
self._file.write(msg)
except UnicodeEncodeError:
# Some environments don't support printing general Unicode
# strings, due to misconfiguration or otherwise; in that case,
# print the string escaped to ASCII.
# When the Unicode situation improves we should consider
# letting the error propagate instead of masking it (see #7475
# for one brief attempt).
msg = msg.encode("unicode-escape").decode("ascii")
self._file.write(msg)
if flush:
self.flush()
def line(self, s: str = "", **markup: bool) -> None:
self.write(s, **markup)
self.write("\n")
def flush(self) -> None:
self._file.flush()
def _write_source(self, lines: Sequence[str], indents: Sequence[str] = ()) -> None:
"""Write lines of source code possibly highlighted.
Keeping this private for now because the API is clunky. We should discuss how
to evolve the terminal writer so we can have more precise color support, for example
being able to write part of a line in one color and the rest in another, and so on.
"""
if indents and len(indents) != len(lines):
raise ValueError(
f"indents size ({len(indents)}) should have same size as lines ({len(lines)})"
)
if not indents:
indents = [""] * len(lines)
source = "\n".join(lines)
new_lines = self._highlight(source).splitlines()
for indent, new_line in zip(indents, new_lines):
self.line(indent + new_line)
def _get_pygments_lexer(self, lexer: Literal["python", "diff"]) -> Lexer:
if lexer == "python":
return PythonLexer()
elif lexer == "diff":
return DiffLexer()
else:
assert_never(lexer)
def _get_pygments_formatter(self) -> TerminalFormatter:
from _pytest.config.exceptions import UsageError
theme = os.getenv("PYTEST_THEME")
theme_mode = os.getenv("PYTEST_THEME_MODE", "dark")
try:
return TerminalFormatter(bg=theme_mode, style=theme)
except pygments.util.ClassNotFound as e:
raise UsageError(
f"PYTEST_THEME environment variable has an invalid value: '{theme}'. "
"Hint: See available pygments styles with `pygmentize -L styles`."
) from e
except pygments.util.OptionError as e:
raise UsageError(
f"PYTEST_THEME_MODE environment variable has an invalid value: '{theme_mode}'. "
"The allowed values are 'dark' (default) and 'light'."
) from e
def _highlight(
self, source: str, lexer: Literal["diff", "python"] = "python"
) -> str:
"""Highlight the given source if we have markup support."""
if not source or not self.hasmarkup or not self.code_highlight:
return source
pygments_lexer = self._get_pygments_lexer(lexer)
pygments_formatter = self._get_pygments_formatter()
highlighted: str = pygments.highlight(
source, pygments_lexer, pygments_formatter
)
# pygments terminal formatter may add a newline when there wasn't one.
# We don't want this, remove.
if highlighted[-1] == "\n" and source[-1] != "\n":
highlighted = highlighted[:-1]
# Some lexers will not set the initial color explicitly
# which may lead to the previous color being propagated to the
# start of the expression, so reset first.
highlighted = "\x1b[0m" + highlighted
return highlighted

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from functools import lru_cache
import unicodedata
@lru_cache(100)
def wcwidth(c: str) -> int:
"""Determine how many columns are needed to display a character in a terminal.
Returns -1 if the character is not printable.
Returns 0, 1 or 2 for other characters.
"""
o = ord(c)
# ASCII fast path.
if 0x20 <= o < 0x07F:
return 1
# Some Cf/Zp/Zl characters which should be zero-width.
if (
o == 0x0000
or 0x200B <= o <= 0x200F
or 0x2028 <= o <= 0x202E
or 0x2060 <= o <= 0x2063
):
return 0
category = unicodedata.category(c)
# Control characters.
if category == "Cc":
return -1
# Combining characters with zero width.
if category in ("Me", "Mn"):
return 0
# Full/Wide east asian characters.
if unicodedata.east_asian_width(c) in ("F", "W"):
return 2
return 1
def wcswidth(s: str) -> int:
"""Determine how many columns are needed to display a string in a terminal.
Returns -1 if the string contains non-printable characters.
"""
width = 0
for c in unicodedata.normalize("NFC", s):
wc = wcwidth(c)
if wc < 0:
return -1
width += wc
return width

View File

@@ -0,0 +1,119 @@
"""create errno-specific classes for IO or os calls."""
from __future__ import annotations
from collections.abc import Callable
import errno
import os
import sys
from typing import TYPE_CHECKING
from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
class Error(EnvironmentError):
def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format(
self.__class__.__module__,
self.__class__.__name__,
self.__class__.__doc__,
" ".join(map(str, self.args)),
# repr(self.args)
)
def __str__(self) -> str:
s = "[{}]: {}".format(
self.__class__.__doc__,
" ".join(map(str, self.args)),
)
return s
_winerrnomap = {
2: errno.ENOENT,
3: errno.ENOENT,
17: errno.EEXIST,
18: errno.EXDEV,
13: errno.EBUSY, # empty cd drive, but ENOMEDIUM seems unavailable
22: errno.ENOTDIR,
20: errno.ENOTDIR,
267: errno.ENOTDIR,
5: errno.EACCES, # anything better?
}
class ErrorMaker:
"""lazily provides Exception classes for each possible POSIX errno
(as defined per the 'errno' module). All such instances
subclass EnvironmentError.
"""
_errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_":
raise AttributeError(name)
eno = getattr(errno, name)
cls = self._geterrnoclass(eno)
setattr(self, name, cls)
return cls
def _geterrnoclass(self, eno: int) -> type[Error]:
try:
return self._errno2class[eno]
except KeyError:
clsname = errno.errorcode.get(eno, f"UnknownErrno{eno}")
errorcls = type(
clsname,
(Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)},
)
self._errno2class[eno] = errorcls
return errorcls
def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
"""Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True
try:
return func(*args, **kwargs)
except Error:
raise
except OSError as value:
if not hasattr(value, "errno"):
raise
if sys.platform == "win32":
try:
# error: Invalid index type "Optional[int]" for "dict[int, int]"; expected type "int" [index]
# OK to ignore because we catch the KeyError below.
cls = self._geterrnoclass(_winerrnomap[value.errno]) # type:ignore[index]
except KeyError:
raise value
else:
# we are not on Windows, or we got a proper OSError
if value.errno is None:
cls = type(
"UnknownErrnoNone",
(Error,),
{"__module__": "py.error", "__doc__": None},
)
else:
cls = self._geterrnoclass(value.errno)
raise cls(f"{func.__name__}{args!r}")
_error_maker = ErrorMaker()
checked_call = _error_maker.checked_call
def __getattr__(attr: str) -> type[Error]:
return getattr(_error_maker, attr) # type: ignore[no-any-return]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
else:
VERSION_TUPLE = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
__version__ = version = '8.4.1'
__version_tuple__ = version_tuple = (8, 4, 1)

View File

@@ -0,0 +1,208 @@
# mypy: allow-untyped-defs
"""Support for presenting detailed information in failing assertions."""
from __future__ import annotations
from collections.abc import Generator
import sys
from typing import Any
from typing import Protocol
from typing import TYPE_CHECKING
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
from _pytest.assertion.rewrite import assertstate_key
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
if TYPE_CHECKING:
from _pytest.main import Session
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
action="store",
dest="assertmode",
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
help=(
"Control assertion debugging tools.\n"
"'plain' performs no assertion debugging.\n"
"'rewrite' (the default) rewrites assert statements in test modules"
" on import to provide assert expression information."
),
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook. "
"Make sure to delete any previously generated pyc cache files.",
)
parser.addini(
"truncation_limit_lines",
default=None,
help="Set threshold of LINES after which truncation will take effect",
)
parser.addini(
"truncation_limit_chars",
default=None,
help=("Set threshold of CHARS after which truncation will take effect"),
)
Config._add_verbosity_ini(
parser,
Config.VERBOSITY_ASSERTIONS,
help=(
"Specify a verbosity level for assertions, overriding the main level. "
"Higher levels will provide more detailed explanation when an assertion fails."
),
)
def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
the package will get their assert statements rewritten.
Thus you should make sure to call this before the module is
actually imported, usually in your __init__.py if you are a plugin
using a package.
:param names: The module names to register.
"""
for name in names:
if not isinstance(name, str):
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
rewrite_hook: RewriteHook
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
rewrite_hook = hook
break
else:
rewrite_hook = DummyRewriteHook()
rewrite_hook.mark_rewrite(*names)
class RewriteHook(Protocol):
def mark_rewrite(self, *names: str) -> None: ...
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook: rewrite.AssertionRewritingHook | None = None
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config.stash[assertstate_key] = AssertionState(config, "rewrite")
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config.stash[assertstate_key].trace("installed rewrite import hook")
def undo() -> None:
hook = config.stash[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
config.add_cleanup(undo)
return hook
def pytest_collection(session: Session) -> None:
# This hook is only called when test modules are collected
# so for example not in the managing process of pytest-xdist
# (which does not collect test modules).
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(session)
@hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
The rewrite module will use util._reprcompare if it exists to use custom
reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> str | None:
"""Call the pytest_assertrepr_compare hook and prepare the result.
This uses the first result from the hook and then ensures the
following:
* Overly verbose explanations are truncated unless configured otherwise
(eg. if running in verbose mode).
* Embedded newlines are escaped to help util.format_explanation()
later.
* If the rewrite mode is used embedded %-characters are replaced
to protect later % formatting.
The result can be formatted by util.format_explanation() for
pretty printing.
"""
hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
)
for new_expl in hook_result:
if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl]
res = "\n~".join(new_expl)
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
return None
saved_assert_hooks = util._reprcompare, util._assertion_pass
util._reprcompare = callbinrepr
util._config = item.config
if ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
util._assertion_pass = call_assertion_pass_hook
try:
return (yield)
finally:
util._reprcompare, util._assertion_pass = saved_assert_hooks
util._config = None
def pytest_sessionfinish(session: Session) -> None:
assertstate = session.config.stash.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(None)
def pytest_assertrepr_compare(
config: Config, op: str, left: Any, right: Any
) -> list[str] | None:
return util.assertrepr_compare(config=config, op=op, left=left, right=right)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
"""
from __future__ import annotations
from _pytest.assertion import util
from _pytest.config import Config
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80
USAGE_MSG = "use '-vv' to show"
def truncate_if_required(explanation: list[str], item: Item) -> list[str]:
"""Truncate this assertion explanation if the given test item is eligible."""
should_truncate, max_lines, max_chars = _get_truncation_parameters(item)
if should_truncate:
return _truncate_explanation(
explanation,
max_lines=max_lines,
max_chars=max_chars,
)
return explanation
def _get_truncation_parameters(item: Item) -> tuple[bool, int, int]:
"""Return the truncation parameters related to the given item, as (should truncate, max lines, max chars)."""
# We do not need to truncate if one of conditions is met:
# 1. Verbosity level is 2 or more;
# 2. Test is being run in CI environment;
# 3. Both truncation_limit_lines and truncation_limit_chars
# .ini parameters are set to 0 explicitly.
max_lines = item.config.getini("truncation_limit_lines")
max_lines = int(max_lines if max_lines is not None else DEFAULT_MAX_LINES)
max_chars = item.config.getini("truncation_limit_chars")
max_chars = int(max_chars if max_chars is not None else DEFAULT_MAX_CHARS)
verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
should_truncate = verbose < 2 and not util.running_on_ci()
should_truncate = should_truncate and (max_lines > 0 or max_chars > 0)
return should_truncate, max_lines, max_chars
def _truncate_explanation(
input_lines: list[str],
max_lines: int,
max_chars: int,
) -> list[str]:
"""Truncate given list of strings that makes up the assertion explanation.
Truncates to either max_lines, or max_chars - whichever the input reaches
first, taking the truncation explanation into account. The remaining lines
will be replaced by a usage message.
"""
# Check if truncation required
input_char_count = len("".join(input_lines))
# The length of the truncation explanation depends on the number of lines
# removed but is at least 68 characters:
# The real value is
# 64 (for the base message:
# '...\n...Full output truncated (1 line hidden), use '-vv' to show")'
# )
# + 1 (for plural)
# + int(math.log10(len(input_lines) - max_lines)) (number of hidden line, at least 1)
# + 3 for the '...' added to the truncated line
# But if there's more than 100 lines it's very likely that we're going to
# truncate, so we don't need the exact value using log10.
tolerable_max_chars = (
max_chars + 70 # 64 + 1 (for plural) + 2 (for '99') + 3 for '...'
)
# The truncation explanation add two lines to the output
tolerable_max_lines = max_lines + 2
if (
len(input_lines) <= tolerable_max_lines
and input_char_count <= tolerable_max_chars
):
return input_lines
# Truncate first to max_lines, and then truncate to max_chars if necessary
if max_lines > 0:
truncated_explanation = input_lines[:max_lines]
else:
truncated_explanation = input_lines
truncated_char = True
# We reevaluate the need to truncate chars following removal of some lines
if len("".join(truncated_explanation)) > tolerable_max_chars and max_chars > 0:
truncated_explanation = _truncate_by_char_count(
truncated_explanation, max_chars
)
else:
truncated_char = False
if truncated_explanation == input_lines:
# No truncation happened, so we do not need to add any explanations
return truncated_explanation
truncated_line_count = len(input_lines) - len(truncated_explanation)
if truncated_explanation[-1]:
# Add ellipsis and take into account part-truncated final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
if truncated_char:
# It's possible that we did not remove any char from this line
truncated_line_count += 1
else:
# Add proper ellipsis when we were able to fit a full line exactly
truncated_explanation[-1] = "..."
return [
*truncated_explanation,
"",
f"...Full output truncated ({truncated_line_count} line"
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
]
def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
# Find point at which input length exceeds total allowed length
iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines):
if iterated_char_count + len(input_line) > max_chars:
break
iterated_char_count += len(input_line)
# Create truncated explanation with modified final line
truncated_result = input_lines[:iterated_index]
final_line = input_lines[iterated_index]
if final_line:
final_line_truncate_point = max_chars - iterated_char_count
final_line = final_line[:final_line_truncate_point]
truncated_result.append(final_line)
return truncated_result

View File

@@ -0,0 +1,621 @@
# mypy: allow-untyped-defs
"""Utilities for assertion debugging."""
from __future__ import annotations
import collections.abc
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Set as AbstractSet
import os
import pprint
from typing import Any
from typing import Literal
from typing import Protocol
from unicodedata import normalize
from _pytest import outcomes
import _pytest._code
from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
from _pytest.config import Config
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
_reprcompare: Callable[[str, object, object], str | None] | None = None
# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass: Callable[[int, str, str], None] | None = None
# Config object which is assigned during pytest_runtest_protocol.
_config: Config | None = None
class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Apply highlighting to the given source."""
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Dummy highlighter that returns the text unprocessed.
Needed for _notin_text, as the diff gets post-processed to only show the "+" part.
"""
return source
def format_explanation(explanation: str) -> str:
r"""Format an explanation.
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
cover nested explanations, see function and attribute explanations
for examples (.visit_Call(), visit_Attribute()). The last one is
for when one explanation needs to span multiple lines, e.g. when
displaying diffs.
"""
lines = _split_explanation(explanation)
result = _format_lines(lines)
return "\n".join(result)
def _split_explanation(explanation: str) -> list[str]:
r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
literal '\n' characters.
"""
raw_lines = (explanation or "").split("\n")
lines = [raw_lines[0]]
for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]:
lines.append(values)
else:
lines[-1] += "\\n" + values
return lines
def _format_lines(lines: Sequence[str]) -> list[str]:
"""Format the individual lines.
This will replace the '{', '}' and '~' characters of our mini formatting
language with the proper 'where ...', 'and ...' and ' + ...' text, taking
care of indentation along the way.
Return a list of formatted lines.
"""
result = list(lines[:1])
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith("{"):
if stackcnt[-1]:
s = "and "
else:
s = "where "
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"):
stack.pop()
stackcnt.pop()
result[stack[-1]] += line[1:]
else:
assert line[0] in ["~", ">"]
stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1
result.append(" " * indent + line[1:])
assert len(stack) == 1
return result
def issequence(x: Any) -> bool:
return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
def istext(x: Any) -> bool:
return isinstance(x, str)
def isdict(x: Any) -> bool:
return isinstance(x, dict)
def isset(x: Any) -> bool:
return isinstance(x, (set, frozenset))
def isnamedtuple(obj: Any) -> bool:
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None
def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None
def isiterable(obj: Any) -> bool:
try:
iter(obj)
return not istext(obj)
except Exception:
return False
def has_default_eq(
obj: object,
) -> bool:
"""Check if an instance of an object contains the default eq
First, we check if the object's __eq__ attribute has __code__,
if so, we check the equally of the method code filename (__code__.co_filename)
to the default one generated by the dataclass and attr module
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
"""
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"):
code_filename = obj.__eq__.__code__.co_filename
if isattrs(obj):
return "attrs generated " in code_filename
return code_filename == "<string>" # data class
return True
def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> list[str] | None:
"""Return specialised explanations for some operators/operands."""
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246.
use_ascii = (
isinstance(left, str)
and isinstance(right, str)
and normalize("NFD", left) == normalize("NFD", right)
)
if verbose > 1:
left_repr = saferepr_unlimited(left, use_ascii=use_ascii)
right_repr = saferepr_unlimited(right, use_ascii=use_ascii)
else:
# XXX: "15 chars indentation" is wrong
# ("E AssertionError: assert "); should use term width.
maxsize = (
80 - 15 - len(op) - 2
) // 2 # 15 chars indentation, 1 space around op
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}"
highlighter = config.get_terminal_writer()._highlight
explanation = None
try:
if op == "==":
explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
elif op == "!=":
if isset(left) and isset(right):
explanation = ["Both sets are equal"]
elif op == ">=":
if isset(left) and isset(right):
explanation = _compare_gte_set(left, right, highlighter, verbose)
elif op == "<=":
if isset(left) and isset(right):
explanation = _compare_lte_set(left, right, highlighter, verbose)
elif op == ">":
if isset(left) and isset(right):
explanation = _compare_gt_set(left, right, highlighter, verbose)
elif op == "<":
if isset(left) and isset(right):
explanation = _compare_lt_set(left, right, highlighter, verbose)
except outcomes.Exit:
raise
except Exception:
repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash()
explanation = [
f"(pytest_assertion plugin: representation of details failed: {repr_crash}.",
" Probably an object has a faulty __repr__.)",
]
if not explanation:
return None
if explanation[0] != "":
explanation = ["", *explanation]
return [summary, *explanation]
def _compare_eq_any(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, highlighter, verbose)
else:
from _pytest.python_api import ApproxBase
if isinstance(left, ApproxBase) or isinstance(right, ApproxBase):
# Although the common order should be obtained == expected, this ensures both ways
approx_side = left if isinstance(left, ApproxBase) else right
other_side = right if isinstance(left, ApproxBase) else left
explanation = approx_side._repr_compare(other_side)
elif type(left) is type(right) and (
isdatacls(left) or isattrs(left) or isnamedtuple(left)
):
# Note: unlike dataclasses/attrs, namedtuples compare only the
# field values, not the type or field names. But this branch
# intentionally only handles the same-type case, which was often
# used in older code bases before dataclasses/attrs were available.
explanation = _compare_eq_cls(left, right, highlighter, verbose)
elif issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, highlighter, verbose)
elif isset(left) and isset(right):
explanation = _compare_eq_set(left, right, highlighter, verbose)
elif isdict(left) and isdict(right):
explanation = _compare_eq_dict(left, right, highlighter, verbose)
if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, highlighter, verbose)
explanation.extend(expl)
return explanation
def _diff_text(
left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0
) -> list[str]:
"""Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing
characters which are identical to keep the diff minimal.
"""
from difflib import ndiff
explanation: list[str] = []
if verbose < 1:
i = 0 # just in case left or right has zero length
for i in range(min(len(left), len(right))):
if left[i] != right[i]:
break
if i > 42:
i -= 10 # Provide some context
explanation = [
f"Skipping {i} identical leading characters in diff, use -v to show"
]
left = left[i:]
right = right[i:]
if len(left) == len(right):
for i in range(len(left)):
if left[-i] != right[-i]:
break
if i > 42:
i -= 10 # Provide some context
explanation += [
f"Skipping {i} identical trailing "
"characters in diff, use -v to show"
]
left = left[:-i]
right = right[:-i]
keepends = True
if left.isspace() or right.isspace():
left = repr(str(left))
right = repr(str(right))
explanation += ["Strings contain only whitespace, escaping them using repr()"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highlighter(
"\n".join(
line.strip("\n")
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
),
lexer="diff",
).splitlines()
)
return explanation
def _compare_eq_iterable(
left: Iterable[Any],
right: Iterable[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"]
# dynamic import to speedup pytest
import difflib
left_formatting = PrettyPrinter().pformat(left).splitlines()
right_formatting = PrettyPrinter().pformat(right).splitlines()
explanation = ["", "Full diff:"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
highlighter(
"\n".join(
line.rstrip()
for line in difflib.ndiff(right_formatting, left_formatting)
),
lexer="diff",
).splitlines()
)
return explanation
def _compare_eq_sequence(
left: Sequence[Any],
right: Sequence[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: list[str] = []
len_left = len(left)
len_right = len(right)
for i in range(min(len_left, len_right)):
if left[i] != right[i]:
if comparing_bytes:
# when comparing bytes, we want to see their ascii representation
# instead of their numeric values (#5260)
# using a slice gives us the ascii representation:
# >>> s = b'foo'
# >>> s[0]
# 102
# >>> s[0:1]
# b'f'
left_value = left[i : i + 1]
right_value = right[i : i + 1]
else:
left_value = left[i]
right_value = right[i]
explanation.append(
f"At index {i} diff:"
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
)
break
if comparing_bytes:
# when comparing bytes, it doesn't help to show the "sides contain one or more
# items" longer explanation, so skip it
return explanation
len_diff = len_left - len_right
if len_diff:
if len_diff > 0:
dir_with_more = "Left"
extra = saferepr(left[len_right])
else:
len_diff = 0 - len_diff
dir_with_more = "Right"
extra = saferepr(right[len_left])
if len_diff == 1:
explanation += [
f"{dir_with_more} contains one more item: {highlighter(extra)}"
]
else:
explanation += [
f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}"
]
return explanation
def _compare_eq_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = []
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
return explanation
def _compare_gt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_gte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation
def _compare_lt_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation = _compare_lte_set(left, right, highlighter)
if not explanation:
return ["Both sets are equal"]
return explanation
def _compare_gte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("right", right, left, highlighter)
def _compare_lte_set(
left: AbstractSet[Any],
right: AbstractSet[Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
return _set_one_sided_diff("left", left, right, highlighter)
def _set_one_sided_diff(
posn: str,
set1: AbstractSet[Any],
set2: AbstractSet[Any],
highlighter: _HighlightFunc,
) -> list[str]:
explanation = []
diff = set1 - set2
if diff:
explanation.append(f"Extra items in the {posn} set:")
for item in diff:
explanation.append(highlighter(saferepr(item)))
return explanation
def _compare_eq_dict(
left: Mapping[Any, Any],
right: Mapping[Any, Any],
highlighter: _HighlightFunc,
verbose: int = 0,
) -> list[str]:
explanation: list[str] = []
set_left = set(left)
set_right = set(right)
common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2:
explanation += [f"Omitting {len(same)} identical items, use -vv to show"]
elif same:
explanation += ["Common items:"]
explanation += highlighter(pprint.pformat(same)).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
explanation += ["Differing items:"]
for k in diff:
explanation += [
highlighter(saferepr({k: left[k]}))
+ " != "
+ highlighter(saferepr({k: right[k]}))
]
extra_left = set_left - set_right
len_extra_left = len(extra_left)
if len_extra_left:
explanation.append(
f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:"
)
explanation.extend(
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
)
extra_right = set_right - set_left
len_extra_right = len(extra_right)
if len_extra_right:
explanation.append(
f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:"
)
explanation.extend(
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
)
return explanation
def _compare_eq_cls(
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
) -> list[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
import dataclasses
all_fields = dataclasses.fields(left)
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
elif isnamedtuple(left):
fields_to_check = left._fields
else:
assert False
indent = " "
same = []
diff = []
for field in fields_to_check:
if getattr(left, field) == getattr(right, field):
same.append(field)
else:
diff.append(field)
explanation = []
if same or diff:
explanation += [""]
if same and verbose < 2:
explanation.append(f"Omitting {len(same)} identical items, use -vv to show")
elif same:
explanation += ["Matching attributes:"]
explanation += highlighter(pprint.pformat(same)).splitlines()
if diff:
explanation += ["Differing attributes:"]
explanation += highlighter(pprint.pformat(diff)).splitlines()
for field in diff:
field_left = getattr(left, field)
field_right = getattr(right, field)
explanation += [
"",
f"Drill down into differing attribute {field}:",
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
]
explanation += [
indent + line
for line in _compare_eq_any(
field_left, field_right, highlighter, verbose
)
]
return explanation
def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
index = text.find(term)
head = text[:index]
tail = text[index + len(term) :]
correct_text = head + tail
diff = _diff_text(text, correct_text, dummy_highlighter, verbose)
newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"]
for line in diff:
if line.startswith("Skipping"):
continue
if line.startswith("- "):
continue
if line.startswith("+ "):
newdiff.append(" " + line[2:])
else:
newdiff.append(line)
return newdiff
def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)

View File

@@ -0,0 +1,625 @@
# mypy: allow-untyped-defs
"""Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version.
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Iterable
import dataclasses
import errno
import json
import os
from pathlib import Path
import tempfile
from typing import final
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.nodes import Directory
from _pytest.nodes import File
from _pytest.reports import TestReport
README_CONTENT = """\
# pytest cache directory #
This directory contains data from the pytest's cache plugin,
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
**Do not** commit this to version control.
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
"""
CACHEDIR_TAG_CONTENT = b"""\
Signature: 8a477f597d28d172789f06886806bc55
# This file is a cache directory tag created by pytest.
# For information about cache directory tags, see:
# https://bford.info/cachedir/spec.html
"""
@final
@dataclasses.dataclass
class Cache:
"""Instance of the `cache` fixture."""
_cachedir: Path = dataclasses.field(repr=False)
_config: Config = dataclasses.field(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d"
# Sub-directory under cache-dir for values created by `set()`.
_CACHE_PREFIX_VALUES = "v"
def __init__(
self, cachedir: Path, config: Config, *, _ispytest: bool = False
) -> None:
check_ispytest(_ispytest)
self._cachedir = cachedir
self._config = config
@classmethod
def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
"""Create the Cache instance for a Config.
:meta private:
"""
check_ispytest(_ispytest)
cachedir = cls.cache_dir_from_config(config, _ispytest=True)
if config.getoption("cacheclear") and cachedir.is_dir():
cls.clear_cache(cachedir, _ispytest=True)
return cls(cachedir, config, _ispytest=True)
@classmethod
def clear_cache(cls, cachedir: Path, _ispytest: bool = False) -> None:
"""Clear the sub-directories used to hold cached directories and values.
:meta private:
"""
check_ispytest(_ispytest)
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix
if d.is_dir():
rm_rf(d)
@staticmethod
def cache_dir_from_config(config: Config, *, _ispytest: bool = False) -> Path:
"""Get the path to the cache directory for a Config.
:meta private:
"""
check_ispytest(_ispytest)
return resolve_from_str(config.getini("cache_dir"), config.rootpath)
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
"""Issue a cache warning.
:meta private:
"""
check_ispytest(_ispytest)
import warnings
from _pytest.warning_types import PytestCacheWarning
warnings.warn(
PytestCacheWarning(fmt.format(**args) if args else fmt),
self._config.hook,
stacklevel=3,
)
def _mkdir(self, path: Path) -> None:
self._ensure_cache_dir_and_supporting_files()
path.mkdir(exist_ok=True, parents=True)
def mkdir(self, name: str) -> Path:
"""Return a directory path object with the given name.
If the directory does not yet exist, it will be created. You can use
it to manage files to e.g. store/retrieve database dumps across test
sessions.
.. versionadded:: 7.0
:param name:
Must be a string not containing a ``/`` separator.
Make sure the name contains your plugin or application
identifiers to prevent clashes with other cache users.
"""
path = Path(name)
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
self._mkdir(res)
return res
def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
def get(self, key: str, default):
"""Return the cached value for the given key.
If no value was yet cached or the value cannot be read, the specified
default is returned.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param default:
The value to return in case of a cache-miss or invalid cache value.
"""
path = self._getvaluepath(key)
try:
with path.open("r", encoding="UTF-8") as f:
return json.load(f)
except (ValueError, OSError):
return default
def set(self, key: str, value: object) -> None:
"""Save value for the given key.
:param key:
Must be a ``/`` separated value. Usually the first
name is the name of your plugin or your application.
:param value:
Must be of any combination of basic python types,
including nested types like lists of dictionaries.
"""
path = self._getvaluepath(key)
try:
self._mkdir(path.parent)
except OSError as exc:
self.warn(
f"could not create cache path {path}: {exc}",
_ispytest=True,
)
return
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
except OSError as exc:
self.warn(
f"cache could not write path {path}: {exc}",
_ispytest=True,
)
else:
with f:
f.write(data)
def _ensure_cache_dir_and_supporting_files(self) -> None:
"""Create the cache dir and its supporting files."""
if self._cachedir.is_dir():
return
self._cachedir.parent.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory(
prefix="pytest-cache-files-",
dir=self._cachedir.parent,
) as newpath:
path = Path(newpath)
# Reset permissions to the default, see #12308.
# Note: there's no way to get the current umask atomically, eek.
umask = os.umask(0o022)
os.umask(umask)
path.chmod(0o777 - umask)
with open(path.joinpath("README.md"), "x", encoding="UTF-8") as f:
f.write(README_CONTENT)
with open(path.joinpath(".gitignore"), "x", encoding="UTF-8") as f:
f.write("# Created by pytest automatically.\n*\n")
with open(path.joinpath("CACHEDIR.TAG"), "xb") as f:
f.write(CACHEDIR_TAG_CONTENT)
try:
path.rename(self._cachedir)
except OSError as e:
# If 2 concurrent pytests both race to the rename, the loser
# gets "Directory not empty" from the rename. In this case,
# everything is handled so just continue (while letting the
# temporary directory be cleaned up).
# On Windows, the error is a FileExistsError which translates to EEXIST.
if e.errno not in (errno.ENOTEMPTY, errno.EEXIST):
raise
else:
# Create a directory in place of the one we just moved so that
# `TemporaryDirectory`'s cleanup doesn't complain.
#
# TODO: pass ignore_cleanup_errors=True when we no longer support python < 3.10.
# See https://github.com/python/cpython/issues/74168. Note that passing
# delete=False would do the wrong thing in case of errors and isn't supported
# until python 3.12.
path.mkdir()
class LFPluginCollWrapper:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
@hookimpl(wrapper=True)
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Generator[None, CollectReport, CollectReport]:
res = yield
if isinstance(collector, (Session, Directory)):
# Sort any lf-paths to the beginning.
lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to prioritize last failed.
def sort_key(node: nodes.Item | nodes.Collector) -> bool:
return node.path in lf_paths
res.result = sorted(
res.result,
key=sort_key,
reverse=True,
)
elif isinstance(collector, File):
if collector.path in self.lfplugin._last_failed_paths:
result = res.result
lastfailed = self.lfplugin.lastfailed
# Only filter with known failures.
if not self._collected_at_least_one_failure:
if not any(x.nodeid in lastfailed for x in result):
return res
self.lfplugin.config.pluginmanager.register(
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
)
self._collected_at_least_one_failure = True
session = collector.session
result[:] = [
x
for x in result
if x.nodeid in lastfailed
# Include any passed arguments (not trivial to filter).
or session.isinitpath(x.path)
# Keep all sub-collectors.
or isinstance(x, nodes.Collector)
]
return res
class LFPluginCollSkipfiles:
def __init__(self, lfplugin: LFPlugin) -> None:
self.lfplugin = lfplugin
@hookimpl
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> CollectReport | None:
if isinstance(collector, File):
if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
return CollectReport(
collector.nodeid, "passed", longrepr=None, result=[]
)
return None
class LFPlugin:
"""Plugin which implements the --lf (run last-failing) option."""
def __init__(self, config: Config) -> None:
self.config = config
active_keys = "lf", "failedfirst"
self.active = any(config.getoption(key) for key in active_keys)
assert config.cache
self.lastfailed: dict[str, bool] = config.cache.get("cache/lastfailed", {})
self._previously_failed_count: int | None = None
self._report_status: str | None = None
self._skipped_files = 0 # count skipped files during collection due to --lf
if config.getoption("lf"):
self._last_failed_paths = self.get_last_failed_paths()
config.pluginmanager.register(
LFPluginCollWrapper(self), "lfplugin-collwrapper"
)
def get_last_failed_paths(self) -> set[Path]:
"""Return a set with all Paths of the previously failed nodeids and
their parents."""
rootpath = self.config.rootpath
result = set()
for nodeid in self.lastfailed:
path = rootpath / nodeid.split("::")[0]
result.add(path)
result.update(path.parents)
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> str | None:
if self.active and self.config.get_verbosity() >= 0:
return f"run-last-failure: {self._report_status}"
return None
def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:
self.lastfailed.pop(report.nodeid)
self.lastfailed.update((item.nodeid, True) for item in report.result)
else:
self.lastfailed[report.nodeid] = True
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(
self, config: Config, items: list[nodes.Item]
) -> Generator[None]:
res = yield
if not self.active:
return res
if self.lastfailed:
previously_failed = []
previously_passed = []
for item in items:
if item.nodeid in self.lastfailed:
previously_failed.append(item)
else:
previously_passed.append(item)
self._previously_failed_count = len(previously_failed)
if not previously_failed:
# Running a subset of all tests with recorded failures
# only outside of it.
self._report_status = (
f"{len(self.lastfailed)} known failures not in selected tests"
)
else:
if self.config.getoption("lf"):
items[:] = previously_failed
config.hook.pytest_deselected(items=previously_passed)
else: # --failedfirst
items[:] = previously_failed + previously_passed
noun = "failure" if self._previously_failed_count == 1 else "failures"
suffix = " first" if self.config.getoption("failedfirst") else ""
self._report_status = (
f"rerun previous {self._previously_failed_count} {noun}{suffix}"
)
if self._skipped_files > 0:
files_noun = "file" if self._skipped_files == 1 else "files"
self._report_status += f" (skipped {self._skipped_files} {files_noun})"
else:
self._report_status = "no previously failed tests, "
if self.config.getoption("last_failed_no_failures") == "none":
self._report_status += "deselecting all items."
config.hook.pytest_deselected(items=items[:])
items[:] = []
else:
self._report_status += "not deselecting items."
return res
def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
class NFPlugin:
"""Plugin which implements the --nf (run new-first) option."""
def __init__(self, config: Config) -> None:
self.config = config
self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection_modifyitems(self, items: list[nodes.Item]) -> Generator[None]:
res = yield
if self.active:
new_items: dict[str, nodes.Item] = {}
other_items: dict[str, nodes.Item] = {}
for item in items:
if item.nodeid not in self.cached_nodeids:
new_items[item.nodeid] = item
else:
other_items[item.nodeid] = item
items[:] = self._get_increasing_order(
new_items.values()
) + self._get_increasing_order(other_items.values())
self.cached_nodeids.update(new_items)
else:
self.cached_nodeids.update(item.nodeid for item in items)
return res
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True)
def pytest_sessionfinish(self) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
return
if config.getoption("collectonly"):
return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--lf",
"--last-failed",
action="store_true",
dest="lf",
help="Rerun only the tests that failed at the last run (or all if none failed)",
)
group.addoption(
"--ff",
"--failed-first",
action="store_true",
dest="failedfirst",
help="Run all tests, but run the last failures first. "
"This may re-order tests and thus lead to "
"repeated fixture setup/teardown.",
)
group.addoption(
"--nf",
"--new-first",
action="store_true",
dest="newfirst",
help="Run tests from new files first, then the rest of the tests "
"sorted by file mtime",
)
group.addoption(
"--cache-show",
action="append",
nargs="?",
dest="cacheshow",
help=(
"Show cache contents, don't perform collection or tests. "
"Optional argument: glob (default: '*')."
),
)
group.addoption(
"--cache-clear",
action="store_true",
dest="cacheclear",
help="Remove all cache contents at start of test run",
)
cache_dir_default = ".pytest_cache"
if "TOX_ENV_DIR" in os.environ:
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default)
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path")
group.addoption(
"--lfnf",
"--last-failed-no-failures",
action="store",
dest="last_failed_no_failures",
choices=("all", "none"),
default="all",
help="With ``--lf``, determines whether to execute tests when there "
"are no previously (known) failures or when no "
"cached ``lastfailed`` data was found. "
"``all`` (the default) runs the full test suite again. "
"``none`` just emits a message about no known failures and exits successfully.",
)
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)
return None
@hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
config.cache = Cache.for_config(config, _ispytest=True)
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
@fixture
def cache(request: FixtureRequest) -> Cache:
"""Return a cache object that can persist state between testing sessions.
cache.get(key, default)
cache.set(key, value)
Keys must be ``/`` separated strings, where the first part is usually the
name of your plugin or application to avoid clashes with other cache users.
Values can be any object handled by the json stdlib module.
"""
assert request.config.cache is not None
return request.config.cache
def pytest_report_header(config: Config) -> str | None:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible
try:
displaypath = cachedir.relative_to(config.rootpath)
except ValueError:
displaypath = cachedir
return f"cachedir: {displaypath}"
return None
def cacheshow(config: Config, session: Session) -> int:
from pprint import pformat
assert config.cache is not None
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():
tw.line("cache is empty")
return 0
glob = config.option.cacheshow[0]
if glob is None:
glob = "*"
dummy = object()
basedir = config.cache._cachedir
vdir = basedir / Cache._CACHE_PREFIX_VALUES
tw.sep("-", f"cache values for {glob!r}")
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
key = str(valpath.relative_to(vdir))
val = config.cache.get(key, dummy)
if val is dummy:
tw.line(f"{key} contains unreadable content, will be ignored")
else:
tw.line(f"{key} contains:")
for line in pformat(val).splitlines():
tw.line(" " + line)
ddir = basedir / Cache._CACHE_PREFIX_DIRS
if ddir.is_dir():
contents = sorted(ddir.rglob(glob))
tw.sep("-", f"cache directories for {glob!r}")
for p in contents:
# if p.is_dir():
# print("%s/" % p.relative_to(basedir))
if p.is_file():
key = str(p.relative_to(basedir))
tw.line(f"{key} is a file of length {p.stat().st_size}")
return 0

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,322 @@
# mypy: allow-untyped-defs
"""Python version compatibility code."""
from __future__ import annotations
from collections.abc import Callable
import enum
import functools
import inspect
from inspect import Parameter
from inspect import signature
import os
from pathlib import Path
import sys
from typing import Any
from typing import Final
from typing import NoReturn
import py
#: constant to prepare valuing pylib path replacements/lazy proxies later on
# intended for removal in pytest 8.0 or 9.0
# fmt: off
# intentional space to create a fake difference for the verification
LEGACY_PATH = py.path. local
# fmt: on
def legacy_path(path: str | os.PathLike[str]) -> LEGACY_PATH:
"""Internal wrapper to prepare lazy proxies for legacy_path instances"""
return LEGACY_PATH(path)
# fmt: off
# Singleton type for NOTSET, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET: Final = NotSetType.token
# fmt: on
def iscoroutinefunction(func: object) -> bool:
"""Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with
@asyncio.coroutine.
Note: copied and modified from Python 3.5's builtin coroutines.py to avoid
importing asyncio directly, which in turns also initializes the "logging"
module as a side-effect (see issue #8).
"""
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def is_async_function(func: object) -> bool:
"""Return True if the given function seems to be an async function or
an async generator."""
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
function = get_real_func(function)
fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
if curdir is not None:
try:
relfn = fn.relative_to(curdir)
except ValueError:
pass
else:
return f"{relfn}:{lineno + 1}"
return f"{fn}:{lineno + 1}"
def num_mock_patch_args(function) -> int:
"""Return number of arguments used up by mock arguments (if any)."""
patchings = getattr(function, "patchings", None)
if not patchings:
return 0
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object())
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object())
return len(
[
p
for p in patchings
if not p.attribute_name
and (p.new is mock_sentinel or p.new is ut_mock_sentinel)
]
)
def getfuncargnames(
function: Callable[..., object],
*,
name: str = "",
cls: type | None = None,
) -> tuple[str, ...]:
"""Return the names of a function's mandatory arguments.
Should return the names of all function arguments that:
* Aren't bound to an instance or type as in instance or class methods.
* Don't have default values.
* Aren't bound with functools.partial.
* Aren't replaced with mocks.
The cls arguments indicate that the function should be treated as a bound
method even though it's not unless the function is a static method.
The name parameter should be the original name in which the function was collected.
"""
# TODO(RonnyPfannschmidt): This function should be refactored when we
# revisit fixtures. The fixture mechanism should ask the node for
# the fixture names, and not try to obtain directly from the
# function object well after collection has occurred.
# The parameters attribute of a Signature object contains an
# ordered mapping of parameter names to Parameter instances. This
# creates a tuple of the names of the parameters that don't have
# defaults.
try:
parameters = signature(function).parameters.values()
except (ValueError, TypeError) as e:
from _pytest.outcomes import fail
fail(
f"Could not determine arguments of {function!r}: {e}",
pytrace=False,
)
arg_names = tuple(
p.name
for p in parameters
if (
p.kind is Parameter.POSITIONAL_OR_KEYWORD
or p.kind is Parameter.KEYWORD_ONLY
)
and p.default is Parameter.empty
)
if not name:
name = function.__name__
# If this function should be treated as a bound method even though
# it's passed as an unbound method or function, and its first parameter
# wasn't defined as positional only, remove the first parameter name.
if not any(p.kind is Parameter.POSITIONAL_ONLY for p in parameters) and (
# Not using `getattr` because we don't want to resolve the staticmethod.
# Not using `cls.__dict__` because we want to check the entire MRO.
cls
and not isinstance(
inspect.getattr_static(cls, name, default=None), staticmethod
)
):
arg_names = arg_names[1:]
# Remove any names that will be replaced with mocks.
if hasattr(function, "__wrapped__"):
arg_names = arg_names[num_mock_patch_args(function) :]
return arg_names
def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of
# getfuncargnames, to get the arguments which were excluded from its result
# because they had default values.
return tuple(
p.name
for p in signature(function).parameters.values()
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
and p.default is not Parameter.empty
)
_non_printable_ascii_translate_table = {
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
}
_non_printable_ascii_translate_table.update(
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
)
def ascii_escaped(val: bytes | str) -> str:
r"""If val is pure ASCII, return it as an str, otherwise, escape
bytes objects into a sequence of escaped bytes:
b'\xc3\xb4\xc5\xd6' -> r'\xc3\xb4\xc5\xd6'
and escapes strings into a sequence of escaped unicode ids, e.g.:
r'4\nV\U00043efa\x0eMXWB\x1e\u3028\u15fd\xcd\U0007d944'
Note:
The obvious "v.decode('unicode-escape')" will return
valid UTF-8 unicode if it finds them in bytes, but we
want to return escaped bytes for any byte, even if they match
a UTF-8 string.
"""
if isinstance(val, bytes):
ret = val.decode("ascii", "backslashreplace")
else:
ret = val.encode("unicode_escape").decode("ascii")
return ret.translate(_non_printable_ascii_translate_table)
def get_real_func(obj):
"""Get the real function object of the (possibly) wrapped object by
:func:`functools.wraps`, or :func:`functools.partial`."""
obj = inspect.unwrap(obj)
if isinstance(obj, functools.partial):
obj = obj.func
return obj
def getimfunc(func):
try:
return func.__func__
except AttributeError:
return func
def safe_getattr(object: Any, name: str, default: Any) -> Any:
"""Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects.
See issue #214.
It catches OutcomeException because of #2490 (issue #580), new outcomes
are derived from BaseException instead of Exception (for more details
check #2707).
"""
from _pytest.outcomes import TEST_OUTCOME
try:
return getattr(object, name, default)
except TEST_OUTCOME:
return default
def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3."""
try:
return inspect.isclass(obj)
except Exception:
return False
def get_user_id() -> int | None:
"""Return the current process's real user id or None if it could not be
determined.
:return: The user id or None if it could not be determined.
"""
# mypy follows the version and platform checking expectation of PEP 484:
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks
# Containment checks are too complex for mypy v1.5.0 and cause failure.
if sys.platform == "win32" or sys.platform == "emscripten":
# win32 does not have a getuid() function.
# Emscripten has a return 0 stub.
return None
else:
# On other platforms, a return value of -1 is assumed to indicate that
# the current process's real user id could not be determined.
ERROR = -1
uid = os.getuid()
return uid if uid != ERROR else None
# Perform exhaustiveness checking.
#
# Consider this example:
#
# MyUnion = Union[int, str]
#
# def handle(x: MyUnion) -> int {
# if isinstance(x, int):
# return 1
# elif isinstance(x, str):
# return 2
# else:
# raise Exception('unreachable')
#
# Now suppose we add a new variant:
#
# MyUnion = Union[int, str, bytes]
#
# After doing this, we must remember ourselves to go and update the handle
# function to handle the new variant.
#
# With `assert_never` we can do better:
#
# // raise Exception('unreachable')
# return assert_never(x)
#
# Now, if we forget to handle the new variant, the type-checker will emit a
# compile-time error, instead of the runtime error we would have gotten
# previously.
#
# This also work for Enums (if you use `is` to compare) and Literals.
def assert_never(value: NoReturn) -> NoReturn:
assert False, f"Unhandled value: {value} ({type(value).__name__})"
class CallableBool:
"""
A bool-like object that can also be called, returning its true/false value.
Used for backwards compatibility in cases where something was supposed to be a method
but was implemented as a simple attribute by mistake (see `TerminalReporter.isatty`).
Do not use in new code.
"""
def __init__(self, value: bool) -> None:
self._value = value
def __bool__(self) -> bool:
return self._value
def __call__(self) -> bool:
return self._value

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,533 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import argparse
from collections.abc import Callable
from collections.abc import Mapping
from collections.abc import Sequence
import os
from typing import Any
from typing import cast
from typing import final
from typing import Literal
from typing import NoReturn
import _pytest._io
from _pytest.config.exceptions import UsageError
from _pytest.deprecated import check_ispytest
FILE_OR_DIR = "file_or_dir"
class NotSet:
def __repr__(self) -> str:
return "<notset>"
NOT_SET = NotSet()
@final
class Parser:
"""Parser for command line arguments and ini-file values.
:ivar extra_info: Dict of generic param -> value to display in case
there's an error processing the command line arguments.
"""
prog: str | None = None
def __init__(
self,
usage: str | None = None,
processopt: Callable[[Argument], None] | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True)
self._groups: list[OptionGroup] = []
self._processopt = processopt
self._usage = usage
self._inidict: dict[str, tuple[str, str | None, Any]] = {}
self._ininames: list[str] = []
self.extra_info: dict[str, Any] = {}
def processoption(self, option: Argument) -> None:
if self._processopt:
if option.dest:
self._processopt(option)
def getgroup(
self, name: str, description: str = "", after: str | None = None
) -> OptionGroup:
"""Get (or create) a named option Group.
:param name: Name of the option group.
:param description: Long description for --help output.
:param after: Name of another group, used for ordering --help output.
:returns: The option group.
The returned group object has an ``addoption`` method with the same
signature as :func:`parser.addoption <pytest.Parser.addoption>` but
will be shown in the respective group in the output of
``pytest --help``.
"""
for group in self._groups:
if group.name == name:
return group
group = OptionGroup(name, description, parser=self, _ispytest=True)
i = 0
for i, grp in enumerate(self._groups):
if grp.name == after:
break
self._groups.insert(i + 1, group)
return group
def addoption(self, *opts: str, **attrs: Any) -> None:
"""Register a command line option.
:param opts:
Option names, can be short or long options.
:param attrs:
Same attributes as the argparse library's :meth:`add_argument()
<argparse.ArgumentParser.add_argument>` function accepts.
After command line parsing, options are available on the pytest config
object via ``config.option.NAME`` where ``NAME`` is usually set
by passing a ``dest`` attribute, for example
``addoption("--long", dest="NAME", ...)``.
"""
self._anonymous.addoption(*opts, **attrs)
def parse(
self,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
from _pytest._argcomplete import try_argcomplete
self.optparser = self._getparser()
try_argcomplete(self.optparser)
strargs = [os.fspath(x) for x in args]
return self.optparser.parse_args(strargs, namespace=namespace)
def _getparser(self) -> MyOptionParser:
from _pytest._argcomplete import filescompleter
optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
groups = [*self._groups, self._anonymous]
for group in groups:
if group.options:
desc = group.description or group.name
arggroup = optparser.add_argument_group(desc)
for option in group.options:
n = option.names()
a = option.attrs()
arggroup.add_argument(*n, **a)
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs="*")
# bash like autocompletion for dirs (appending '/')
# Type ignored because typeshed doesn't know about argcomplete.
file_or_dir_arg.completer = filescompleter # type: ignore
return optparser
def parse_setoption(
self,
args: Sequence[str | os.PathLike[str]],
option: argparse.Namespace,
namespace: argparse.Namespace | None = None,
) -> list[str]:
parsedoption = self.parse(args, namespace=namespace)
for name, value in parsedoption.__dict__.items():
setattr(option, name, value)
return cast(list[str], getattr(parsedoption, FILE_OR_DIR))
def parse_known_args(
self,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
"""Parse the known arguments at this point.
:returns: An argparse namespace object.
"""
return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
def parse_known_and_unknown_args(
self,
args: Sequence[str | os.PathLike[str]],
namespace: argparse.Namespace | None = None,
) -> tuple[argparse.Namespace, list[str]]:
"""Parse the known arguments at this point, and also return the
remaining unknown arguments.
:returns:
A tuple containing an argparse namespace object for the known
arguments, and a list of the unknown arguments.
"""
optparser = self._getparser()
strargs = [os.fspath(x) for x in args]
return optparser.parse_known_args(strargs, namespace=namespace)
def addini(
self,
name: str,
help: str,
type: Literal["string", "paths", "pathlist", "args", "linelist", "bool"]
| None = None,
default: Any = NOT_SET,
) -> None:
"""Register an ini-file option.
:param name:
Name of the ini-variable.
:param type:
Type of the variable. Can be:
* ``string``: a string
* ``bool``: a boolean
* ``args``: a list of strings, separated as in a shell
* ``linelist``: a list of strings, separated by line breaks
* ``paths``: a list of :class:`pathlib.Path`, separated as in a shell
* ``pathlist``: a list of ``py.path``, separated as in a shell
* ``int``: an integer
* ``float``: a floating-point number
.. versionadded:: 8.4
The ``float`` and ``int`` types.
For ``paths`` and ``pathlist`` types, they are considered relative to the ini-file.
In case the execution is happening without an ini-file defined,
they will be considered relative to the current working directory (for example with ``--override-ini``).
.. versionadded:: 7.0
The ``paths`` variable type.
.. versionadded:: 8.1
Use the current working directory to resolve ``paths`` and ``pathlist`` in the absence of an ini-file.
Defaults to ``string`` if ``None`` or not passed.
:param default:
Default value if no ini-file option exists but is queried.
The value of ini-variables can be retrieved via a call to
:py:func:`config.getini(name) <pytest.Config.getini>`.
"""
assert type in (
None,
"string",
"paths",
"pathlist",
"args",
"linelist",
"bool",
"int",
"float",
)
if default is NOT_SET:
default = get_ini_default_for_type(type)
self._inidict[name] = (help, type, default)
self._ininames.append(name)
def get_ini_default_for_type(
type: Literal[
"string", "paths", "pathlist", "args", "linelist", "bool", "int", "float"
]
| None,
) -> Any:
"""
Used by addini to get the default value for a given ini-option type, when
default is not supplied.
"""
if type is None:
return ""
elif type in ("paths", "pathlist", "args", "linelist"):
return []
elif type == "bool":
return False
elif type == "int":
return 0
elif type == "float":
return 0.0
else:
return ""
class ArgumentError(Exception):
"""Raised if an Argument instance is created with invalid or
inconsistent arguments."""
def __init__(self, msg: str, option: Argument | str) -> None:
self.msg = msg
self.option_id = str(option)
def __str__(self) -> str:
if self.option_id:
return f"option {self.option_id}: {self.msg}"
else:
return self.msg
class Argument:
"""Class that mimics the necessary behaviour of optparse.Option.
It's currently a least effort implementation and ignoring choices
and integer prefixes.
https://docs.python.org/3/library/optparse.html#optparse-standard-option-types
"""
def __init__(self, *names: str, **attrs: Any) -> None:
"""Store params in private vars for use in add_argument."""
self._attrs = attrs
self._short_opts: list[str] = []
self._long_opts: list[str] = []
try:
self.type = attrs["type"]
except KeyError:
pass
try:
# Attribute existence is tested in Config._processopt.
self.default = attrs["default"]
except KeyError:
pass
self._set_opt_strings(names)
dest: str | None = attrs.get("dest")
if dest:
self.dest = dest
elif self._long_opts:
self.dest = self._long_opts[0][2:].replace("-", "_")
else:
try:
self.dest = self._short_opts[0][1:]
except IndexError as e:
self.dest = "???" # Needed for the error repr.
raise ArgumentError("need a long or short option", self) from e
def names(self) -> list[str]:
return self._short_opts + self._long_opts
def attrs(self) -> Mapping[str, Any]:
# Update any attributes set by processopt.
attrs = "default dest help".split()
attrs.append(self.dest)
for attr in attrs:
try:
self._attrs[attr] = getattr(self, attr)
except AttributeError:
pass
return self._attrs
def _set_opt_strings(self, opts: Sequence[str]) -> None:
"""Directly from optparse.
Might not be necessary as this is passed to argparse later on.
"""
for opt in opts:
if len(opt) < 2:
raise ArgumentError(
f"invalid option string {opt!r}: "
"must be at least two characters long",
self,
)
elif len(opt) == 2:
if not (opt[0] == "-" and opt[1] != "-"):
raise ArgumentError(
f"invalid short option string {opt!r}: "
"must be of the form -x, (x any non-dash char)",
self,
)
self._short_opts.append(opt)
else:
if not (opt[0:2] == "--" and opt[2] != "-"):
raise ArgumentError(
f"invalid long option string {opt!r}: "
"must start with --, followed by non-dash",
self,
)
self._long_opts.append(opt)
def __repr__(self) -> str:
args: list[str] = []
if self._short_opts:
args += ["_short_opts: " + repr(self._short_opts)]
if self._long_opts:
args += ["_long_opts: " + repr(self._long_opts)]
args += ["dest: " + repr(self.dest)]
if hasattr(self, "type"):
args += ["type: " + repr(self.type)]
if hasattr(self, "default"):
args += ["default: " + repr(self.default)]
return "Argument({})".format(", ".join(args))
class OptionGroup:
"""A group of options shown in its own section."""
def __init__(
self,
name: str,
description: str = "",
parser: Parser | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.name = name
self.description = description
self.options: list[Argument] = []
self.parser = parser
def addoption(self, *opts: str, **attrs: Any) -> None:
"""Add an option to this group.
If a shortened version of a long option is specified, it will
be suppressed in the help. ``addoption('--twowords', '--two-words')``
results in help showing ``--two-words`` only, but ``--twowords`` gets
accepted **and** the automatic destination is in ``args.twowords``.
:param opts:
Option names, can be short or long options.
:param attrs:
Same attributes as the argparse library's :meth:`add_argument()
<argparse.ArgumentParser.add_argument>` function accepts.
"""
conflict = set(opts).intersection(
name for opt in self.options for name in opt.names()
)
if conflict:
raise ValueError(f"option names {conflict} already added")
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=False)
def _addoption(self, *opts: str, **attrs: Any) -> None:
option = Argument(*opts, **attrs)
self._addoption_instance(option, shortupper=True)
def _addoption_instance(self, option: Argument, shortupper: bool = False) -> None:
if not shortupper:
for opt in option._short_opts:
if opt[0] == "-" and opt[1].islower():
raise ValueError("lowercase shortoptions reserved")
if self.parser:
self.parser.processoption(option)
self.options.append(option)
class MyOptionParser(argparse.ArgumentParser):
def __init__(
self,
parser: Parser,
extra_info: dict[str, Any] | None = None,
prog: str | None = None,
) -> None:
self._parser = parser
super().__init__(
prog=prog,
usage=parser._usage,
add_help=False,
formatter_class=DropShorterLongHelpFormatter,
allow_abbrev=False,
fromfile_prefix_chars="@",
)
# extra_info is a dict of (param -> value) to display if there's
# an usage error to provide more contextual information to the user.
self.extra_info = extra_info if extra_info else {}
def error(self, message: str) -> NoReturn:
"""Transform argparse error message into UsageError."""
msg = f"{self.prog}: error: {message}"
if hasattr(self._parser, "_config_source_hint"):
msg = f"{msg} ({self._parser._config_source_hint})"
raise UsageError(self.format_usage() + msg)
# Type ignored because typeshed has a very complex type in the superclass.
def parse_args( # type: ignore
self,
args: Sequence[str] | None = None,
namespace: argparse.Namespace | None = None,
) -> argparse.Namespace:
"""Allow splitting of positional arguments."""
parsed, unrecognized = self.parse_known_args(args, namespace)
if unrecognized:
for arg in unrecognized:
if arg and arg[0] == "-":
lines = [
"unrecognized arguments: {}".format(" ".join(unrecognized))
]
for k, v in sorted(self.extra_info.items()):
lines.append(f" {k}: {v}")
self.error("\n".join(lines))
getattr(parsed, FILE_OR_DIR).extend(unrecognized)
return parsed
class DropShorterLongHelpFormatter(argparse.HelpFormatter):
"""Shorten help for long options that differ only in extra hyphens.
- Collapse **long** options that are the same except for extra hyphens.
- Shortcut if there are only two options and one of them is a short one.
- Cache result on the action object as this is called at least 2 times.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
# Use more accurate terminal width.
if "width" not in kwargs:
kwargs["width"] = _pytest._io.get_terminal_width()
super().__init__(*args, **kwargs)
def _format_action_invocation(self, action: argparse.Action) -> str:
orgstr = super()._format_action_invocation(action)
if orgstr and orgstr[0] != "-": # only optional arguments
return orgstr
res: str | None = getattr(action, "_formatted_action_invocation", None)
if res:
return res
options = orgstr.split(", ")
if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2):
# a shortcut for '-h, --help' or '--abc', '-a'
action._formatted_action_invocation = orgstr # type: ignore
return orgstr
return_list = []
short_long: dict[str, str] = {}
for option in options:
if len(option) == 2 or option[2] == " ":
continue
if not option.startswith("--"):
raise ArgumentError(
f'long optional argument without "--": [{option}]', option
)
xxoption = option[2:]
shortened = xxoption.replace("-", "")
if shortened not in short_long or len(short_long[shortened]) < len(
xxoption
):
short_long[shortened] = xxoption
# now short_long has been filled out to the longest with dashes
# **and** we keep the right option ordering from add_argument
for option in options:
if len(option) == 2 or option[2] == " ":
return_list.append(option)
if option[2:] == short_long.get(option.replace("-", "")):
return_list.append(option.replace(" ", "=", 1))
formatted_action_invocation = ", ".join(return_list)
action._formatted_action_invocation = formatted_action_invocation # type: ignore
return formatted_action_invocation
def _split_lines(self, text, width):
"""Wrap lines after splitting on original newlines.
This allows to have explicit line breaks in the help text.
"""
import textwrap
lines = []
for line in text.splitlines():
lines.extend(textwrap.wrap(line.strip(), width))
return lines

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from collections.abc import Mapping
import functools
from pathlib import Path
from typing import Any
import warnings
import pluggy
from ..compat import LEGACY_PATH
from ..compat import legacy_path
from ..deprecated import HOOK_LEGACY_PATH_ARG
# hookname: (Path, LEGACY_PATH)
imply_paths_hooks: Mapping[str, tuple[str, str]] = {
"pytest_ignore_collect": ("collection_path", "path"),
"pytest_collect_file": ("file_path", "path"),
"pytest_pycollect_makemodule": ("module_path", "path"),
"pytest_report_header": ("start_path", "startdir"),
"pytest_report_collectionfinish": ("start_path", "startdir"),
}
def _check_path(path: Path, fspath: LEGACY_PATH) -> None:
if Path(fspath) != path:
raise ValueError(
f"Path({fspath!r}) != {path!r}\n"
"if both path and fspath are given they need to be equal"
)
class PathAwareHookProxy:
"""
this helper wraps around hook callers
until pluggy supports fixingcalls, this one will do
it currently doesn't return full hook caller proxies for fixed hooks,
this may have to be changed later depending on bugs
"""
def __init__(self, hook_relay: pluggy.HookRelay) -> None:
self._hook_relay = hook_relay
def __dir__(self) -> list[str]:
return dir(self._hook_relay)
def __getattr__(self, key: str) -> pluggy.HookCaller:
hook: pluggy.HookCaller = getattr(self._hook_relay, key)
if key not in imply_paths_hooks:
self.__dict__[key] = hook
return hook
else:
path_var, fspath_var = imply_paths_hooks[key]
@functools.wraps(hook)
def fixed_hook(**kw: Any) -> Any:
path_value: Path | None = kw.pop(path_var, None)
fspath_value: LEGACY_PATH | None = kw.pop(fspath_var, None)
if fspath_value is not None:
warnings.warn(
HOOK_LEGACY_PATH_ARG.format(
pylib_path_arg=fspath_var, pathlib_path_arg=path_var
),
stacklevel=2,
)
if path_value is not None:
if fspath_value is not None:
_check_path(path_value, fspath_value)
else:
fspath_value = legacy_path(path_value)
else:
assert fspath_value is not None
path_value = Path(fspath_value)
kw[path_var] = path_value
kw[fspath_var] = fspath_value
return hook(**kw)
fixed_hook.name = hook.name # type: ignore[attr-defined]
fixed_hook.spec = hook.spec # type: ignore[attr-defined]
fixed_hook.__name__ = key
self.__dict__[key] = fixed_hook
return fixed_hook # type: ignore[return-value]

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from typing import final
@final
class UsageError(Exception):
"""Error in pytest usage or invocation."""
class PrintHelp(Exception):
"""Raised when pytest should print its help to skip the rest of the
argument parsing and validation."""

View File

@@ -0,0 +1,239 @@
from __future__ import annotations
from collections.abc import Iterable
from collections.abc import Sequence
import os
from pathlib import Path
import sys
from typing import TYPE_CHECKING
import iniconfig
from .exceptions import UsageError
from _pytest.outcomes import fail
from _pytest.pathlib import absolutepath
from _pytest.pathlib import commonpath
from _pytest.pathlib import safe_exists
if TYPE_CHECKING:
from typing import Union
from typing_extensions import TypeAlias
# Even though TOML supports richer data types, all values are converted to str/list[str] during
# parsing to maintain compatibility with the rest of the configuration system.
ConfigDict: TypeAlias = dict[str, Union[str, list[str]]]
def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
"""Parse the given generic '.ini' file using legacy IniConfig parser, returning
the parsed object.
Raise UsageError if the file cannot be parsed.
"""
try:
return iniconfig.IniConfig(str(path))
except iniconfig.ParseError as exc:
raise UsageError(str(exc)) from exc
def load_config_dict_from_file(
filepath: Path,
) -> ConfigDict | None:
"""Load pytest configuration from the given file path, if supported.
Return None if the file does not contain valid pytest configuration.
"""
# Configuration from ini files are obtained from the [pytest] section, if present.
if filepath.suffix == ".ini":
iniconfig = _parse_ini_config(filepath)
if "pytest" in iniconfig:
return dict(iniconfig["pytest"].items())
else:
# "pytest.ini" files are always the source of configuration, even if empty.
if filepath.name == "pytest.ini":
return {}
# '.cfg' files are considered if they contain a "[tool:pytest]" section.
elif filepath.suffix == ".cfg":
iniconfig = _parse_ini_config(filepath)
if "tool:pytest" in iniconfig.sections:
return dict(iniconfig["tool:pytest"].items())
elif "pytest" in iniconfig.sections:
# If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that
# plain "[pytest]" sections in setup.cfg files is no longer supported (#3086).
fail(CFG_PYTEST_SECTION.format(filename="setup.cfg"), pytrace=False)
# '.toml' files are considered if they contain a [tool.pytest.ini_options] table.
elif filepath.suffix == ".toml":
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
toml_text = filepath.read_text(encoding="utf-8")
try:
config = tomllib.loads(toml_text)
except tomllib.TOMLDecodeError as exc:
raise UsageError(f"{filepath}: {exc}") from exc
result = config.get("tool", {}).get("pytest", {}).get("ini_options", None)
if result is not None:
# TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
# however we need to convert all scalar values to str for compatibility with the rest
# of the configuration system, which expects strings only.
def make_scalar(v: object) -> str | list[str]:
return v if isinstance(v, list) else str(v)
return {k: make_scalar(v) for k, v in result.items()}
return None
def locate_config(
invocation_dir: Path,
args: Iterable[Path],
) -> tuple[Path | None, Path | None, ConfigDict]:
"""Search in the list of arguments for a valid ini-file for pytest,
and return a tuple of (rootdir, inifile, cfg-dict)."""
config_names = [
"pytest.ini",
".pytest.ini",
"pyproject.toml",
"tox.ini",
"setup.cfg",
]
args = [x for x in args if not str(x).startswith("-")]
if not args:
args = [invocation_dir]
found_pyproject_toml: Path | None = None
for arg in args:
argpath = absolutepath(arg)
for base in (argpath, *argpath.parents):
for config_name in config_names:
p = base / config_name
if p.is_file():
if p.name == "pyproject.toml" and found_pyproject_toml is None:
found_pyproject_toml = p
ini_config = load_config_dict_from_file(p)
if ini_config is not None:
return base, p, ini_config
if found_pyproject_toml is not None:
return found_pyproject_toml.parent, found_pyproject_toml, {}
return None, None, {}
def get_common_ancestor(
invocation_dir: Path,
paths: Iterable[Path],
) -> Path:
common_ancestor: Path | None = None
for path in paths:
if not path.exists():
continue
if common_ancestor is None:
common_ancestor = path
else:
if common_ancestor in path.parents or path == common_ancestor:
continue
elif path in common_ancestor.parents:
common_ancestor = path
else:
shared = commonpath(path, common_ancestor)
if shared is not None:
common_ancestor = shared
if common_ancestor is None:
common_ancestor = invocation_dir
elif common_ancestor.is_file():
common_ancestor = common_ancestor.parent
return common_ancestor
def get_dirs_from_args(args: Iterable[str]) -> list[Path]:
def is_option(x: str) -> bool:
return x.startswith("-")
def get_file_part_from_node_id(x: str) -> str:
return x.split("::")[0]
def get_dir_from_path(path: Path) -> Path:
if path.is_dir():
return path
return path.parent
# These look like paths but may not exist
possible_paths = (
absolutepath(get_file_part_from_node_id(arg))
for arg in args
if not is_option(arg)
)
return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)]
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
def determine_setup(
*,
inifile: str | None,
args: Sequence[str],
rootdir_cmd_arg: str | None,
invocation_dir: Path,
) -> tuple[Path, Path | None, ConfigDict]:
"""Determine the rootdir, inifile and ini configuration values from the
command line arguments.
:param inifile:
The `--inifile` command line argument, if given.
:param args:
The free command line arguments.
:param rootdir_cmd_arg:
The `--rootdir` command line argument, if given.
:param invocation_dir:
The working directory when pytest was invoked.
"""
rootdir = None
dirs = get_dirs_from_args(args)
if inifile:
inipath_ = absolutepath(inifile)
inipath: Path | None = inipath_
inicfg = load_config_dict_from_file(inipath_) or {}
if rootdir_cmd_arg is None:
rootdir = inipath_.parent
else:
ancestor = get_common_ancestor(invocation_dir, dirs)
rootdir, inipath, inicfg = locate_config(invocation_dir, [ancestor])
if rootdir is None and rootdir_cmd_arg is None:
for possible_rootdir in (ancestor, *ancestor.parents):
if (possible_rootdir / "setup.py").is_file():
rootdir = possible_rootdir
break
else:
if dirs != [ancestor]:
rootdir, inipath, inicfg = locate_config(invocation_dir, dirs)
if rootdir is None:
rootdir = get_common_ancestor(
invocation_dir, [invocation_dir, ancestor]
)
if is_fs_root(rootdir):
rootdir = ancestor
if rootdir_cmd_arg:
rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
if not rootdir.is_dir():
raise UsageError(
f"Directory '{rootdir}' not found. Check your '--rootdir' option."
)
assert rootdir is not None
return rootdir, inipath, inicfg or {}
def is_fs_root(p: Path) -> bool:
r"""
Return True if the given path is pointing to the root of the
file system ("/" on Unix and "C:\\" on Windows for example).
"""
return os.path.splitdrive(str(p))[1] == os.sep

View File

@@ -0,0 +1,407 @@
# mypy: allow-untyped-defs
# ruff: noqa: T100
"""Interactive debugging with PDB, the Python Debugger."""
from __future__ import annotations
import argparse
from collections.abc import Callable
from collections.abc import Generator
import functools
import sys
import types
from typing import Any
import unittest
from _pytest import outcomes
from _pytest._code import ExceptionInfo
from _pytest.capture import CaptureManager
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
from _pytest.config import hookimpl
from _pytest.config import PytestPluginManager
from _pytest.config.argparsing import Parser
from _pytest.config.exceptions import UsageError
from _pytest.nodes import Node
from _pytest.reports import BaseReport
from _pytest.runner import CallInfo
def _validate_usepdb_cls(value: str) -> tuple[str, str]:
"""Validate syntax of --pdbcls option."""
try:
modname, classname = value.split(":")
except ValueError as e:
raise argparse.ArgumentTypeError(
f"{value!r} is not in the format 'modname:classname'"
) from e
return (modname, classname)
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--pdb",
dest="usepdb",
action="store_true",
help="Start the interactive Python debugger on errors or KeyboardInterrupt",
)
group.addoption(
"--pdbcls",
dest="usepdb_cls",
metavar="modulename:classname",
type=_validate_usepdb_cls,
help="Specify a custom interactive Python debugger for use with --pdb."
"For example: --pdbcls=IPython.terminal.debugger:TerminalPdb",
)
group.addoption(
"--trace",
dest="trace",
action="store_true",
help="Immediately break when running each test",
)
def pytest_configure(config: Config) -> None:
import pdb
if config.getvalue("trace"):
config.pluginmanager.register(PdbTrace(), "pdbtrace")
if config.getvalue("usepdb"):
config.pluginmanager.register(PdbInvoke(), "pdbinvoke")
pytestPDB._saved.append(
(pdb.set_trace, pytestPDB._pluginmanager, pytestPDB._config)
)
pdb.set_trace = pytestPDB.set_trace
pytestPDB._pluginmanager = config.pluginmanager
pytestPDB._config = config
# NOTE: not using pytest_unconfigure, since it might get called although
# pytest_configure was not (if another plugin raises UsageError).
def fin() -> None:
(
pdb.set_trace,
pytestPDB._pluginmanager,
pytestPDB._config,
) = pytestPDB._saved.pop()
config.add_cleanup(fin)
class pytestPDB:
"""Pseudo PDB that defers to the real pdb."""
_pluginmanager: PytestPluginManager | None = None
_config: Config | None = None
_saved: list[
tuple[Callable[..., None], PytestPluginManager | None, Config | None]
] = []
_recursive_debug = 0
_wrapped_pdb_cls: tuple[type[Any], type[Any]] | None = None
@classmethod
def _is_capturing(cls, capman: CaptureManager | None) -> str | bool:
if capman:
return capman.is_capturing()
return False
@classmethod
def _import_pdb_cls(cls, capman: CaptureManager | None):
if not cls._config:
import pdb
# Happens when using pytest.set_trace outside of a test.
return pdb.Pdb
usepdb_cls = cls._config.getvalue("usepdb_cls")
if cls._wrapped_pdb_cls and cls._wrapped_pdb_cls[0] == usepdb_cls:
return cls._wrapped_pdb_cls[1]
if usepdb_cls:
modname, classname = usepdb_cls
try:
__import__(modname)
mod = sys.modules[modname]
# Handle --pdbcls=pdb:pdb.Pdb (useful e.g. with pdbpp).
parts = classname.split(".")
pdb_cls = getattr(mod, parts[0])
for part in parts[1:]:
pdb_cls = getattr(pdb_cls, part)
except Exception as exc:
value = ":".join((modname, classname))
raise UsageError(
f"--pdbcls: could not import {value!r}: {exc}"
) from exc
else:
import pdb
pdb_cls = pdb.Pdb
wrapped_cls = cls._get_pdb_wrapper_class(pdb_cls, capman)
cls._wrapped_pdb_cls = (usepdb_cls, wrapped_cls)
return wrapped_cls
@classmethod
def _get_pdb_wrapper_class(cls, pdb_cls, capman: CaptureManager | None):
import _pytest.config
class PytestPdbWrapper(pdb_cls):
_pytest_capman = capman
_continued = False
def do_debug(self, arg):
cls._recursive_debug += 1
ret = super().do_debug(arg)
cls._recursive_debug -= 1
return ret
if hasattr(pdb_cls, "do_debug"):
do_debug.__doc__ = pdb_cls.do_debug.__doc__
def do_continue(self, arg):
ret = super().do_continue(arg)
if cls._recursive_debug == 0:
assert cls._config is not None
tw = _pytest.config.create_terminal_writer(cls._config)
tw.line()
capman = self._pytest_capman
capturing = pytestPDB._is_capturing(capman)
if capturing:
if capturing == "global":
tw.sep(">", "PDB continue (IO-capturing resumed)")
else:
tw.sep(
">",
f"PDB continue (IO-capturing resumed for {capturing})",
)
assert capman is not None
capman.resume()
else:
tw.sep(">", "PDB continue")
assert cls._pluginmanager is not None
cls._pluginmanager.hook.pytest_leave_pdb(config=cls._config, pdb=self)
self._continued = True
return ret
if hasattr(pdb_cls, "do_continue"):
do_continue.__doc__ = pdb_cls.do_continue.__doc__
do_c = do_cont = do_continue
def do_quit(self, arg):
# Raise Exit outcome when quit command is used in pdb.
#
# This is a bit of a hack - it would be better if BdbQuit
# could be handled, but this would require to wrap the
# whole pytest run, and adjust the report etc.
ret = super().do_quit(arg)
if cls._recursive_debug == 0:
outcomes.exit("Quitting debugger")
return ret
if hasattr(pdb_cls, "do_quit"):
do_quit.__doc__ = pdb_cls.do_quit.__doc__
do_q = do_quit
do_exit = do_quit
def setup(self, f, tb):
"""Suspend on setup().
Needed after do_continue resumed, and entering another
breakpoint again.
"""
ret = super().setup(f, tb)
if not ret and self._continued:
# pdb.setup() returns True if the command wants to exit
# from the interaction: do not suspend capturing then.
if self._pytest_capman:
self._pytest_capman.suspend_global_capture(in_=True)
return ret
def get_stack(self, f, t):
stack, i = super().get_stack(f, t)
if f is None:
# Find last non-hidden frame.
i = max(0, len(stack) - 1)
while i and stack[i][0].f_locals.get("__tracebackhide__", False):
i -= 1
return stack, i
return PytestPdbWrapper
@classmethod
def _init_pdb(cls, method, *args, **kwargs):
"""Initialize PDB debugging, dropping any IO capturing."""
import _pytest.config
if cls._pluginmanager is None:
capman: CaptureManager | None = None
else:
capman = cls._pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend(in_=True)
if cls._config:
tw = _pytest.config.create_terminal_writer(cls._config)
tw.line()
if cls._recursive_debug == 0:
# Handle header similar to pdb.set_trace in py37+.
header = kwargs.pop("header", None)
if header is not None:
tw.sep(">", header)
else:
capturing = cls._is_capturing(capman)
if capturing == "global":
tw.sep(">", f"PDB {method} (IO-capturing turned off)")
elif capturing:
tw.sep(
">",
f"PDB {method} (IO-capturing turned off for {capturing})",
)
else:
tw.sep(">", f"PDB {method}")
_pdb = cls._import_pdb_cls(capman)(**kwargs)
if cls._pluginmanager:
cls._pluginmanager.hook.pytest_enter_pdb(config=cls._config, pdb=_pdb)
return _pdb
@classmethod
def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
frame = sys._getframe().f_back
_pdb = cls._init_pdb("set_trace", *args, **kwargs)
_pdb.set_trace(frame)
class PdbInvoke:
def pytest_exception_interact(
self, node: Node, call: CallInfo[Any], report: BaseReport
) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture()
sys.stdout.write(out)
sys.stdout.write(err)
assert call.excinfo is not None
if not isinstance(call.excinfo.value, unittest.SkipTest):
_enter_pdb(node, call.excinfo, report)
def pytest_internalerror(self, excinfo: ExceptionInfo[BaseException]) -> None:
exc_or_tb = _postmortem_exc_or_tb(excinfo)
post_mortem(exc_or_tb)
class PdbTrace:
@hookimpl(wrapper=True)
def pytest_pyfunc_call(self, pyfuncitem) -> Generator[None, object, object]:
wrap_pytest_function_for_tracing(pyfuncitem)
return (yield)
def wrap_pytest_function_for_tracing(pyfuncitem) -> None:
"""Change the Python function object of the given Function item by a
wrapper which actually enters pdb before calling the python function
itself, effectively leaving the user in the pdb prompt in the first
statement of the function."""
_pdb = pytestPDB._init_pdb("runcall")
testfunction = pyfuncitem.obj
# we can't just return `partial(pdb.runcall, testfunction)` because (on
# python < 3.7.4) runcall's first param is `func`, which means we'd get
# an exception if one of the kwargs to testfunction was called `func`.
@functools.wraps(testfunction)
def wrapper(*args, **kwargs) -> None:
func = functools.partial(testfunction, *args, **kwargs)
_pdb.runcall(func)
pyfuncitem.obj = wrapper
def maybe_wrap_pytest_function_for_tracing(pyfuncitem) -> None:
"""Wrap the given pytestfunct item for tracing support if --trace was given in
the command line."""
if pyfuncitem.config.getvalue("trace"):
wrap_pytest_function_for_tracing(pyfuncitem)
def _enter_pdb(
node: Node, excinfo: ExceptionInfo[BaseException], rep: BaseReport
) -> BaseReport:
# XXX we reuse the TerminalReporter's terminalwriter
# because this seems to avoid some encoding related troubles
# for not completely clear reasons.
tw = node.config.pluginmanager.getplugin("terminalreporter")._tw
tw.line()
showcapture = node.config.option.showcapture
for sectionname, content in (
("stdout", rep.capstdout),
("stderr", rep.capstderr),
("log", rep.caplog),
):
if showcapture in (sectionname, "all") and content:
tw.sep(">", "captured " + sectionname)
if content[-1:] == "\n":
content = content[:-1]
tw.line(content)
tw.sep(">", "traceback")
rep.toterminal(tw)
tw.sep(">", "entering PDB")
tb_or_exc = _postmortem_exc_or_tb(excinfo)
rep._pdbshown = True # type: ignore[attr-defined]
post_mortem(tb_or_exc)
return rep
def _postmortem_exc_or_tb(
excinfo: ExceptionInfo[BaseException],
) -> types.TracebackType | BaseException:
from doctest import UnexpectedException
get_exc = sys.version_info >= (3, 13)
if isinstance(excinfo.value, UnexpectedException):
# A doctest.UnexpectedException is not useful for post_mortem.
# Use the underlying exception instead:
underlying_exc = excinfo.value
if get_exc:
return underlying_exc.exc_info[1]
return underlying_exc.exc_info[2]
elif isinstance(excinfo.value, ConftestImportFailure):
# A config.ConftestImportFailure is not useful for post_mortem.
# Use the underlying exception instead:
cause = excinfo.value.cause
if get_exc:
return cause
assert cause.__traceback__ is not None
return cause.__traceback__
else:
assert excinfo._excinfo is not None
if get_exc:
return excinfo._excinfo[1]
return excinfo._excinfo[2]
def post_mortem(tb_or_exc: types.TracebackType | BaseException) -> None:
p = pytestPDB._init_pdb("post_mortem")
p.reset()
p.interaction(None, tb_or_exc)
if p.quitting:
outcomes.exit("Quitting debugger")

View File

@@ -0,0 +1,91 @@
"""Deprecation messages and bits of code used elsewhere in the codebase that
is planned to be removed in the next pytest release.
Keeping it in a central location makes it easy to track what is deprecated and should
be removed when the time comes.
All constants defined in this module should be either instances of
:class:`PytestWarning`, or :class:`UnformattedWarning`
in case of warnings which need to format their messages.
"""
from __future__ import annotations
from warnings import warn
from _pytest.warning_types import PytestDeprecationWarning
from _pytest.warning_types import PytestRemovedIn9Warning
from _pytest.warning_types import UnformattedWarning
# set of plugins which have been integrated into the core; we use this list to ignore
# them during registration to avoid conflicts
DEPRECATED_EXTERNAL_PLUGINS = {
"pytest_catchlog",
"pytest_capturelog",
"pytest_faulthandler",
}
# This can be* removed pytest 8, but it's harmless and common, so no rush to remove.
# * If you're in the future: "could have been".
YIELD_FIXTURE = PytestDeprecationWarning(
"@pytest.yield_fixture is deprecated.\n"
"Use @pytest.fixture instead; they are the same."
)
# This deprecation is never really meant to be removed.
PRIVATE = PytestDeprecationWarning("A private pytest class or function was used.")
HOOK_LEGACY_PATH_ARG = UnformattedWarning(
PytestRemovedIn9Warning,
"The ({pylib_path_arg}: py.path.local) argument is deprecated, please use ({pathlib_path_arg}: pathlib.Path)\n"
"see https://docs.pytest.org/en/latest/deprecations.html"
"#py-path-local-arguments-for-hooks-replaced-with-pathlib-path",
)
NODE_CTOR_FSPATH_ARG = UnformattedWarning(
PytestRemovedIn9Warning,
"The (fspath: py.path.local) argument to {node_type_name} is deprecated. "
"Please use the (path: pathlib.Path) argument instead.\n"
"See https://docs.pytest.org/en/latest/deprecations.html"
"#fspath-argument-for-node-constructors-replaced-with-pathlib-path",
)
HOOK_LEGACY_MARKING = UnformattedWarning(
PytestDeprecationWarning,
"The hook{type} {fullname} uses old-style configuration options (marks or attributes).\n"
"Please use the pytest.hook{type}({hook_opts}) decorator instead\n"
" to configure the hooks.\n"
" See https://docs.pytest.org/en/latest/deprecations.html"
"#configuring-hook-specs-impls-using-markers",
)
MARKED_FIXTURE = PytestRemovedIn9Warning(
"Marks applied to fixtures have no effect\n"
"See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function"
)
# You want to make some `__init__` or function "private".
#
# def my_private_function(some, args):
# ...
#
# Do this:
#
# def my_private_function(some, args, *, _ispytest: bool = False):
# check_ispytest(_ispytest)
# ...
#
# Change all internal/allowed calls to
#
# my_private_function(some, args, _ispytest=True)
#
# All other calls will get the default _ispytest=False and trigger
# the warning (possibly error in the future).
def check_ispytest(ispytest: bool) -> None:
if not ispytest:
warn(PRIVATE, stacklevel=3)

View File

@@ -0,0 +1,754 @@
# mypy: allow-untyped-defs
"""Discover and run doctests in modules and test files."""
from __future__ import annotations
import bdb
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from contextlib import contextmanager
import functools
import inspect
import os
from pathlib import Path
import platform
import re
import sys
import traceback
import types
from typing import Any
from typing import TYPE_CHECKING
import warnings
from _pytest import outcomes
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ReprFileLocation
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import safe_getattr
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import fixture
from _pytest.fixtures import TopRequest
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import OutcomeException
from _pytest.outcomes import skip
from _pytest.pathlib import fnmatch_ex
from _pytest.python import Module
from _pytest.python_api import approx
from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
import doctest
from typing_extensions import Self
DOCTEST_REPORT_CHOICE_NONE = "none"
DOCTEST_REPORT_CHOICE_CDIFF = "cdiff"
DOCTEST_REPORT_CHOICE_NDIFF = "ndiff"
DOCTEST_REPORT_CHOICE_UDIFF = "udiff"
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE = "only_first_failure"
DOCTEST_REPORT_CHOICES = (
DOCTEST_REPORT_CHOICE_NONE,
DOCTEST_REPORT_CHOICE_CDIFF,
DOCTEST_REPORT_CHOICE_NDIFF,
DOCTEST_REPORT_CHOICE_UDIFF,
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE,
)
# Lazy definition of runner class
RUNNER_CLASS = None
# Lazy definition of output checker class
CHECKER_CLASS: type[doctest.OutputChecker] | None = None
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"doctest_optionflags",
"Option flags for doctests",
type="args",
default=["ELLIPSIS"],
)
parser.addini(
"doctest_encoding", "Encoding used for doctest files", default="utf-8"
)
group = parser.getgroup("collect")
group.addoption(
"--doctest-modules",
action="store_true",
default=False,
help="Run doctests in all .py modules",
dest="doctestmodules",
)
group.addoption(
"--doctest-report",
type=str.lower,
default="udiff",
help="Choose another output format for diffs on doctest failure",
choices=DOCTEST_REPORT_CHOICES,
dest="doctestreport",
)
group.addoption(
"--doctest-glob",
action="append",
default=[],
metavar="pat",
help="Doctests file matching pattern, default: test*.txt",
dest="doctestglob",
)
group.addoption(
"--doctest-ignore-import-errors",
action="store_true",
default=False,
help="Ignore doctest collection errors",
dest="doctest_ignore_import_errors",
)
group.addoption(
"--doctest-continue-on-failure",
action="store_true",
default=False,
help="For a given doctest, continue to run after the first failure",
dest="doctest_continue_on_failure",
)
def pytest_unconfigure() -> None:
global RUNNER_CLASS
RUNNER_CLASS = None
def pytest_collect_file(
file_path: Path,
parent: Collector,
) -> DoctestModule | DoctestTextfile | None:
config = parent.config
if file_path.suffix == ".py":
if config.option.doctestmodules and not any(
(_is_setup_py(file_path), _is_main_py(file_path))
):
return DoctestModule.from_parent(parent, path=file_path)
elif _is_doctest(config, file_path, parent):
return DoctestTextfile.from_parent(parent, path=file_path)
return None
def _is_setup_py(path: Path) -> bool:
if path.name != "setup.py":
return False
contents = path.read_bytes()
return b"setuptools" in contents or b"distutils" in contents
def _is_doctest(config: Config, path: Path, parent: Collector) -> bool:
if path.suffix in (".txt", ".rst") and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
return any(fnmatch_ex(glob, path) for glob in globs)
def _is_main_py(path: Path) -> bool:
return path.name == "__main__.py"
class ReprFailDoctest(TerminalRepr):
def __init__(
self, reprlocation_lines: Sequence[tuple[ReprFileLocation, Sequence[str]]]
) -> None:
self.reprlocation_lines = reprlocation_lines
def toterminal(self, tw: TerminalWriter) -> None:
for reprlocation, lines in self.reprlocation_lines:
for line in lines:
tw.line(line)
reprlocation.toterminal(tw)
class MultipleDoctestFailures(Exception):
def __init__(self, failures: Sequence[doctest.DocTestFailure]) -> None:
super().__init__()
self.failures = failures
def _init_runner_class() -> type[doctest.DocTestRunner]:
import doctest
class PytestDoctestRunner(doctest.DebugRunner):
"""Runner to collect failures.
Note that the out variable in this case is a list instead of a
stdout-like object.
"""
def __init__(
self,
checker: doctest.OutputChecker | None = None,
verbose: bool | None = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> None:
super().__init__(checker=checker, verbose=verbose, optionflags=optionflags)
self.continue_on_failure = continue_on_failure
def report_failure(
self,
out,
test: doctest.DocTest,
example: doctest.Example,
got: str,
) -> None:
failure = doctest.DocTestFailure(test, example, got)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
def report_unexpected_exception(
self,
out,
test: doctest.DocTest,
example: doctest.Example,
exc_info: tuple[type[BaseException], BaseException, types.TracebackType],
) -> None:
if isinstance(exc_info[1], OutcomeException):
raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit):
outcomes.exit("Quitting debugger")
failure = doctest.UnexpectedException(test, example, exc_info)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
return PytestDoctestRunner
def _get_runner(
checker: doctest.OutputChecker | None = None,
verbose: bool | None = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> doctest.DocTestRunner:
# We need this in order to do a lazy import on doctest
global RUNNER_CLASS
if RUNNER_CLASS is None:
RUNNER_CLASS = _init_runner_class()
# Type ignored because the continue_on_failure argument is only defined on
# PytestDoctestRunner, which is lazily defined so can't be used as a type.
return RUNNER_CLASS( # type: ignore
checker=checker,
verbose=verbose,
optionflags=optionflags,
continue_on_failure=continue_on_failure,
)
class DoctestItem(Item):
def __init__(
self,
name: str,
parent: DoctestTextfile | DoctestModule,
runner: doctest.DocTestRunner,
dtest: doctest.DocTest,
) -> None:
super().__init__(name, parent)
self.runner = runner
self.dtest = dtest
# Stuff needed for fixture support.
self.obj = None
fm = self.session._fixturemanager
fixtureinfo = fm.getfixtureinfo(node=self, func=None, cls=None)
self._fixtureinfo = fixtureinfo
self.fixturenames = fixtureinfo.names_closure
self._initrequest()
@classmethod
def from_parent( # type: ignore[override]
cls,
parent: DoctestTextfile | DoctestModule,
*,
name: str,
runner: doctest.DocTestRunner,
dtest: doctest.DocTest,
) -> Self:
# incompatible signature due to imposed limits on subclass
"""The public named constructor."""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def _initrequest(self) -> None:
self.funcargs: dict[str, object] = {}
self._request = TopRequest(self, _ispytest=True) # type: ignore[arg-type]
def setup(self) -> None:
self._request._fillfixtures()
globs = dict(getfixture=self._request.getfixturevalue)
for name, value in self._request.getfixturevalue("doctest_namespace").items():
globs[name] = value
self.dtest.globs.update(globs)
def runtest(self) -> None:
_check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin()
failures: list[doctest.DocTestFailure] = []
# Type ignored because we change the type of `out` from what
# doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type]
if failures:
raise MultipleDoctestFailures(failures)
def _disable_output_capturing_for_darwin(self) -> None:
"""Disable output capturing. Otherwise, stdout is lost to doctest (#985)."""
if platform.system() != "Darwin":
return
capman = self.config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture(in_=True)
out, err = capman.read_global_capture()
sys.stdout.write(out)
sys.stderr.write(err)
# TODO: Type ignored -- breaks Liskov Substitution.
def repr_failure( # type: ignore[override]
self,
excinfo: ExceptionInfo[BaseException],
) -> str | TerminalRepr:
import doctest
failures: (
Sequence[doctest.DocTestFailure | doctest.UnexpectedException] | None
) = None
if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
):
failures = [excinfo.value]
elif isinstance(excinfo.value, MultipleDoctestFailures):
failures = excinfo.value.failures
if failures is None:
return super().repr_failure(excinfo)
reprlocation_lines = []
for failure in failures:
example = failure.example
test = failure.test
filename = test.filename
if test.lineno is None:
lineno = None
else:
lineno = test.lineno + example.lineno + 1
message = type(failure).__name__
# TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type]
checker = _get_checker()
report_choice = _get_report_choice(self.config.getoption("doctestreport"))
if lineno is not None:
assert failure.test.docstring is not None
lines = failure.test.docstring.splitlines(False)
# add line numbers to the left of the error message
assert test.lineno is not None
lines = [
f"{i + test.lineno + 1:03d} {x}" for (i, x) in enumerate(lines)
]
# trim docstring error lines to 10
lines = lines[max(example.lineno - 9, 0) : example.lineno + 1]
else:
lines = [
"EXAMPLE LOCATION UNKNOWN, not showing all tests of that example"
]
indent = ">>>"
for line in example.source.splitlines():
lines.append(f"??? {indent} {line}")
indent = "..."
if isinstance(failure, doctest.DocTestFailure):
lines += checker.output_difference(
example, failure.got, report_choice
).split("\n")
else:
inner_excinfo = ExceptionInfo.from_exc_info(failure.exc_info)
lines += [f"UNEXPECTED EXCEPTION: {inner_excinfo.value!r}"]
lines += [
x.strip("\n") for x in traceback.format_exception(*failure.exc_info)
]
reprlocation_lines.append((reprlocation, lines))
return ReprFailDoctest(reprlocation_lines)
def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
return self.path, self.dtest.lineno, f"[doctest] {self.name}"
def _get_flag_lookup() -> dict[str, int]:
import doctest
return dict(
DONT_ACCEPT_TRUE_FOR_1=doctest.DONT_ACCEPT_TRUE_FOR_1,
DONT_ACCEPT_BLANKLINE=doctest.DONT_ACCEPT_BLANKLINE,
NORMALIZE_WHITESPACE=doctest.NORMALIZE_WHITESPACE,
ELLIPSIS=doctest.ELLIPSIS,
IGNORE_EXCEPTION_DETAIL=doctest.IGNORE_EXCEPTION_DETAIL,
COMPARISON_FLAGS=doctest.COMPARISON_FLAGS,
ALLOW_UNICODE=_get_allow_unicode_flag(),
ALLOW_BYTES=_get_allow_bytes_flag(),
NUMBER=_get_number_flag(),
)
def get_optionflags(config: Config) -> int:
optionflags_str = config.getini("doctest_optionflags")
flag_lookup_table = _get_flag_lookup()
flag_acc = 0
for flag in optionflags_str:
flag_acc |= flag_lookup_table[flag]
return flag_acc
def _get_continue_on_failure(config: Config) -> bool:
continue_on_failure: bool = config.getvalue("doctest_continue_on_failure")
if continue_on_failure:
# We need to turn off this if we use pdb since we should stop at
# the first failure.
if config.getvalue("usepdb"):
continue_on_failure = False
return continue_on_failure
class DoctestTextfile(Module):
obj = None
def collect(self) -> Iterable[DoctestItem]:
import doctest
# Inspired by doctest.testfile; ideally we would use it directly,
# but it doesn't support passing a custom checker.
encoding = self.config.getini("doctest_encoding")
text = self.path.read_text(encoding)
filename = str(self.path)
name = self.path.name
globs = {"__name__": "__main__"}
optionflags = get_optionflags(self.config)
runner = _get_runner(
verbose=False,
optionflags=optionflags,
checker=_get_checker(),
continue_on_failure=_get_continue_on_failure(self.config),
)
parser = doctest.DocTestParser()
test = parser.get_doctest(text, globs, name, filename, 0)
if test.examples:
yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
)
def _check_all_skipped(test: doctest.DocTest) -> None:
"""Raise pytest.skip() if all examples in the given DocTest have the SKIP
option set."""
import doctest
all_skipped = all(x.options.get(doctest.SKIP, False) for x in test.examples)
if all_skipped:
skip("all tests skipped by +SKIP option")
def _is_mocked(obj: object) -> bool:
"""Return if an object is possibly a mock object by checking the
existence of a highly improbable attribute."""
return (
safe_getattr(obj, "pytest_mock_example_attribute_that_shouldnt_exist", None)
is not None
)
@contextmanager
def _patch_unwrap_mock_aware() -> Generator[None]:
"""Context manager which replaces ``inspect.unwrap`` with a version
that's aware of mock objects and doesn't recurse into them."""
real_unwrap = inspect.unwrap
def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Callable[[Any], Any] | None = None
) -> Any:
try:
if stop is None or stop is _is_mocked:
return real_unwrap(func, stop=_is_mocked)
_stop = stop
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e:
warnings.warn(
f"Got {e!r} when unwrapping {func!r}. This is usually caused "
"by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080",
PytestWarning,
)
raise
inspect.unwrap = _mock_aware_unwrap
try:
yield
finally:
inspect.unwrap = real_unwrap
class DoctestModule(Module):
def collect(self) -> Iterable[DoctestItem]:
import doctest
class MockAwareDocTestFinder(doctest.DocTestFinder):
py_ver_info_minor = sys.version_info[:2]
is_find_lineno_broken = (
py_ver_info_minor < (3, 11)
or (py_ver_info_minor == (3, 11) and sys.version_info.micro < 9)
or (py_ver_info_minor == (3, 12) and sys.version_info.micro < 3)
)
if is_find_lineno_broken:
def _find_lineno(self, obj, source_lines):
"""On older Pythons, doctest code does not take into account
`@property`. https://github.com/python/cpython/issues/61648
Moreover, wrapped Doctests need to be unwrapped so the correct
line number is returned. #8796
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)
if hasattr(obj, "__wrapped__"):
# Get the main obj in case of it being wrapped
obj = inspect.unwrap(obj)
# Type ignored because this is a private function.
return super()._find_lineno( # type:ignore[misc]
obj,
source_lines,
)
if sys.version_info < (3, 10):
def _find(
self, tests, obj, name, module, source_lines, globs, seen
) -> None:
"""Override _find to work around issue in stdlib.
https://github.com/pytest-dev/pytest/issues/3456
https://github.com/python/cpython/issues/69718
"""
if _is_mocked(obj):
return # pragma: no cover
with _patch_unwrap_mock_aware():
# Type ignored because this is a private function.
super()._find( # type:ignore[misc]
tests, obj, name, module, source_lines, globs, seen
)
if sys.version_info < (3, 13):
def _from_module(self, module, object):
"""`cached_property` objects are never considered a part
of the 'current module'. As such they are skipped by doctest.
Here we override `_from_module` to check the underlying
function instead. https://github.com/python/cpython/issues/107995
"""
if isinstance(object, functools.cached_property):
object = object.func
# Type ignored because this is a private function.
return super()._from_module(module, object) # type: ignore[misc]
try:
module = self.obj
except Collector.CollectError:
if self.config.getvalue("doctest_ignore_import_errors"):
skip(f"unable to import module {self.path!r}")
else:
raise
# While doctests currently don't support fixtures directly, we still
# need to pick up autouse fixtures.
self.session._fixturemanager.parsefactories(self)
# Uses internal doctest module parsing mechanism.
finder = MockAwareDocTestFinder()
optionflags = get_optionflags(self.config)
runner = _get_runner(
verbose=False,
optionflags=optionflags,
checker=_get_checker(),
continue_on_failure=_get_continue_on_failure(self.config),
)
for test in finder.find(module, module.__name__):
if test.examples: # skip empty doctests
yield DoctestItem.from_parent(
self, name=test.name, runner=runner, dtest=test
)
def _init_checker_class() -> type[doctest.OutputChecker]:
import doctest
class LiteralsOutputChecker(doctest.OutputChecker):
# Based on doctest_nose_plugin.py from the nltk project
# (https://github.com/nltk/nltk) and on the "numtest" doctest extension
# by Sebastien Boisgerault (https://github.com/boisgera/numtest).
_unicode_literal_re = re.compile(r"(\W|^)[uU]([rR]?[\'\"])", re.UNICODE)
_bytes_literal_re = re.compile(r"(\W|^)[bB]([rR]?[\'\"])", re.UNICODE)
_number_re = re.compile(
r"""
(?P<number>
(?P<mantissa>
(?P<integer1> [+-]?\d*)\.(?P<fraction>\d+)
|
(?P<integer2> [+-]?\d+)\.
)
(?:
[Ee]
(?P<exponent1> [+-]?\d+)
)?
|
(?P<integer3> [+-]?\d+)
(?:
[Ee]
(?P<exponent2> [+-]?\d+)
)
)
""",
re.VERBOSE,
)
def check_output(self, want: str, got: str, optionflags: int) -> bool:
if super().check_output(want, got, optionflags):
return True
allow_unicode = optionflags & _get_allow_unicode_flag()
allow_bytes = optionflags & _get_allow_bytes_flag()
allow_number = optionflags & _get_number_flag()
if not allow_unicode and not allow_bytes and not allow_number:
return False
def remove_prefixes(regex: re.Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt)
if allow_unicode:
want = remove_prefixes(self._unicode_literal_re, want)
got = remove_prefixes(self._unicode_literal_re, got)
if allow_bytes:
want = remove_prefixes(self._bytes_literal_re, want)
got = remove_prefixes(self._bytes_literal_re, got)
if allow_number:
got = self._remove_unwanted_precision(want, got)
return super().check_output(want, got, optionflags)
def _remove_unwanted_precision(self, want: str, got: str) -> str:
wants = list(self._number_re.finditer(want))
gots = list(self._number_re.finditer(got))
if len(wants) != len(gots):
return got
offset = 0
for w, g in zip(wants, gots):
fraction: str | None = w.group("fraction")
exponent: str | None = w.group("exponent1")
if exponent is None:
exponent = w.group("exponent2")
precision = 0 if fraction is None else len(fraction)
if exponent is not None:
precision -= int(exponent)
if float(w.group()) == approx(float(g.group()), abs=10**-precision):
# They're close enough. Replace the text we actually
# got with the text we want, so that it will match when we
# check the string literally.
got = (
got[: g.start() + offset] + w.group() + got[g.end() + offset :]
)
offset += w.end() - w.start() - (g.end() - g.start())
return got
return LiteralsOutputChecker
def _get_checker() -> doctest.OutputChecker:
"""Return a doctest.OutputChecker subclass that supports some
additional options:
* ALLOW_UNICODE and ALLOW_BYTES options to ignore u'' and b''
prefixes (respectively) in string literals. Useful when the same
doctest should run in Python 2 and Python 3.
* NUMBER to ignore floating-point differences smaller than the
precision of the literal number in the doctest.
An inner class is used to avoid importing "doctest" at the module
level.
"""
global CHECKER_CLASS
if CHECKER_CLASS is None:
CHECKER_CLASS = _init_checker_class()
return CHECKER_CLASS()
def _get_allow_unicode_flag() -> int:
"""Register and return the ALLOW_UNICODE flag."""
import doctest
return doctest.register_optionflag("ALLOW_UNICODE")
def _get_allow_bytes_flag() -> int:
"""Register and return the ALLOW_BYTES flag."""
import doctest
return doctest.register_optionflag("ALLOW_BYTES")
def _get_number_flag() -> int:
"""Register and return the NUMBER flag."""
import doctest
return doctest.register_optionflag("NUMBER")
def _get_report_choice(key: str) -> int:
"""Return the actual `doctest` module flag value.
We want to do it as late as possible to avoid importing `doctest` and all
its dependencies when parsing options, as it adds overhead and breaks tests.
"""
import doctest
return {
DOCTEST_REPORT_CHOICE_UDIFF: doctest.REPORT_UDIFF,
DOCTEST_REPORT_CHOICE_CDIFF: doctest.REPORT_CDIFF,
DOCTEST_REPORT_CHOICE_NDIFF: doctest.REPORT_NDIFF,
DOCTEST_REPORT_CHOICE_ONLY_FIRST_FAILURE: doctest.REPORT_ONLY_FIRST_FAILURE,
DOCTEST_REPORT_CHOICE_NONE: 0,
}[key]
@fixture(scope="session")
def doctest_namespace() -> dict[str, Any]:
"""Fixture that returns a :py:class:`dict` that will be injected into the
namespace of doctests.
Usually this fixture is used in conjunction with another ``autouse`` fixture:
.. code-block:: python
@pytest.fixture(autouse=True)
def add_np(doctest_namespace):
doctest_namespace["np"] = numpy
For more details: :ref:`doctest_namespace`.
"""
return dict()

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from collections.abc import Generator
import os
import sys
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
from _pytest.stash import StashKey
import pytest
fault_handler_original_stderr_fd_key = StashKey[int]()
fault_handler_stderr_fd_key = StashKey[int]()
def pytest_addoption(parser: Parser) -> None:
help = (
"Dump the traceback of all threads if a test takes "
"more than TIMEOUT seconds to finish"
)
parser.addini("faulthandler_timeout", help, default=0.0)
def pytest_configure(config: Config) -> None:
import faulthandler
# at teardown we want to restore the original faulthandler fileno
# but faulthandler has no api to return the original fileno
# so here we stash the stderr fileno to be used at teardown
# sys.stderr and sys.__stderr__ may be closed or patched during the session
# so we can't rely on their values being good at that point (#11572).
stderr_fileno = get_stderr_fileno()
if faulthandler.is_enabled():
config.stash[fault_handler_original_stderr_fd_key] = stderr_fileno
config.stash[fault_handler_stderr_fd_key] = os.dup(stderr_fileno)
faulthandler.enable(file=config.stash[fault_handler_stderr_fd_key])
def pytest_unconfigure(config: Config) -> None:
import faulthandler
faulthandler.disable()
# Close the dup file installed during pytest_configure.
if fault_handler_stderr_fd_key in config.stash:
os.close(config.stash[fault_handler_stderr_fd_key])
del config.stash[fault_handler_stderr_fd_key]
# Re-enable the faulthandler if it was originally enabled.
if fault_handler_original_stderr_fd_key in config.stash:
faulthandler.enable(config.stash[fault_handler_original_stderr_fd_key])
del config.stash[fault_handler_original_stderr_fd_key]
def get_stderr_fileno() -> int:
try:
fileno = sys.stderr.fileno()
# The Twisted Logger will return an invalid file descriptor since it is not backed
# by an FD. So, let's also forward this to the same code path as with pytest-xdist.
if fileno == -1:
raise AttributeError()
return fileno
except (AttributeError, ValueError):
# pytest-xdist monkeypatches sys.stderr with an object that is not an actual file.
# https://docs.python.org/3/library/faulthandler.html#issue-with-file-descriptors
# This is potentially dangerous, but the best we can do.
assert sys.__stderr__ is not None
return sys.__stderr__.fileno()
def get_timeout_config_value(config: Config) -> float:
return float(config.getini("faulthandler_timeout") or 0.0)
@pytest.hookimpl(wrapper=True, trylast=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
timeout = get_timeout_config_value(item.config)
if timeout > 0:
import faulthandler
stderr = item.config.stash[fault_handler_stderr_fd_key]
faulthandler.dump_traceback_later(timeout, file=stderr)
try:
return (yield)
finally:
faulthandler.cancel_dump_traceback_later()
else:
return (yield)
@pytest.hookimpl(tryfirst=True)
def pytest_enter_pdb() -> None:
"""Cancel any traceback dumping due to timeout before entering pdb."""
import faulthandler
faulthandler.cancel_dump_traceback_later()
@pytest.hookimpl(tryfirst=True)
def pytest_exception_interact() -> None:
"""Cancel any traceback dumping due to an interactive exception being
raised."""
import faulthandler
faulthandler.cancel_dump_traceback_later()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,45 @@
"""Provides a function to report all internal modules for using freezing
tools."""
from __future__ import annotations
from collections.abc import Iterator
import types
def freeze_includes() -> list[str]:
"""Return a list of module names used by pytest that should be
included by cx_freeze."""
import _pytest
result = list(_iter_all_modules(_pytest))
return result
def _iter_all_modules(
package: str | types.ModuleType,
prefix: str = "",
) -> Iterator[str]:
"""Iterate over the names of all modules that can be found in the given
package, recursively.
>>> import _pytest
>>> list(_iter_all_modules(_pytest))
['_pytest._argcomplete', '_pytest._code.code', ...]
"""
import os
import pkgutil
if isinstance(package, str):
path = package
else:
# Type ignored because typeshed doesn't define ModuleType.__path__
# (only defined on packages).
package_path = package.__path__
path, prefix = package_path[0], package.__name__ + "."
for _, name, is_package in pkgutil.iter_modules([path]):
if is_package:
for m in _iter_all_modules(os.path.join(path, name), prefix=name + "."):
yield prefix + m
else:
yield prefix + name

View File

@@ -0,0 +1,283 @@
# mypy: allow-untyped-defs
"""Version info, help messages, tracing configuration."""
from __future__ import annotations
from argparse import Action
from collections.abc import Generator
import os
import sys
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PrintHelp
from _pytest.config.argparsing import Parser
from _pytest.terminal import TerminalReporter
import pytest
class HelpAction(Action):
"""An argparse Action that will raise an exception in order to skip the
rest of the argument parsing when --help is passed.
This prevents argparse from quitting due to missing required arguments
when any are defined, for example by ``pytest_addoption``.
This is similar to the way that the builtin argparse --help option is
implemented by raising SystemExit.
"""
def __init__(self, option_strings, dest=None, default=False, help=None):
super().__init__(
option_strings=option_strings,
dest=dest,
const=True,
default=default,
nargs=0,
help=help,
)
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, self.const)
# We should only skip the rest of the parsing after preparse is done.
if getattr(parser._parser, "after_preparse", False):
raise PrintHelp
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--version",
"-V",
action="count",
default=0,
dest="version",
help="Display pytest version and information about plugins. "
"When given twice, also display information about plugins.",
)
group._addoption( # private to use reserved lower-case short option
"-h",
"--help",
action=HelpAction,
dest="help",
help="Show help message and configuration info",
)
group._addoption( # private to use reserved lower-case short option
"-p",
action="append",
dest="plugins",
default=[],
metavar="name",
help="Early-load given plugin module name or entry point (multi-allowed). "
"To avoid loading of plugins, use the `no:` prefix, e.g. "
"`no:doctest`. See also --disable-plugin-autoload.",
)
group.addoption(
"--disable-plugin-autoload",
action="store_true",
default=False,
help="Disable plugin auto-loading through entry point packaging metadata. "
"Only plugins explicitly specified in -p or env var PYTEST_PLUGINS will be loaded.",
)
group.addoption(
"--traceconfig",
"--trace-config",
action="store_true",
default=False,
help="Trace considerations of conftest.py files",
)
group.addoption(
"--debug",
action="store",
nargs="?",
const="pytestdebug.log",
dest="debug",
metavar="DEBUG_FILE_NAME",
help="Store internal tracing debug information in this log file. "
"This file is opened with 'w' and truncated as a result, care advised. "
"Default: pytestdebug.log.",
)
group._addoption( # private to use reserved lower-case short option
"-o",
"--override-ini",
dest="override_ini",
action="append",
help='Override ini option with "option=value" style, '
"e.g. `-o xfail_strict=True -o cache_dir=cache`.",
)
@pytest.hookimpl(wrapper=True)
def pytest_cmdline_parse() -> Generator[None, Config, Config]:
config = yield
if config.option.debug:
# --debug | --debug <file.log> was provided.
path = config.option.debug
debugfile = open(path, "w", encoding="utf-8")
debugfile.write(
"versions pytest-{}, "
"python-{}\ninvocation_dir={}\ncwd={}\nargs={}\n\n".format(
pytest.__version__,
".".join(map(str, sys.version_info)),
config.invocation_params.dir,
os.getcwd(),
config.invocation_params.args,
)
)
config.trace.root.setwriter(debugfile.write)
undo_tracing = config.pluginmanager.enable_tracing()
sys.stderr.write(f"writing pytest debug information to {path}\n")
def unset_tracing() -> None:
debugfile.close()
sys.stderr.write(f"wrote pytest debug information to {debugfile.name}\n")
config.trace.root.setwriter(None)
undo_tracing()
config.add_cleanup(unset_tracing)
return config
def showversion(config: Config) -> None:
if config.option.version > 1:
sys.stdout.write(
f"This is pytest version {pytest.__version__}, imported from {pytest.__file__}\n"
)
plugininfo = getpluginversioninfo(config)
if plugininfo:
for line in plugininfo:
sys.stdout.write(line + "\n")
else:
sys.stdout.write(f"pytest {pytest.__version__}\n")
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.version > 0:
showversion(config)
return 0
elif config.option.help:
config._do_configure()
showhelp(config)
config._ensure_unconfigure()
return 0
return None
def showhelp(config: Config) -> None:
import textwrap
reporter: TerminalReporter | None = config.pluginmanager.get_plugin(
"terminalreporter"
)
assert reporter is not None
tw = reporter._tw
tw.write(config._parser.optparser.format_help())
tw.line()
tw.line(
"[pytest] ini-options in the first "
"pytest.ini|tox.ini|setup.cfg|pyproject.toml file found:"
)
tw.line()
columns = tw.fullwidth # costly call
indent_len = 24 # based on argparse's max_help_position=24
indent = " " * indent_len
for name in config._parser._ininames:
help, type, default = config._parser._inidict[name]
if type is None:
type = "string"
if help is None:
raise TypeError(f"help argument cannot be None for {name}")
spec = f"{name} ({type}):"
tw.write(f" {spec}")
spec_len = len(spec)
if spec_len > (indent_len - 3):
# Display help starting at a new line.
tw.line()
helplines = textwrap.wrap(
help,
columns,
initial_indent=indent,
subsequent_indent=indent,
break_on_hyphens=False,
)
for line in helplines:
tw.line(line)
else:
# Display help starting after the spec, following lines indented.
tw.write(" " * (indent_len - spec_len - 2))
wrapped = textwrap.wrap(help, columns - indent_len, break_on_hyphens=False)
if wrapped:
tw.line(wrapped[0])
for line in wrapped[1:]:
tw.line(indent + line)
tw.line()
tw.line("Environment variables:")
vars = [
(
"CI",
"When set (regardless of value), pytest knows it is running in a "
"CI process and does not truncate summary info",
),
("BUILD_NUMBER", "Equivalent to CI"),
("PYTEST_ADDOPTS", "Extra command line options"),
("PYTEST_PLUGINS", "Comma-separated plugins to load during startup"),
("PYTEST_DISABLE_PLUGIN_AUTOLOAD", "Set to disable plugin auto-loading"),
("PYTEST_DEBUG", "Set to enable debug tracing of pytest's internals"),
]
for name, help in vars:
tw.line(f" {name:<24} {help}")
tw.line()
tw.line()
tw.line("to see available markers type: pytest --markers")
tw.line("to see available fixtures type: pytest --fixtures")
tw.line(
"(shown according to specified file_or_dir or current dir "
"if not specified; fixtures with leading '_' are only shown "
"with the '-v' option"
)
for warningreport in reporter.stats.get("warnings", []):
tw.line("warning : " + warningreport.message, red=True)
conftest_options = [("pytest_plugins", "list of plugin names to load")]
def getpluginversioninfo(config: Config) -> list[str]:
lines = []
plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo:
lines.append("registered third-party plugins:")
for plugin, dist in plugininfo:
loc = getattr(plugin, "__file__", repr(plugin))
content = f"{dist.project_name}-{dist.version} at {loc}"
lines.append(" " + content)
return lines
def pytest_report_header(config: Config) -> list[str]:
lines = []
if config.option.debug or config.option.traceconfig:
lines.append(f"using: pytest-{pytest.__version__}")
verinfo = getpluginversioninfo(config)
if verinfo:
lines.extend(verinfo)
if config.option.traceconfig:
lines.append("active plugins:")
items = config.pluginmanager.list_name_plugin()
for name, plugin in items:
if hasattr(plugin, "__file__"):
r = plugin.__file__
else:
r = repr(plugin)
lines.append(f" {name:<20}: {r}")
return lines

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,692 @@
# mypy: allow-untyped-defs
"""Report test results in JUnit-XML format, for use with Jenkins and build
integration servers.
Based on initial code from Ross Lawley.
Output conforms to
https://github.com/jenkinsci/xunit-plugin/blob/master/src/main/resources/org/jenkinsci/plugins/xunit/types/model/xsd/junit-10.xsd
"""
from __future__ import annotations
from collections.abc import Callable
import functools
import os
import platform
import re
import xml.etree.ElementTree as ET
from _pytest import nodes
from _pytest import timing
from _pytest._code.code import ExceptionRepr
from _pytest._code.code import ReprFileLocation
from _pytest.config import Config
from _pytest.config import filename_arg
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
import pytest
xml_key = StashKey["LogXML"]()
def bin_xml_escape(arg: object) -> str:
r"""Visually escape invalid XML characters.
For example, transforms
'hello\aworld\b'
into
'hello#x07world#x08'
Note that the #xABs are *not* XML escapes - missing the ampersand &#xAB.
The idea is to escape visually for the user rather than for XML itself.
"""
def repl(matchobj: re.Match[str]) -> str:
i = ord(matchobj.group())
if i <= 0xFF:
return f"#x{i:02X}"
else:
return f"#x{i:04X}"
# The spec range of valid chars is:
# Char ::= #x9 | #xA | #xD | [#x20-#xD7FF] | [#xE000-#xFFFD] | [#x10000-#x10FFFF]
# For an unknown(?) reason, we disallow #x7F (DEL) as well.
illegal_xml_re = (
"[^\u0009\u000a\u000d\u0020-\u007e\u0080-\ud7ff\ue000-\ufffd\u10000-\u10ffff]"
)
return re.sub(illegal_xml_re, repl, str(arg))
def merge_family(left, right) -> None:
result = {}
for kl, vl in left.items():
for kr, vr in right.items():
if not isinstance(vl, list):
raise TypeError(type(vl))
result[kl] = vl + vr
left.update(result)
families = { # pylint: disable=dict-init-mutate
"_base": {"testcase": ["classname", "name"]},
"_base_legacy": {"testcase": ["file", "line", "url"]},
}
# xUnit 1.x inherits legacy attributes.
families["xunit1"] = families["_base"].copy()
merge_family(families["xunit1"], families["_base_legacy"])
# xUnit 2.x uses strict base attributes.
families["xunit2"] = families["_base"]
class _NodeReporter:
def __init__(self, nodeid: str | TestReport, xml: LogXML) -> None:
self.id = nodeid
self.xml = xml
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0.0
self.properties: list[tuple[str, str]] = []
self.nodes: list[ET.Element] = []
self.attrs: dict[str, str] = {}
def append(self, node: ET.Element) -> None:
self.xml.add_stats(node.tag)
self.nodes.append(node)
def add_property(self, name: str, value: object) -> None:
self.properties.append((str(name), bin_xml_escape(value)))
def add_attribute(self, name: str, value: object) -> None:
self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any."""
if self.properties:
properties = ET.Element("properties")
for name, value in self.properties:
properties.append(ET.Element("property", name=name, value=value))
return properties
return None
def record_testreport(self, testreport: TestReport) -> None:
names = mangle_test_address(testreport.nodeid)
existing_attrs = self.attrs
classnames = names[:-1]
if self.xml.prefix:
classnames.insert(0, self.xml.prefix)
attrs: dict[str, str] = {
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
}
if testreport.location[1] is not None:
attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
attrs["url"] = testreport.url
self.attrs = attrs
self.attrs.update(existing_attrs) # Restore any user-defined attributes.
# Preserve legacy testcase behavior.
if self.family == "xunit1":
return
# Filter out attributes not permitted by this test family.
# Including custom attributes because they are not valid here.
temp_attrs = {}
for key in self.attrs:
if key in families[self.family]["testcase"]:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs
def to_xml(self) -> ET.Element:
testcase = ET.Element("testcase", self.attrs, time=f"{self.duration:.3f}")
properties = self.make_properties_node()
if properties is not None:
testcase.append(properties)
testcase.extend(self.nodes)
return testcase
def _add_simple(self, tag: str, message: str, data: str | None = None) -> None:
node = ET.Element(tag, message=message)
node.text = bin_xml_escape(data)
self.append(node)
def write_captured_output(self, report: TestReport) -> None:
if not self.xml.log_passing_tests and report.passed:
return
content_out = report.capstdout
content_log = report.caplog
content_err = report.capstderr
if self.xml.logging == "no":
return
content_all = ""
if self.xml.logging in ["log", "all"]:
content_all = self._prepare_content(content_log, " Captured Log ")
if self.xml.logging in ["system-out", "out-err", "all"]:
content_all += self._prepare_content(content_out, " Captured Out ")
self._write_content(report, content_all, "system-out")
content_all = ""
if self.xml.logging in ["system-err", "out-err", "all"]:
content_all += self._prepare_content(content_err, " Captured Err ")
self._write_content(report, content_all, "system-err")
content_all = ""
if content_all:
self._write_content(report, content_all, "system-out")
def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = ET.Element(jheader)
tag.text = bin_xml_escape(content)
self.append(tag)
def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")
def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
self._add_simple("skipped", "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
reprcrash: ReprFileLocation | None = getattr(
report.longrepr, "reprcrash", None
)
if reprcrash is not None:
message = reprcrash.message
else:
message = str(report.longrepr)
message = bin_xml_escape(message)
self._add_simple("failure", message, str(report.longrepr))
def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self._add_simple("error", "collection failure", str(report.longrepr))
def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple("skipped", "collection skipped", str(report.longrepr))
def append_error(self, report: TestReport) -> None:
assert report.longrepr is not None
reprcrash: ReprFileLocation | None = getattr(report.longrepr, "reprcrash", None)
if reprcrash is not None:
reason = reprcrash.message
else:
reason = str(report.longrepr)
if report.when == "teardown":
msg = f'failed on teardown with "{reason}"'
else:
msg = f'failed on setup with "{reason}"'
self._add_simple("error", bin_xml_escape(msg), str(report.longrepr))
def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
xfailreason = xfailreason[8:]
xfailreason = bin_xml_escape(xfailreason)
skipped = ET.Element("skipped", type="pytest.xfail", message=xfailreason)
self.append(skipped)
else:
assert isinstance(report.longrepr, tuple)
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
details = f"{filename}:{lineno}: {skipreason}"
skipped = ET.Element(
"skipped", type="pytest.skip", message=bin_xml_escape(skipreason)
)
skipped.text = bin_xml_escape(details)
self.append(skipped)
self.write_captured_output(report)
def finalize(self) -> None:
data = self.to_xml()
self.__dict__.clear()
# Type ignored because mypy doesn't like overriding a method.
# Also the return value doesn't match...
self.to_xml = lambda: data # type: ignore[method-assign]
def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str
) -> None:
"""Emit a PytestWarning about the given fixture being incompatible with newer xunit revisions."""
from _pytest.warning_types import PytestWarning
xml = request.config.stash.get(xml_key, None)
if xml is not None and xml.family not in ("xunit1", "legacy"):
request.node.warn(
PytestWarning(
f"{fixture_name} is incompatible with junit_family '{xml.family}' (use 'legacy' or 'xunit1')"
)
)
@pytest.fixture
def record_property(request: FixtureRequest) -> Callable[[str, object], None]:
"""Add extra properties to the calling test.
User properties become part of the test report and are available to the
configured reporters, like JUnit XML.
The fixture is callable with ``name, value``. The value is automatically
XML-encoded.
Example::
def test_function(record_property):
record_property("example_key", 1)
"""
_warn_incompatibility_with_xunit2(request, "record_property")
def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value))
return append_property
@pytest.fixture
def record_xml_attribute(request: FixtureRequest) -> Callable[[str, object], None]:
"""Add extra xml attributes to the tag for the calling test.
The fixture is callable with ``name, value``. The value is
automatically XML-encoded.
"""
from _pytest.warning_types import PytestExperimentalApiWarning
request.node.warn(
PytestExperimentalApiWarning("record_xml_attribute is an experimental feature")
)
_warn_incompatibility_with_xunit2(request, "record_xml_attribute")
# Declare noop
def add_attr_noop(name: str, value: object) -> None:
pass
attr_func = add_attr_noop
xml = request.config.stash.get(xml_key, None)
if xml is not None:
node_reporter = xml.node_reporter(request.node.nodeid)
attr_func = node_reporter.add_attribute
return attr_func
def _check_record_param_type(param: str, v: str) -> None:
"""Used by record_testsuite_property to check that the given parameter name is of the proper
type."""
__tracebackhide__ = True
if not isinstance(v, str):
msg = "{param} parameter needs to be a string, but {g} given" # type: ignore[unreachable]
raise TypeError(msg.format(param=param, g=type(v).__name__))
@pytest.fixture(scope="session")
def record_testsuite_property(request: FixtureRequest) -> Callable[[str, object], None]:
"""Record a new ``<property>`` tag as child of the root ``<testsuite>``.
This is suitable to writing global information regarding the entire test
suite, and is compatible with ``xunit2`` JUnit family.
This is a ``session``-scoped fixture which is called with ``(name, value)``. Example:
.. code-block:: python
def test_foo(record_testsuite_property):
record_testsuite_property("ARCH", "PPC")
record_testsuite_property("STORAGE_TYPE", "CEPH")
:param name:
The property name.
:param value:
The property value. Will be converted to a string.
.. warning::
Currently this fixture **does not work** with the
`pytest-xdist <https://github.com/pytest-dev/pytest-xdist>`__ plugin. See
:issue:`7767` for details.
"""
__tracebackhide__ = True
def record_func(name: str, value: object) -> None:
"""No-op function in case --junit-xml was not passed in the command-line."""
__tracebackhide__ = True
_check_record_param_type("name", name)
xml = request.config.stash.get(xml_key, None)
if xml is not None:
record_func = xml.add_global_property
return record_func
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--junitxml",
"--junit-xml",
action="store",
dest="xmlpath",
metavar="path",
type=functools.partial(filename_arg, optname="--junitxml"),
default=None,
help="Create junit-xml style report file at given path",
)
group.addoption(
"--junitprefix",
"--junit-prefix",
action="store",
metavar="str",
default=None,
help="Prepend prefix to classnames in junit-xml output",
)
parser.addini(
"junit_suite_name", "Test suite name for JUnit report", default="pytest"
)
parser.addini(
"junit_logging",
"Write captured log messages to JUnit report: "
"one of no|log|system-out|system-err|out-err|all",
default="no",
)
parser.addini(
"junit_log_passing_tests",
"Capture log information for passing tests to JUnit report: ",
type="bool",
default=True,
)
parser.addini(
"junit_duration_report",
"Duration time to report: one of total|call",
default="total",
) # choices=['total', 'call'])
parser.addini(
"junit_family",
"Emit XML for schema: one of legacy|xunit1|xunit2",
default="xunit2",
)
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath
# Prevent opening xmllog on worker nodes (xdist).
if xmlpath and not hasattr(config, "workerinput"):
junit_family = config.getini("junit_family")
config.stash[xml_key] = LogXML(
xmlpath,
config.option.junitprefix,
config.getini("junit_suite_name"),
config.getini("junit_logging"),
config.getini("junit_duration_report"),
junit_family,
config.getini("junit_log_passing_tests"),
)
config.pluginmanager.register(config.stash[xml_key])
def pytest_unconfigure(config: Config) -> None:
xml = config.stash.get(xml_key, None)
if xml:
del config.stash[xml_key]
config.pluginmanager.unregister(xml)
def mangle_test_address(address: str) -> list[str]:
path, possible_open_bracket, params = address.partition("[")
names = path.split("::")
# Convert file path to dotted path.
names[0] = names[0].replace(nodes.SEP, ".")
names[0] = re.sub(r"\.py$", "", names[0])
# Put any params back.
names[-1] += possible_open_bracket + params
return names
class LogXML:
def __init__(
self,
logfile,
prefix: str | None,
suite_name: str = "pytest",
logging: str = "no",
report_duration: str = "total",
family="xunit1",
log_passing_tests: bool = True,
) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile))
self.prefix = prefix
self.suite_name = suite_name
self.logging = logging
self.log_passing_tests = log_passing_tests
self.report_duration = report_duration
self.family = family
self.stats: dict[str, int] = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0
)
self.node_reporters: dict[tuple[str | TestReport, object], _NodeReporter] = {}
self.node_reporters_ordered: list[_NodeReporter] = []
self.global_properties: list[tuple[str, str]] = []
# List of reports that failed on call but teardown is pending.
self.open_reports: list[TestReport] = []
self.cnt_double_fail_tests = 0
# Replaces convenience family with real family.
if self.family == "legacy":
self.family = "xunit1"
def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
reporter = self.node_reporters.pop((nodeid, workernode))
for propname, propvalue in report.user_properties:
reporter.add_property(propname, str(propvalue))
if reporter is not None:
reporter.finalize()
def node_reporter(self, report: TestReport | str) -> _NodeReporter:
nodeid: str | TestReport = getattr(report, "nodeid", report)
# Local hack to handle xdist report order.
workernode = getattr(report, "node", None)
key = nodeid, workernode
if key in self.node_reporters:
# TODO: breaks for --dist=each
return self.node_reporters[key]
reporter = _NodeReporter(nodeid, self)
self.node_reporters[key] = reporter
self.node_reporters_ordered.append(reporter)
return reporter
def add_stats(self, key: str) -> None:
if key in self.stats:
self.stats[key] += 1
def _opentestcase(self, report: TestReport) -> _NodeReporter:
reporter = self.node_reporter(report)
reporter.record_testreport(report)
return reporter
def pytest_runtest_logreport(self, report: TestReport) -> None:
"""Handle a setup/call/teardown report, generating the appropriate
XML tags as necessary.
Note: due to plugins like xdist, this hook may be called in interlaced
order with reports from other nodes. For example:
Usual call order:
-> setup node1
-> call node1
-> teardown node1
-> setup node2
-> call node2
-> teardown node2
Possible call order in xdist:
-> setup node1
-> call node1
-> setup node2
-> call node2
-> teardown node2
-> teardown node1
"""
close_report = None
if report.passed:
if report.when == "call": # ignore setup/teardown
reporter = self._opentestcase(report)
reporter.append_pass(report)
elif report.failed:
if report.when == "teardown":
# The following vars are needed when xdist plugin is used.
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
close_report = next(
(
rep
for rep in self.open_reports
if (
rep.nodeid == report.nodeid
and getattr(rep, "item_index", None) == report_ii
and getattr(rep, "worker_id", None) == report_wid
)
),
None,
)
if close_report:
# We need to open new testcase in case we have failure in
# call and error in teardown in order to follow junit
# schema.
self.finalize(close_report)
self.cnt_double_fail_tests += 1
reporter = self._opentestcase(report)
if report.when == "call":
reporter.append_failure(report)
self.open_reports.append(report)
if not self.log_passing_tests:
reporter.write_captured_output(report)
else:
reporter.append_error(report)
elif report.skipped:
reporter = self._opentestcase(report)
reporter.append_skipped(report)
self.update_testcase_duration(report)
if report.when == "teardown":
reporter = self._opentestcase(report)
reporter.write_captured_output(report)
self.finalize(report)
report_wid = getattr(report, "worker_id", None)
report_ii = getattr(report, "item_index", None)
close_report = next(
(
rep
for rep in self.open_reports
if (
rep.nodeid == report.nodeid
and getattr(rep, "item_index", None) == report_ii
and getattr(rep, "worker_id", None) == report_wid
)
),
None,
)
if close_report:
self.open_reports.remove(close_report)
def update_testcase_duration(self, report: TestReport) -> None:
"""Accumulate total duration for nodeid from given report and update
the Junit.testcase with the new total if already created."""
if self.report_duration in {"total", report.when}:
reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0)
def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed:
reporter = self._opentestcase(report)
if report.failed:
reporter.append_collect_error(report)
else:
reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr: ExceptionRepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple("error", "internal error", str(excrepr))
def pytest_sessionstart(self) -> None:
self.suite_start = timing.Instant()
def pytest_sessionfinish(self) -> None:
dirname = os.path.dirname(os.path.abspath(self.logfile))
# exist_ok avoids filesystem race conditions between checking path existence and requesting creation
os.makedirs(dirname, exist_ok=True)
with open(self.logfile, "w", encoding="utf-8") as logfile:
duration = self.suite_start.elapsed()
numtests = (
self.stats["passed"]
+ self.stats["failure"]
+ self.stats["skipped"]
+ self.stats["error"]
- self.cnt_double_fail_tests
)
logfile.write('<?xml version="1.0" encoding="utf-8"?>')
suite_node = ET.Element(
"testsuite",
name=self.suite_name,
errors=str(self.stats["error"]),
failures=str(self.stats["failure"]),
skipped=str(self.stats["skipped"]),
tests=str(numtests),
time=f"{duration.seconds:.3f}",
timestamp=self.suite_start.as_utc().astimezone().isoformat(),
hostname=platform.node(),
)
global_properties = self._get_global_properties_node()
if global_properties is not None:
suite_node.append(global_properties)
for node_reporter in self.node_reporters_ordered:
suite_node.append(node_reporter.to_xml())
testsuites = ET.Element("testsuites")
testsuites.set("name", "pytest tests")
testsuites.append(suite_node)
logfile.write(ET.tostring(testsuites, encoding="unicode"))
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", f"generated xml file: {self.logfile}")
def add_global_property(self, name: str, value: object) -> None:
__tracebackhide__ = True
_check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self) -> ET.Element | None:
"""Return a Junit node containing custom properties, if any."""
if self.global_properties:
properties = ET.Element("properties")
for name, value in self.global_properties:
properties.append(ET.Element("property", name=name, value=value))
return properties
return None

View File

@@ -0,0 +1,468 @@
# mypy: allow-untyped-defs
"""Add backward compatibility support for the legacy py path type."""
from __future__ import annotations
import dataclasses
from pathlib import Path
import shlex
import subprocess
from typing import Final
from typing import final
from typing import TYPE_CHECKING
from iniconfig import SectionWrapper
from _pytest.cacheprovider import Cache
from _pytest.compat import LEGACY_PATH
from _pytest.compat import legacy_path
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config import PytestPluginManager
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.pytester import HookRecorder
from _pytest.pytester import Pytester
from _pytest.pytester import RunResult
from _pytest.terminal import TerminalReporter
from _pytest.tmpdir import TempPathFactory
if TYPE_CHECKING:
import pexpect
@final
class Testdir:
"""
Similar to :class:`Pytester`, but this class works with legacy legacy_path objects instead.
All methods just forward to an internal :class:`Pytester` instance, converting results
to `legacy_path` objects as necessary.
"""
__test__ = False
CLOSE_STDIN: Final = Pytester.CLOSE_STDIN
TimeoutExpired: Final = Pytester.TimeoutExpired
def __init__(self, pytester: Pytester, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._pytester = pytester
@property
def tmpdir(self) -> LEGACY_PATH:
"""Temporary directory where tests are executed."""
return legacy_path(self._pytester.path)
@property
def test_tmproot(self) -> LEGACY_PATH:
return legacy_path(self._pytester._test_tmproot)
@property
def request(self):
return self._pytester._request
@property
def plugins(self):
return self._pytester.plugins
@plugins.setter
def plugins(self, plugins):
self._pytester.plugins = plugins
@property
def monkeypatch(self) -> MonkeyPatch:
return self._pytester._monkeypatch
def make_hook_recorder(self, pluginmanager) -> HookRecorder:
"""See :meth:`Pytester.make_hook_recorder`."""
return self._pytester.make_hook_recorder(pluginmanager)
def chdir(self) -> None:
"""See :meth:`Pytester.chdir`."""
return self._pytester.chdir()
def finalize(self) -> None:
return self._pytester._finalize()
def makefile(self, ext, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.makefile`."""
if ext and not ext.startswith("."):
# pytester.makefile is going to throw a ValueError in a way that
# testdir.makefile did not, because
# pathlib.Path is stricter suffixes than py.path
# This ext arguments is likely user error, but since testdir has
# allowed this, we will prepend "." as a workaround to avoid breaking
# testdir usage that worked before
ext = "." + ext
return legacy_path(self._pytester.makefile(ext, *args, **kwargs))
def makeconftest(self, source) -> LEGACY_PATH:
"""See :meth:`Pytester.makeconftest`."""
return legacy_path(self._pytester.makeconftest(source))
def makeini(self, source) -> LEGACY_PATH:
"""See :meth:`Pytester.makeini`."""
return legacy_path(self._pytester.makeini(source))
def getinicfg(self, source: str) -> SectionWrapper:
"""See :meth:`Pytester.getinicfg`."""
return self._pytester.getinicfg(source)
def makepyprojecttoml(self, source) -> LEGACY_PATH:
"""See :meth:`Pytester.makepyprojecttoml`."""
return legacy_path(self._pytester.makepyprojecttoml(source))
def makepyfile(self, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.makepyfile`."""
return legacy_path(self._pytester.makepyfile(*args, **kwargs))
def maketxtfile(self, *args, **kwargs) -> LEGACY_PATH:
"""See :meth:`Pytester.maketxtfile`."""
return legacy_path(self._pytester.maketxtfile(*args, **kwargs))
def syspathinsert(self, path=None) -> None:
"""See :meth:`Pytester.syspathinsert`."""
return self._pytester.syspathinsert(path)
def mkdir(self, name) -> LEGACY_PATH:
"""See :meth:`Pytester.mkdir`."""
return legacy_path(self._pytester.mkdir(name))
def mkpydir(self, name) -> LEGACY_PATH:
"""See :meth:`Pytester.mkpydir`."""
return legacy_path(self._pytester.mkpydir(name))
def copy_example(self, name=None) -> LEGACY_PATH:
"""See :meth:`Pytester.copy_example`."""
return legacy_path(self._pytester.copy_example(name))
def getnode(self, config: Config, arg) -> Item | Collector | None:
"""See :meth:`Pytester.getnode`."""
return self._pytester.getnode(config, arg)
def getpathnode(self, path):
"""See :meth:`Pytester.getpathnode`."""
return self._pytester.getpathnode(path)
def genitems(self, colitems: list[Item | Collector]) -> list[Item]:
"""See :meth:`Pytester.genitems`."""
return self._pytester.genitems(colitems)
def runitem(self, source):
"""See :meth:`Pytester.runitem`."""
return self._pytester.runitem(source)
def inline_runsource(self, source, *cmdlineargs):
"""See :meth:`Pytester.inline_runsource`."""
return self._pytester.inline_runsource(source, *cmdlineargs)
def inline_genitems(self, *args):
"""See :meth:`Pytester.inline_genitems`."""
return self._pytester.inline_genitems(*args)
def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False):
"""See :meth:`Pytester.inline_run`."""
return self._pytester.inline_run(
*args, plugins=plugins, no_reraise_ctrlc=no_reraise_ctrlc
)
def runpytest_inprocess(self, *args, **kwargs) -> RunResult:
"""See :meth:`Pytester.runpytest_inprocess`."""
return self._pytester.runpytest_inprocess(*args, **kwargs)
def runpytest(self, *args, **kwargs) -> RunResult:
"""See :meth:`Pytester.runpytest`."""
return self._pytester.runpytest(*args, **kwargs)
def parseconfig(self, *args) -> Config:
"""See :meth:`Pytester.parseconfig`."""
return self._pytester.parseconfig(*args)
def parseconfigure(self, *args) -> Config:
"""See :meth:`Pytester.parseconfigure`."""
return self._pytester.parseconfigure(*args)
def getitem(self, source, funcname="test_func"):
"""See :meth:`Pytester.getitem`."""
return self._pytester.getitem(source, funcname)
def getitems(self, source):
"""See :meth:`Pytester.getitems`."""
return self._pytester.getitems(source)
def getmodulecol(self, source, configargs=(), withinit=False):
"""See :meth:`Pytester.getmodulecol`."""
return self._pytester.getmodulecol(
source, configargs=configargs, withinit=withinit
)
def collect_by_name(self, modcol: Collector, name: str) -> Item | Collector | None:
"""See :meth:`Pytester.collect_by_name`."""
return self._pytester.collect_by_name(modcol, name)
def popen(
self,
cmdargs,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=CLOSE_STDIN,
**kw,
):
"""See :meth:`Pytester.popen`."""
return self._pytester.popen(cmdargs, stdout, stderr, stdin, **kw)
def run(self, *cmdargs, timeout=None, stdin=CLOSE_STDIN) -> RunResult:
"""See :meth:`Pytester.run`."""
return self._pytester.run(*cmdargs, timeout=timeout, stdin=stdin)
def runpython(self, script) -> RunResult:
"""See :meth:`Pytester.runpython`."""
return self._pytester.runpython(script)
def runpython_c(self, command):
"""See :meth:`Pytester.runpython_c`."""
return self._pytester.runpython_c(command)
def runpytest_subprocess(self, *args, timeout=None) -> RunResult:
"""See :meth:`Pytester.runpytest_subprocess`."""
return self._pytester.runpytest_subprocess(*args, timeout=timeout)
def spawn_pytest(self, string: str, expect_timeout: float = 10.0) -> pexpect.spawn:
"""See :meth:`Pytester.spawn_pytest`."""
return self._pytester.spawn_pytest(string, expect_timeout=expect_timeout)
def spawn(self, cmd: str, expect_timeout: float = 10.0) -> pexpect.spawn:
"""See :meth:`Pytester.spawn`."""
return self._pytester.spawn(cmd, expect_timeout=expect_timeout)
def __repr__(self) -> str:
return f"<Testdir {self.tmpdir!r}>"
def __str__(self) -> str:
return str(self.tmpdir)
class LegacyTestdirPlugin:
@staticmethod
@fixture
def testdir(pytester: Pytester) -> Testdir:
"""
Identical to :fixture:`pytester`, and provides an instance whose methods return
legacy ``LEGACY_PATH`` objects instead when applicable.
New code should avoid using :fixture:`testdir` in favor of :fixture:`pytester`.
"""
return Testdir(pytester, _ispytest=True)
@final
@dataclasses.dataclass
class TempdirFactory:
"""Backward compatibility wrapper that implements ``py.path.local``
for :class:`TempPathFactory`.
.. note::
These days, it is preferred to use ``tmp_path_factory``.
:ref:`About the tmpdir and tmpdir_factory fixtures<tmpdir and tmpdir_factory>`.
"""
_tmppath_factory: TempPathFactory
def __init__(
self, tmppath_factory: TempPathFactory, *, _ispytest: bool = False
) -> None:
check_ispytest(_ispytest)
self._tmppath_factory = tmppath_factory
def mktemp(self, basename: str, numbered: bool = True) -> LEGACY_PATH:
"""Same as :meth:`TempPathFactory.mktemp`, but returns a ``py.path.local`` object."""
return legacy_path(self._tmppath_factory.mktemp(basename, numbered).resolve())
def getbasetemp(self) -> LEGACY_PATH:
"""Same as :meth:`TempPathFactory.getbasetemp`, but returns a ``py.path.local`` object."""
return legacy_path(self._tmppath_factory.getbasetemp().resolve())
class LegacyTmpdirPlugin:
@staticmethod
@fixture(scope="session")
def tmpdir_factory(request: FixtureRequest) -> TempdirFactory:
"""Return a :class:`pytest.TempdirFactory` instance for the test session."""
# Set dynamically by pytest_configure().
return request.config._tmpdirhandler # type: ignore
@staticmethod
@fixture
def tmpdir(tmp_path: Path) -> LEGACY_PATH:
"""Return a temporary directory (as `legacy_path`_ object)
which is unique to each test function invocation.
The temporary directory is created as a subdirectory
of the base temporary directory, with configurable retention,
as discussed in :ref:`temporary directory location and retention`.
.. note::
These days, it is preferred to use ``tmp_path``.
:ref:`About the tmpdir and tmpdir_factory fixtures<tmpdir and tmpdir_factory>`.
.. _legacy_path: https://py.readthedocs.io/en/latest/path.html
"""
return legacy_path(tmp_path)
def Cache_makedir(self: Cache, name: str) -> LEGACY_PATH:
"""Return a directory path object with the given name.
Same as :func:`mkdir`, but returns a legacy py path instance.
"""
return legacy_path(self.mkdir(name))
def FixtureRequest_fspath(self: FixtureRequest) -> LEGACY_PATH:
"""(deprecated) The file system path of the test module which collected this test."""
return legacy_path(self.path)
def TerminalReporter_startdir(self: TerminalReporter) -> LEGACY_PATH:
"""The directory from which pytest was invoked.
Prefer to use ``startpath`` which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(self.startpath)
def Config_invocation_dir(self: Config) -> LEGACY_PATH:
"""The directory from which pytest was invoked.
Prefer to use :attr:`invocation_params.dir <InvocationParams.dir>`,
which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(str(self.invocation_params.dir))
def Config_rootdir(self: Config) -> LEGACY_PATH:
"""The path to the :ref:`rootdir <rootdir>`.
Prefer to use :attr:`rootpath`, which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(str(self.rootpath))
def Config_inifile(self: Config) -> LEGACY_PATH | None:
"""The path to the :ref:`configfile <configfiles>`.
Prefer to use :attr:`inipath`, which is a :class:`pathlib.Path`.
:type: Optional[LEGACY_PATH]
"""
return legacy_path(str(self.inipath)) if self.inipath else None
def Session_startdir(self: Session) -> LEGACY_PATH:
"""The path from which pytest was invoked.
Prefer to use ``startpath`` which is a :class:`pathlib.Path`.
:type: LEGACY_PATH
"""
return legacy_path(self.startpath)
def Config__getini_unknown_type(self, name: str, type: str, value: str | list[str]):
if type == "pathlist":
# TODO: This assert is probably not valid in all cases.
assert self.inipath is not None
dp = self.inipath.parent
input_values = shlex.split(value) if isinstance(value, str) else value
return [legacy_path(str(dp / x)) for x in input_values]
else:
raise ValueError(f"unknown configuration type: {type}", value)
def Node_fspath(self: Node) -> LEGACY_PATH:
"""(deprecated) returns a legacy_path copy of self.path"""
return legacy_path(self.path)
def Node_fspath_set(self: Node, value: LEGACY_PATH) -> None:
self.path = Path(value)
@hookimpl(tryfirst=True)
def pytest_load_initial_conftests(early_config: Config) -> None:
"""Monkeypatch legacy path attributes in several classes, as early as possible."""
mp = MonkeyPatch()
early_config.add_cleanup(mp.undo)
# Add Cache.makedir().
mp.setattr(Cache, "makedir", Cache_makedir, raising=False)
# Add FixtureRequest.fspath property.
mp.setattr(FixtureRequest, "fspath", property(FixtureRequest_fspath), raising=False)
# Add TerminalReporter.startdir property.
mp.setattr(
TerminalReporter, "startdir", property(TerminalReporter_startdir), raising=False
)
# Add Config.{invocation_dir,rootdir,inifile} properties.
mp.setattr(Config, "invocation_dir", property(Config_invocation_dir), raising=False)
mp.setattr(Config, "rootdir", property(Config_rootdir), raising=False)
mp.setattr(Config, "inifile", property(Config_inifile), raising=False)
# Add Session.startdir property.
mp.setattr(Session, "startdir", property(Session_startdir), raising=False)
# Add pathlist configuration type.
mp.setattr(Config, "_getini_unknown_type", Config__getini_unknown_type)
# Add Node.fspath property.
mp.setattr(Node, "fspath", property(Node_fspath, Node_fspath_set), raising=False)
@hookimpl
def pytest_configure(config: Config) -> None:
"""Installs the LegacyTmpdirPlugin if the ``tmpdir`` plugin is also installed."""
if config.pluginmanager.has_plugin("tmpdir"):
mp = MonkeyPatch()
config.add_cleanup(mp.undo)
# Create TmpdirFactory and attach it to the config object.
#
# This is to comply with existing plugins which expect the handler to be
# available at pytest_configure time, but ideally should be moved entirely
# to the tmpdir_factory session fixture.
try:
tmp_path_factory = config._tmp_path_factory # type: ignore[attr-defined]
except AttributeError:
# tmpdir plugin is blocked.
pass
else:
_tmpdirhandler = TempdirFactory(tmp_path_factory, _ispytest=True)
mp.setattr(config, "_tmpdirhandler", _tmpdirhandler, raising=False)
config.pluginmanager.register(LegacyTmpdirPlugin, "legacypath-tmpdir")
@hookimpl
def pytest_plugin_registered(plugin: object, manager: PytestPluginManager) -> None:
# pytester is not loaded by default and is commonly loaded from a conftest,
# so checking for it in `pytest_configure` is not enough.
is_pytester = plugin is manager.get_plugin("pytester")
if is_pytester and not manager.is_registered(LegacyTestdirPlugin):
manager.register(LegacyTestdirPlugin, "legacypath-pytester")

View File

@@ -0,0 +1,960 @@
# mypy: allow-untyped-defs
"""Access and control log capturing."""
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Set as AbstractSet
from contextlib import contextmanager
from contextlib import nullcontext
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import io
from io import StringIO
import logging
from logging import LogRecord
import os
from pathlib import Path
import re
from types import TracebackType
from typing import final
from typing import Generic
from typing import Literal
from typing import TYPE_CHECKING
from typing import TypeVar
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.capture import CaptureManager
from _pytest.config import _strtobool
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
if TYPE_CHECKING:
logging_StreamHandler = logging.StreamHandler[StringIO]
else:
logging_StreamHandler = logging.StreamHandler
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
DEFAULT_LOG_DATE_FORMAT = "%H:%M:%S"
_ANSI_ESCAPE_SEQ = re.compile(r"\x1b\[[\d;]+m")
caplog_handler_key = StashKey["LogCaptureHandler"]()
caplog_records_key = StashKey[dict[str, list[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
class DatetimeFormatter(logging.Formatter):
"""A logging formatter which formats record with
:func:`datetime.datetime.strftime` formatter instead of
:func:`time.strftime` in case of microseconds in format string.
"""
def formatTime(self, record: LogRecord, datefmt: str | None = None) -> str:
if datefmt and "%f" in datefmt:
ct = self.converter(record.created)
tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone)
# Construct `datetime.datetime` object from `struct_time`
# and msecs information from `record`
# Using int() instead of round() to avoid it exceeding 1_000_000 and causing a ValueError (#11861).
dt = datetime(*ct[0:6], microsecond=int(record.msecs * 1000), tzinfo=tz)
return dt.strftime(datefmt)
# Use `logging.Formatter` for non-microsecond formats
return super().formatTime(record, datefmt)
class ColoredLevelFormatter(DatetimeFormatter):
"""A logging formatter which colorizes the %(levelname)..s part of the
log format passed to __init__."""
LOGLEVEL_COLOROPTS: Mapping[int, AbstractSet[str]] = {
logging.CRITICAL: {"red"},
logging.ERROR: {"red", "bold"},
logging.WARNING: {"yellow"},
logging.WARN: {"yellow"},
logging.INFO: {"green"},
logging.DEBUG: {"purple"},
logging.NOTSET: set(),
}
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*(?:\.\d+)?s)")
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._terminalwriter = terminalwriter
self._original_fmt = self._style._fmt
self._level_to_fmt_mapping: dict[int, str] = {}
for level, color_opts in self.LOGLEVEL_COLOROPTS.items():
self.add_color_level(level, *color_opts)
def add_color_level(self, level: int, *color_opts: str) -> None:
"""Add or update color opts for a log level.
:param level:
Log level to apply a style to, e.g. ``logging.INFO``.
:param color_opts:
ANSI escape sequence color options. Capitalized colors indicates
background color, i.e. ``'green', 'Yellow', 'bold'`` will give bold
green text on yellow background.
.. warning::
This is an experimental API.
"""
assert self._fmt is not None
levelname_fmt_match = self.LEVELNAME_FMT_REGEX.search(self._fmt)
if not levelname_fmt_match:
return
levelname_fmt = levelname_fmt_match.group()
formatted_levelname = levelname_fmt % {"levelname": logging.getLevelName(level)}
# add ANSI escape sequences around the formatted levelname
color_kwargs = {name: True for name in color_opts}
colorized_formatted_levelname = self._terminalwriter.markup(
formatted_levelname, **color_kwargs
)
self._level_to_fmt_mapping[level] = self.LEVELNAME_FMT_REGEX.sub(
colorized_formatted_levelname, self._fmt
)
def format(self, record: logging.LogRecord) -> str:
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
self._style._fmt = fmt
return super().format(record)
class PercentStyleMultiline(logging.PercentStyle):
"""A logging style with special support for multiline messages.
If the message of a record consists of multiple lines, this style
formats the message as if each line were logged separately.
"""
def __init__(self, fmt: str, auto_indent: int | str | bool | None) -> None:
super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod
def _get_auto_indent(auto_indent_option: int | str | bool | None) -> int:
"""Determine the current auto indentation setting.
Specify auto indent behavior (on/off/fixed) by passing in
extra={"auto_indent": [value]} to the call to logging.log() or
using a --log-auto-indent [value] command line or the
log_auto_indent [value] config option.
Default behavior is auto-indent off.
Using the string "True" or "on" or the boolean True as the value
turns auto indent on, using the string "False" or "off" or the
boolean False or the int 0 turns it off, and specifying a
positive integer fixes the indentation position to the value
specified.
Any other values for the option are invalid, and will silently be
converted to the default.
:param None|bool|int|str auto_indent_option:
User specified option for indentation from command line, config
or extra kwarg. Accepts int, bool or str. str option accepts the
same range of values as boolean config options, as well as
positive integers represented in str form.
:returns:
Indentation value, which can be
-1 (automatically determine indentation) or
0 (auto-indent turned off) or
>0 (explicitly set indentation position).
"""
if auto_indent_option is None:
return 0
elif isinstance(auto_indent_option, bool):
if auto_indent_option:
return -1
else:
return 0
elif isinstance(auto_indent_option, int):
return int(auto_indent_option)
elif isinstance(auto_indent_option, str):
try:
return int(auto_indent_option)
except ValueError:
pass
try:
if _strtobool(auto_indent_option):
return -1
except ValueError:
return 0
return 0
def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message:
if hasattr(record, "auto_indent"):
# Passed in from the "extra={}" kwarg on the call to logging.log().
auto_indent = self._get_auto_indent(record.auto_indent)
else:
auto_indent = self._auto_indent
if auto_indent:
lines = record.message.splitlines()
formatted = self._fmt % {**record.__dict__, "message": lines[0]}
if auto_indent < 0:
indentation = _remove_ansi_escape_sequences(formatted).find(
lines[0]
)
else:
# Optimizes logging by allowing a fixed indentation.
indentation = auto_indent
lines[0] = formatted
return ("\n" + " " * indentation).join(lines)
return self._fmt % record.__dict__
def get_option_ini(config: Config, *names: str):
for name in names:
ret = config.getoption(name) # 'default' arg won't work as expected
if ret is None:
ret = config.getini(name)
if ret:
return ret
def pytest_addoption(parser: Parser) -> None:
"""Add options to control log capturing."""
group = parser.getgroup("logging")
def add_option_ini(option, dest, default=None, type=None, **kwargs):
parser.addini(
dest, default=default, type=type, help="Default value for " + option
)
group.addoption(option, dest=dest, **kwargs)
add_option_ini(
"--log-level",
dest="log_level",
default=None,
metavar="LEVEL",
help=(
"Level of messages to catch/display."
" Not set by default, so it depends on the root/parent log handler's"
' effective level, where it is "WARNING" by default.'
),
)
add_option_ini(
"--log-format",
dest="log_format",
default=DEFAULT_LOG_FORMAT,
help="Log format used by the logging module",
)
add_option_ini(
"--log-date-format",
dest="log_date_format",
default=DEFAULT_LOG_DATE_FORMAT,
help="Log date format used by the logging module",
)
parser.addini(
"log_cli",
default=False,
type="bool",
help='Enable log display during test run (also known as "live logging")',
)
add_option_ini(
"--log-cli-level", dest="log_cli_level", default=None, help="CLI logging level"
)
add_option_ini(
"--log-cli-format",
dest="log_cli_format",
default=None,
help="Log format used by the logging module",
)
add_option_ini(
"--log-cli-date-format",
dest="log_cli_date_format",
default=None,
help="Log date format used by the logging module",
)
add_option_ini(
"--log-file",
dest="log_file",
default=None,
help="Path to a file when logging will be written to",
)
add_option_ini(
"--log-file-mode",
dest="log_file_mode",
default="w",
choices=["w", "a"],
help="Log file open mode",
)
add_option_ini(
"--log-file-level",
dest="log_file_level",
default=None,
help="Log file logging level",
)
add_option_ini(
"--log-file-format",
dest="log_file_format",
default=None,
help="Log format used by the logging module",
)
add_option_ini(
"--log-file-date-format",
dest="log_file_date_format",
default=None,
help="Log date format used by the logging module",
)
add_option_ini(
"--log-auto-indent",
dest="log_auto_indent",
default=None,
help="Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.",
)
group.addoption(
"--log-disable",
action="append",
default=[],
dest="logger_disable",
help="Disable a logger by name. Can be passed multiple times.",
)
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
# Not using @contextmanager for performance reasons.
class catching_logs(Generic[_HandlerType]):
"""Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level")
def __init__(self, handler: _HandlerType, level: int | None = None) -> None:
self.handler = handler
self.level = level
def __enter__(self) -> _HandlerType:
root_logger = logging.getLogger()
if self.level is not None:
self.handler.setLevel(self.level)
root_logger.addHandler(self.handler)
if self.level is not None:
self.orig_level = root_logger.level
root_logger.setLevel(min(self.orig_level, self.level))
return self.handler
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
root_logger = logging.getLogger()
if self.level is not None:
root_logger.setLevel(self.orig_level)
root_logger.removeHandler(self.handler)
class LogCaptureHandler(logging_StreamHandler):
"""A logging handler that stores log records and the log text."""
def __init__(self) -> None:
"""Create a new log handler."""
super().__init__(StringIO())
self.records: list[logging.LogRecord] = []
def emit(self, record: logging.LogRecord) -> None:
"""Keep the log records in a list in addition to the log text."""
self.records.append(record)
super().emit(record)
def reset(self) -> None:
self.records = []
self.stream = StringIO()
def clear(self) -> None:
self.records.clear()
self.stream = StringIO()
def handleError(self, record: logging.LogRecord) -> None:
if logging.raiseExceptions:
# Fail the test if the log message is bad (emit failed).
# The default behavior of logging is to print "Logging error"
# to stderr with the call stack and some extra details.
# pytest wants to make such mistakes visible during testing.
raise # noqa: PLE0704
@final
class LogCaptureFixture:
"""Provides access and control of log capturing."""
def __init__(self, item: nodes.Node, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._item = item
self._initial_handler_level: int | None = None
# Dict of log name -> log level.
self._initial_logger_levels: dict[str | None, int] = {}
self._initial_disabled_logging_level: int | None = None
def _finalize(self) -> None:
"""Finalize the fixture.
This restores the log levels and the disabled logging levels changed by :meth:`set_level`.
"""
# Restore log levels.
if self._initial_handler_level is not None:
self.handler.setLevel(self._initial_handler_level)
for logger_name, level in self._initial_logger_levels.items():
logger = logging.getLogger(logger_name)
logger.setLevel(level)
# Disable logging at the original disabled logging level.
if self._initial_disabled_logging_level is not None:
logging.disable(self._initial_disabled_logging_level)
self._initial_disabled_logging_level = None
@property
def handler(self) -> LogCaptureHandler:
"""Get the logging handler used by the fixture."""
return self._item.stash[caplog_handler_key]
def get_records(
self, when: Literal["setup", "call", "teardown"]
) -> list[logging.LogRecord]:
"""Get the logging records for one of the possible test phases.
:param when:
Which test phase to obtain the records from.
Valid values are: "setup", "call" and "teardown".
:returns: The list of captured records at the given stage.
.. versionadded:: 3.4
"""
return self._item.stash[caplog_records_key].get(when, [])
@property
def text(self) -> str:
"""The formatted log text."""
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property
def records(self) -> list[logging.LogRecord]:
"""The list of log records."""
return self.handler.records
@property
def record_tuples(self) -> list[tuple[str, int, str]]:
"""A list of a stripped down version of log records intended
for use in assertion comparison.
The format of the tuple is:
(logger_name, log_level, message)
"""
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property
def messages(self) -> list[str]:
"""A list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for
interpolation, log messages in this list are all interpolated.
Unlike 'text', which contains the output from the handler, log
messages in this list are unadorned with levels, timestamps, etc,
making exact comparisons more reliable.
Note that traceback or stack info (from :func:`logging.exception` or
the `exc_info` or `stack_info` arguments to the logging functions) is
not included, as this is added by the formatter in the handler.
.. versionadded:: 3.7
"""
return [r.getMessage() for r in self.records]
def clear(self) -> None:
"""Reset the list of log records and the captured log text."""
self.handler.clear()
def _force_enable_logging(
self, level: int | str, logger_obj: logging.Logger
) -> int:
"""Enable the desired logging level if the global level was disabled via ``logging.disabled``.
Only enables logging levels greater than or equal to the requested ``level``.
Does nothing if the desired ``level`` wasn't disabled.
:param level:
The logger level caplog should capture.
All logging is enabled if a non-standard logging level string is supplied.
Valid level strings are in :data:`logging._nameToLevel`.
:param logger_obj: The logger object to check.
:return: The original disabled logging level.
"""
original_disable_level: int = logger_obj.manager.disable
if isinstance(level, str):
# Try to translate the level string to an int for `logging.disable()`
level = logging.getLevelName(level)
if not isinstance(level, int):
# The level provided was not valid, so just un-disable all logging.
logging.disable(logging.NOTSET)
elif not logger_obj.isEnabledFor(level):
# Each level is `10` away from other levels.
# https://docs.python.org/3/library/logging.html#logging-levels
disable_level = max(level - 10, logging.NOTSET)
logging.disable(disable_level)
return original_disable_level
def set_level(self, level: int | str, logger: str | None = None) -> None:
"""Set the threshold level of a logger for the duration of a test.
Logging messages which are less severe than this level will not be captured.
.. versionchanged:: 3.4
The levels of the loggers changed by this function will be
restored to their initial values at the end of the test.
Will enable the requested logging level if it was disabled via :func:`logging.disable`.
:param level: The level.
:param logger: The logger to update. If not given, the root logger.
"""
logger_obj = logging.getLogger(logger)
# Save the original log-level to restore it during teardown.
self._initial_logger_levels.setdefault(logger, logger_obj.level)
logger_obj.setLevel(level)
if self._initial_handler_level is None:
self._initial_handler_level = self.handler.level
self.handler.setLevel(level)
initial_disabled_logging_level = self._force_enable_logging(level, logger_obj)
if self._initial_disabled_logging_level is None:
self._initial_disabled_logging_level = initial_disabled_logging_level
@contextmanager
def at_level(self, level: int | str, logger: str | None = None) -> Generator[None]:
"""Context manager that sets the level for capturing of logs. After
the end of the 'with' statement the level is restored to its original
value.
Will enable the requested logging level if it was disabled via :func:`logging.disable`.
:param level: The level.
:param logger: The logger to update. If not given, the root logger.
"""
logger_obj = logging.getLogger(logger)
orig_level = logger_obj.level
logger_obj.setLevel(level)
handler_orig_level = self.handler.level
self.handler.setLevel(level)
original_disable_level = self._force_enable_logging(level, logger_obj)
try:
yield
finally:
logger_obj.setLevel(orig_level)
self.handler.setLevel(handler_orig_level)
logging.disable(original_disable_level)
@contextmanager
def filtering(self, filter_: logging.Filter) -> Generator[None]:
"""Context manager that temporarily adds the given filter to the caplog's
:meth:`handler` for the 'with' statement block, and removes that filter at the
end of the block.
:param filter_: A custom :class:`logging.Filter` object.
.. versionadded:: 7.5
"""
self.handler.addFilter(filter_)
try:
yield
finally:
self.handler.removeFilter(filter_)
@fixture
def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture]:
"""Access and control log capturing.
Captured logs are available through the following properties/methods::
* caplog.messages -> list of format-interpolated log messages
* caplog.text -> string containing formatted log output
* caplog.records -> list of logging.LogRecord instances
* caplog.record_tuples -> list of (logger_name, level, message) tuples
* caplog.clear() -> clear captured records and formatted log output string
"""
result = LogCaptureFixture(request.node, _ispytest=True)
yield result
result._finalize()
def get_log_level_for_setting(config: Config, *setting_names: str) -> int | None:
for setting_name in setting_names:
log_level = config.getoption(setting_name)
if log_level is None:
log_level = config.getini(setting_name)
if log_level:
break
else:
return None
if isinstance(log_level, str):
log_level = log_level.upper()
try:
return int(getattr(logging, log_level, log_level))
except ValueError as e:
# Python logging does not recognise this as a logging level
raise UsageError(
f"'{log_level}' is not recognized as a logging level name for "
f"'{setting_name}'. Please consider passing the "
"logging level num instead."
) from e
# run after terminalreporter/capturemanager are configured
@hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
config.pluginmanager.register(LoggingPlugin(config), "logging-plugin")
class LoggingPlugin:
"""Attaches to the logging module and captures log messages for each test."""
def __init__(self, config: Config) -> None:
"""Create a new plugin to capture log messages.
The formatter can be safely shared across all handlers so
create a single one for the entire test session here.
"""
self._config = config
# Report logging.
self.formatter = self._create_formatter(
get_option_ini(config, "log_format"),
get_option_ini(config, "log_date_format"),
get_option_ini(config, "log_auto_indent"),
)
self.log_level = get_log_level_for_setting(config, "log_level")
self.caplog_handler = LogCaptureHandler()
self.caplog_handler.setFormatter(self.formatter)
self.report_handler = LogCaptureHandler()
self.report_handler.setFormatter(self.formatter)
# File logging.
self.log_file_level = get_log_level_for_setting(
config, "log_file_level", "log_level"
)
log_file = get_option_ini(config, "log_file") or os.devnull
if log_file != os.devnull:
directory = os.path.dirname(os.path.abspath(log_file))
if not os.path.isdir(directory):
os.makedirs(directory)
self.log_file_mode = get_option_ini(config, "log_file_mode") or "w"
self.log_file_handler = _FileHandler(
log_file, mode=self.log_file_mode, encoding="UTF-8"
)
log_file_format = get_option_ini(config, "log_file_format", "log_format")
log_file_date_format = get_option_ini(
config, "log_file_date_format", "log_date_format"
)
log_file_formatter = DatetimeFormatter(
log_file_format, datefmt=log_file_date_format
)
self.log_file_handler.setFormatter(log_file_formatter)
# CLI/live logging.
self.log_cli_level = get_log_level_for_setting(
config, "log_cli_level", "log_level"
)
if self._log_cli_enabled():
terminal_reporter = config.pluginmanager.get_plugin("terminalreporter")
# Guaranteed by `_log_cli_enabled()`.
assert terminal_reporter is not None
capture_manager = config.pluginmanager.get_plugin("capturemanager")
# if capturemanager plugin is disabled, live logging still works.
self.log_cli_handler: (
_LiveLoggingStreamHandler | _LiveLoggingNullHandler
) = _LiveLoggingStreamHandler(terminal_reporter, capture_manager)
else:
self.log_cli_handler = _LiveLoggingNullHandler()
log_cli_formatter = self._create_formatter(
get_option_ini(config, "log_cli_format", "log_format"),
get_option_ini(config, "log_cli_date_format", "log_date_format"),
get_option_ini(config, "log_auto_indent"),
)
self.log_cli_handler.setFormatter(log_cli_formatter)
self._disable_loggers(loggers_to_disable=config.option.logger_disable)
def _disable_loggers(self, loggers_to_disable: list[str]) -> None:
if not loggers_to_disable:
return
for name in loggers_to_disable:
logger = logging.getLogger(name)
logger.disabled = True
def _create_formatter(self, log_format, log_date_format, auto_indent):
# Color option doesn't exist if terminal plugin is disabled.
color = getattr(self._config.option, "color", "no")
if color != "no" and ColoredLevelFormatter.LEVELNAME_FMT_REGEX.search(
log_format
):
formatter: logging.Formatter = ColoredLevelFormatter(
create_terminal_writer(self._config), log_format, log_date_format
)
else:
formatter = DatetimeFormatter(log_format, log_date_format)
formatter._style = PercentStyleMultiline(
formatter._style._fmt, auto_indent=auto_indent
)
return formatter
def set_log_path(self, fname: str) -> None:
"""Set the filename parameter for Logging.FileHandler().
Creates parent directory if it does not exist.
.. warning::
This is an experimental API.
"""
fpath = Path(fname)
if not fpath.is_absolute():
fpath = self._config.rootpath / fpath
if not fpath.parent.exists():
fpath.parent.mkdir(exist_ok=True, parents=True)
# https://github.com/python/mypy/issues/11193
stream: io.TextIOWrapper = fpath.open(mode=self.log_file_mode, encoding="UTF-8") # type: ignore[assignment]
old_stream = self.log_file_handler.setStream(stream)
if old_stream:
old_stream.close()
def _log_cli_enabled(self) -> bool:
"""Return whether live logging is enabled."""
enabled = self._config.getoption(
"--log-cli-level"
) is not None or self._config.getini("log_cli")
if not enabled:
return False
terminal_reporter = self._config.pluginmanager.get_plugin("terminalreporter")
if terminal_reporter is None:
# terminal reporter is disabled e.g. by pytest-xdist.
return False
return True
@hookimpl(wrapper=True, tryfirst=True)
def pytest_sessionstart(self) -> Generator[None]:
self.log_cli_handler.set_when("sessionstart")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
return (yield)
@hookimpl(wrapper=True, tryfirst=True)
def pytest_collection(self) -> Generator[None]:
self.log_cli_handler.set_when("collection")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
return (yield)
@hookimpl(wrapper=True)
def pytest_runtestloop(self, session: Session) -> Generator[None, object, object]:
if session.config.option.collectonly:
return (yield)
if self._log_cli_enabled() and self._config.get_verbosity() < 1:
# The verbose flag is needed to avoid messy test progress output.
self._config.option.verbose = 1
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
return (yield) # Run all the tests.
@hookimpl
def pytest_runtest_logstart(self) -> None:
self.log_cli_handler.reset()
self.log_cli_handler.set_when("start")
@hookimpl
def pytest_runtest_logreport(self) -> None:
self.log_cli_handler.set_when("logreport")
@contextmanager
def _runtest_for(self, item: nodes.Item, when: str) -> Generator[None]:
"""Implement the internals of the pytest_runtest_xxx() hooks."""
with (
catching_logs(
self.caplog_handler,
level=self.log_level,
) as caplog_handler,
catching_logs(
self.report_handler,
level=self.log_level,
) as report_handler,
):
caplog_handler.reset()
report_handler.reset()
item.stash[caplog_records_key][when] = caplog_handler.records
item.stash[caplog_handler_key] = caplog_handler
try:
yield
finally:
log = report_handler.stream.getvalue().strip()
item.add_report_section(when, "log", log)
@hookimpl(wrapper=True)
def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None]:
self.log_cli_handler.set_when("setup")
empty: dict[str, list[logging.LogRecord]] = {}
item.stash[caplog_records_key] = empty
with self._runtest_for(item, "setup"):
yield
@hookimpl(wrapper=True)
def pytest_runtest_call(self, item: nodes.Item) -> Generator[None]:
self.log_cli_handler.set_when("call")
with self._runtest_for(item, "call"):
yield
@hookimpl(wrapper=True)
def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None]:
self.log_cli_handler.set_when("teardown")
try:
with self._runtest_for(item, "teardown"):
yield
finally:
del item.stash[caplog_records_key]
del item.stash[caplog_handler_key]
@hookimpl
def pytest_runtest_logfinish(self) -> None:
self.log_cli_handler.set_when("finish")
@hookimpl(wrapper=True, tryfirst=True)
def pytest_sessionfinish(self) -> Generator[None]:
self.log_cli_handler.set_when("sessionfinish")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
with catching_logs(self.log_file_handler, level=self.log_file_level):
return (yield)
@hookimpl
def pytest_unconfigure(self) -> None:
# Close the FileHandler explicitly.
# (logging.shutdown might have lost the weakref?!)
self.log_file_handler.close()
class _FileHandler(logging.FileHandler):
"""A logging FileHandler with pytest tweaks."""
def handleError(self, record: logging.LogRecord) -> None:
# Handled by LogCaptureHandler.
pass
class _LiveLoggingStreamHandler(logging_StreamHandler):
"""A logging StreamHandler used by the live logging feature: it will
write a newline before the first log message in each test.
During live logging we must also explicitly disable stdout/stderr
capturing otherwise it will get captured and won't appear in the
terminal.
"""
# Officially stream needs to be a IO[str], but TerminalReporter
# isn't. So force it.
stream: TerminalReporter = None # type: ignore
def __init__(
self,
terminal_reporter: TerminalReporter,
capture_manager: CaptureManager | None,
) -> None:
super().__init__(stream=terminal_reporter) # type: ignore[arg-type]
self.capture_manager = capture_manager
self.reset()
self.set_when(None)
self._test_outcome_written = False
def reset(self) -> None:
"""Reset the handler; should be called before the start of each test."""
self._first_record_emitted = False
def set_when(self, when: str | None) -> None:
"""Prepare for the given test phase (setup/call/teardown)."""
self._when = when
self._section_name_shown = False
if when == "start":
self._test_outcome_written = False
def emit(self, record: logging.LogRecord) -> None:
ctx_manager = (
self.capture_manager.global_and_fixture_disabled()
if self.capture_manager
else nullcontext()
)
with ctx_manager:
if not self._first_record_emitted:
self.stream.write("\n")
self._first_record_emitted = True
elif self._when in ("teardown", "finish"):
if not self._test_outcome_written:
self._test_outcome_written = True
self.stream.write("\n")
if not self._section_name_shown and self._when:
self.stream.section("live log " + self._when, sep="-", bold=True)
self._section_name_shown = True
super().emit(record)
def handleError(self, record: logging.LogRecord) -> None:
# Handled by LogCaptureHandler.
pass
class _LiveLoggingNullHandler(logging.NullHandler):
"""A logging handler used when live logging is disabled."""
def reset(self) -> None:
pass
def set_when(self, when: str) -> None:
pass
def handleError(self, record: logging.LogRecord) -> None:
# Handled by LogCaptureHandler.
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,301 @@
"""Generic mechanism for marking and selecting python functions."""
from __future__ import annotations
import collections
from collections.abc import Collection
from collections.abc import Iterable
from collections.abc import Set as AbstractSet
import dataclasses
from typing import Optional
from typing import TYPE_CHECKING
from .expression import Expression
from .expression import ParseError
from .structures import _HiddenParam
from .structures import EMPTY_PARAMETERSET_OPTION
from .structures import get_empty_parameterset_mark
from .structures import HIDDEN_PARAM
from .structures import Mark
from .structures import MARK_GEN
from .structures import MarkDecorator
from .structures import MarkGenerator
from .structures import ParameterSet
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import NOT_SET
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
if TYPE_CHECKING:
from _pytest.nodes import Item
__all__ = [
"HIDDEN_PARAM",
"MARK_GEN",
"Mark",
"MarkDecorator",
"MarkGenerator",
"ParameterSet",
"get_empty_parameterset_mark",
]
old_mark_config_key = StashKey[Optional[Config]]()
def param(
*values: object,
marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: str | _HiddenParam | None = None,
) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`.
.. code-block:: python
@pytest.mark.parametrize(
"test_input,expected",
[
("3+5", 8),
pytest.param("6*9", 42, marks=pytest.mark.xfail),
],
)
def test_eval(test_input, expected):
assert eval(test_input) == expected
:param values: Variable args of the values of the parameter set, in order.
:param marks:
A single mark or a list of marks to be applied to this parameter set.
:ref:`pytest.mark.usefixtures <pytest.mark.usefixtures ref>` cannot be added via this parameter.
:type id: str | Literal[pytest.HIDDEN_PARAM] | None
:param id:
The id to attribute to this parameter set.
.. versionadded:: 8.4
:ref:`hidden-param` means to hide the parameter set
from the test name. Can only be used at most 1 time, as
test names need to be unique.
"""
return ParameterSet.param(*values, marks=marks, id=id)
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption( # private to use reserved lower-case short option
"-k",
action="store",
dest="keyword",
default="",
metavar="EXPRESSION",
help="Only run tests which match the given substring expression. "
"An expression is a Python evaluable expression "
"where all names are substring-matched against test names "
"and their parent classes. Example: -k 'test_method or test_"
"other' matches all test functions and classes whose name "
"contains 'test_method' or 'test_other', while -k 'not test_method' "
"matches those that don't contain 'test_method' in their names. "
"-k 'not test_method and not test_other' will eliminate the matches. "
"Additionally keywords are matched to classes and functions "
"containing extra names in their 'extra_keyword_matches' set, "
"as well as functions which have names assigned directly to them. "
"The matching is case-insensitive.",
)
group._addoption( # private to use reserved lower-case short option
"-m",
action="store",
dest="markexpr",
default="",
metavar="MARKEXPR",
help="Only run tests matching given mark expression. "
"For example: -m 'mark1 and not mark2'.",
)
group.addoption(
"--markers",
action="store_true",
help="show markers (builtin, plugin and per-project ones).",
)
parser.addini("markers", "Register new markers for test functions", "linelist")
parser.addini(EMPTY_PARAMETERSET_OPTION, "Default marker for empty parametersets")
@hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
import _pytest.config
if config.option.markers:
config._do_configure()
tw = _pytest.config.create_terminal_writer(config)
for line in config.getini("markers"):
parts = line.split(":", 1)
name = parts[0]
rest = parts[1] if len(parts) == 2 else ""
tw.write(f"@pytest.mark.{name}:", bold=True)
tw.line(rest)
tw.line()
config._ensure_unconfigure()
return 0
return None
@dataclasses.dataclass
class KeywordMatcher:
"""A matcher for keywords.
Given a list of names, matches any substring of one of these names. The
string inclusion check is case-insensitive.
Will match on the name of colitem, including the names of its parents.
Only matches names of items which are either a :class:`Class` or a
:class:`Function`.
Additionally, matches on names in the 'extra_keyword_matches' set of
any item, as well as names directly assigned to test functions.
"""
__slots__ = ("_names",)
_names: AbstractSet[str]
@classmethod
def from_item(cls, item: Item) -> KeywordMatcher:
mapped_names = set()
# Add the names of the current item and any parent items,
# except the Session and root Directory's which are not
# interesting for matching.
import pytest
for node in item.listchain():
if isinstance(node, pytest.Session):
continue
if isinstance(node, pytest.Directory) and isinstance(
node.parent, pytest.Session
):
continue
mapped_names.add(node.name)
# Add the names added as extra keywords to current or parent items.
mapped_names.update(item.listextrakeywords())
# Add the names attached to the current function through direct assignment.
function_obj = getattr(item, "function", None)
if function_obj:
mapped_names.update(function_obj.__dict__)
# Add the markers to the keywords as we no longer handle them correctly.
mapped_names.update(mark.name for mark in item.iter_markers())
return cls(mapped_names)
def __call__(self, subname: str, /, **kwargs: str | int | bool | None) -> bool:
if kwargs:
raise UsageError("Keyword expressions do not support call parameters.")
subname = subname.lower()
return any(subname in name.lower() for name in self._names)
def deselect_by_keyword(items: list[Item], config: Config) -> None:
keywordexpr = config.option.keyword.lstrip()
if not keywordexpr:
return
expr = _parse_expression(keywordexpr, "Wrong expression passed to '-k'")
remaining = []
deselected = []
for colitem in items:
if not expr.evaluate(KeywordMatcher.from_item(colitem)):
deselected.append(colitem)
else:
remaining.append(colitem)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = remaining
@dataclasses.dataclass
class MarkMatcher:
"""A matcher for markers which are present.
Tries to match on any marker names, attached to the given colitem.
"""
__slots__ = ("own_mark_name_mapping",)
own_mark_name_mapping: dict[str, list[Mark]]
@classmethod
def from_markers(cls, markers: Iterable[Mark]) -> MarkMatcher:
mark_name_mapping = collections.defaultdict(list)
for mark in markers:
mark_name_mapping[mark.name].append(mark)
return cls(mark_name_mapping)
def __call__(self, name: str, /, **kwargs: str | int | bool | None) -> bool:
if not (matches := self.own_mark_name_mapping.get(name, [])):
return False
for mark in matches: # pylint: disable=consider-using-any-or-all
if all(mark.kwargs.get(k, NOT_SET) == v for k, v in kwargs.items()):
return True
return False
def deselect_by_mark(items: list[Item], config: Config) -> None:
matchexpr = config.option.markexpr
if not matchexpr:
return
expr = _parse_expression(matchexpr, "Wrong expression passed to '-m'")
remaining: list[Item] = []
deselected: list[Item] = []
for item in items:
if expr.evaluate(MarkMatcher.from_markers(item.iter_markers())):
remaining.append(item)
else:
deselected.append(item)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = remaining
def _parse_expression(expr: str, exc_message: str) -> Expression:
try:
return Expression.compile(expr)
except ParseError as e:
raise UsageError(f"{exc_message}: {expr}: {e}") from None
def pytest_collection_modifyitems(items: list[Item], config: Config) -> None:
deselect_by_keyword(items, config)
deselect_by_mark(items, config)
def pytest_configure(config: Config) -> None:
config.stash[old_mark_config_key] = MARK_GEN._config
MARK_GEN._config = config
empty_parameterset = config.getini(EMPTY_PARAMETERSET_OPTION)
if empty_parameterset not in ("skip", "xfail", "fail_at_collect", None, ""):
raise UsageError(
f"{EMPTY_PARAMETERSET_OPTION!s} must be one of skip, xfail or fail_at_collect"
f" but it is {empty_parameterset!r}"
)
def pytest_unconfigure(config: Config) -> None:
MARK_GEN._config = config.stash.get(old_mark_config_key, None)

View File

@@ -0,0 +1,331 @@
r"""Evaluate match expressions, as used by `-k` and `-m`.
The grammar is:
expression: expr? EOF
expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident kwargs?
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
kwargs: ('(' name '=' value ( ', ' name '=' value )* ')')
name: a valid ident, but not a reserved keyword
value: (unescaped) string literal | (-)?[0-9]+ | 'False' | 'True' | 'None'
The semantics are:
- Empty expression evaluates to False.
- ident evaluates to True or False according to a provided matcher function.
- or/and/not evaluate according to the usual boolean semantics.
- ident with parentheses and keyword arguments evaluates to True or False according to a provided matcher function.
"""
from __future__ import annotations
import ast
from collections.abc import Iterator
from collections.abc import Mapping
from collections.abc import Sequence
import dataclasses
import enum
import keyword
import re
import types
from typing import Literal
from typing import NoReturn
from typing import overload
from typing import Protocol
__all__ = [
"Expression",
"ParseError",
]
class TokenType(enum.Enum):
LPAREN = "left parenthesis"
RPAREN = "right parenthesis"
OR = "or"
AND = "and"
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
EQUAL = "="
STRING = "string literal"
COMMA = ","
@dataclasses.dataclass(frozen=True)
class Token:
__slots__ = ("pos", "type", "value")
type: TokenType
value: str
pos: int
class ParseError(Exception):
"""The expression contains invalid syntax.
:param column: The column in the line where the error occurred (1-based).
:param message: A description of the error.
"""
def __init__(self, column: int, message: str) -> None:
self.column = column
self.message = message
def __str__(self) -> str:
return f"at column {self.column}: {self.message}"
class Scanner:
__slots__ = ("current", "tokens")
def __init__(self, input: str) -> None:
self.tokens = self.lex(input)
self.current = next(self.tokens)
def lex(self, input: str) -> Iterator[Token]:
pos = 0
while pos < len(input):
if input[pos] in (" ", "\t"):
pos += 1
elif input[pos] == "(":
yield Token(TokenType.LPAREN, "(", pos)
pos += 1
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
pos += 1
elif input[pos] == "=":
yield Token(TokenType.EQUAL, "=", pos)
pos += 1
elif input[pos] == ",":
yield Token(TokenType.COMMA, ",", pos)
pos += 1
elif (quote_char := input[pos]) in ("'", '"'):
end_quote_pos = input.find(quote_char, pos + 1)
if end_quote_pos == -1:
raise ParseError(
pos + 1,
f'closing quote "{quote_char}" is missing',
)
value = input[pos : end_quote_pos + 1]
if (backslash_pos := input.find("\\")) != -1:
raise ParseError(
backslash_pos + 1,
r'escaping with "\" not supported in marker expression',
)
yield Token(TokenType.STRING, value, pos)
pos += len(value)
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match:
value = match.group(0)
if value == "or":
yield Token(TokenType.OR, value, pos)
elif value == "and":
yield Token(TokenType.AND, value, pos)
elif value == "not":
yield Token(TokenType.NOT, value, pos)
else:
yield Token(TokenType.IDENT, value, pos)
pos += len(value)
else:
raise ParseError(
pos + 1,
f'unexpected character "{input[pos]}"',
)
yield Token(TokenType.EOF, "", pos)
@overload
def accept(self, type: TokenType, *, reject: Literal[True]) -> Token: ...
@overload
def accept(
self, type: TokenType, *, reject: Literal[False] = False
) -> Token | None: ...
def accept(self, type: TokenType, *, reject: bool = False) -> Token | None:
if self.current.type is type:
token = self.current
if token.type is not TokenType.EOF:
self.current = next(self.tokens)
return token
if reject:
self.reject((type,))
return None
def reject(self, expected: Sequence[TokenType]) -> NoReturn:
raise ParseError(
self.current.pos + 1,
"expected {}; got {}".format(
" OR ".join(type.value for type in expected),
self.current.type.value,
),
)
# True, False and None are legal match expression identifiers,
# but illegal as Python identifiers. To fix this, this prefix
# is added to identifiers in the conversion to Python AST.
IDENT_PREFIX = "$"
def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF):
ret: ast.expr = ast.Constant(False)
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)
return ast.fix_missing_locations(ast.Expression(ret))
def expr(s: Scanner) -> ast.expr:
ret = and_expr(s)
while s.accept(TokenType.OR):
rhs = and_expr(s)
ret = ast.BoolOp(ast.Or(), [ret, rhs])
return ret
def and_expr(s: Scanner) -> ast.expr:
ret = not_expr(s)
while s.accept(TokenType.AND):
rhs = not_expr(s)
ret = ast.BoolOp(ast.And(), [ret, rhs])
return ret
def not_expr(s: Scanner) -> ast.expr:
if s.accept(TokenType.NOT):
return ast.UnaryOp(ast.Not(), not_expr(s))
if s.accept(TokenType.LPAREN):
ret = expr(s)
s.accept(TokenType.RPAREN, reject=True)
return ret
ident = s.accept(TokenType.IDENT)
if ident:
name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
if s.accept(TokenType.LPAREN):
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
s.accept(TokenType.RPAREN, reject=True)
else:
ret = name
return ret
s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))
BUILTIN_MATCHERS = {"True": True, "False": False, "None": None}
def single_kwarg(s: Scanner) -> ast.keyword:
keyword_name = s.accept(TokenType.IDENT, reject=True)
if not keyword_name.value.isidentifier():
raise ParseError(
keyword_name.pos + 1,
f"not a valid python identifier {keyword_name.value}",
)
if keyword.iskeyword(keyword_name.value):
raise ParseError(
keyword_name.pos + 1,
f"unexpected reserved python keyword `{keyword_name.value}`",
)
s.accept(TokenType.EQUAL, reject=True)
if value_token := s.accept(TokenType.STRING):
value: str | int | bool | None = value_token.value[1:-1] # strip quotes
else:
value_token = s.accept(TokenType.IDENT, reject=True)
if (number := value_token.value).isdigit() or (
number.startswith("-") and number[1:].isdigit()
):
value = int(number)
elif value_token.value in BUILTIN_MATCHERS:
value = BUILTIN_MATCHERS[value_token.value]
else:
raise ParseError(
value_token.pos + 1,
f'unexpected character/s "{value_token.value}"',
)
ret = ast.keyword(keyword_name.value, ast.Constant(value))
return ret
def all_kwargs(s: Scanner) -> list[ast.keyword]:
ret = [single_kwarg(s)]
while s.accept(TokenType.COMMA):
ret.append(single_kwarg(s))
return ret
class MatcherCall(Protocol):
def __call__(self, name: str, /, **kwargs: str | int | bool | None) -> bool: ...
@dataclasses.dataclass
class MatcherNameAdapter:
matcher: MatcherCall
name: str
def __bool__(self) -> bool:
return self.matcher(self.name)
def __call__(self, **kwargs: str | int | bool | None) -> bool:
return self.matcher(self.name, **kwargs)
class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
"""Adapts a matcher function to a locals mapping as required by eval()."""
def __init__(self, matcher: MatcherCall) -> None:
self.matcher = matcher
def __getitem__(self, key: str) -> MatcherNameAdapter:
return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :])
def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
def __len__(self) -> int:
raise NotImplementedError()
class Expression:
"""A compiled match expression as used by -k and -m.
The expression can be evaluated against different matchers.
"""
__slots__ = ("code",)
def __init__(self, code: types.CodeType) -> None:
self.code = code
@classmethod
def compile(cls, input: str) -> Expression:
"""Compile a match expression.
:param input: The input expression - one line.
"""
astexpr = expression(Scanner(input))
code: types.CodeType = compile(
astexpr,
filename="<pytest match expression>",
mode="eval",
)
return Expression(code)
def evaluate(self, matcher: MatcherCall) -> bool:
"""Evaluate the match expression.
:param matcher:
Given an identifier, should return whether it matches or not.
Should be prepared to handle arbitrary strings as input.
:returns: Whether the expression matches or not.
"""
ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
return ret

View File

@@ -0,0 +1,662 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import collections.abc
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from collections.abc import MutableMapping
from collections.abc import Sequence
import dataclasses
import enum
import inspect
from typing import Any
from typing import final
from typing import NamedTuple
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import warnings
from .._code import getfslineno
from ..compat import NOTSET
from ..compat import NotSetType
from _pytest.config import Config
from _pytest.deprecated import check_ispytest
from _pytest.deprecated import MARKED_FIXTURE
from _pytest.outcomes import fail
from _pytest.raises import AbstractRaises
from _pytest.scope import _ScopeName
from _pytest.warning_types import PytestUnknownMarkWarning
if TYPE_CHECKING:
from ..nodes import Node
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
# Singleton type for HIDDEN_PARAM, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class _HiddenParam(enum.Enum):
token = 0
#: Can be used as a parameter set id to hide it from the test name.
HIDDEN_PARAM = _HiddenParam.token
def istestfunc(func) -> bool:
return callable(func) and getattr(func, "__name__", "<lambda>") != "<lambda>"
def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func
) -> MarkDecorator:
from ..nodes import Collector
argslisting = ", ".join(argnames)
fs, lineno = getfslineno(func)
reason = f"got empty parameter set for ({argslisting})"
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
if requested_mark in ("", None, "skip"):
mark = MARK_GEN.skip(reason=reason)
elif requested_mark == "xfail":
mark = MARK_GEN.xfail(reason=reason, run=False)
elif requested_mark == "fail_at_collect":
raise Collector.CollectError(
f"Empty parameter set in '{func.__name__}' at line {lineno + 1}"
)
else:
raise LookupError(requested_mark)
return mark
class ParameterSet(NamedTuple):
"""A set of values for a set of parameters along with associated marks and
an optional ID for the set.
Examples::
pytest.param(1, 2, 3)
# ParameterSet(values=(1, 2, 3), marks=(), id=None)
pytest.param("hello", id="greeting")
# ParameterSet(values=("hello",), marks=(), id="greeting")
# Parameter set with marks
pytest.param(42, marks=pytest.mark.xfail)
# ParameterSet(values=(42,), marks=(MarkDecorator(...),), id=None)
# From parametrize mark (parameter names + list of parameter sets)
pytest.mark.parametrize(
("a", "b", "expected"),
[
(1, 2, 3),
pytest.param(40, 2, 42, id="everything"),
],
)
# ParameterSet(values=(1, 2, 3), marks=(), id=None)
# ParameterSet(values=(2, 2, 3), marks=(), id="everything")
"""
values: Sequence[object | NotSetType]
marks: Collection[MarkDecorator | Mark]
id: str | _HiddenParam | None
@classmethod
def param(
cls,
*values: object,
marks: MarkDecorator | Collection[MarkDecorator | Mark] = (),
id: str | _HiddenParam | None = None,
) -> ParameterSet:
if isinstance(marks, MarkDecorator):
marks = (marks,)
else:
assert isinstance(marks, collections.abc.Collection)
if any(i.name == "usefixtures" for i in marks):
raise ValueError(
"pytest.param cannot add pytest.mark.usefixtures; see "
"https://docs.pytest.org/en/stable/reference/reference.html#pytest-param"
)
if id is not None:
if not isinstance(id, str) and id is not HIDDEN_PARAM:
raise TypeError(
"Expected id to be a string or a `pytest.HIDDEN_PARAM` sentinel, "
f"got {type(id)}: {id!r}",
)
return cls(values, marks, id)
@classmethod
def extract_from(
cls,
parameterset: ParameterSet | Sequence[object] | object,
force_tuple: bool = False,
) -> ParameterSet:
"""Extract from an object or objects.
:param parameterset:
A legacy style parameterset that may or may not be a tuple,
and may or may not be wrapped into a mess of mark objects.
:param force_tuple:
Enforce tuple wrapping so single argument tuple values
don't get decomposed and break tests.
"""
if isinstance(parameterset, cls):
return parameterset
if force_tuple:
return cls.param(parameterset)
else:
# TODO: Refactor to fix this type-ignore. Currently the following
# passes type-checking but crashes:
#
# @pytest.mark.parametrize(('x', 'y'), [1, 2])
# def test_foo(x, y): pass
return cls(parameterset, marks=[], id=None) # type: ignore[arg-type]
@staticmethod
def _parse_parametrize_args(
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
*args,
**kwargs,
) -> tuple[Sequence[str], bool]:
if isinstance(argnames, str):
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
force_tuple = len(argnames) == 1
else:
force_tuple = False
return argnames, force_tuple
@staticmethod
def _parse_parametrize_parameters(
argvalues: Iterable[ParameterSet | Sequence[object] | object],
force_tuple: bool,
) -> list[ParameterSet]:
return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
]
@classmethod
def _for_parametrize(
cls,
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
func,
config: Config,
nodeid: str,
) -> tuple[Sequence[str], list[ParameterSet]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues
if parameters:
# Check all parameter sets have the correct number of values.
for param in parameters:
if len(param.values) != len(argnames):
msg = (
'{nodeid}: in "parametrize" the number of names ({names_len}):\n'
" {names}\n"
"must be equal to the number of values ({values_len}):\n"
" {values}"
)
fail(
msg.format(
nodeid=nodeid,
values=param.values,
names=argnames,
names_len=len(argnames),
values_len=len(param.values),
),
pytrace=False,
)
else:
# Empty parameter set (likely computed at runtime): create a single
# parameter set with NOTSET values, with the "empty parameter set" mark applied to it.
mark = get_empty_parameterset_mark(config, argnames, func)
parameters.append(
ParameterSet(
values=(NOTSET,) * len(argnames), marks=[mark], id="NOTSET"
)
)
return argnames, parameters
@final
@dataclasses.dataclass(frozen=True)
class Mark:
"""A pytest mark."""
#: Name of the mark.
name: str
#: Positional arguments of the mark decorator.
args: tuple[Any, ...]
#: Keyword arguments of the mark decorator.
kwargs: Mapping[str, Any]
#: Source Mark for ids with parametrize Marks.
_param_ids_from: Mark | None = dataclasses.field(default=None, repr=False)
#: Resolved/generated ids with parametrize Marks.
_param_ids_generated: Sequence[str] | None = dataclasses.field(
default=None, repr=False
)
def __init__(
self,
name: str,
args: tuple[Any, ...],
kwargs: Mapping[str, Any],
param_ids_from: Mark | None = None,
param_ids_generated: Sequence[str] | None = None,
*,
_ispytest: bool = False,
) -> None:
""":meta private:"""
check_ispytest(_ispytest)
# Weirdness to bypass frozen=True.
object.__setattr__(self, "name", name)
object.__setattr__(self, "args", args)
object.__setattr__(self, "kwargs", kwargs)
object.__setattr__(self, "_param_ids_from", param_ids_from)
object.__setattr__(self, "_param_ids_generated", param_ids_generated)
def _has_param_ids(self) -> bool:
return "ids" in self.kwargs or len(self.args) >= 4
def combined_with(self, other: Mark) -> Mark:
"""Return a new Mark which is a combination of this
Mark and another Mark.
Combines by appending args and merging kwargs.
:param Mark other: The mark to combine with.
:rtype: Mark
"""
assert self.name == other.name
# Remember source of ids with parametrize Marks.
param_ids_from: Mark | None = None
if self.name == "parametrize":
if other._has_param_ids():
param_ids_from = other
elif self._has_param_ids():
param_ids_from = self
return Mark(
self.name,
self.args + other.args,
dict(self.kwargs, **other.kwargs),
param_ids_from=param_ids_from,
_ispytest=True,
)
# A generic parameter designating an object to which a Mark may
# be applied -- a test function (callable) or class.
# Note: a lambda is not allowed, but this can't be represented.
Markable = TypeVar("Markable", bound=Union[Callable[..., object], type])
@dataclasses.dataclass
class MarkDecorator:
"""A decorator for applying a mark on test functions and classes.
``MarkDecorators`` are created with ``pytest.mark``::
mark1 = pytest.mark.NAME # Simple MarkDecorator
mark2 = pytest.mark.NAME(name1=value) # Parametrized MarkDecorator
and can then be applied as decorators to test functions::
@mark2
def test_function():
pass
When a ``MarkDecorator`` is called, it does the following:
1. If called with a single class as its only positional argument and no
additional keyword arguments, it attaches the mark to the class so it
gets applied automatically to all test cases found in that class.
2. If called with a single function as its only positional argument and
no additional keyword arguments, it attaches the mark to the function,
containing all the arguments already stored internally in the
``MarkDecorator``.
3. When called in any other case, it returns a new ``MarkDecorator``
instance with the original ``MarkDecorator``'s content updated with
the arguments passed to this call.
Note: The rules above prevent a ``MarkDecorator`` from storing only a
single function or class reference as its positional argument with no
additional keyword or positional arguments. You can work around this by
using `with_args()`.
"""
mark: Mark
def __init__(self, mark: Mark, *, _ispytest: bool = False) -> None:
""":meta private:"""
check_ispytest(_ispytest)
self.mark = mark
@property
def name(self) -> str:
"""Alias for mark.name."""
return self.mark.name
@property
def args(self) -> tuple[Any, ...]:
"""Alias for mark.args."""
return self.mark.args
@property
def kwargs(self) -> Mapping[str, Any]:
"""Alias for mark.kwargs."""
return self.mark.kwargs
@property
def markname(self) -> str:
""":meta private:"""
return self.name # for backward-compat (2.4.1 had this attr)
def with_args(self, *args: object, **kwargs: object) -> MarkDecorator:
"""Return a MarkDecorator with extra arguments added.
Unlike calling the MarkDecorator, with_args() can be used even
if the sole argument is a callable/class.
"""
mark = Mark(self.name, args, kwargs, _ispytest=True)
return MarkDecorator(self.mark.combined_with(mark), _ispytest=True)
# Type ignored because the overloads overlap with an incompatible
# return type. Not much we can do about that. Thankfully mypy picks
# the first match so it works out even if we break the rules.
@overload
def __call__(self, arg: Markable) -> Markable: # type: ignore[overload-overlap]
pass
@overload
def __call__(self, *args: object, **kwargs: object) -> MarkDecorator:
pass
def __call__(self, *args: object, **kwargs: object):
"""Call the MarkDecorator."""
if args and not kwargs:
func = args[0]
is_class = inspect.isclass(func)
# For staticmethods/classmethods, the marks are eventually fetched from the
# function object, not the descriptor, so unwrap.
unwrapped_func = func
if isinstance(func, (staticmethod, classmethod)):
unwrapped_func = func.__func__
if len(args) == 1 and (istestfunc(unwrapped_func) or is_class):
store_mark(unwrapped_func, self.mark, stacklevel=3)
return func
return self.with_args(*args, **kwargs)
def get_unpacked_marks(
obj: object | type,
*,
consider_mro: bool = True,
) -> list[Mark]:
"""Obtain the unpacked marks that are stored on an object.
If obj is a class and consider_mro is true, return marks applied to
this class and all of its super-classes in MRO order. If consider_mro
is false, only return marks applied directly to this class.
"""
if isinstance(obj, type):
if not consider_mro:
mark_lists = [obj.__dict__.get("pytestmark", [])]
else:
mark_lists = [
x.__dict__.get("pytestmark", []) for x in reversed(obj.__mro__)
]
mark_list = []
for item in mark_lists:
if isinstance(item, list):
mark_list.extend(item)
else:
mark_list.append(item)
else:
mark_attribute = getattr(obj, "pytestmark", [])
if isinstance(mark_attribute, list):
mark_list = mark_attribute
else:
mark_list = [mark_attribute]
return list(normalize_mark_list(mark_list))
def normalize_mark_list(
mark_list: Iterable[Mark | MarkDecorator],
) -> Iterable[Mark]:
"""
Normalize an iterable of Mark or MarkDecorator objects into a list of marks
by retrieving the `mark` attribute on MarkDecorator instances.
:param mark_list: marks to normalize
:returns: A new list of the extracted Mark objects
"""
for mark in mark_list:
mark_obj = getattr(mark, "mark", mark)
if not isinstance(mark_obj, Mark):
raise TypeError(f"got {mark_obj!r} instead of Mark")
yield mark_obj
def store_mark(obj, mark: Mark, *, stacklevel: int = 2) -> None:
"""Store a Mark on an object.
This is used to implement the Mark declarations/decorators correctly.
"""
assert isinstance(mark, Mark), mark
from ..fixtures import getfixturemarker
if getfixturemarker(obj) is not None:
warnings.warn(MARKED_FIXTURE, stacklevel=stacklevel)
# Always reassign name to avoid updating pytestmark in a reference that
# was only borrowed.
obj.pytestmark = [*get_unpacked_marks(obj, consider_mro=False), mark]
# Typing for builtin pytest marks. This is cheating; it gives builtin marks
# special privilege, and breaks modularity. But practicality beats purity...
if TYPE_CHECKING:
class _SkipMarkDecorator(MarkDecorator):
@overload # type: ignore[override,no-overload-impl]
def __call__(self, arg: Markable) -> Markable: ...
@overload
def __call__(self, reason: str = ...) -> MarkDecorator: ...
class _SkipifMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
condition: str | bool = ...,
*conditions: str | bool,
reason: str = ...,
) -> MarkDecorator: ...
class _XfailMarkDecorator(MarkDecorator):
@overload # type: ignore[override,no-overload-impl]
def __call__(self, arg: Markable) -> Markable: ...
@overload
def __call__(
self,
condition: str | bool = False,
*conditions: str | bool,
reason: str = ...,
run: bool = ...,
raises: None
| type[BaseException]
| tuple[type[BaseException], ...]
| AbstractRaises[BaseException] = ...,
strict: bool = ...,
) -> MarkDecorator: ...
class _ParametrizeMarkDecorator(MarkDecorator):
def __call__( # type: ignore[override]
self,
argnames: str | Sequence[str],
argvalues: Iterable[ParameterSet | Sequence[object] | object],
*,
indirect: bool | Sequence[str] = ...,
ids: Iterable[None | str | float | int | bool]
| Callable[[Any], object | None]
| None = ...,
scope: _ScopeName | None = ...,
) -> MarkDecorator: ...
class _UsefixturesMarkDecorator(MarkDecorator):
def __call__(self, *fixtures: str) -> MarkDecorator: # type: ignore[override]
...
class _FilterwarningsMarkDecorator(MarkDecorator):
def __call__(self, *filters: str) -> MarkDecorator: # type: ignore[override]
...
@final
class MarkGenerator:
"""Factory for :class:`MarkDecorator` objects - exposed as
a ``pytest.mark`` singleton instance.
Example::
import pytest
@pytest.mark.slowtest
def test_function():
pass
applies a 'slowtest' :class:`Mark` on ``test_function``.
"""
# See TYPE_CHECKING above.
if TYPE_CHECKING:
skip: _SkipMarkDecorator
skipif: _SkipifMarkDecorator
xfail: _XfailMarkDecorator
parametrize: _ParametrizeMarkDecorator
usefixtures: _UsefixturesMarkDecorator
filterwarnings: _FilterwarningsMarkDecorator
def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
self._config: Config | None = None
self._markers: set[str] = set()
def __getattr__(self, name: str) -> MarkDecorator:
"""Generate a new :class:`MarkDecorator` with the given name."""
if name[0] == "_":
raise AttributeError("Marker name must NOT start with underscore")
if self._config is not None:
# We store a set of markers as a performance optimisation - if a mark
# name is in the set we definitely know it, but a mark may be known and
# not in the set. We therefore start by updating the set!
if name not in self._markers:
for line in self._config.getini("markers"):
# example lines: "skipif(condition): skip the given test if..."
# or "hypothesis: tests which use Hypothesis", so to get the
# marker name we split on both `:` and `(`.
marker = line.split(":")[0].split("(")[0].strip()
self._markers.add(marker)
# If the name is not in the set of known marks after updating,
# then it really is time to issue a warning or an error.
if name not in self._markers:
if self._config.option.strict_markers or self._config.option.strict:
fail(
f"{name!r} not found in `markers` configuration option",
pytrace=False,
)
# Raise a specific error for common misspellings of "parametrize".
if name in ["parameterize", "parametrise", "parameterise"]:
__tracebackhide__ = True
fail(f"Unknown '{name}' mark, did you mean 'parametrize'?")
warnings.warn(
f"Unknown pytest.mark.{name} - is this a typo? You can register "
"custom marks to avoid this warning - for details, see "
"https://docs.pytest.org/en/stable/how-to/mark.html",
PytestUnknownMarkWarning,
2,
)
return MarkDecorator(Mark(name, (), {}, _ispytest=True), _ispytest=True)
MARK_GEN = MarkGenerator(_ispytest=True)
@final
class NodeKeywords(MutableMapping[str, Any]):
__slots__ = ("_markers", "node", "parent")
def __init__(self, node: Node) -> None:
self.node = node
self.parent = node.parent
self._markers = {node.name: True}
def __getitem__(self, key: str) -> Any:
try:
return self._markers[key]
except KeyError:
if self.parent is None:
raise
return self.parent.keywords[key]
def __setitem__(self, key: str, value: Any) -> None:
self._markers[key] = value
# Note: we could've avoided explicitly implementing some of the methods
# below and use the collections.abc fallback, but that would be slow.
def __contains__(self, key: object) -> bool:
return key in self._markers or (
self.parent is not None and key in self.parent.keywords
)
def update( # type: ignore[override]
self,
other: Mapping[str, Any] | Iterable[tuple[str, Any]] = (),
**kwds: Any,
) -> None:
self._markers.update(other)
self._markers.update(kwds)
def __delitem__(self, key: str) -> None:
raise ValueError("cannot delete key in keywords dict")
def __iter__(self) -> Iterator[str]:
# Doesn't need to be fast.
yield from self._markers
if self.parent is not None:
for keyword in self.parent.keywords:
# self._marks and self.parent.keywords can have duplicates.
if keyword not in self._markers:
yield keyword
def __len__(self) -> int:
# Doesn't need to be fast.
return sum(1 for keyword in self)
def __repr__(self) -> str:
return f"<NodeKeywords for node {self.node}>"

View File

@@ -0,0 +1,415 @@
# mypy: allow-untyped-defs
"""Monkeypatching and mocking functionality."""
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import MutableMapping
from contextlib import contextmanager
import os
import re
import sys
from typing import Any
from typing import final
from typing import overload
from typing import TypeVar
import warnings
from _pytest.fixtures import fixture
from _pytest.warning_types import PytestWarning
RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$")
K = TypeVar("K")
V = TypeVar("V")
@fixture
def monkeypatch() -> Generator[MonkeyPatch]:
"""A convenient fixture for monkey-patching.
The fixture provides these methods to modify objects, dictionaries, or
:data:`os.environ`:
* :meth:`monkeypatch.setattr(obj, name, value, raising=True) <pytest.MonkeyPatch.setattr>`
* :meth:`monkeypatch.delattr(obj, name, raising=True) <pytest.MonkeyPatch.delattr>`
* :meth:`monkeypatch.setitem(mapping, name, value) <pytest.MonkeyPatch.setitem>`
* :meth:`monkeypatch.delitem(obj, name, raising=True) <pytest.MonkeyPatch.delitem>`
* :meth:`monkeypatch.setenv(name, value, prepend=None) <pytest.MonkeyPatch.setenv>`
* :meth:`monkeypatch.delenv(name, raising=True) <pytest.MonkeyPatch.delenv>`
* :meth:`monkeypatch.syspath_prepend(path) <pytest.MonkeyPatch.syspath_prepend>`
* :meth:`monkeypatch.chdir(path) <pytest.MonkeyPatch.chdir>`
* :meth:`monkeypatch.context() <pytest.MonkeyPatch.context>`
All modifications will be undone after the requesting test function or
fixture has finished. The ``raising`` parameter determines if a :class:`KeyError`
or :class:`AttributeError` will be raised if the set/deletion operation does not have the
specified target.
To undo modifications done by the fixture in a contained scope,
use :meth:`context() <pytest.MonkeyPatch.context>`.
"""
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
def resolve(name: str) -> object:
# Simplified from zope.dottedname.
parts = name.split(".")
used = parts.pop(0)
found: object = __import__(used)
for part in parts:
used += "." + part
try:
found = getattr(found, part)
except AttributeError:
pass
else:
continue
# We use explicit un-nesting of the handling block in order
# to avoid nested exceptions.
try:
__import__(used)
except ImportError as ex:
expected = str(ex).split()[-1]
if expected == used:
raise
else:
raise ImportError(f"import error in {used}: {ex}") from ex
found = annotated_getattr(found, part, used)
return found
def annotated_getattr(obj: object, name: str, ann: str) -> object:
try:
obj = getattr(obj, name)
except AttributeError as e:
raise AttributeError(
f"{type(obj).__name__!r} object at {ann} has no attribute {name!r}"
) from e
return obj
def derive_importpath(import_path: str, raising: bool) -> tuple[str, object]:
if not isinstance(import_path, str) or "." not in import_path:
raise TypeError(f"must be absolute import path string, not {import_path!r}")
module, attr = import_path.rsplit(".", 1)
target = resolve(module)
if raising:
annotated_getattr(target, attr, ann=module)
return attr, target
class Notset:
def __repr__(self) -> str:
return "<notset>"
notset = Notset()
@final
class MonkeyPatch:
"""Helper to conveniently monkeypatch attributes/items/environment
variables/syspath.
Returned by the :fixture:`monkeypatch` fixture.
.. versionchanged:: 6.2
Can now also be used directly as `pytest.MonkeyPatch()`, for when
the fixture is not available. In this case, use
:meth:`with MonkeyPatch.context() as mp: <context>` or remember to call
:meth:`undo` explicitly.
"""
def __init__(self) -> None:
self._setattr: list[tuple[object, str, object]] = []
self._setitem: list[tuple[Mapping[Any, Any], object, object]] = []
self._cwd: str | None = None
self._savesyspath: list[str] | None = None
@classmethod
@contextmanager
def context(cls) -> Generator[MonkeyPatch]:
"""Context manager that returns a new :class:`MonkeyPatch` object
which undoes any patching done inside the ``with`` block upon exit.
Example:
.. code-block:: python
import functools
def test_partial(monkeypatch):
with monkeypatch.context() as m:
m.setattr(functools, "partial", 3)
Useful in situations where it is desired to undo some patches before the test ends,
such as mocking ``stdlib`` functions that might break pytest itself if mocked (for examples
of this see :issue:`3290`).
"""
m = cls()
try:
yield m
finally:
m.undo()
@overload
def setattr(
self,
target: str,
name: object,
value: Notset = ...,
raising: bool = ...,
) -> None: ...
@overload
def setattr(
self,
target: object,
name: str,
value: object,
raising: bool = ...,
) -> None: ...
def setattr(
self,
target: str | object,
name: object | str,
value: object = notset,
raising: bool = True,
) -> None:
"""
Set attribute value on target, memorizing the old value.
For example:
.. code-block:: python
import os
monkeypatch.setattr(os, "getcwd", lambda: "/")
The code above replaces the :func:`os.getcwd` function by a ``lambda`` which
always returns ``"/"``.
For convenience, you can specify a string as ``target`` which
will be interpreted as a dotted import path, with the last part
being the attribute name:
.. code-block:: python
monkeypatch.setattr("os.getcwd", lambda: "/")
Raises :class:`AttributeError` if the attribute does not exist, unless
``raising`` is set to False.
**Where to patch**
``monkeypatch.setattr`` works by (temporarily) changing the object that a name points to with another one.
There can be many names pointing to any individual object, so for patching to work you must ensure
that you patch the name used by the system under test.
See the section :ref:`Where to patch <python:where-to-patch>` in the :mod:`unittest.mock`
docs for a complete explanation, which is meant for :func:`unittest.mock.patch` but
applies to ``monkeypatch.setattr`` as well.
"""
__tracebackhide__ = True
import inspect
if isinstance(value, Notset):
if not isinstance(target, str):
raise TypeError(
"use setattr(target, name, value) or "
"setattr(target, value) with target being a dotted "
"import string"
)
value = name
name, target = derive_importpath(target, raising)
else:
if not isinstance(name, str):
raise TypeError(
"use setattr(target, name, value) with name being a string or "
"setattr(target, value) with target being a dotted "
"import string"
)
oldval = getattr(target, name, notset)
if raising and oldval is notset:
raise AttributeError(f"{target!r} has no attribute {name!r}")
# avoid class descriptors like staticmethod/classmethod
if inspect.isclass(target):
oldval = target.__dict__.get(name, notset)
self._setattr.append((target, name, oldval))
setattr(target, name, value)
def delattr(
self,
target: object | str,
name: str | Notset = notset,
raising: bool = True,
) -> None:
"""Delete attribute ``name`` from ``target``.
If no ``name`` is specified and ``target`` is a string
it will be interpreted as a dotted import path with the
last part being the attribute name.
Raises AttributeError it the attribute does not exist, unless
``raising`` is set to False.
"""
__tracebackhide__ = True
import inspect
if isinstance(name, Notset):
if not isinstance(target, str):
raise TypeError(
"use delattr(target, name) or "
"delattr(target) with target being a dotted "
"import string"
)
name, target = derive_importpath(target, raising)
if not hasattr(target, name):
if raising:
raise AttributeError(name)
else:
oldval = getattr(target, name, notset)
# Avoid class descriptors like staticmethod/classmethod.
if inspect.isclass(target):
oldval = target.__dict__.get(name, notset)
self._setattr.append((target, name, oldval))
delattr(target, name)
def setitem(self, dic: Mapping[K, V], name: K, value: V) -> None:
"""Set dictionary entry ``name`` to value."""
self._setitem.append((dic, name, dic.get(name, notset)))
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
dic[name] = value # type: ignore[index]
def delitem(self, dic: Mapping[K, V], name: K, raising: bool = True) -> None:
"""Delete ``name`` from dict.
Raises ``KeyError`` if it doesn't exist, unless ``raising`` is set to
False.
"""
if name not in dic:
if raising:
raise KeyError(name)
else:
self._setitem.append((dic, name, dic.get(name, notset)))
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dic[name] # type: ignore[attr-defined]
def setenv(self, name: str, value: str, prepend: str | None = None) -> None:
"""Set environment variable ``name`` to ``value``.
If ``prepend`` is a character, read the current environment variable
value and prepend the ``value`` adjoined with the ``prepend``
character.
"""
if not isinstance(value, str):
warnings.warn( # type: ignore[unreachable]
PytestWarning(
f"Value of environment variable {name} type should be str, but got "
f"{value!r} (type: {type(value).__name__}); converted to str implicitly"
),
stacklevel=2,
)
value = str(value)
if prepend and name in os.environ:
value = value + prepend + os.environ[name]
self.setitem(os.environ, name, value)
def delenv(self, name: str, raising: bool = True) -> None:
"""Delete ``name`` from the environment.
Raises ``KeyError`` if it does not exist, unless ``raising`` is set to
False.
"""
environ: MutableMapping[str, str] = os.environ
self.delitem(environ, name, raising=raising)
def syspath_prepend(self, path) -> None:
"""Prepend ``path`` to ``sys.path`` list of import locations."""
if self._savesyspath is None:
self._savesyspath = sys.path[:]
sys.path.insert(0, str(path))
# https://github.com/pypa/setuptools/blob/d8b901bc/docs/pkg_resources.txt#L162-L171
# this is only needed when pkg_resources was already loaded by the namespace package
if "pkg_resources" in sys.modules:
from pkg_resources import fixup_namespace_packages
fixup_namespace_packages(str(path))
# A call to syspathinsert() usually means that the caller wants to
# import some dynamically created files, thus with python3 we
# invalidate its import caches.
# This is especially important when any namespace package is in use,
# since then the mtime based FileFinder cache (that gets created in
# this case already) gets not invalidated when writing the new files
# quickly afterwards.
from importlib import invalidate_caches
invalidate_caches()
def chdir(self, path: str | os.PathLike[str]) -> None:
"""Change the current working directory to the specified path.
:param path:
The path to change into.
"""
if self._cwd is None:
self._cwd = os.getcwd()
os.chdir(path)
def undo(self) -> None:
"""Undo previous changes.
This call consumes the undo stack. Calling it a second time has no
effect unless you do more monkeypatching after the undo call.
There is generally no need to call `undo()`, since it is
called automatically during tear-down.
.. note::
The same `monkeypatch` fixture is used across a
single test function invocation. If `monkeypatch` is used both by
the test function itself and one of the test fixtures,
calling `undo()` will undo all of the changes made in
both functions.
Prefer to use :meth:`context() <pytest.MonkeyPatch.context>` instead.
"""
for obj, name, value in reversed(self._setattr):
if value is not notset:
setattr(obj, name, value)
else:
delattr(obj, name)
self._setattr[:] = []
for dictionary, key, value in reversed(self._setitem):
if value is notset:
try:
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dictionary[key] # type: ignore[attr-defined]
except KeyError:
pass # Was already deleted, so we have the desired state.
else:
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
dictionary[key] = value # type: ignore[index]
self._setitem[:] = []
if self._savesyspath is not None:
sys.path[:] = self._savesyspath
self._savesyspath = None
if self._cwd is not None:
os.chdir(self._cwd)
self._cwd = None

View File

@@ -0,0 +1,772 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import abc
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import MutableMapping
from functools import cached_property
from functools import lru_cache
from inspect import signature
import os
import pathlib
from pathlib import Path
from typing import Any
from typing import cast
from typing import NoReturn
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
import warnings
import pluggy
import _pytest._code
from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest._code.code import Traceback
from _pytest._code.code import TracebackStyle
from _pytest.compat import LEGACY_PATH
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
from _pytest.config.compat import _check_path
from _pytest.deprecated import NODE_CTOR_FSPATH_ARG
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords
from _pytest.outcomes import fail
from _pytest.pathlib import absolutepath
from _pytest.stash import Stash
from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
from typing_extensions import Self
# Imported here due to circular import.
from _pytest.main import Session
SEP = "/"
tracebackcutdir = Path(_pytest.__file__).parent
_T = TypeVar("_T")
def _imply_path(
node_type: type[Node],
path: Path | None,
fspath: LEGACY_PATH | None,
) -> Path:
if fspath is not None:
warnings.warn(
NODE_CTOR_FSPATH_ARG.format(
node_type_name=node_type.__name__,
),
stacklevel=6,
)
if path is not None:
if fspath is not None:
_check_path(path, fspath)
return path
else:
assert fspath is not None
return Path(fspath)
_NodeType = TypeVar("_NodeType", bound="Node")
class NodeMeta(abc.ABCMeta):
"""Metaclass used by :class:`Node` to enforce that direct construction raises
:class:`Failed`.
This behaviour supports the indirection introduced with :meth:`Node.from_parent`,
the named constructor to be used instead of direct construction. The design
decision to enforce indirection with :class:`NodeMeta` was made as a
temporary aid for refactoring the collection tree, which was diagnosed to
have :class:`Node` objects whose creational patterns were overly entangled.
Once the refactoring is complete, this metaclass can be removed.
See https://github.com/pytest-dev/pytest/projects/3 for an overview of the
progress on detangling the :class:`Node` classes.
"""
def __call__(cls, *k, **kw) -> NoReturn:
msg = (
"Direct construction of {name} has been deprecated, please use {name}.from_parent.\n"
"See "
"https://docs.pytest.org/en/stable/deprecations.html#node-construction-changed-to-node-from-parent"
" for more details."
).format(name=f"{cls.__module__}.{cls.__name__}")
fail(msg, pytrace=False)
def _create(cls: type[_T], *k, **kw) -> _T:
try:
return super().__call__(*k, **kw) # type: ignore[no-any-return,misc]
except TypeError:
sig = signature(getattr(cls, "__init__"))
known_kw = {k: v for k, v in kw.items() if k in sig.parameters}
from .warning_types import PytestDeprecationWarning
warnings.warn(
PytestDeprecationWarning(
f"{cls} is not using a cooperative constructor and only takes {set(known_kw)}.\n"
"See https://docs.pytest.org/en/stable/deprecations.html"
"#constructors-of-custom-pytest-node-subclasses-should-take-kwargs "
"for more details."
)
)
return super().__call__(*k, **known_kw) # type: ignore[no-any-return,misc]
class Node(abc.ABC, metaclass=NodeMeta):
r"""Base class of :class:`Collector` and :class:`Item`, the components of
the test collection tree.
``Collector``\'s are the internal nodes of the tree, and ``Item``\'s are the
leaf nodes.
"""
# Implemented in the legacypath plugin.
#: A ``LEGACY_PATH`` copy of the :attr:`path` attribute. Intended for usage
#: for methods not migrated to ``pathlib.Path`` yet, such as
#: :meth:`Item.reportinfo <pytest.Item.reportinfo>`. Will be deprecated in
#: a future release, prefer using :attr:`path` instead.
fspath: LEGACY_PATH
# Use __slots__ to make attribute access faster.
# Note that __dict__ is still available.
__slots__ = (
"__dict__",
"_nodeid",
"_store",
"config",
"name",
"parent",
"path",
"session",
)
def __init__(
self,
name: str,
parent: Node | None = None,
config: Config | None = None,
session: Session | None = None,
fspath: LEGACY_PATH | None = None,
path: Path | None = None,
nodeid: str | None = None,
) -> None:
#: A unique name within the scope of the parent node.
self.name: str = name
#: The parent collector node.
self.parent = parent
if config:
#: The pytest config object.
self.config: Config = config
else:
if not parent:
raise TypeError("config or parent must be provided")
self.config = parent.config
if session:
#: The pytest session this node is part of.
self.session: Session = session
else:
if not parent:
raise TypeError("session or parent must be provided")
self.session = parent.session
if path is None and fspath is None:
path = getattr(parent, "path", None)
#: Filesystem path where this node was collected from (can be None).
self.path: pathlib.Path = _imply_path(type(self), path, fspath=fspath)
# The explicit annotation is to avoid publicly exposing NodeKeywords.
#: Keywords/markers collected from all scopes.
self.keywords: MutableMapping[str, Any] = NodeKeywords(self)
#: The marker objects belonging to this node.
self.own_markers: list[Mark] = []
#: Allow adding of extra keywords to use for matching.
self.extra_keyword_matches: set[str] = set()
if nodeid is not None:
assert "::()" not in nodeid
self._nodeid = nodeid
else:
if not self.parent:
raise TypeError("nodeid or parent must be provided")
self._nodeid = self.parent.nodeid + "::" + self.name
#: A place where plugins can store information on the node for their
#: own use.
self.stash: Stash = Stash()
# Deprecated alias. Was never public. Can be removed in a few releases.
self._store = self.stash
@classmethod
def from_parent(cls, parent: Node, **kw) -> Self:
"""Public constructor for Nodes.
This indirection got introduced in order to enable removing
the fragile logic from the node constructors.
Subclasses can use ``super().from_parent(...)`` when overriding the
construction.
:param parent: The parent node of this Node.
"""
if "config" in kw:
raise TypeError("config is not a valid argument for from_parent")
if "session" in kw:
raise TypeError("session is not a valid argument for from_parent")
return cls._create(parent=parent, **kw)
@property
def ihook(self) -> pluggy.HookRelay:
"""fspath-sensitive hook proxy used to call pytest hooks."""
return self.session.gethookproxy(self.path)
def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
def warn(self, warning: Warning) -> None:
"""Issue a warning for this Node.
Warnings will be displayed after the test session, unless explicitly suppressed.
:param Warning warning:
The warning instance to issue.
:raises ValueError: If ``warning`` instance is not a subclass of Warning.
Example usage:
.. code-block:: python
node.warn(PytestWarning("some message"))
node.warn(UserWarning("some message"))
.. versionchanged:: 6.2
Any subclass of :class:`Warning` is now accepted, rather than only
:class:`PytestWarning <pytest.PytestWarning>` subclasses.
"""
# enforce type checks here to avoid getting a generic type error later otherwise.
if not isinstance(warning, Warning):
raise ValueError(
f"warning must be an instance of Warning or subclass, got {warning!r}"
)
path, lineno = get_fslocation_from_item(self)
assert lineno is not None
warnings.warn_explicit(
warning,
category=None,
filename=str(path),
lineno=lineno + 1,
)
# Methods for ordering nodes.
@property
def nodeid(self) -> str:
"""A ::-separated string denoting its collection tree address."""
return self._nodeid
def __hash__(self) -> int:
return hash(self._nodeid)
def setup(self) -> None:
pass
def teardown(self) -> None:
pass
def iter_parents(self) -> Iterator[Node]:
"""Iterate over all parent collectors starting from and including self
up to the root of the collection tree.
.. versionadded:: 8.1
"""
parent: Node | None = self
while parent is not None:
yield parent
parent = parent.parent
def listchain(self) -> list[Node]:
"""Return a list of all parent collectors starting from the root of the
collection tree down to and including self."""
chain = []
item: Node | None = self
while item is not None:
chain.append(item)
item = item.parent
chain.reverse()
return chain
def add_marker(self, marker: str | MarkDecorator, append: bool = True) -> None:
"""Dynamically add a marker object to the node.
:param marker:
The marker.
:param append:
Whether to append the marker, or prepend it.
"""
from _pytest.mark import MARK_GEN
if isinstance(marker, MarkDecorator):
marker_ = marker
elif isinstance(marker, str):
marker_ = getattr(MARK_GEN, marker)
else:
raise ValueError("is not a string or pytest.mark.* Marker")
self.keywords[marker_.name] = marker_
if append:
self.own_markers.append(marker_.mark)
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name: str | None = None) -> Iterator[Mark]:
"""Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute.
:returns: An iterator of the markers of the node.
"""
return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node(
self, name: str | None = None
) -> Iterator[tuple[Node, Mark]]:
"""Iterate over all markers of the node.
:param name: If given, filter the results by the name attribute.
:returns: An iterator of (node, mark) tuples.
"""
for node in self.iter_parents():
for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name:
yield node, mark
@overload
def get_closest_marker(self, name: str) -> Mark | None: ...
@overload
def get_closest_marker(self, name: str, default: Mark) -> Mark: ...
def get_closest_marker(self, name: str, default: Mark | None = None) -> Mark | None:
"""Return the first marker matching the name, from closest (for
example function) to farther level (for example module level).
:param default: Fallback return value if no marker was found.
:param name: Name to filter by.
"""
return next(self.iter_markers(name=name), default)
def listextrakeywords(self) -> set[str]:
"""Return a set of all extra keywords in self and any parents."""
extra_keywords: set[str] = set()
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
def listnames(self) -> list[str]:
return [x.name for x in self.listchain()]
def addfinalizer(self, fin: Callable[[], object]) -> None:
"""Register a function to be called without arguments when this node is
finalized.
This method can only be called when this node is active
in a setup chain, for example during self.setup().
"""
self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls: type[_NodeType]) -> _NodeType | None:
"""Get the closest parent node (including self) which is an instance of
the given class.
:param cls: The node class to search for.
:returns: The node, if found.
"""
for node in self.iter_parents():
if isinstance(node, cls):
return node
return None
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
return excinfo.traceback
def _repr_failure_py(
self,
excinfo: ExceptionInfo[BaseException],
style: TracebackStyle | None = None,
) -> TerminalRepr:
from _pytest.fixtures import FixtureLookupError
if isinstance(excinfo.value, ConftestImportFailure):
excinfo = ExceptionInfo.from_exception(excinfo.value.cause)
if isinstance(excinfo.value, fail.Exception):
if not excinfo.value.pytrace:
style = "value"
if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr()
tbfilter: bool | Callable[[ExceptionInfo[BaseException]], Traceback]
if self.config.getoption("fulltrace", False):
style = "long"
tbfilter = False
else:
tbfilter = self._traceback_filter
if style == "auto":
style = "long"
# XXX should excinfo.getrepr record all data and toterminal() process it?
if style is None:
if self.config.getoption("tbstyle", "auto") == "short":
style = "short"
else:
style = "long"
if self.config.get_verbosity() > 1:
truncate_locals = False
else:
truncate_locals = True
truncate_args = False if self.config.get_verbosity() > 2 else True
# excinfo.getrepr() formats paths relative to the CWD if `abspath` is False.
# It is possible for a fixture/test to change the CWD while this code runs, which
# would then result in the user seeing confusing paths in the failure message.
# To fix this, if the CWD changed, always display the full absolute path.
# It will be better to just always display paths relative to invocation_dir, but
# this requires a lot of plumbing (#6428).
try:
abspath = Path(os.getcwd()) != self.config.invocation_params.dir
except OSError:
abspath = True
return excinfo.getrepr(
funcargs=True,
abspath=abspath,
showlocals=self.config.getoption("showlocals", False),
style=style,
tbfilter=tbfilter,
truncate_locals=truncate_locals,
truncate_args=truncate_args,
)
def repr_failure(
self,
excinfo: ExceptionInfo[BaseException],
style: TracebackStyle | None = None,
) -> str | TerminalRepr:
"""Return a representation of a collection or test failure.
.. seealso:: :ref:`non-python tests`
:param excinfo: Exception information for the failure.
"""
return self._repr_failure_py(excinfo, style)
def get_fslocation_from_item(node: Node) -> tuple[str | Path, int | None]:
"""Try to extract the actual location from a node, depending on available attributes:
* "location": a pair (path, lineno)
* "obj": a Python object that the node wraps.
* "path": just a path
:rtype: A tuple of (str|Path, int) with filename and 0-based line number.
"""
# See Item.location.
location: tuple[str, int | None, str] | None = getattr(node, "location", None)
if location is not None:
return location[:2]
obj = getattr(node, "obj", None)
if obj is not None:
return getfslineno(obj)
return getattr(node, "path", "unknown location"), -1
class Collector(Node, abc.ABC):
"""Base class of all collectors.
Collector create children through `collect()` and thus iteratively build
the collection tree.
"""
class CollectError(Exception):
"""An error during collection, contains a custom message."""
@abc.abstractmethod
def collect(self) -> Iterable[Item | Collector]:
"""Collect children (items and collectors) for this collector."""
raise NotImplementedError("abstract")
# TODO: This omits the style= parameter which breaks Liskov Substitution.
def repr_failure( # type: ignore[override]
self, excinfo: ExceptionInfo[BaseException]
) -> str | TerminalRepr:
"""Return a representation of a collection failure.
:param excinfo: Exception information for the failure.
"""
if isinstance(excinfo.value, self.CollectError) and not self.config.getoption(
"fulltrace", False
):
exc = excinfo.value
return str(exc.args[0])
# Respect explicit tbstyle option, but default to "short"
# (_repr_failure_py uses "long" with "fulltrace" option always).
tbstyle = self.config.getoption("tbstyle", "auto")
if tbstyle == "auto":
tbstyle = "short"
return self._repr_failure_py(excinfo, style=tbstyle)
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
if hasattr(self, "path"):
traceback = excinfo.traceback
ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback:
ntraceback = ntraceback.cut(excludepath=tracebackcutdir)
return ntraceback.filter(excinfo)
return excinfo.traceback
@lru_cache(maxsize=1000)
def _check_initialpaths_for_relpath(
initial_paths: frozenset[Path], path: Path
) -> str | None:
if path in initial_paths:
return ""
for parent in path.parents:
if parent in initial_paths:
return str(path.relative_to(parent))
return None
class FSCollector(Collector, abc.ABC):
"""Base class for filesystem collectors."""
def __init__(
self,
fspath: LEGACY_PATH | None = None,
path_or_parent: Path | Node | None = None,
path: Path | None = None,
name: str | None = None,
parent: Node | None = None,
config: Config | None = None,
session: Session | None = None,
nodeid: str | None = None,
) -> None:
if path_or_parent:
if isinstance(path_or_parent, Node):
assert parent is None
parent = cast(FSCollector, path_or_parent)
elif isinstance(path_or_parent, Path):
assert path is None
path = path_or_parent
path = _imply_path(type(self), path, fspath=fspath)
if name is None:
name = path.name
if parent is not None and parent.path != path:
try:
rel = path.relative_to(parent.path)
except ValueError:
pass
else:
name = str(rel)
name = name.replace(os.sep, SEP)
self.path = path
if session is None:
assert parent is not None
session = parent.session
if nodeid is None:
try:
nodeid = str(self.path.relative_to(session.config.rootpath))
except ValueError:
nodeid = _check_initialpaths_for_relpath(session._initialpaths, path)
if nodeid and os.sep != SEP:
nodeid = nodeid.replace(os.sep, SEP)
super().__init__(
name=name,
parent=parent,
config=config,
session=session,
nodeid=nodeid,
path=path,
)
@classmethod
def from_parent(
cls,
parent,
*,
fspath: LEGACY_PATH | None = None,
path: Path | None = None,
**kw,
) -> Self:
"""The public constructor."""
return super().from_parent(parent=parent, fspath=fspath, path=path, **kw)
class File(FSCollector, abc.ABC):
"""Base class for collecting tests from a file.
:ref:`non-python tests`.
"""
class Directory(FSCollector, abc.ABC):
"""Base class for collecting files from a directory.
A basic directory collector does the following: goes over the files and
sub-directories in the directory and creates collectors for them by calling
the hooks :hook:`pytest_collect_directory` and :hook:`pytest_collect_file`,
after checking that they are not ignored using
:hook:`pytest_ignore_collect`.
The default directory collectors are :class:`~pytest.Dir` and
:class:`~pytest.Package`.
.. versionadded:: 8.0
:ref:`custom directory collectors`.
"""
class Item(Node, abc.ABC):
"""Base class of all test invocation items.
Note that for a single function there might be multiple test invocation items.
"""
nextitem = None
def __init__(
self,
name,
parent=None,
config: Config | None = None,
session: Session | None = None,
nodeid: str | None = None,
**kw,
) -> None:
# The first two arguments are intentionally passed positionally,
# to keep plugins who define a node type which inherits from
# (pytest.Item, pytest.File) working (see issue #8435).
# They can be made kwargs when the deprecation above is done.
super().__init__(
name,
parent,
config=config,
session=session,
nodeid=nodeid,
**kw,
)
self._report_sections: list[tuple[str, str, str]] = []
#: A list of tuples (name, value) that holds user defined properties
#: for this test.
self.user_properties: list[tuple[str, object]] = []
self._check_item_and_collector_diamond_inheritance()
def _check_item_and_collector_diamond_inheritance(self) -> None:
"""
Check if the current type inherits from both File and Collector
at the same time, emitting a warning accordingly (#8447).
"""
cls = type(self)
# We inject an attribute in the type to avoid issuing this warning
# for the same class more than once, which is not helpful.
# It is a hack, but was deemed acceptable in order to avoid
# flooding the user in the common case.
attr_name = "_pytest_diamond_inheritance_warning_shown"
if getattr(cls, attr_name, False):
return
setattr(cls, attr_name, True)
problems = ", ".join(
base.__name__ for base in cls.__bases__ if issubclass(base, Collector)
)
if problems:
warnings.warn(
f"{cls.__name__} is an Item subclass and should not be a collector, "
f"however its bases {problems} are collectors.\n"
"Please split the Collectors and the Item into separate node types.\n"
"Pytest Doc example: https://docs.pytest.org/en/latest/example/nonpython.html\n"
"example pull request on a plugin: https://github.com/asmeurer/pytest-flakes/pull/40/",
PytestWarning,
)
@abc.abstractmethod
def runtest(self) -> None:
"""Run the test case for this item.
Must be implemented by subclasses.
.. seealso:: :ref:`non-python tests`
"""
raise NotImplementedError("runtest must be implemented by Item subclass")
def add_report_section(self, when: str, key: str, content: str) -> None:
"""Add a new report section, similar to what's done internally to add
stdout and stderr captured output::
item.add_report_section("call", "stdout", "report section contents")
:param str when:
One of the possible capture states, ``"setup"``, ``"call"``, ``"teardown"``.
:param str key:
Name of the section, can be customized at will. Pytest uses ``"stdout"`` and
``"stderr"`` internally.
:param str content:
The full contents as a string.
"""
if content:
self._report_sections.append((when, key, content))
def reportinfo(self) -> tuple[os.PathLike[str] | str, int | None, str]:
"""Get location information for this item for test reports.
Returns a tuple with three elements:
- The path of the test (default ``self.path``)
- The 0-based line number of the test (default ``None``)
- A name of the test to be shown (default ``""``)
.. seealso:: :ref:`non-python tests`
"""
return self.path, None, ""
@cached_property
def location(self) -> tuple[str, int | None, str]:
"""
Returns a tuple of ``(relfspath, lineno, testname)`` for this item
where ``relfspath`` is file path relative to ``config.rootpath``
and lineno is a 0-based line number.
"""
location = self.reportinfo()
path = absolutepath(location[0])
relfspath = self.session._node_location_to_relpath(path)
assert type(location[2]) is str
return (relfspath, location[1], location[2])

View File

@@ -0,0 +1,317 @@
"""Exception classes and constants handling test outcomes as well as
functions creating them."""
from __future__ import annotations
from collections.abc import Callable
import sys
from typing import Any
from typing import cast
from typing import NoReturn
from typing import Protocol
from typing import TypeVar
from .warning_types import PytestDeprecationWarning
class OutcomeException(BaseException):
"""OutcomeException and its subclass instances indicate and contain info
about test and collection outcomes."""
def __init__(self, msg: str | None = None, pytrace: bool = True) -> None:
if msg is not None and not isinstance(msg, str):
error_msg = ( # type: ignore[unreachable]
"{} expected string as 'msg' parameter, got '{}' instead.\n"
"Perhaps you meant to use a mark?"
)
raise TypeError(error_msg.format(type(self).__name__, type(msg).__name__))
super().__init__(msg)
self.msg = msg
self.pytrace = pytrace
def __repr__(self) -> str:
if self.msg is not None:
return self.msg
return f"<{self.__class__.__name__} instance>"
__str__ = __repr__
TEST_OUTCOME = (OutcomeException, Exception)
class Skipped(OutcomeException):
# XXX hackish: on 3k we fake to live in the builtins
# in order to have Skipped exception printing shorter/nicer
__module__ = "builtins"
def __init__(
self,
msg: str | None = None,
pytrace: bool = True,
allow_module_level: bool = False,
*,
_use_item_location: bool = False,
) -> None:
super().__init__(msg=msg, pytrace=pytrace)
self.allow_module_level = allow_module_level
# If true, the skip location is reported as the item's location,
# instead of the place that raises the exception/calls skip().
self._use_item_location = _use_item_location
class Failed(OutcomeException):
"""Raised from an explicit call to pytest.fail()."""
__module__ = "builtins"
class Exit(Exception):
"""Raised for immediate program exits (no tracebacks/summaries)."""
def __init__(
self, msg: str = "unknown reason", returncode: int | None = None
) -> None:
self.msg = msg
self.returncode = returncode
super().__init__(msg)
# We need a callable protocol to add attributes, for discussion see
# https://github.com/python/mypy/issues/2087.
_F = TypeVar("_F", bound=Callable[..., object])
_ET = TypeVar("_ET", bound=type[BaseException])
class _WithException(Protocol[_F, _ET]):
Exception: _ET
__call__: _F
def _with_exception(exception_type: _ET) -> Callable[[_F], _WithException[_F, _ET]]:
def decorate(func: _F) -> _WithException[_F, _ET]:
func_with_exception = cast(_WithException[_F, _ET], func)
func_with_exception.Exception = exception_type
return func_with_exception
return decorate
# Exposed helper methods.
@_with_exception(Exit)
def exit(
reason: str = "",
returncode: int | None = None,
) -> NoReturn:
"""Exit testing process.
:param reason:
The message to show as the reason for exiting pytest. reason has a default value
only because `msg` is deprecated.
:param returncode:
Return code to be used when exiting pytest. None means the same as ``0`` (no error), same as :func:`sys.exit`.
:raises pytest.exit.Exception:
The exception that is raised.
"""
__tracebackhide__ = True
raise Exit(reason, returncode)
@_with_exception(Skipped)
def skip(
reason: str = "",
*,
allow_module_level: bool = False,
) -> NoReturn:
"""Skip an executing test with the given message.
This function should be called only during testing (setup, call or teardown) or
during collection by using the ``allow_module_level`` flag. This function can
be called in doctests as well.
:param reason:
The message to show the user as reason for the skip.
:param allow_module_level:
Allows this function to be called at module level.
Raising the skip exception at module level will stop
the execution of the module and prevent the collection of all tests in the module,
even those defined before the `skip` call.
Defaults to False.
:raises pytest.skip.Exception:
The exception that is raised.
.. note::
It is better to use the :ref:`pytest.mark.skipif ref` marker when
possible to declare a test to be skipped under certain conditions
like mismatching platforms or dependencies.
Similarly, use the ``# doctest: +SKIP`` directive (see :py:data:`doctest.SKIP`)
to skip a doctest statically.
"""
__tracebackhide__ = True
raise Skipped(msg=reason, allow_module_level=allow_module_level)
@_with_exception(Failed)
def fail(reason: str = "", pytrace: bool = True) -> NoReturn:
"""Explicitly fail an executing test with the given message.
:param reason:
The message to show the user as reason for the failure.
:param pytrace:
If False, msg represents the full failure information and no
python traceback will be reported.
:raises pytest.fail.Exception:
The exception that is raised.
"""
__tracebackhide__ = True
raise Failed(msg=reason, pytrace=pytrace)
class XFailed(Failed):
"""Raised from an explicit call to pytest.xfail()."""
@_with_exception(XFailed)
def xfail(reason: str = "") -> NoReturn:
"""Imperatively xfail an executing test or setup function with the given reason.
This function should be called only during testing (setup, call or teardown).
No other code is executed after using ``xfail()`` (it is implemented
internally by raising an exception).
:param reason:
The message to show the user as reason for the xfail.
.. note::
It is better to use the :ref:`pytest.mark.xfail ref` marker when
possible to declare a test to be xfailed under certain conditions
like known bugs or missing features.
:raises pytest.xfail.Exception:
The exception that is raised.
"""
__tracebackhide__ = True
raise XFailed(reason)
def importorskip(
modname: str,
minversion: str | None = None,
reason: str | None = None,
*,
exc_type: type[ImportError] | None = None,
) -> Any:
"""Import and return the requested module ``modname``, or skip the
current test if the module cannot be imported.
:param modname:
The name of the module to import.
:param minversion:
If given, the imported module's ``__version__`` attribute must be at
least this minimal version, otherwise the test is still skipped.
:param reason:
If given, this reason is shown as the message when the module cannot
be imported.
:param exc_type:
The exception that should be captured in order to skip modules.
Must be :py:class:`ImportError` or a subclass.
If the module can be imported but raises :class:`ImportError`, pytest will
issue a warning to the user, as often users expect the module not to be
found (which would raise :class:`ModuleNotFoundError` instead).
This warning can be suppressed by passing ``exc_type=ImportError`` explicitly.
See :ref:`import-or-skip-import-error` for details.
:returns:
The imported module. This should be assigned to its canonical name.
:raises pytest.skip.Exception:
If the module cannot be imported.
Example::
docutils = pytest.importorskip("docutils")
.. versionadded:: 8.2
The ``exc_type`` parameter.
"""
import warnings
__tracebackhide__ = True
compile(modname, "", "eval") # to catch syntaxerrors
# Until pytest 9.1, we will warn the user if we catch ImportError (instead of ModuleNotFoundError),
# as this might be hiding an installation/environment problem, which is not usually what is intended
# when using importorskip() (#11523).
# In 9.1, to keep the function signature compatible, we just change the code below to:
# 1. Use `exc_type = ModuleNotFoundError` if `exc_type` is not given.
# 2. Remove `warn_on_import` and the warning handling.
if exc_type is None:
exc_type = ImportError
warn_on_import_error = True
else:
warn_on_import_error = False
skipped: Skipped | None = None
warning: Warning | None = None
with warnings.catch_warnings():
# Make sure to ignore ImportWarnings that might happen because
# of existing directories with the same name we're trying to
# import but without a __init__.py file.
warnings.simplefilter("ignore")
try:
__import__(modname)
except exc_type as exc:
# Do not raise or issue warnings inside the catch_warnings() block.
if reason is None:
reason = f"could not import {modname!r}: {exc}"
skipped = Skipped(reason, allow_module_level=True)
if warn_on_import_error and not isinstance(exc, ModuleNotFoundError):
lines = [
"",
f"Module '{modname}' was found, but when imported by pytest it raised:",
f" {exc!r}",
"In pytest 9.1 this warning will become an error by default.",
"You can fix the underlying problem, or alternatively overwrite this behavior and silence this "
"warning by passing exc_type=ImportError explicitly.",
"See https://docs.pytest.org/en/stable/deprecations.html#pytest-importorskip-default-behavior-regarding-importerror",
]
warning = PytestDeprecationWarning("\n".join(lines))
if warning:
warnings.warn(warning, stacklevel=2)
if skipped:
raise skipped
mod = sys.modules[modname]
if minversion is None:
return mod
verattr = getattr(mod, "__version__", None)
if minversion is not None:
# Imported lazily to improve start-up time.
from packaging.version import Version
if verattr is None or Version(verattr) < Version(minversion):
raise Skipped(
f"module {modname!r} has __version__ {verattr!r}, required is: {minversion!r}",
allow_module_level=True,
)
return mod

View File

@@ -0,0 +1,117 @@
# mypy: allow-untyped-defs
"""Submit failure or test session information to a pastebin service."""
from __future__ import annotations
from io import StringIO
import tempfile
from typing import IO
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter
import pytest
pastebinfile_key = StashKey[IO[bytes]]()
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--pastebin",
metavar="mode",
action="store",
dest="pastebin",
default=None,
choices=["failed", "all"],
help="Send failed|all info to bpaste.net pastebin service",
)
@pytest.hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
if config.option.pastebin == "all":
tr = config.pluginmanager.getplugin("terminalreporter")
# If no terminal reporter plugin is present, nothing we can do here;
# this can happen when this function executes in a worker node
# when using pytest-xdist, for example.
if tr is not None:
# pastebin file will be UTF-8 encoded binary file.
config.stash[pastebinfile_key] = tempfile.TemporaryFile("w+b")
oldwrite = tr._tw.write
def tee_write(s, **kwargs):
oldwrite(s, **kwargs)
if isinstance(s, str):
s = s.encode("utf-8")
config.stash[pastebinfile_key].write(s)
tr._tw.write = tee_write
def pytest_unconfigure(config: Config) -> None:
if pastebinfile_key in config.stash:
pastebinfile = config.stash[pastebinfile_key]
# Get terminal contents and delete file.
pastebinfile.seek(0)
sessionlog = pastebinfile.read()
pastebinfile.close()
del config.stash[pastebinfile_key]
# Undo our patching in the terminal reporter.
tr = config.pluginmanager.getplugin("terminalreporter")
del tr._tw.__dict__["write"]
# Write summary.
tr.write_sep("=", "Sending information to Paste Service")
pastebinurl = create_new_paste(sessionlog)
tr.write_line(f"pastebin session-log: {pastebinurl}\n")
def create_new_paste(contents: str | bytes) -> str:
"""Create a new paste using the bpaste.net service.
:contents: Paste contents string.
:returns: URL to the pasted contents, or an error message.
"""
import re
from urllib.error import HTTPError
from urllib.parse import urlencode
from urllib.request import urlopen
params = {"code": contents, "lexer": "text", "expiry": "1week"}
url = "https://bpa.st"
try:
response: str = (
urlopen(url, data=urlencode(params).encode("ascii")).read().decode("utf-8")
)
except HTTPError as e:
with e: # HTTPErrors are also http responses that must be closed!
return f"bad response: {e}"
except OSError as e: # eg urllib.error.URLError
return f"bad response: {e}"
m = re.search(r'href="/raw/(\w+)"', response)
if m:
return f"{url}/show/{m.group(1)}"
else:
return "bad response: invalid format ('" + response + "')"
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
if terminalreporter.config.option.pastebin != "failed":
return
if "failed" in terminalreporter.stats:
terminalreporter.write_sep("=", "Sending information to Paste Service")
for rep in terminalreporter.stats["failed"]:
try:
msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc
except AttributeError:
msg = terminalreporter._getfailureheadline(rep)
file = StringIO()
tw = create_terminal_writer(terminalreporter.config, file)
rep.toterminal(tw)
s = file.getvalue()
assert len(s)
pastebinurl = create_new_paste(s)
terminalreporter.write_line(f"{msg} --> {pastebinurl}")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,74 @@
"""Helper plugin for pytester; should not be loaded on its own."""
# This plugin contains assertions used by pytester. pytester cannot
# contain them itself, since it is imported by the `pytest` module,
# hence cannot be subject to assertion rewriting, which requires a
# module to not be already imported.
from __future__ import annotations
from collections.abc import Sequence
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
def assertoutcome(
outcomes: tuple[
Sequence[TestReport],
Sequence[CollectReport | TestReport],
Sequence[CollectReport | TestReport],
],
passed: int = 0,
skipped: int = 0,
failed: int = 0,
) -> None:
__tracebackhide__ = True
realpassed, realskipped, realfailed = outcomes
obtained = {
"passed": len(realpassed),
"skipped": len(realskipped),
"failed": len(realfailed),
}
expected = {"passed": passed, "skipped": skipped, "failed": failed}
assert obtained == expected, outcomes
def assert_outcomes(
outcomes: dict[str, int],
passed: int = 0,
skipped: int = 0,
failed: int = 0,
errors: int = 0,
xpassed: int = 0,
xfailed: int = 0,
warnings: int | None = None,
deselected: int | None = None,
) -> None:
"""Assert that the specified outcomes appear with the respective
numbers (0 means it didn't occur) in the text output from a test run."""
__tracebackhide__ = True
obtained = {
"passed": outcomes.get("passed", 0),
"skipped": outcomes.get("skipped", 0),
"failed": outcomes.get("failed", 0),
"errors": outcomes.get("errors", 0),
"xpassed": outcomes.get("xpassed", 0),
"xfailed": outcomes.get("xfailed", 0),
}
expected = {
"passed": passed,
"skipped": skipped,
"failed": failed,
"errors": errors,
"xpassed": xpassed,
"xfailed": xfailed,
}
if warnings is not None:
obtained["warnings"] = outcomes.get("warnings", 0)
expected["warnings"] = warnings
if deselected is not None:
obtained["deselected"] = outcomes.get("deselected", 0)
expected["deselected"] = deselected
assert obtained == expected

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,793 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from collections.abc import Collection
from collections.abc import Mapping
from collections.abc import Sequence
from collections.abc import Sized
from decimal import Decimal
import math
from numbers import Complex
import pprint
import sys
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from numpy import ndarray
def _compare_approx(
full_object: object,
message_data: Sequence[tuple[str, str, str]],
number_of_elements: int,
different_ids: Sequence[object],
max_abs_diff: float,
max_rel_diff: float,
) -> list[str]:
message_list = list(message_data)
message_list.insert(0, ("Index", "Obtained", "Expected"))
max_sizes = [0, 0, 0]
for index, obtained, expected in message_list:
max_sizes[0] = max(max_sizes[0], len(index))
max_sizes[1] = max(max_sizes[1], len(obtained))
max_sizes[2] = max(max_sizes[2], len(expected))
explanation = [
f"comparison failed. Mismatched elements: {len(different_ids)} / {number_of_elements}:",
f"Max absolute difference: {max_abs_diff}",
f"Max relative difference: {max_rel_diff}",
] + [
f"{indexes:<{max_sizes[0]}} | {obtained:<{max_sizes[1]}} | {expected:<{max_sizes[2]}}"
for indexes, obtained, expected in message_list
]
return explanation
# builtin pytest.approx helper
class ApproxBase:
"""Provide shared utilities for making approximate comparisons between
numbers or sequences of numbers."""
# Tell numpy to use our `__eq__` operator instead of its.
__array_ufunc__ = None
__array_priority__ = 100
def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
__tracebackhide__ = True
self.expected = expected
self.abs = abs
self.rel = rel
self.nan_ok = nan_ok
self._check_type()
def __repr__(self) -> str:
raise NotImplementedError
def _repr_compare(self, other_side: Any) -> list[str]:
return [
"comparison failed",
f"Obtained: {other_side}",
f"Expected: {self}",
]
def __eq__(self, actual) -> bool:
return all(
a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
)
def __bool__(self):
__tracebackhide__ = True
raise AssertionError(
"approx() is not supported in a boolean context.\nDid you mean: `assert a == approx(b)`?"
)
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
def __ne__(self, actual) -> bool:
return not (actual == self)
def _approx_scalar(self, x) -> ApproxScalar:
if isinstance(x, Decimal):
return ApproxDecimal(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
def _yield_comparisons(self, actual):
"""Yield all the pairs of numbers to be compared.
This is used to implement the `__eq__` method.
"""
raise NotImplementedError
def _check_type(self) -> None:
"""Raise a TypeError if the expected value is not a valid type."""
# This is only a concern if the expected value is a sequence. In every
# other case, the approx() function ensures that the expected value has
# a numeric type. For this reason, the default is to do nothing. The
# classes that deal with sequences should reimplement this method to
# raise if there are any non-numeric elements in the sequence.
def _recursive_sequence_map(f, x):
"""Recursively map a function over a sequence of arbitrary depth"""
if isinstance(x, (list, tuple)):
seq_type = type(x)
return seq_type(_recursive_sequence_map(f, xi) for xi in x)
elif _is_sequence_like(x):
return [_recursive_sequence_map(f, xi) for xi in x]
else:
return f(x)
class ApproxNumpy(ApproxBase):
"""Perform approximate comparisons where the expected value is numpy array."""
def __repr__(self) -> str:
list_scalars = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
return f"approx({list_scalars!r})"
def _repr_compare(self, other_side: ndarray | list[Any]) -> list[str]:
import itertools
import math
def get_value_from_nested_list(
nested_list: list[Any], nd_index: tuple[Any, ...]
) -> Any:
"""
Helper function to get the value out of a nested list, given an n-dimensional index.
This mimics numpy's indexing, but for raw nested python lists.
"""
value: Any = nested_list
for i in nd_index:
value = value[i]
return value
np_array_shape = self.expected.shape
approx_side_as_seq = _recursive_sequence_map(
self._approx_scalar, self.expected.tolist()
)
# convert other_side to numpy array to ensure shape attribute is available
other_side_as_array = _as_numpy_array(other_side)
assert other_side_as_array is not None
if np_array_shape != other_side_as_array.shape:
return [
"Impossible to compare arrays with different shapes.",
f"Shapes: {np_array_shape} and {other_side_as_array.shape}",
]
number_of_elements = self.expected.size
max_abs_diff = -math.inf
max_rel_diff = -math.inf
different_ids = []
for index in itertools.product(*(range(i) for i in np_array_shape)):
approx_value = get_value_from_nested_list(approx_side_as_seq, index)
other_value = get_value_from_nested_list(other_side_as_array, index)
if approx_value != other_value:
abs_diff = abs(approx_value.expected - other_value)
max_abs_diff = max(max_abs_diff, abs_diff)
if other_value == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(max_rel_diff, abs_diff / abs(other_value))
different_ids.append(index)
message_data = [
(
str(index),
str(get_value_from_nested_list(other_side_as_array, index)),
str(get_value_from_nested_list(approx_side_as_seq, index)),
)
for index in different_ids
]
return _compare_approx(
self.expected,
message_data,
number_of_elements,
different_ids,
max_abs_diff,
max_rel_diff,
)
def __eq__(self, actual) -> bool:
import numpy as np
# self.expected is supposed to always be an array here.
if not np.isscalar(actual):
try:
actual = np.asarray(actual)
except Exception as e:
raise TypeError(f"cannot compare '{actual}' to numpy.ndarray") from e
if not np.isscalar(actual) and actual.shape != self.expected.shape:
return False
return super().__eq__(actual)
def _yield_comparisons(self, actual):
import numpy as np
# `actual` can either be a numpy array or a scalar, it is treated in
# `__eq__` before being passed to `ApproxBase.__eq__`, which is the
# only method that calls this one.
if np.isscalar(actual):
for i in np.ndindex(self.expected.shape):
yield actual, self.expected[i].item()
else:
for i in np.ndindex(self.expected.shape):
yield actual[i].item(), self.expected[i].item()
class ApproxMapping(ApproxBase):
"""Perform approximate comparisons where the expected value is a mapping
with numeric values (the keys can be anything)."""
def __repr__(self) -> str:
return f"approx({ ({k: self._approx_scalar(v) for k, v in self.expected.items()})!r})"
def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
import math
approx_side_as_map = {
k: self._approx_scalar(v) for k, v in self.expected.items()
}
number_of_elements = len(approx_side_as_map)
max_abs_diff = -math.inf
max_rel_diff = -math.inf
different_ids = []
for (approx_key, approx_value), other_value in zip(
approx_side_as_map.items(), other_side.values()
):
if approx_value != other_value:
if approx_value.expected is not None and other_value is not None:
try:
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
if approx_value.expected == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
)
except ZeroDivisionError:
pass
different_ids.append(approx_key)
message_data = [
(str(key), str(other_side[key]), str(approx_side_as_map[key]))
for key in different_ids
]
return _compare_approx(
self.expected,
message_data,
number_of_elements,
different_ids,
max_abs_diff,
max_rel_diff,
)
def __eq__(self, actual) -> bool:
try:
if set(actual.keys()) != set(self.expected.keys()):
return False
except AttributeError:
return False
return super().__eq__(actual)
def _yield_comparisons(self, actual):
for k in self.expected.keys():
yield actual[k], self.expected[k]
def _check_type(self) -> None:
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
class ApproxSequenceLike(ApproxBase):
"""Perform approximate comparisons where the expected value is a sequence of numbers."""
def __repr__(self) -> str:
seq_type = type(self.expected)
if seq_type not in (tuple, list):
seq_type = list
return f"approx({seq_type(self._approx_scalar(x) for x in self.expected)!r})"
def _repr_compare(self, other_side: Sequence[float]) -> list[str]:
import math
if len(self.expected) != len(other_side):
return [
"Impossible to compare lists with different sizes.",
f"Lengths: {len(self.expected)} and {len(other_side)}",
]
approx_side_as_map = _recursive_sequence_map(self._approx_scalar, self.expected)
number_of_elements = len(approx_side_as_map)
max_abs_diff = -math.inf
max_rel_diff = -math.inf
different_ids = []
for i, (approx_value, other_value) in enumerate(
zip(approx_side_as_map, other_side)
):
if approx_value != other_value:
try:
abs_diff = abs(approx_value.expected - other_value)
max_abs_diff = max(max_abs_diff, abs_diff)
# Ignore non-numbers for the diff calculations (#13012).
except TypeError:
pass
else:
if other_value == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(max_rel_diff, abs_diff / abs(other_value))
different_ids.append(i)
message_data = [
(str(i), str(other_side[i]), str(approx_side_as_map[i]))
for i in different_ids
]
return _compare_approx(
self.expected,
message_data,
number_of_elements,
different_ids,
max_abs_diff,
max_rel_diff,
)
def __eq__(self, actual) -> bool:
try:
if len(actual) != len(self.expected):
return False
except TypeError:
return False
return super().__eq__(actual)
def _yield_comparisons(self, actual):
return zip(actual, self.expected)
def _check_type(self) -> None:
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
class ApproxScalar(ApproxBase):
"""Perform approximate comparisons where the expected value is a single number."""
# Using Real should be better than this Union, but not possible yet:
# https://github.com/python/typeshed/pull/3108
DEFAULT_ABSOLUTE_TOLERANCE: float | Decimal = 1e-12
DEFAULT_RELATIVE_TOLERANCE: float | Decimal = 1e-6
def __repr__(self) -> str:
"""Return a string communicating both the expected value and the
tolerance for the comparison being made.
For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ∠ ±180°``.
"""
# Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (
isinstance(self.expected, bool)
or (not isinstance(self.expected, (Complex, Decimal)))
or math.isinf(abs(self.expected) or isinstance(self.expected, bool))
):
return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'.
try:
if 1e-3 <= self.tolerance < 1e3:
vetted_tolerance = f"{self.tolerance:n}"
else:
vetted_tolerance = f"{self.tolerance:.1e}"
if (
isinstance(self.expected, Complex)
and self.expected.imag
and not math.isinf(self.tolerance)
):
vetted_tolerance += " ∠ ±180°"
except ValueError:
vetted_tolerance = "???"
return f"{self.expected} ± {vetted_tolerance}"
def __eq__(self, actual) -> bool:
"""Return whether the given value is equal to the expected value
within the pre-specified tolerance."""
def is_bool(val: Any) -> bool:
# Check if `val` is a native bool or numpy bool.
if isinstance(val, bool):
return True
try:
import numpy as np
return isinstance(val, np.bool_)
except ImportError:
return False
asarray = _as_numpy_array(actual)
if asarray is not None:
# Call ``__eq__()`` manually to prevent infinite-recursion with
# numpy<1.13. See #3748.
return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality, except for bool and np.bool_
if is_bool(self.expected) and not is_bool(actual):
return False
elif actual == self.expected:
return True
# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined. Also, consider bool to be
# non-numeric, even though it has the required arithmetic.
if is_bool(self.expected) or not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
return False
# Allow the user to control whether NaNs are considered equal to each
# other or not. The abs() calls are for compatibility with complex
# numbers.
if math.isnan(abs(self.expected)):
return self.nan_ok and math.isnan(abs(actual))
# Infinity shouldn't be approximately equal to anything but itself, but
# if there's a relative tolerance, it will be infinite and infinity
# will seem approximately equal to everything. The equal-to-itself
# case would have been short circuited above, so here we can just
# return false if the expected value is infinite. The abs() call is
# for compatibility with complex numbers.
if math.isinf(abs(self.expected)):
return False
# Return true if the two numbers are within the tolerance.
result: bool = abs(self.expected - actual) <= self.tolerance
return result
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore
@property
def tolerance(self):
"""Return the tolerance for the comparison.
This could be either an absolute tolerance or a relative tolerance,
depending on what the user specified or which would be larger.
"""
def set_default(x, default):
return x if x is not None else default
# Figure out what the absolute tolerance should be. ``self.abs`` is
# either None or a value specified by the user.
absolute_tolerance = set_default(self.abs, self.DEFAULT_ABSOLUTE_TOLERANCE)
if absolute_tolerance < 0:
raise ValueError(
f"absolute tolerance can't be negative: {absolute_tolerance}"
)
if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.")
# If the user specified an absolute tolerance but not a relative one,
# just return the absolute tolerance.
if self.rel is None:
if self.abs is not None:
return absolute_tolerance
# Figure out what the relative tolerance should be. ``self.rel`` is
# either None or a value specified by the user. This is done after
# we've made sure the user didn't ask for an absolute tolerance only,
# because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it.
relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE
) * abs(self.expected)
if relative_tolerance < 0:
raise ValueError(
f"relative tolerance can't be negative: {relative_tolerance}"
)
if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.")
# Return the larger of the relative and absolute tolerances.
return max(relative_tolerance, absolute_tolerance)
class ApproxDecimal(ApproxScalar):
"""Perform approximate comparisons where the expected value is a Decimal."""
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
"""Assert that two numbers (or two ordered sequences of numbers) are equal to each other
within some tolerance.
Due to the :doc:`python:tutorial/floatingpoint`, numbers that we
would intuitively expect to be equal are not always so::
>>> 0.1 + 0.2 == 0.3
False
This problem is commonly encountered when writing tests, e.g. when making
sure that floating-point values are what you expect them to be. One way to
deal with this problem is to assert that two floating-point numbers are
equal to within some appropriate tolerance::
>>> abs((0.1 + 0.2) - 0.3) < 1e-6
True
However, comparisons like this are tedious to write and difficult to
understand. Furthermore, absolute comparisons like the one above are
usually discouraged because there's no tolerance that works well for all
situations. ``1e-6`` is good for numbers around ``1``, but too small for
very big numbers and too big for very small ones. It's better to express
the tolerance as a fraction of the expected value, but relative comparisons
like that are even more difficult to write correctly and concisely.
The ``approx`` class performs floating-point comparisons using a syntax
that's as intuitive as possible::
>>> from pytest import approx
>>> 0.1 + 0.2 == approx(0.3)
True
The same syntax also works for ordered sequences of numbers::
>>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
True
``numpy`` arrays::
>>> import numpy as np # doctest: +SKIP
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
True
And for a ``numpy`` array against a scalar::
>>> import numpy as np # doctest: +SKIP
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
True
Only ordered sequences are supported, because ``approx`` needs
to infer the relative position of the sequences without ambiguity. This means
``sets`` and other unordered sequences are not supported.
Finally, dictionary *values* can also be compared::
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
True
The comparison will be true if both mappings have the same keys and their
respective values match the expected tolerances.
**Tolerances**
By default, ``approx`` considers numbers within a relative tolerance of
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
This treatment would lead to surprising results if the expected value was
``0.0``, because nothing but ``0.0`` itself is relatively close to ``0.0``.
To handle this case less surprisingly, ``approx`` also considers numbers
within an absolute tolerance of ``1e-12`` of its expected value to be
equal. Infinity and NaN are special cases. Infinity is only considered
equal to itself, regardless of the relative tolerance. NaN is not
considered equal to anything by default, but you can make it be equal to
itself by setting the ``nan_ok`` argument to True. (This is meant to
facilitate comparing arrays that use NaN to mean "no data".)
Both the relative and absolute tolerances can be changed by passing
arguments to the ``approx`` constructor::
>>> 1.0001 == approx(1)
False
>>> 1.0001 == approx(1, rel=1e-3)
True
>>> 1.0001 == approx(1, abs=1e-3)
True
If you specify ``abs`` but not ``rel``, the comparison will not consider
the relative tolerance at all. In other words, two numbers that are within
the default relative tolerance of ``1e-6`` will still be considered unequal
if they exceed the specified absolute tolerance. If you specify both
``abs`` and ``rel``, the numbers will be considered equal if either
tolerance is met::
>>> 1 + 1e-8 == approx(1)
True
>>> 1 + 1e-8 == approx(1, abs=1e-12)
False
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
True
**Non-numeric types**
You can also use ``approx`` to compare non-numeric types, or dicts and
sequences containing non-numeric types, in which case it falls back to
strict equality. This can be useful for comparing dicts and sequences that
can contain optional values::
>>> {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
True
>>> [None, 1.0000005] == approx([None,1])
True
>>> ["foo", 1.0000005] == approx([None,1])
False
If you're thinking about using ``approx``, then you might want to know how
it compares to other good ways of comparing floating-point numbers. All of
these algorithms are based on relative and absolute tolerances and should
agree for the most part, but they do have meaningful differences:
- ``math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)``: True if the relative
tolerance is met w.r.t. either ``a`` or ``b`` or if the absolute
tolerance is met. Because the relative tolerance is calculated w.r.t.
both ``a`` and ``b``, this test is symmetric (i.e. neither ``a`` nor
``b`` is a "reference value"). You have to specify an absolute tolerance
if you want to compare to ``0.0`` because there is no tolerance by
default. More information: :py:func:`math.isclose`.
- ``numpy.isclose(a, b, rtol=1e-5, atol=1e-8)``: True if the difference
between ``a`` and ``b`` is less that the sum of the relative tolerance
w.r.t. ``b`` and the absolute tolerance. Because the relative tolerance
is only calculated w.r.t. ``b``, this test is asymmetric and you can
think of ``b`` as the reference value. Support for comparing sequences
is provided by :py:func:`numpy.allclose`. More information:
:std:doc:`numpy:reference/generated/numpy.isclose`.
- ``unittest.TestCase.assertAlmostEqual(a, b)``: True if ``a`` and ``b``
are within an absolute tolerance of ``1e-7``. No relative tolerance is
considered , so this function is not appropriate for very large or very
small numbers. Also, it's only available in subclasses of ``unittest.TestCase``
and it's ugly because it doesn't follow PEP8. More information:
:py:meth:`unittest.TestCase.assertAlmostEqual`.
- ``a == pytest.approx(b, rel=1e-6, abs=1e-12)``: True if the relative
tolerance is met w.r.t. ``b`` or if the absolute tolerance is met.
Because the relative tolerance is only calculated w.r.t. ``b``, this test
is asymmetric and you can think of ``b`` as the reference value. In the
special case that you explicitly specify an absolute tolerance but not a
relative tolerance, only the absolute tolerance is considered.
.. note::
``approx`` can handle numpy arrays, but we recommend the
specialised test helpers in :std:doc:`numpy:reference/routines.testing`
if you need support for comparisons, NaNs, or ULP-based tolerances.
To match strings using regex, you can use
`Matches <https://github.com/asottile/re-assert#re_assertmatchespattern-str-args-kwargs>`_
from the
`re_assert package <https://github.com/asottile/re-assert>`_.
.. note::
Unlike built-in equality, this function considers
booleans unequal to numeric zero or one. For example::
>>> 1 == approx(True)
False
.. warning::
.. versionchanged:: 3.2
In order to avoid inconsistent behavior, :py:exc:`TypeError` is
raised for ``>``, ``>=``, ``<`` and ``<=`` comparisons.
The example below illustrates the problem::
assert approx(0.1) > 0.1 + 1e-10 # calls approx(0.1).__gt__(0.1 + 1e-10)
assert 0.1 + 1e-10 > approx(0.1) # calls approx(0.1).__lt__(0.1 + 1e-10)
In the second example one expects ``approx(0.1).__le__(0.1 + 1e-10)``
to be called. But instead, ``approx(0.1).__lt__(0.1 + 1e-10)`` is used to
comparison. This is because the call hierarchy of rich comparisons
follows a fixed behavior. More information: :py:meth:`object.__ge__`
.. versionchanged:: 3.7.1
``approx`` raises ``TypeError`` when it encounters a dict value or
sequence element of non-numeric type.
.. versionchanged:: 6.1.0
``approx`` falls back to strict equality for non-numeric types instead
of raising ``TypeError``.
"""
# Delegate the comparison to a class that knows how to deal with the type
# of the expected value (e.g. int, float, list, dict, numpy.array, etc).
#
# The primary responsibility of these classes is to implement ``__eq__()``
# and ``__repr__()``. The former is used to actually check if some
# "actual" value is equivalent to the given expected value within the
# allowed tolerance. The latter is used to show the user the expected
# value and tolerance, in the case that a test failed.
#
# The actual logic for making approximate comparisons can be found in
# ApproxScalar, which is used to compare individual numbers. All of the
# other Approx classes eventually delegate to this class. The ApproxBase
# class provides some convenient methods and overloads, but isn't really
# essential.
__tracebackhide__ = True
if isinstance(expected, Decimal):
cls: type[ApproxBase] = ApproxDecimal
elif isinstance(expected, Mapping):
cls = ApproxMapping
elif _is_numpy_array(expected):
expected = _as_numpy_array(expected)
cls = ApproxNumpy
elif _is_sequence_like(expected):
cls = ApproxSequenceLike
elif isinstance(expected, Collection) and not isinstance(expected, (str, bytes)):
msg = f"pytest.approx() only supports ordered sequences, but got: {expected!r}"
raise TypeError(msg)
else:
cls = ApproxScalar
return cls(expected, rel, abs, nan_ok)
def _is_sequence_like(expected: object) -> bool:
return (
hasattr(expected, "__getitem__")
and isinstance(expected, Sized)
and not isinstance(expected, (str, bytes))
)
def _is_numpy_array(obj: object) -> bool:
"""
Return true if the given object is implicitly convertible to ndarray,
and numpy is already imported.
"""
return _as_numpy_array(obj) is not None
def _as_numpy_array(obj: object) -> ndarray | None:
"""
Return an ndarray if the given object is implicitly convertible to ndarray,
and numpy is already imported, otherwise None.
"""
np: Any = sys.modules.get("numpy")
if np is not None:
# avoid infinite recursion on numpy scalars, which have __array__
if np.isscalar(obj):
return None
elif isinstance(obj, np.ndarray):
return obj
elif hasattr(obj, "__array__") or hasattr("obj", "__array_interface__"):
return np.asarray(obj)
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,365 @@
# mypy: allow-untyped-defs
"""Record warnings during test function execution."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from pprint import pformat
import re
from types import TracebackType
from typing import Any
from typing import final
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import Self
import warnings
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.outcomes import Exit
from _pytest.outcomes import fail
T = TypeVar("T")
@fixture
def recwarn() -> Generator[WarningsRecorder]:
"""Return a :class:`WarningsRecorder` instance that records all warnings emitted by test functions.
See :ref:`warnings` for information on warning categories.
"""
wrec = WarningsRecorder(_ispytest=True)
with wrec:
warnings.simplefilter("default")
yield wrec
@overload
def deprecated_call(
*, match: str | re.Pattern[str] | None = ...
) -> WarningsRecorder: ...
@overload
def deprecated_call(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: ...
def deprecated_call(
func: Callable[..., Any] | None = None, *args: Any, **kwargs: Any
) -> WarningsRecorder | Any:
"""Assert that code produces a ``DeprecationWarning`` or ``PendingDeprecationWarning`` or ``FutureWarning``.
This function can be used as a context manager::
>>> import warnings
>>> def api_call_v2():
... warnings.warn('use v3 of this api', DeprecationWarning)
... return 200
>>> import pytest
>>> with pytest.deprecated_call():
... assert api_call_v2() == 200
It can also be used by passing a function and ``*args`` and ``**kwargs``,
in which case it will ensure calling ``func(*args, **kwargs)`` produces one of
the warnings types above. The return value is the return value of the function.
In the context manager form you may use the keyword argument ``match`` to assert
that the warning matches a text or regex.
The context manager produces a list of :class:`warnings.WarningMessage` objects,
one for each warning raised.
"""
__tracebackhide__ = True
if func is not None:
args = (func, *args)
return warns(
(DeprecationWarning, PendingDeprecationWarning, FutureWarning), *args, **kwargs
)
@overload
def warns(
expected_warning: type[Warning] | tuple[type[Warning], ...] = ...,
*,
match: str | re.Pattern[str] | None = ...,
) -> WarningsChecker: ...
@overload
def warns(
expected_warning: type[Warning] | tuple[type[Warning], ...],
func: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> T: ...
def warns(
expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
*args: Any,
match: str | re.Pattern[str] | None = None,
**kwargs: Any,
) -> WarningsChecker | Any:
r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or tuple
of warning classes, and the code inside the ``with`` block must issue at least one
warning of that class or classes.
This helper produces a list of :class:`warnings.WarningMessage` objects, one for
each warning emitted (regardless of whether it is an ``expected_warning`` or not).
Since pytest 8.0, unmatched warnings are also re-emitted when the context closes.
This function can be used as a context manager::
>>> import pytest
>>> with pytest.warns(RuntimeWarning):
... warnings.warn("my warning", RuntimeWarning)
In the context manager form you may use the keyword argument ``match`` to assert
that the warning matches a text or regex::
>>> with pytest.warns(UserWarning, match='must be 0 or None'):
... warnings.warn("value must be 0 or None", UserWarning)
>>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("value must be 42", UserWarning)
>>> with pytest.warns(UserWarning): # catch re-emitted warning
... with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("this is not here", UserWarning)
Traceback (most recent call last):
...
Failed: DID NOT WARN. No warnings of type ...UserWarning... were emitted...
**Using with** ``pytest.mark.parametrize``
When using :ref:`pytest.mark.parametrize ref` it is possible to parametrize tests
such that some runs raise a warning and others do not.
This could be achieved in the same way as with exceptions, see
:ref:`parametrizing_conditional_raising` for an example.
"""
__tracebackhide__ = True
if not args:
if kwargs:
argnames = ", ".join(sorted(kwargs))
raise TypeError(
f"Unexpected keyword arguments passed to pytest.warns: {argnames}"
"\nUse context-manager form instead?"
)
return WarningsChecker(expected_warning, match_expr=match, _ispytest=True)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
with WarningsChecker(expected_warning, _ispytest=True):
return func(*args[1:], **kwargs)
class WarningsRecorder(warnings.catch_warnings): # type:ignore[type-arg]
"""A context manager to record raised warnings.
Each recorded warning is an instance of :class:`warnings.WarningMessage`.
Adapted from `warnings.catch_warnings`.
.. note::
``DeprecationWarning`` and ``PendingDeprecationWarning`` are treated
differently; see :ref:`ensuring_function_triggers`.
"""
def __init__(self, *, _ispytest: bool = False) -> None:
check_ispytest(_ispytest)
super().__init__(record=True)
self._entered = False
self._list: list[warnings.WarningMessage] = []
@property
def list(self) -> list[warnings.WarningMessage]:
"""The list of recorded warnings."""
return self._list
def __getitem__(self, i: int) -> warnings.WarningMessage:
"""Get a recorded warning by index."""
return self._list[i]
def __iter__(self) -> Iterator[warnings.WarningMessage]:
"""Iterate through the recorded warnings."""
return iter(self._list)
def __len__(self) -> int:
"""The number of recorded warnings."""
return len(self._list)
def pop(self, cls: type[Warning] = Warning) -> warnings.WarningMessage:
"""Pop the first recorded warning which is an instance of ``cls``,
but not an instance of a child class of any other match.
Raises ``AssertionError`` if there is no match.
"""
best_idx: int | None = None
for i, w in enumerate(self._list):
if w.category == cls:
return self._list.pop(i) # exact match, stop looking
if issubclass(w.category, cls) and (
best_idx is None
or not issubclass(w.category, self._list[best_idx].category)
):
best_idx = i
if best_idx is not None:
return self._list.pop(best_idx)
__tracebackhide__ = True
raise AssertionError(f"{cls!r} not found in warning list")
def clear(self) -> None:
"""Clear the list of recorded warnings."""
self._list[:] = []
def __enter__(self) -> Self:
if self._entered:
__tracebackhide__ = True
raise RuntimeError(f"Cannot enter {self!r} twice")
_list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always")
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if not self._entered:
__tracebackhide__ = True
raise RuntimeError(f"Cannot exit {self!r} without entering first")
super().__exit__(exc_type, exc_val, exc_tb)
# Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable.
self._entered = False
@final
class WarningsChecker(WarningsRecorder):
def __init__(
self,
expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning,
match_expr: str | re.Pattern[str] | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
super().__init__(_ispytest=True)
msg = "exceptions must be derived from Warning, not %s"
if isinstance(expected_warning, tuple):
for exc in expected_warning:
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
expected_warning_tup = expected_warning
elif isinstance(expected_warning, type) and issubclass(
expected_warning, Warning
):
expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning))
self.expected_warning = expected_warning_tup
self.match_expr = match_expr
def matches(self, warning: warnings.WarningMessage) -> bool:
assert self.expected_warning is not None
return issubclass(warning.category, self.expected_warning) and bool(
self.match_expr is None or re.search(self.match_expr, str(warning.message))
)
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
super().__exit__(exc_type, exc_val, exc_tb)
__tracebackhide__ = True
# BaseExceptions like pytest.{skip,fail,xfail,exit} or Ctrl-C within
# pytest.warns should *not* trigger "DID NOT WARN" and get suppressed
# when the warning doesn't happen. Control-flow exceptions should always
# propagate.
if exc_val is not None and (
not isinstance(exc_val, Exception)
# Exit is an Exception, not a BaseException, for some reason.
or isinstance(exc_val, Exit)
):
return
def found_str() -> str:
return pformat([record.message for record in self], indent=2)
try:
if not any(issubclass(w.category, self.expected_warning) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f" Emitted warnings: {found_str()}."
)
elif not any(self.matches(w) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n"
f" Regex: {self.match_expr}\n"
f" Emitted warnings: {found_str()}."
)
finally:
# Whether or not any warnings matched, we want to re-emit all unmatched warnings.
for w in self:
if not self.matches(w):
warnings.warn_explicit(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
module=w.__module__,
source=w.source,
)
# Currently in Python it is possible to pass other types than an
# `str` message when creating `Warning` instances, however this
# causes an exception when :func:`warnings.filterwarnings` is used
# to filter those warnings. See
# https://github.com/python/cpython/issues/103577 for a discussion.
# While this can be considered a bug in CPython, we put guards in
# pytest as the error message produced without this check in place
# is confusing (#10865).
for w in self:
if type(w.message) is not UserWarning:
# If the warning was of an incorrect type then `warnings.warn()`
# creates a UserWarning. Any other warning must have been specified
# explicitly.
continue
if not w.message.args:
# UserWarning() without arguments must have been specified explicitly.
continue
msg = w.message.args[0]
if isinstance(msg, str):
continue
# It's possible that UserWarning was explicitly specified, and
# its first argument was not a string. But that case can't be
# distinguished from an invalid type.
raise TypeError(
f"Warning must be str or Warning, got {msg!r} (type {type(msg).__name__})"
)

View File

@@ -0,0 +1,637 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from collections.abc import Sequence
import dataclasses
from io import StringIO
import os
from pprint import pprint
from typing import Any
from typing import cast
from typing import final
from typing import Literal
from typing import NoReturn
from typing import TYPE_CHECKING
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ExceptionRepr
from _pytest._code.code import ReprEntry
from _pytest._code.code import ReprEntryNative
from _pytest._code.code import ReprExceptionInfo
from _pytest._code.code import ReprFileLocation
from _pytest._code.code import ReprFuncArgs
from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.config import Config
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import fail
from _pytest.outcomes import skip
if TYPE_CHECKING:
from typing_extensions import Self
from _pytest.runner import CallInfo
def getworkerinfoline(node):
try:
return node._workerinfocache
except AttributeError:
d = node.workerinfo
ver = "{}.{}.{}".format(*d["version_info"][:3])
node._workerinfocache = s = "[{}] {} -- Python {} {}".format(
d["id"], d["sysplatform"], ver, d["executable"]
)
return s
class BaseReport:
when: str | None
location: tuple[str, int | None, str] | None
longrepr: (
None | ExceptionInfo[BaseException] | tuple[str, int, str] | str | TerminalRepr
)
sections: list[tuple[str, str]]
nodeid: str
outcome: Literal["passed", "failed", "skipped"]
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)
if TYPE_CHECKING:
# Can have arbitrary fields given to __init__().
def __getattr__(self, key: str) -> Any: ...
def toterminal(self, out: TerminalWriter) -> None:
if hasattr(self, "node"):
worker_info = getworkerinfoline(self.node)
if worker_info:
out.line(worker_info)
longrepr = self.longrepr
if longrepr is None:
return
if hasattr(longrepr, "toterminal"):
longrepr_terminal = cast(TerminalRepr, longrepr)
longrepr_terminal.toterminal(out)
else:
try:
s = str(longrepr)
except UnicodeEncodeError:
s = "<unprintable longrepr>"
out.line(s)
def get_sections(self, prefix: str) -> Iterator[tuple[str, str]]:
for name, content in self.sections:
if name.startswith(prefix):
yield prefix, content
@property
def longreprtext(self) -> str:
"""Read-only property that returns the full string representation of
``longrepr``.
.. versionadded:: 3.0
"""
file = StringIO()
tw = TerminalWriter(file)
tw.hasmarkup = False
self.toterminal(tw)
exc = file.getvalue()
return exc.strip()
@property
def caplog(self) -> str:
"""Return captured log lines, if log capturing is enabled.
.. versionadded:: 3.5
"""
return "\n".join(
content for (prefix, content) in self.get_sections("Captured log")
)
@property
def capstdout(self) -> str:
"""Return captured text from stdout, if capturing is enabled.
.. versionadded:: 3.0
"""
return "".join(
content for (prefix, content) in self.get_sections("Captured stdout")
)
@property
def capstderr(self) -> str:
"""Return captured text from stderr, if capturing is enabled.
.. versionadded:: 3.0
"""
return "".join(
content for (prefix, content) in self.get_sections("Captured stderr")
)
@property
def passed(self) -> bool:
"""Whether the outcome is passed."""
return self.outcome == "passed"
@property
def failed(self) -> bool:
"""Whether the outcome is failed."""
return self.outcome == "failed"
@property
def skipped(self) -> bool:
"""Whether the outcome is skipped."""
return self.outcome == "skipped"
@property
def fspath(self) -> str:
"""The path portion of the reported node, as a string."""
return self.nodeid.split("::")[0]
@property
def count_towards_summary(self) -> bool:
"""**Experimental** Whether this report should be counted towards the
totals shown at the end of the test session: "1 passed, 1 failure, etc".
.. note::
This function is considered **experimental**, so beware that it is subject to changes
even in patch releases.
"""
return True
@property
def head_line(self) -> str | None:
"""**Experimental** The head line shown with longrepr output for this
report, more commonly during traceback representation during
failures::
________ Test.foo ________
In the example above, the head_line is "Test.foo".
.. note::
This function is considered **experimental**, so beware that it is subject to changes
even in patch releases.
"""
if self.location is not None:
fspath, lineno, domain = self.location
return domain
return None
def _get_verbose_word_with_markup(
self, config: Config, default_markup: Mapping[str, bool]
) -> tuple[str, Mapping[str, bool]]:
_category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config
)
if isinstance(verbose, str):
return verbose, default_markup
if isinstance(verbose, Sequence) and len(verbose) == 2:
word, markup = verbose
if isinstance(word, str) and isinstance(markup, Mapping):
return word, markup
fail( # pragma: no cover
"pytest_report_teststatus() hook (from a plugin) returned "
f"an invalid verbose value: {verbose!r}.\nExpected either a string "
"or a tuple of (word, markup)."
)
def _to_json(self) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries,
suitable for serialization.
This was originally the serialize_report() function from xdist (ca03269).
Experimental method.
"""
return _report_to_json(self)
@classmethod
def _from_json(cls, reportdict: dict[str, object]) -> Self:
"""Create either a TestReport or CollectReport, depending on the calling class.
It is the callers responsibility to know which class to pass here.
This was originally the serialize_report() function from xdist (ca03269).
Experimental method.
"""
kwargs = _report_kwargs_from_json(reportdict)
return cls(**kwargs)
def _report_unserialization_failure(
type_name: str, report_class: type[BaseReport], reportdict
) -> NoReturn:
url = "https://github.com/pytest-dev/pytest/issues"
stream = StringIO()
pprint("-" * 100, stream=stream)
pprint(f"INTERNALERROR: Unknown entry type returned: {type_name}", stream=stream)
pprint(f"report_name: {report_class}", stream=stream)
pprint(reportdict, stream=stream)
pprint(f"Please report this bug at {url}", stream=stream)
pprint("-" * 100, stream=stream)
raise RuntimeError(stream.getvalue())
@final
class TestReport(BaseReport):
"""Basic test report object (also used for setup and teardown calls if
they fail).
Reports can contain arbitrary extra attributes.
"""
__test__ = False
# Defined by skipping plugin.
# xfail reason if xfailed, otherwise not defined. Use hasattr to distinguish.
wasxfail: str
def __init__(
self,
nodeid: str,
location: tuple[str, int | None, str],
keywords: Mapping[str, Any],
outcome: Literal["passed", "failed", "skipped"],
longrepr: None
| ExceptionInfo[BaseException]
| tuple[str, int, str]
| str
| TerminalRepr,
when: Literal["setup", "call", "teardown"],
sections: Iterable[tuple[str, str]] = (),
duration: float = 0,
start: float = 0,
stop: float = 0,
user_properties: Iterable[tuple[str, object]] | None = None,
**extra,
) -> None:
#: Normalized collection nodeid.
self.nodeid = nodeid
#: A (filesystempath, lineno, domaininfo) tuple indicating the
#: actual location of a test item - it might be different from the
#: collected one e.g. if a method is inherited from a different module.
#: The filesystempath may be relative to ``config.rootdir``.
#: The line number is 0-based.
self.location: tuple[str, int | None, str] = location
#: A name -> value dictionary containing all keywords and
#: markers associated with a test invocation.
self.keywords: Mapping[str, Any] = keywords
#: Test outcome, always one of "passed", "failed", "skipped".
self.outcome = outcome
#: None or a failure representation.
self.longrepr = longrepr
#: One of 'setup', 'call', 'teardown' to indicate runtest phase.
self.when: Literal["setup", "call", "teardown"] = when
#: User properties is a list of tuples (name, value) that holds user
#: defined properties of the test.
self.user_properties = list(user_properties or [])
#: Tuples of str ``(heading, content)`` with extra information
#: for the test report. Used by pytest to add text captured
#: from ``stdout``, ``stderr``, and intercepted logging events. May
#: be used by other plugins to add arbitrary information to reports.
self.sections = list(sections)
#: Time it took to run just the test.
self.duration: float = duration
#: The system time when the call started, in seconds since the epoch.
self.start: float = start
#: The system time when the call ended, in seconds since the epoch.
self.stop: float = stop
self.__dict__.update(extra)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.nodeid!r} when={self.when!r} outcome={self.outcome!r}>"
@classmethod
def from_item_and_call(cls, item: Item, call: CallInfo[None]) -> TestReport:
"""Create and fill a TestReport with standard item and call info.
:param item: The item.
:param call: The call info.
"""
when = call.when
# Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect"
duration = call.duration
start = call.start
stop = call.stop
keywords = {x: 1 for x in item.keywords}
excinfo = call.excinfo
sections = []
if not call.excinfo:
outcome: Literal["passed", "failed", "skipped"] = "passed"
longrepr: (
None
| ExceptionInfo[BaseException]
| tuple[str, int, str]
| str
| TerminalRepr
) = None
else:
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
longrepr = excinfo
elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped"
r = excinfo._getreprcrash()
assert r is not None, (
"There should always be a traceback entry for skipping a test."
)
if excinfo.value._use_item_location:
path, line = item.reportinfo()[:2]
assert line is not None
longrepr = os.fspath(path), line + 1, r.message
else:
longrepr = (str(r.path), r.lineno, r.message)
else:
outcome = "failed"
if call.when == "call":
longrepr = item.repr_failure(excinfo)
else: # exception in setup or teardown
longrepr = item._repr_failure_py(
excinfo, style=item.config.getoption("tbstyle", "auto")
)
for rwhen, key, content in item._report_sections:
sections.append((f"Captured {key} {rwhen}", content))
return cls(
item.nodeid,
item.location,
keywords,
outcome,
longrepr,
when,
sections,
duration,
start,
stop,
user_properties=item.user_properties,
)
@final
class CollectReport(BaseReport):
"""Collection report object.
Reports can contain arbitrary extra attributes.
"""
when = "collect"
def __init__(
self,
nodeid: str,
outcome: Literal["passed", "failed", "skipped"],
longrepr: None
| ExceptionInfo[BaseException]
| tuple[str, int, str]
| str
| TerminalRepr,
result: list[Item | Collector] | None,
sections: Iterable[tuple[str, str]] = (),
**extra,
) -> None:
#: Normalized collection nodeid.
self.nodeid = nodeid
#: Test outcome, always one of "passed", "failed", "skipped".
self.outcome = outcome
#: None or a failure representation.
self.longrepr = longrepr
#: The collected items and collection nodes.
self.result = result or []
#: Tuples of str ``(heading, content)`` with extra information
#: for the test report. Used by pytest to add text captured
#: from ``stdout``, ``stderr``, and intercepted logging events. May
#: be used by other plugins to add arbitrary information to reports.
self.sections = list(sections)
self.__dict__.update(extra)
@property
def location( # type:ignore[override]
self,
) -> tuple[str, int | None, str] | None:
return (self.fspath, None, self.fspath)
def __repr__(self) -> str:
return f"<CollectReport {self.nodeid!r} lenresult={len(self.result)} outcome={self.outcome!r}>"
class CollectErrorRepr(TerminalRepr):
def __init__(self, msg: str) -> None:
self.longrepr = msg
def toterminal(self, out: TerminalWriter) -> None:
out.line(self.longrepr, red=True)
def pytest_report_to_serializable(
report: CollectReport | TestReport,
) -> dict[str, Any] | None:
if isinstance(report, (TestReport, CollectReport)):
data = report._to_json()
data["$report_type"] = report.__class__.__name__
return data
# TODO: Check if this is actually reachable.
return None # type: ignore[unreachable]
def pytest_report_from_serializable(
data: dict[str, Any],
) -> CollectReport | TestReport | None:
if "$report_type" in data:
if data["$report_type"] == "TestReport":
return TestReport._from_json(data)
elif data["$report_type"] == "CollectReport":
return CollectReport._from_json(data)
assert False, "Unknown report_type unserialize data: {}".format(
data["$report_type"]
)
return None
def _report_to_json(report: BaseReport) -> dict[str, Any]:
"""Return the contents of this report as a dict of builtin entries,
suitable for serialization.
This was originally the serialize_report() function from xdist (ca03269).
"""
def serialize_repr_entry(
entry: ReprEntry | ReprEntryNative,
) -> dict[str, Any]:
data = dataclasses.asdict(entry)
for key, value in data.items():
if hasattr(value, "__dict__"):
data[key] = dataclasses.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data}
return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> dict[str, Any]:
result = dataclasses.asdict(reprtraceback)
result["reprentries"] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries
]
return result
def serialize_repr_crash(
reprcrash: ReprFileLocation | None,
) -> dict[str, Any] | None:
if reprcrash is not None:
return dataclasses.asdict(reprcrash)
else:
return None
def serialize_exception_longrepr(rep: BaseReport) -> dict[str, Any]:
assert rep.longrepr is not None
# TODO: Investigate whether the duck typing is really necessary here.
longrepr = cast(ExceptionRepr, rep.longrepr)
result: dict[str, Any] = {
"reprcrash": serialize_repr_crash(longrepr.reprcrash),
"reprtraceback": serialize_repr_traceback(longrepr.reprtraceback),
"sections": longrepr.sections,
}
if isinstance(longrepr, ExceptionChainRepr):
result["chain"] = []
for repr_traceback, repr_crash, description in longrepr.chain:
result["chain"].append(
(
serialize_repr_traceback(repr_traceback),
serialize_repr_crash(repr_crash),
description,
)
)
else:
result["chain"] = None
return result
d = report.__dict__.copy()
if hasattr(report.longrepr, "toterminal"):
if hasattr(report.longrepr, "reprtraceback") and hasattr(
report.longrepr, "reprcrash"
):
d["longrepr"] = serialize_exception_longrepr(report)
else:
d["longrepr"] = str(report.longrepr)
else:
d["longrepr"] = report.longrepr
for name in d:
if isinstance(d[name], os.PathLike):
d[name] = os.fspath(d[name])
elif name == "result":
d[name] = None # for now
return d
def _report_kwargs_from_json(reportdict: dict[str, Any]) -> dict[str, Any]:
"""Return **kwargs that can be used to construct a TestReport or
CollectReport instance.
This was originally the serialize_report() function from xdist (ca03269).
"""
def deserialize_repr_entry(entry_data):
data = entry_data["data"]
entry_type = entry_data["type"]
if entry_type == "ReprEntry":
reprfuncargs = None
reprfileloc = None
reprlocals = None
if data["reprfuncargs"]:
reprfuncargs = ReprFuncArgs(**data["reprfuncargs"])
if data["reprfileloc"]:
reprfileloc = ReprFileLocation(**data["reprfileloc"])
if data["reprlocals"]:
reprlocals = ReprLocals(data["reprlocals"]["lines"])
reprentry: ReprEntry | ReprEntryNative = ReprEntry(
lines=data["lines"],
reprfuncargs=reprfuncargs,
reprlocals=reprlocals,
reprfileloc=reprfileloc,
style=data["style"],
)
elif entry_type == "ReprEntryNative":
reprentry = ReprEntryNative(data["lines"])
else:
_report_unserialization_failure(entry_type, TestReport, reportdict)
return reprentry
def deserialize_repr_traceback(repr_traceback_dict):
repr_traceback_dict["reprentries"] = [
deserialize_repr_entry(x) for x in repr_traceback_dict["reprentries"]
]
return ReprTraceback(**repr_traceback_dict)
def deserialize_repr_crash(repr_crash_dict: dict[str, Any] | None):
if repr_crash_dict is not None:
return ReprFileLocation(**repr_crash_dict)
else:
return None
if (
reportdict["longrepr"]
and "reprcrash" in reportdict["longrepr"]
and "reprtraceback" in reportdict["longrepr"]
):
reprtraceback = deserialize_repr_traceback(
reportdict["longrepr"]["reprtraceback"]
)
reprcrash = deserialize_repr_crash(reportdict["longrepr"]["reprcrash"])
if reportdict["longrepr"]["chain"]:
chain = []
for repr_traceback_data, repr_crash_data, description in reportdict[
"longrepr"
]["chain"]:
chain.append(
(
deserialize_repr_traceback(repr_traceback_data),
deserialize_repr_crash(repr_crash_data),
description,
)
)
exception_info: ExceptionChainRepr | ReprExceptionInfo = ExceptionChainRepr(
chain
)
else:
exception_info = ReprExceptionInfo(
reprtraceback=reprtraceback,
reprcrash=reprcrash,
)
for section in reportdict["longrepr"]["sections"]:
exception_info.addsection(*section)
reportdict["longrepr"] = exception_info
return reportdict

View File

@@ -0,0 +1,571 @@
# mypy: allow-untyped-defs
"""Basic collect and runtest protocol implementations."""
from __future__ import annotations
import bdb
from collections.abc import Callable
import dataclasses
import os
import sys
import types
from typing import cast
from typing import final
from typing import Generic
from typing import Literal
from typing import TYPE_CHECKING
from typing import TypeVar
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
from _pytest import timing
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.nodes import Collector
from _pytest.nodes import Directory
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.outcomes import Exit
from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
if TYPE_CHECKING:
from _pytest.main import Session
from _pytest.terminal import TerminalReporter
#
# pytest plugin hooks.
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "Reporting", after="general")
group.addoption(
"--durations",
action="store",
type=int,
default=None,
metavar="N",
help="Show N slowest setup/test durations (N=0 for all)",
)
group.addoption(
"--durations-min",
action="store",
type=float,
default=None,
metavar="N",
help="Minimal duration in seconds for inclusion in slowest list. "
"Default: 0.005 (or 0.0 if -vv is given).",
)
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
durations = terminalreporter.config.option.durations
durations_min = terminalreporter.config.option.durations_min
verbose = terminalreporter.config.get_verbosity()
if durations is None:
return
if durations_min is None:
durations_min = 0.005 if verbose < 2 else 0.0
tr = terminalreporter
dlist = []
for replist in tr.stats.values():
for rep in replist:
if hasattr(rep, "duration"):
dlist.append(rep)
if not dlist:
return
dlist.sort(key=lambda x: x.duration, reverse=True)
if not durations:
tr.write_sep("=", "slowest durations")
else:
tr.write_sep("=", f"slowest {durations} durations")
dlist = dlist[:durations]
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
tr.write_line("")
message = f"({len(dlist) - i} durations < {durations_min:g}s hidden."
if terminalreporter.config.option.durations_min is None:
message += " Use -vv to show these durations."
message += ")"
tr.write_line(message)
break
tr.write_line(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}")
def pytest_sessionstart(session: Session) -> None:
session._setupstate = SetupState()
def pytest_sessionfinish(session: Session) -> None:
session._setupstate.teardown_exact(None)
def pytest_runtest_protocol(item: Item, nextitem: Item | None) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem)
ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
return True
def runtestprotocol(
item: Item, log: bool = True, nextitem: Item | None = None
) -> list[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
# This only happens if the item is re-run, as is done by
# pytest-rerunfailures.
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
reports.append(call_and_report(item, "call", log))
# If the session is about to fail or stop, teardown everything - this is
# necessary to correctly report fixture teardown errors (see #11706)
if item.session.shouldfail or item.session.shouldstop:
nextitem = None
reports.append(call_and_report(item, "teardown", log, nextitem=nextitem))
# After all teardown hooks have been called
# want funcargs and request info to go away.
if hasrequest:
item._request = False # type: ignore[attr-defined]
item.funcargs = None # type: ignore[attr-defined]
return reports
def show_test_item(item: Item) -> None:
"""Show test function, parameters and the fixtures of the test item."""
tw = item.config.get_terminal_writer()
tw.line()
tw.write(" " * 8)
tw.write(item.nodeid)
used_fixtures = sorted(getattr(item, "fixturenames", []))
if used_fixtures:
tw.write(" (fixtures used: {})".format(", ".join(used_fixtures)))
tw.flush()
def pytest_runtest_setup(item: Item) -> None:
_update_current_test_var(item, "setup")
item.session._setupstate.setup(item)
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
if sys.version_info >= (3, 12, 0):
del sys.last_exc # type:ignore[attr-defined]
except AttributeError:
pass
try:
item.runtest()
except Exception as e:
# Store trace info to allow postmortem debugging
sys.last_type = type(e)
sys.last_value = e
if sys.version_info >= (3, 12, 0):
sys.last_exc = e # type:ignore[attr-defined]
assert e.__traceback__ is not None
# Skip *this* frame
sys.last_traceback = e.__traceback__.tb_next
raise
def pytest_runtest_teardown(item: Item, nextitem: Item | None) -> None:
_update_current_test_var(item, "teardown")
item.session._setupstate.teardown_exact(nextitem)
_update_current_test_var(item, None)
def _update_current_test_var(
item: Item, when: Literal["setup", "call", "teardown"] | None
) -> None:
"""Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
If ``when`` is None, delete ``PYTEST_CURRENT_TEST`` from the environment.
"""
var_name = "PYTEST_CURRENT_TEST"
if when:
value = f"{item.nodeid} ({when})"
# don't allow null bytes on environment variables (see #2644, #2957)
value = value.replace("\x00", "(null)")
os.environ[var_name] = value
else:
os.environ.pop(var_name)
def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if report.when in ("setup", "teardown"):
if report.failed:
# category, shortletter, verbose-word
return "error", "E", "ERROR"
elif report.skipped:
return "skipped", "s", "SKIPPED"
else:
return "", "", ""
return None
#
# Implementation
def call_and_report(
item: Item, when: Literal["setup", "call", "teardown"], log: bool = True, **kwds
) -> TestReport:
ihook = item.ihook
if when == "setup":
runtest_hook: Callable[..., None] = ihook.pytest_runtest_setup
elif when == "call":
runtest_hook = ihook.pytest_runtest_call
elif when == "teardown":
runtest_hook = ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: tuple[type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
call = CallInfo.from_call(
lambda: runtest_hook(item=item, **kwds), when=when, reraise=reraise
)
report: TestReport = ihook.pytest_runtest_makereport(item=item, call=call)
if log:
ihook.pytest_runtest_logreport(report=report)
if check_interactive_exception(call, report):
ihook.pytest_exception_interact(node=item, call=call, report=report)
return report
def check_interactive_exception(call: CallInfo[object], report: BaseReport) -> bool:
"""Check whether the call raised an exception that should be reported as
interactive."""
if call.excinfo is None:
# Didn't raise.
return False
if hasattr(report, "wasxfail"):
# Exception was expected.
return False
if isinstance(call.excinfo.value, (Skipped, bdb.BdbQuit)):
# Special control flow exception.
return False
return True
TResult = TypeVar("TResult", covariant=True)
@final
@dataclasses.dataclass
class CallInfo(Generic[TResult]):
"""Result/Exception info of a function invocation."""
_result: TResult | None
#: The captured exception of the call, if it raised.
excinfo: ExceptionInfo[BaseException] | None
#: The system time when the call started, in seconds since the epoch.
start: float
#: The system time when the call ended, in seconds since the epoch.
stop: float
#: The call duration, in seconds.
duration: float
#: The context of invocation: "collect", "setup", "call" or "teardown".
when: Literal["collect", "setup", "call", "teardown"]
def __init__(
self,
result: TResult | None,
excinfo: ExceptionInfo[BaseException] | None,
start: float,
stop: float,
duration: float,
when: Literal["collect", "setup", "call", "teardown"],
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self._result = result
self.excinfo = excinfo
self.start = start
self.stop = stop
self.duration = duration
self.when = when
@property
def result(self) -> TResult:
"""The return value of the call, if it didn't raise.
Can only be accessed if excinfo is None.
"""
if self.excinfo is not None:
raise AttributeError(f"{self!r} has no valid result")
# The cast is safe because an exception wasn't raised, hence
# _result has the expected function return type (which may be
# None, that's why a cast and not an assert).
return cast(TResult, self._result)
@classmethod
def from_call(
cls,
func: Callable[[], TResult],
when: Literal["collect", "setup", "call", "teardown"],
reraise: type[BaseException] | tuple[type[BaseException], ...] | None = None,
) -> CallInfo[TResult]:
"""Call func, wrapping the result in a CallInfo.
:param func:
The function to call. Called without arguments.
:type func: Callable[[], _pytest.runner.TResult]
:param when:
The phase in which the function is called.
:param reraise:
Exception or exceptions that shall propagate if raised by the
function, instead of being wrapped in the CallInfo.
"""
excinfo = None
instant = timing.Instant()
try:
result: TResult | None = func()
except BaseException:
excinfo = ExceptionInfo.from_current()
if reraise is not None and isinstance(excinfo.value, reraise):
raise
result = None
duration = instant.elapsed()
return cls(
start=duration.start.time,
stop=duration.stop.time,
duration=duration.seconds,
when=when,
result=result,
excinfo=excinfo,
_ispytest=True,
)
def __repr__(self) -> str:
if self.excinfo is None:
return f"<CallInfo when={self.when!r} result: {self._result!r}>"
return f"<CallInfo when={self.when!r} excinfo={self.excinfo!r}>"
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
return TestReport.from_item_and_call(item, call)
def pytest_make_collect_report(collector: Collector) -> CollectReport:
def collect() -> list[Item | Collector]:
# Before collecting, if this is a Directory, load the conftests.
# If a conftest import fails to load, it is considered a collection
# error of the Directory collector. This is why it's done inside of the
# CallInfo wrapper.
#
# Note: initial conftests are loaded early, not here.
if isinstance(collector, Directory):
collector.config.pluginmanager._loadconftestmodules(
collector.path,
collector.config.getoption("importmode"),
rootpath=collector.config.rootpath,
consider_namespace_packages=collector.config.getini(
"consider_namespace_packages"
),
)
return list(collector.collect())
call = CallInfo.from_call(
collect, "collect", reraise=(KeyboardInterrupt, SystemExit)
)
longrepr: None | tuple[str, int, str] | str | TerminalRepr = None
if not call.excinfo:
outcome: Literal["passed", "skipped", "failed"] = "passed"
else:
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
if unittest is not None:
skip_exceptions.append(unittest.SkipTest)
if isinstance(call.excinfo.value, tuple(skip_exceptions)):
outcome = "skipped"
r_ = collector._repr_failure_py(call.excinfo, "line")
assert isinstance(r_, ExceptionChainRepr), repr(r_)
r = r_.reprcrash
assert r
longrepr = (str(r.path), r.lineno, r.message)
else:
outcome = "failed"
errorinfo = collector.repr_failure(call.excinfo)
if not hasattr(errorinfo, "toterminal"):
assert isinstance(errorinfo, str)
errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo
result = call.result if not call.excinfo else None
rep = CollectReport(collector.nodeid, outcome, longrepr, result)
rep.call = call # type: ignore # see collect_one_node
return rep
class SetupState:
"""Shared state for setting up/tearing down test items or collectors
in a session.
Suppose we have a collection tree as follows:
<Session session>
<Module mod1>
<Function item1>
<Module mod2>
<Function item2>
The SetupState maintains a stack. The stack starts out empty:
[]
During the setup phase of item1, setup(item1) is called. What it does
is:
push session to stack, run session.setup()
push mod1 to stack, run mod1.setup()
push item1 to stack, run item1.setup()
The stack is:
[session, mod1, item1]
While the stack is in this shape, it is allowed to add finalizers to
each of session, mod1, item1 using addfinalizer().
During the teardown phase of item1, teardown_exact(item2) is called,
where item2 is the next item to item1. What it does is:
pop item1 from stack, run its teardowns
pop mod1 from stack, run its teardowns
mod1 was popped because it ended its purpose with item1. The stack is:
[session]
During the setup phase of item2, setup(item2) is called. What it does
is:
push mod2 to stack, run mod2.setup()
push item2 to stack, run item2.setup()
Stack:
[session, mod2, item2]
During the teardown phase of item2, teardown_exact(None) is called,
because item2 is the last item. What it does is:
pop item2 from stack, run its teardowns
pop mod2 from stack, run its teardowns
pop session from stack, run its teardowns
Stack:
[]
The end!
"""
def __init__(self) -> None:
# The stack is in the dict insertion order.
self.stack: dict[
Node,
tuple[
# Node's finalizers.
list[Callable[[], object]],
# Node's exception and original traceback, if its setup raised.
tuple[OutcomeException | Exception, types.TracebackType | None] | None,
],
] = {}
def setup(self, item: Item) -> None:
"""Setup objects along the collector chain to the item."""
needed_collectors = item.listchain()
# If a collector fails its setup, fail its entire subtree of items.
# The setup is not retried for each item - the same exception is used.
for col, (finalizers, exc) in self.stack.items():
assert col in needed_collectors, "previous item was not torn down properly"
if exc:
raise exc[0].with_traceback(exc[1])
for col in needed_collectors[len(self.stack) :]:
assert col not in self.stack
# Push onto the stack.
self.stack[col] = ([col.teardown], None)
try:
col.setup()
except TEST_OUTCOME as exc:
self.stack[col] = (self.stack[col][0], (exc, exc.__traceback__))
raise
def addfinalizer(self, finalizer: Callable[[], object], node: Node) -> None:
"""Attach a finalizer to the given node.
The node must be currently active in the stack.
"""
assert node and not isinstance(node, tuple)
assert callable(finalizer)
assert node in self.stack, (node, self.stack)
self.stack[node][0].append(finalizer)
def teardown_exact(self, nextitem: Item | None) -> None:
"""Teardown the current stack up until reaching nodes that nextitem
also descends from.
When nextitem is None (meaning we're at the last item), the entire
stack is torn down.
"""
needed_collectors = (nextitem and nextitem.listchain()) or []
exceptions: list[BaseException] = []
while self.stack:
if list(self.stack.keys()) == needed_collectors[: len(self.stack)]:
break
node, (finalizers, _) = self.stack.popitem()
these_exceptions = []
while finalizers:
fin = finalizers.pop()
try:
fin()
except TEST_OUTCOME as e:
these_exceptions.append(e)
if len(these_exceptions) == 1:
exceptions.extend(these_exceptions)
elif these_exceptions:
msg = f"errors while tearing down {node!r}"
exceptions.append(BaseExceptionGroup(msg, these_exceptions[::-1]))
if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
raise BaseExceptionGroup("errors during test teardown", exceptions[::-1])
if nextitem is None:
assert not self.stack
def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook
ihook.pytest_collectstart(collector=collector)
rep: CollectReport = ihook.pytest_make_collect_report(collector=collector)
call = rep.__dict__.pop("call", None)
if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep)
return rep

View File

@@ -0,0 +1,91 @@
"""
Scope definition and related utilities.
Those are defined here, instead of in the 'fixtures' module because
their use is spread across many other pytest modules, and centralizing it in 'fixtures'
would cause circular references.
Also this makes the module light to import, as it should.
"""
from __future__ import annotations
from enum import Enum
from functools import total_ordering
from typing import Literal
_ScopeName = Literal["session", "package", "module", "class", "function"]
@total_ordering
class Scope(Enum):
"""
Represents one of the possible fixture scopes in pytest.
Scopes are ordered from lower to higher, that is:
->>> higher ->>>
Function < Class < Module < Package < Session
<<<- lower <<<-
"""
# Scopes need to be listed from lower to higher.
Function = "function"
Class = "class"
Module = "module"
Package = "package"
Session = "session"
def next_lower(self) -> Scope:
"""Return the next lower scope."""
index = _SCOPE_INDICES[self]
if index == 0:
raise ValueError(f"{self} is the lower-most scope")
return _ALL_SCOPES[index - 1]
def next_higher(self) -> Scope:
"""Return the next higher scope."""
index = _SCOPE_INDICES[self]
if index == len(_SCOPE_INDICES) - 1:
raise ValueError(f"{self} is the upper-most scope")
return _ALL_SCOPES[index + 1]
def __lt__(self, other: Scope) -> bool:
self_index = _SCOPE_INDICES[self]
other_index = _SCOPE_INDICES[other]
return self_index < other_index
@classmethod
def from_user(
cls, scope_name: _ScopeName, descr: str, where: str | None = None
) -> Scope:
"""
Given a scope name from the user, return the equivalent Scope enum. Should be used
whenever we want to convert a user provided scope name to its enum object.
If the scope name is invalid, construct a user friendly message and call pytest.fail.
"""
from _pytest.outcomes import fail
try:
# Holding this reference is necessary for mypy at the moment.
scope = Scope(scope_name)
except ValueError:
fail(
"{} {}got an unexpected scope value '{}'".format(
descr, f"from {where} " if where else "", scope_name
),
pytrace=False,
)
return scope
_ALL_SCOPES = list(Scope)
_SCOPE_INDICES = {scope: index for index, scope in enumerate(_ALL_SCOPES)}
# Ordered list of scopes which can contain many tests (in practice all except Function).
HIGH_SCOPES = [x for x in Scope if x is not Scope.Function]

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from collections.abc import Generator
from _pytest._io.saferepr import saferepr
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
from _pytest.scope import Scope
import pytest
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setuponly",
"--setup-only",
action="store_true",
help="Only setup fixtures, do not execute tests",
)
group.addoption(
"--setupshow",
"--setup-show",
action="store_true",
help="Show setup of fixtures while executing tests",
)
@pytest.hookimpl(wrapper=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest
) -> Generator[None, object, object]:
try:
return (yield)
finally:
if request.config.option.setupshow:
if hasattr(request, "param"):
# Save the fixture parameter so ._show_fixture_action() can
# display it now and during the teardown (in .finish()).
if fixturedef.ids:
if callable(fixturedef.ids):
param = fixturedef.ids(request.param)
else:
param = fixturedef.ids[request.param_index]
else:
param = request.param
fixturedef.cached_param = param # type: ignore[attr-defined]
_show_fixture_action(fixturedef, request.config, "SETUP")
def pytest_fixture_post_finalizer(
fixturedef: FixtureDef[object], request: SubRequest
) -> None:
if fixturedef.cached_result is not None:
config = request.config
if config.option.setupshow:
_show_fixture_action(fixturedef, request.config, "TEARDOWN")
if hasattr(fixturedef, "cached_param"):
del fixturedef.cached_param
def _show_fixture_action(
fixturedef: FixtureDef[object], config: Config, msg: str
) -> None:
capman = config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture()
tw = config.get_terminal_writer()
tw.line()
# Use smaller indentation the higher the scope: Session = 0, Package = 1, etc.
scope_indent = list(reversed(Scope)).index(fixturedef._scope)
tw.write(" " * 2 * scope_indent)
scopename = fixturedef.scope[0].upper()
tw.write(f"{msg:<8} {scopename} {fixturedef.argname}")
if msg == "SETUP":
deps = sorted(arg for arg in fixturedef.argnames if arg != "request")
if deps:
tw.write(" (fixtures used: {})".format(", ".join(deps)))
if hasattr(fixturedef, "cached_param"):
tw.write(f"[{saferepr(fixturedef.cached_param, maxsize=42)}]")
tw.flush()
if capman:
capman.resume_global_capture()
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setuponly:
config.option.setupshow = True
return None

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
import pytest
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setupplan",
"--setup-plan",
action="store_true",
help="Show what fixtures and tests would be executed but "
"don't execute anything",
)
@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(
fixturedef: FixtureDef[object], request: SubRequest
) -> object | None:
# Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request)
fixturedef.cached_result = (None, my_cache_key, None)
return fixturedef.cached_result
return None
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
if config.option.setupplan:
config.option.setuponly = True
config.option.setupshow = True
return None

View File

@@ -0,0 +1,316 @@
# mypy: allow-untyped-defs
"""Support for skip/xfail functions and markers."""
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Mapping
import dataclasses
import os
import platform
import sys
import traceback
from typing import Optional
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.mark.structures import Mark
from _pytest.nodes import Item
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.raises import AbstractRaises
from _pytest.reports import BaseReport
from _pytest.reports import TestReport
from _pytest.runner import CallInfo
from _pytest.stash import StashKey
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--runxfail",
action="store_true",
dest="runxfail",
default=False,
help="Report the results of xfail tests as if they were not marked",
)
parser.addini(
"xfail_strict",
"Default for the strict parameter of xfail "
"markers when not given explicitly (default: False)",
default=False,
type="bool",
)
def pytest_configure(config: Config) -> None:
if config.option.runxfail:
# yay a hack
import pytest
old = pytest.xfail
config.add_cleanup(lambda: setattr(pytest, "xfail", old))
def nop(*args, **kwargs):
pass
nop.Exception = xfail.Exception # type: ignore[attr-defined]
setattr(pytest, "xfail", nop)
config.addinivalue_line(
"markers",
"skip(reason=None): skip the given test function with an optional reason. "
'Example: skip(reason="no way of currently testing this") skips the '
"test.",
)
config.addinivalue_line(
"markers",
"skipif(condition, ..., *, reason=...): "
"skip the given test function if any of the conditions evaluate to True. "
"Example: skipif(sys.platform == 'win32') skips the test if we are on the win32 platform. "
"See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-skipif",
)
config.addinivalue_line(
"markers",
"xfail(condition, ..., *, reason=..., run=True, raises=None, strict=xfail_strict): "
"mark the test function as an expected failure if any of the conditions "
"evaluate to True. Optionally specify a reason for better reporting "
"and run=False if you don't even want to execute the test function. "
"If only specific exception(s) are expected, you can list them in "
"raises, and if the test fails in other ways, it will be reported as "
"a true failure. See https://docs.pytest.org/en/stable/reference/reference.html#pytest-mark-xfail",
)
def evaluate_condition(item: Item, mark: Mark, condition: object) -> tuple[bool, str]:
"""Evaluate a single skipif/xfail condition.
If an old-style string condition is given, it is eval()'d, otherwise the
condition is bool()'d. If this fails, an appropriately formatted pytest.fail
is raised.
Returns (result, reason). The reason is only relevant if the result is True.
"""
# String condition.
if isinstance(condition, str):
globals_ = {
"os": os,
"sys": sys,
"platform": platform,
"config": item.config,
}
for dictionary in reversed(
item.ihook.pytest_markeval_namespace(config=item.config)
):
if not isinstance(dictionary, Mapping):
raise ValueError(
f"pytest_markeval_namespace() needs to return a dict, got {dictionary!r}"
)
globals_.update(dictionary)
if hasattr(item, "obj"):
globals_.update(item.obj.__globals__)
try:
filename = f"<{mark.name} condition>"
condition_code = compile(condition, filename, "eval")
result = eval(condition_code, globals_)
except SyntaxError as exc:
msglines = [
f"Error evaluating {mark.name!r} condition",
" " + condition,
" " + " " * (exc.offset or 0) + "^",
"SyntaxError: invalid syntax",
]
fail("\n".join(msglines), pytrace=False)
except Exception as exc:
msglines = [
f"Error evaluating {mark.name!r} condition",
" " + condition,
*traceback.format_exception_only(type(exc), exc),
]
fail("\n".join(msglines), pytrace=False)
# Boolean condition.
else:
try:
result = bool(condition)
except Exception as exc:
msglines = [
f"Error evaluating {mark.name!r} condition as a boolean",
*traceback.format_exception_only(type(exc), exc),
]
fail("\n".join(msglines), pytrace=False)
reason = mark.kwargs.get("reason", None)
if reason is None:
if isinstance(condition, str):
reason = "condition: " + condition
else:
# XXX better be checked at collection time
msg = (
f"Error evaluating {mark.name!r}: "
+ "you need to specify reason=STRING when using booleans as conditions."
)
fail(msg, pytrace=False)
return result, reason
@dataclasses.dataclass(frozen=True)
class Skip:
"""The result of evaluate_skip_marks()."""
reason: str = "unconditional skip"
def evaluate_skip_marks(item: Item) -> Skip | None:
"""Evaluate skip and skipif marks on item, returning Skip if triggered."""
for mark in item.iter_markers(name="skipif"):
if "condition" not in mark.kwargs:
conditions = mark.args
else:
conditions = (mark.kwargs["condition"],)
# Unconditional.
if not conditions:
reason = mark.kwargs.get("reason", "")
return Skip(reason)
# If any of the conditions are true.
for condition in conditions:
result, reason = evaluate_condition(item, mark, condition)
if result:
return Skip(reason)
for mark in item.iter_markers(name="skip"):
try:
return Skip(*mark.args, **mark.kwargs)
except TypeError as e:
raise TypeError(str(e) + " - maybe you meant pytest.mark.skipif?") from None
return None
@dataclasses.dataclass(frozen=True)
class Xfail:
"""The result of evaluate_xfail_marks()."""
__slots__ = ("raises", "reason", "run", "strict")
reason: str
run: bool
strict: bool
raises: (
type[BaseException]
| tuple[type[BaseException], ...]
| AbstractRaises[BaseException]
| None
)
def evaluate_xfail_marks(item: Item) -> Xfail | None:
"""Evaluate xfail marks on item, returning Xfail if triggered."""
for mark in item.iter_markers(name="xfail"):
run = mark.kwargs.get("run", True)
strict = mark.kwargs.get("strict", item.config.getini("xfail_strict"))
raises = mark.kwargs.get("raises", None)
if "condition" not in mark.kwargs:
conditions = mark.args
else:
conditions = (mark.kwargs["condition"],)
# Unconditional.
if not conditions:
reason = mark.kwargs.get("reason", "")
return Xfail(reason, run, strict, raises)
# If any of the conditions are true.
for condition in conditions:
result, reason = evaluate_condition(item, mark, condition)
if result:
return Xfail(reason, run, strict, raises)
return None
# Saves the xfail mark evaluation. Can be refreshed during call if None.
xfailed_key = StashKey[Optional[Xfail]]()
@hookimpl(tryfirst=True)
def pytest_runtest_setup(item: Item) -> None:
skipped = evaluate_skip_marks(item)
if skipped:
raise skip.Exception(skipped.reason, _use_item_location=True)
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason)
@hookimpl(wrapper=True)
def pytest_runtest_call(item: Item) -> Generator[None]:
xfailed = item.stash.get(xfailed_key, None)
if xfailed is None:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
if xfailed and not item.config.option.runxfail and not xfailed.run:
xfail("[NOTRUN] " + xfailed.reason)
try:
return (yield)
finally:
# The test run may have added an xfail mark dynamically.
xfailed = item.stash.get(xfailed_key, None)
if xfailed is None:
item.stash[xfailed_key] = xfailed = evaluate_xfail_marks(item)
@hookimpl(wrapper=True)
def pytest_runtest_makereport(
item: Item, call: CallInfo[None]
) -> Generator[None, TestReport, TestReport]:
rep = yield
xfailed = item.stash.get(xfailed_key, None)
if item.config.option.runxfail:
pass # don't interfere
elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception):
assert call.excinfo.value.msg is not None
rep.wasxfail = call.excinfo.value.msg
rep.outcome = "skipped"
elif not rep.skipped and xfailed:
if call.excinfo:
raises = xfailed.raises
if raises is None or (
(
isinstance(raises, (type, tuple))
and isinstance(call.excinfo.value, raises)
)
or (
isinstance(raises, AbstractRaises)
and raises.matches(call.excinfo.value)
)
):
rep.outcome = "skipped"
rep.wasxfail = xfailed.reason
else:
rep.outcome = "failed"
elif call.when == "call":
if xfailed.strict:
rep.outcome = "failed"
rep.longrepr = "[XPASS(strict)] " + xfailed.reason
else:
rep.outcome = "passed"
rep.wasxfail = xfailed.reason
return rep
def pytest_report_teststatus(report: BaseReport) -> tuple[str, str, str] | None:
if hasattr(report, "wasxfail"):
if report.skipped:
return "xfailed", "x", "XFAIL"
elif report.passed:
return "xpassed", "X", "XPASS"
return None

View File

@@ -0,0 +1,116 @@
from __future__ import annotations
from typing import Any
from typing import cast
from typing import Generic
from typing import TypeVar
__all__ = ["Stash", "StashKey"]
T = TypeVar("T")
D = TypeVar("D")
class StashKey(Generic[T]):
"""``StashKey`` is an object used as a key to a :class:`Stash`.
A ``StashKey`` is associated with the type ``T`` of the value of the key.
A ``StashKey`` is unique and cannot conflict with another key.
.. versionadded:: 7.0
"""
__slots__ = ()
class Stash:
r"""``Stash`` is a type-safe heterogeneous mutable mapping that
allows keys and value types to be defined separately from
where it (the ``Stash``) is created.
Usually you will be given an object which has a ``Stash``, for example
:class:`~pytest.Config` or a :class:`~_pytest.nodes.Node`:
.. code-block:: python
stash: Stash = some_object.stash
If a module or plugin wants to store data in this ``Stash``, it creates
:class:`StashKey`\s for its keys (at the module level):
.. code-block:: python
# At the top-level of the module
some_str_key = StashKey[str]()
some_bool_key = StashKey[bool]()
To store information:
.. code-block:: python
# Value type must match the key.
stash[some_str_key] = "value"
stash[some_bool_key] = True
To retrieve the information:
.. code-block:: python
# The static type of some_str is str.
some_str = stash[some_str_key]
# The static type of some_bool is bool.
some_bool = stash[some_bool_key]
.. versionadded:: 7.0
"""
__slots__ = ("_storage",)
def __init__(self) -> None:
self._storage: dict[StashKey[Any], object] = {}
def __setitem__(self, key: StashKey[T], value: T) -> None:
"""Set a value for key."""
self._storage[key] = value
def __getitem__(self, key: StashKey[T]) -> T:
"""Get the value for key.
Raises ``KeyError`` if the key wasn't set before.
"""
return cast(T, self._storage[key])
def get(self, key: StashKey[T], default: D) -> T | D:
"""Get the value for key, or return default if the key wasn't set
before."""
try:
return self[key]
except KeyError:
return default
def setdefault(self, key: StashKey[T], default: T) -> T:
"""Return the value of key if already set, otherwise set the value
of key to default and return default."""
try:
return self[key]
except KeyError:
self[key] = default
return default
def __delitem__(self, key: StashKey[T]) -> None:
"""Delete the value for key.
Raises ``KeyError`` if the key wasn't set before.
"""
del self._storage[key]
def __contains__(self, key: StashKey[T]) -> bool:
"""Return whether key was set."""
return key in self._storage
def __len__(self) -> int:
"""Return how many items exist in the stash."""
return len(self._storage)

View File

@@ -0,0 +1,209 @@
from __future__ import annotations
import dataclasses
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import TYPE_CHECKING
from _pytest import nodes
from _pytest.cacheprovider import Cache
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.reports import TestReport
if TYPE_CHECKING:
from typing_extensions import Self
STEPWISE_CACHE_DIR = "cache/stepwise"
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--sw",
"--stepwise",
action="store_true",
default=False,
dest="stepwise",
help="Exit on test failure and continue from last failing test next time",
)
group.addoption(
"--sw-skip",
"--stepwise-skip",
action="store_true",
default=False,
dest="stepwise_skip",
help="Ignore the first failing test but stop on the next failing test. "
"Implicitly enables --stepwise.",
)
group.addoption(
"--sw-reset",
"--stepwise-reset",
action="store_true",
default=False,
dest="stepwise_reset",
help="Resets stepwise state, restarting the stepwise workflow. "
"Implicitly enables --stepwise.",
)
def pytest_configure(config: Config) -> None:
# --stepwise-skip/--stepwise-reset implies stepwise.
if config.option.stepwise_skip or config.option.stepwise_reset:
config.option.stepwise = True
if config.getoption("stepwise"):
config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin")
def pytest_sessionfinish(session: Session) -> None:
if not session.config.getoption("stepwise"):
assert session.config.cache is not None
if hasattr(session.config, "workerinput"):
# Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641).
return
@dataclasses.dataclass
class StepwiseCacheInfo:
# The nodeid of the last failed test.
last_failed: str | None
# The number of tests in the last time --stepwise was run.
# We use this information as a simple way to invalidate the cache information, avoiding
# confusing behavior in case the cache is stale.
last_test_count: int | None
# The date when the cache was last updated, for information purposes only.
last_cache_date_str: str
@property
def last_cache_date(self) -> datetime:
return datetime.fromisoformat(self.last_cache_date_str)
@classmethod
def empty(cls) -> Self:
return cls(
last_failed=None,
last_test_count=None,
last_cache_date_str=datetime.now().isoformat(),
)
def update_date_to_now(self) -> None:
self.last_cache_date_str = datetime.now().isoformat()
class StepwisePlugin:
def __init__(self, config: Config) -> None:
self.config = config
self.session: Session | None = None
self.report_status: list[str] = []
assert config.cache is not None
self.cache: Cache = config.cache
self.skip: bool = config.getoption("stepwise_skip")
self.reset: bool = config.getoption("stepwise_reset")
self.cached_info = self._load_cached_info()
def _load_cached_info(self) -> StepwiseCacheInfo:
cached_dict: dict[str, Any] | None = self.cache.get(STEPWISE_CACHE_DIR, None)
if cached_dict:
try:
return StepwiseCacheInfo(
cached_dict["last_failed"],
cached_dict["last_test_count"],
cached_dict["last_cache_date_str"],
)
except (KeyError, TypeError) as e:
error = f"{type(e).__name__}: {e}"
self.report_status.append(f"error reading cache, discarding ({error})")
# Cache not found or error during load, return a new cache.
return StepwiseCacheInfo.empty()
def pytest_sessionstart(self, session: Session) -> None:
self.session = session
def pytest_collection_modifyitems(
self, config: Config, items: list[nodes.Item]
) -> None:
last_test_count = self.cached_info.last_test_count
self.cached_info.last_test_count = len(items)
if self.reset:
self.report_status.append("resetting state, not skipping.")
self.cached_info.last_failed = None
return
if not self.cached_info.last_failed:
self.report_status.append("no previously failed tests, not skipping.")
return
if last_test_count is not None and last_test_count != len(items):
self.report_status.append(
f"test count changed, not skipping (now {len(items)} tests, previously {last_test_count})."
)
self.cached_info.last_failed = None
return
# Check all item nodes until we find a match on last failed.
failed_index = None
for index, item in enumerate(items):
if item.nodeid == self.cached_info.last_failed:
failed_index = index
break
# If the previously failed test was not found among the test items,
# do not skip any tests.
if failed_index is None:
self.report_status.append("previously failed test not found, not skipping.")
else:
cache_age = datetime.now() - self.cached_info.last_cache_date
# Round up to avoid showing microseconds.
cache_age = timedelta(seconds=int(cache_age.total_seconds()))
self.report_status.append(
f"skipping {failed_index} already passed items (cache from {cache_age} ago,"
f" use --sw-reset to discard)."
)
deselected = items[:failed_index]
del items[:failed_index]
config.hook.pytest_deselected(items=deselected)
def pytest_runtest_logreport(self, report: TestReport) -> None:
if report.failed:
if self.skip:
# Remove test from the failed ones (if it exists) and unset the skip option
# to make sure the following tests will not be skipped.
if report.nodeid == self.cached_info.last_failed:
self.cached_info.last_failed = None
self.skip = False
else:
# Mark test as the last failing and interrupt the test session.
self.cached_info.last_failed = report.nodeid
assert self.session is not None
self.session.shouldstop = (
"Test failed, continuing from this test next run."
)
else:
# If the test was actually run and did pass.
if report.when == "call":
# Remove test from the failed ones, if exists.
if report.nodeid == self.cached_info.last_failed:
self.cached_info.last_failed = None
def pytest_report_collectionfinish(self) -> list[str] | None:
if self.config.get_verbosity() >= 0 and self.report_status:
return [f"stepwise: {x}" for x in self.report_status]
return None
def pytest_sessionfinish(self) -> None:
if hasattr(self.config, "workerinput"):
# Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641).
return
self.cached_info.update_date_to_now()
self.cache.set(STEPWISE_CACHE_DIR, dataclasses.asdict(self.cached_info))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,152 @@
from __future__ import annotations
import collections
from collections.abc import Callable
import functools
import sys
import threading
import traceback
from typing import NamedTuple
from typing import TYPE_CHECKING
import warnings
from _pytest.config import Config
from _pytest.nodes import Item
from _pytest.stash import StashKey
from _pytest.tracemalloc import tracemalloc_message
import pytest
if TYPE_CHECKING:
pass
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
class ThreadExceptionMeta(NamedTuple):
msg: str
cause_msg: str
exc_value: BaseException | None
thread_exceptions: StashKey[collections.deque[ThreadExceptionMeta | BaseException]] = (
StashKey()
)
def collect_thread_exception(config: Config) -> None:
pop_thread_exception = config.stash[thread_exceptions].pop
errors: list[pytest.PytestUnhandledThreadExceptionWarning | RuntimeError] = []
meta = None
hook_error = None
try:
while True:
try:
meta = pop_thread_exception()
except IndexError:
break
if isinstance(meta, BaseException):
hook_error = RuntimeError("Failed to process thread exception")
hook_error.__cause__ = meta
errors.append(hook_error)
continue
msg = meta.msg
try:
warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
except pytest.PytestUnhandledThreadExceptionWarning as e:
# This except happens when the warning is treated as an error (e.g. `-Werror`).
if meta.exc_value is not None:
# Exceptions have a better way to show the traceback, but
# warnings do not, so hide the traceback from the msg and
# set the cause so the traceback shows up in the right place.
e.args = (meta.cause_msg,)
e.__cause__ = meta.exc_value
errors.append(e)
if len(errors) == 1:
raise errors[0]
if errors:
raise ExceptionGroup("multiple thread exception warnings", errors)
finally:
del errors, meta, hook_error
def cleanup(
*, config: Config, prev_hook: Callable[[threading.ExceptHookArgs], object]
) -> None:
try:
try:
# We don't join threads here, so exceptions raised from any
# threads still running by the time _threading_atexits joins them
# do not get captured (see #13027).
collect_thread_exception(config)
finally:
threading.excepthook = prev_hook
finally:
del config.stash[thread_exceptions]
def thread_exception_hook(
args: threading.ExceptHookArgs,
/,
*,
append: Callable[[ThreadExceptionMeta | BaseException], object],
) -> None:
try:
# we need to compute these strings here as they might change after
# the excepthook finishes and before the metadata object is
# collected by a pytest hook
thread_name = "<unknown>" if args.thread is None else args.thread.name
summary = f"Exception in thread {thread_name}"
traceback_message = "\n\n" + "".join(
traceback.format_exception(
args.exc_type,
args.exc_value,
args.exc_traceback,
)
)
tracemalloc_tb = "\n" + tracemalloc_message(args.thread)
msg = summary + traceback_message + tracemalloc_tb
cause_msg = summary + tracemalloc_tb
append(
ThreadExceptionMeta(
# Compute these strings here as they might change later
msg=msg,
cause_msg=cause_msg,
exc_value=args.exc_value,
)
)
except BaseException as e:
append(e)
# Raising this will cause the exception to be logged twice, once in our
# collect_thread_exception and once by sys.excepthook
# which is fine - this should never happen anyway and if it does
# it should probably be reported as a pytest bug.
raise
def pytest_configure(config: Config) -> None:
prev_hook = threading.excepthook
deque: collections.deque[ThreadExceptionMeta | BaseException] = collections.deque()
config.stash[thread_exceptions] = deque
config.add_cleanup(functools.partial(cleanup, config=config, prev_hook=prev_hook))
threading.excepthook = functools.partial(thread_exception_hook, append=deque.append)
@pytest.hookimpl(trylast=True)
def pytest_runtest_setup(item: Item) -> None:
collect_thread_exception(item.config)
@pytest.hookimpl(trylast=True)
def pytest_runtest_call(item: Item) -> None:
collect_thread_exception(item.config)
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(item: Item) -> None:
collect_thread_exception(item.config)

View File

@@ -0,0 +1,94 @@
"""Indirection for time functions.
We intentionally grab some "time" functions internally to avoid tests mocking "time" to affect
pytest runtime information (issue #185).
Fixture "mock_timing" also interacts with this module for pytest's own tests.
"""
from __future__ import annotations
import dataclasses
from datetime import datetime
from datetime import timezone
from time import perf_counter
from time import sleep
from time import time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pytest import MonkeyPatch
@dataclasses.dataclass(frozen=True)
class Instant:
"""
Represents an instant in time, used to both get the timestamp value and to measure
the duration of a time span.
Inspired by Rust's `std::time::Instant`.
"""
# Creation time of this instant, using time.time(), to measure actual time.
# Note: using a `lambda` to correctly get the mocked time via `MockTiming`.
time: float = dataclasses.field(default_factory=lambda: time(), init=False)
# Performance counter tick of the instant, used to measure precise elapsed time.
# Note: using a `lambda` to correctly get the mocked time via `MockTiming`.
perf_count: float = dataclasses.field(
default_factory=lambda: perf_counter(), init=False
)
def elapsed(self) -> Duration:
"""Measure the duration since `Instant` was created."""
return Duration(start=self, stop=Instant())
def as_utc(self) -> datetime:
"""Instant as UTC datetime."""
return datetime.fromtimestamp(self.time, timezone.utc)
@dataclasses.dataclass(frozen=True)
class Duration:
"""A span of time as measured by `Instant.elapsed()`."""
start: Instant
stop: Instant
@property
def seconds(self) -> float:
"""Elapsed time of the duration in seconds, measured using a performance counter for precise timing."""
return self.stop.perf_count - self.start.perf_count
@dataclasses.dataclass
class MockTiming:
"""Mocks _pytest.timing with a known object that can be used to control timing in tests
deterministically.
pytest itself should always use functions from `_pytest.timing` instead of `time` directly.
This then allows us more control over time during testing, if testing code also
uses `_pytest.timing` functions.
Time is static, and only advances through `sleep` calls, thus tests might sleep over large
numbers and obtain accurate time() calls at the end, making tests reliable and instant."""
_current_time: float = datetime(2020, 5, 22, 14, 20, 50).timestamp()
def sleep(self, seconds: float) -> None:
self._current_time += seconds
def time(self) -> float:
return self._current_time
def patch(self, monkeypatch: MonkeyPatch) -> None:
from _pytest import timing # noqa: PLW0406
monkeypatch.setattr(timing, "sleep", self.sleep)
monkeypatch.setattr(timing, "time", self.time)
monkeypatch.setattr(timing, "perf_counter", self.time)
__all__ = ["perf_counter", "sleep", "time"]

View File

@@ -0,0 +1,312 @@
# mypy: allow-untyped-defs
"""Support for providing temporary directories to test functions."""
from __future__ import annotations
from collections.abc import Generator
import dataclasses
import os
from pathlib import Path
import re
from shutil import rmtree
import tempfile
from typing import Any
from typing import final
from typing import Literal
from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
from _pytest.compat import get_user_id
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Item
from _pytest.reports import TestReport
from _pytest.stash import StashKey
tmppath_result_key = StashKey[dict[str, bool]]()
RetentionType = Literal["all", "failed", "none"]
@final
@dataclasses.dataclass
class TempPathFactory:
"""Factory for temporary directories under the common base temp directory,
as discussed at :ref:`temporary directory location and retention`.
"""
_given_basetemp: Path | None
# pluggy TagTracerSub, not currently exposed, so Any.
_trace: Any
_basetemp: Path | None
_retention_count: int
_retention_policy: RetentionType
def __init__(
self,
given_basetemp: Path | None,
retention_count: int,
retention_policy: RetentionType,
trace,
basetemp: Path | None = None,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
if given_basetemp is None:
self._given_basetemp = None
else:
# Use os.path.abspath() to get absolute path instead of resolve() as it
# does not work the same in all platforms (see #4427).
# Path.absolute() exists, but it is not public (see https://bugs.python.org/issue25012).
self._given_basetemp = Path(os.path.abspath(str(given_basetemp)))
self._trace = trace
self._retention_count = retention_count
self._retention_policy = retention_policy
self._basetemp = basetemp
@classmethod
def from_config(
cls,
config: Config,
*,
_ispytest: bool = False,
) -> TempPathFactory:
"""Create a factory according to pytest configuration.
:meta private:
"""
check_ispytest(_ispytest)
count = int(config.getini("tmp_path_retention_count"))
if count < 0:
raise ValueError(
f"tmp_path_retention_count must be >= 0. Current input: {count}."
)
policy = config.getini("tmp_path_retention_policy")
if policy not in ("all", "failed", "none"):
raise ValueError(
f"tmp_path_retention_policy must be either all, failed, none. Current input: {policy}."
)
return cls(
given_basetemp=config.option.basetemp,
trace=config.trace.get("tmpdir"),
retention_count=count,
retention_policy=policy,
_ispytest=True,
)
def _ensure_relative_to_basetemp(self, basename: str) -> str:
basename = os.path.normpath(basename)
if (self.getbasetemp() / basename).resolve().parent != self.getbasetemp():
raise ValueError(f"{basename} is not a normalized and relative path")
return basename
def mktemp(self, basename: str, numbered: bool = True) -> Path:
"""Create a new temporary directory managed by the factory.
:param basename:
Directory base name, must be a relative path.
:param numbered:
If ``True``, ensure the directory is unique by adding a numbered
suffix greater than any existing one: ``basename="foo-"`` and ``numbered=True``
means that this function will create directories named ``"foo-0"``,
``"foo-1"``, ``"foo-2"`` and so on.
:returns:
The path to the new directory.
"""
basename = self._ensure_relative_to_basetemp(basename)
if not numbered:
p = self.getbasetemp().joinpath(basename)
p.mkdir(mode=0o700)
else:
p = make_numbered_dir(root=self.getbasetemp(), prefix=basename, mode=0o700)
self._trace("mktemp", p)
return p
def getbasetemp(self) -> Path:
"""Return the base temporary directory, creating it if needed.
:returns:
The base temporary directory.
"""
if self._basetemp is not None:
return self._basetemp
if self._given_basetemp is not None:
basetemp = self._given_basetemp
if basetemp.exists():
rm_rf(basetemp)
basetemp.mkdir(mode=0o700)
basetemp = basetemp.resolve()
else:
from_env = os.environ.get("PYTEST_DEBUG_TEMPROOT")
temproot = Path(from_env or tempfile.gettempdir()).resolve()
user = get_user() or "unknown"
# use a sub-directory in the temproot to speed-up
# make_numbered_dir() call
rootdir = temproot.joinpath(f"pytest-of-{user}")
try:
rootdir.mkdir(mode=0o700, exist_ok=True)
except OSError:
# getuser() likely returned illegal characters for the platform, use unknown back off mechanism
rootdir = temproot.joinpath("pytest-of-unknown")
rootdir.mkdir(mode=0o700, exist_ok=True)
# Because we use exist_ok=True with a predictable name, make sure
# we are the owners, to prevent any funny business (on unix, where
# temproot is usually shared).
# Also, to keep things private, fixup any world-readable temp
# rootdir's permissions. Historically 0o755 was used, so we can't
# just error out on this, at least for a while.
uid = get_user_id()
if uid is not None:
rootdir_stat = rootdir.stat()
if rootdir_stat.st_uid != uid:
raise OSError(
f"The temporary directory {rootdir} is not owned by the current user. "
"Fix this and try again."
)
if (rootdir_stat.st_mode & 0o077) != 0:
os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
keep = self._retention_count
if self._retention_policy == "none":
keep = 0
basetemp = make_numbered_dir_with_cleanup(
prefix="pytest-",
root=rootdir,
keep=keep,
lock_timeout=LOCK_TIMEOUT,
mode=0o700,
)
assert basetemp is not None, basetemp
self._basetemp = basetemp
self._trace("new basetemp", basetemp)
return basetemp
def get_user() -> str | None:
"""Return the current user name, or None if getuser() does not work
in the current environment (see #1010)."""
try:
# In some exotic environments, getpass may not be importable.
import getpass
return getpass.getuser()
except (ImportError, OSError, KeyError):
return None
def pytest_configure(config: Config) -> None:
"""Create a TempPathFactory and attach it to the config object.
This is to comply with existing plugins which expect the handler to be
available at pytest_configure time, but ideally should be moved entirely
to the tmp_path_factory session fixture.
"""
mp = MonkeyPatch()
config.add_cleanup(mp.undo)
_tmp_path_factory = TempPathFactory.from_config(config, _ispytest=True)
mp.setattr(config, "_tmp_path_factory", _tmp_path_factory, raising=False)
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"tmp_path_retention_count",
help="How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.",
default=3,
)
parser.addini(
"tmp_path_retention_policy",
help="Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. "
"(all/failed/none)",
default="all",
)
@fixture(scope="session")
def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
"""Return a :class:`pytest.TempPathFactory` instance for the test session."""
# Set dynamically by pytest_configure() above.
return request.config._tmp_path_factory # type: ignore
def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
name = request.node.name
name = re.sub(r"[\W]", "_", name)
MAXVAL = 30
name = name[:MAXVAL]
return factory.mktemp(name, numbered=True)
@fixture
def tmp_path(
request: FixtureRequest, tmp_path_factory: TempPathFactory
) -> Generator[Path]:
"""Return a temporary directory (as :class:`pathlib.Path` object)
which is unique to each test function invocation.
The temporary directory is created as a subdirectory
of the base temporary directory, with configurable retention,
as discussed in :ref:`temporary directory location and retention`.
"""
path = _mk_tmp(request, tmp_path_factory)
yield path
# Remove the tmpdir if the policy is "failed" and the test passed.
policy = tmp_path_factory._retention_policy
result_dict = request.node.stash[tmppath_result_key]
if policy == "failed" and result_dict.get("call", True):
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
# permissions, etc, in which case we ignore it.
rmtree(path, ignore_errors=True)
del request.node.stash[tmppath_result_key]
def pytest_sessionfinish(session, exitstatus: int | ExitCode):
"""After each session, remove base directory if all the tests passed,
the policy is "failed", and the basetemp is not specified by a user.
"""
tmp_path_factory: TempPathFactory = session.config._tmp_path_factory
basetemp = tmp_path_factory._basetemp
if basetemp is None:
return
policy = tmp_path_factory._retention_policy
if (
exitstatus == 0
and policy == "failed"
and tmp_path_factory._given_basetemp is None
):
if basetemp.is_dir():
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
# permissions, etc, in which case we ignore it.
rmtree(basetemp, ignore_errors=True)
# Remove dead symlinks.
if basetemp.is_dir():
cleanup_dead_symlinks(basetemp)
@hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_makereport(
item: Item, call
) -> Generator[None, TestReport, TestReport]:
rep = yield
assert rep.when is not None
empty: dict[str, bool] = {}
item.stash.setdefault(tmppath_result_key, empty)[rep.when] = rep.passed
return rep

View File

@@ -0,0 +1,24 @@
from __future__ import annotations
def tracemalloc_message(source: object) -> str:
if source is None:
return ""
try:
import tracemalloc
except ImportError:
return ""
tb = tracemalloc.get_object_traceback(source)
if tb is not None:
formatted_tb = "\n".join(tb.format())
# Use a leading new line to better separate the (large) output
# from the traceback to the previous warning text.
return f"\nObject allocated at:\n{formatted_tb}"
# No need for a leading new line.
url = "https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings"
return (
"Enable tracemalloc to get traceback where the object was allocated.\n"
f"See {url} for more info."
)

View File

@@ -0,0 +1,516 @@
# mypy: allow-untyped-defs
"""Discover and run std-library "unittest" style tests."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Iterator
from enum import auto
from enum import Enum
import inspect
import sys
import traceback
import types
from typing import TYPE_CHECKING
from typing import Union
import _pytest._code
from _pytest.compat import is_async_function
from _pytest.config import hookimpl
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import exit
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import Module
from _pytest.runner import CallInfo
import pytest
if sys.version_info[:2] < (3, 11):
from exceptiongroup import ExceptionGroup
if TYPE_CHECKING:
import unittest
import twisted.trial.unittest
_SysExcInfoType = Union[
tuple[type[BaseException], BaseException, types.TracebackType],
tuple[None, None, None],
]
def pytest_pycollect_makeitem(
collector: Module | Class, name: str, obj: object
) -> UnitTestCase | None:
try:
# Has unittest been imported?
ut = sys.modules["unittest"]
# Is obj a subclass of unittest.TestCase?
# Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore
return None
except Exception:
return None
# Is obj a concrete class?
# Abstract classes can't be instantiated so no point collecting them.
if inspect.isabstract(obj):
return None
# Yes, so let's collect it.
return UnitTestCase.from_parent(collector, name=name, obj=obj)
class UnitTestCase(Class):
# Marker for fixturemanger.getfixtureinfo()
# to declare that our children do not support funcargs.
nofuncargs = True
def newinstance(self):
# TestCase __init__ takes the method (test) name. The TestCase
# constructor treats the name "runTest" as a special no-op, so it can be
# used when a dummy instance is needed. While unittest.TestCase has a
# default, some subclasses omit the default (#9610), so always supply
# it.
return self.obj("runTest")
def collect(self) -> Iterable[Item | Collector]:
from unittest import TestLoader
cls = self.obj
if not getattr(cls, "__test__", True):
return
skipped = _is_skipped(cls)
if not skipped:
self._register_unittest_setup_method_fixture(cls)
self._register_unittest_setup_class_fixture(cls)
self._register_setup_class_fixture()
self.session._fixturemanager.parsefactories(self.newinstance(), self.nodeid)
loader = TestLoader()
foundsomething = False
for name in loader.getTestCaseNames(self.obj):
x = getattr(self.obj, name)
if not getattr(x, "__test__", True):
continue
yield TestCaseFunction.from_parent(self, name=name)
foundsomething = True
if not foundsomething:
runtest = getattr(self.obj, "runTest", None)
if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None)
if ut is None or runtest != ut.TestCase.runTest:
yield TestCaseFunction.from_parent(self, name="runTest")
def _register_unittest_setup_class_fixture(self, cls: type) -> None:
"""Register an auto-use fixture to invoke setUpClass and
tearDownClass (#517)."""
setup = getattr(cls, "setUpClass", None)
teardown = getattr(cls, "tearDownClass", None)
if setup is None and teardown is None:
return None
cleanup = getattr(cls, "doClassCleanups", lambda: None)
def process_teardown_exceptions() -> None:
# tearDown_exceptions is a list set in the class containing exc_infos for errors during
# teardown for the class.
exc_infos = getattr(cls, "tearDown_exceptions", None)
if not exc_infos:
return
exceptions = [exc for (_, exc, _) in exc_infos]
# If a single exception, raise it directly as this provides a more readable
# error (hopefully this will improve in #12255).
if len(exceptions) == 1:
raise exceptions[0]
else:
raise ExceptionGroup("Unittest class cleanup errors", exceptions)
def unittest_setup_class_fixture(
request: FixtureRequest,
) -> Generator[None]:
cls = request.cls
if _is_skipped(cls):
reason = cls.__unittest_skip_why__
raise pytest.skip.Exception(reason, _use_item_location=True)
if setup is not None:
try:
setup()
# unittest does not call the cleanup function for every BaseException, so we
# follow this here.
except Exception:
cleanup()
process_teardown_exceptions()
raise
yield
try:
if teardown is not None:
teardown()
finally:
cleanup()
process_teardown_exceptions()
self.session._fixturemanager._register_fixture(
# Use a unique name to speed up lookup.
name=f"_unittest_setUpClass_fixture_{cls.__qualname__}",
func=unittest_setup_class_fixture,
nodeid=self.nodeid,
scope="class",
autouse=True,
)
def _register_unittest_setup_method_fixture(self, cls: type) -> None:
"""Register an auto-use fixture to invoke setup_method and
teardown_method (#517)."""
setup = getattr(cls, "setup_method", None)
teardown = getattr(cls, "teardown_method", None)
if setup is None and teardown is None:
return None
def unittest_setup_method_fixture(
request: FixtureRequest,
) -> Generator[None]:
self = request.instance
if _is_skipped(self):
reason = self.__unittest_skip_why__
raise pytest.skip.Exception(reason, _use_item_location=True)
if setup is not None:
setup(self, request.function)
yield
if teardown is not None:
teardown(self, request.function)
self.session._fixturemanager._register_fixture(
# Use a unique name to speed up lookup.
name=f"_unittest_setup_method_fixture_{cls.__qualname__}",
func=unittest_setup_method_fixture,
nodeid=self.nodeid,
scope="function",
autouse=True,
)
class TestCaseFunction(Function):
nofuncargs = True
_excinfo: list[_pytest._code.ExceptionInfo[BaseException]] | None = None
def _getinstance(self):
assert isinstance(self.parent, UnitTestCase)
return self.parent.obj(self.name)
# Backward compat for pytest-django; can be removed after pytest-django
# updates + some slack.
@property
def _testcase(self):
return self.instance
def setup(self) -> None:
# A bound method to be called during teardown() if set (see 'runtest()').
self._explicit_tearDown: Callable[[], None] | None = None
super().setup()
def teardown(self) -> None:
if self._explicit_tearDown is not None:
self._explicit_tearDown()
self._explicit_tearDown = None
self._obj = None
del self._instance
super().teardown()
def startTest(self, testcase: unittest.TestCase) -> None:
pass
def _addexcinfo(self, rawexcinfo: _SysExcInfoType) -> None:
rawexcinfo = _handle_twisted_exc_info(rawexcinfo)
try:
excinfo = _pytest._code.ExceptionInfo[BaseException].from_exc_info(
rawexcinfo # type: ignore[arg-type]
)
# Invoke the attributes to trigger storing the traceback
# trial causes some issue there.
_ = excinfo.value
_ = excinfo.traceback
except TypeError:
try:
try:
values = traceback.format_exception(*rawexcinfo)
values.insert(
0,
"NOTE: Incompatible Exception Representation, "
"displaying natively:\n\n",
)
fail("".join(values), pytrace=False)
except (fail.Exception, KeyboardInterrupt):
raise
except BaseException:
fail(
"ERROR: Unknown Incompatible Exception "
f"representation:\n{rawexcinfo!r}",
pytrace=False,
)
except KeyboardInterrupt:
raise
except fail.Exception:
excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo)
def addError(
self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType
) -> None:
try:
if isinstance(rawexcinfo[1], exit.Exception):
exit(rawexcinfo[1].msg)
except TypeError:
pass
self._addexcinfo(rawexcinfo)
def addFailure(
self, testcase: unittest.TestCase, rawexcinfo: _SysExcInfoType
) -> None:
self._addexcinfo(rawexcinfo)
def addSkip(self, testcase: unittest.TestCase, reason: str) -> None:
try:
raise pytest.skip.Exception(reason, _use_item_location=True)
except skip.Exception:
self._addexcinfo(sys.exc_info())
def addExpectedFailure(
self,
testcase: unittest.TestCase,
rawexcinfo: _SysExcInfoType,
reason: str = "",
) -> None:
try:
xfail(str(reason))
except xfail.Exception:
self._addexcinfo(sys.exc_info())
def addUnexpectedSuccess(
self,
testcase: unittest.TestCase,
reason: twisted.trial.unittest.Todo | None = None,
) -> None:
msg = "Unexpected success"
if reason:
msg += f": {reason.reason}"
# Preserve unittest behaviour - fail the test. Explicitly not an XPASS.
try:
fail(msg, pytrace=False)
except fail.Exception:
self._addexcinfo(sys.exc_info())
def addSuccess(self, testcase: unittest.TestCase) -> None:
pass
def stopTest(self, testcase: unittest.TestCase) -> None:
pass
def addDuration(self, testcase: unittest.TestCase, elapsed: float) -> None:
pass
def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
testcase = self.instance
assert testcase is not None
maybe_wrap_pytest_function_for_tracing(self)
# Let the unittest framework handle async functions.
if is_async_function(self.obj):
testcase(result=self)
else:
# When --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up
# instance variables, which makes it difficult to debug.
# Arguably we could always postpone tearDown(), but this changes the moment where the
# TestCase instance interacts with the results object, so better to only do it
# when absolutely needed.
# We need to consider if the test itself is skipped, or the whole class.
assert isinstance(self.parent, UnitTestCase)
skipped = _is_skipped(self.obj) or _is_skipped(self.parent.obj)
if self.config.getoption("usepdb") and not skipped:
self._explicit_tearDown = testcase.tearDown
setattr(testcase, "tearDown", lambda *args: None)
# We need to update the actual bound method with self.obj, because
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper.
setattr(testcase, self.name, self.obj)
try:
testcase(result=self)
finally:
delattr(testcase, self.name)
def _traceback_filter(
self, excinfo: _pytest._code.ExceptionInfo[BaseException]
) -> _pytest._code.Traceback:
traceback = super()._traceback_filter(excinfo)
ntraceback = traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest"),
)
if not ntraceback:
ntraceback = traceback
return ntraceback
@hookimpl(tryfirst=True)
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
if isinstance(item, TestCaseFunction):
if item._excinfo:
call.excinfo = item._excinfo.pop(0)
try:
del call.result
except AttributeError:
pass
# Convert unittest.SkipTest to pytest.skip.
# This is actually only needed for nose, which reuses unittest.SkipTest for
# its own nose.SkipTest. For unittest TestCases, SkipTest is already
# handled internally, and doesn't reach here.
unittest = sys.modules.get("unittest")
if unittest and call.excinfo and isinstance(call.excinfo.value, unittest.SkipTest):
excinfo = call.excinfo
call2 = CallInfo[None].from_call(
lambda: pytest.skip(str(excinfo.value)), call.when
)
call.excinfo = call2.excinfo
def _is_skipped(obj) -> bool:
"""Return True if the given object has been marked with @unittest.skip."""
return bool(getattr(obj, "__unittest_skip__", False))
def pytest_configure() -> None:
"""Register the TestCaseFunction class as an IReporter if twisted.trial is available."""
if _get_twisted_version() is not TwistedVersion.NotInstalled:
from twisted.trial.itrial import IReporter
from zope.interface import classImplements
classImplements(TestCaseFunction, IReporter)
class TwistedVersion(Enum):
"""
The Twisted version installed in the environment.
We have different workarounds in place for different versions of Twisted.
"""
# Twisted version 24 or prior.
Version24 = auto()
# Twisted version 25 or later.
Version25 = auto()
# Twisted version is not available.
NotInstalled = auto()
def _get_twisted_version() -> TwistedVersion:
# We need to check if "twisted.trial.unittest" is specifically present in sys.modules.
# This is because we intend to integrate with Trial only when it's actively running
# the test suite, but not needed when only other Twisted components are in use.
if "twisted.trial.unittest" not in sys.modules:
return TwistedVersion.NotInstalled
import importlib.metadata
import packaging.version
version_str = importlib.metadata.version("twisted")
version = packaging.version.parse(version_str)
if version.major <= 24:
return TwistedVersion.Version24
else:
return TwistedVersion.Version25
# Name of the attribute in `twisted.python.Failure` instances that stores
# the `sys.exc_info()` tuple.
# See twisted.trial support in `pytest_runtest_protocol`.
TWISTED_RAW_EXCINFO_ATTR = "_twisted_raw_excinfo"
@hookimpl(wrapper=True)
def pytest_runtest_protocol(item: Item) -> Iterator[None]:
if _get_twisted_version() is TwistedVersion.Version24:
import twisted.python.failure as ut
# Monkeypatch `Failure.__init__` to store the raw exception info.
original__init__ = ut.Failure.__init__
def store_raw_exception_info(
self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None
): # pragma: no cover
if exc_value is None:
raw_exc_info = sys.exc_info()
else:
if exc_type is None:
exc_type = type(exc_value)
if exc_tb is None:
exc_tb = sys.exc_info()[2]
raw_exc_info = (exc_type, exc_value, exc_tb)
setattr(self, TWISTED_RAW_EXCINFO_ATTR, tuple(raw_exc_info))
try:
original__init__(
self, exc_value, exc_type, exc_tb, captureVars=captureVars
)
except TypeError: # pragma: no cover
original__init__(self, exc_value, exc_type, exc_tb)
with MonkeyPatch.context() as patcher:
patcher.setattr(ut.Failure, "__init__", store_raw_exception_info)
return (yield)
else:
return (yield)
def _handle_twisted_exc_info(
rawexcinfo: _SysExcInfoType | BaseException,
) -> _SysExcInfoType:
"""
Twisted passes a custom Failure instance to `addError()` instead of using `sys.exc_info()`.
Therefore, if `rawexcinfo` is a `Failure` instance, convert it into the equivalent `sys.exc_info()` tuple
as expected by pytest.
"""
twisted_version = _get_twisted_version()
if twisted_version is TwistedVersion.NotInstalled:
# Unfortunately, because we cannot import `twisted.python.failure` at the top of the file
# and use it in the signature, we need to use `type:ignore` here because we cannot narrow
# the type properly in the `if` statement above.
return rawexcinfo # type:ignore[return-value]
elif twisted_version is TwistedVersion.Version24:
# Twisted calls addError() passing its own classes (like `twisted.python.Failure`), which violates
# the `addError()` signature, so we extract the original `sys.exc_info()` tuple which is stored
# in the object.
if hasattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR):
saved_exc_info = getattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR)
# Delete the attribute from the original object to avoid leaks.
delattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR)
return saved_exc_info # type:ignore[no-any-return]
return rawexcinfo # type:ignore[return-value]
elif twisted_version is TwistedVersion.Version25:
if isinstance(rawexcinfo, BaseException):
import twisted.python.failure
if isinstance(rawexcinfo, twisted.python.failure.Failure):
tb = rawexcinfo.__traceback__
if tb is None:
tb = sys.exc_info()[2]
return type(rawexcinfo.value), rawexcinfo.value, tb
return rawexcinfo # type:ignore[return-value]
else:
# Ideally we would use assert_never() here, but it is not available in all Python versions
# we support, plus we do not require `type_extensions` currently.
assert False, f"Unexpected Twisted version: {twisted_version}"

View File

@@ -0,0 +1,163 @@
from __future__ import annotations
import collections
from collections.abc import Callable
import functools
import gc
import sys
import traceback
from typing import NamedTuple
from typing import TYPE_CHECKING
import warnings
from _pytest.config import Config
from _pytest.nodes import Item
from _pytest.stash import StashKey
from _pytest.tracemalloc import tracemalloc_message
import pytest
if TYPE_CHECKING:
pass
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
# This is a stash item and not a simple constant to allow pytester to override it.
gc_collect_iterations_key = StashKey[int]()
def gc_collect_harder(iterations: int) -> None:
for _ in range(iterations):
gc.collect()
class UnraisableMeta(NamedTuple):
msg: str
cause_msg: str
exc_value: BaseException | None
unraisable_exceptions: StashKey[collections.deque[UnraisableMeta | BaseException]] = (
StashKey()
)
def collect_unraisable(config: Config) -> None:
pop_unraisable = config.stash[unraisable_exceptions].pop
errors: list[pytest.PytestUnraisableExceptionWarning | RuntimeError] = []
meta = None
hook_error = None
try:
while True:
try:
meta = pop_unraisable()
except IndexError:
break
if isinstance(meta, BaseException):
hook_error = RuntimeError("Failed to process unraisable exception")
hook_error.__cause__ = meta
errors.append(hook_error)
continue
msg = meta.msg
try:
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
except pytest.PytestUnraisableExceptionWarning as e:
# This except happens when the warning is treated as an error (e.g. `-Werror`).
if meta.exc_value is not None:
# Exceptions have a better way to show the traceback, but
# warnings do not, so hide the traceback from the msg and
# set the cause so the traceback shows up in the right place.
e.args = (meta.cause_msg,)
e.__cause__ = meta.exc_value
errors.append(e)
if len(errors) == 1:
raise errors[0]
if errors:
raise ExceptionGroup("multiple unraisable exception warnings", errors)
finally:
del errors, meta, hook_error
def cleanup(
*, config: Config, prev_hook: Callable[[sys.UnraisableHookArgs], object]
) -> None:
# A single collection doesn't necessarily collect everything.
# Constant determined experimentally by the Trio project.
gc_collect_iterations = config.stash.get(gc_collect_iterations_key, 5)
try:
try:
gc_collect_harder(gc_collect_iterations)
collect_unraisable(config)
finally:
sys.unraisablehook = prev_hook
finally:
del config.stash[unraisable_exceptions]
def unraisable_hook(
unraisable: sys.UnraisableHookArgs,
/,
*,
append: Callable[[UnraisableMeta | BaseException], object],
) -> None:
try:
# we need to compute these strings here as they might change after
# the unraisablehook finishes and before the metadata object is
# collected by a pytest hook
err_msg = (
"Exception ignored in" if unraisable.err_msg is None else unraisable.err_msg
)
summary = f"{err_msg}: {unraisable.object!r}"
traceback_message = "\n\n" + "".join(
traceback.format_exception(
unraisable.exc_type,
unraisable.exc_value,
unraisable.exc_traceback,
)
)
tracemalloc_tb = "\n" + tracemalloc_message(unraisable.object)
msg = summary + traceback_message + tracemalloc_tb
cause_msg = summary + tracemalloc_tb
append(
UnraisableMeta(
msg=msg,
cause_msg=cause_msg,
exc_value=unraisable.exc_value,
)
)
except BaseException as e:
append(e)
# Raising this will cause the exception to be logged twice, once in our
# collect_unraisable and once by the unraisablehook calling machinery
# which is fine - this should never happen anyway and if it does
# it should probably be reported as a pytest bug.
raise
def pytest_configure(config: Config) -> None:
prev_hook = sys.unraisablehook
deque: collections.deque[UnraisableMeta | BaseException] = collections.deque()
config.stash[unraisable_exceptions] = deque
config.add_cleanup(functools.partial(cleanup, config=config, prev_hook=prev_hook))
sys.unraisablehook = functools.partial(unraisable_hook, append=deque.append)
@pytest.hookimpl(trylast=True)
def pytest_runtest_setup(item: Item) -> None:
collect_unraisable(item.config)
@pytest.hookimpl(trylast=True)
def pytest_runtest_call(item: Item) -> None:
collect_unraisable(item.config)
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(item: Item) -> None:
collect_unraisable(item.config)

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
import dataclasses
import inspect
from types import FunctionType
from typing import Any
from typing import final
from typing import Generic
from typing import TypeVar
import warnings
class PytestWarning(UserWarning):
"""Base class for all warnings emitted by pytest."""
__module__ = "pytest"
@final
class PytestAssertRewriteWarning(PytestWarning):
"""Warning emitted by the pytest assert rewrite module."""
__module__ = "pytest"
@final
class PytestCacheWarning(PytestWarning):
"""Warning emitted by the cache plugin in various situations."""
__module__ = "pytest"
@final
class PytestConfigWarning(PytestWarning):
"""Warning emitted for configuration issues."""
__module__ = "pytest"
@final
class PytestCollectionWarning(PytestWarning):
"""Warning emitted when pytest is not able to collect a file or symbol in a module."""
__module__ = "pytest"
class PytestDeprecationWarning(PytestWarning, DeprecationWarning):
"""Warning class for features that will be removed in a future version."""
__module__ = "pytest"
class PytestRemovedIn9Warning(PytestDeprecationWarning):
"""Warning class for features that will be removed in pytest 9."""
__module__ = "pytest"
@final
class PytestExperimentalApiWarning(PytestWarning, FutureWarning):
"""Warning category used to denote experiments in pytest.
Use sparingly as the API might change or even be removed completely in a
future version.
"""
__module__ = "pytest"
@classmethod
def simple(cls, apiname: str) -> PytestExperimentalApiWarning:
return cls(f"{apiname} is an experimental api that may change over time")
@final
class PytestReturnNotNoneWarning(PytestWarning):
"""
Warning emitted when a test function returns a value other than ``None``.
See :ref:`return-not-none` for details.
"""
__module__ = "pytest"
@final
class PytestUnknownMarkWarning(PytestWarning):
"""Warning emitted on use of unknown markers.
See :ref:`mark` for details.
"""
__module__ = "pytest"
@final
class PytestUnraisableExceptionWarning(PytestWarning):
"""An unraisable exception was reported.
Unraisable exceptions are exceptions raised in :meth:`__del__ <object.__del__>`
implementations and similar situations when the exception cannot be raised
as normal.
"""
__module__ = "pytest"
@final
class PytestUnhandledThreadExceptionWarning(PytestWarning):
"""An unhandled exception occurred in a :class:`~threading.Thread`.
Such exceptions don't propagate normally.
"""
__module__ = "pytest"
_W = TypeVar("_W", bound=PytestWarning)
@final
@dataclasses.dataclass
class UnformattedWarning(Generic[_W]):
"""A warning meant to be formatted during runtime.
This is used to hold warnings that need to format their message at runtime,
as opposed to a direct message.
"""
category: type[_W]
template: str
def format(self, **kwargs: Any) -> _W:
"""Return an instance of the warning category, formatted with given kwargs."""
return self.category(self.template.format(**kwargs))
@final
class PytestFDWarning(PytestWarning):
"""When the lsof plugin finds leaked fds."""
__module__ = "pytest"
def warn_explicit_for(method: FunctionType, message: PytestWarning) -> None:
"""
Issue the warning :param:`message` for the definition of the given :param:`method`
this helps to log warnings for functions defined prior to finding an issue with them
(like hook wrappers being marked in a legacy mechanism)
"""
lineno = method.__code__.co_firstlineno
filename = inspect.getfile(method)
module = method.__module__
mod_globals = method.__globals__
try:
warnings.warn_explicit(
message,
type(message),
filename=filename,
module=module,
registry=mod_globals.setdefault("__warningregistry__", {}),
lineno=lineno,
)
except Warning as w:
# If warnings are errors (e.g. -Werror), location information gets lost, so we add it to the message.
raise type(w)(f"{w}\n at {filename}:{lineno}") from None

View File

@@ -0,0 +1,152 @@
# mypy: allow-untyped-defs
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
from contextlib import ExitStack
import sys
from typing import Literal
import warnings
from _pytest.config import apply_warning_filters
from _pytest.config import Config
from _pytest.config import parse_warning_filter
from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.terminal import TerminalReporter
from _pytest.tracemalloc import tracemalloc_message
import pytest
@contextmanager
def catch_warnings_for_item(
config: Config,
ihook,
when: Literal["config", "collect", "runtest"],
item: Item | None,
*,
record: bool = True,
) -> Generator[None]:
"""Context manager that catches warnings generated in the contained execution block.
``item`` can be None if we are not in the context of an item execution.
Each warning captured triggers the ``pytest_warning_recorded`` hook.
"""
config_filters = config.getini("filterwarnings")
cmdline_filters = config.known_args_namespace.pythonwarnings or []
with warnings.catch_warnings(record=record) as log:
if not sys.warnoptions:
# If user is not explicitly configuring warning filters, show deprecation warnings by default (#2908).
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning)
# To be enabled in pytest 9.0.0.
# warnings.filterwarnings("error", category=pytest.PytestRemovedIn9Warning)
apply_warning_filters(config_filters, cmdline_filters)
# apply filters from "filterwarnings" marks
nodeid = "" if item is None else item.nodeid
if item is not None:
for mark in item.iter_markers(name="filterwarnings"):
for arg in mark.args:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
try:
yield
finally:
if record:
# mypy can't infer that record=True means log is not None; help it.
assert log is not None
for warning_message in log:
ihook.pytest_warning_recorded.call_historic(
kwargs=dict(
warning_message=warning_message,
nodeid=nodeid,
when=when,
location=None,
)
)
def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
"""Convert a warnings.WarningMessage to a string."""
return warnings.formatwarning(
str(warning_message.message),
warning_message.category,
warning_message.filename,
warning_message.lineno,
warning_message.line,
) + tracemalloc_message(warning_message.source)
@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item
):
return (yield)
@pytest.hookimpl(wrapper=True, tryfirst=True)
def pytest_collection(session: Session) -> Generator[None, object, object]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="collect", item=None
):
return (yield)
@pytest.hookimpl(wrapper=True)
def pytest_terminal_summary(
terminalreporter: TerminalReporter,
) -> Generator[None]:
config = terminalreporter.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
):
return (yield)
@pytest.hookimpl(wrapper=True)
def pytest_sessionfinish(session: Session) -> Generator[None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
):
return (yield)
@pytest.hookimpl(wrapper=True)
def pytest_load_initial_conftests(
early_config: Config,
) -> Generator[None]:
with catch_warnings_for_item(
config=early_config, ihook=early_config.hook, when="config", item=None
):
return (yield)
def pytest_configure(config: Config) -> None:
with ExitStack() as stack:
stack.enter_context(
catch_warnings_for_item(
config=config,
ihook=config.hook,
when="config",
item=None,
# this disables recording because the terminalreporter has
# finished by the time it comes to reporting logged warnings
# from the end of config cleanup. So for now, this is only
# useful for setting a warning filter with an 'error' action.
record=False,
)
)
config.addinivalue_line(
"markers",
"filterwarnings(warning): add a warning filter to the given test. "
"see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#pytest-mark-filterwarnings ",
)
config.add_cleanup(stack.pop_all().close)

View File

@@ -0,0 +1 @@
pip

View File

@@ -0,0 +1,93 @@
Metadata-Version: 2.4
Name: anyio
Version: 4.10.0
Summary: High-level concurrency and networking framework on top of asyncio or Trio
Author-email: Alex Grönholm <alex.gronholm@nextday.fi>
License-Expression: MIT
Project-URL: Documentation, https://anyio.readthedocs.io/en/latest/
Project-URL: Changelog, https://anyio.readthedocs.io/en/stable/versionhistory.html
Project-URL: Source code, https://github.com/agronholm/anyio
Project-URL: Issue tracker, https://github.com/agronholm/anyio/issues
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Framework :: AnyIO
Classifier: Typing :: Typed
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Requires-Python: >=3.9
Description-Content-Type: text/x-rst
License-File: LICENSE
Requires-Dist: exceptiongroup>=1.0.2; python_version < "3.11"
Requires-Dist: idna>=2.8
Requires-Dist: sniffio>=1.1
Requires-Dist: typing_extensions>=4.5; python_version < "3.13"
Provides-Extra: trio
Requires-Dist: trio>=0.26.1; extra == "trio"
Dynamic: license-file
.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg
:target: https://github.com/agronholm/anyio/actions/workflows/test.yml
:alt: Build Status
.. image:: https://coveralls.io/repos/github/agronholm/anyio/badge.svg?branch=master
:target: https://coveralls.io/github/agronholm/anyio?branch=master
:alt: Code Coverage
.. image:: https://readthedocs.org/projects/anyio/badge/?version=latest
:target: https://anyio.readthedocs.io/en/latest/?badge=latest
:alt: Documentation
.. image:: https://badges.gitter.im/gitterHQ/gitter.svg
:target: https://gitter.im/python-trio/AnyIO
:alt: Gitter chat
AnyIO is an asynchronous networking and concurrency library that works on top of either asyncio_ or
Trio_. It implements Trio-like `structured concurrency`_ (SC) on top of asyncio and works in harmony
with the native SC of Trio itself.
Applications and libraries written against AnyIO's API will run unmodified on either asyncio_ or
Trio_. AnyIO can also be adopted into a library or application incrementally bit by bit, no full
refactoring necessary. It will blend in with the native libraries of your chosen backend.
To find out why you might want to use AnyIO's APIs instead of asyncio's, you can read about it
`here <https://anyio.readthedocs.io/en/stable/why.html>`_.
Documentation
-------------
View full documentation at: https://anyio.readthedocs.io/
Features
--------
AnyIO offers the following functionality:
* Task groups (nurseries_ in trio terminology)
* High-level networking (TCP, UDP and UNIX sockets)
* `Happy eyeballs`_ algorithm for TCP connections (more robust than that of asyncio on Python
3.8)
* async/await style UDP sockets (unlike asyncio where you still have to use Transports and
Protocols)
* A versatile API for byte streams and object streams
* Inter-task synchronization and communication (locks, conditions, events, semaphores, object
streams)
* Worker threads
* Subprocesses
* Asynchronous file I/O (using worker threads)
* Signal handling
AnyIO also comes with its own pytest_ plugin which also supports asynchronous fixtures.
It even works with the popular Hypothesis_ library.
.. _asyncio: https://docs.python.org/3/library/asyncio.html
.. _Trio: https://github.com/python-trio/trio
.. _structured concurrency: https://en.wikipedia.org/wiki/Structured_concurrency
.. _nurseries: https://trio.readthedocs.io/en/stable/reference-core.html#nurseries-and-spawning
.. _Happy eyeballs: https://en.wikipedia.org/wiki/Happy_Eyeballs
.. _pytest: https://docs.pytest.org/en/latest/
.. _Hypothesis: https://hypothesis.works/

View File

@@ -0,0 +1,90 @@
anyio-4.10.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
anyio-4.10.0.dist-info/METADATA,sha256=1AD_60gPgqxWKsO54FUTbKDQHyni5j_56_XQinKJ9LQ,4014
anyio-4.10.0.dist-info/RECORD,,
anyio-4.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
anyio-4.10.0.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
anyio-4.10.0.dist-info/licenses/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
anyio-4.10.0.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
anyio/__init__.py,sha256=z3IyWgWQuxCi-KUwma-1LSys4WB50mV2N8FvS9_IePE,5955
anyio/__pycache__/__init__.cpython-312.pyc,,
anyio/__pycache__/from_thread.cpython-312.pyc,,
anyio/__pycache__/lowlevel.cpython-312.pyc,,
anyio/__pycache__/pytest_plugin.cpython-312.pyc,,
anyio/__pycache__/to_interpreter.cpython-312.pyc,,
anyio/__pycache__/to_process.cpython-312.pyc,,
anyio/__pycache__/to_thread.cpython-312.pyc,,
anyio/_backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/_backends/__pycache__/__init__.cpython-312.pyc,,
anyio/_backends/__pycache__/_asyncio.cpython-312.pyc,,
anyio/_backends/__pycache__/_trio.cpython-312.pyc,,
anyio/_backends/_asyncio.py,sha256=YXpQJ0C-tNiYvZdElVa3zGflG_Jdvf7FNDiG9-THhMg,97359
anyio/_backends/_trio.py,sha256=tRGDtos6xmqmGlstfI8wEjGvhZq0y_SYTaM2m8zatRU,41963
anyio/_core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/_core/__pycache__/__init__.cpython-312.pyc,,
anyio/_core/__pycache__/_asyncio_selector_thread.cpython-312.pyc,,
anyio/_core/__pycache__/_contextmanagers.cpython-312.pyc,,
anyio/_core/__pycache__/_eventloop.cpython-312.pyc,,
anyio/_core/__pycache__/_exceptions.cpython-312.pyc,,
anyio/_core/__pycache__/_fileio.cpython-312.pyc,,
anyio/_core/__pycache__/_resources.cpython-312.pyc,,
anyio/_core/__pycache__/_signals.cpython-312.pyc,,
anyio/_core/__pycache__/_sockets.cpython-312.pyc,,
anyio/_core/__pycache__/_streams.cpython-312.pyc,,
anyio/_core/__pycache__/_subprocesses.cpython-312.pyc,,
anyio/_core/__pycache__/_synchronization.cpython-312.pyc,,
anyio/_core/__pycache__/_tasks.cpython-312.pyc,,
anyio/_core/__pycache__/_tempfile.cpython-312.pyc,,
anyio/_core/__pycache__/_testing.cpython-312.pyc,,
anyio/_core/__pycache__/_typedattr.cpython-312.pyc,,
anyio/_core/_asyncio_selector_thread.py,sha256=2PdxFM3cs02Kp6BSppbvmRT7q7asreTW5FgBxEsflBo,5626
anyio/_core/_contextmanagers.py,sha256=YInBCabiEeS-UaP_Jdxa1CaFC71ETPW8HZTHIM8Rsc8,7215
anyio/_core/_eventloop.py,sha256=t_tAwBFPjF8jrZGjlJ6bbYy6KA3bjsbZxV9mvh9t1i0,4695
anyio/_core/_exceptions.py,sha256=uQ9yXs3gRghZiuxiWtbvVlHB6CvCRtxObKMWF-Mnz18,3683
anyio/_core/_fileio.py,sha256=KATysDZP7bvwwjpUwEaGAc0xGouJgAPqNVpnBMTsToY,23332
anyio/_core/_resources.py,sha256=NbmU5O5UX3xEyACnkmYX28Fmwdl-f-ny0tHym26e0w0,435
anyio/_core/_signals.py,sha256=vulT1M1xdLYtAR-eY5TamIgaf1WTlOwOrMGwswlTTr8,905
anyio/_core/_sockets.py,sha256=MRo3vVzBLnWwA0DqjWhJ2ICj_XKQ78BtWxdrSAwKcxU,32232
anyio/_core/_streams.py,sha256=OnaKgoDD-FcMSwLvkoAUGP51sG2ZdRvMpxt9q2w1gYA,1804
anyio/_core/_subprocesses.py,sha256=EXm5igL7dj55iYkPlbYVAqtbqxJxjU-6OndSTIx9SRg,8047
anyio/_core/_synchronization.py,sha256=76KyUbGD3A3eCXPrLnOccQfRsNSxIcoR36JeK1P4VFQ,20306
anyio/_core/_tasks.py,sha256=f3CuWwo06cCZ6jaOv-JHFKWkgpgf2cvaF25Oh4augMA,4757
anyio/_core/_tempfile.py,sha256=lHb7CW4FyIlpkf5ADAf4VmLHCKwEHF9nxqNyBCFFUiA,19697
anyio/_core/_testing.py,sha256=YUGwA5cgFFbUTv4WFd7cv_BSVr4ryTtPp8owQA3JdWE,2118
anyio/_core/_typedattr.py,sha256=P4ozZikn3-DbpoYcvyghS_FOYAgbmUxeoU8-L_07pZM,2508
anyio/abc/__init__.py,sha256=6mWhcl_pGXhrgZVHP_TCfMvIXIOp9mroEFM90fYCU_U,2869
anyio/abc/__pycache__/__init__.cpython-312.pyc,,
anyio/abc/__pycache__/_eventloop.cpython-312.pyc,,
anyio/abc/__pycache__/_resources.cpython-312.pyc,,
anyio/abc/__pycache__/_sockets.cpython-312.pyc,,
anyio/abc/__pycache__/_streams.cpython-312.pyc,,
anyio/abc/__pycache__/_subprocesses.cpython-312.pyc,,
anyio/abc/__pycache__/_tasks.cpython-312.pyc,,
anyio/abc/__pycache__/_testing.cpython-312.pyc,,
anyio/abc/_eventloop.py,sha256=_rrVDoNAS9yIFvSE70ewoppYd_9zNbRjPFl5UPMSR8I,10729
anyio/abc/_resources.py,sha256=DrYvkNN1hH6Uvv5_5uKySvDsnknGVDe8FCKfko0VtN8,783
anyio/abc/_sockets.py,sha256=ECTY0jLEF18gryANHR3vFzXzGdZ-xPwELq1QdgOb0Jo,13258
anyio/abc/_streams.py,sha256=005GKSCXGprxnhucILboSqc2JFovECZk9m3p-qqxXVc,7640
anyio/abc/_subprocesses.py,sha256=cumAPJTktOQtw63IqG0lDpyZqu_l1EElvQHMiwJgL08,2067
anyio/abc/_tasks.py,sha256=Jh4LXVz1DoRacOnw1rwAS9wujNiEWK9oqdF0cTEhhNA,3604
anyio/abc/_testing.py,sha256=tBJUzkSfOXJw23fe8qSJ03kJlShOYjjaEyFB6k6MYT8,1821
anyio/from_thread.py,sha256=t8B_amqFBqlJy8X18mhvpYkhzeSXmRsI-ep6Yg04H4M,17678
anyio/lowlevel.py,sha256=IisVkje5kwqOCpe-RgBjGCvlr-JBFGBrkobR7iZ3Fv4,4153
anyio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/pytest_plugin.py,sha256=qXNwk9Pa7hPQKWocgLl9qijqKGMkGzdH2wJa-jPkGUM,9375
anyio/streams/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/streams/__pycache__/__init__.cpython-312.pyc,,
anyio/streams/__pycache__/buffered.cpython-312.pyc,,
anyio/streams/__pycache__/file.cpython-312.pyc,,
anyio/streams/__pycache__/memory.cpython-312.pyc,,
anyio/streams/__pycache__/stapled.cpython-312.pyc,,
anyio/streams/__pycache__/text.cpython-312.pyc,,
anyio/streams/__pycache__/tls.cpython-312.pyc,,
anyio/streams/buffered.py,sha256=joUPdz0OoRfKgGmMpHI9vZyMNm6ly9iFlofrZUPs9cQ,6162
anyio/streams/file.py,sha256=6uoTNb5KbMoj-6gS3_xrrL8uZN8Q4iIvOS1WtGyFfKw,4383
anyio/streams/memory.py,sha256=GcbF3cahdsdFZtcTZaIKpZXPDZKogj18wWPPmE0OmGU,10620
anyio/streams/stapled.py,sha256=U09pCrmOw9kkNhe6tKopsm1QIMT1lFTFvtb-A7SIe4k,4302
anyio/streams/text.py,sha256=tCJ8ljavGM-HY0aL-5Twxv-Kyw1BfR0B4OtVIB6kZ9w,5662
anyio/streams/tls.py,sha256=siSaaRyX-XnfC7Jbn9VjtIdVzJkDsvIW_2pSEVheDFQ,15275
anyio/to_interpreter.py,sha256=Z0-kLCxlITjFG_RM_TNdUlEnog94l48GXVDZ80w0URc,6986
anyio/to_process.py,sha256=ZvruelRM-HNmqDaql4sdNODg2QD_uSlwSCxnV4OhsfQ,9595
anyio/to_thread.py,sha256=WM2JQ2MbVsd5D5CM08bQiTwzZIvpsGjfH1Fy247KoDQ,2396

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (80.9.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,2 @@
[pytest11]
anyio = anyio.pytest_plugin

View File

@@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2018 Alex Grönholm
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
from ._core._contextmanagers import AsyncContextManagerMixin as AsyncContextManagerMixin
from ._core._contextmanagers import ContextManagerMixin as ContextManagerMixin
from ._core._eventloop import current_time as current_time
from ._core._eventloop import get_all_backends as get_all_backends
from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class
from ._core._eventloop import run as run
from ._core._eventloop import sleep as sleep
from ._core._eventloop import sleep_forever as sleep_forever
from ._core._eventloop import sleep_until as sleep_until
from ._core._exceptions import BrokenResourceError as BrokenResourceError
from ._core._exceptions import BrokenWorkerInterpreter as BrokenWorkerInterpreter
from ._core._exceptions import BrokenWorkerProcess as BrokenWorkerProcess
from ._core._exceptions import BusyResourceError as BusyResourceError
from ._core._exceptions import ClosedResourceError as ClosedResourceError
from ._core._exceptions import ConnectionFailed as ConnectionFailed
from ._core._exceptions import DelimiterNotFound as DelimiterNotFound
from ._core._exceptions import EndOfStream as EndOfStream
from ._core._exceptions import IncompleteRead as IncompleteRead
from ._core._exceptions import TypedAttributeLookupError as TypedAttributeLookupError
from ._core._exceptions import WouldBlock as WouldBlock
from ._core._fileio import AsyncFile as AsyncFile
from ._core._fileio import Path as Path
from ._core._fileio import open_file as open_file
from ._core._fileio import wrap_file as wrap_file
from ._core._resources import aclose_forcefully as aclose_forcefully
from ._core._signals import open_signal_receiver as open_signal_receiver
from ._core._sockets import TCPConnectable as TCPConnectable
from ._core._sockets import UNIXConnectable as UNIXConnectable
from ._core._sockets import as_connectable as as_connectable
from ._core._sockets import connect_tcp as connect_tcp
from ._core._sockets import connect_unix as connect_unix
from ._core._sockets import create_connected_udp_socket as create_connected_udp_socket
from ._core._sockets import (
create_connected_unix_datagram_socket as create_connected_unix_datagram_socket,
)
from ._core._sockets import create_tcp_listener as create_tcp_listener
from ._core._sockets import create_udp_socket as create_udp_socket
from ._core._sockets import create_unix_datagram_socket as create_unix_datagram_socket
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
from ._core._sockets import notify_closing as notify_closing
from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
from ._core._sockets import wait_writable as wait_writable
from ._core._streams import create_memory_object_stream as create_memory_object_stream
from ._core._subprocesses import open_process as open_process
from ._core._subprocesses import run_process as run_process
from ._core._synchronization import CapacityLimiter as CapacityLimiter
from ._core._synchronization import (
CapacityLimiterStatistics as CapacityLimiterStatistics,
)
from ._core._synchronization import Condition as Condition
from ._core._synchronization import ConditionStatistics as ConditionStatistics
from ._core._synchronization import Event as Event
from ._core._synchronization import EventStatistics as EventStatistics
from ._core._synchronization import Lock as Lock
from ._core._synchronization import LockStatistics as LockStatistics
from ._core._synchronization import ResourceGuard as ResourceGuard
from ._core._synchronization import Semaphore as Semaphore
from ._core._synchronization import SemaphoreStatistics as SemaphoreStatistics
from ._core._tasks import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED
from ._core._tasks import CancelScope as CancelScope
from ._core._tasks import create_task_group as create_task_group
from ._core._tasks import current_effective_deadline as current_effective_deadline
from ._core._tasks import fail_after as fail_after
from ._core._tasks import move_on_after as move_on_after
from ._core._tempfile import NamedTemporaryFile as NamedTemporaryFile
from ._core._tempfile import SpooledTemporaryFile as SpooledTemporaryFile
from ._core._tempfile import TemporaryDirectory as TemporaryDirectory
from ._core._tempfile import TemporaryFile as TemporaryFile
from ._core._tempfile import gettempdir as gettempdir
from ._core._tempfile import gettempdirb as gettempdirb
from ._core._tempfile import mkdtemp as mkdtemp
from ._core._tempfile import mkstemp as mkstemp
from ._core._testing import TaskInfo as TaskInfo
from ._core._testing import get_current_task as get_current_task
from ._core._testing import get_running_tasks as get_running_tasks
from ._core._testing import wait_all_tasks_blocked as wait_all_tasks_blocked
from ._core._typedattr import TypedAttributeProvider as TypedAttributeProvider
from ._core._typedattr import TypedAttributeSet as TypedAttributeSet
from ._core._typedattr import typed_attribute as typed_attribute
# Re-export imports so they look like they live directly in this package
for __value in list(locals().values()):
if getattr(__value, "__module__", "").startswith("anyio."):
__value.__module__ = __name__
del __value
def __getattr__(attr: str) -> type[BrokenWorkerInterpreter]:
"""Support deprecated aliases."""
if attr == "BrokenWorkerIntepreter":
import warnings
warnings.warn(
"The 'BrokenWorkerIntepreter' alias is deprecated, use 'BrokenWorkerInterpreter' instead.",
DeprecationWarning,
stacklevel=2,
)
return BrokenWorkerInterpreter
raise AttributeError(f"module {__name__!r} has no attribute {attr!r}")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import asyncio
import socket
import threading
from collections.abc import Callable
from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from _typeshed import FileDescriptorLike
_selector_lock = threading.Lock()
_selector: Selector | None = None
class Selector:
def __init__(self) -> None:
self._thread = threading.Thread(target=self.run, name="AnyIO socket selector")
self._selector = DefaultSelector()
self._send, self._receive = socket.socketpair()
self._send.setblocking(False)
self._receive.setblocking(False)
# This somewhat reduces the amount of memory wasted queueing up data
# for wakeups. With these settings, maximum number of 1-byte sends
# before getting BlockingIOError:
# Linux 4.8: 6
# macOS (darwin 15.5): 1
# Windows 10: 525347
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
# blocking, even on non-blocking sockets, so don't do that.)
self._receive.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
self._send.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
# On Windows this is a TCP socket so this might matter. On other
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
try:
self._send.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError:
pass
self._selector.register(self._receive, EVENT_READ)
self._closed = False
def start(self) -> None:
self._thread.start()
threading._register_atexit(self._stop) # type: ignore[attr-defined]
def _stop(self) -> None:
global _selector
self._closed = True
self._notify_self()
self._send.close()
self._thread.join()
self._selector.unregister(self._receive)
self._receive.close()
self._selector.close()
_selector = None
assert not self._selector.get_map(), (
"selector still has registered file descriptors after shutdown"
)
def _notify_self(self) -> None:
try:
self._send.send(b"\x00")
except BlockingIOError:
pass
def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)})
else:
if EVENT_READ in key.data:
raise ValueError(
"this file descriptor is already registered for reading"
)
key.data[EVENT_READ] = loop, callback
self._selector.modify(fd, key.events | EVENT_READ, key.data)
self._notify_self()
def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)})
else:
if EVENT_WRITE in key.data:
raise ValueError(
"this file descriptor is already registered for writing"
)
key.data[EVENT_WRITE] = loop, callback
self._selector.modify(fd, key.events | EVENT_WRITE, key.data)
self._notify_self()
def remove_reader(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False
if new_events := key.events ^ EVENT_READ:
del key.data[EVENT_READ]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)
return True
def remove_writer(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False
if new_events := key.events ^ EVENT_WRITE:
del key.data[EVENT_WRITE]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)
return True
def run(self) -> None:
while not self._closed:
for key, events in self._selector.select():
if key.fileobj is self._receive:
try:
while self._receive.recv(4096):
pass
except BlockingIOError:
pass
continue
if events & EVENT_READ:
loop, callback = key.data[EVENT_READ]
self.remove_reader(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed
if events & EVENT_WRITE:
loop, callback = key.data[EVENT_WRITE]
self.remove_writer(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed
def get_selector() -> Selector:
global _selector
with _selector_lock:
if _selector is None:
_selector = Selector()
_selector.start()
return _selector

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
from abc import abstractmethod
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from inspect import isasyncgen, iscoroutine, isgenerator
from types import TracebackType
from typing import Protocol, TypeVar, cast, final
_T_co = TypeVar("_T_co", covariant=True)
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound="bool | None")
class _SupportsCtxMgr(Protocol[_T_co, _ExitT_co]):
def __contextmanager__(self) -> AbstractContextManager[_T_co, _ExitT_co]: ...
class _SupportsAsyncCtxMgr(Protocol[_T_co, _ExitT_co]):
def __asynccontextmanager__(
self,
) -> AbstractAsyncContextManager[_T_co, _ExitT_co]: ...
class ContextManagerMixin:
"""
Mixin class providing context manager functionality via a generator-based
implementation.
This class allows you to implement a context manager via :meth:`__contextmanager__`
which should return a generator. The mechanics are meant to mirror those of
:func:`@contextmanager <contextlib.contextmanager>`.
.. note:: Classes using this mix-in are not reentrant as context managers, meaning
that once you enter it, you can't re-enter before first exiting it.
.. seealso:: :doc:`contextmanagers`
"""
__cm: AbstractContextManager[object, bool | None] | None = None
@final
def __enter__(self: _SupportsCtxMgr[_T_co, bool | None]) -> _T_co:
# Needed for mypy to assume self still has the __cm member
assert isinstance(self, ContextManagerMixin)
if self.__cm is not None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has already been entered"
)
cm = self.__contextmanager__()
if not isinstance(cm, AbstractContextManager):
if isgenerator(cm):
raise TypeError(
"__contextmanager__() returned a generator object instead of "
"a context manager. Did you forget to add the @contextmanager "
"decorator?"
)
raise TypeError(
f"__contextmanager__() did not return a context manager object, "
f"but {cm.__class__!r}"
)
if cm is self:
raise TypeError(
f"{self.__class__.__qualname__}.__contextmanager__() returned "
f"self. Did you forget to add the @contextmanager decorator and a "
f"'yield' statement?"
)
value = cm.__enter__()
self.__cm = cm
return value
@final
def __exit__(
self: _SupportsCtxMgr[object, _ExitT_co],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> _ExitT_co:
# Needed for mypy to assume self still has the __cm member
assert isinstance(self, ContextManagerMixin)
if self.__cm is None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has not been entered yet"
)
# Prevent circular references
cm = self.__cm
del self.__cm
return cast(_ExitT_co, cm.__exit__(exc_type, exc_val, exc_tb))
@abstractmethod
def __contextmanager__(self) -> AbstractContextManager[object, bool | None]:
"""
Implement your context manager logic here.
This method **must** be decorated with
:func:`@contextmanager <contextlib.contextmanager>`.
.. note:: Remember that the ``yield`` will raise any exception raised in the
enclosed context block, so use a ``finally:`` block to clean up resources!
:return: a context manager object
"""
class AsyncContextManagerMixin:
"""
Mixin class providing async context manager functionality via a generator-based
implementation.
This class allows you to implement a context manager via
:meth:`__asynccontextmanager__`. The mechanics are meant to mirror those of
:func:`@asynccontextmanager <contextlib.asynccontextmanager>`.
.. note:: Classes using this mix-in are not reentrant as context managers, meaning
that once you enter it, you can't re-enter before first exiting it.
.. seealso:: :doc:`contextmanagers`
"""
__cm: AbstractAsyncContextManager[object, bool | None] | None = None
@final
async def __aenter__(self: _SupportsAsyncCtxMgr[_T_co, bool | None]) -> _T_co:
# Needed for mypy to assume self still has the __cm member
assert isinstance(self, AsyncContextManagerMixin)
if self.__cm is not None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has already been entered"
)
cm = self.__asynccontextmanager__()
if not isinstance(cm, AbstractAsyncContextManager):
if isasyncgen(cm):
raise TypeError(
"__asynccontextmanager__() returned an async generator instead of "
"an async context manager. Did you forget to add the "
"@asynccontextmanager decorator?"
)
elif iscoroutine(cm):
cm.close()
raise TypeError(
"__asynccontextmanager__() returned a coroutine object instead of "
"an async context manager. Did you forget to add the "
"@asynccontextmanager decorator and a 'yield' statement?"
)
raise TypeError(
f"__asynccontextmanager__() did not return an async context manager, "
f"but {cm.__class__!r}"
)
if cm is self:
raise TypeError(
f"{self.__class__.__qualname__}.__asynccontextmanager__() returned "
f"self. Did you forget to add the @asynccontextmanager decorator and a "
f"'yield' statement?"
)
value = await cm.__aenter__()
self.__cm = cm
return value
@final
async def __aexit__(
self: _SupportsAsyncCtxMgr[object, _ExitT_co],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> _ExitT_co:
assert isinstance(self, AsyncContextManagerMixin)
if self.__cm is None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has not been entered yet"
)
# Prevent circular references
cm = self.__cm
del self.__cm
return cast(_ExitT_co, await cm.__aexit__(exc_type, exc_val, exc_tb))
@abstractmethod
def __asynccontextmanager__(
self,
) -> AbstractAsyncContextManager[object, bool | None]:
"""
Implement your async context manager logic here.
This method **must** be decorated with
:func:`@asynccontextmanager <contextlib.asynccontextmanager>`.
.. note:: Remember that the ``yield`` will raise any exception raised in the
enclosed context block, so use a ``finally:`` block to clean up resources!
:return: an async context manager object
"""

Some files were not shown because too many files have changed in this diff Show More