This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -24,4 +24,6 @@ bot.db
|
|||||||
tests/
|
tests/
|
||||||
|
|
||||||
# Git
|
# Git
|
||||||
.git/
|
.git/
|
||||||
|
.venv/
|
||||||
|
.history/
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
.env
|
|
||||||
.env.local
|
|
||||||
.env.*.local
|
|
||||||
bot.db
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
tests/
|
|
||||||
.git/
|
|
||||||
.gitignore
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
.env
|
|
||||||
.env.local
|
|
||||||
.env.*.local
|
|
||||||
bot.db
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
tests/
|
|
||||||
.git/
|
|
||||||
.gitignore
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
kind: pipeline
|
|
||||||
name: default
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: build
|
|
||||||
image: docker:dind
|
|
||||||
privileged: true
|
|
||||||
commands:
|
|
||||||
- docker build -t post_bot .
|
|
||||||
- name: test
|
|
||||||
image: python:3.11-slim
|
|
||||||
commands:
|
|
||||||
- pip install --no-cache-dir python-telegram-bot sqlalchemy python-dotenv pytest
|
|
||||||
- pytest tests/
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
kind: pipeline
|
|
||||||
name: default
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: build
|
|
||||||
image: docker:dind
|
|
||||||
privileged: true
|
|
||||||
commands:
|
|
||||||
- docker build -t post_bot .
|
|
||||||
- name: test
|
|
||||||
image: python:3.11-slim
|
|
||||||
commands:
|
|
||||||
- pip install --no-cache-dir python-telegram-bot sqlalchemy python-dotenv pytest
|
|
||||||
- pytest tests/
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
TELEGRAM_TOKEN=your_token_here
|
|
||||||
DATABASE_URL=sqlite:///bot.db
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
TELEGRAM_TOKEN=your_token_here
|
|
||||||
DATABASE_URL=sqlite:///bot.db
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
TELEGRAM_TOKEN=your_token_here
|
|
||||||
DATABASE_URL=sqlite:///bot.db
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
TELEGRAM_TOKEN=6414100562:AAFxeXt331_sKf8ui1EJve9vinUHyKHBjiU
|
|
||||||
DATABASE_URL=sqlite:///bot.db
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
TELEGRAM_TOKEN=6414100562:AAEbDJZnFFJfddS0SV1xA4L0MQqxOArU4a0
|
|
||||||
DATABASE_URL=sqlite:///bot.db
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
.env
|
|
||||||
.venv/
|
|
||||||
.history
|
|
||||||
__pycache__/
|
|
||||||
.db
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
|
|
||||||
# Python
|
|
||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
|
|
||||||
# Env
|
|
||||||
.env
|
|
||||||
.env.local
|
|
||||||
.env.*.local
|
|
||||||
|
|
||||||
# DB
|
|
||||||
bot.db
|
|
||||||
|
|
||||||
# IDE
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
|
|
||||||
# Docker
|
|
||||||
*.log
|
|
||||||
|
|
||||||
# Tests
|
|
||||||
tests/
|
|
||||||
|
|
||||||
# Git
|
|
||||||
.git/
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
|
|
||||||
# Python
|
|
||||||
__pycache__/
|
|
||||||
*.pyc
|
|
||||||
*.pyo
|
|
||||||
*.pyd
|
|
||||||
|
|
||||||
# Env
|
|
||||||
.env
|
|
||||||
.env.local
|
|
||||||
.env.*.local
|
|
||||||
|
|
||||||
# DB
|
|
||||||
bot.db
|
|
||||||
|
|
||||||
# IDE
|
|
||||||
.idea/
|
|
||||||
.vscode/
|
|
||||||
|
|
||||||
# Docker
|
|
||||||
*.log
|
|
||||||
|
|
||||||
# Tests
|
|
||||||
tests/
|
|
||||||
|
|
||||||
# Git
|
|
||||||
.git/
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
FROM python:3.11-slim
|
|
||||||
WORKDIR /app
|
|
||||||
COPY . /app
|
|
||||||
RUN pip install --no-cache-dir python-telegram-bot sqlalchemy python-dotenv
|
|
||||||
CMD ["python", "main.py"]
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
FROM python:3.11-slim
|
|
||||||
WORKDIR /app
|
|
||||||
COPY . /app
|
|
||||||
RUN pip install --no-cache-dir python-telegram-bot sqlalchemy python-dotenv
|
|
||||||
CMD ["python", "main.py"]
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
import os
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from models import Base
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///bot.db')
|
|
||||||
engine = create_engine(DATABASE_URL, echo=True)
|
|
||||||
SessionLocal = sessionmaker(bind=engine)
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
import os
|
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
from models import Base
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///bot.db')
|
|
||||||
engine = create_engine(DATABASE_URL, echo=True)
|
|
||||||
SessionLocal = sessionmaker(bind=engine)
|
|
||||||
|
|
||||||
def init_db():
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
version: '3.8'
|
|
||||||
services:
|
|
||||||
bot:
|
|
||||||
build: .
|
|
||||||
env_file:
|
|
||||||
- .env.example
|
|
||||||
volumes:
|
|
||||||
- ./bot.db:/app/bot.db
|
|
||||||
restart: unless-stopped
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
version: '3.8'
|
|
||||||
services:
|
|
||||||
bot:
|
|
||||||
build: .
|
|
||||||
env_file:
|
|
||||||
- .env.example
|
|
||||||
volumes:
|
|
||||||
- ./bot.db:/app/bot.db
|
|
||||||
restart: unless-stopped
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data['button_name']
|
|
||||||
target = context.user_data['target']
|
|
||||||
session = SessionLocal()
|
|
||||||
if target.startswith('channel_'):
|
|
||||||
channel_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, channel_id=channel_id)
|
|
||||||
else:
|
|
||||||
group_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, group_id=group_id)
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data['button_name']
|
|
||||||
target = context.user_data['target']
|
|
||||||
session = SessionLocal()
|
|
||||||
if target.startswith('channel_'):
|
|
||||||
channel_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, channel_id=channel_id)
|
|
||||||
else:
|
|
||||||
group_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, group_id=group_id)
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data['button_name']
|
|
||||||
target = context.user_data['target']
|
|
||||||
session = SessionLocal()
|
|
||||||
if target.startswith('channel_'):
|
|
||||||
channel_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, channel_id=channel_id)
|
|
||||||
else:
|
|
||||||
group_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, group_id=group_id)
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data['button_name']
|
|
||||||
target = context.user_data['target']
|
|
||||||
session = SessionLocal()
|
|
||||||
if target.startswith('channel_'):
|
|
||||||
channel_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, channel_id=channel_id)
|
|
||||||
else:
|
|
||||||
group_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, group_id=group_id)
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
# Если выбран канал или группа уже сохранены — сразу переход к названию
|
|
||||||
if context.user_data.get('channel_id'):
|
|
||||||
context.user_data['target'] = f"channel_{context.user_data['channel_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
elif context.user_data.get('group_id'):
|
|
||||||
context.user_data['target'] = f"group_{context.user_data['group_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
# Если нет — стандартный выбор
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data['button_name']
|
|
||||||
target = context.user_data['target']
|
|
||||||
session = SessionLocal()
|
|
||||||
if target.startswith('channel_'):
|
|
||||||
channel_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, channel_id=channel_id)
|
|
||||||
else:
|
|
||||||
group_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, group_id=group_id)
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
# Если выбран канал или группа уже сохранены — сразу переход к названию
|
|
||||||
if context.user_data.get('channel_id'):
|
|
||||||
context.user_data['target'] = f"channel_{context.user_data['channel_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
elif context.user_data.get('group_id'):
|
|
||||||
context.user_data['target'] = f"group_{context.user_data['group_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
# Если нет — стандартный выбор
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data['button_name']
|
|
||||||
target = context.user_data['target']
|
|
||||||
session = SessionLocal()
|
|
||||||
if target.startswith('channel_'):
|
|
||||||
channel_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, channel_id=channel_id)
|
|
||||||
else:
|
|
||||||
group_id = int(target.split('_')[1])
|
|
||||||
button = Button(name=name, url=url, group_id=group_id)
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
# Если выбран канал или группа уже сохранены — сразу переход к названию
|
|
||||||
if context.user_data.get('channel_id'):
|
|
||||||
context.user_data['target'] = f"channel_{context.user_data['channel_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
elif context.user_data.get('group_id'):
|
|
||||||
context.user_data['target'] = f"group_{context.user_data['group_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
# Если нет — стандартный выбор
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data.get('button_name')
|
|
||||||
target = context.user_data.get('target')
|
|
||||||
if not target or ('_' not in target):
|
|
||||||
await update.message.reply_text('Ошибка: не выбран канал или группа. Попробуйте снова.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
session = SessionLocal()
|
|
||||||
try:
|
|
||||||
type_, obj_id = target.split('_', 1)
|
|
||||||
obj_id = int(obj_id)
|
|
||||||
if type_ == 'channel':
|
|
||||||
button = Button(name=name, url=url, channel_id=obj_id)
|
|
||||||
elif type_ == 'group':
|
|
||||||
button = Button(name=name, url=url, group_id=obj_id)
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Ошибка: неверный тип объекта.')
|
|
||||||
session.close()
|
|
||||||
return ConversationHandler.END
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
except Exception as e:
|
|
||||||
await update.message.reply_text(f'Ошибка при добавлении кнопки: {e}')
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, CommandHandler, CallbackQueryHandler, MessageHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_TARGET, INPUT_NAME, INPUT_URL = range(3)
|
|
||||||
|
|
||||||
async def add_button_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
# Если выбран канал или группа уже сохранены — сразу переход к названию
|
|
||||||
if context.user_data.get('channel_id'):
|
|
||||||
context.user_data['target'] = f"channel_{context.user_data['channel_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
elif context.user_data.get('group_id'):
|
|
||||||
context.user_data['target'] = f"group_{context.user_data['group_id']}"
|
|
||||||
await update.message.reply_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
# Если нет — стандартный выбор
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([{'text': f'Канал: {c.name}', 'callback_data': f'channel_{c.id}'}])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([{'text': f'Группа: {g.name}', 'callback_data': f'group_{g.id}'}])
|
|
||||||
if not keyboard:
|
|
||||||
await update.message.reply_text('Нет каналов или групп для добавления кнопки.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
await update.message.reply_text('Выберите канал или группу:', reply_markup=None)
|
|
||||||
context.user_data['keyboard'] = keyboard
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
context.user_data['target'] = data
|
|
||||||
await query.edit_message_text('Введите название кнопки:')
|
|
||||||
return INPUT_NAME
|
|
||||||
|
|
||||||
async def input_name(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['button_name'] = update.message.text
|
|
||||||
await update.message.reply_text('Введите ссылку для кнопки:')
|
|
||||||
return INPUT_URL
|
|
||||||
|
|
||||||
async def input_url(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
url = update.message.text
|
|
||||||
name = context.user_data.get('button_name')
|
|
||||||
target = context.user_data.get('target')
|
|
||||||
if not target or ('_' not in target):
|
|
||||||
await update.message.reply_text('Ошибка: не выбран канал или группа. Попробуйте снова.')
|
|
||||||
return ConversationHandler.END
|
|
||||||
session = SessionLocal()
|
|
||||||
try:
|
|
||||||
type_, obj_id = target.split('_', 1)
|
|
||||||
obj_id = int(obj_id)
|
|
||||||
if type_ == 'channel':
|
|
||||||
button = Button(name=name, url=url, channel_id=obj_id)
|
|
||||||
elif type_ == 'group':
|
|
||||||
button = Button(name=name, url=url, group_id=obj_id)
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Ошибка: неверный тип объекта.')
|
|
||||||
session.close()
|
|
||||||
return ConversationHandler.END
|
|
||||||
session.add(button)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Кнопка добавлена.')
|
|
||||||
except Exception as e:
|
|
||||||
await update.message.reply_text(f'Ошибка при добавлении кнопки: {e}')
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
add_button_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('add_button', add_button_start)],
|
|
||||||
states={
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
INPUT_NAME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_name)],
|
|
||||||
INPUT_URL: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_url)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel
|
|
||||||
|
|
||||||
async def add_channel(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if len(args) < 2:
|
|
||||||
await update.message.reply_text('Используйте: /add_channel <название> <ссылка>')
|
|
||||||
return
|
|
||||||
name, link = args[0], args[1]
|
|
||||||
session = SessionLocal()
|
|
||||||
channel = Channel(name=name, link=link)
|
|
||||||
session.add(channel)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Канал "{name}" добавлен.')
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel
|
|
||||||
|
|
||||||
async def add_channel(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if len(args) < 2:
|
|
||||||
await update.message.reply_text('Используйте: /add_channel <название> <ссылка>')
|
|
||||||
return
|
|
||||||
name, link = args[0], args[1]
|
|
||||||
session = SessionLocal()
|
|
||||||
channel = Channel(name=name, link=link)
|
|
||||||
session.add(channel)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Канал "{name}" добавлен.')
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Group
|
|
||||||
|
|
||||||
async def add_group(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if len(args) < 2:
|
|
||||||
await update.message.reply_text('Используйте: /add_group <название> <ссылка>')
|
|
||||||
return
|
|
||||||
name, link = args[0], args[1]
|
|
||||||
session = SessionLocal()
|
|
||||||
group = Group(name=name, link=link)
|
|
||||||
session.add(group)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Группа "{name}" добавлена.')
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Group
|
|
||||||
|
|
||||||
async def add_group(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if len(args) < 2:
|
|
||||||
await update.message.reply_text('Используйте: /add_group <название> <ссылка>')
|
|
||||||
return
|
|
||||||
name, link = args[0], args[1]
|
|
||||||
session = SessionLocal()
|
|
||||||
group = Group(name=name, link=link)
|
|
||||||
session.add(group)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Группа "{name}" добавлена.')
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import CommandHandler, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Button
|
|
||||||
|
|
||||||
SELECT_CHANNEL, MANAGE_BUTTONS = range(2)
|
|
||||||
|
|
||||||
async def channel_buttons_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = [[InlineKeyboardButton(c.name, callback_data=str(c.id))] for c in channels]
|
|
||||||
await update.message.reply_text(
|
|
||||||
"Выберите канал для настройки клавиатуры:",
|
|
||||||
reply_markup=InlineKeyboardMarkup(keyboard)
|
|
||||||
)
|
|
||||||
return SELECT_CHANNEL
|
|
||||||
|
|
||||||
async def select_channel(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
channel_id = int(query.data)
|
|
||||||
context.user_data['channel_id'] = channel_id
|
|
||||||
session = SessionLocal()
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
session.close()
|
|
||||||
text = "Кнопки этого канала:\n"
|
|
||||||
for b in buttons:
|
|
||||||
text += f"- {b.name}: {b.url}\n"
|
|
||||||
text += "\nДобавить новую кнопку: /add_button\nУдалить: /del_button <название>"
|
|
||||||
await query.edit_message_text(text)
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
channel_buttons_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('channel_buttons', channel_buttons_start)],
|
|
||||||
states={
|
|
||||||
SELECT_CHANNEL: [CallbackQueryHandler(select_channel)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import CommandHandler, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Button
|
|
||||||
|
|
||||||
SELECT_CHANNEL, MANAGE_BUTTONS = range(2)
|
|
||||||
|
|
||||||
async def channel_buttons_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = [[InlineKeyboardButton(c.name, callback_data=str(c.id))] for c in channels]
|
|
||||||
await update.message.reply_text(
|
|
||||||
"Выберите канал для настройки клавиатуры:",
|
|
||||||
reply_markup=InlineKeyboardMarkup(keyboard)
|
|
||||||
)
|
|
||||||
return SELECT_CHANNEL
|
|
||||||
|
|
||||||
async def select_channel(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
channel_id = int(query.data)
|
|
||||||
context.user_data['channel_id'] = channel_id
|
|
||||||
session = SessionLocal()
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
session.close()
|
|
||||||
text = "Кнопки этого канала:\n"
|
|
||||||
for b in buttons:
|
|
||||||
text += f"- {b.name}: {b.url}\n"
|
|
||||||
text += "\nДобавить новую кнопку: /add_button\nУдалить: /del_button <название>"
|
|
||||||
await query.edit_message_text(text)
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
channel_buttons_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('channel_buttons', channel_buttons_start)],
|
|
||||||
states={
|
|
||||||
SELECT_CHANNEL: [CallbackQueryHandler(select_channel)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import CommandHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Button
|
|
||||||
|
|
||||||
async def del_button(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if not args:
|
|
||||||
await update.message.reply_text('Используйте: /del_button <название>')
|
|
||||||
return
|
|
||||||
name = args[0]
|
|
||||||
session = SessionLocal()
|
|
||||||
button = session.query(Button).filter_by(name=name).first()
|
|
||||||
if not button:
|
|
||||||
await update.message.reply_text('Кнопка не найдена.')
|
|
||||||
session.close()
|
|
||||||
return
|
|
||||||
session.delete(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Кнопка "{name}" удалена.')
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import CommandHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Button
|
|
||||||
|
|
||||||
async def del_button(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if not args:
|
|
||||||
await update.message.reply_text('Используйте: /del_button <название>')
|
|
||||||
return
|
|
||||||
name = args[0]
|
|
||||||
session = SessionLocal()
|
|
||||||
button = session.query(Button).filter_by(name=name).first()
|
|
||||||
if not button:
|
|
||||||
await update.message.reply_text('Кнопка не найдена.')
|
|
||||||
session.close()
|
|
||||||
return
|
|
||||||
session.delete(button)
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Кнопка "{name}" удалена.')
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import CommandHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Button
|
|
||||||
|
|
||||||
async def edit_button(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if len(args) < 3:
|
|
||||||
await update.message.reply_text('Используйте: /edit_button <название> <новое_название> <новая_ссылка>')
|
|
||||||
return
|
|
||||||
name, new_name, new_url = args[0], args[1], args[2]
|
|
||||||
session = SessionLocal()
|
|
||||||
button = session.query(Button).filter_by(name=name).first()
|
|
||||||
if not button:
|
|
||||||
await update.message.reply_text('Кнопка не найдена.')
|
|
||||||
session.close()
|
|
||||||
return
|
|
||||||
button.name = new_name
|
|
||||||
button.url = new_url
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Кнопка "{name}" изменена.')
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from telegram import Update
|
|
||||||
from telegram.ext import CommandHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Button
|
|
||||||
|
|
||||||
async def edit_button(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
args = context.args
|
|
||||||
if len(args) < 3:
|
|
||||||
await update.message.reply_text('Используйте: /edit_button <название> <новое_название> <новая_ссылка>')
|
|
||||||
return
|
|
||||||
name, new_name, new_url = args[0], args[1], args[2]
|
|
||||||
session = SessionLocal()
|
|
||||||
button = session.query(Button).filter_by(name=name).first()
|
|
||||||
if not button:
|
|
||||||
await update.message.reply_text('Кнопка не найдена.')
|
|
||||||
session.close()
|
|
||||||
return
|
|
||||||
button.name = new_name
|
|
||||||
button.url = new_url
|
|
||||||
session.commit()
|
|
||||||
session.close()
|
|
||||||
await update.message.reply_text(f'Кнопка "{name}" изменена.')
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
# Заглушка для будущей реализации редактирования поста
|
|
||||||
# Можно реализовать хранение черновиков и их редактирование
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
# Заглушка для будущей реализации редактирования поста
|
|
||||||
# Можно реализовать хранение черновиков и их редактирование
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import CommandHandler, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Group, Button
|
|
||||||
|
|
||||||
SELECT_GROUP, MANAGE_BUTTONS = range(2)
|
|
||||||
|
|
||||||
async def group_buttons_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = [[InlineKeyboardButton(g.name, callback_data=str(g.id))] for g in groups]
|
|
||||||
await update.message.reply_text(
|
|
||||||
"Выберите группу для настройки клавиатуры:",
|
|
||||||
reply_markup=InlineKeyboardMarkup(keyboard)
|
|
||||||
)
|
|
||||||
return SELECT_GROUP
|
|
||||||
|
|
||||||
async def select_group(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
group_id = int(query.data)
|
|
||||||
context.user_data['group_id'] = group_id
|
|
||||||
session = SessionLocal()
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
session.close()
|
|
||||||
text = "Кнопки этой группы:\n"
|
|
||||||
for b in buttons:
|
|
||||||
text += f"- {b.name}: {b.url}\n"
|
|
||||||
text += "\nДобавить новую кнопку: /add_button\nУдалить: /del_button <название>"
|
|
||||||
await query.edit_message_text(text)
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
group_buttons_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('group_buttons', group_buttons_start)],
|
|
||||||
states={
|
|
||||||
SELECT_GROUP: [CallbackQueryHandler(select_group)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import CommandHandler, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Group, Button
|
|
||||||
|
|
||||||
SELECT_GROUP, MANAGE_BUTTONS = range(2)
|
|
||||||
|
|
||||||
async def group_buttons_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = [[InlineKeyboardButton(g.name, callback_data=str(g.id))] for g in groups]
|
|
||||||
await update.message.reply_text(
|
|
||||||
"Выберите группу для настройки клавиатуры:",
|
|
||||||
reply_markup=InlineKeyboardMarkup(keyboard)
|
|
||||||
)
|
|
||||||
return SELECT_GROUP
|
|
||||||
|
|
||||||
async def select_group(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
group_id = int(query.data)
|
|
||||||
context.user_data['group_id'] = group_id
|
|
||||||
session = SessionLocal()
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
session.close()
|
|
||||||
text = "Кнопки этой группы:\n"
|
|
||||||
for b in buttons:
|
|
||||||
text += f"- {b.name}: {b.url}\n"
|
|
||||||
text += "\nДобавить новую кнопку: /add_button\nУдалить: /del_button <название>"
|
|
||||||
await query.edit_message_text(text)
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
group_buttons_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('group_buttons', group_buttons_start)],
|
|
||||||
states={
|
|
||||||
SELECT_GROUP: [CallbackQueryHandler(select_group)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
from telegram import Update, InputMediaPhoto, InlineKeyboardMarkup, InlineKeyboardButton
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, MessageHandler, CommandHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_MEDIA, SELECT_TEXT, SELECT_TARGET = range(3)
|
|
||||||
|
|
||||||
async def new_post_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
await update.message.reply_text('Отправьте картинку для поста или /skip:')
|
|
||||||
return SELECT_MEDIA
|
|
||||||
|
|
||||||
async def select_media(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
if update.message.photo:
|
|
||||||
context.user_data['photo'] = update.message.photo[-1].file_id
|
|
||||||
await update.message.reply_text('Введите текст поста или пересланное сообщение:')
|
|
||||||
return SELECT_TEXT
|
|
||||||
|
|
||||||
async def select_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['text'] = update.message.text or update.message.caption
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Канал: {c.name}', callback_data=f'channel_{c.id}')])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Группа: {g.name}', callback_data=f'group_{g.id}')])
|
|
||||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
|
||||||
await update.message.reply_text('Выберите, куда отправить пост:', reply_markup=reply_markup)
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
session = SessionLocal()
|
|
||||||
if data.startswith('channel_'):
|
|
||||||
channel_id = int(data.split('_')[1])
|
|
||||||
channel = session.query(Channel).get(channel_id)
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
await context.bot.send_photo(chat_id=channel.link, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
else:
|
|
||||||
group_id = int(data.split('_')[1])
|
|
||||||
group = session.query(Group).get(group_id)
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
await context.bot.send_photo(chat_id=group.link, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
session.close()
|
|
||||||
await query.edit_message_text('Пост отправлен!')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
new_post_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('new_post', new_post_start)],
|
|
||||||
states={
|
|
||||||
SELECT_MEDIA: [MessageHandler(filters.PHOTO | filters.Document.IMAGE | filters.COMMAND, select_media)],
|
|
||||||
SELECT_TEXT: [MessageHandler(filters.TEXT | filters.FORWARDED, select_text)],
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
from telegram import Update, InputMediaPhoto, InlineKeyboardMarkup, InlineKeyboardButton
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, MessageHandler, CommandHandler, filters
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_MEDIA, SELECT_TEXT, SELECT_TARGET = range(3)
|
|
||||||
|
|
||||||
async def new_post_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
await update.message.reply_text('Отправьте картинку для поста или /skip:')
|
|
||||||
return SELECT_MEDIA
|
|
||||||
|
|
||||||
async def select_media(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
if update.message.photo:
|
|
||||||
context.user_data['photo'] = update.message.photo[-1].file_id
|
|
||||||
await update.message.reply_text('Введите текст поста или пересланное сообщение:')
|
|
||||||
return SELECT_TEXT
|
|
||||||
|
|
||||||
async def select_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['text'] = update.message.text or update.message.caption
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Канал: {c.name}', callback_data=f'channel_{c.id}')])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Группа: {g.name}', callback_data=f'group_{g.id}')])
|
|
||||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
|
||||||
await update.message.reply_text('Выберите, куда отправить пост:', reply_markup=reply_markup)
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
session = SessionLocal()
|
|
||||||
if data.startswith('channel_'):
|
|
||||||
channel_id = int(data.split('_')[1])
|
|
||||||
channel = session.query(Channel).get(channel_id)
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
await context.bot.send_photo(chat_id=channel.link, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
else:
|
|
||||||
group_id = int(data.split('_')[1])
|
|
||||||
group = session.query(Group).get(group_id)
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
await context.bot.send_photo(chat_id=group.link, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
session.close()
|
|
||||||
await query.edit_message_text('Пост отправлен!')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
new_post_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('new_post', new_post_start)],
|
|
||||||
states={
|
|
||||||
SELECT_MEDIA: [MessageHandler(filters.PHOTO | filters.Document.IMAGE | filters.COMMAND, select_media)],
|
|
||||||
SELECT_TEXT: [MessageHandler(filters.TEXT | filters.FORWARDED, select_text)],
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from telegram import Update, InputMediaPhoto, InlineKeyboardMarkup, InlineKeyboardButton
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, MessageHandler, CommandHandler, filters, CallbackQueryHandler, ContextTypes
|
|
||||||
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_MEDIA, SELECT_TEXT, SELECT_TARGET = range(3)
|
|
||||||
|
|
||||||
async def new_post_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
await update.message.reply_text('Отправьте картинку для поста или /skip:')
|
|
||||||
return SELECT_MEDIA
|
|
||||||
|
|
||||||
async def select_media(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
if update.message.photo:
|
|
||||||
context.user_data['photo'] = update.message.photo[-1].file_id
|
|
||||||
await update.message.reply_text('Введите текст поста или пересланное сообщение:')
|
|
||||||
return SELECT_TEXT
|
|
||||||
|
|
||||||
async def select_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['text'] = update.message.text or update.message.caption
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Канал: {c.name}', callback_data=f'channel_{c.id}')])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Группа: {g.name}', callback_data=f'group_{g.id}')])
|
|
||||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
|
||||||
await update.message.reply_text('Выберите, куда отправить пост:', reply_markup=reply_markup)
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
session = SessionLocal()
|
|
||||||
if data.startswith('channel_'):
|
|
||||||
channel_id = int(data.split('_')[1])
|
|
||||||
channel = session.query(Channel).get(channel_id)
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
await context.bot.send_photo(chat_id=channel.link, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
else:
|
|
||||||
group_id = int(data.split('_')[1])
|
|
||||||
group = session.query(Group).get(group_id)
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
await context.bot.send_photo(chat_id=group.link, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
session.close()
|
|
||||||
await query.edit_message_text('Пост отправлен!')
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
new_post_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('new_post', new_post_start)],
|
|
||||||
states={
|
|
||||||
SELECT_MEDIA: [MessageHandler(filters.PHOTO | filters.Document.IMAGE | filters.COMMAND, select_media)],
|
|
||||||
SELECT_TEXT: [MessageHandler(filters.TEXT | filters.FORWARDED, select_text)],
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
from telegram import Update, InputMediaPhoto, InlineKeyboardMarkup, InlineKeyboardButton
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, MessageHandler, CommandHandler, filters, CallbackQueryHandler, ContextTypes
|
|
||||||
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_MEDIA, SELECT_TEXT, SELECT_TARGET = range(3)
|
|
||||||
|
|
||||||
async def new_post_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
await update.message.reply_text('Отправьте картинку для поста или /skip:')
|
|
||||||
return SELECT_MEDIA
|
|
||||||
|
|
||||||
async def select_media(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
if update.message.photo:
|
|
||||||
context.user_data['photo'] = update.message.photo[-1].file_id
|
|
||||||
await update.message.reply_text('Введите текст поста или пересланное сообщение:')
|
|
||||||
return SELECT_TEXT
|
|
||||||
|
|
||||||
async def select_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['text'] = update.message.text or update.message.caption
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Канал: {c.name}', callback_data=f'channel_{c.id}')])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Группа: {g.name}', callback_data=f'group_{g.id}')])
|
|
||||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
|
||||||
await update.message.reply_text('Выберите, куда отправить пост:', reply_markup=reply_markup)
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
session = SessionLocal()
|
|
||||||
try:
|
|
||||||
if data.startswith('channel_'):
|
|
||||||
channel_id = int(data.split('_')[1])
|
|
||||||
channel = session.query(Channel).get(channel_id)
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
chat_id = channel.link.strip()
|
|
||||||
else:
|
|
||||||
group_id = int(data.split('_')[1])
|
|
||||||
group = session.query(Group).get(group_id)
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
chat_id = group.link.strip()
|
|
||||||
# Проверка chat_id
|
|
||||||
if not (chat_id.startswith('@') or chat_id.startswith('-')):
|
|
||||||
await query.edit_message_text('Ошибка: ссылка должна быть username (@channel) или числовой ID (-100...)')
|
|
||||||
return ConversationHandler.END
|
|
||||||
try:
|
|
||||||
await context.bot.send_photo(chat_id=chat_id, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
await query.edit_message_text('Пост отправлен!')
|
|
||||||
except Exception as e:
|
|
||||||
await query.edit_message_text(f'Ошибка отправки поста: {e}')
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
new_post_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('new_post', new_post_start)],
|
|
||||||
states={
|
|
||||||
SELECT_MEDIA: [MessageHandler(filters.PHOTO | filters.Document.IMAGE | filters.COMMAND, select_media)],
|
|
||||||
SELECT_TEXT: [MessageHandler(filters.TEXT | filters.FORWARDED, select_text)],
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
from telegram import Update, InputMediaPhoto, InlineKeyboardMarkup, InlineKeyboardButton
|
|
||||||
from telegram.ext import ContextTypes, ConversationHandler, MessageHandler, CommandHandler, filters, CallbackQueryHandler, ContextTypes
|
|
||||||
|
|
||||||
from db import SessionLocal
|
|
||||||
from models import Channel, Group, Button
|
|
||||||
|
|
||||||
SELECT_MEDIA, SELECT_TEXT, SELECT_TARGET = range(3)
|
|
||||||
|
|
||||||
async def new_post_start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
await update.message.reply_text('Отправьте картинку для поста или /skip:')
|
|
||||||
return SELECT_MEDIA
|
|
||||||
|
|
||||||
async def select_media(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
if update.message.photo:
|
|
||||||
context.user_data['photo'] = update.message.photo[-1].file_id
|
|
||||||
await update.message.reply_text('Введите текст поста или пересланное сообщение:')
|
|
||||||
return SELECT_TEXT
|
|
||||||
|
|
||||||
async def select_text(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
context.user_data['text'] = update.message.text or update.message.caption
|
|
||||||
session = SessionLocal()
|
|
||||||
channels = session.query(Channel).all()
|
|
||||||
groups = session.query(Group).all()
|
|
||||||
session.close()
|
|
||||||
keyboard = []
|
|
||||||
for c in channels:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Канал: {c.name}', callback_data=f'channel_{c.id}')])
|
|
||||||
for g in groups:
|
|
||||||
keyboard.append([InlineKeyboardButton(f'Группа: {g.name}', callback_data=f'group_{g.id}')])
|
|
||||||
reply_markup = InlineKeyboardMarkup(keyboard)
|
|
||||||
await update.message.reply_text('Выберите, куда отправить пост:', reply_markup=reply_markup)
|
|
||||||
return SELECT_TARGET
|
|
||||||
|
|
||||||
async def select_target(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
query = update.callback_query
|
|
||||||
await query.answer()
|
|
||||||
data = query.data
|
|
||||||
session = SessionLocal()
|
|
||||||
try:
|
|
||||||
if data.startswith('channel_'):
|
|
||||||
channel_id = int(data.split('_')[1])
|
|
||||||
channel = session.query(Channel).get(channel_id)
|
|
||||||
buttons = session.query(Button).filter_by(channel_id=channel_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
chat_id = channel.link.strip()
|
|
||||||
else:
|
|
||||||
group_id = int(data.split('_')[1])
|
|
||||||
group = session.query(Group).get(group_id)
|
|
||||||
buttons = session.query(Button).filter_by(group_id=group_id).all()
|
|
||||||
markup = InlineKeyboardMarkup([[InlineKeyboardButton(b.name, url=b.url)] for b in buttons]) if buttons else None
|
|
||||||
chat_id = group.link.strip()
|
|
||||||
# Проверка chat_id
|
|
||||||
if not (chat_id.startswith('@') or chat_id.startswith('-')):
|
|
||||||
await query.edit_message_text('Ошибка: ссылка должна быть username (@channel) или числовой ID (-100...)')
|
|
||||||
return ConversationHandler.END
|
|
||||||
try:
|
|
||||||
await context.bot.send_photo(chat_id=chat_id, photo=context.user_data.get('photo'), caption=context.user_data.get('text'), reply_markup=markup)
|
|
||||||
await query.edit_message_text('Пост отправлен!')
|
|
||||||
except Exception as e:
|
|
||||||
await query.edit_message_text(f'Ошибка отправки поста: {e}')
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
return ConversationHandler.END
|
|
||||||
|
|
||||||
new_post_conv = ConversationHandler(
|
|
||||||
entry_points=[CommandHandler('new_post', new_post_start)],
|
|
||||||
states={
|
|
||||||
SELECT_MEDIA: [MessageHandler(filters.PHOTO | filters.Document.IMAGE | filters.COMMAND, select_media)],
|
|
||||||
SELECT_TEXT: [MessageHandler(filters.TEXT | filters.FORWARDED, select_text)],
|
|
||||||
SELECT_TARGET: [CallbackQueryHandler(select_target)],
|
|
||||||
},
|
|
||||||
fallbacks=[]
|
|
||||||
)
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
# Заглушка для будущей реализации отправки поста вручную
|
|
||||||
# Можно реализовать выбор поста и канал/группу для отправки
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
# Заглушка для будущей реализации отправки поста вручную
|
|
||||||
# Можно реализовать выбор поста и канал/группу для отправки
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from db import SessionLocal, init_db
|
|
||||||
from models import Admin, Channel, Group, Button
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
user_id = update.effective_user.id
|
|
||||||
admin = session.query(Admin).filter_by(tg_id=user_id).first()
|
|
||||||
if not admin:
|
|
||||||
admin = Admin(tg_id=user_id)
|
|
||||||
session.add(admin)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Вы зарегистрированы как админ.')
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Вы уже зарегистрированы.')
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
# ...handlers for add_channel, add_group, add_button, new_post, etc. будут добавлены...
|
|
||||||
|
|
||||||
def main():
|
|
||||||
application = Application.builder().token(TELEGRAM_TOKEN).build()
|
|
||||||
application.add_handler(CommandHandler('start', start))
|
|
||||||
# ...add other handlers...
|
|
||||||
application.run_polling()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from db import SessionLocal, init_db
|
|
||||||
from models import Admin, Channel, Group, Button
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
user_id = update.effective_user.id
|
|
||||||
admin = session.query(Admin).filter_by(tg_id=user_id).first()
|
|
||||||
if not admin:
|
|
||||||
admin = Admin(tg_id=user_id)
|
|
||||||
session.add(admin)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Вы зарегистрированы как админ.')
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Вы уже зарегистрированы.')
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Импорт обработчиков
|
|
||||||
from handlers.add_channel import add_channel
|
|
||||||
from handlers.add_group import add_group
|
|
||||||
from handlers.add_button import add_button_conv
|
|
||||||
from handlers.new_post import new_post_conv
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
application = Application.builder().token(TELEGRAM_TOKEN).build()
|
|
||||||
application.add_handler(CommandHandler('start', start))
|
|
||||||
application.add_handler(CommandHandler('add_channel', add_channel))
|
|
||||||
application.add_handler(CommandHandler('add_group', add_group))
|
|
||||||
application.add_handler(add_button_conv)
|
|
||||||
application.add_handler(new_post_conv)
|
|
||||||
application.run_polling()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from db import SessionLocal, init_db
|
|
||||||
from models import Admin, Channel, Group, Button
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
user_id = update.effective_user.id
|
|
||||||
admin = session.query(Admin).filter_by(tg_id=user_id).first()
|
|
||||||
if not admin:
|
|
||||||
admin = Admin(tg_id=user_id)
|
|
||||||
session.add(admin)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Вы зарегистрированы как админ.')
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Вы уже зарегистрированы.')
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Импорт обработчиков
|
|
||||||
from handlers.add_channel import add_channel
|
|
||||||
from handlers.add_group import add_group
|
|
||||||
from handlers.add_button import add_button_conv
|
|
||||||
from handlers.new_post import new_post_conv
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
application = Application.builder().token(TELEGRAM_TOKEN).build()
|
|
||||||
application.add_handler(CommandHandler('start', start))
|
|
||||||
application.add_handler(CommandHandler('add_channel', add_channel))
|
|
||||||
application.add_handler(CommandHandler('add_group', add_group))
|
|
||||||
application.add_handler(add_button_conv)
|
|
||||||
application.add_handler(new_post_conv)
|
|
||||||
application.run_polling()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from db import SessionLocal, init_db
|
|
||||||
from models import Admin, Channel, Group, Button
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
user_id = update.effective_user.id
|
|
||||||
admin = session.query(Admin).filter_by(tg_id=user_id).first()
|
|
||||||
if not admin:
|
|
||||||
admin = Admin(tg_id=user_id)
|
|
||||||
session.add(admin)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Вы зарегистрированы как админ.')
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Вы уже зарегистрированы.')
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Импорт обработчиков
|
|
||||||
from handlers.add_channel import add_channel
|
|
||||||
from handlers.add_group import add_group
|
|
||||||
from handlers.add_button import add_button_conv
|
|
||||||
from handlers.new_post import new_post_conv
|
|
||||||
from handlers.group_buttons import group_buttons_conv
|
|
||||||
from handlers.channel_buttons import channel_buttons_conv
|
|
||||||
from handlers.edit_button import edit_button
|
|
||||||
from handlers.del_button import del_button
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
application = Application.builder().token(TELEGRAM_TOKEN).build()
|
|
||||||
application.add_handler(CommandHandler('start', start))
|
|
||||||
application.add_handler(CommandHandler('add_channel', add_channel))
|
|
||||||
application.add_handler(CommandHandler('add_group', add_group))
|
|
||||||
application.add_handler(add_button_conv)
|
|
||||||
application.add_handler(new_post_conv)
|
|
||||||
application.add_handler(group_buttons_conv)
|
|
||||||
application.add_handler(channel_buttons_conv)
|
|
||||||
application.add_handler(CommandHandler('edit_button', edit_button))
|
|
||||||
application.add_handler(CommandHandler('del_button', del_button))
|
|
||||||
application.run_polling()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
import logging
|
|
||||||
import os
|
|
||||||
from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup
|
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, CallbackQueryHandler, ConversationHandler, ContextTypes
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from db import SessionLocal, init_db
|
|
||||||
from models import Admin, Channel, Group, Button
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
TELEGRAM_TOKEN = os.getenv('TELEGRAM_TOKEN')
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
async def start(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
|
||||||
session = SessionLocal()
|
|
||||||
user_id = update.effective_user.id
|
|
||||||
admin = session.query(Admin).filter_by(tg_id=user_id).first()
|
|
||||||
if not admin:
|
|
||||||
admin = Admin(tg_id=user_id)
|
|
||||||
session.add(admin)
|
|
||||||
session.commit()
|
|
||||||
await update.message.reply_text('Вы зарегистрированы как админ.')
|
|
||||||
else:
|
|
||||||
await update.message.reply_text('Вы уже зарегистрированы.')
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
# Импорт обработчиков
|
|
||||||
from handlers.add_channel import add_channel
|
|
||||||
from handlers.add_group import add_group
|
|
||||||
from handlers.add_button import add_button_conv
|
|
||||||
from handlers.new_post import new_post_conv
|
|
||||||
from handlers.group_buttons import group_buttons_conv
|
|
||||||
from handlers.channel_buttons import channel_buttons_conv
|
|
||||||
from handlers.edit_button import edit_button
|
|
||||||
from handlers.del_button import del_button
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
application = Application.builder().token(TELEGRAM_TOKEN).build()
|
|
||||||
application.add_handler(CommandHandler('start', start))
|
|
||||||
application.add_handler(CommandHandler('add_channel', add_channel))
|
|
||||||
application.add_handler(CommandHandler('add_group', add_group))
|
|
||||||
application.add_handler(add_button_conv)
|
|
||||||
application.add_handler(new_post_conv)
|
|
||||||
application.add_handler(group_buttons_conv)
|
|
||||||
application.add_handler(channel_buttons_conv)
|
|
||||||
application.add_handler(CommandHandler('edit_button', edit_button))
|
|
||||||
application.add_handler(CommandHandler('del_button', del_button))
|
|
||||||
application.run_polling()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
from sqlalchemy import Column, Integer, String, ForeignKey, Text
|
|
||||||
from sqlalchemy.orm import declarative_base, relationship
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
class Admin(Base):
|
|
||||||
__tablename__ = 'admins'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
tg_id = Column(Integer, unique=True, nullable=False)
|
|
||||||
|
|
||||||
class Channel(Base):
|
|
||||||
__tablename__ = 'channels'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
link = Column(String, nullable=False)
|
|
||||||
buttons = relationship('Button', back_populates='channel')
|
|
||||||
|
|
||||||
class Group(Base):
|
|
||||||
__tablename__ = 'groups'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
link = Column(String, nullable=False)
|
|
||||||
buttons = relationship('Button', back_populates='group')
|
|
||||||
|
|
||||||
class Button(Base):
|
|
||||||
__tablename__ = 'buttons'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
url = Column(String, nullable=False)
|
|
||||||
channel_id = Column(Integer, ForeignKey('channels.id'), nullable=True)
|
|
||||||
group_id = Column(Integer, ForeignKey('groups.id'), nullable=True)
|
|
||||||
channel = relationship('Channel', back_populates='buttons')
|
|
||||||
group = relationship('Group', back_populates='buttons')
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
from sqlalchemy import Column, Integer, String, ForeignKey, Text
|
|
||||||
from sqlalchemy.orm import declarative_base, relationship
|
|
||||||
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
class Admin(Base):
|
|
||||||
__tablename__ = 'admins'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
tg_id = Column(Integer, unique=True, nullable=False)
|
|
||||||
|
|
||||||
class Channel(Base):
|
|
||||||
__tablename__ = 'channels'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
link = Column(String, nullable=False)
|
|
||||||
buttons = relationship('Button', back_populates='channel')
|
|
||||||
|
|
||||||
class Group(Base):
|
|
||||||
__tablename__ = 'groups'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
link = Column(String, nullable=False)
|
|
||||||
buttons = relationship('Button', back_populates='group')
|
|
||||||
|
|
||||||
class Button(Base):
|
|
||||||
__tablename__ = 'buttons'
|
|
||||||
id = Column(Integer, primary_key=True)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
url = Column(String, nullable=False)
|
|
||||||
channel_id = Column(Integer, ForeignKey('channels.id'), nullable=True)
|
|
||||||
group_id = Column(Integer, ForeignKey('groups.id'), nullable=True)
|
|
||||||
channel = relationship('Channel', back_populates='buttons')
|
|
||||||
group = relationship('Group', back_populates='buttons')
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
python-telegram-bot>=20.0
|
|
||||||
sqlalchemy>=2.0
|
|
||||||
python-dotenv>=1.0
|
|
||||||
pytest>=7.0
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
python-telegram-bot>=20.0
|
|
||||||
sqlalchemy>=2.0
|
|
||||||
python-dotenv>=1.0
|
|
||||||
pytest>=7.0
|
|
||||||
@@ -1,247 +0,0 @@
|
|||||||
<#
|
|
||||||
.Synopsis
|
|
||||||
Activate a Python virtual environment for the current PowerShell session.
|
|
||||||
|
|
||||||
.Description
|
|
||||||
Pushes the python executable for a virtual environment to the front of the
|
|
||||||
$Env:PATH environment variable and sets the prompt to signify that you are
|
|
||||||
in a Python virtual environment. Makes use of the command line switches as
|
|
||||||
well as the `pyvenv.cfg` file values present in the virtual environment.
|
|
||||||
|
|
||||||
.Parameter VenvDir
|
|
||||||
Path to the directory that contains the virtual environment to activate. The
|
|
||||||
default value for this is the parent of the directory that the Activate.ps1
|
|
||||||
script is located within.
|
|
||||||
|
|
||||||
.Parameter Prompt
|
|
||||||
The prompt prefix to display when this virtual environment is activated. By
|
|
||||||
default, this prompt is the name of the virtual environment folder (VenvDir)
|
|
||||||
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1
|
|
||||||
Activates the Python virtual environment that contains the Activate.ps1 script.
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1 -Verbose
|
|
||||||
Activates the Python virtual environment that contains the Activate.ps1 script,
|
|
||||||
and shows extra information about the activation as it executes.
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
|
|
||||||
Activates the Python virtual environment located in the specified location.
|
|
||||||
|
|
||||||
.Example
|
|
||||||
Activate.ps1 -Prompt "MyPython"
|
|
||||||
Activates the Python virtual environment that contains the Activate.ps1 script,
|
|
||||||
and prefixes the current prompt with the specified string (surrounded in
|
|
||||||
parentheses) while the virtual environment is active.
|
|
||||||
|
|
||||||
.Notes
|
|
||||||
On Windows, it may be required to enable this Activate.ps1 script by setting the
|
|
||||||
execution policy for the user. You can do this by issuing the following PowerShell
|
|
||||||
command:
|
|
||||||
|
|
||||||
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
|
||||||
|
|
||||||
For more information on Execution Policies:
|
|
||||||
https://go.microsoft.com/fwlink/?LinkID=135170
|
|
||||||
|
|
||||||
#>
|
|
||||||
Param(
|
|
||||||
[Parameter(Mandatory = $false)]
|
|
||||||
[String]
|
|
||||||
$VenvDir,
|
|
||||||
[Parameter(Mandatory = $false)]
|
|
||||||
[String]
|
|
||||||
$Prompt
|
|
||||||
)
|
|
||||||
|
|
||||||
<# Function declarations --------------------------------------------------- #>
|
|
||||||
|
|
||||||
<#
|
|
||||||
.Synopsis
|
|
||||||
Remove all shell session elements added by the Activate script, including the
|
|
||||||
addition of the virtual environment's Python executable from the beginning of
|
|
||||||
the PATH variable.
|
|
||||||
|
|
||||||
.Parameter NonDestructive
|
|
||||||
If present, do not remove this function from the global namespace for the
|
|
||||||
session.
|
|
||||||
|
|
||||||
#>
|
|
||||||
function global:deactivate ([switch]$NonDestructive) {
|
|
||||||
# Revert to original values
|
|
||||||
|
|
||||||
# The prior prompt:
|
|
||||||
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
|
|
||||||
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
|
|
||||||
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
|
|
||||||
}
|
|
||||||
|
|
||||||
# The prior PYTHONHOME:
|
|
||||||
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
|
|
||||||
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
|
|
||||||
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
|
|
||||||
}
|
|
||||||
|
|
||||||
# The prior PATH:
|
|
||||||
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
|
|
||||||
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
|
|
||||||
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
|
|
||||||
}
|
|
||||||
|
|
||||||
# Just remove the VIRTUAL_ENV altogether:
|
|
||||||
if (Test-Path -Path Env:VIRTUAL_ENV) {
|
|
||||||
Remove-Item -Path env:VIRTUAL_ENV
|
|
||||||
}
|
|
||||||
|
|
||||||
# Just remove VIRTUAL_ENV_PROMPT altogether.
|
|
||||||
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
|
|
||||||
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
|
|
||||||
}
|
|
||||||
|
|
||||||
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
|
|
||||||
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
|
|
||||||
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
|
|
||||||
}
|
|
||||||
|
|
||||||
# Leave deactivate function in the global namespace if requested:
|
|
||||||
if (-not $NonDestructive) {
|
|
||||||
Remove-Item -Path function:deactivate
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
<#
|
|
||||||
.Description
|
|
||||||
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
|
|
||||||
given folder, and returns them in a map.
|
|
||||||
|
|
||||||
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
|
|
||||||
two strings separated by `=` (with any amount of whitespace surrounding the =)
|
|
||||||
then it is considered a `key = value` line. The left hand string is the key,
|
|
||||||
the right hand is the value.
|
|
||||||
|
|
||||||
If the value starts with a `'` or a `"` then the first and last character is
|
|
||||||
stripped from the value before being captured.
|
|
||||||
|
|
||||||
.Parameter ConfigDir
|
|
||||||
Path to the directory that contains the `pyvenv.cfg` file.
|
|
||||||
#>
|
|
||||||
function Get-PyVenvConfig(
|
|
||||||
[String]
|
|
||||||
$ConfigDir
|
|
||||||
) {
|
|
||||||
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
|
|
||||||
|
|
||||||
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
|
|
||||||
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
|
|
||||||
|
|
||||||
# An empty map will be returned if no config file is found.
|
|
||||||
$pyvenvConfig = @{ }
|
|
||||||
|
|
||||||
if ($pyvenvConfigPath) {
|
|
||||||
|
|
||||||
Write-Verbose "File exists, parse `key = value` lines"
|
|
||||||
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
|
|
||||||
|
|
||||||
$pyvenvConfigContent | ForEach-Object {
|
|
||||||
$keyval = $PSItem -split "\s*=\s*", 2
|
|
||||||
if ($keyval[0] -and $keyval[1]) {
|
|
||||||
$val = $keyval[1]
|
|
||||||
|
|
||||||
# Remove extraneous quotations around a string value.
|
|
||||||
if ("'""".Contains($val.Substring(0, 1))) {
|
|
||||||
$val = $val.Substring(1, $val.Length - 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
$pyvenvConfig[$keyval[0]] = $val
|
|
||||||
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return $pyvenvConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
<# Begin Activate script --------------------------------------------------- #>
|
|
||||||
|
|
||||||
# Determine the containing directory of this script
|
|
||||||
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
|
|
||||||
$VenvExecDir = Get-Item -Path $VenvExecPath
|
|
||||||
|
|
||||||
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
|
|
||||||
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
|
|
||||||
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
|
|
||||||
|
|
||||||
# Set values required in priority: CmdLine, ConfigFile, Default
|
|
||||||
# First, get the location of the virtual environment, it might not be
|
|
||||||
# VenvExecDir if specified on the command line.
|
|
||||||
if ($VenvDir) {
|
|
||||||
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
|
|
||||||
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
|
|
||||||
Write-Verbose "VenvDir=$VenvDir"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Next, read the `pyvenv.cfg` file to determine any required value such
|
|
||||||
# as `prompt`.
|
|
||||||
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
|
|
||||||
|
|
||||||
# Next, set the prompt from the command line, or the config file, or
|
|
||||||
# just use the name of the virtual environment folder.
|
|
||||||
if ($Prompt) {
|
|
||||||
Write-Verbose "Prompt specified as argument, using '$Prompt'"
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
|
|
||||||
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
|
|
||||||
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
|
|
||||||
$Prompt = $pyvenvCfg['prompt'];
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
|
|
||||||
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
|
|
||||||
$Prompt = Split-Path -Path $venvDir -Leaf
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Write-Verbose "Prompt = '$Prompt'"
|
|
||||||
Write-Verbose "VenvDir='$VenvDir'"
|
|
||||||
|
|
||||||
# Deactivate any currently active virtual environment, but leave the
|
|
||||||
# deactivate function in place.
|
|
||||||
deactivate -nondestructive
|
|
||||||
|
|
||||||
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
|
|
||||||
# that there is an activated venv.
|
|
||||||
$env:VIRTUAL_ENV = $VenvDir
|
|
||||||
|
|
||||||
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
|
|
||||||
|
|
||||||
Write-Verbose "Setting prompt to '$Prompt'"
|
|
||||||
|
|
||||||
# Set the prompt to include the env name
|
|
||||||
# Make sure _OLD_VIRTUAL_PROMPT is global
|
|
||||||
function global:_OLD_VIRTUAL_PROMPT { "" }
|
|
||||||
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
|
|
||||||
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
|
|
||||||
|
|
||||||
function global:prompt {
|
|
||||||
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
|
|
||||||
_OLD_VIRTUAL_PROMPT
|
|
||||||
}
|
|
||||||
$env:VIRTUAL_ENV_PROMPT = $Prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
# Clear PYTHONHOME
|
|
||||||
if (Test-Path -Path Env:PYTHONHOME) {
|
|
||||||
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
|
|
||||||
Remove-Item -Path Env:PYTHONHOME
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add the venv to the PATH
|
|
||||||
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
|
|
||||||
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
# This file must be used with "source bin/activate" *from bash*
|
|
||||||
# You cannot run it directly
|
|
||||||
|
|
||||||
deactivate () {
|
|
||||||
# reset old environment variables
|
|
||||||
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
|
|
||||||
PATH="${_OLD_VIRTUAL_PATH:-}"
|
|
||||||
export PATH
|
|
||||||
unset _OLD_VIRTUAL_PATH
|
|
||||||
fi
|
|
||||||
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
|
|
||||||
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
|
|
||||||
export PYTHONHOME
|
|
||||||
unset _OLD_VIRTUAL_PYTHONHOME
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Call hash to forget past commands. Without forgetting
|
|
||||||
# past commands the $PATH changes we made may not be respected
|
|
||||||
hash -r 2> /dev/null
|
|
||||||
|
|
||||||
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
|
|
||||||
PS1="${_OLD_VIRTUAL_PS1:-}"
|
|
||||||
export PS1
|
|
||||||
unset _OLD_VIRTUAL_PS1
|
|
||||||
fi
|
|
||||||
|
|
||||||
unset VIRTUAL_ENV
|
|
||||||
unset VIRTUAL_ENV_PROMPT
|
|
||||||
if [ ! "${1:-}" = "nondestructive" ] ; then
|
|
||||||
# Self destruct!
|
|
||||||
unset -f deactivate
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# unset irrelevant variables
|
|
||||||
deactivate nondestructive
|
|
||||||
|
|
||||||
# on Windows, a path can contain colons and backslashes and has to be converted:
|
|
||||||
if [ "${OSTYPE:-}" = "cygwin" ] || [ "${OSTYPE:-}" = "msys" ] ; then
|
|
||||||
# transform D:\path\to\venv to /d/path/to/venv on MSYS
|
|
||||||
# and to /cygdrive/d/path/to/venv on Cygwin
|
|
||||||
export VIRTUAL_ENV=$(cygpath /home/data/post_bot/.venv)
|
|
||||||
else
|
|
||||||
# use the path as-is
|
|
||||||
export VIRTUAL_ENV=/home/data/post_bot/.venv
|
|
||||||
fi
|
|
||||||
|
|
||||||
_OLD_VIRTUAL_PATH="$PATH"
|
|
||||||
PATH="$VIRTUAL_ENV/"bin":$PATH"
|
|
||||||
export PATH
|
|
||||||
|
|
||||||
# unset PYTHONHOME if set
|
|
||||||
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
|
|
||||||
# could use `if (set -u; : $PYTHONHOME) ;` in bash
|
|
||||||
if [ -n "${PYTHONHOME:-}" ] ; then
|
|
||||||
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
|
|
||||||
unset PYTHONHOME
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
|
|
||||||
_OLD_VIRTUAL_PS1="${PS1:-}"
|
|
||||||
PS1='(.venv) '"${PS1:-}"
|
|
||||||
export PS1
|
|
||||||
VIRTUAL_ENV_PROMPT='(.venv) '
|
|
||||||
export VIRTUAL_ENV_PROMPT
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Call hash to forget past commands. Without forgetting
|
|
||||||
# past commands the $PATH changes we made may not be respected
|
|
||||||
hash -r 2> /dev/null
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
# This file must be used with "source bin/activate.csh" *from csh*.
|
|
||||||
# You cannot run it directly.
|
|
||||||
|
|
||||||
# Created by Davide Di Blasi <davidedb@gmail.com>.
|
|
||||||
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
|
|
||||||
|
|
||||||
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
|
|
||||||
|
|
||||||
# Unset irrelevant variables.
|
|
||||||
deactivate nondestructive
|
|
||||||
|
|
||||||
setenv VIRTUAL_ENV /home/data/post_bot/.venv
|
|
||||||
|
|
||||||
set _OLD_VIRTUAL_PATH="$PATH"
|
|
||||||
setenv PATH "$VIRTUAL_ENV/"bin":$PATH"
|
|
||||||
|
|
||||||
|
|
||||||
set _OLD_VIRTUAL_PROMPT="$prompt"
|
|
||||||
|
|
||||||
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
|
|
||||||
set prompt = '(.venv) '"$prompt"
|
|
||||||
setenv VIRTUAL_ENV_PROMPT '(.venv) '
|
|
||||||
endif
|
|
||||||
|
|
||||||
alias pydoc python -m pydoc
|
|
||||||
|
|
||||||
rehash
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
|
|
||||||
# (https://fishshell.com/). You cannot run it directly.
|
|
||||||
|
|
||||||
function deactivate -d "Exit virtual environment and return to normal shell environment"
|
|
||||||
# reset old environment variables
|
|
||||||
if test -n "$_OLD_VIRTUAL_PATH"
|
|
||||||
set -gx PATH $_OLD_VIRTUAL_PATH
|
|
||||||
set -e _OLD_VIRTUAL_PATH
|
|
||||||
end
|
|
||||||
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
|
|
||||||
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
|
|
||||||
set -e _OLD_VIRTUAL_PYTHONHOME
|
|
||||||
end
|
|
||||||
|
|
||||||
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
|
|
||||||
set -e _OLD_FISH_PROMPT_OVERRIDE
|
|
||||||
# prevents error when using nested fish instances (Issue #93858)
|
|
||||||
if functions -q _old_fish_prompt
|
|
||||||
functions -e fish_prompt
|
|
||||||
functions -c _old_fish_prompt fish_prompt
|
|
||||||
functions -e _old_fish_prompt
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
set -e VIRTUAL_ENV
|
|
||||||
set -e VIRTUAL_ENV_PROMPT
|
|
||||||
if test "$argv[1]" != "nondestructive"
|
|
||||||
# Self-destruct!
|
|
||||||
functions -e deactivate
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
# Unset irrelevant variables.
|
|
||||||
deactivate nondestructive
|
|
||||||
|
|
||||||
set -gx VIRTUAL_ENV /home/data/post_bot/.venv
|
|
||||||
|
|
||||||
set -gx _OLD_VIRTUAL_PATH $PATH
|
|
||||||
set -gx PATH "$VIRTUAL_ENV/"bin $PATH
|
|
||||||
|
|
||||||
# Unset PYTHONHOME if set.
|
|
||||||
if set -q PYTHONHOME
|
|
||||||
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
|
|
||||||
set -e PYTHONHOME
|
|
||||||
end
|
|
||||||
|
|
||||||
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
|
|
||||||
# fish uses a function instead of an env var to generate the prompt.
|
|
||||||
|
|
||||||
# Save the current fish_prompt function as the function _old_fish_prompt.
|
|
||||||
functions -c fish_prompt _old_fish_prompt
|
|
||||||
|
|
||||||
# With the original prompt function renamed, we can override with our own.
|
|
||||||
function fish_prompt
|
|
||||||
# Save the return status of the last command.
|
|
||||||
set -l old_status $status
|
|
||||||
|
|
||||||
# Output the venv prompt; color taken from the blue of the Python logo.
|
|
||||||
printf "%s%s%s" (set_color 4B8BBE) '(.venv) ' (set_color normal)
|
|
||||||
|
|
||||||
# Restore the return status of the previous command.
|
|
||||||
echo "exit $old_status" | .
|
|
||||||
# Output the original/"old" prompt.
|
|
||||||
_old_fish_prompt
|
|
||||||
end
|
|
||||||
|
|
||||||
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
|
|
||||||
set -gx VIRTUAL_ENV_PROMPT '(.venv) '
|
|
||||||
end
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from dotenv.__main__ import cli
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(cli())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from httpx import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pip._internal.cli.main import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pip._internal.cli.main import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pip._internal.cli.main import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pytest import console_main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(console_main())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pygments.cmdline import main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(main())
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
#!/home/data/post_bot/.venv/bin/python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pytest import console_main
|
|
||||||
if __name__ == '__main__':
|
|
||||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
|
||||||
sys.exit(console_main())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
python3
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
/usr/bin/python3
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
python3
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
|
|
||||||
|
|
||||||
/* Greenlet object interface */
|
|
||||||
|
|
||||||
#ifndef Py_GREENLETOBJECT_H
|
|
||||||
#define Py_GREENLETOBJECT_H
|
|
||||||
|
|
||||||
|
|
||||||
#include <Python.h>
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/* This is deprecated and undocumented. It does not change. */
|
|
||||||
#define GREENLET_VERSION "1.0.0"
|
|
||||||
|
|
||||||
#ifndef GREENLET_MODULE
|
|
||||||
#define implementation_ptr_t void*
|
|
||||||
#endif
|
|
||||||
|
|
||||||
typedef struct _greenlet {
|
|
||||||
PyObject_HEAD
|
|
||||||
PyObject* weakreflist;
|
|
||||||
PyObject* dict;
|
|
||||||
implementation_ptr_t pimpl;
|
|
||||||
} PyGreenlet;
|
|
||||||
|
|
||||||
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
|
|
||||||
|
|
||||||
|
|
||||||
/* C API functions */
|
|
||||||
|
|
||||||
/* Total number of symbols that are exported */
|
|
||||||
#define PyGreenlet_API_pointers 12
|
|
||||||
|
|
||||||
#define PyGreenlet_Type_NUM 0
|
|
||||||
#define PyExc_GreenletError_NUM 1
|
|
||||||
#define PyExc_GreenletExit_NUM 2
|
|
||||||
|
|
||||||
#define PyGreenlet_New_NUM 3
|
|
||||||
#define PyGreenlet_GetCurrent_NUM 4
|
|
||||||
#define PyGreenlet_Throw_NUM 5
|
|
||||||
#define PyGreenlet_Switch_NUM 6
|
|
||||||
#define PyGreenlet_SetParent_NUM 7
|
|
||||||
|
|
||||||
#define PyGreenlet_MAIN_NUM 8
|
|
||||||
#define PyGreenlet_STARTED_NUM 9
|
|
||||||
#define PyGreenlet_ACTIVE_NUM 10
|
|
||||||
#define PyGreenlet_GET_PARENT_NUM 11
|
|
||||||
|
|
||||||
#ifndef GREENLET_MODULE
|
|
||||||
/* This section is used by modules that uses the greenlet C API */
|
|
||||||
static void** _PyGreenlet_API = NULL;
|
|
||||||
|
|
||||||
# define PyGreenlet_Type \
|
|
||||||
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
|
|
||||||
|
|
||||||
# define PyExc_GreenletError \
|
|
||||||
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
|
|
||||||
|
|
||||||
# define PyExc_GreenletExit \
|
|
||||||
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* PyGreenlet_New(PyObject *args)
|
|
||||||
*
|
|
||||||
* greenlet.greenlet(run, parent=None)
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_New \
|
|
||||||
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_New_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* PyGreenlet_GetCurrent(void)
|
|
||||||
*
|
|
||||||
* greenlet.getcurrent()
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_GetCurrent \
|
|
||||||
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* PyGreenlet_Throw(
|
|
||||||
* PyGreenlet *greenlet,
|
|
||||||
* PyObject *typ,
|
|
||||||
* PyObject *val,
|
|
||||||
* PyObject *tb)
|
|
||||||
*
|
|
||||||
* g.throw(...)
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_Throw \
|
|
||||||
(*(PyObject * (*)(PyGreenlet * self, \
|
|
||||||
PyObject * typ, \
|
|
||||||
PyObject * val, \
|
|
||||||
PyObject * tb)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_Throw_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
|
|
||||||
*
|
|
||||||
* g.switch(*args, **kwargs)
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_Switch \
|
|
||||||
(*(PyObject * \
|
|
||||||
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_Switch_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
|
|
||||||
*
|
|
||||||
* g.parent = new_parent
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_SetParent \
|
|
||||||
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* PyGreenlet_GetParent(PyObject* greenlet)
|
|
||||||
*
|
|
||||||
* return greenlet.parent;
|
|
||||||
*
|
|
||||||
* This could return NULL even if there is no exception active.
|
|
||||||
* If it does not return NULL, you are responsible for decrementing the
|
|
||||||
* reference count.
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_GetParent \
|
|
||||||
(*(PyGreenlet* (*)(PyGreenlet*)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
|
|
||||||
|
|
||||||
/*
|
|
||||||
* deprecated, undocumented alias.
|
|
||||||
*/
|
|
||||||
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
|
|
||||||
|
|
||||||
# define PyGreenlet_MAIN \
|
|
||||||
(*(int (*)(PyGreenlet*)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
|
|
||||||
|
|
||||||
# define PyGreenlet_STARTED \
|
|
||||||
(*(int (*)(PyGreenlet*)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
|
|
||||||
|
|
||||||
# define PyGreenlet_ACTIVE \
|
|
||||||
(*(int (*)(PyGreenlet*)) \
|
|
||||||
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/* Macro that imports greenlet and initializes C API */
|
|
||||||
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
|
|
||||||
keep the older definition to be sure older code that might have a copy of
|
|
||||||
the header still works. */
|
|
||||||
# define PyGreenlet_Import() \
|
|
||||||
{ \
|
|
||||||
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif /* GREENLET_MODULE */
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#endif /* !Py_GREENLETOBJECT_H */
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["__version__", "version_tuple"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ._version import version as __version__
|
|
||||||
from ._version import version_tuple
|
|
||||||
except ImportError: # pragma: no cover
|
|
||||||
# broken installation, we don't even try
|
|
||||||
# unknown only works because we do poor mans version compare
|
|
||||||
__version__ = "unknown"
|
|
||||||
version_tuple = (0, 0, "unknown")
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
"""Allow bash-completion for argparse with argcomplete if installed.
|
|
||||||
|
|
||||||
Needs argcomplete>=0.5.6 for python 3.2/3.3 (older versions fail
|
|
||||||
to find the magic string, so _ARGCOMPLETE env. var is never set, and
|
|
||||||
this does not need special code).
|
|
||||||
|
|
||||||
Function try_argcomplete(parser) should be called directly before
|
|
||||||
the call to ArgumentParser.parse_args().
|
|
||||||
|
|
||||||
The filescompleter is what you normally would use on the positional
|
|
||||||
arguments specification, in order to get "dirname/" after "dirn<TAB>"
|
|
||||||
instead of the default "dirname ":
|
|
||||||
|
|
||||||
optparser.add_argument(Config._file_or_dir, nargs='*').completer=filescompleter
|
|
||||||
|
|
||||||
Other, application specific, completers should go in the file
|
|
||||||
doing the add_argument calls as they need to be specified as .completer
|
|
||||||
attributes as well. (If argcomplete is not installed, the function the
|
|
||||||
attribute points to will not be used).
|
|
||||||
|
|
||||||
SPEEDUP
|
|
||||||
=======
|
|
||||||
|
|
||||||
The generic argcomplete script for bash-completion
|
|
||||||
(/etc/bash_completion.d/python-argcomplete.sh)
|
|
||||||
uses a python program to determine startup script generated by pip.
|
|
||||||
You can speed up completion somewhat by changing this script to include
|
|
||||||
# PYTHON_ARGCOMPLETE_OK
|
|
||||||
so the python-argcomplete-check-easy-install-script does not
|
|
||||||
need to be called to find the entry point of the code and see if that is
|
|
||||||
marked with PYTHON_ARGCOMPLETE_OK.
|
|
||||||
|
|
||||||
INSTALL/DEBUGGING
|
|
||||||
=================
|
|
||||||
|
|
||||||
To include this support in another application that has setup.py generated
|
|
||||||
scripts:
|
|
||||||
|
|
||||||
- Add the line:
|
|
||||||
# PYTHON_ARGCOMPLETE_OK
|
|
||||||
near the top of the main python entry point.
|
|
||||||
|
|
||||||
- Include in the file calling parse_args():
|
|
||||||
from _argcomplete import try_argcomplete, filescompleter
|
|
||||||
Call try_argcomplete just before parse_args(), and optionally add
|
|
||||||
filescompleter to the positional arguments' add_argument().
|
|
||||||
|
|
||||||
If things do not work right away:
|
|
||||||
|
|
||||||
- Switch on argcomplete debugging with (also helpful when doing custom
|
|
||||||
completers):
|
|
||||||
export _ARC_DEBUG=1
|
|
||||||
|
|
||||||
- Run:
|
|
||||||
python-argcomplete-check-easy-install-script $(which appname)
|
|
||||||
echo $?
|
|
||||||
will echo 0 if the magic line has been found, 1 if not.
|
|
||||||
|
|
||||||
- Sometimes it helps to find early on errors using:
|
|
||||||
_ARGCOMPLETE=1 _ARC_DEBUG=1 appname
|
|
||||||
which should throw a KeyError: 'COMPLINE' (which is properly set by the
|
|
||||||
global argcomplete script).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from glob import glob
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class FastFilesCompleter:
|
|
||||||
"""Fast file completer class."""
|
|
||||||
|
|
||||||
def __init__(self, directories: bool = True) -> None:
|
|
||||||
self.directories = directories
|
|
||||||
|
|
||||||
def __call__(self, prefix: str, **kwargs: Any) -> list[str]:
|
|
||||||
# Only called on non option completions.
|
|
||||||
if os.sep in prefix[1:]:
|
|
||||||
prefix_dir = len(os.path.dirname(prefix) + os.sep)
|
|
||||||
else:
|
|
||||||
prefix_dir = 0
|
|
||||||
completion = []
|
|
||||||
globbed = []
|
|
||||||
if "*" not in prefix and "?" not in prefix:
|
|
||||||
# We are on unix, otherwise no bash.
|
|
||||||
if not prefix or prefix[-1] == os.sep:
|
|
||||||
globbed.extend(glob(prefix + ".*"))
|
|
||||||
prefix += "*"
|
|
||||||
globbed.extend(glob(prefix))
|
|
||||||
for x in sorted(globbed):
|
|
||||||
if os.path.isdir(x):
|
|
||||||
x += "/"
|
|
||||||
# Append stripping the prefix (like bash, not like compgen).
|
|
||||||
completion.append(x[prefix_dir:])
|
|
||||||
return completion
|
|
||||||
|
|
||||||
|
|
||||||
if os.environ.get("_ARGCOMPLETE"):
|
|
||||||
try:
|
|
||||||
import argcomplete.completers
|
|
||||||
except ImportError:
|
|
||||||
sys.exit(-1)
|
|
||||||
filescompleter: FastFilesCompleter | None = FastFilesCompleter()
|
|
||||||
|
|
||||||
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
|
|
||||||
argcomplete.autocomplete(parser, always_complete_options=False)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
filescompleter = None
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""Python inspection/code generation API."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .code import Code
|
|
||||||
from .code import ExceptionInfo
|
|
||||||
from .code import filter_traceback
|
|
||||||
from .code import Frame
|
|
||||||
from .code import getfslineno
|
|
||||||
from .code import Traceback
|
|
||||||
from .code import TracebackEntry
|
|
||||||
from .source import getrawcode
|
|
||||||
from .source import Source
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Code",
|
|
||||||
"ExceptionInfo",
|
|
||||||
"Frame",
|
|
||||||
"Source",
|
|
||||||
"Traceback",
|
|
||||||
"TracebackEntry",
|
|
||||||
"filter_traceback",
|
|
||||||
"getfslineno",
|
|
||||||
"getrawcode",
|
|
||||||
]
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,225 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import ast
|
|
||||||
from bisect import bisect_right
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from collections.abc import Iterator
|
|
||||||
import inspect
|
|
||||||
import textwrap
|
|
||||||
import tokenize
|
|
||||||
import types
|
|
||||||
from typing import overload
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
class Source:
|
|
||||||
"""An immutable object holding a source code fragment.
|
|
||||||
|
|
||||||
When using Source(...), the source lines are deindented.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, obj: object = None) -> None:
|
|
||||||
if not obj:
|
|
||||||
self.lines: list[str] = []
|
|
||||||
self.raw_lines: list[str] = []
|
|
||||||
elif isinstance(obj, Source):
|
|
||||||
self.lines = obj.lines
|
|
||||||
self.raw_lines = obj.raw_lines
|
|
||||||
elif isinstance(obj, (tuple, list)):
|
|
||||||
self.lines = deindent(x.rstrip("\n") for x in obj)
|
|
||||||
self.raw_lines = list(x.rstrip("\n") for x in obj)
|
|
||||||
elif isinstance(obj, str):
|
|
||||||
self.lines = deindent(obj.split("\n"))
|
|
||||||
self.raw_lines = obj.split("\n")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
rawcode = getrawcode(obj)
|
|
||||||
src = inspect.getsource(rawcode)
|
|
||||||
except TypeError:
|
|
||||||
src = inspect.getsource(obj) # type: ignore[arg-type]
|
|
||||||
self.lines = deindent(src.split("\n"))
|
|
||||||
self.raw_lines = src.split("\n")
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, Source):
|
|
||||||
return NotImplemented
|
|
||||||
return self.lines == other.lines
|
|
||||||
|
|
||||||
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
|
||||||
__hash__ = None # type: ignore
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __getitem__(self, key: int) -> str: ...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def __getitem__(self, key: slice) -> Source: ...
|
|
||||||
|
|
||||||
def __getitem__(self, key: int | slice) -> str | Source:
|
|
||||||
if isinstance(key, int):
|
|
||||||
return self.lines[key]
|
|
||||||
else:
|
|
||||||
if key.step not in (None, 1):
|
|
||||||
raise IndexError("cannot slice a Source with a step")
|
|
||||||
newsource = Source()
|
|
||||||
newsource.lines = self.lines[key.start : key.stop]
|
|
||||||
newsource.raw_lines = self.raw_lines[key.start : key.stop]
|
|
||||||
return newsource
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[str]:
|
|
||||||
return iter(self.lines)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.lines)
|
|
||||||
|
|
||||||
def strip(self) -> Source:
|
|
||||||
"""Return new Source object with trailing and leading blank lines removed."""
|
|
||||||
start, end = 0, len(self)
|
|
||||||
while start < end and not self.lines[start].strip():
|
|
||||||
start += 1
|
|
||||||
while end > start and not self.lines[end - 1].strip():
|
|
||||||
end -= 1
|
|
||||||
source = Source()
|
|
||||||
source.raw_lines = self.raw_lines
|
|
||||||
source.lines[:] = self.lines[start:end]
|
|
||||||
return source
|
|
||||||
|
|
||||||
def indent(self, indent: str = " " * 4) -> Source:
|
|
||||||
"""Return a copy of the source object with all lines indented by the
|
|
||||||
given indent-string."""
|
|
||||||
newsource = Source()
|
|
||||||
newsource.raw_lines = self.raw_lines
|
|
||||||
newsource.lines = [(indent + line) for line in self.lines]
|
|
||||||
return newsource
|
|
||||||
|
|
||||||
def getstatement(self, lineno: int) -> Source:
|
|
||||||
"""Return Source statement which contains the given linenumber
|
|
||||||
(counted from 0)."""
|
|
||||||
start, end = self.getstatementrange(lineno)
|
|
||||||
return self[start:end]
|
|
||||||
|
|
||||||
def getstatementrange(self, lineno: int) -> tuple[int, int]:
|
|
||||||
"""Return (start, end) tuple which spans the minimal statement region
|
|
||||||
which containing the given lineno."""
|
|
||||||
if not (0 <= lineno < len(self)):
|
|
||||||
raise IndexError("lineno out of range")
|
|
||||||
ast, start, end = getstatementrange_ast(lineno, self)
|
|
||||||
return start, end
|
|
||||||
|
|
||||||
def deindent(self) -> Source:
|
|
||||||
"""Return a new Source object deindented."""
|
|
||||||
newsource = Source()
|
|
||||||
newsource.lines[:] = deindent(self.lines)
|
|
||||||
newsource.raw_lines = self.raw_lines
|
|
||||||
return newsource
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return "\n".join(self.lines)
|
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# helper functions
|
|
||||||
#
|
|
||||||
|
|
||||||
|
|
||||||
def findsource(obj) -> tuple[Source | None, int]:
|
|
||||||
try:
|
|
||||||
sourcelines, lineno = inspect.findsource(obj)
|
|
||||||
except Exception:
|
|
||||||
return None, -1
|
|
||||||
source = Source()
|
|
||||||
source.lines = [line.rstrip() for line in sourcelines]
|
|
||||||
source.raw_lines = sourcelines
|
|
||||||
return source, lineno
|
|
||||||
|
|
||||||
|
|
||||||
def getrawcode(obj: object, trycall: bool = True) -> types.CodeType:
|
|
||||||
"""Return code object for given function."""
|
|
||||||
try:
|
|
||||||
return obj.__code__ # type: ignore[attr-defined,no-any-return]
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
if trycall:
|
|
||||||
call = getattr(obj, "__call__", None)
|
|
||||||
if call and not isinstance(obj, type):
|
|
||||||
return getrawcode(call, trycall=False)
|
|
||||||
raise TypeError(f"could not get code object for {obj!r}")
|
|
||||||
|
|
||||||
|
|
||||||
def deindent(lines: Iterable[str]) -> list[str]:
|
|
||||||
return textwrap.dedent("\n".join(lines)).splitlines()
|
|
||||||
|
|
||||||
|
|
||||||
def get_statement_startend2(lineno: int, node: ast.AST) -> tuple[int, int | None]:
|
|
||||||
# Flatten all statements and except handlers into one lineno-list.
|
|
||||||
# AST's line numbers start indexing at 1.
|
|
||||||
values: list[int] = []
|
|
||||||
for x in ast.walk(node):
|
|
||||||
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
|
|
||||||
# The lineno points to the class/def, so need to include the decorators.
|
|
||||||
if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
||||||
for d in x.decorator_list:
|
|
||||||
values.append(d.lineno - 1)
|
|
||||||
values.append(x.lineno - 1)
|
|
||||||
for name in ("finalbody", "orelse"):
|
|
||||||
val: list[ast.stmt] | None = getattr(x, name, None)
|
|
||||||
if val:
|
|
||||||
# Treat the finally/orelse part as its own statement.
|
|
||||||
values.append(val[0].lineno - 1 - 1)
|
|
||||||
values.sort()
|
|
||||||
insert_index = bisect_right(values, lineno)
|
|
||||||
start = values[insert_index - 1]
|
|
||||||
if insert_index >= len(values):
|
|
||||||
end = None
|
|
||||||
else:
|
|
||||||
end = values[insert_index]
|
|
||||||
return start, end
|
|
||||||
|
|
||||||
|
|
||||||
def getstatementrange_ast(
|
|
||||||
lineno: int,
|
|
||||||
source: Source,
|
|
||||||
assertion: bool = False,
|
|
||||||
astnode: ast.AST | None = None,
|
|
||||||
) -> tuple[ast.AST, int, int]:
|
|
||||||
if astnode is None:
|
|
||||||
content = str(source)
|
|
||||||
# See #4260:
|
|
||||||
# Don't produce duplicate warnings when compiling source to find AST.
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
astnode = ast.parse(content, "source", "exec")
|
|
||||||
|
|
||||||
start, end = get_statement_startend2(lineno, astnode)
|
|
||||||
# We need to correct the end:
|
|
||||||
# - ast-parsing strips comments
|
|
||||||
# - there might be empty lines
|
|
||||||
# - we might have lesser indented code blocks at the end
|
|
||||||
if end is None:
|
|
||||||
end = len(source.lines)
|
|
||||||
|
|
||||||
if end > start + 1:
|
|
||||||
# Make sure we don't span differently indented code blocks
|
|
||||||
# by using the BlockFinder helper used which inspect.getsource() uses itself.
|
|
||||||
block_finder = inspect.BlockFinder()
|
|
||||||
# If we start with an indented line, put blockfinder to "started" mode.
|
|
||||||
block_finder.started = (
|
|
||||||
bool(source.lines[start]) and source.lines[start][0].isspace()
|
|
||||||
)
|
|
||||||
it = ((x + "\n") for x in source.lines[start:end])
|
|
||||||
try:
|
|
||||||
for tok in tokenize.generate_tokens(lambda: next(it)):
|
|
||||||
block_finder.tokeneater(*tok)
|
|
||||||
except (inspect.EndOfBlock, IndentationError):
|
|
||||||
end = block_finder.last + start
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# The end might still point to a comment or empty line, correct it.
|
|
||||||
while end:
|
|
||||||
line = source.lines[end - 1].lstrip()
|
|
||||||
if line.startswith("#") or not line:
|
|
||||||
end -= 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return astnode, start, end
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .terminalwriter import get_terminal_width
|
|
||||||
from .terminalwriter import TerminalWriter
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TerminalWriter",
|
|
||||||
"get_terminal_width",
|
|
||||||
]
|
|
||||||
@@ -1,673 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
# This module was imported from the cpython standard library
|
|
||||||
# (https://github.com/python/cpython/) at commit
|
|
||||||
# c5140945c723ae6c4b7ee81ff720ac8ea4b52cfd (python3.12).
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Original Author: Fred L. Drake, Jr.
|
|
||||||
# fdrake@acm.org
|
|
||||||
#
|
|
||||||
# This is a simple little module I wrote to make life easier. I didn't
|
|
||||||
# see anything quite like it in the library, though I may have overlooked
|
|
||||||
# something. I wrote this when I was trying to read some heavily nested
|
|
||||||
# tuples with fairly non-descriptive content. This is modeled very much
|
|
||||||
# after Lisp/Scheme - style pretty-printing of lists. If you find it
|
|
||||||
# useful, thank small children who sleep at night.
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import collections as _collections
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Iterator
|
|
||||||
import dataclasses as _dataclasses
|
|
||||||
from io import StringIO as _StringIO
|
|
||||||
import re
|
|
||||||
import types as _types
|
|
||||||
from typing import Any
|
|
||||||
from typing import IO
|
|
||||||
|
|
||||||
|
|
||||||
class _safe_key:
|
|
||||||
"""Helper function for key functions when sorting unorderable objects.
|
|
||||||
|
|
||||||
The wrapped-object will fallback to a Py2.x style comparison for
|
|
||||||
unorderable types (sorting first comparing the type name and then by
|
|
||||||
the obj ids). Does not work recursively, so dict.items() must have
|
|
||||||
_safe_key applied to both the key and the value.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ["obj"]
|
|
||||||
|
|
||||||
def __init__(self, obj):
|
|
||||||
self.obj = obj
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
try:
|
|
||||||
return self.obj < other.obj
|
|
||||||
except TypeError:
|
|
||||||
return (str(type(self.obj)), id(self.obj)) < (
|
|
||||||
str(type(other.obj)),
|
|
||||||
id(other.obj),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_tuple(t):
|
|
||||||
"""Helper function for comparing 2-tuples"""
|
|
||||||
return _safe_key(t[0]), _safe_key(t[1])
|
|
||||||
|
|
||||||
|
|
||||||
class PrettyPrinter:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
indent: int = 4,
|
|
||||||
width: int = 80,
|
|
||||||
depth: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Handle pretty printing operations onto a stream using a set of
|
|
||||||
configured parameters.
|
|
||||||
|
|
||||||
indent
|
|
||||||
Number of spaces to indent for each level of nesting.
|
|
||||||
|
|
||||||
width
|
|
||||||
Attempted maximum number of columns in the output.
|
|
||||||
|
|
||||||
depth
|
|
||||||
The maximum depth to print out nested structures.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if indent < 0:
|
|
||||||
raise ValueError("indent must be >= 0")
|
|
||||||
if depth is not None and depth <= 0:
|
|
||||||
raise ValueError("depth must be > 0")
|
|
||||||
if not width:
|
|
||||||
raise ValueError("width must be != 0")
|
|
||||||
self._depth = depth
|
|
||||||
self._indent_per_level = indent
|
|
||||||
self._width = width
|
|
||||||
|
|
||||||
def pformat(self, object: Any) -> str:
|
|
||||||
sio = _StringIO()
|
|
||||||
self._format(object, sio, 0, 0, set(), 0)
|
|
||||||
return sio.getvalue()
|
|
||||||
|
|
||||||
def _format(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
objid = id(object)
|
|
||||||
if objid in context:
|
|
||||||
stream.write(_recursion(object))
|
|
||||||
return
|
|
||||||
|
|
||||||
p = self._dispatch.get(type(object).__repr__, None)
|
|
||||||
if p is not None:
|
|
||||||
context.add(objid)
|
|
||||||
p(self, object, stream, indent, allowance, context, level + 1)
|
|
||||||
context.remove(objid)
|
|
||||||
elif (
|
|
||||||
_dataclasses.is_dataclass(object)
|
|
||||||
and not isinstance(object, type)
|
|
||||||
and object.__dataclass_params__.repr # type:ignore[attr-defined]
|
|
||||||
and
|
|
||||||
# Check dataclass has generated repr method.
|
|
||||||
hasattr(object.__repr__, "__wrapped__")
|
|
||||||
and "__create_fn__" in object.__repr__.__wrapped__.__qualname__
|
|
||||||
):
|
|
||||||
context.add(objid)
|
|
||||||
self._pprint_dataclass(
|
|
||||||
object, stream, indent, allowance, context, level + 1
|
|
||||||
)
|
|
||||||
context.remove(objid)
|
|
||||||
else:
|
|
||||||
stream.write(self._repr(object, context, level))
|
|
||||||
|
|
||||||
def _pprint_dataclass(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
cls_name = object.__class__.__name__
|
|
||||||
items = [
|
|
||||||
(f.name, getattr(object, f.name))
|
|
||||||
for f in _dataclasses.fields(object)
|
|
||||||
if f.repr
|
|
||||||
]
|
|
||||||
stream.write(cls_name + "(")
|
|
||||||
self._format_namespace_items(items, stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch: dict[
|
|
||||||
Callable[..., str],
|
|
||||||
Callable[[PrettyPrinter, Any, IO[str], int, int, set[int], int], None],
|
|
||||||
] = {}
|
|
||||||
|
|
||||||
def _pprint_dict(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
write = stream.write
|
|
||||||
write("{")
|
|
||||||
items = sorted(object.items(), key=_safe_tuple)
|
|
||||||
self._format_dict_items(items, stream, indent, allowance, context, level)
|
|
||||||
write("}")
|
|
||||||
|
|
||||||
_dispatch[dict.__repr__] = _pprint_dict
|
|
||||||
|
|
||||||
def _pprint_ordered_dict(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if not len(object):
|
|
||||||
stream.write(repr(object))
|
|
||||||
return
|
|
||||||
cls = object.__class__
|
|
||||||
stream.write(cls.__name__ + "(")
|
|
||||||
self._pprint_dict(object, stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[_collections.OrderedDict.__repr__] = _pprint_ordered_dict
|
|
||||||
|
|
||||||
def _pprint_list(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
stream.write("[")
|
|
||||||
self._format_items(object, stream, indent, allowance, context, level)
|
|
||||||
stream.write("]")
|
|
||||||
|
|
||||||
_dispatch[list.__repr__] = _pprint_list
|
|
||||||
|
|
||||||
def _pprint_tuple(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
stream.write("(")
|
|
||||||
self._format_items(object, stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[tuple.__repr__] = _pprint_tuple
|
|
||||||
|
|
||||||
def _pprint_set(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if not len(object):
|
|
||||||
stream.write(repr(object))
|
|
||||||
return
|
|
||||||
typ = object.__class__
|
|
||||||
if typ is set:
|
|
||||||
stream.write("{")
|
|
||||||
endchar = "}"
|
|
||||||
else:
|
|
||||||
stream.write(typ.__name__ + "({")
|
|
||||||
endchar = "})"
|
|
||||||
object = sorted(object, key=_safe_key)
|
|
||||||
self._format_items(object, stream, indent, allowance, context, level)
|
|
||||||
stream.write(endchar)
|
|
||||||
|
|
||||||
_dispatch[set.__repr__] = _pprint_set
|
|
||||||
_dispatch[frozenset.__repr__] = _pprint_set
|
|
||||||
|
|
||||||
def _pprint_str(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
write = stream.write
|
|
||||||
if not len(object):
|
|
||||||
write(repr(object))
|
|
||||||
return
|
|
||||||
chunks = []
|
|
||||||
lines = object.splitlines(True)
|
|
||||||
if level == 1:
|
|
||||||
indent += 1
|
|
||||||
allowance += 1
|
|
||||||
max_width1 = max_width = self._width - indent
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
rep = repr(line)
|
|
||||||
if i == len(lines) - 1:
|
|
||||||
max_width1 -= allowance
|
|
||||||
if len(rep) <= max_width1:
|
|
||||||
chunks.append(rep)
|
|
||||||
else:
|
|
||||||
# A list of alternating (non-space, space) strings
|
|
||||||
parts = re.findall(r"\S*\s*", line)
|
|
||||||
assert parts
|
|
||||||
assert not parts[-1]
|
|
||||||
parts.pop() # drop empty last part
|
|
||||||
max_width2 = max_width
|
|
||||||
current = ""
|
|
||||||
for j, part in enumerate(parts):
|
|
||||||
candidate = current + part
|
|
||||||
if j == len(parts) - 1 and i == len(lines) - 1:
|
|
||||||
max_width2 -= allowance
|
|
||||||
if len(repr(candidate)) > max_width2:
|
|
||||||
if current:
|
|
||||||
chunks.append(repr(current))
|
|
||||||
current = part
|
|
||||||
else:
|
|
||||||
current = candidate
|
|
||||||
if current:
|
|
||||||
chunks.append(repr(current))
|
|
||||||
if len(chunks) == 1:
|
|
||||||
write(rep)
|
|
||||||
return
|
|
||||||
if level == 1:
|
|
||||||
write("(")
|
|
||||||
for i, rep in enumerate(chunks):
|
|
||||||
if i > 0:
|
|
||||||
write("\n" + " " * indent)
|
|
||||||
write(rep)
|
|
||||||
if level == 1:
|
|
||||||
write(")")
|
|
||||||
|
|
||||||
_dispatch[str.__repr__] = _pprint_str
|
|
||||||
|
|
||||||
def _pprint_bytes(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
write = stream.write
|
|
||||||
if len(object) <= 4:
|
|
||||||
write(repr(object))
|
|
||||||
return
|
|
||||||
parens = level == 1
|
|
||||||
if parens:
|
|
||||||
indent += 1
|
|
||||||
allowance += 1
|
|
||||||
write("(")
|
|
||||||
delim = ""
|
|
||||||
for rep in _wrap_bytes_repr(object, self._width - indent, allowance):
|
|
||||||
write(delim)
|
|
||||||
write(rep)
|
|
||||||
if not delim:
|
|
||||||
delim = "\n" + " " * indent
|
|
||||||
if parens:
|
|
||||||
write(")")
|
|
||||||
|
|
||||||
_dispatch[bytes.__repr__] = _pprint_bytes
|
|
||||||
|
|
||||||
def _pprint_bytearray(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
write = stream.write
|
|
||||||
write("bytearray(")
|
|
||||||
self._pprint_bytes(
|
|
||||||
bytes(object), stream, indent + 10, allowance + 1, context, level + 1
|
|
||||||
)
|
|
||||||
write(")")
|
|
||||||
|
|
||||||
_dispatch[bytearray.__repr__] = _pprint_bytearray
|
|
||||||
|
|
||||||
def _pprint_mappingproxy(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
stream.write("mappingproxy(")
|
|
||||||
self._format(object.copy(), stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[_types.MappingProxyType.__repr__] = _pprint_mappingproxy
|
|
||||||
|
|
||||||
def _pprint_simplenamespace(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if type(object) is _types.SimpleNamespace:
|
|
||||||
# The SimpleNamespace repr is "namespace" instead of the class
|
|
||||||
# name, so we do the same here. For subclasses; use the class name.
|
|
||||||
cls_name = "namespace"
|
|
||||||
else:
|
|
||||||
cls_name = object.__class__.__name__
|
|
||||||
items = object.__dict__.items()
|
|
||||||
stream.write(cls_name + "(")
|
|
||||||
self._format_namespace_items(items, stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[_types.SimpleNamespace.__repr__] = _pprint_simplenamespace
|
|
||||||
|
|
||||||
def _format_dict_items(
|
|
||||||
self,
|
|
||||||
items: list[tuple[Any, Any]],
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if not items:
|
|
||||||
return
|
|
||||||
|
|
||||||
write = stream.write
|
|
||||||
item_indent = indent + self._indent_per_level
|
|
||||||
delimnl = "\n" + " " * item_indent
|
|
||||||
for key, ent in items:
|
|
||||||
write(delimnl)
|
|
||||||
write(self._repr(key, context, level))
|
|
||||||
write(": ")
|
|
||||||
self._format(ent, stream, item_indent, 1, context, level)
|
|
||||||
write(",")
|
|
||||||
|
|
||||||
write("\n" + " " * indent)
|
|
||||||
|
|
||||||
def _format_namespace_items(
|
|
||||||
self,
|
|
||||||
items: list[tuple[Any, Any]],
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if not items:
|
|
||||||
return
|
|
||||||
|
|
||||||
write = stream.write
|
|
||||||
item_indent = indent + self._indent_per_level
|
|
||||||
delimnl = "\n" + " " * item_indent
|
|
||||||
for key, ent in items:
|
|
||||||
write(delimnl)
|
|
||||||
write(key)
|
|
||||||
write("=")
|
|
||||||
if id(ent) in context:
|
|
||||||
# Special-case representation of recursion to match standard
|
|
||||||
# recursive dataclass repr.
|
|
||||||
write("...")
|
|
||||||
else:
|
|
||||||
self._format(
|
|
||||||
ent,
|
|
||||||
stream,
|
|
||||||
item_indent + len(key) + 1,
|
|
||||||
1,
|
|
||||||
context,
|
|
||||||
level,
|
|
||||||
)
|
|
||||||
|
|
||||||
write(",")
|
|
||||||
|
|
||||||
write("\n" + " " * indent)
|
|
||||||
|
|
||||||
def _format_items(
|
|
||||||
self,
|
|
||||||
items: list[Any],
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if not items:
|
|
||||||
return
|
|
||||||
|
|
||||||
write = stream.write
|
|
||||||
item_indent = indent + self._indent_per_level
|
|
||||||
delimnl = "\n" + " " * item_indent
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
write(delimnl)
|
|
||||||
self._format(item, stream, item_indent, 1, context, level)
|
|
||||||
write(",")
|
|
||||||
|
|
||||||
write("\n" + " " * indent)
|
|
||||||
|
|
||||||
def _repr(self, object: Any, context: set[int], level: int) -> str:
|
|
||||||
return self._safe_repr(object, context.copy(), self._depth, level)
|
|
||||||
|
|
||||||
def _pprint_default_dict(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
rdf = self._repr(object.default_factory, context, level)
|
|
||||||
stream.write(f"{object.__class__.__name__}({rdf}, ")
|
|
||||||
self._pprint_dict(object, stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[_collections.defaultdict.__repr__] = _pprint_default_dict
|
|
||||||
|
|
||||||
def _pprint_counter(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
stream.write(object.__class__.__name__ + "(")
|
|
||||||
|
|
||||||
if object:
|
|
||||||
stream.write("{")
|
|
||||||
items = object.most_common()
|
|
||||||
self._format_dict_items(items, stream, indent, allowance, context, level)
|
|
||||||
stream.write("}")
|
|
||||||
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[_collections.Counter.__repr__] = _pprint_counter
|
|
||||||
|
|
||||||
def _pprint_chain_map(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
if not len(object.maps) or (len(object.maps) == 1 and not len(object.maps[0])):
|
|
||||||
stream.write(repr(object))
|
|
||||||
return
|
|
||||||
|
|
||||||
stream.write(object.__class__.__name__ + "(")
|
|
||||||
self._format_items(object.maps, stream, indent, allowance, context, level)
|
|
||||||
stream.write(")")
|
|
||||||
|
|
||||||
_dispatch[_collections.ChainMap.__repr__] = _pprint_chain_map
|
|
||||||
|
|
||||||
def _pprint_deque(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
stream.write(object.__class__.__name__ + "(")
|
|
||||||
if object.maxlen is not None:
|
|
||||||
stream.write(f"maxlen={object.maxlen}, ")
|
|
||||||
stream.write("[")
|
|
||||||
|
|
||||||
self._format_items(object, stream, indent, allowance + 1, context, level)
|
|
||||||
stream.write("])")
|
|
||||||
|
|
||||||
_dispatch[_collections.deque.__repr__] = _pprint_deque
|
|
||||||
|
|
||||||
def _pprint_user_dict(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
self._format(object.data, stream, indent, allowance, context, level - 1)
|
|
||||||
|
|
||||||
_dispatch[_collections.UserDict.__repr__] = _pprint_user_dict
|
|
||||||
|
|
||||||
def _pprint_user_list(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
self._format(object.data, stream, indent, allowance, context, level - 1)
|
|
||||||
|
|
||||||
_dispatch[_collections.UserList.__repr__] = _pprint_user_list
|
|
||||||
|
|
||||||
def _pprint_user_string(
|
|
||||||
self,
|
|
||||||
object: Any,
|
|
||||||
stream: IO[str],
|
|
||||||
indent: int,
|
|
||||||
allowance: int,
|
|
||||||
context: set[int],
|
|
||||||
level: int,
|
|
||||||
) -> None:
|
|
||||||
self._format(object.data, stream, indent, allowance, context, level - 1)
|
|
||||||
|
|
||||||
_dispatch[_collections.UserString.__repr__] = _pprint_user_string
|
|
||||||
|
|
||||||
def _safe_repr(
|
|
||||||
self, object: Any, context: set[int], maxlevels: int | None, level: int
|
|
||||||
) -> str:
|
|
||||||
typ = type(object)
|
|
||||||
if typ in _builtin_scalars:
|
|
||||||
return repr(object)
|
|
||||||
|
|
||||||
r = getattr(typ, "__repr__", None)
|
|
||||||
|
|
||||||
if issubclass(typ, dict) and r is dict.__repr__:
|
|
||||||
if not object:
|
|
||||||
return "{}"
|
|
||||||
objid = id(object)
|
|
||||||
if maxlevels and level >= maxlevels:
|
|
||||||
return "{...}"
|
|
||||||
if objid in context:
|
|
||||||
return _recursion(object)
|
|
||||||
context.add(objid)
|
|
||||||
components: list[str] = []
|
|
||||||
append = components.append
|
|
||||||
level += 1
|
|
||||||
for k, v in sorted(object.items(), key=_safe_tuple):
|
|
||||||
krepr = self._safe_repr(k, context, maxlevels, level)
|
|
||||||
vrepr = self._safe_repr(v, context, maxlevels, level)
|
|
||||||
append(f"{krepr}: {vrepr}")
|
|
||||||
context.remove(objid)
|
|
||||||
return "{{{}}}".format(", ".join(components))
|
|
||||||
|
|
||||||
if (issubclass(typ, list) and r is list.__repr__) or (
|
|
||||||
issubclass(typ, tuple) and r is tuple.__repr__
|
|
||||||
):
|
|
||||||
if issubclass(typ, list):
|
|
||||||
if not object:
|
|
||||||
return "[]"
|
|
||||||
format = "[%s]"
|
|
||||||
elif len(object) == 1:
|
|
||||||
format = "(%s,)"
|
|
||||||
else:
|
|
||||||
if not object:
|
|
||||||
return "()"
|
|
||||||
format = "(%s)"
|
|
||||||
objid = id(object)
|
|
||||||
if maxlevels and level >= maxlevels:
|
|
||||||
return format % "..."
|
|
||||||
if objid in context:
|
|
||||||
return _recursion(object)
|
|
||||||
context.add(objid)
|
|
||||||
components = []
|
|
||||||
append = components.append
|
|
||||||
level += 1
|
|
||||||
for o in object:
|
|
||||||
orepr = self._safe_repr(o, context, maxlevels, level)
|
|
||||||
append(orepr)
|
|
||||||
context.remove(objid)
|
|
||||||
return format % ", ".join(components)
|
|
||||||
|
|
||||||
return repr(object)
|
|
||||||
|
|
||||||
|
|
||||||
_builtin_scalars = frozenset(
|
|
||||||
{str, bytes, bytearray, float, complex, bool, type(None), int}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _recursion(object: Any) -> str:
|
|
||||||
return f"<Recursion on {type(object).__name__} with id={id(object)}>"
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_bytes_repr(object: Any, width: int, allowance: int) -> Iterator[str]:
|
|
||||||
current = b""
|
|
||||||
last = len(object) // 4 * 4
|
|
||||||
for i in range(0, len(object), 4):
|
|
||||||
part = object[i : i + 4]
|
|
||||||
candidate = current + part
|
|
||||||
if i == last:
|
|
||||||
width -= allowance
|
|
||||||
if len(repr(candidate)) > width:
|
|
||||||
if current:
|
|
||||||
yield repr(current)
|
|
||||||
current = part
|
|
||||||
else:
|
|
||||||
current = candidate
|
|
||||||
if current:
|
|
||||||
yield repr(current)
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pprint
|
|
||||||
import reprlib
|
|
||||||
|
|
||||||
|
|
||||||
def _try_repr_or_str(obj: object) -> str:
|
|
||||||
try:
|
|
||||||
return repr(obj)
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
raise
|
|
||||||
except BaseException:
|
|
||||||
return f'{type(obj).__name__}("{obj}")'
|
|
||||||
|
|
||||||
|
|
||||||
def _format_repr_exception(exc: BaseException, obj: object) -> str:
|
|
||||||
try:
|
|
||||||
exc_info = _try_repr_or_str(exc)
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
raise
|
|
||||||
except BaseException as inner_exc:
|
|
||||||
exc_info = f"unpresentable exception ({_try_repr_or_str(inner_exc)})"
|
|
||||||
return (
|
|
||||||
f"<[{exc_info} raised in repr()] {type(obj).__name__} object at 0x{id(obj):x}>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ellipsize(s: str, maxsize: int) -> str:
|
|
||||||
if len(s) > maxsize:
|
|
||||||
i = max(0, (maxsize - 3) // 2)
|
|
||||||
j = max(0, maxsize - 3 - i)
|
|
||||||
return s[:i] + "..." + s[len(s) - j :]
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class SafeRepr(reprlib.Repr):
|
|
||||||
"""
|
|
||||||
repr.Repr that limits the resulting size of repr() and includes
|
|
||||||
information on exceptions raised during the call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int | None, use_ascii: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
:param maxsize:
|
|
||||||
If not None, will truncate the resulting repr to that specific size, using ellipsis
|
|
||||||
somewhere in the middle to hide the extra text.
|
|
||||||
If None, will not impose any size limits on the returning repr.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
# ``maxstring`` is used by the superclass, and needs to be an int; using a
|
|
||||||
# very large number in case maxsize is None, meaning we want to disable
|
|
||||||
# truncation.
|
|
||||||
self.maxstring = maxsize if maxsize is not None else 1_000_000_000
|
|
||||||
self.maxsize = maxsize
|
|
||||||
self.use_ascii = use_ascii
|
|
||||||
|
|
||||||
def repr(self, x: object) -> str:
|
|
||||||
try:
|
|
||||||
if self.use_ascii:
|
|
||||||
s = ascii(x)
|
|
||||||
else:
|
|
||||||
s = super().repr(x)
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
raise
|
|
||||||
except BaseException as exc:
|
|
||||||
s = _format_repr_exception(exc, x)
|
|
||||||
if self.maxsize is not None:
|
|
||||||
s = _ellipsize(s, self.maxsize)
|
|
||||||
return s
|
|
||||||
|
|
||||||
def repr_instance(self, x: object, level: int) -> str:
|
|
||||||
try:
|
|
||||||
s = repr(x)
|
|
||||||
except (KeyboardInterrupt, SystemExit):
|
|
||||||
raise
|
|
||||||
except BaseException as exc:
|
|
||||||
s = _format_repr_exception(exc, x)
|
|
||||||
if self.maxsize is not None:
|
|
||||||
s = _ellipsize(s, self.maxsize)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
def safeformat(obj: object) -> str:
|
|
||||||
"""Return a pretty printed string for the given object.
|
|
||||||
|
|
||||||
Failing __repr__ functions of user instances will be represented
|
|
||||||
with a short exception info.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return pprint.pformat(obj)
|
|
||||||
except Exception as exc:
|
|
||||||
return _format_repr_exception(exc, obj)
|
|
||||||
|
|
||||||
|
|
||||||
# Maximum size of overall repr of objects to display during assertion errors.
|
|
||||||
DEFAULT_REPR_MAX_SIZE = 240
|
|
||||||
|
|
||||||
|
|
||||||
def saferepr(
|
|
||||||
obj: object, maxsize: int | None = DEFAULT_REPR_MAX_SIZE, use_ascii: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Return a size-limited safe repr-string for the given object.
|
|
||||||
|
|
||||||
Failing __repr__ functions of user instances will be represented
|
|
||||||
with a short exception info and 'saferepr' generally takes
|
|
||||||
care to never raise exceptions itself.
|
|
||||||
|
|
||||||
This function is a wrapper around the Repr/reprlib functionality of the
|
|
||||||
stdlib.
|
|
||||||
"""
|
|
||||||
return SafeRepr(maxsize, use_ascii).repr(obj)
|
|
||||||
|
|
||||||
|
|
||||||
def saferepr_unlimited(obj: object, use_ascii: bool = True) -> str:
|
|
||||||
"""Return an unlimited-size safe repr-string for the given object.
|
|
||||||
|
|
||||||
As with saferepr, failing __repr__ functions of user instances
|
|
||||||
will be represented with a short exception info.
|
|
||||||
|
|
||||||
This function is a wrapper around simple repr.
|
|
||||||
|
|
||||||
Note: a cleaner solution would be to alter ``saferepr``this way
|
|
||||||
when maxsize=None, but that might affect some other code.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if use_ascii:
|
|
||||||
return ascii(obj)
|
|
||||||
return repr(obj)
|
|
||||||
except Exception as exc:
|
|
||||||
return _format_repr_exception(exc, obj)
|
|
||||||
@@ -1,254 +0,0 @@
|
|||||||
"""Helper functions for writing to terminals and files."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
from typing import final
|
|
||||||
from typing import Literal
|
|
||||||
from typing import TextIO
|
|
||||||
|
|
||||||
import pygments
|
|
||||||
from pygments.formatters.terminal import TerminalFormatter
|
|
||||||
from pygments.lexer import Lexer
|
|
||||||
from pygments.lexers.diff import DiffLexer
|
|
||||||
from pygments.lexers.python import PythonLexer
|
|
||||||
|
|
||||||
from ..compat import assert_never
|
|
||||||
from .wcwidth import wcswidth
|
|
||||||
|
|
||||||
|
|
||||||
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.
|
|
||||||
|
|
||||||
|
|
||||||
def get_terminal_width() -> int:
|
|
||||||
width, _ = shutil.get_terminal_size(fallback=(80, 24))
|
|
||||||
|
|
||||||
# The Windows get_terminal_size may be bogus, let's sanify a bit.
|
|
||||||
if width < 40:
|
|
||||||
width = 80
|
|
||||||
|
|
||||||
return width
|
|
||||||
|
|
||||||
|
|
||||||
def should_do_markup(file: TextIO) -> bool:
|
|
||||||
if os.environ.get("PY_COLORS") == "1":
|
|
||||||
return True
|
|
||||||
if os.environ.get("PY_COLORS") == "0":
|
|
||||||
return False
|
|
||||||
if os.environ.get("NO_COLOR"):
|
|
||||||
return False
|
|
||||||
if os.environ.get("FORCE_COLOR"):
|
|
||||||
return True
|
|
||||||
return (
|
|
||||||
hasattr(file, "isatty") and file.isatty() and os.environ.get("TERM") != "dumb"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class TerminalWriter:
|
|
||||||
_esctable = dict(
|
|
||||||
black=30,
|
|
||||||
red=31,
|
|
||||||
green=32,
|
|
||||||
yellow=33,
|
|
||||||
blue=34,
|
|
||||||
purple=35,
|
|
||||||
cyan=36,
|
|
||||||
white=37,
|
|
||||||
Black=40,
|
|
||||||
Red=41,
|
|
||||||
Green=42,
|
|
||||||
Yellow=43,
|
|
||||||
Blue=44,
|
|
||||||
Purple=45,
|
|
||||||
Cyan=46,
|
|
||||||
White=47,
|
|
||||||
bold=1,
|
|
||||||
light=2,
|
|
||||||
blink=5,
|
|
||||||
invert=7,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, file: TextIO | None = None) -> None:
|
|
||||||
if file is None:
|
|
||||||
file = sys.stdout
|
|
||||||
if hasattr(file, "isatty") and file.isatty() and sys.platform == "win32":
|
|
||||||
try:
|
|
||||||
import colorama
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
file = colorama.AnsiToWin32(file).stream
|
|
||||||
assert file is not None
|
|
||||||
self._file = file
|
|
||||||
self.hasmarkup = should_do_markup(file)
|
|
||||||
self._current_line = ""
|
|
||||||
self._terminal_width: int | None = None
|
|
||||||
self.code_highlight = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fullwidth(self) -> int:
|
|
||||||
if self._terminal_width is not None:
|
|
||||||
return self._terminal_width
|
|
||||||
return get_terminal_width()
|
|
||||||
|
|
||||||
@fullwidth.setter
|
|
||||||
def fullwidth(self, value: int) -> None:
|
|
||||||
self._terminal_width = value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def width_of_current_line(self) -> int:
|
|
||||||
"""Return an estimate of the width so far in the current line."""
|
|
||||||
return wcswidth(self._current_line)
|
|
||||||
|
|
||||||
def markup(self, text: str, **markup: bool) -> str:
|
|
||||||
for name in markup:
|
|
||||||
if name not in self._esctable:
|
|
||||||
raise ValueError(f"unknown markup: {name!r}")
|
|
||||||
if self.hasmarkup:
|
|
||||||
esc = [self._esctable[name] for name, on in markup.items() if on]
|
|
||||||
if esc:
|
|
||||||
text = "".join(f"\x1b[{cod}m" for cod in esc) + text + "\x1b[0m"
|
|
||||||
return text
|
|
||||||
|
|
||||||
def sep(
|
|
||||||
self,
|
|
||||||
sepchar: str,
|
|
||||||
title: str | None = None,
|
|
||||||
fullwidth: int | None = None,
|
|
||||||
**markup: bool,
|
|
||||||
) -> None:
|
|
||||||
if fullwidth is None:
|
|
||||||
fullwidth = self.fullwidth
|
|
||||||
# The goal is to have the line be as long as possible
|
|
||||||
# under the condition that len(line) <= fullwidth.
|
|
||||||
if sys.platform == "win32":
|
|
||||||
# If we print in the last column on windows we are on a
|
|
||||||
# new line but there is no way to verify/neutralize this
|
|
||||||
# (we may not know the exact line width).
|
|
||||||
# So let's be defensive to avoid empty lines in the output.
|
|
||||||
fullwidth -= 1
|
|
||||||
if title is not None:
|
|
||||||
# we want 2 + 2*len(fill) + len(title) <= fullwidth
|
|
||||||
# i.e. 2 + 2*len(sepchar)*N + len(title) <= fullwidth
|
|
||||||
# 2*len(sepchar)*N <= fullwidth - len(title) - 2
|
|
||||||
# N <= (fullwidth - len(title) - 2) // (2*len(sepchar))
|
|
||||||
N = max((fullwidth - len(title) - 2) // (2 * len(sepchar)), 1)
|
|
||||||
fill = sepchar * N
|
|
||||||
line = f"{fill} {title} {fill}"
|
|
||||||
else:
|
|
||||||
# we want len(sepchar)*N <= fullwidth
|
|
||||||
# i.e. N <= fullwidth // len(sepchar)
|
|
||||||
line = sepchar * (fullwidth // len(sepchar))
|
|
||||||
# In some situations there is room for an extra sepchar at the right,
|
|
||||||
# in particular if we consider that with a sepchar like "_ " the
|
|
||||||
# trailing space is not important at the end of the line.
|
|
||||||
if len(line) + len(sepchar.rstrip()) <= fullwidth:
|
|
||||||
line += sepchar.rstrip()
|
|
||||||
|
|
||||||
self.line(line, **markup)
|
|
||||||
|
|
||||||
def write(self, msg: str, *, flush: bool = False, **markup: bool) -> None:
|
|
||||||
if msg:
|
|
||||||
current_line = msg.rsplit("\n", 1)[-1]
|
|
||||||
if "\n" in msg:
|
|
||||||
self._current_line = current_line
|
|
||||||
else:
|
|
||||||
self._current_line += current_line
|
|
||||||
|
|
||||||
msg = self.markup(msg, **markup)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self._file.write(msg)
|
|
||||||
except UnicodeEncodeError:
|
|
||||||
# Some environments don't support printing general Unicode
|
|
||||||
# strings, due to misconfiguration or otherwise; in that case,
|
|
||||||
# print the string escaped to ASCII.
|
|
||||||
# When the Unicode situation improves we should consider
|
|
||||||
# letting the error propagate instead of masking it (see #7475
|
|
||||||
# for one brief attempt).
|
|
||||||
msg = msg.encode("unicode-escape").decode("ascii")
|
|
||||||
self._file.write(msg)
|
|
||||||
|
|
||||||
if flush:
|
|
||||||
self.flush()
|
|
||||||
|
|
||||||
def line(self, s: str = "", **markup: bool) -> None:
|
|
||||||
self.write(s, **markup)
|
|
||||||
self.write("\n")
|
|
||||||
|
|
||||||
def flush(self) -> None:
|
|
||||||
self._file.flush()
|
|
||||||
|
|
||||||
def _write_source(self, lines: Sequence[str], indents: Sequence[str] = ()) -> None:
|
|
||||||
"""Write lines of source code possibly highlighted.
|
|
||||||
|
|
||||||
Keeping this private for now because the API is clunky. We should discuss how
|
|
||||||
to evolve the terminal writer so we can have more precise color support, for example
|
|
||||||
being able to write part of a line in one color and the rest in another, and so on.
|
|
||||||
"""
|
|
||||||
if indents and len(indents) != len(lines):
|
|
||||||
raise ValueError(
|
|
||||||
f"indents size ({len(indents)}) should have same size as lines ({len(lines)})"
|
|
||||||
)
|
|
||||||
if not indents:
|
|
||||||
indents = [""] * len(lines)
|
|
||||||
source = "\n".join(lines)
|
|
||||||
new_lines = self._highlight(source).splitlines()
|
|
||||||
for indent, new_line in zip(indents, new_lines):
|
|
||||||
self.line(indent + new_line)
|
|
||||||
|
|
||||||
def _get_pygments_lexer(self, lexer: Literal["python", "diff"]) -> Lexer:
|
|
||||||
if lexer == "python":
|
|
||||||
return PythonLexer()
|
|
||||||
elif lexer == "diff":
|
|
||||||
return DiffLexer()
|
|
||||||
else:
|
|
||||||
assert_never(lexer)
|
|
||||||
|
|
||||||
def _get_pygments_formatter(self) -> TerminalFormatter:
|
|
||||||
from _pytest.config.exceptions import UsageError
|
|
||||||
|
|
||||||
theme = os.getenv("PYTEST_THEME")
|
|
||||||
theme_mode = os.getenv("PYTEST_THEME_MODE", "dark")
|
|
||||||
|
|
||||||
try:
|
|
||||||
return TerminalFormatter(bg=theme_mode, style=theme)
|
|
||||||
except pygments.util.ClassNotFound as e:
|
|
||||||
raise UsageError(
|
|
||||||
f"PYTEST_THEME environment variable has an invalid value: '{theme}'. "
|
|
||||||
"Hint: See available pygments styles with `pygmentize -L styles`."
|
|
||||||
) from e
|
|
||||||
except pygments.util.OptionError as e:
|
|
||||||
raise UsageError(
|
|
||||||
f"PYTEST_THEME_MODE environment variable has an invalid value: '{theme_mode}'. "
|
|
||||||
"The allowed values are 'dark' (default) and 'light'."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def _highlight(
|
|
||||||
self, source: str, lexer: Literal["diff", "python"] = "python"
|
|
||||||
) -> str:
|
|
||||||
"""Highlight the given source if we have markup support."""
|
|
||||||
if not source or not self.hasmarkup or not self.code_highlight:
|
|
||||||
return source
|
|
||||||
|
|
||||||
pygments_lexer = self._get_pygments_lexer(lexer)
|
|
||||||
pygments_formatter = self._get_pygments_formatter()
|
|
||||||
|
|
||||||
highlighted: str = pygments.highlight(
|
|
||||||
source, pygments_lexer, pygments_formatter
|
|
||||||
)
|
|
||||||
# pygments terminal formatter may add a newline when there wasn't one.
|
|
||||||
# We don't want this, remove.
|
|
||||||
if highlighted[-1] == "\n" and source[-1] != "\n":
|
|
||||||
highlighted = highlighted[:-1]
|
|
||||||
|
|
||||||
# Some lexers will not set the initial color explicitly
|
|
||||||
# which may lead to the previous color being propagated to the
|
|
||||||
# start of the expression, so reset first.
|
|
||||||
highlighted = "\x1b[0m" + highlighted
|
|
||||||
|
|
||||||
return highlighted
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(100)
|
|
||||||
def wcwidth(c: str) -> int:
|
|
||||||
"""Determine how many columns are needed to display a character in a terminal.
|
|
||||||
|
|
||||||
Returns -1 if the character is not printable.
|
|
||||||
Returns 0, 1 or 2 for other characters.
|
|
||||||
"""
|
|
||||||
o = ord(c)
|
|
||||||
|
|
||||||
# ASCII fast path.
|
|
||||||
if 0x20 <= o < 0x07F:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
# Some Cf/Zp/Zl characters which should be zero-width.
|
|
||||||
if (
|
|
||||||
o == 0x0000
|
|
||||||
or 0x200B <= o <= 0x200F
|
|
||||||
or 0x2028 <= o <= 0x202E
|
|
||||||
or 0x2060 <= o <= 0x2063
|
|
||||||
):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
category = unicodedata.category(c)
|
|
||||||
|
|
||||||
# Control characters.
|
|
||||||
if category == "Cc":
|
|
||||||
return -1
|
|
||||||
|
|
||||||
# Combining characters with zero width.
|
|
||||||
if category in ("Me", "Mn"):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Full/Wide east asian characters.
|
|
||||||
if unicodedata.east_asian_width(c) in ("F", "W"):
|
|
||||||
return 2
|
|
||||||
|
|
||||||
return 1
|
|
||||||
|
|
||||||
|
|
||||||
def wcswidth(s: str) -> int:
|
|
||||||
"""Determine how many columns are needed to display a string in a terminal.
|
|
||||||
|
|
||||||
Returns -1 if the string contains non-printable characters.
|
|
||||||
"""
|
|
||||||
width = 0
|
|
||||||
for c in unicodedata.normalize("NFC", s):
|
|
||||||
wc = wcwidth(c)
|
|
||||||
if wc < 0:
|
|
||||||
return -1
|
|
||||||
width += wc
|
|
||||||
return width
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""create errno-specific classes for IO or os calls."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
import errno
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing_extensions import ParamSpec
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class Error(EnvironmentError):
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "{}.{} {!r}: {} ".format(
|
|
||||||
self.__class__.__module__,
|
|
||||||
self.__class__.__name__,
|
|
||||||
self.__class__.__doc__,
|
|
||||||
" ".join(map(str, self.args)),
|
|
||||||
# repr(self.args)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
s = "[{}]: {}".format(
|
|
||||||
self.__class__.__doc__,
|
|
||||||
" ".join(map(str, self.args)),
|
|
||||||
)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
_winerrnomap = {
|
|
||||||
2: errno.ENOENT,
|
|
||||||
3: errno.ENOENT,
|
|
||||||
17: errno.EEXIST,
|
|
||||||
18: errno.EXDEV,
|
|
||||||
13: errno.EBUSY, # empty cd drive, but ENOMEDIUM seems unavailable
|
|
||||||
22: errno.ENOTDIR,
|
|
||||||
20: errno.ENOTDIR,
|
|
||||||
267: errno.ENOTDIR,
|
|
||||||
5: errno.EACCES, # anything better?
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorMaker:
|
|
||||||
"""lazily provides Exception classes for each possible POSIX errno
|
|
||||||
(as defined per the 'errno' module). All such instances
|
|
||||||
subclass EnvironmentError.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_errno2class: dict[int, type[Error]] = {}
|
|
||||||
|
|
||||||
def __getattr__(self, name: str) -> type[Error]:
|
|
||||||
if name[0] == "_":
|
|
||||||
raise AttributeError(name)
|
|
||||||
eno = getattr(errno, name)
|
|
||||||
cls = self._geterrnoclass(eno)
|
|
||||||
setattr(self, name, cls)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
def _geterrnoclass(self, eno: int) -> type[Error]:
|
|
||||||
try:
|
|
||||||
return self._errno2class[eno]
|
|
||||||
except KeyError:
|
|
||||||
clsname = errno.errorcode.get(eno, f"UnknownErrno{eno}")
|
|
||||||
errorcls = type(
|
|
||||||
clsname,
|
|
||||||
(Error,),
|
|
||||||
{"__module__": "py.error", "__doc__": os.strerror(eno)},
|
|
||||||
)
|
|
||||||
self._errno2class[eno] = errorcls
|
|
||||||
return errorcls
|
|
||||||
|
|
||||||
def checked_call(
|
|
||||||
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
|
||||||
) -> R:
|
|
||||||
"""Call a function and raise an errno-exception if applicable."""
|
|
||||||
__tracebackhide__ = True
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except Error:
|
|
||||||
raise
|
|
||||||
except OSError as value:
|
|
||||||
if not hasattr(value, "errno"):
|
|
||||||
raise
|
|
||||||
if sys.platform == "win32":
|
|
||||||
try:
|
|
||||||
# error: Invalid index type "Optional[int]" for "dict[int, int]"; expected type "int" [index]
|
|
||||||
# OK to ignore because we catch the KeyError below.
|
|
||||||
cls = self._geterrnoclass(_winerrnomap[value.errno]) # type:ignore[index]
|
|
||||||
except KeyError:
|
|
||||||
raise value
|
|
||||||
else:
|
|
||||||
# we are not on Windows, or we got a proper OSError
|
|
||||||
if value.errno is None:
|
|
||||||
cls = type(
|
|
||||||
"UnknownErrnoNone",
|
|
||||||
(Error,),
|
|
||||||
{"__module__": "py.error", "__doc__": None},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cls = self._geterrnoclass(value.errno)
|
|
||||||
|
|
||||||
raise cls(f"{func.__name__}{args!r}")
|
|
||||||
|
|
||||||
|
|
||||||
_error_maker = ErrorMaker()
|
|
||||||
checked_call = _error_maker.checked_call
|
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(attr: str) -> type[Error]:
|
|
||||||
return getattr(_error_maker, attr) # type: ignore[no-any-return]
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,21 +0,0 @@
|
|||||||
# file generated by setuptools-scm
|
|
||||||
# don't change, don't track in version control
|
|
||||||
|
|
||||||
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
||||||
|
|
||||||
TYPE_CHECKING = False
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
||||||
else:
|
|
||||||
VERSION_TUPLE = object
|
|
||||||
|
|
||||||
version: str
|
|
||||||
__version__: str
|
|
||||||
__version_tuple__: VERSION_TUPLE
|
|
||||||
version_tuple: VERSION_TUPLE
|
|
||||||
|
|
||||||
__version__ = version = '8.4.1'
|
|
||||||
__version_tuple__ = version_tuple = (8, 4, 1)
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
"""Support for presenting detailed information in failing assertions."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
import sys
|
|
||||||
from typing import Any
|
|
||||||
from typing import Protocol
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from _pytest.assertion import rewrite
|
|
||||||
from _pytest.assertion import truncate
|
|
||||||
from _pytest.assertion import util
|
|
||||||
from _pytest.assertion.rewrite import assertstate_key
|
|
||||||
from _pytest.config import Config
|
|
||||||
from _pytest.config import hookimpl
|
|
||||||
from _pytest.config.argparsing import Parser
|
|
||||||
from _pytest.nodes import Item
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from _pytest.main import Session
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: Parser) -> None:
|
|
||||||
group = parser.getgroup("debugconfig")
|
|
||||||
group.addoption(
|
|
||||||
"--assert",
|
|
||||||
action="store",
|
|
||||||
dest="assertmode",
|
|
||||||
choices=("rewrite", "plain"),
|
|
||||||
default="rewrite",
|
|
||||||
metavar="MODE",
|
|
||||||
help=(
|
|
||||||
"Control assertion debugging tools.\n"
|
|
||||||
"'plain' performs no assertion debugging.\n"
|
|
||||||
"'rewrite' (the default) rewrites assert statements in test modules"
|
|
||||||
" on import to provide assert expression information."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.addini(
|
|
||||||
"enable_assertion_pass_hook",
|
|
||||||
type="bool",
|
|
||||||
default=False,
|
|
||||||
help="Enables the pytest_assertion_pass hook. "
|
|
||||||
"Make sure to delete any previously generated pyc cache files.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.addini(
|
|
||||||
"truncation_limit_lines",
|
|
||||||
default=None,
|
|
||||||
help="Set threshold of LINES after which truncation will take effect",
|
|
||||||
)
|
|
||||||
parser.addini(
|
|
||||||
"truncation_limit_chars",
|
|
||||||
default=None,
|
|
||||||
help=("Set threshold of CHARS after which truncation will take effect"),
|
|
||||||
)
|
|
||||||
|
|
||||||
Config._add_verbosity_ini(
|
|
||||||
parser,
|
|
||||||
Config.VERBOSITY_ASSERTIONS,
|
|
||||||
help=(
|
|
||||||
"Specify a verbosity level for assertions, overriding the main level. "
|
|
||||||
"Higher levels will provide more detailed explanation when an assertion fails."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_assert_rewrite(*names: str) -> None:
|
|
||||||
"""Register one or more module names to be rewritten on import.
|
|
||||||
|
|
||||||
This function will make sure that this module or all modules inside
|
|
||||||
the package will get their assert statements rewritten.
|
|
||||||
Thus you should make sure to call this before the module is
|
|
||||||
actually imported, usually in your __init__.py if you are a plugin
|
|
||||||
using a package.
|
|
||||||
|
|
||||||
:param names: The module names to register.
|
|
||||||
"""
|
|
||||||
for name in names:
|
|
||||||
if not isinstance(name, str):
|
|
||||||
msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
|
|
||||||
raise TypeError(msg.format(repr(names)))
|
|
||||||
rewrite_hook: RewriteHook
|
|
||||||
for hook in sys.meta_path:
|
|
||||||
if isinstance(hook, rewrite.AssertionRewritingHook):
|
|
||||||
rewrite_hook = hook
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
rewrite_hook = DummyRewriteHook()
|
|
||||||
rewrite_hook.mark_rewrite(*names)
|
|
||||||
|
|
||||||
|
|
||||||
class RewriteHook(Protocol):
|
|
||||||
def mark_rewrite(self, *names: str) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DummyRewriteHook:
|
|
||||||
"""A no-op import hook for when rewriting is disabled."""
|
|
||||||
|
|
||||||
def mark_rewrite(self, *names: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AssertionState:
|
|
||||||
"""State for the assertion plugin."""
|
|
||||||
|
|
||||||
def __init__(self, config: Config, mode) -> None:
|
|
||||||
self.mode = mode
|
|
||||||
self.trace = config.trace.root.get("assertion")
|
|
||||||
self.hook: rewrite.AssertionRewritingHook | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
|
|
||||||
"""Try to install the rewrite hook, raise SystemError if it fails."""
|
|
||||||
config.stash[assertstate_key] = AssertionState(config, "rewrite")
|
|
||||||
config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
|
|
||||||
sys.meta_path.insert(0, hook)
|
|
||||||
config.stash[assertstate_key].trace("installed rewrite import hook")
|
|
||||||
|
|
||||||
def undo() -> None:
|
|
||||||
hook = config.stash[assertstate_key].hook
|
|
||||||
if hook is not None and hook in sys.meta_path:
|
|
||||||
sys.meta_path.remove(hook)
|
|
||||||
|
|
||||||
config.add_cleanup(undo)
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection(session: Session) -> None:
|
|
||||||
# This hook is only called when test modules are collected
|
|
||||||
# so for example not in the managing process of pytest-xdist
|
|
||||||
# (which does not collect test modules).
|
|
||||||
assertstate = session.config.stash.get(assertstate_key, None)
|
|
||||||
if assertstate:
|
|
||||||
if assertstate.hook is not None:
|
|
||||||
assertstate.hook.set_session(session)
|
|
||||||
|
|
||||||
|
|
||||||
@hookimpl(wrapper=True, tryfirst=True)
|
|
||||||
def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
|
|
||||||
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
|
|
||||||
|
|
||||||
The rewrite module will use util._reprcompare if it exists to use custom
|
|
||||||
reporting via the pytest_assertrepr_compare hook. This sets up this custom
|
|
||||||
comparison for the test.
|
|
||||||
"""
|
|
||||||
ihook = item.ihook
|
|
||||||
|
|
||||||
def callbinrepr(op, left: object, right: object) -> str | None:
|
|
||||||
"""Call the pytest_assertrepr_compare hook and prepare the result.
|
|
||||||
|
|
||||||
This uses the first result from the hook and then ensures the
|
|
||||||
following:
|
|
||||||
* Overly verbose explanations are truncated unless configured otherwise
|
|
||||||
(eg. if running in verbose mode).
|
|
||||||
* Embedded newlines are escaped to help util.format_explanation()
|
|
||||||
later.
|
|
||||||
* If the rewrite mode is used embedded %-characters are replaced
|
|
||||||
to protect later % formatting.
|
|
||||||
|
|
||||||
The result can be formatted by util.format_explanation() for
|
|
||||||
pretty printing.
|
|
||||||
"""
|
|
||||||
hook_result = ihook.pytest_assertrepr_compare(
|
|
||||||
config=item.config, op=op, left=left, right=right
|
|
||||||
)
|
|
||||||
for new_expl in hook_result:
|
|
||||||
if new_expl:
|
|
||||||
new_expl = truncate.truncate_if_required(new_expl, item)
|
|
||||||
new_expl = [line.replace("\n", "\\n") for line in new_expl]
|
|
||||||
res = "\n~".join(new_expl)
|
|
||||||
if item.config.getvalue("assertmode") == "rewrite":
|
|
||||||
res = res.replace("%", "%%")
|
|
||||||
return res
|
|
||||||
return None
|
|
||||||
|
|
||||||
saved_assert_hooks = util._reprcompare, util._assertion_pass
|
|
||||||
util._reprcompare = callbinrepr
|
|
||||||
util._config = item.config
|
|
||||||
|
|
||||||
if ihook.pytest_assertion_pass.get_hookimpls():
|
|
||||||
|
|
||||||
def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
|
|
||||||
ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
|
|
||||||
|
|
||||||
util._assertion_pass = call_assertion_pass_hook
|
|
||||||
|
|
||||||
try:
|
|
||||||
return (yield)
|
|
||||||
finally:
|
|
||||||
util._reprcompare, util._assertion_pass = saved_assert_hooks
|
|
||||||
util._config = None
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_sessionfinish(session: Session) -> None:
|
|
||||||
assertstate = session.config.stash.get(assertstate_key, None)
|
|
||||||
if assertstate:
|
|
||||||
if assertstate.hook is not None:
|
|
||||||
assertstate.hook.set_session(None)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_assertrepr_compare(
|
|
||||||
config: Config, op: str, left: Any, right: Any
|
|
||||||
) -> list[str] | None:
|
|
||||||
return util.assertrepr_compare(config=config, op=op, left=left, right=right)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,137 +0,0 @@
|
|||||||
"""Utilities for truncating assertion output.
|
|
||||||
|
|
||||||
Current default behaviour is to truncate assertion explanations at
|
|
||||||
terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from _pytest.assertion import util
|
|
||||||
from _pytest.config import Config
|
|
||||||
from _pytest.nodes import Item
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MAX_LINES = 8
|
|
||||||
DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80
|
|
||||||
USAGE_MSG = "use '-vv' to show"
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_if_required(explanation: list[str], item: Item) -> list[str]:
|
|
||||||
"""Truncate this assertion explanation if the given test item is eligible."""
|
|
||||||
should_truncate, max_lines, max_chars = _get_truncation_parameters(item)
|
|
||||||
if should_truncate:
|
|
||||||
return _truncate_explanation(
|
|
||||||
explanation,
|
|
||||||
max_lines=max_lines,
|
|
||||||
max_chars=max_chars,
|
|
||||||
)
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _get_truncation_parameters(item: Item) -> tuple[bool, int, int]:
|
|
||||||
"""Return the truncation parameters related to the given item, as (should truncate, max lines, max chars)."""
|
|
||||||
# We do not need to truncate if one of conditions is met:
|
|
||||||
# 1. Verbosity level is 2 or more;
|
|
||||||
# 2. Test is being run in CI environment;
|
|
||||||
# 3. Both truncation_limit_lines and truncation_limit_chars
|
|
||||||
# .ini parameters are set to 0 explicitly.
|
|
||||||
max_lines = item.config.getini("truncation_limit_lines")
|
|
||||||
max_lines = int(max_lines if max_lines is not None else DEFAULT_MAX_LINES)
|
|
||||||
|
|
||||||
max_chars = item.config.getini("truncation_limit_chars")
|
|
||||||
max_chars = int(max_chars if max_chars is not None else DEFAULT_MAX_CHARS)
|
|
||||||
|
|
||||||
verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
|
|
||||||
|
|
||||||
should_truncate = verbose < 2 and not util.running_on_ci()
|
|
||||||
should_truncate = should_truncate and (max_lines > 0 or max_chars > 0)
|
|
||||||
|
|
||||||
return should_truncate, max_lines, max_chars
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_explanation(
|
|
||||||
input_lines: list[str],
|
|
||||||
max_lines: int,
|
|
||||||
max_chars: int,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Truncate given list of strings that makes up the assertion explanation.
|
|
||||||
|
|
||||||
Truncates to either max_lines, or max_chars - whichever the input reaches
|
|
||||||
first, taking the truncation explanation into account. The remaining lines
|
|
||||||
will be replaced by a usage message.
|
|
||||||
"""
|
|
||||||
# Check if truncation required
|
|
||||||
input_char_count = len("".join(input_lines))
|
|
||||||
# The length of the truncation explanation depends on the number of lines
|
|
||||||
# removed but is at least 68 characters:
|
|
||||||
# The real value is
|
|
||||||
# 64 (for the base message:
|
|
||||||
# '...\n...Full output truncated (1 line hidden), use '-vv' to show")'
|
|
||||||
# )
|
|
||||||
# + 1 (for plural)
|
|
||||||
# + int(math.log10(len(input_lines) - max_lines)) (number of hidden line, at least 1)
|
|
||||||
# + 3 for the '...' added to the truncated line
|
|
||||||
# But if there's more than 100 lines it's very likely that we're going to
|
|
||||||
# truncate, so we don't need the exact value using log10.
|
|
||||||
tolerable_max_chars = (
|
|
||||||
max_chars + 70 # 64 + 1 (for plural) + 2 (for '99') + 3 for '...'
|
|
||||||
)
|
|
||||||
# The truncation explanation add two lines to the output
|
|
||||||
tolerable_max_lines = max_lines + 2
|
|
||||||
if (
|
|
||||||
len(input_lines) <= tolerable_max_lines
|
|
||||||
and input_char_count <= tolerable_max_chars
|
|
||||||
):
|
|
||||||
return input_lines
|
|
||||||
# Truncate first to max_lines, and then truncate to max_chars if necessary
|
|
||||||
if max_lines > 0:
|
|
||||||
truncated_explanation = input_lines[:max_lines]
|
|
||||||
else:
|
|
||||||
truncated_explanation = input_lines
|
|
||||||
truncated_char = True
|
|
||||||
# We reevaluate the need to truncate chars following removal of some lines
|
|
||||||
if len("".join(truncated_explanation)) > tolerable_max_chars and max_chars > 0:
|
|
||||||
truncated_explanation = _truncate_by_char_count(
|
|
||||||
truncated_explanation, max_chars
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
truncated_char = False
|
|
||||||
|
|
||||||
if truncated_explanation == input_lines:
|
|
||||||
# No truncation happened, so we do not need to add any explanations
|
|
||||||
return truncated_explanation
|
|
||||||
|
|
||||||
truncated_line_count = len(input_lines) - len(truncated_explanation)
|
|
||||||
if truncated_explanation[-1]:
|
|
||||||
# Add ellipsis and take into account part-truncated final line
|
|
||||||
truncated_explanation[-1] = truncated_explanation[-1] + "..."
|
|
||||||
if truncated_char:
|
|
||||||
# It's possible that we did not remove any char from this line
|
|
||||||
truncated_line_count += 1
|
|
||||||
else:
|
|
||||||
# Add proper ellipsis when we were able to fit a full line exactly
|
|
||||||
truncated_explanation[-1] = "..."
|
|
||||||
return [
|
|
||||||
*truncated_explanation,
|
|
||||||
"",
|
|
||||||
f"...Full output truncated ({truncated_line_count} line"
|
|
||||||
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_by_char_count(input_lines: list[str], max_chars: int) -> list[str]:
|
|
||||||
# Find point at which input length exceeds total allowed length
|
|
||||||
iterated_char_count = 0
|
|
||||||
for iterated_index, input_line in enumerate(input_lines):
|
|
||||||
if iterated_char_count + len(input_line) > max_chars:
|
|
||||||
break
|
|
||||||
iterated_char_count += len(input_line)
|
|
||||||
|
|
||||||
# Create truncated explanation with modified final line
|
|
||||||
truncated_result = input_lines[:iterated_index]
|
|
||||||
final_line = input_lines[iterated_index]
|
|
||||||
if final_line:
|
|
||||||
final_line_truncate_point = max_chars - iterated_char_count
|
|
||||||
final_line = final_line[:final_line_truncate_point]
|
|
||||||
truncated_result.append(final_line)
|
|
||||||
return truncated_result
|
|
||||||
@@ -1,621 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
"""Utilities for assertion debugging."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import collections.abc
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from collections.abc import Set as AbstractSet
|
|
||||||
import os
|
|
||||||
import pprint
|
|
||||||
from typing import Any
|
|
||||||
from typing import Literal
|
|
||||||
from typing import Protocol
|
|
||||||
from unicodedata import normalize
|
|
||||||
|
|
||||||
from _pytest import outcomes
|
|
||||||
import _pytest._code
|
|
||||||
from _pytest._io.pprint import PrettyPrinter
|
|
||||||
from _pytest._io.saferepr import saferepr
|
|
||||||
from _pytest._io.saferepr import saferepr_unlimited
|
|
||||||
from _pytest.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
# The _reprcompare attribute on the util module is used by the new assertion
|
|
||||||
# interpretation code and assertion rewriter to detect this plugin was
|
|
||||||
# loaded and in turn call the hooks defined here as part of the
|
|
||||||
# DebugInterpreter.
|
|
||||||
_reprcompare: Callable[[str, object, object], str | None] | None = None
|
|
||||||
|
|
||||||
# Works similarly as _reprcompare attribute. Is populated with the hook call
|
|
||||||
# when pytest_runtest_setup is called.
|
|
||||||
_assertion_pass: Callable[[int, str, str], None] | None = None
|
|
||||||
|
|
||||||
# Config object which is assigned during pytest_runtest_protocol.
|
|
||||||
_config: Config | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class _HighlightFunc(Protocol):
|
|
||||||
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
|
|
||||||
"""Apply highlighting to the given source."""
|
|
||||||
|
|
||||||
|
|
||||||
def dummy_highlighter(source: str, lexer: Literal["diff", "python"] = "python") -> str:
|
|
||||||
"""Dummy highlighter that returns the text unprocessed.
|
|
||||||
|
|
||||||
Needed for _notin_text, as the diff gets post-processed to only show the "+" part.
|
|
||||||
"""
|
|
||||||
return source
|
|
||||||
|
|
||||||
|
|
||||||
def format_explanation(explanation: str) -> str:
|
|
||||||
r"""Format an explanation.
|
|
||||||
|
|
||||||
Normally all embedded newlines are escaped, however there are
|
|
||||||
three exceptions: \n{, \n} and \n~. The first two are intended
|
|
||||||
cover nested explanations, see function and attribute explanations
|
|
||||||
for examples (.visit_Call(), visit_Attribute()). The last one is
|
|
||||||
for when one explanation needs to span multiple lines, e.g. when
|
|
||||||
displaying diffs.
|
|
||||||
"""
|
|
||||||
lines = _split_explanation(explanation)
|
|
||||||
result = _format_lines(lines)
|
|
||||||
return "\n".join(result)
|
|
||||||
|
|
||||||
|
|
||||||
def _split_explanation(explanation: str) -> list[str]:
|
|
||||||
r"""Return a list of individual lines in the explanation.
|
|
||||||
|
|
||||||
This will return a list of lines split on '\n{', '\n}' and '\n~'.
|
|
||||||
Any other newlines will be escaped and appear in the line as the
|
|
||||||
literal '\n' characters.
|
|
||||||
"""
|
|
||||||
raw_lines = (explanation or "").split("\n")
|
|
||||||
lines = [raw_lines[0]]
|
|
||||||
for values in raw_lines[1:]:
|
|
||||||
if values and values[0] in ["{", "}", "~", ">"]:
|
|
||||||
lines.append(values)
|
|
||||||
else:
|
|
||||||
lines[-1] += "\\n" + values
|
|
||||||
return lines
|
|
||||||
|
|
||||||
|
|
||||||
def _format_lines(lines: Sequence[str]) -> list[str]:
|
|
||||||
"""Format the individual lines.
|
|
||||||
|
|
||||||
This will replace the '{', '}' and '~' characters of our mini formatting
|
|
||||||
language with the proper 'where ...', 'and ...' and ' + ...' text, taking
|
|
||||||
care of indentation along the way.
|
|
||||||
|
|
||||||
Return a list of formatted lines.
|
|
||||||
"""
|
|
||||||
result = list(lines[:1])
|
|
||||||
stack = [0]
|
|
||||||
stackcnt = [0]
|
|
||||||
for line in lines[1:]:
|
|
||||||
if line.startswith("{"):
|
|
||||||
if stackcnt[-1]:
|
|
||||||
s = "and "
|
|
||||||
else:
|
|
||||||
s = "where "
|
|
||||||
stack.append(len(result))
|
|
||||||
stackcnt[-1] += 1
|
|
||||||
stackcnt.append(0)
|
|
||||||
result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
|
|
||||||
elif line.startswith("}"):
|
|
||||||
stack.pop()
|
|
||||||
stackcnt.pop()
|
|
||||||
result[stack[-1]] += line[1:]
|
|
||||||
else:
|
|
||||||
assert line[0] in ["~", ">"]
|
|
||||||
stack[-1] += 1
|
|
||||||
indent = len(stack) if line.startswith("~") else len(stack) - 1
|
|
||||||
result.append(" " * indent + line[1:])
|
|
||||||
assert len(stack) == 1
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def issequence(x: Any) -> bool:
|
|
||||||
return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
|
|
||||||
|
|
||||||
|
|
||||||
def istext(x: Any) -> bool:
|
|
||||||
return isinstance(x, str)
|
|
||||||
|
|
||||||
|
|
||||||
def isdict(x: Any) -> bool:
|
|
||||||
return isinstance(x, dict)
|
|
||||||
|
|
||||||
|
|
||||||
def isset(x: Any) -> bool:
|
|
||||||
return isinstance(x, (set, frozenset))
|
|
||||||
|
|
||||||
|
|
||||||
def isnamedtuple(obj: Any) -> bool:
|
|
||||||
return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def isdatacls(obj: Any) -> bool:
|
|
||||||
return getattr(obj, "__dataclass_fields__", None) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def isattrs(obj: Any) -> bool:
|
|
||||||
return getattr(obj, "__attrs_attrs__", None) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def isiterable(obj: Any) -> bool:
|
|
||||||
try:
|
|
||||||
iter(obj)
|
|
||||||
return not istext(obj)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def has_default_eq(
|
|
||||||
obj: object,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an instance of an object contains the default eq
|
|
||||||
|
|
||||||
First, we check if the object's __eq__ attribute has __code__,
|
|
||||||
if so, we check the equally of the method code filename (__code__.co_filename)
|
|
||||||
to the default one generated by the dataclass and attr module
|
|
||||||
for dataclasses the default co_filename is <string>, for attrs class, the __eq__ should contain "attrs eq generated"
|
|
||||||
"""
|
|
||||||
# inspired from https://github.com/willmcgugan/rich/blob/07d51ffc1aee6f16bd2e5a25b4e82850fb9ed778/rich/pretty.py#L68
|
|
||||||
if hasattr(obj.__eq__, "__code__") and hasattr(obj.__eq__.__code__, "co_filename"):
|
|
||||||
code_filename = obj.__eq__.__code__.co_filename
|
|
||||||
|
|
||||||
if isattrs(obj):
|
|
||||||
return "attrs generated " in code_filename
|
|
||||||
|
|
||||||
return code_filename == "<string>" # data class
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def assertrepr_compare(
|
|
||||||
config, op: str, left: Any, right: Any, use_ascii: bool = False
|
|
||||||
) -> list[str] | None:
|
|
||||||
"""Return specialised explanations for some operators/operands."""
|
|
||||||
verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
|
|
||||||
|
|
||||||
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
|
|
||||||
# See issue #3246.
|
|
||||||
use_ascii = (
|
|
||||||
isinstance(left, str)
|
|
||||||
and isinstance(right, str)
|
|
||||||
and normalize("NFD", left) == normalize("NFD", right)
|
|
||||||
)
|
|
||||||
|
|
||||||
if verbose > 1:
|
|
||||||
left_repr = saferepr_unlimited(left, use_ascii=use_ascii)
|
|
||||||
right_repr = saferepr_unlimited(right, use_ascii=use_ascii)
|
|
||||||
else:
|
|
||||||
# XXX: "15 chars indentation" is wrong
|
|
||||||
# ("E AssertionError: assert "); should use term width.
|
|
||||||
maxsize = (
|
|
||||||
80 - 15 - len(op) - 2
|
|
||||||
) // 2 # 15 chars indentation, 1 space around op
|
|
||||||
|
|
||||||
left_repr = saferepr(left, maxsize=maxsize, use_ascii=use_ascii)
|
|
||||||
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
|
|
||||||
|
|
||||||
summary = f"{left_repr} {op} {right_repr}"
|
|
||||||
highlighter = config.get_terminal_writer()._highlight
|
|
||||||
|
|
||||||
explanation = None
|
|
||||||
try:
|
|
||||||
if op == "==":
|
|
||||||
explanation = _compare_eq_any(left, right, highlighter, verbose)
|
|
||||||
elif op == "not in":
|
|
||||||
if istext(left) and istext(right):
|
|
||||||
explanation = _notin_text(left, right, verbose)
|
|
||||||
elif op == "!=":
|
|
||||||
if isset(left) and isset(right):
|
|
||||||
explanation = ["Both sets are equal"]
|
|
||||||
elif op == ">=":
|
|
||||||
if isset(left) and isset(right):
|
|
||||||
explanation = _compare_gte_set(left, right, highlighter, verbose)
|
|
||||||
elif op == "<=":
|
|
||||||
if isset(left) and isset(right):
|
|
||||||
explanation = _compare_lte_set(left, right, highlighter, verbose)
|
|
||||||
elif op == ">":
|
|
||||||
if isset(left) and isset(right):
|
|
||||||
explanation = _compare_gt_set(left, right, highlighter, verbose)
|
|
||||||
elif op == "<":
|
|
||||||
if isset(left) and isset(right):
|
|
||||||
explanation = _compare_lt_set(left, right, highlighter, verbose)
|
|
||||||
|
|
||||||
except outcomes.Exit:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
repr_crash = _pytest._code.ExceptionInfo.from_current()._getreprcrash()
|
|
||||||
explanation = [
|
|
||||||
f"(pytest_assertion plugin: representation of details failed: {repr_crash}.",
|
|
||||||
" Probably an object has a faulty __repr__.)",
|
|
||||||
]
|
|
||||||
|
|
||||||
if not explanation:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if explanation[0] != "":
|
|
||||||
explanation = ["", *explanation]
|
|
||||||
return [summary, *explanation]
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_eq_any(
|
|
||||||
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
|
|
||||||
) -> list[str]:
|
|
||||||
explanation = []
|
|
||||||
if istext(left) and istext(right):
|
|
||||||
explanation = _diff_text(left, right, highlighter, verbose)
|
|
||||||
else:
|
|
||||||
from _pytest.python_api import ApproxBase
|
|
||||||
|
|
||||||
if isinstance(left, ApproxBase) or isinstance(right, ApproxBase):
|
|
||||||
# Although the common order should be obtained == expected, this ensures both ways
|
|
||||||
approx_side = left if isinstance(left, ApproxBase) else right
|
|
||||||
other_side = right if isinstance(left, ApproxBase) else left
|
|
||||||
|
|
||||||
explanation = approx_side._repr_compare(other_side)
|
|
||||||
elif type(left) is type(right) and (
|
|
||||||
isdatacls(left) or isattrs(left) or isnamedtuple(left)
|
|
||||||
):
|
|
||||||
# Note: unlike dataclasses/attrs, namedtuples compare only the
|
|
||||||
# field values, not the type or field names. But this branch
|
|
||||||
# intentionally only handles the same-type case, which was often
|
|
||||||
# used in older code bases before dataclasses/attrs were available.
|
|
||||||
explanation = _compare_eq_cls(left, right, highlighter, verbose)
|
|
||||||
elif issequence(left) and issequence(right):
|
|
||||||
explanation = _compare_eq_sequence(left, right, highlighter, verbose)
|
|
||||||
elif isset(left) and isset(right):
|
|
||||||
explanation = _compare_eq_set(left, right, highlighter, verbose)
|
|
||||||
elif isdict(left) and isdict(right):
|
|
||||||
explanation = _compare_eq_dict(left, right, highlighter, verbose)
|
|
||||||
|
|
||||||
if isiterable(left) and isiterable(right):
|
|
||||||
expl = _compare_eq_iterable(left, right, highlighter, verbose)
|
|
||||||
explanation.extend(expl)
|
|
||||||
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _diff_text(
|
|
||||||
left: str, right: str, highlighter: _HighlightFunc, verbose: int = 0
|
|
||||||
) -> list[str]:
|
|
||||||
"""Return the explanation for the diff between text.
|
|
||||||
|
|
||||||
Unless --verbose is used this will skip leading and trailing
|
|
||||||
characters which are identical to keep the diff minimal.
|
|
||||||
"""
|
|
||||||
from difflib import ndiff
|
|
||||||
|
|
||||||
explanation: list[str] = []
|
|
||||||
|
|
||||||
if verbose < 1:
|
|
||||||
i = 0 # just in case left or right has zero length
|
|
||||||
for i in range(min(len(left), len(right))):
|
|
||||||
if left[i] != right[i]:
|
|
||||||
break
|
|
||||||
if i > 42:
|
|
||||||
i -= 10 # Provide some context
|
|
||||||
explanation = [
|
|
||||||
f"Skipping {i} identical leading characters in diff, use -v to show"
|
|
||||||
]
|
|
||||||
left = left[i:]
|
|
||||||
right = right[i:]
|
|
||||||
if len(left) == len(right):
|
|
||||||
for i in range(len(left)):
|
|
||||||
if left[-i] != right[-i]:
|
|
||||||
break
|
|
||||||
if i > 42:
|
|
||||||
i -= 10 # Provide some context
|
|
||||||
explanation += [
|
|
||||||
f"Skipping {i} identical trailing "
|
|
||||||
"characters in diff, use -v to show"
|
|
||||||
]
|
|
||||||
left = left[:-i]
|
|
||||||
right = right[:-i]
|
|
||||||
keepends = True
|
|
||||||
if left.isspace() or right.isspace():
|
|
||||||
left = repr(str(left))
|
|
||||||
right = repr(str(right))
|
|
||||||
explanation += ["Strings contain only whitespace, escaping them using repr()"]
|
|
||||||
# "right" is the expected base against which we compare "left",
|
|
||||||
# see https://github.com/pytest-dev/pytest/issues/3333
|
|
||||||
explanation.extend(
|
|
||||||
highlighter(
|
|
||||||
"\n".join(
|
|
||||||
line.strip("\n")
|
|
||||||
for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
|
|
||||||
),
|
|
||||||
lexer="diff",
|
|
||||||
).splitlines()
|
|
||||||
)
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_eq_iterable(
|
|
||||||
left: Iterable[Any],
|
|
||||||
right: Iterable[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
if verbose <= 0 and not running_on_ci():
|
|
||||||
return ["Use -v to get more diff"]
|
|
||||||
# dynamic import to speedup pytest
|
|
||||||
import difflib
|
|
||||||
|
|
||||||
left_formatting = PrettyPrinter().pformat(left).splitlines()
|
|
||||||
right_formatting = PrettyPrinter().pformat(right).splitlines()
|
|
||||||
|
|
||||||
explanation = ["", "Full diff:"]
|
|
||||||
# "right" is the expected base against which we compare "left",
|
|
||||||
# see https://github.com/pytest-dev/pytest/issues/3333
|
|
||||||
explanation.extend(
|
|
||||||
highlighter(
|
|
||||||
"\n".join(
|
|
||||||
line.rstrip()
|
|
||||||
for line in difflib.ndiff(right_formatting, left_formatting)
|
|
||||||
),
|
|
||||||
lexer="diff",
|
|
||||||
).splitlines()
|
|
||||||
)
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_eq_sequence(
|
|
||||||
left: Sequence[Any],
|
|
||||||
right: Sequence[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
|
|
||||||
explanation: list[str] = []
|
|
||||||
len_left = len(left)
|
|
||||||
len_right = len(right)
|
|
||||||
for i in range(min(len_left, len_right)):
|
|
||||||
if left[i] != right[i]:
|
|
||||||
if comparing_bytes:
|
|
||||||
# when comparing bytes, we want to see their ascii representation
|
|
||||||
# instead of their numeric values (#5260)
|
|
||||||
# using a slice gives us the ascii representation:
|
|
||||||
# >>> s = b'foo'
|
|
||||||
# >>> s[0]
|
|
||||||
# 102
|
|
||||||
# >>> s[0:1]
|
|
||||||
# b'f'
|
|
||||||
left_value = left[i : i + 1]
|
|
||||||
right_value = right[i : i + 1]
|
|
||||||
else:
|
|
||||||
left_value = left[i]
|
|
||||||
right_value = right[i]
|
|
||||||
|
|
||||||
explanation.append(
|
|
||||||
f"At index {i} diff:"
|
|
||||||
f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
if comparing_bytes:
|
|
||||||
# when comparing bytes, it doesn't help to show the "sides contain one or more
|
|
||||||
# items" longer explanation, so skip it
|
|
||||||
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
len_diff = len_left - len_right
|
|
||||||
if len_diff:
|
|
||||||
if len_diff > 0:
|
|
||||||
dir_with_more = "Left"
|
|
||||||
extra = saferepr(left[len_right])
|
|
||||||
else:
|
|
||||||
len_diff = 0 - len_diff
|
|
||||||
dir_with_more = "Right"
|
|
||||||
extra = saferepr(right[len_left])
|
|
||||||
|
|
||||||
if len_diff == 1:
|
|
||||||
explanation += [
|
|
||||||
f"{dir_with_more} contains one more item: {highlighter(extra)}"
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
explanation += [
|
|
||||||
f"{dir_with_more} contains {len_diff} more items, first extra item: {highlighter(extra)}"
|
|
||||||
]
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_eq_set(
|
|
||||||
left: AbstractSet[Any],
|
|
||||||
right: AbstractSet[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
explanation = []
|
|
||||||
explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
|
|
||||||
explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_gt_set(
|
|
||||||
left: AbstractSet[Any],
|
|
||||||
right: AbstractSet[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
explanation = _compare_gte_set(left, right, highlighter)
|
|
||||||
if not explanation:
|
|
||||||
return ["Both sets are equal"]
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_lt_set(
|
|
||||||
left: AbstractSet[Any],
|
|
||||||
right: AbstractSet[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
explanation = _compare_lte_set(left, right, highlighter)
|
|
||||||
if not explanation:
|
|
||||||
return ["Both sets are equal"]
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_gte_set(
|
|
||||||
left: AbstractSet[Any],
|
|
||||||
right: AbstractSet[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
return _set_one_sided_diff("right", right, left, highlighter)
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_lte_set(
|
|
||||||
left: AbstractSet[Any],
|
|
||||||
right: AbstractSet[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
return _set_one_sided_diff("left", left, right, highlighter)
|
|
||||||
|
|
||||||
|
|
||||||
def _set_one_sided_diff(
|
|
||||||
posn: str,
|
|
||||||
set1: AbstractSet[Any],
|
|
||||||
set2: AbstractSet[Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
) -> list[str]:
|
|
||||||
explanation = []
|
|
||||||
diff = set1 - set2
|
|
||||||
if diff:
|
|
||||||
explanation.append(f"Extra items in the {posn} set:")
|
|
||||||
for item in diff:
|
|
||||||
explanation.append(highlighter(saferepr(item)))
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_eq_dict(
|
|
||||||
left: Mapping[Any, Any],
|
|
||||||
right: Mapping[Any, Any],
|
|
||||||
highlighter: _HighlightFunc,
|
|
||||||
verbose: int = 0,
|
|
||||||
) -> list[str]:
|
|
||||||
explanation: list[str] = []
|
|
||||||
set_left = set(left)
|
|
||||||
set_right = set(right)
|
|
||||||
common = set_left.intersection(set_right)
|
|
||||||
same = {k: left[k] for k in common if left[k] == right[k]}
|
|
||||||
if same and verbose < 2:
|
|
||||||
explanation += [f"Omitting {len(same)} identical items, use -vv to show"]
|
|
||||||
elif same:
|
|
||||||
explanation += ["Common items:"]
|
|
||||||
explanation += highlighter(pprint.pformat(same)).splitlines()
|
|
||||||
diff = {k for k in common if left[k] != right[k]}
|
|
||||||
if diff:
|
|
||||||
explanation += ["Differing items:"]
|
|
||||||
for k in diff:
|
|
||||||
explanation += [
|
|
||||||
highlighter(saferepr({k: left[k]}))
|
|
||||||
+ " != "
|
|
||||||
+ highlighter(saferepr({k: right[k]}))
|
|
||||||
]
|
|
||||||
extra_left = set_left - set_right
|
|
||||||
len_extra_left = len(extra_left)
|
|
||||||
if len_extra_left:
|
|
||||||
explanation.append(
|
|
||||||
f"Left contains {len_extra_left} more item{'' if len_extra_left == 1 else 's'}:"
|
|
||||||
)
|
|
||||||
explanation.extend(
|
|
||||||
highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
|
|
||||||
)
|
|
||||||
extra_right = set_right - set_left
|
|
||||||
len_extra_right = len(extra_right)
|
|
||||||
if len_extra_right:
|
|
||||||
explanation.append(
|
|
||||||
f"Right contains {len_extra_right} more item{'' if len_extra_right == 1 else 's'}:"
|
|
||||||
)
|
|
||||||
explanation.extend(
|
|
||||||
highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
|
|
||||||
)
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_eq_cls(
|
|
||||||
left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
|
|
||||||
) -> list[str]:
|
|
||||||
if not has_default_eq(left):
|
|
||||||
return []
|
|
||||||
if isdatacls(left):
|
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
all_fields = dataclasses.fields(left)
|
|
||||||
fields_to_check = [info.name for info in all_fields if info.compare]
|
|
||||||
elif isattrs(left):
|
|
||||||
all_fields = left.__attrs_attrs__
|
|
||||||
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
|
|
||||||
elif isnamedtuple(left):
|
|
||||||
fields_to_check = left._fields
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
indent = " "
|
|
||||||
same = []
|
|
||||||
diff = []
|
|
||||||
for field in fields_to_check:
|
|
||||||
if getattr(left, field) == getattr(right, field):
|
|
||||||
same.append(field)
|
|
||||||
else:
|
|
||||||
diff.append(field)
|
|
||||||
|
|
||||||
explanation = []
|
|
||||||
if same or diff:
|
|
||||||
explanation += [""]
|
|
||||||
if same and verbose < 2:
|
|
||||||
explanation.append(f"Omitting {len(same)} identical items, use -vv to show")
|
|
||||||
elif same:
|
|
||||||
explanation += ["Matching attributes:"]
|
|
||||||
explanation += highlighter(pprint.pformat(same)).splitlines()
|
|
||||||
if diff:
|
|
||||||
explanation += ["Differing attributes:"]
|
|
||||||
explanation += highlighter(pprint.pformat(diff)).splitlines()
|
|
||||||
for field in diff:
|
|
||||||
field_left = getattr(left, field)
|
|
||||||
field_right = getattr(right, field)
|
|
||||||
explanation += [
|
|
||||||
"",
|
|
||||||
f"Drill down into differing attribute {field}:",
|
|
||||||
f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
|
|
||||||
]
|
|
||||||
explanation += [
|
|
||||||
indent + line
|
|
||||||
for line in _compare_eq_any(
|
|
||||||
field_left, field_right, highlighter, verbose
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return explanation
|
|
||||||
|
|
||||||
|
|
||||||
def _notin_text(term: str, text: str, verbose: int = 0) -> list[str]:
|
|
||||||
index = text.find(term)
|
|
||||||
head = text[:index]
|
|
||||||
tail = text[index + len(term) :]
|
|
||||||
correct_text = head + tail
|
|
||||||
diff = _diff_text(text, correct_text, dummy_highlighter, verbose)
|
|
||||||
newdiff = [f"{saferepr(term, maxsize=42)} is contained here:"]
|
|
||||||
for line in diff:
|
|
||||||
if line.startswith("Skipping"):
|
|
||||||
continue
|
|
||||||
if line.startswith("- "):
|
|
||||||
continue
|
|
||||||
if line.startswith("+ "):
|
|
||||||
newdiff.append(" " + line[2:])
|
|
||||||
else:
|
|
||||||
newdiff.append(line)
|
|
||||||
return newdiff
|
|
||||||
|
|
||||||
|
|
||||||
def running_on_ci() -> bool:
|
|
||||||
"""Check if we're currently running on a CI system."""
|
|
||||||
env_vars = ["CI", "BUILD_NUMBER"]
|
|
||||||
return any(var in os.environ for var in env_vars)
|
|
||||||
@@ -1,625 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
"""Implementation of the cache provider."""
|
|
||||||
|
|
||||||
# This plugin was not named "cache" to avoid conflicts with the external
|
|
||||||
# pytest-cache version.
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
from collections.abc import Iterable
|
|
||||||
import dataclasses
|
|
||||||
import errno
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import tempfile
|
|
||||||
from typing import final
|
|
||||||
|
|
||||||
from .pathlib import resolve_from_str
|
|
||||||
from .pathlib import rm_rf
|
|
||||||
from .reports import CollectReport
|
|
||||||
from _pytest import nodes
|
|
||||||
from _pytest._io import TerminalWriter
|
|
||||||
from _pytest.config import Config
|
|
||||||
from _pytest.config import ExitCode
|
|
||||||
from _pytest.config import hookimpl
|
|
||||||
from _pytest.config.argparsing import Parser
|
|
||||||
from _pytest.deprecated import check_ispytest
|
|
||||||
from _pytest.fixtures import fixture
|
|
||||||
from _pytest.fixtures import FixtureRequest
|
|
||||||
from _pytest.main import Session
|
|
||||||
from _pytest.nodes import Directory
|
|
||||||
from _pytest.nodes import File
|
|
||||||
from _pytest.reports import TestReport
|
|
||||||
|
|
||||||
|
|
||||||
README_CONTENT = """\
|
|
||||||
# pytest cache directory #
|
|
||||||
|
|
||||||
This directory contains data from the pytest's cache plugin,
|
|
||||||
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
|
||||||
|
|
||||||
**Do not** commit this to version control.
|
|
||||||
|
|
||||||
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
CACHEDIR_TAG_CONTENT = b"""\
|
|
||||||
Signature: 8a477f597d28d172789f06886806bc55
|
|
||||||
# This file is a cache directory tag created by pytest.
|
|
||||||
# For information about cache directory tags, see:
|
|
||||||
# https://bford.info/cachedir/spec.html
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Cache:
|
|
||||||
"""Instance of the `cache` fixture."""
|
|
||||||
|
|
||||||
_cachedir: Path = dataclasses.field(repr=False)
|
|
||||||
_config: Config = dataclasses.field(repr=False)
|
|
||||||
|
|
||||||
# Sub-directory under cache-dir for directories created by `mkdir()`.
|
|
||||||
_CACHE_PREFIX_DIRS = "d"
|
|
||||||
|
|
||||||
# Sub-directory under cache-dir for values created by `set()`.
|
|
||||||
_CACHE_PREFIX_VALUES = "v"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, cachedir: Path, config: Config, *, _ispytest: bool = False
|
|
||||||
) -> None:
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
self._cachedir = cachedir
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def for_config(cls, config: Config, *, _ispytest: bool = False) -> Cache:
|
|
||||||
"""Create the Cache instance for a Config.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
cachedir = cls.cache_dir_from_config(config, _ispytest=True)
|
|
||||||
if config.getoption("cacheclear") and cachedir.is_dir():
|
|
||||||
cls.clear_cache(cachedir, _ispytest=True)
|
|
||||||
return cls(cachedir, config, _ispytest=True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def clear_cache(cls, cachedir: Path, _ispytest: bool = False) -> None:
|
|
||||||
"""Clear the sub-directories used to hold cached directories and values.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
|
|
||||||
d = cachedir / prefix
|
|
||||||
if d.is_dir():
|
|
||||||
rm_rf(d)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def cache_dir_from_config(config: Config, *, _ispytest: bool = False) -> Path:
|
|
||||||
"""Get the path to the cache directory for a Config.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
return resolve_from_str(config.getini("cache_dir"), config.rootpath)
|
|
||||||
|
|
||||||
def warn(self, fmt: str, *, _ispytest: bool = False, **args: object) -> None:
|
|
||||||
"""Issue a cache warning.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from _pytest.warning_types import PytestCacheWarning
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
PytestCacheWarning(fmt.format(**args) if args else fmt),
|
|
||||||
self._config.hook,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _mkdir(self, path: Path) -> None:
|
|
||||||
self._ensure_cache_dir_and_supporting_files()
|
|
||||||
path.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
def mkdir(self, name: str) -> Path:
|
|
||||||
"""Return a directory path object with the given name.
|
|
||||||
|
|
||||||
If the directory does not yet exist, it will be created. You can use
|
|
||||||
it to manage files to e.g. store/retrieve database dumps across test
|
|
||||||
sessions.
|
|
||||||
|
|
||||||
.. versionadded:: 7.0
|
|
||||||
|
|
||||||
:param name:
|
|
||||||
Must be a string not containing a ``/`` separator.
|
|
||||||
Make sure the name contains your plugin or application
|
|
||||||
identifiers to prevent clashes with other cache users.
|
|
||||||
"""
|
|
||||||
path = Path(name)
|
|
||||||
if len(path.parts) > 1:
|
|
||||||
raise ValueError("name is not allowed to contain path separators")
|
|
||||||
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
|
|
||||||
self._mkdir(res)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _getvaluepath(self, key: str) -> Path:
|
|
||||||
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
|
|
||||||
|
|
||||||
def get(self, key: str, default):
|
|
||||||
"""Return the cached value for the given key.
|
|
||||||
|
|
||||||
If no value was yet cached or the value cannot be read, the specified
|
|
||||||
default is returned.
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
Must be a ``/`` separated value. Usually the first
|
|
||||||
name is the name of your plugin or your application.
|
|
||||||
:param default:
|
|
||||||
The value to return in case of a cache-miss or invalid cache value.
|
|
||||||
"""
|
|
||||||
path = self._getvaluepath(key)
|
|
||||||
try:
|
|
||||||
with path.open("r", encoding="UTF-8") as f:
|
|
||||||
return json.load(f)
|
|
||||||
except (ValueError, OSError):
|
|
||||||
return default
|
|
||||||
|
|
||||||
def set(self, key: str, value: object) -> None:
|
|
||||||
"""Save value for the given key.
|
|
||||||
|
|
||||||
:param key:
|
|
||||||
Must be a ``/`` separated value. Usually the first
|
|
||||||
name is the name of your plugin or your application.
|
|
||||||
:param value:
|
|
||||||
Must be of any combination of basic python types,
|
|
||||||
including nested types like lists of dictionaries.
|
|
||||||
"""
|
|
||||||
path = self._getvaluepath(key)
|
|
||||||
try:
|
|
||||||
self._mkdir(path.parent)
|
|
||||||
except OSError as exc:
|
|
||||||
self.warn(
|
|
||||||
f"could not create cache path {path}: {exc}",
|
|
||||||
_ispytest=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
data = json.dumps(value, ensure_ascii=False, indent=2)
|
|
||||||
try:
|
|
||||||
f = path.open("w", encoding="UTF-8")
|
|
||||||
except OSError as exc:
|
|
||||||
self.warn(
|
|
||||||
f"cache could not write path {path}: {exc}",
|
|
||||||
_ispytest=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
with f:
|
|
||||||
f.write(data)
|
|
||||||
|
|
||||||
def _ensure_cache_dir_and_supporting_files(self) -> None:
|
|
||||||
"""Create the cache dir and its supporting files."""
|
|
||||||
if self._cachedir.is_dir():
|
|
||||||
return
|
|
||||||
|
|
||||||
self._cachedir.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with tempfile.TemporaryDirectory(
|
|
||||||
prefix="pytest-cache-files-",
|
|
||||||
dir=self._cachedir.parent,
|
|
||||||
) as newpath:
|
|
||||||
path = Path(newpath)
|
|
||||||
|
|
||||||
# Reset permissions to the default, see #12308.
|
|
||||||
# Note: there's no way to get the current umask atomically, eek.
|
|
||||||
umask = os.umask(0o022)
|
|
||||||
os.umask(umask)
|
|
||||||
path.chmod(0o777 - umask)
|
|
||||||
|
|
||||||
with open(path.joinpath("README.md"), "x", encoding="UTF-8") as f:
|
|
||||||
f.write(README_CONTENT)
|
|
||||||
with open(path.joinpath(".gitignore"), "x", encoding="UTF-8") as f:
|
|
||||||
f.write("# Created by pytest automatically.\n*\n")
|
|
||||||
with open(path.joinpath("CACHEDIR.TAG"), "xb") as f:
|
|
||||||
f.write(CACHEDIR_TAG_CONTENT)
|
|
||||||
|
|
||||||
try:
|
|
||||||
path.rename(self._cachedir)
|
|
||||||
except OSError as e:
|
|
||||||
# If 2 concurrent pytests both race to the rename, the loser
|
|
||||||
# gets "Directory not empty" from the rename. In this case,
|
|
||||||
# everything is handled so just continue (while letting the
|
|
||||||
# temporary directory be cleaned up).
|
|
||||||
# On Windows, the error is a FileExistsError which translates to EEXIST.
|
|
||||||
if e.errno not in (errno.ENOTEMPTY, errno.EEXIST):
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
# Create a directory in place of the one we just moved so that
|
|
||||||
# `TemporaryDirectory`'s cleanup doesn't complain.
|
|
||||||
#
|
|
||||||
# TODO: pass ignore_cleanup_errors=True when we no longer support python < 3.10.
|
|
||||||
# See https://github.com/python/cpython/issues/74168. Note that passing
|
|
||||||
# delete=False would do the wrong thing in case of errors and isn't supported
|
|
||||||
# until python 3.12.
|
|
||||||
path.mkdir()
|
|
||||||
|
|
||||||
|
|
||||||
class LFPluginCollWrapper:
|
|
||||||
def __init__(self, lfplugin: LFPlugin) -> None:
|
|
||||||
self.lfplugin = lfplugin
|
|
||||||
self._collected_at_least_one_failure = False
|
|
||||||
|
|
||||||
@hookimpl(wrapper=True)
|
|
||||||
def pytest_make_collect_report(
|
|
||||||
self, collector: nodes.Collector
|
|
||||||
) -> Generator[None, CollectReport, CollectReport]:
|
|
||||||
res = yield
|
|
||||||
if isinstance(collector, (Session, Directory)):
|
|
||||||
# Sort any lf-paths to the beginning.
|
|
||||||
lf_paths = self.lfplugin._last_failed_paths
|
|
||||||
|
|
||||||
# Use stable sort to prioritize last failed.
|
|
||||||
def sort_key(node: nodes.Item | nodes.Collector) -> bool:
|
|
||||||
return node.path in lf_paths
|
|
||||||
|
|
||||||
res.result = sorted(
|
|
||||||
res.result,
|
|
||||||
key=sort_key,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(collector, File):
|
|
||||||
if collector.path in self.lfplugin._last_failed_paths:
|
|
||||||
result = res.result
|
|
||||||
lastfailed = self.lfplugin.lastfailed
|
|
||||||
|
|
||||||
# Only filter with known failures.
|
|
||||||
if not self._collected_at_least_one_failure:
|
|
||||||
if not any(x.nodeid in lastfailed for x in result):
|
|
||||||
return res
|
|
||||||
self.lfplugin.config.pluginmanager.register(
|
|
||||||
LFPluginCollSkipfiles(self.lfplugin), "lfplugin-collskip"
|
|
||||||
)
|
|
||||||
self._collected_at_least_one_failure = True
|
|
||||||
|
|
||||||
session = collector.session
|
|
||||||
result[:] = [
|
|
||||||
x
|
|
||||||
for x in result
|
|
||||||
if x.nodeid in lastfailed
|
|
||||||
# Include any passed arguments (not trivial to filter).
|
|
||||||
or session.isinitpath(x.path)
|
|
||||||
# Keep all sub-collectors.
|
|
||||||
or isinstance(x, nodes.Collector)
|
|
||||||
]
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class LFPluginCollSkipfiles:
|
|
||||||
def __init__(self, lfplugin: LFPlugin) -> None:
|
|
||||||
self.lfplugin = lfplugin
|
|
||||||
|
|
||||||
@hookimpl
|
|
||||||
def pytest_make_collect_report(
|
|
||||||
self, collector: nodes.Collector
|
|
||||||
) -> CollectReport | None:
|
|
||||||
if isinstance(collector, File):
|
|
||||||
if collector.path not in self.lfplugin._last_failed_paths:
|
|
||||||
self.lfplugin._skipped_files += 1
|
|
||||||
|
|
||||||
return CollectReport(
|
|
||||||
collector.nodeid, "passed", longrepr=None, result=[]
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class LFPlugin:
|
|
||||||
"""Plugin which implements the --lf (run last-failing) option."""
|
|
||||||
|
|
||||||
def __init__(self, config: Config) -> None:
|
|
||||||
self.config = config
|
|
||||||
active_keys = "lf", "failedfirst"
|
|
||||||
self.active = any(config.getoption(key) for key in active_keys)
|
|
||||||
assert config.cache
|
|
||||||
self.lastfailed: dict[str, bool] = config.cache.get("cache/lastfailed", {})
|
|
||||||
self._previously_failed_count: int | None = None
|
|
||||||
self._report_status: str | None = None
|
|
||||||
self._skipped_files = 0 # count skipped files during collection due to --lf
|
|
||||||
|
|
||||||
if config.getoption("lf"):
|
|
||||||
self._last_failed_paths = self.get_last_failed_paths()
|
|
||||||
config.pluginmanager.register(
|
|
||||||
LFPluginCollWrapper(self), "lfplugin-collwrapper"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_last_failed_paths(self) -> set[Path]:
|
|
||||||
"""Return a set with all Paths of the previously failed nodeids and
|
|
||||||
their parents."""
|
|
||||||
rootpath = self.config.rootpath
|
|
||||||
result = set()
|
|
||||||
for nodeid in self.lastfailed:
|
|
||||||
path = rootpath / nodeid.split("::")[0]
|
|
||||||
result.add(path)
|
|
||||||
result.update(path.parents)
|
|
||||||
return {x for x in result if x.exists()}
|
|
||||||
|
|
||||||
def pytest_report_collectionfinish(self) -> str | None:
|
|
||||||
if self.active and self.config.get_verbosity() >= 0:
|
|
||||||
return f"run-last-failure: {self._report_status}"
|
|
||||||
return None
|
|
||||||
|
|
||||||
def pytest_runtest_logreport(self, report: TestReport) -> None:
|
|
||||||
if (report.when == "call" and report.passed) or report.skipped:
|
|
||||||
self.lastfailed.pop(report.nodeid, None)
|
|
||||||
elif report.failed:
|
|
||||||
self.lastfailed[report.nodeid] = True
|
|
||||||
|
|
||||||
def pytest_collectreport(self, report: CollectReport) -> None:
|
|
||||||
passed = report.outcome in ("passed", "skipped")
|
|
||||||
if passed:
|
|
||||||
if report.nodeid in self.lastfailed:
|
|
||||||
self.lastfailed.pop(report.nodeid)
|
|
||||||
self.lastfailed.update((item.nodeid, True) for item in report.result)
|
|
||||||
else:
|
|
||||||
self.lastfailed[report.nodeid] = True
|
|
||||||
|
|
||||||
@hookimpl(wrapper=True, tryfirst=True)
|
|
||||||
def pytest_collection_modifyitems(
|
|
||||||
self, config: Config, items: list[nodes.Item]
|
|
||||||
) -> Generator[None]:
|
|
||||||
res = yield
|
|
||||||
|
|
||||||
if not self.active:
|
|
||||||
return res
|
|
||||||
|
|
||||||
if self.lastfailed:
|
|
||||||
previously_failed = []
|
|
||||||
previously_passed = []
|
|
||||||
for item in items:
|
|
||||||
if item.nodeid in self.lastfailed:
|
|
||||||
previously_failed.append(item)
|
|
||||||
else:
|
|
||||||
previously_passed.append(item)
|
|
||||||
self._previously_failed_count = len(previously_failed)
|
|
||||||
|
|
||||||
if not previously_failed:
|
|
||||||
# Running a subset of all tests with recorded failures
|
|
||||||
# only outside of it.
|
|
||||||
self._report_status = (
|
|
||||||
f"{len(self.lastfailed)} known failures not in selected tests"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.config.getoption("lf"):
|
|
||||||
items[:] = previously_failed
|
|
||||||
config.hook.pytest_deselected(items=previously_passed)
|
|
||||||
else: # --failedfirst
|
|
||||||
items[:] = previously_failed + previously_passed
|
|
||||||
|
|
||||||
noun = "failure" if self._previously_failed_count == 1 else "failures"
|
|
||||||
suffix = " first" if self.config.getoption("failedfirst") else ""
|
|
||||||
self._report_status = (
|
|
||||||
f"rerun previous {self._previously_failed_count} {noun}{suffix}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._skipped_files > 0:
|
|
||||||
files_noun = "file" if self._skipped_files == 1 else "files"
|
|
||||||
self._report_status += f" (skipped {self._skipped_files} {files_noun})"
|
|
||||||
else:
|
|
||||||
self._report_status = "no previously failed tests, "
|
|
||||||
if self.config.getoption("last_failed_no_failures") == "none":
|
|
||||||
self._report_status += "deselecting all items."
|
|
||||||
config.hook.pytest_deselected(items=items[:])
|
|
||||||
items[:] = []
|
|
||||||
else:
|
|
||||||
self._report_status += "not deselecting items."
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def pytest_sessionfinish(self, session: Session) -> None:
|
|
||||||
config = self.config
|
|
||||||
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
|
|
||||||
return
|
|
||||||
|
|
||||||
assert config.cache is not None
|
|
||||||
saved_lastfailed = config.cache.get("cache/lastfailed", {})
|
|
||||||
if saved_lastfailed != self.lastfailed:
|
|
||||||
config.cache.set("cache/lastfailed", self.lastfailed)
|
|
||||||
|
|
||||||
|
|
||||||
class NFPlugin:
|
|
||||||
"""Plugin which implements the --nf (run new-first) option."""
|
|
||||||
|
|
||||||
def __init__(self, config: Config) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.active = config.option.newfirst
|
|
||||||
assert config.cache is not None
|
|
||||||
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
|
|
||||||
|
|
||||||
@hookimpl(wrapper=True, tryfirst=True)
|
|
||||||
def pytest_collection_modifyitems(self, items: list[nodes.Item]) -> Generator[None]:
|
|
||||||
res = yield
|
|
||||||
|
|
||||||
if self.active:
|
|
||||||
new_items: dict[str, nodes.Item] = {}
|
|
||||||
other_items: dict[str, nodes.Item] = {}
|
|
||||||
for item in items:
|
|
||||||
if item.nodeid not in self.cached_nodeids:
|
|
||||||
new_items[item.nodeid] = item
|
|
||||||
else:
|
|
||||||
other_items[item.nodeid] = item
|
|
||||||
|
|
||||||
items[:] = self._get_increasing_order(
|
|
||||||
new_items.values()
|
|
||||||
) + self._get_increasing_order(other_items.values())
|
|
||||||
self.cached_nodeids.update(new_items)
|
|
||||||
else:
|
|
||||||
self.cached_nodeids.update(item.nodeid for item in items)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> list[nodes.Item]:
|
|
||||||
return sorted(items, key=lambda item: item.path.stat().st_mtime, reverse=True)
|
|
||||||
|
|
||||||
def pytest_sessionfinish(self) -> None:
|
|
||||||
config = self.config
|
|
||||||
if config.getoption("cacheshow") or hasattr(config, "workerinput"):
|
|
||||||
return
|
|
||||||
|
|
||||||
if config.getoption("collectonly"):
|
|
||||||
return
|
|
||||||
|
|
||||||
assert config.cache is not None
|
|
||||||
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser: Parser) -> None:
|
|
||||||
group = parser.getgroup("general")
|
|
||||||
group.addoption(
|
|
||||||
"--lf",
|
|
||||||
"--last-failed",
|
|
||||||
action="store_true",
|
|
||||||
dest="lf",
|
|
||||||
help="Rerun only the tests that failed at the last run (or all if none failed)",
|
|
||||||
)
|
|
||||||
group.addoption(
|
|
||||||
"--ff",
|
|
||||||
"--failed-first",
|
|
||||||
action="store_true",
|
|
||||||
dest="failedfirst",
|
|
||||||
help="Run all tests, but run the last failures first. "
|
|
||||||
"This may re-order tests and thus lead to "
|
|
||||||
"repeated fixture setup/teardown.",
|
|
||||||
)
|
|
||||||
group.addoption(
|
|
||||||
"--nf",
|
|
||||||
"--new-first",
|
|
||||||
action="store_true",
|
|
||||||
dest="newfirst",
|
|
||||||
help="Run tests from new files first, then the rest of the tests "
|
|
||||||
"sorted by file mtime",
|
|
||||||
)
|
|
||||||
group.addoption(
|
|
||||||
"--cache-show",
|
|
||||||
action="append",
|
|
||||||
nargs="?",
|
|
||||||
dest="cacheshow",
|
|
||||||
help=(
|
|
||||||
"Show cache contents, don't perform collection or tests. "
|
|
||||||
"Optional argument: glob (default: '*')."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
group.addoption(
|
|
||||||
"--cache-clear",
|
|
||||||
action="store_true",
|
|
||||||
dest="cacheclear",
|
|
||||||
help="Remove all cache contents at start of test run",
|
|
||||||
)
|
|
||||||
cache_dir_default = ".pytest_cache"
|
|
||||||
if "TOX_ENV_DIR" in os.environ:
|
|
||||||
cache_dir_default = os.path.join(os.environ["TOX_ENV_DIR"], cache_dir_default)
|
|
||||||
parser.addini("cache_dir", default=cache_dir_default, help="Cache directory path")
|
|
||||||
group.addoption(
|
|
||||||
"--lfnf",
|
|
||||||
"--last-failed-no-failures",
|
|
||||||
action="store",
|
|
||||||
dest="last_failed_no_failures",
|
|
||||||
choices=("all", "none"),
|
|
||||||
default="all",
|
|
||||||
help="With ``--lf``, determines whether to execute tests when there "
|
|
||||||
"are no previously (known) failures or when no "
|
|
||||||
"cached ``lastfailed`` data was found. "
|
|
||||||
"``all`` (the default) runs the full test suite again. "
|
|
||||||
"``none`` just emits a message about no known failures and exits successfully.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_cmdline_main(config: Config) -> int | ExitCode | None:
|
|
||||||
if config.option.cacheshow and not config.option.help:
|
|
||||||
from _pytest.main import wrap_session
|
|
||||||
|
|
||||||
return wrap_session(config, cacheshow)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@hookimpl(tryfirst=True)
|
|
||||||
def pytest_configure(config: Config) -> None:
|
|
||||||
config.cache = Cache.for_config(config, _ispytest=True)
|
|
||||||
config.pluginmanager.register(LFPlugin(config), "lfplugin")
|
|
||||||
config.pluginmanager.register(NFPlugin(config), "nfplugin")
|
|
||||||
|
|
||||||
|
|
||||||
@fixture
|
|
||||||
def cache(request: FixtureRequest) -> Cache:
|
|
||||||
"""Return a cache object that can persist state between testing sessions.
|
|
||||||
|
|
||||||
cache.get(key, default)
|
|
||||||
cache.set(key, value)
|
|
||||||
|
|
||||||
Keys must be ``/`` separated strings, where the first part is usually the
|
|
||||||
name of your plugin or application to avoid clashes with other cache users.
|
|
||||||
|
|
||||||
Values can be any object handled by the json stdlib module.
|
|
||||||
"""
|
|
||||||
assert request.config.cache is not None
|
|
||||||
return request.config.cache
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_report_header(config: Config) -> str | None:
|
|
||||||
"""Display cachedir with --cache-show and if non-default."""
|
|
||||||
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
|
|
||||||
assert config.cache is not None
|
|
||||||
cachedir = config.cache._cachedir
|
|
||||||
# TODO: evaluate generating upward relative paths
|
|
||||||
# starting with .., ../.. if sensible
|
|
||||||
|
|
||||||
try:
|
|
||||||
displaypath = cachedir.relative_to(config.rootpath)
|
|
||||||
except ValueError:
|
|
||||||
displaypath = cachedir
|
|
||||||
return f"cachedir: {displaypath}"
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def cacheshow(config: Config, session: Session) -> int:
|
|
||||||
from pprint import pformat
|
|
||||||
|
|
||||||
assert config.cache is not None
|
|
||||||
|
|
||||||
tw = TerminalWriter()
|
|
||||||
tw.line("cachedir: " + str(config.cache._cachedir))
|
|
||||||
if not config.cache._cachedir.is_dir():
|
|
||||||
tw.line("cache is empty")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
glob = config.option.cacheshow[0]
|
|
||||||
if glob is None:
|
|
||||||
glob = "*"
|
|
||||||
|
|
||||||
dummy = object()
|
|
||||||
basedir = config.cache._cachedir
|
|
||||||
vdir = basedir / Cache._CACHE_PREFIX_VALUES
|
|
||||||
tw.sep("-", f"cache values for {glob!r}")
|
|
||||||
for valpath in sorted(x for x in vdir.rglob(glob) if x.is_file()):
|
|
||||||
key = str(valpath.relative_to(vdir))
|
|
||||||
val = config.cache.get(key, dummy)
|
|
||||||
if val is dummy:
|
|
||||||
tw.line(f"{key} contains unreadable content, will be ignored")
|
|
||||||
else:
|
|
||||||
tw.line(f"{key} contains:")
|
|
||||||
for line in pformat(val).splitlines():
|
|
||||||
tw.line(" " + line)
|
|
||||||
|
|
||||||
ddir = basedir / Cache._CACHE_PREFIX_DIRS
|
|
||||||
if ddir.is_dir():
|
|
||||||
contents = sorted(ddir.rglob(glob))
|
|
||||||
tw.sep("-", f"cache directories for {glob!r}")
|
|
||||||
for p in contents:
|
|
||||||
# if p.is_dir():
|
|
||||||
# print("%s/" % p.relative_to(basedir))
|
|
||||||
if p.is_file():
|
|
||||||
key = str(p.relative_to(basedir))
|
|
||||||
tw.line(f"{key} is a file of length {p.stat().st_size}")
|
|
||||||
return 0
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,322 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
"""Python version compatibility code."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
import enum
|
|
||||||
import functools
|
|
||||||
import inspect
|
|
||||||
from inspect import Parameter
|
|
||||||
from inspect import signature
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import sys
|
|
||||||
from typing import Any
|
|
||||||
from typing import Final
|
|
||||||
from typing import NoReturn
|
|
||||||
|
|
||||||
import py
|
|
||||||
|
|
||||||
|
|
||||||
#: constant to prepare valuing pylib path replacements/lazy proxies later on
|
|
||||||
# intended for removal in pytest 8.0 or 9.0
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# intentional space to create a fake difference for the verification
|
|
||||||
LEGACY_PATH = py.path. local
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
def legacy_path(path: str | os.PathLike[str]) -> LEGACY_PATH:
|
|
||||||
"""Internal wrapper to prepare lazy proxies for legacy_path instances"""
|
|
||||||
return LEGACY_PATH(path)
|
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# Singleton type for NOTSET, as described in:
|
|
||||||
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
|
|
||||||
class NotSetType(enum.Enum):
|
|
||||||
token = 0
|
|
||||||
NOTSET: Final = NotSetType.token
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
def iscoroutinefunction(func: object) -> bool:
|
|
||||||
"""Return True if func is a coroutine function (a function defined with async
|
|
||||||
def syntax, and doesn't contain yield), or a function decorated with
|
|
||||||
@asyncio.coroutine.
|
|
||||||
|
|
||||||
Note: copied and modified from Python 3.5's builtin coroutines.py to avoid
|
|
||||||
importing asyncio directly, which in turns also initializes the "logging"
|
|
||||||
module as a side-effect (see issue #8).
|
|
||||||
"""
|
|
||||||
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
|
|
||||||
|
|
||||||
|
|
||||||
def is_async_function(func: object) -> bool:
|
|
||||||
"""Return True if the given function seems to be an async function or
|
|
||||||
an async generator."""
|
|
||||||
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
|
|
||||||
|
|
||||||
|
|
||||||
def getlocation(function, curdir: str | os.PathLike[str] | None = None) -> str:
|
|
||||||
function = get_real_func(function)
|
|
||||||
fn = Path(inspect.getfile(function))
|
|
||||||
lineno = function.__code__.co_firstlineno
|
|
||||||
if curdir is not None:
|
|
||||||
try:
|
|
||||||
relfn = fn.relative_to(curdir)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return f"{relfn}:{lineno + 1}"
|
|
||||||
return f"{fn}:{lineno + 1}"
|
|
||||||
|
|
||||||
|
|
||||||
def num_mock_patch_args(function) -> int:
|
|
||||||
"""Return number of arguments used up by mock arguments (if any)."""
|
|
||||||
patchings = getattr(function, "patchings", None)
|
|
||||||
if not patchings:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
mock_sentinel = getattr(sys.modules.get("mock"), "DEFAULT", object())
|
|
||||||
ut_mock_sentinel = getattr(sys.modules.get("unittest.mock"), "DEFAULT", object())
|
|
||||||
|
|
||||||
return len(
|
|
||||||
[
|
|
||||||
p
|
|
||||||
for p in patchings
|
|
||||||
if not p.attribute_name
|
|
||||||
and (p.new is mock_sentinel or p.new is ut_mock_sentinel)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def getfuncargnames(
|
|
||||||
function: Callable[..., object],
|
|
||||||
*,
|
|
||||||
name: str = "",
|
|
||||||
cls: type | None = None,
|
|
||||||
) -> tuple[str, ...]:
|
|
||||||
"""Return the names of a function's mandatory arguments.
|
|
||||||
|
|
||||||
Should return the names of all function arguments that:
|
|
||||||
* Aren't bound to an instance or type as in instance or class methods.
|
|
||||||
* Don't have default values.
|
|
||||||
* Aren't bound with functools.partial.
|
|
||||||
* Aren't replaced with mocks.
|
|
||||||
|
|
||||||
The cls arguments indicate that the function should be treated as a bound
|
|
||||||
method even though it's not unless the function is a static method.
|
|
||||||
|
|
||||||
The name parameter should be the original name in which the function was collected.
|
|
||||||
"""
|
|
||||||
# TODO(RonnyPfannschmidt): This function should be refactored when we
|
|
||||||
# revisit fixtures. The fixture mechanism should ask the node for
|
|
||||||
# the fixture names, and not try to obtain directly from the
|
|
||||||
# function object well after collection has occurred.
|
|
||||||
|
|
||||||
# The parameters attribute of a Signature object contains an
|
|
||||||
# ordered mapping of parameter names to Parameter instances. This
|
|
||||||
# creates a tuple of the names of the parameters that don't have
|
|
||||||
# defaults.
|
|
||||||
try:
|
|
||||||
parameters = signature(function).parameters.values()
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
from _pytest.outcomes import fail
|
|
||||||
|
|
||||||
fail(
|
|
||||||
f"Could not determine arguments of {function!r}: {e}",
|
|
||||||
pytrace=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
arg_names = tuple(
|
|
||||||
p.name
|
|
||||||
for p in parameters
|
|
||||||
if (
|
|
||||||
p.kind is Parameter.POSITIONAL_OR_KEYWORD
|
|
||||||
or p.kind is Parameter.KEYWORD_ONLY
|
|
||||||
)
|
|
||||||
and p.default is Parameter.empty
|
|
||||||
)
|
|
||||||
if not name:
|
|
||||||
name = function.__name__
|
|
||||||
|
|
||||||
# If this function should be treated as a bound method even though
|
|
||||||
# it's passed as an unbound method or function, and its first parameter
|
|
||||||
# wasn't defined as positional only, remove the first parameter name.
|
|
||||||
if not any(p.kind is Parameter.POSITIONAL_ONLY for p in parameters) and (
|
|
||||||
# Not using `getattr` because we don't want to resolve the staticmethod.
|
|
||||||
# Not using `cls.__dict__` because we want to check the entire MRO.
|
|
||||||
cls
|
|
||||||
and not isinstance(
|
|
||||||
inspect.getattr_static(cls, name, default=None), staticmethod
|
|
||||||
)
|
|
||||||
):
|
|
||||||
arg_names = arg_names[1:]
|
|
||||||
# Remove any names that will be replaced with mocks.
|
|
||||||
if hasattr(function, "__wrapped__"):
|
|
||||||
arg_names = arg_names[num_mock_patch_args(function) :]
|
|
||||||
return arg_names
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
|
|
||||||
# Note: this code intentionally mirrors the code at the beginning of
|
|
||||||
# getfuncargnames, to get the arguments which were excluded from its result
|
|
||||||
# because they had default values.
|
|
||||||
return tuple(
|
|
||||||
p.name
|
|
||||||
for p in signature(function).parameters.values()
|
|
||||||
if p.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
|
|
||||||
and p.default is not Parameter.empty
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_non_printable_ascii_translate_table = {
|
|
||||||
i: f"\\x{i:02x}" for i in range(128) if i not in range(32, 127)
|
|
||||||
}
|
|
||||||
_non_printable_ascii_translate_table.update(
|
|
||||||
{ord("\t"): "\\t", ord("\r"): "\\r", ord("\n"): "\\n"}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ascii_escaped(val: bytes | str) -> str:
|
|
||||||
r"""If val is pure ASCII, return it as an str, otherwise, escape
|
|
||||||
bytes objects into a sequence of escaped bytes:
|
|
||||||
|
|
||||||
b'\xc3\xb4\xc5\xd6' -> r'\xc3\xb4\xc5\xd6'
|
|
||||||
|
|
||||||
and escapes strings into a sequence of escaped unicode ids, e.g.:
|
|
||||||
|
|
||||||
r'4\nV\U00043efa\x0eMXWB\x1e\u3028\u15fd\xcd\U0007d944'
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The obvious "v.decode('unicode-escape')" will return
|
|
||||||
valid UTF-8 unicode if it finds them in bytes, but we
|
|
||||||
want to return escaped bytes for any byte, even if they match
|
|
||||||
a UTF-8 string.
|
|
||||||
"""
|
|
||||||
if isinstance(val, bytes):
|
|
||||||
ret = val.decode("ascii", "backslashreplace")
|
|
||||||
else:
|
|
||||||
ret = val.encode("unicode_escape").decode("ascii")
|
|
||||||
return ret.translate(_non_printable_ascii_translate_table)
|
|
||||||
|
|
||||||
|
|
||||||
def get_real_func(obj):
|
|
||||||
"""Get the real function object of the (possibly) wrapped object by
|
|
||||||
:func:`functools.wraps`, or :func:`functools.partial`."""
|
|
||||||
obj = inspect.unwrap(obj)
|
|
||||||
|
|
||||||
if isinstance(obj, functools.partial):
|
|
||||||
obj = obj.func
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def getimfunc(func):
|
|
||||||
try:
|
|
||||||
return func.__func__
|
|
||||||
except AttributeError:
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def safe_getattr(object: Any, name: str, default: Any) -> Any:
|
|
||||||
"""Like getattr but return default upon any Exception or any OutcomeException.
|
|
||||||
|
|
||||||
Attribute access can potentially fail for 'evil' Python objects.
|
|
||||||
See issue #214.
|
|
||||||
It catches OutcomeException because of #2490 (issue #580), new outcomes
|
|
||||||
are derived from BaseException instead of Exception (for more details
|
|
||||||
check #2707).
|
|
||||||
"""
|
|
||||||
from _pytest.outcomes import TEST_OUTCOME
|
|
||||||
|
|
||||||
try:
|
|
||||||
return getattr(object, name, default)
|
|
||||||
except TEST_OUTCOME:
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def safe_isclass(obj: object) -> bool:
|
|
||||||
"""Ignore any exception via isinstance on Python 3."""
|
|
||||||
try:
|
|
||||||
return inspect.isclass(obj)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_id() -> int | None:
|
|
||||||
"""Return the current process's real user id or None if it could not be
|
|
||||||
determined.
|
|
||||||
|
|
||||||
:return: The user id or None if it could not be determined.
|
|
||||||
"""
|
|
||||||
# mypy follows the version and platform checking expectation of PEP 484:
|
|
||||||
# https://mypy.readthedocs.io/en/stable/common_issues.html?highlight=platform#python-version-and-system-platform-checks
|
|
||||||
# Containment checks are too complex for mypy v1.5.0 and cause failure.
|
|
||||||
if sys.platform == "win32" or sys.platform == "emscripten":
|
|
||||||
# win32 does not have a getuid() function.
|
|
||||||
# Emscripten has a return 0 stub.
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# On other platforms, a return value of -1 is assumed to indicate that
|
|
||||||
# the current process's real user id could not be determined.
|
|
||||||
ERROR = -1
|
|
||||||
uid = os.getuid()
|
|
||||||
return uid if uid != ERROR else None
|
|
||||||
|
|
||||||
|
|
||||||
# Perform exhaustiveness checking.
|
|
||||||
#
|
|
||||||
# Consider this example:
|
|
||||||
#
|
|
||||||
# MyUnion = Union[int, str]
|
|
||||||
#
|
|
||||||
# def handle(x: MyUnion) -> int {
|
|
||||||
# if isinstance(x, int):
|
|
||||||
# return 1
|
|
||||||
# elif isinstance(x, str):
|
|
||||||
# return 2
|
|
||||||
# else:
|
|
||||||
# raise Exception('unreachable')
|
|
||||||
#
|
|
||||||
# Now suppose we add a new variant:
|
|
||||||
#
|
|
||||||
# MyUnion = Union[int, str, bytes]
|
|
||||||
#
|
|
||||||
# After doing this, we must remember ourselves to go and update the handle
|
|
||||||
# function to handle the new variant.
|
|
||||||
#
|
|
||||||
# With `assert_never` we can do better:
|
|
||||||
#
|
|
||||||
# // raise Exception('unreachable')
|
|
||||||
# return assert_never(x)
|
|
||||||
#
|
|
||||||
# Now, if we forget to handle the new variant, the type-checker will emit a
|
|
||||||
# compile-time error, instead of the runtime error we would have gotten
|
|
||||||
# previously.
|
|
||||||
#
|
|
||||||
# This also work for Enums (if you use `is` to compare) and Literals.
|
|
||||||
def assert_never(value: NoReturn) -> NoReturn:
|
|
||||||
assert False, f"Unhandled value: {value} ({type(value).__name__})"
|
|
||||||
|
|
||||||
|
|
||||||
class CallableBool:
|
|
||||||
"""
|
|
||||||
A bool-like object that can also be called, returning its true/false value.
|
|
||||||
|
|
||||||
Used for backwards compatibility in cases where something was supposed to be a method
|
|
||||||
but was implemented as a simple attribute by mistake (see `TerminalReporter.isatty`).
|
|
||||||
|
|
||||||
Do not use in new code.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, value: bool) -> None:
|
|
||||||
self._value = value
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return self._value
|
|
||||||
|
|
||||||
def __call__(self) -> bool:
|
|
||||||
return self._value
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,533 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from collections.abc import Sequence
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
from typing import final
|
|
||||||
from typing import Literal
|
|
||||||
from typing import NoReturn
|
|
||||||
|
|
||||||
import _pytest._io
|
|
||||||
from _pytest.config.exceptions import UsageError
|
|
||||||
from _pytest.deprecated import check_ispytest
|
|
||||||
|
|
||||||
|
|
||||||
FILE_OR_DIR = "file_or_dir"
|
|
||||||
|
|
||||||
|
|
||||||
class NotSet:
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "<notset>"
|
|
||||||
|
|
||||||
|
|
||||||
NOT_SET = NotSet()
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class Parser:
|
|
||||||
"""Parser for command line arguments and ini-file values.
|
|
||||||
|
|
||||||
:ivar extra_info: Dict of generic param -> value to display in case
|
|
||||||
there's an error processing the command line arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prog: str | None = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
usage: str | None = None,
|
|
||||||
processopt: Callable[[Argument], None] | None = None,
|
|
||||||
*,
|
|
||||||
_ispytest: bool = False,
|
|
||||||
) -> None:
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
self._anonymous = OptionGroup("Custom options", parser=self, _ispytest=True)
|
|
||||||
self._groups: list[OptionGroup] = []
|
|
||||||
self._processopt = processopt
|
|
||||||
self._usage = usage
|
|
||||||
self._inidict: dict[str, tuple[str, str | None, Any]] = {}
|
|
||||||
self._ininames: list[str] = []
|
|
||||||
self.extra_info: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def processoption(self, option: Argument) -> None:
|
|
||||||
if self._processopt:
|
|
||||||
if option.dest:
|
|
||||||
self._processopt(option)
|
|
||||||
|
|
||||||
def getgroup(
|
|
||||||
self, name: str, description: str = "", after: str | None = None
|
|
||||||
) -> OptionGroup:
|
|
||||||
"""Get (or create) a named option Group.
|
|
||||||
|
|
||||||
:param name: Name of the option group.
|
|
||||||
:param description: Long description for --help output.
|
|
||||||
:param after: Name of another group, used for ordering --help output.
|
|
||||||
:returns: The option group.
|
|
||||||
|
|
||||||
The returned group object has an ``addoption`` method with the same
|
|
||||||
signature as :func:`parser.addoption <pytest.Parser.addoption>` but
|
|
||||||
will be shown in the respective group in the output of
|
|
||||||
``pytest --help``.
|
|
||||||
"""
|
|
||||||
for group in self._groups:
|
|
||||||
if group.name == name:
|
|
||||||
return group
|
|
||||||
group = OptionGroup(name, description, parser=self, _ispytest=True)
|
|
||||||
i = 0
|
|
||||||
for i, grp in enumerate(self._groups):
|
|
||||||
if grp.name == after:
|
|
||||||
break
|
|
||||||
self._groups.insert(i + 1, group)
|
|
||||||
return group
|
|
||||||
|
|
||||||
def addoption(self, *opts: str, **attrs: Any) -> None:
|
|
||||||
"""Register a command line option.
|
|
||||||
|
|
||||||
:param opts:
|
|
||||||
Option names, can be short or long options.
|
|
||||||
:param attrs:
|
|
||||||
Same attributes as the argparse library's :meth:`add_argument()
|
|
||||||
<argparse.ArgumentParser.add_argument>` function accepts.
|
|
||||||
|
|
||||||
After command line parsing, options are available on the pytest config
|
|
||||||
object via ``config.option.NAME`` where ``NAME`` is usually set
|
|
||||||
by passing a ``dest`` attribute, for example
|
|
||||||
``addoption("--long", dest="NAME", ...)``.
|
|
||||||
"""
|
|
||||||
self._anonymous.addoption(*opts, **attrs)
|
|
||||||
|
|
||||||
def parse(
|
|
||||||
self,
|
|
||||||
args: Sequence[str | os.PathLike[str]],
|
|
||||||
namespace: argparse.Namespace | None = None,
|
|
||||||
) -> argparse.Namespace:
|
|
||||||
from _pytest._argcomplete import try_argcomplete
|
|
||||||
|
|
||||||
self.optparser = self._getparser()
|
|
||||||
try_argcomplete(self.optparser)
|
|
||||||
strargs = [os.fspath(x) for x in args]
|
|
||||||
return self.optparser.parse_args(strargs, namespace=namespace)
|
|
||||||
|
|
||||||
def _getparser(self) -> MyOptionParser:
|
|
||||||
from _pytest._argcomplete import filescompleter
|
|
||||||
|
|
||||||
optparser = MyOptionParser(self, self.extra_info, prog=self.prog)
|
|
||||||
groups = [*self._groups, self._anonymous]
|
|
||||||
for group in groups:
|
|
||||||
if group.options:
|
|
||||||
desc = group.description or group.name
|
|
||||||
arggroup = optparser.add_argument_group(desc)
|
|
||||||
for option in group.options:
|
|
||||||
n = option.names()
|
|
||||||
a = option.attrs()
|
|
||||||
arggroup.add_argument(*n, **a)
|
|
||||||
file_or_dir_arg = optparser.add_argument(FILE_OR_DIR, nargs="*")
|
|
||||||
# bash like autocompletion for dirs (appending '/')
|
|
||||||
# Type ignored because typeshed doesn't know about argcomplete.
|
|
||||||
file_or_dir_arg.completer = filescompleter # type: ignore
|
|
||||||
return optparser
|
|
||||||
|
|
||||||
def parse_setoption(
|
|
||||||
self,
|
|
||||||
args: Sequence[str | os.PathLike[str]],
|
|
||||||
option: argparse.Namespace,
|
|
||||||
namespace: argparse.Namespace | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
parsedoption = self.parse(args, namespace=namespace)
|
|
||||||
for name, value in parsedoption.__dict__.items():
|
|
||||||
setattr(option, name, value)
|
|
||||||
return cast(list[str], getattr(parsedoption, FILE_OR_DIR))
|
|
||||||
|
|
||||||
def parse_known_args(
|
|
||||||
self,
|
|
||||||
args: Sequence[str | os.PathLike[str]],
|
|
||||||
namespace: argparse.Namespace | None = None,
|
|
||||||
) -> argparse.Namespace:
|
|
||||||
"""Parse the known arguments at this point.
|
|
||||||
|
|
||||||
:returns: An argparse namespace object.
|
|
||||||
"""
|
|
||||||
return self.parse_known_and_unknown_args(args, namespace=namespace)[0]
|
|
||||||
|
|
||||||
def parse_known_and_unknown_args(
|
|
||||||
self,
|
|
||||||
args: Sequence[str | os.PathLike[str]],
|
|
||||||
namespace: argparse.Namespace | None = None,
|
|
||||||
) -> tuple[argparse.Namespace, list[str]]:
|
|
||||||
"""Parse the known arguments at this point, and also return the
|
|
||||||
remaining unknown arguments.
|
|
||||||
|
|
||||||
:returns:
|
|
||||||
A tuple containing an argparse namespace object for the known
|
|
||||||
arguments, and a list of the unknown arguments.
|
|
||||||
"""
|
|
||||||
optparser = self._getparser()
|
|
||||||
strargs = [os.fspath(x) for x in args]
|
|
||||||
return optparser.parse_known_args(strargs, namespace=namespace)
|
|
||||||
|
|
||||||
def addini(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
help: str,
|
|
||||||
type: Literal["string", "paths", "pathlist", "args", "linelist", "bool"]
|
|
||||||
| None = None,
|
|
||||||
default: Any = NOT_SET,
|
|
||||||
) -> None:
|
|
||||||
"""Register an ini-file option.
|
|
||||||
|
|
||||||
:param name:
|
|
||||||
Name of the ini-variable.
|
|
||||||
:param type:
|
|
||||||
Type of the variable. Can be:
|
|
||||||
|
|
||||||
* ``string``: a string
|
|
||||||
* ``bool``: a boolean
|
|
||||||
* ``args``: a list of strings, separated as in a shell
|
|
||||||
* ``linelist``: a list of strings, separated by line breaks
|
|
||||||
* ``paths``: a list of :class:`pathlib.Path`, separated as in a shell
|
|
||||||
* ``pathlist``: a list of ``py.path``, separated as in a shell
|
|
||||||
* ``int``: an integer
|
|
||||||
* ``float``: a floating-point number
|
|
||||||
|
|
||||||
.. versionadded:: 8.4
|
|
||||||
|
|
||||||
The ``float`` and ``int`` types.
|
|
||||||
|
|
||||||
For ``paths`` and ``pathlist`` types, they are considered relative to the ini-file.
|
|
||||||
In case the execution is happening without an ini-file defined,
|
|
||||||
they will be considered relative to the current working directory (for example with ``--override-ini``).
|
|
||||||
|
|
||||||
.. versionadded:: 7.0
|
|
||||||
The ``paths`` variable type.
|
|
||||||
|
|
||||||
.. versionadded:: 8.1
|
|
||||||
Use the current working directory to resolve ``paths`` and ``pathlist`` in the absence of an ini-file.
|
|
||||||
|
|
||||||
Defaults to ``string`` if ``None`` or not passed.
|
|
||||||
:param default:
|
|
||||||
Default value if no ini-file option exists but is queried.
|
|
||||||
|
|
||||||
The value of ini-variables can be retrieved via a call to
|
|
||||||
:py:func:`config.getini(name) <pytest.Config.getini>`.
|
|
||||||
"""
|
|
||||||
assert type in (
|
|
||||||
None,
|
|
||||||
"string",
|
|
||||||
"paths",
|
|
||||||
"pathlist",
|
|
||||||
"args",
|
|
||||||
"linelist",
|
|
||||||
"bool",
|
|
||||||
"int",
|
|
||||||
"float",
|
|
||||||
)
|
|
||||||
if default is NOT_SET:
|
|
||||||
default = get_ini_default_for_type(type)
|
|
||||||
|
|
||||||
self._inidict[name] = (help, type, default)
|
|
||||||
self._ininames.append(name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ini_default_for_type(
|
|
||||||
type: Literal[
|
|
||||||
"string", "paths", "pathlist", "args", "linelist", "bool", "int", "float"
|
|
||||||
]
|
|
||||||
| None,
|
|
||||||
) -> Any:
|
|
||||||
"""
|
|
||||||
Used by addini to get the default value for a given ini-option type, when
|
|
||||||
default is not supplied.
|
|
||||||
"""
|
|
||||||
if type is None:
|
|
||||||
return ""
|
|
||||||
elif type in ("paths", "pathlist", "args", "linelist"):
|
|
||||||
return []
|
|
||||||
elif type == "bool":
|
|
||||||
return False
|
|
||||||
elif type == "int":
|
|
||||||
return 0
|
|
||||||
elif type == "float":
|
|
||||||
return 0.0
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class ArgumentError(Exception):
|
|
||||||
"""Raised if an Argument instance is created with invalid or
|
|
||||||
inconsistent arguments."""
|
|
||||||
|
|
||||||
def __init__(self, msg: str, option: Argument | str) -> None:
|
|
||||||
self.msg = msg
|
|
||||||
self.option_id = str(option)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
if self.option_id:
|
|
||||||
return f"option {self.option_id}: {self.msg}"
|
|
||||||
else:
|
|
||||||
return self.msg
|
|
||||||
|
|
||||||
|
|
||||||
class Argument:
|
|
||||||
"""Class that mimics the necessary behaviour of optparse.Option.
|
|
||||||
|
|
||||||
It's currently a least effort implementation and ignoring choices
|
|
||||||
and integer prefixes.
|
|
||||||
|
|
||||||
https://docs.python.org/3/library/optparse.html#optparse-standard-option-types
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *names: str, **attrs: Any) -> None:
|
|
||||||
"""Store params in private vars for use in add_argument."""
|
|
||||||
self._attrs = attrs
|
|
||||||
self._short_opts: list[str] = []
|
|
||||||
self._long_opts: list[str] = []
|
|
||||||
try:
|
|
||||||
self.type = attrs["type"]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
# Attribute existence is tested in Config._processopt.
|
|
||||||
self.default = attrs["default"]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
self._set_opt_strings(names)
|
|
||||||
dest: str | None = attrs.get("dest")
|
|
||||||
if dest:
|
|
||||||
self.dest = dest
|
|
||||||
elif self._long_opts:
|
|
||||||
self.dest = self._long_opts[0][2:].replace("-", "_")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.dest = self._short_opts[0][1:]
|
|
||||||
except IndexError as e:
|
|
||||||
self.dest = "???" # Needed for the error repr.
|
|
||||||
raise ArgumentError("need a long or short option", self) from e
|
|
||||||
|
|
||||||
def names(self) -> list[str]:
|
|
||||||
return self._short_opts + self._long_opts
|
|
||||||
|
|
||||||
def attrs(self) -> Mapping[str, Any]:
|
|
||||||
# Update any attributes set by processopt.
|
|
||||||
attrs = "default dest help".split()
|
|
||||||
attrs.append(self.dest)
|
|
||||||
for attr in attrs:
|
|
||||||
try:
|
|
||||||
self._attrs[attr] = getattr(self, attr)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
return self._attrs
|
|
||||||
|
|
||||||
def _set_opt_strings(self, opts: Sequence[str]) -> None:
|
|
||||||
"""Directly from optparse.
|
|
||||||
|
|
||||||
Might not be necessary as this is passed to argparse later on.
|
|
||||||
"""
|
|
||||||
for opt in opts:
|
|
||||||
if len(opt) < 2:
|
|
||||||
raise ArgumentError(
|
|
||||||
f"invalid option string {opt!r}: "
|
|
||||||
"must be at least two characters long",
|
|
||||||
self,
|
|
||||||
)
|
|
||||||
elif len(opt) == 2:
|
|
||||||
if not (opt[0] == "-" and opt[1] != "-"):
|
|
||||||
raise ArgumentError(
|
|
||||||
f"invalid short option string {opt!r}: "
|
|
||||||
"must be of the form -x, (x any non-dash char)",
|
|
||||||
self,
|
|
||||||
)
|
|
||||||
self._short_opts.append(opt)
|
|
||||||
else:
|
|
||||||
if not (opt[0:2] == "--" and opt[2] != "-"):
|
|
||||||
raise ArgumentError(
|
|
||||||
f"invalid long option string {opt!r}: "
|
|
||||||
"must start with --, followed by non-dash",
|
|
||||||
self,
|
|
||||||
)
|
|
||||||
self._long_opts.append(opt)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
args: list[str] = []
|
|
||||||
if self._short_opts:
|
|
||||||
args += ["_short_opts: " + repr(self._short_opts)]
|
|
||||||
if self._long_opts:
|
|
||||||
args += ["_long_opts: " + repr(self._long_opts)]
|
|
||||||
args += ["dest: " + repr(self.dest)]
|
|
||||||
if hasattr(self, "type"):
|
|
||||||
args += ["type: " + repr(self.type)]
|
|
||||||
if hasattr(self, "default"):
|
|
||||||
args += ["default: " + repr(self.default)]
|
|
||||||
return "Argument({})".format(", ".join(args))
|
|
||||||
|
|
||||||
|
|
||||||
class OptionGroup:
|
|
||||||
"""A group of options shown in its own section."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
description: str = "",
|
|
||||||
parser: Parser | None = None,
|
|
||||||
*,
|
|
||||||
_ispytest: bool = False,
|
|
||||||
) -> None:
|
|
||||||
check_ispytest(_ispytest)
|
|
||||||
self.name = name
|
|
||||||
self.description = description
|
|
||||||
self.options: list[Argument] = []
|
|
||||||
self.parser = parser
|
|
||||||
|
|
||||||
def addoption(self, *opts: str, **attrs: Any) -> None:
|
|
||||||
"""Add an option to this group.
|
|
||||||
|
|
||||||
If a shortened version of a long option is specified, it will
|
|
||||||
be suppressed in the help. ``addoption('--twowords', '--two-words')``
|
|
||||||
results in help showing ``--two-words`` only, but ``--twowords`` gets
|
|
||||||
accepted **and** the automatic destination is in ``args.twowords``.
|
|
||||||
|
|
||||||
:param opts:
|
|
||||||
Option names, can be short or long options.
|
|
||||||
:param attrs:
|
|
||||||
Same attributes as the argparse library's :meth:`add_argument()
|
|
||||||
<argparse.ArgumentParser.add_argument>` function accepts.
|
|
||||||
"""
|
|
||||||
conflict = set(opts).intersection(
|
|
||||||
name for opt in self.options for name in opt.names()
|
|
||||||
)
|
|
||||||
if conflict:
|
|
||||||
raise ValueError(f"option names {conflict} already added")
|
|
||||||
option = Argument(*opts, **attrs)
|
|
||||||
self._addoption_instance(option, shortupper=False)
|
|
||||||
|
|
||||||
def _addoption(self, *opts: str, **attrs: Any) -> None:
|
|
||||||
option = Argument(*opts, **attrs)
|
|
||||||
self._addoption_instance(option, shortupper=True)
|
|
||||||
|
|
||||||
def _addoption_instance(self, option: Argument, shortupper: bool = False) -> None:
|
|
||||||
if not shortupper:
|
|
||||||
for opt in option._short_opts:
|
|
||||||
if opt[0] == "-" and opt[1].islower():
|
|
||||||
raise ValueError("lowercase shortoptions reserved")
|
|
||||||
if self.parser:
|
|
||||||
self.parser.processoption(option)
|
|
||||||
self.options.append(option)
|
|
||||||
|
|
||||||
|
|
||||||
class MyOptionParser(argparse.ArgumentParser):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
parser: Parser,
|
|
||||||
extra_info: dict[str, Any] | None = None,
|
|
||||||
prog: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self._parser = parser
|
|
||||||
super().__init__(
|
|
||||||
prog=prog,
|
|
||||||
usage=parser._usage,
|
|
||||||
add_help=False,
|
|
||||||
formatter_class=DropShorterLongHelpFormatter,
|
|
||||||
allow_abbrev=False,
|
|
||||||
fromfile_prefix_chars="@",
|
|
||||||
)
|
|
||||||
# extra_info is a dict of (param -> value) to display if there's
|
|
||||||
# an usage error to provide more contextual information to the user.
|
|
||||||
self.extra_info = extra_info if extra_info else {}
|
|
||||||
|
|
||||||
def error(self, message: str) -> NoReturn:
|
|
||||||
"""Transform argparse error message into UsageError."""
|
|
||||||
msg = f"{self.prog}: error: {message}"
|
|
||||||
|
|
||||||
if hasattr(self._parser, "_config_source_hint"):
|
|
||||||
msg = f"{msg} ({self._parser._config_source_hint})"
|
|
||||||
|
|
||||||
raise UsageError(self.format_usage() + msg)
|
|
||||||
|
|
||||||
# Type ignored because typeshed has a very complex type in the superclass.
|
|
||||||
def parse_args( # type: ignore
|
|
||||||
self,
|
|
||||||
args: Sequence[str] | None = None,
|
|
||||||
namespace: argparse.Namespace | None = None,
|
|
||||||
) -> argparse.Namespace:
|
|
||||||
"""Allow splitting of positional arguments."""
|
|
||||||
parsed, unrecognized = self.parse_known_args(args, namespace)
|
|
||||||
if unrecognized:
|
|
||||||
for arg in unrecognized:
|
|
||||||
if arg and arg[0] == "-":
|
|
||||||
lines = [
|
|
||||||
"unrecognized arguments: {}".format(" ".join(unrecognized))
|
|
||||||
]
|
|
||||||
for k, v in sorted(self.extra_info.items()):
|
|
||||||
lines.append(f" {k}: {v}")
|
|
||||||
self.error("\n".join(lines))
|
|
||||||
getattr(parsed, FILE_OR_DIR).extend(unrecognized)
|
|
||||||
return parsed
|
|
||||||
|
|
||||||
|
|
||||||
class DropShorterLongHelpFormatter(argparse.HelpFormatter):
|
|
||||||
"""Shorten help for long options that differ only in extra hyphens.
|
|
||||||
|
|
||||||
- Collapse **long** options that are the same except for extra hyphens.
|
|
||||||
- Shortcut if there are only two options and one of them is a short one.
|
|
||||||
- Cache result on the action object as this is called at least 2 times.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
# Use more accurate terminal width.
|
|
||||||
if "width" not in kwargs:
|
|
||||||
kwargs["width"] = _pytest._io.get_terminal_width()
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def _format_action_invocation(self, action: argparse.Action) -> str:
|
|
||||||
orgstr = super()._format_action_invocation(action)
|
|
||||||
if orgstr and orgstr[0] != "-": # only optional arguments
|
|
||||||
return orgstr
|
|
||||||
res: str | None = getattr(action, "_formatted_action_invocation", None)
|
|
||||||
if res:
|
|
||||||
return res
|
|
||||||
options = orgstr.split(", ")
|
|
||||||
if len(options) == 2 and (len(options[0]) == 2 or len(options[1]) == 2):
|
|
||||||
# a shortcut for '-h, --help' or '--abc', '-a'
|
|
||||||
action._formatted_action_invocation = orgstr # type: ignore
|
|
||||||
return orgstr
|
|
||||||
return_list = []
|
|
||||||
short_long: dict[str, str] = {}
|
|
||||||
for option in options:
|
|
||||||
if len(option) == 2 or option[2] == " ":
|
|
||||||
continue
|
|
||||||
if not option.startswith("--"):
|
|
||||||
raise ArgumentError(
|
|
||||||
f'long optional argument without "--": [{option}]', option
|
|
||||||
)
|
|
||||||
xxoption = option[2:]
|
|
||||||
shortened = xxoption.replace("-", "")
|
|
||||||
if shortened not in short_long or len(short_long[shortened]) < len(
|
|
||||||
xxoption
|
|
||||||
):
|
|
||||||
short_long[shortened] = xxoption
|
|
||||||
# now short_long has been filled out to the longest with dashes
|
|
||||||
# **and** we keep the right option ordering from add_argument
|
|
||||||
for option in options:
|
|
||||||
if len(option) == 2 or option[2] == " ":
|
|
||||||
return_list.append(option)
|
|
||||||
if option[2:] == short_long.get(option.replace("-", "")):
|
|
||||||
return_list.append(option.replace(" ", "=", 1))
|
|
||||||
formatted_action_invocation = ", ".join(return_list)
|
|
||||||
action._formatted_action_invocation = formatted_action_invocation # type: ignore
|
|
||||||
return formatted_action_invocation
|
|
||||||
|
|
||||||
def _split_lines(self, text, width):
|
|
||||||
"""Wrap lines after splitting on original newlines.
|
|
||||||
|
|
||||||
This allows to have explicit line breaks in the help text.
|
|
||||||
"""
|
|
||||||
import textwrap
|
|
||||||
|
|
||||||
lines = []
|
|
||||||
for line in text.splitlines():
|
|
||||||
lines.extend(textwrap.wrap(line.strip(), width))
|
|
||||||
return lines
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Mapping
|
|
||||||
import functools
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import pluggy
|
|
||||||
|
|
||||||
from ..compat import LEGACY_PATH
|
|
||||||
from ..compat import legacy_path
|
|
||||||
from ..deprecated import HOOK_LEGACY_PATH_ARG
|
|
||||||
|
|
||||||
|
|
||||||
# hookname: (Path, LEGACY_PATH)
|
|
||||||
imply_paths_hooks: Mapping[str, tuple[str, str]] = {
|
|
||||||
"pytest_ignore_collect": ("collection_path", "path"),
|
|
||||||
"pytest_collect_file": ("file_path", "path"),
|
|
||||||
"pytest_pycollect_makemodule": ("module_path", "path"),
|
|
||||||
"pytest_report_header": ("start_path", "startdir"),
|
|
||||||
"pytest_report_collectionfinish": ("start_path", "startdir"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _check_path(path: Path, fspath: LEGACY_PATH) -> None:
|
|
||||||
if Path(fspath) != path:
|
|
||||||
raise ValueError(
|
|
||||||
f"Path({fspath!r}) != {path!r}\n"
|
|
||||||
"if both path and fspath are given they need to be equal"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PathAwareHookProxy:
|
|
||||||
"""
|
|
||||||
this helper wraps around hook callers
|
|
||||||
until pluggy supports fixingcalls, this one will do
|
|
||||||
|
|
||||||
it currently doesn't return full hook caller proxies for fixed hooks,
|
|
||||||
this may have to be changed later depending on bugs
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hook_relay: pluggy.HookRelay) -> None:
|
|
||||||
self._hook_relay = hook_relay
|
|
||||||
|
|
||||||
def __dir__(self) -> list[str]:
|
|
||||||
return dir(self._hook_relay)
|
|
||||||
|
|
||||||
def __getattr__(self, key: str) -> pluggy.HookCaller:
|
|
||||||
hook: pluggy.HookCaller = getattr(self._hook_relay, key)
|
|
||||||
if key not in imply_paths_hooks:
|
|
||||||
self.__dict__[key] = hook
|
|
||||||
return hook
|
|
||||||
else:
|
|
||||||
path_var, fspath_var = imply_paths_hooks[key]
|
|
||||||
|
|
||||||
@functools.wraps(hook)
|
|
||||||
def fixed_hook(**kw: Any) -> Any:
|
|
||||||
path_value: Path | None = kw.pop(path_var, None)
|
|
||||||
fspath_value: LEGACY_PATH | None = kw.pop(fspath_var, None)
|
|
||||||
if fspath_value is not None:
|
|
||||||
warnings.warn(
|
|
||||||
HOOK_LEGACY_PATH_ARG.format(
|
|
||||||
pylib_path_arg=fspath_var, pathlib_path_arg=path_var
|
|
||||||
),
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
if path_value is not None:
|
|
||||||
if fspath_value is not None:
|
|
||||||
_check_path(path_value, fspath_value)
|
|
||||||
else:
|
|
||||||
fspath_value = legacy_path(path_value)
|
|
||||||
else:
|
|
||||||
assert fspath_value is not None
|
|
||||||
path_value = Path(fspath_value)
|
|
||||||
|
|
||||||
kw[path_var] = path_value
|
|
||||||
kw[fspath_var] = fspath_value
|
|
||||||
return hook(**kw)
|
|
||||||
|
|
||||||
fixed_hook.name = hook.name # type: ignore[attr-defined]
|
|
||||||
fixed_hook.spec = hook.spec # type: ignore[attr-defined]
|
|
||||||
fixed_hook.__name__ = key
|
|
||||||
self.__dict__[key] = fixed_hook
|
|
||||||
return fixed_hook # type: ignore[return-value]
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import final
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class UsageError(Exception):
|
|
||||||
"""Error in pytest usage or invocation."""
|
|
||||||
|
|
||||||
|
|
||||||
class PrintHelp(Exception):
|
|
||||||
"""Raised when pytest should print its help to skip the rest of the
|
|
||||||
argument parsing and validation."""
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from collections.abc import Sequence
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import sys
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import iniconfig
|
|
||||||
|
|
||||||
from .exceptions import UsageError
|
|
||||||
from _pytest.outcomes import fail
|
|
||||||
from _pytest.pathlib import absolutepath
|
|
||||||
from _pytest.pathlib import commonpath
|
|
||||||
from _pytest.pathlib import safe_exists
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
|
||||||
|
|
||||||
# Even though TOML supports richer data types, all values are converted to str/list[str] during
|
|
||||||
# parsing to maintain compatibility with the rest of the configuration system.
|
|
||||||
ConfigDict: TypeAlias = dict[str, Union[str, list[str]]]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_ini_config(path: Path) -> iniconfig.IniConfig:
|
|
||||||
"""Parse the given generic '.ini' file using legacy IniConfig parser, returning
|
|
||||||
the parsed object.
|
|
||||||
|
|
||||||
Raise UsageError if the file cannot be parsed.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return iniconfig.IniConfig(str(path))
|
|
||||||
except iniconfig.ParseError as exc:
|
|
||||||
raise UsageError(str(exc)) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def load_config_dict_from_file(
|
|
||||||
filepath: Path,
|
|
||||||
) -> ConfigDict | None:
|
|
||||||
"""Load pytest configuration from the given file path, if supported.
|
|
||||||
|
|
||||||
Return None if the file does not contain valid pytest configuration.
|
|
||||||
"""
|
|
||||||
# Configuration from ini files are obtained from the [pytest] section, if present.
|
|
||||||
if filepath.suffix == ".ini":
|
|
||||||
iniconfig = _parse_ini_config(filepath)
|
|
||||||
|
|
||||||
if "pytest" in iniconfig:
|
|
||||||
return dict(iniconfig["pytest"].items())
|
|
||||||
else:
|
|
||||||
# "pytest.ini" files are always the source of configuration, even if empty.
|
|
||||||
if filepath.name == "pytest.ini":
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# '.cfg' files are considered if they contain a "[tool:pytest]" section.
|
|
||||||
elif filepath.suffix == ".cfg":
|
|
||||||
iniconfig = _parse_ini_config(filepath)
|
|
||||||
|
|
||||||
if "tool:pytest" in iniconfig.sections:
|
|
||||||
return dict(iniconfig["tool:pytest"].items())
|
|
||||||
elif "pytest" in iniconfig.sections:
|
|
||||||
# If a setup.cfg contains a "[pytest]" section, we raise a failure to indicate users that
|
|
||||||
# plain "[pytest]" sections in setup.cfg files is no longer supported (#3086).
|
|
||||||
fail(CFG_PYTEST_SECTION.format(filename="setup.cfg"), pytrace=False)
|
|
||||||
|
|
||||||
# '.toml' files are considered if they contain a [tool.pytest.ini_options] table.
|
|
||||||
elif filepath.suffix == ".toml":
|
|
||||||
if sys.version_info >= (3, 11):
|
|
||||||
import tomllib
|
|
||||||
else:
|
|
||||||
import tomli as tomllib
|
|
||||||
|
|
||||||
toml_text = filepath.read_text(encoding="utf-8")
|
|
||||||
try:
|
|
||||||
config = tomllib.loads(toml_text)
|
|
||||||
except tomllib.TOMLDecodeError as exc:
|
|
||||||
raise UsageError(f"{filepath}: {exc}") from exc
|
|
||||||
|
|
||||||
result = config.get("tool", {}).get("pytest", {}).get("ini_options", None)
|
|
||||||
if result is not None:
|
|
||||||
# TOML supports richer data types than ini files (strings, arrays, floats, ints, etc),
|
|
||||||
# however we need to convert all scalar values to str for compatibility with the rest
|
|
||||||
# of the configuration system, which expects strings only.
|
|
||||||
def make_scalar(v: object) -> str | list[str]:
|
|
||||||
return v if isinstance(v, list) else str(v)
|
|
||||||
|
|
||||||
return {k: make_scalar(v) for k, v in result.items()}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def locate_config(
|
|
||||||
invocation_dir: Path,
|
|
||||||
args: Iterable[Path],
|
|
||||||
) -> tuple[Path | None, Path | None, ConfigDict]:
|
|
||||||
"""Search in the list of arguments for a valid ini-file for pytest,
|
|
||||||
and return a tuple of (rootdir, inifile, cfg-dict)."""
|
|
||||||
config_names = [
|
|
||||||
"pytest.ini",
|
|
||||||
".pytest.ini",
|
|
||||||
"pyproject.toml",
|
|
||||||
"tox.ini",
|
|
||||||
"setup.cfg",
|
|
||||||
]
|
|
||||||
args = [x for x in args if not str(x).startswith("-")]
|
|
||||||
if not args:
|
|
||||||
args = [invocation_dir]
|
|
||||||
found_pyproject_toml: Path | None = None
|
|
||||||
for arg in args:
|
|
||||||
argpath = absolutepath(arg)
|
|
||||||
for base in (argpath, *argpath.parents):
|
|
||||||
for config_name in config_names:
|
|
||||||
p = base / config_name
|
|
||||||
if p.is_file():
|
|
||||||
if p.name == "pyproject.toml" and found_pyproject_toml is None:
|
|
||||||
found_pyproject_toml = p
|
|
||||||
ini_config = load_config_dict_from_file(p)
|
|
||||||
if ini_config is not None:
|
|
||||||
return base, p, ini_config
|
|
||||||
if found_pyproject_toml is not None:
|
|
||||||
return found_pyproject_toml.parent, found_pyproject_toml, {}
|
|
||||||
return None, None, {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_common_ancestor(
|
|
||||||
invocation_dir: Path,
|
|
||||||
paths: Iterable[Path],
|
|
||||||
) -> Path:
|
|
||||||
common_ancestor: Path | None = None
|
|
||||||
for path in paths:
|
|
||||||
if not path.exists():
|
|
||||||
continue
|
|
||||||
if common_ancestor is None:
|
|
||||||
common_ancestor = path
|
|
||||||
else:
|
|
||||||
if common_ancestor in path.parents or path == common_ancestor:
|
|
||||||
continue
|
|
||||||
elif path in common_ancestor.parents:
|
|
||||||
common_ancestor = path
|
|
||||||
else:
|
|
||||||
shared = commonpath(path, common_ancestor)
|
|
||||||
if shared is not None:
|
|
||||||
common_ancestor = shared
|
|
||||||
if common_ancestor is None:
|
|
||||||
common_ancestor = invocation_dir
|
|
||||||
elif common_ancestor.is_file():
|
|
||||||
common_ancestor = common_ancestor.parent
|
|
||||||
return common_ancestor
|
|
||||||
|
|
||||||
|
|
||||||
def get_dirs_from_args(args: Iterable[str]) -> list[Path]:
|
|
||||||
def is_option(x: str) -> bool:
|
|
||||||
return x.startswith("-")
|
|
||||||
|
|
||||||
def get_file_part_from_node_id(x: str) -> str:
|
|
||||||
return x.split("::")[0]
|
|
||||||
|
|
||||||
def get_dir_from_path(path: Path) -> Path:
|
|
||||||
if path.is_dir():
|
|
||||||
return path
|
|
||||||
return path.parent
|
|
||||||
|
|
||||||
# These look like paths but may not exist
|
|
||||||
possible_paths = (
|
|
||||||
absolutepath(get_file_part_from_node_id(arg))
|
|
||||||
for arg in args
|
|
||||||
if not is_option(arg)
|
|
||||||
)
|
|
||||||
|
|
||||||
return [get_dir_from_path(path) for path in possible_paths if safe_exists(path)]
|
|
||||||
|
|
||||||
|
|
||||||
CFG_PYTEST_SECTION = "[pytest] section in {filename} files is no longer supported, change to [tool:pytest] instead."
|
|
||||||
|
|
||||||
|
|
||||||
def determine_setup(
|
|
||||||
*,
|
|
||||||
inifile: str | None,
|
|
||||||
args: Sequence[str],
|
|
||||||
rootdir_cmd_arg: str | None,
|
|
||||||
invocation_dir: Path,
|
|
||||||
) -> tuple[Path, Path | None, ConfigDict]:
|
|
||||||
"""Determine the rootdir, inifile and ini configuration values from the
|
|
||||||
command line arguments.
|
|
||||||
|
|
||||||
:param inifile:
|
|
||||||
The `--inifile` command line argument, if given.
|
|
||||||
:param args:
|
|
||||||
The free command line arguments.
|
|
||||||
:param rootdir_cmd_arg:
|
|
||||||
The `--rootdir` command line argument, if given.
|
|
||||||
:param invocation_dir:
|
|
||||||
The working directory when pytest was invoked.
|
|
||||||
"""
|
|
||||||
rootdir = None
|
|
||||||
dirs = get_dirs_from_args(args)
|
|
||||||
if inifile:
|
|
||||||
inipath_ = absolutepath(inifile)
|
|
||||||
inipath: Path | None = inipath_
|
|
||||||
inicfg = load_config_dict_from_file(inipath_) or {}
|
|
||||||
if rootdir_cmd_arg is None:
|
|
||||||
rootdir = inipath_.parent
|
|
||||||
else:
|
|
||||||
ancestor = get_common_ancestor(invocation_dir, dirs)
|
|
||||||
rootdir, inipath, inicfg = locate_config(invocation_dir, [ancestor])
|
|
||||||
if rootdir is None and rootdir_cmd_arg is None:
|
|
||||||
for possible_rootdir in (ancestor, *ancestor.parents):
|
|
||||||
if (possible_rootdir / "setup.py").is_file():
|
|
||||||
rootdir = possible_rootdir
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if dirs != [ancestor]:
|
|
||||||
rootdir, inipath, inicfg = locate_config(invocation_dir, dirs)
|
|
||||||
if rootdir is None:
|
|
||||||
rootdir = get_common_ancestor(
|
|
||||||
invocation_dir, [invocation_dir, ancestor]
|
|
||||||
)
|
|
||||||
if is_fs_root(rootdir):
|
|
||||||
rootdir = ancestor
|
|
||||||
if rootdir_cmd_arg:
|
|
||||||
rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
|
|
||||||
if not rootdir.is_dir():
|
|
||||||
raise UsageError(
|
|
||||||
f"Directory '{rootdir}' not found. Check your '--rootdir' option."
|
|
||||||
)
|
|
||||||
assert rootdir is not None
|
|
||||||
return rootdir, inipath, inicfg or {}
|
|
||||||
|
|
||||||
|
|
||||||
def is_fs_root(p: Path) -> bool:
|
|
||||||
r"""
|
|
||||||
Return True if the given path is pointing to the root of the
|
|
||||||
file system ("/" on Unix and "C:\\" on Windows for example).
|
|
||||||
"""
|
|
||||||
return os.path.splitdrive(str(p))[1] == os.sep
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user